Skip to content

Skip redundant forward pass for on-policy vLLM importance sampling#5413

Open
GJ98 wants to merge 3 commits intohuggingface:mainfrom
GJ98:feat/grpo-skip-no-grad-forward-on-policy
Open

Skip redundant forward pass for on-policy vLLM importance sampling#5413
GJ98 wants to merge 3 commits intohuggingface:mainfrom
GJ98:feat/grpo-skip-no-grad-forward-on-policy

Conversation

@GJ98
Copy link
Copy Markdown

@GJ98 GJ98 commented Mar 31, 2026

What does this PR do?

When GRPOTrainer uses vLLM with importance sampling correction enabled, it unconditionally performs a no_grad forward pass to compute old_per_token_logps, even in on-policy training scenarios where old_per_token_logps == per_token_logps.detach(). This wastes computational resources.

This PR skips the unnecessary forward pass for non-liger, on-policy configurations by:

  • Only pre-computing old_per_token_logps and IS ratio in _generate_and_score_completions when using liger kernel (which requires pre-computed ratios) or off-policy scenarios.
  • For non-liger paths, computing IS correction inline in _compute_loss, leveraging the existing fallback that uses per_token_logps.detach() when old_per_token_logps is not provided.

This optimization applies solely to on-policy + non-liger configurations without affecting training outcomes.

Behavior change summary

Scenario Before After
On-policy + non-liger + vLLM IS no_grad forward pass (redundant) Skipped, IS computed inline in _compute_loss
On-policy + liger + vLLM IS no_grad forward pass no_grad forward pass (unchanged)
Off-policy no_grad forward pass no_grad forward pass (unchanged)

Modified file: trl/trainer/grpo_trainer.py

Fixes #5412

Tests

Ran pytest tests/test_grpo_trainer.py on a single H200 GPU:

  • 90 passed, 7 skipped, 5 xfailed, 1 failed (21m 58s)
  • The 1 failure (test_vlm_training[HuggingFaceTB/SmolVLM-Instruct]) is a pre-existing issue unrelated to this PR — verified by running the same test on main, which also fails with the same ValueError (chat template / conversational prompt mismatch at grpo_trainer.py:1734, outside the scope of this PR's changes which start at line 1863).
  • All importance sampling related tests passed:
    • test_training_vllm_importance_sampling_correction (skipped — no vLLM server)
    • test_compute_liger_loss_passes_vllm_is_ratio
    • test_training_sequence_importance_sampling
    • test_liger_grpo_kernel_importance_sampling
    • test_training_vlm_and_importance_sampling
  • make precommit passed (ruff check, ruff format, doc-builder style)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? No new tests needed

AI writing disclosure

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

@qgallouedec


Note

Medium Risk
Touches GRPO loss/importance-sampling computation paths for vLLM (including VESPO), so mistakes could silently change training dynamics despite being intended as an optimization.

Overview
Reduces redundant computation in GRPO when using vLLM importance-sampling correction by only precomputing old_per_token_logps/IS ratios during _generate_and_score_completions for off-policy misalignment or the liger kernel path.

For the common on-policy + non-liger case, the trainer now falls back to per_token_logps.detach() for old_per_token_logps and computes/caps/masks importance_sampling_ratio inline in _compute_loss (ensuring it is available for vespo), with corresponding IS-delta/ratio metrics logged from the inline path.

Written by Cursor Bugbot for commit 1960957. This will update automatically on new commits. Configure here.

…olicy vLLM training

When using vLLM with vllm_importance_sampling_correction=True, the extra
no_grad forward pass to compute old_per_token_logps is redundant in
on-policy settings since old_per_token_logps == per_token_logps.detach().

Move IS correction computation from _generate_and_score_completions to
_compute_loss for the non-liger path, using old_per_token_logps which
falls back to per_token_logps.detach() on-policy. For liger kernel path,
pre-compute IS ratio as before since liger computes logprobs internally.

Fixes huggingface#5412
The inline IS correction in _compute_loss was computed after the loss
type switch, so vespo's get_gamma_weights could not access the IS ratio.
This silently dropped IS correction for vespo + non-liger + on-policy.

Move IS ratio computation before the loss type switch and store it in
inputs["importance_sampling_ratio"] so both vespo and non-vespo paths
can access it. Align variable names, comments, and control flow with
_generate_and_score_completions per AGENTS.md consistency rules.

Fixes huggingface#5412
f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'."
)

inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale IS ratio reused across iterations via dict mutation

Medium Severity

When num_iterations > 1, _compute_loss is called multiple times for the same buffered inputs dict. The first call computes the inline IS ratio and writes it back via inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio. On subsequent iterations, the guard inputs.get("importance_sampling_ratio") is None evaluates to False, so the ratio from iteration 1 is silently reused. Because old_per_token_logps equals per_token_logps.detach() (which changes each iteration) and coef_1 is always 1 in the on-policy path, the total IS correction becomes stale — using iteration-1 logprobs instead of the current iteration's. The old code avoided this because the pre-computed ratio and the standard IS ratio (coef_1) together always yielded the up-to-date total correction exp(per_token_logps - sampling).

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both old_per_token_logps and importance_sampling_ratio are stored back into inputs on the first iteration, so subsequent iterations (num_iterations > 1) reuse the same generation-time values.

This means old_per_token_logps does not change each iteration — it is fixed from the first call. So coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift, and the total correction coef_1 * IS_ratio = exp(per_token_logps - old) * exp(old - sampling) = exp(per_token_logps - sampling) stays up-to-date, same as the original code.

…ons > 1

On the first iteration, store old_per_token_logps and importance_sampling_ratio
back into inputs so that subsequent iterations reuse the same generation-time
values. This keeps coef_1 and IS correction consistent with the original
pre-computed path.

Fixes huggingface#5412
needs_old_logps_for_liger = (
self.use_liger_kernel and self.use_vllm and self.vllm_importance_sampling_correction
)
if self.args.gradient_accumulation_steps % generate_every != 0 or needs_old_logps_for_liger:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PPO clipping disabled when num_iterations exceeds one

Medium Severity

When num_iterations > 1 with on-policy + non-liger + vLLM IS, skipping the old_per_token_logps pre-computation causes _compute_loss to always set old_per_token_logps = per_token_logps.detach(). This makes coef_1 = exp(per_token_logps - per_token_logps.detach()) = 1 on every iteration, so PPO clipping never activates across iterations. Previously, the stored generation-time old_per_token_logps allowed coef_1 to reflect model drift and trigger clipping. The unclipped gradient product is mathematically equivalent, but the clipping safety mechanism is lost.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _compute_loss method stores old_per_token_logps back into inputs on the first iteration (inputs["old_per_token_logps"] = per_token_logps.detach()), so subsequent iterations reuse the same generation-time logprobs instead of recomputing per_token_logps.detach(). This means coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift across iterations, and PPO clipping remains active.

@GJ98
Copy link
Copy Markdown
Author

GJ98 commented Apr 1, 2026

cursor review

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 3 total unresolved issues (including 2 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

)
self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Significant IS ratio code duplication within same file

Medium Severity

The IS ratio computation (~15 lines) and metrics logging (~35 lines) are now fully duplicated between _generate_and_score_completions (liger path) and _compute_loss (non-liger path) within the same file. If someone later updates the capping logic, metric names, or IS ratio formula in one copy but not the other, the two code paths will silently diverge. The AGENTS.md rules explicitly warn that duplicated logic must stay aligned and that "not propagating a change is a bug." While duplication across trainers is accepted by design, this intra-file duplication doubles the maintenance surface for the same feature within a single trainer.

Additional Locations (2)
Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicated parts are IS ratio computation (~15 lines) and metrics logging (~35 lines). I kept them inline to follow the project's simplicity policy (avoid adding layers of indirection), but I can extract them into helper methods (_compute_vllm_is_ratio, _log_vllm_is_metrics) if preferred.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Skip unnecessary no_grad forward pass for IS correction in on-policy vLLM training

1 participant