-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Skip redundant forward pass for on-policy vLLM importance sampling #5413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1863,12 +1863,13 @@ def _generate_and_score_completions( | |
| # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps | ||
| # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set | ||
| # old_per_token_logps to None. | ||
| # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the | ||
| # distribution mismatch between vLLM and the training model can be large and harm the training. | ||
| # When using liger kernel with vLLM, we need to pre-compute old_per_token_logps for IS correction | ||
| # since liger computes logprobs internally. For non-liger path, IS correction is computed in _compute_loss. | ||
| generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency | ||
| if self.args.gradient_accumulation_steps % generate_every != 0 or ( | ||
| self.use_vllm and self.vllm_importance_sampling_correction | ||
| ): | ||
| needs_old_logps_for_liger = ( | ||
| self.use_liger_kernel and self.use_vllm and self.vllm_importance_sampling_correction | ||
| ) | ||
| if self.args.gradient_accumulation_steps % generate_every != 0 or needs_old_logps_for_liger: | ||
| old_per_token_logps, _ = self._get_per_token_logps_and_entropies( | ||
| self.model, | ||
| prompt_completion_ids, | ||
|
|
@@ -1881,8 +1882,9 @@ def _generate_and_score_completions( | |
| else: | ||
| old_per_token_logps = None | ||
|
|
||
| # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch | ||
| if self.use_vllm and self.vllm_importance_sampling_correction: | ||
| # Pre-compute IS ratio for liger kernel path only; non-liger path computes it in _compute_loss | ||
| vllm_importance_sampling_ratio = None | ||
| if needs_old_logps_for_liger: | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mask = completion_mask if tool_mask is None else completion_mask * tool_mask | ||
| per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask | ||
|
|
||
|
|
@@ -2052,7 +2054,7 @@ def _generate_and_score_completions( | |
| if images is not None: | ||
| self._logs["images"].extend(gather_object(images)) | ||
|
|
||
| if self.use_vllm and self.vllm_importance_sampling_correction: | ||
| if needs_old_logps_for_liger: | ||
| delta = torch.abs(old_per_token_logps - sampling_per_token_logps) | ||
| mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() | ||
| delta = delta[mask] | ||
|
|
@@ -2292,13 +2294,50 @@ def _compute_loss(self, model, inputs): | |
| # provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor. | ||
| if advantages.dim() == 1: | ||
| advantages = advantages.unsqueeze(1) | ||
| # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, | ||
| # old_per_token_logps == per_token_logps. In this case we can skip its computation | ||
| # (see _generate_and_score_completions) and instead use per_token_logps.detach(). | ||
| # The exception is when using vLLM, where we always compute old_per_token_logps | ||
| # for importance sampling | ||
| # When gradient_accumulation_steps % generate_every == 0 (on-policy), | ||
| # old_per_token_logps == per_token_logps on the first iteration. In this case we can skip | ||
| # its computation (see _generate_and_score_completions) and instead use per_token_logps.detach(). | ||
| # When using vLLM without liger, IS correction is computed inline below using old_per_token_logps. | ||
| if inputs.get("old_per_token_logps") is None: | ||
| inputs["old_per_token_logps"] = per_token_logps.detach() | ||
| old_per_token_logps = inputs.get("old_per_token_logps") | ||
| old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps | ||
|
|
||
| # Compute inline IS ratio for non-liger vLLM path. This must happen before the loss type switch | ||
| # because vespo needs it in get_gamma_weights. For the liger path, IS ratio is pre-computed in | ||
| # _generate_and_score_completions and passed via inputs["importance_sampling_ratio"]. | ||
| vllm_importance_sampling_ratio = None | ||
| if self.use_vllm and self.vllm_importance_sampling_correction: | ||
| sampling_per_token_logps = inputs.get("sampling_per_token_logps") | ||
| if sampling_per_token_logps is not None and inputs.get("importance_sampling_ratio") is None: | ||
| per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask | ||
|
|
||
| sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] | ||
| if sequence_level_is: | ||
| per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) | ||
| logps_diff = per_sequence_logps_diff | ||
| else: | ||
| logps_diff = per_token_logps_diff | ||
|
|
||
| vllm_importance_sampling_ratio = torch.exp(logps_diff) | ||
|
|
||
| # vllm_importance_sampling_ratio.shape: | ||
| # token_* modes: (B, T) (per-token ratio) | ||
| # sequence_* modes: (B, 1) (per-sequence ratio) | ||
|
|
||
| if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: | ||
| vllm_importance_sampling_ratio = torch.clamp( | ||
| vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap | ||
| ) | ||
| elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: | ||
| vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( | ||
| vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." | ||
| ) | ||
|
|
||
| inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stale IS ratio reused across iterations via dict mutationMedium Severity When Additional Locations (1)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both old_per_token_logps and importance_sampling_ratio are stored back into inputs on the first iteration, so subsequent iterations (num_iterations > 1) reuse the same generation-time values. This means old_per_token_logps does not change each iteration — it is fixed from the first call. So coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift, and the total correction coef_1 * IS_ratio = exp(per_token_logps - old) * exp(old - sampling) = exp(per_token_logps - sampling) stays up-to-date, same as the original code. |
||
|
|
||
| if self.off_policy_mask_threshold is not None: | ||
| # OPSM should use inference-time logprobs to detect both sources of off-policyness: | ||
|
|
@@ -2381,6 +2420,43 @@ def _compute_loss(self, model, inputs): | |
| if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": | ||
| per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] | ||
|
|
||
| # Log IS correction metrics for inline-computed IS ratio (non-liger path) | ||
| if self.use_vllm and self.vllm_importance_sampling_correction and vllm_importance_sampling_ratio is not None: | ||
| mode = "train" if self.model.training else "eval" | ||
| delta = torch.abs(old_per_token_logps - sampling_per_token_logps) | ||
| delta = delta[mask.bool()] | ||
| mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) | ||
| max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) | ||
| self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( | ||
| self.accelerator.gather(mean_delta).mean().item() | ||
| ) | ||
| self._metrics[mode]["sampling/sampling_logp_difference/max"].append( | ||
| self.accelerator.gather(max_delta).max().item() | ||
| ) | ||
| if sequence_level_is: | ||
| flat_is_ratio = vllm_importance_sampling_ratio.flatten() | ||
| else: | ||
| flat_is_ratio = vllm_importance_sampling_ratio[mask.bool()] | ||
|
|
||
| min_importance_sampling_ratio = ( | ||
| torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) | ||
| ) | ||
| mean_importance_sampling_ratio = ( | ||
| torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) | ||
| ) | ||
| max_importance_sampling_ratio = ( | ||
| torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) | ||
| ) | ||
| self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( | ||
| nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() | ||
| ) | ||
| self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( | ||
| self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() | ||
| ) | ||
| self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( | ||
| nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Significant IS ratio code duplication within same fileMedium Severity The IS ratio computation (~15 lines) and metrics logging (~35 lines) are now fully duplicated between Additional Locations (2)Triggered by project rule: ../.ai/AGENTS.md
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The duplicated parts are IS ratio computation (~15 lines) and metrics logging (~35 lines). I kept them inline to follow the project's simplicity policy (avoid adding layers of indirection), but I can extract them into helper methods (_compute_vllm_is_ratio, _log_vllm_is_metrics) if preferred. |
||
|
|
||
| if self.beta != 0.0: | ||
| per_token_loss = per_token_loss + self.beta * per_token_kl | ||
|
|
||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PPO clipping disabled when num_iterations exceeds one
Medium Severity
When
num_iterations > 1with on-policy + non-liger + vLLM IS, skipping theold_per_token_logpspre-computation causes_compute_lossto always setold_per_token_logps = per_token_logps.detach(). This makescoef_1 = exp(per_token_logps - per_token_logps.detach()) = 1on every iteration, so PPO clipping never activates across iterations. Previously, the stored generation-timeold_per_token_logpsallowedcoef_1to reflect model drift and trigger clipping. The unclipped gradient product is mathematically equivalent, but the clipping safety mechanism is lost.Additional Locations (1)
trl/trainer/grpo_trainer.py#L2300-L2303There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _compute_loss method stores old_per_token_logps back into inputs on the first iteration (inputs["old_per_token_logps"] = per_token_logps.detach()), so subsequent iterations reuse the same generation-time logprobs instead of recomputing per_token_logps.detach(). This means coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift across iterations, and PPO clipping remains active.