Skip redundant forward pass for on-policy vLLM importance sampling#5413
Skip redundant forward pass for on-policy vLLM importance sampling#5413GJ98 wants to merge 3 commits intohuggingface:mainfrom
Conversation
…olicy vLLM training When using vLLM with vllm_importance_sampling_correction=True, the extra no_grad forward pass to compute old_per_token_logps is redundant in on-policy settings since old_per_token_logps == per_token_logps.detach(). Move IS correction computation from _generate_and_score_completions to _compute_loss for the non-liger path, using old_per_token_logps which falls back to per_token_logps.detach() on-policy. For liger kernel path, pre-compute IS ratio as before since liger computes logprobs internally. Fixes huggingface#5412
The inline IS correction in _compute_loss was computed after the loss type switch, so vespo's get_gamma_weights could not access the IS ratio. This silently dropped IS correction for vespo + non-liger + on-policy. Move IS ratio computation before the loss type switch and store it in inputs["importance_sampling_ratio"] so both vespo and non-vespo paths can access it. Align variable names, comments, and control flow with _generate_and_score_completions per AGENTS.md consistency rules. Fixes huggingface#5412
| f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." | ||
| ) | ||
|
|
||
| inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio |
There was a problem hiding this comment.
Stale IS ratio reused across iterations via dict mutation
Medium Severity
When num_iterations > 1, _compute_loss is called multiple times for the same buffered inputs dict. The first call computes the inline IS ratio and writes it back via inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio. On subsequent iterations, the guard inputs.get("importance_sampling_ratio") is None evaluates to False, so the ratio from iteration 1 is silently reused. Because old_per_token_logps equals per_token_logps.detach() (which changes each iteration) and coef_1 is always 1 in the on-policy path, the total IS correction becomes stale — using iteration-1 logprobs instead of the current iteration's. The old code avoided this because the pre-computed ratio and the standard IS ratio (coef_1) together always yielded the up-to-date total correction exp(per_token_logps - sampling).
Additional Locations (1)
There was a problem hiding this comment.
Both old_per_token_logps and importance_sampling_ratio are stored back into inputs on the first iteration, so subsequent iterations (num_iterations > 1) reuse the same generation-time values.
This means old_per_token_logps does not change each iteration — it is fixed from the first call. So coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift, and the total correction coef_1 * IS_ratio = exp(per_token_logps - old) * exp(old - sampling) = exp(per_token_logps - sampling) stays up-to-date, same as the original code.
…ons > 1 On the first iteration, store old_per_token_logps and importance_sampling_ratio back into inputs so that subsequent iterations reuse the same generation-time values. This keeps coef_1 and IS correction consistent with the original pre-computed path. Fixes huggingface#5412
| needs_old_logps_for_liger = ( | ||
| self.use_liger_kernel and self.use_vllm and self.vllm_importance_sampling_correction | ||
| ) | ||
| if self.args.gradient_accumulation_steps % generate_every != 0 or needs_old_logps_for_liger: |
There was a problem hiding this comment.
PPO clipping disabled when num_iterations exceeds one
Medium Severity
When num_iterations > 1 with on-policy + non-liger + vLLM IS, skipping the old_per_token_logps pre-computation causes _compute_loss to always set old_per_token_logps = per_token_logps.detach(). This makes coef_1 = exp(per_token_logps - per_token_logps.detach()) = 1 on every iteration, so PPO clipping never activates across iterations. Previously, the stored generation-time old_per_token_logps allowed coef_1 to reflect model drift and trigger clipping. The unclipped gradient product is mathematically equivalent, but the clipping safety mechanism is lost.
Additional Locations (1)
There was a problem hiding this comment.
The _compute_loss method stores old_per_token_logps back into inputs on the first iteration (inputs["old_per_token_logps"] = per_token_logps.detach()), so subsequent iterations reuse the same generation-time logprobs instead of recomputing per_token_logps.detach(). This means coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift across iterations, and PPO clipping remains active.
|
cursor review |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 3 total unresolved issues (including 2 from previous reviews).
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| ) | ||
| self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( | ||
| nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() | ||
| ) |
There was a problem hiding this comment.
Significant IS ratio code duplication within same file
Medium Severity
The IS ratio computation (~15 lines) and metrics logging (~35 lines) are now fully duplicated between _generate_and_score_completions (liger path) and _compute_loss (non-liger path) within the same file. If someone later updates the capping logic, metric names, or IS ratio formula in one copy but not the other, the two code paths will silently diverge. The AGENTS.md rules explicitly warn that duplicated logic must stay aligned and that "not propagating a change is a bug." While duplication across trainers is accepted by design, this intra-file duplication doubles the maintenance surface for the same feature within a single trainer.
Additional Locations (2)
Triggered by project rule: ../.ai/AGENTS.md
There was a problem hiding this comment.
The duplicated parts are IS ratio computation (~15 lines) and metrics logging (~35 lines). I kept them inline to follow the project's simplicity policy (avoid adding layers of indirection), but I can extract them into helper methods (_compute_vllm_is_ratio, _log_vllm_is_metrics) if preferred.


What does this PR do?
When GRPOTrainer uses vLLM with importance sampling correction enabled, it unconditionally performs a
no_gradforward pass to computeold_per_token_logps, even in on-policy training scenarios whereold_per_token_logps == per_token_logps.detach(). This wastes computational resources.This PR skips the unnecessary forward pass for non-liger, on-policy configurations by:
old_per_token_logpsand IS ratio in_generate_and_score_completionswhen using liger kernel (which requires pre-computed ratios) or off-policy scenarios._compute_loss, leveraging the existing fallback that usesper_token_logps.detach()whenold_per_token_logpsis not provided.This optimization applies solely to on-policy + non-liger configurations without affecting training outcomes.
Behavior change summary
no_gradforward pass (redundant)_compute_lossno_gradforward passno_gradforward pass (unchanged)no_gradforward passno_gradforward pass (unchanged)Modified file:
trl/trainer/grpo_trainer.pyFixes #5412
Tests
Ran
pytest tests/test_grpo_trainer.pyon a single H200 GPU:test_vlm_training[HuggingFaceTB/SmolVLM-Instruct]) is a pre-existing issue unrelated to this PR — verified by running the same test onmain, which also fails with the sameValueError(chat template / conversational prompt mismatch atgrpo_trainer.py:1734, outside the scope of this PR's changes which start at line 1863).test_training_vllm_importance_sampling_correction(skipped — no vLLM server)test_compute_liger_loss_passes_vllm_is_ratio✅test_training_sequence_importance_sampling✅test_liger_grpo_kernel_importance_sampling✅test_training_vlm_and_importance_sampling✅make precommitpassed (ruff check, ruff format, doc-builder style)Before submitting
AI writing disclosure
Who can review?
@qgallouedec
Note
Medium Risk
Touches GRPO loss/importance-sampling computation paths for vLLM (including VESPO), so mistakes could silently change training dynamics despite being intended as an optimization.
Overview
Reduces redundant computation in GRPO when using vLLM importance-sampling correction by only precomputing
old_per_token_logps/IS ratios during_generate_and_score_completionsfor off-policy misalignment or the liger kernel path.For the common on-policy + non-liger case, the trainer now falls back to
per_token_logps.detach()forold_per_token_logpsand computes/caps/masksimportance_sampling_ratioinline in_compute_loss(ensuring it is available forvespo), with corresponding IS-delta/ratio metrics logged from the inline path.Written by Cursor Bugbot for commit 1960957. This will update automatically on new commits. Configure here.