Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 90 additions & 14 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,12 +1863,13 @@ def _generate_and_score_completions(
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
# old_per_token_logps to None.
# When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
# distribution mismatch between vLLM and the training model can be large and harm the training.
# When using liger kernel with vLLM, we need to pre-compute old_per_token_logps for IS correction
# since liger computes logprobs internally. For non-liger path, IS correction is computed in _compute_loss.
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
if self.args.gradient_accumulation_steps % generate_every != 0 or (
self.use_vllm and self.vllm_importance_sampling_correction
):
needs_old_logps_for_liger = (
self.use_liger_kernel and self.use_vllm and self.vllm_importance_sampling_correction
)
if self.args.gradient_accumulation_steps % generate_every != 0 or needs_old_logps_for_liger:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PPO clipping disabled when num_iterations exceeds one

Medium Severity

When num_iterations > 1 with on-policy + non-liger + vLLM IS, skipping the old_per_token_logps pre-computation causes _compute_loss to always set old_per_token_logps = per_token_logps.detach(). This makes coef_1 = exp(per_token_logps - per_token_logps.detach()) = 1 on every iteration, so PPO clipping never activates across iterations. Previously, the stored generation-time old_per_token_logps allowed coef_1 to reflect model drift and trigger clipping. The unclipped gradient product is mathematically equivalent, but the clipping safety mechanism is lost.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _compute_loss method stores old_per_token_logps back into inputs on the first iteration (inputs["old_per_token_logps"] = per_token_logps.detach()), so subsequent iterations reuse the same generation-time logprobs instead of recomputing per_token_logps.detach(). This means coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift across iterations, and PPO clipping remains active.

old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
Expand All @@ -1881,8 +1882,9 @@ def _generate_and_score_completions(
else:
old_per_token_logps = None

# Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
if self.use_vllm and self.vllm_importance_sampling_correction:
# Pre-compute IS ratio for liger kernel path only; non-liger path computes it in _compute_loss
vllm_importance_sampling_ratio = None
if needs_old_logps_for_liger:
mask = completion_mask if tool_mask is None else completion_mask * tool_mask
per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask

Expand Down Expand Up @@ -2052,7 +2054,7 @@ def _generate_and_score_completions(
if images is not None:
self._logs["images"].extend(gather_object(images))

if self.use_vllm and self.vllm_importance_sampling_correction:
if needs_old_logps_for_liger:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool()
delta = delta[mask]
Expand Down Expand Up @@ -2292,13 +2294,50 @@ def _compute_loss(self, model, inputs):
# provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor.
if advantages.dim() == 1:
advantages = advantages.unsqueeze(1)
# When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
# old_per_token_logps == per_token_logps. In this case we can skip its computation
# (see _generate_and_score_completions) and instead use per_token_logps.detach().
# The exception is when using vLLM, where we always compute old_per_token_logps
# for importance sampling
# When gradient_accumulation_steps % generate_every == 0 (on-policy),
# old_per_token_logps == per_token_logps on the first iteration. In this case we can skip
# its computation (see _generate_and_score_completions) and instead use per_token_logps.detach().
# When using vLLM without liger, IS correction is computed inline below using old_per_token_logps.
if inputs.get("old_per_token_logps") is None:
inputs["old_per_token_logps"] = per_token_logps.detach()
old_per_token_logps = inputs.get("old_per_token_logps")
old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps

# Compute inline IS ratio for non-liger vLLM path. This must happen before the loss type switch
# because vespo needs it in get_gamma_weights. For the liger path, IS ratio is pre-computed in
# _generate_and_score_completions and passed via inputs["importance_sampling_ratio"].
vllm_importance_sampling_ratio = None
if self.use_vllm and self.vllm_importance_sampling_correction:
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
if sampling_per_token_logps is not None and inputs.get("importance_sampling_ratio") is None:
per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask

sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]
if sequence_level_is:
per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True)
logps_diff = per_sequence_logps_diff
else:
logps_diff = per_token_logps_diff

vllm_importance_sampling_ratio = torch.exp(logps_diff)

# vllm_importance_sampling_ratio.shape:
# token_* modes: (B, T) (per-token ratio)
# sequence_* modes: (B, 1) (per-sequence ratio)

if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
vllm_importance_sampling_ratio = torch.clamp(
vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill(
vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0
)
else:
raise ValueError(
f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'."
)

inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale IS ratio reused across iterations via dict mutation

Medium Severity

When num_iterations > 1, _compute_loss is called multiple times for the same buffered inputs dict. The first call computes the inline IS ratio and writes it back via inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio. On subsequent iterations, the guard inputs.get("importance_sampling_ratio") is None evaluates to False, so the ratio from iteration 1 is silently reused. Because old_per_token_logps equals per_token_logps.detach() (which changes each iteration) and coef_1 is always 1 in the on-policy path, the total IS correction becomes stale — using iteration-1 logprobs instead of the current iteration's. The old code avoided this because the pre-computed ratio and the standard IS ratio (coef_1) together always yielded the up-to-date total correction exp(per_token_logps - sampling).

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both old_per_token_logps and importance_sampling_ratio are stored back into inputs on the first iteration, so subsequent iterations (num_iterations > 1) reuse the same generation-time values.

This means old_per_token_logps does not change each iteration — it is fixed from the first call. So coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift, and the total correction coef_1 * IS_ratio = exp(per_token_logps - old) * exp(old - sampling) = exp(per_token_logps - sampling) stays up-to-date, same as the original code.


if self.off_policy_mask_threshold is not None:
# OPSM should use inference-time logprobs to detect both sources of off-policyness:
Expand Down Expand Up @@ -2381,6 +2420,43 @@ def _compute_loss(self, model, inputs):
if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo":
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]

# Log IS correction metrics for inline-computed IS ratio (non-liger path)
if self.use_vllm and self.vllm_importance_sampling_correction and vllm_importance_sampling_ratio is not None:
mode = "train" if self.model.training else "eval"
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
delta = delta[mask.bool()]
mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device)
max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device)
self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
self.accelerator.gather(mean_delta).mean().item()
)
self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
self.accelerator.gather(max_delta).max().item()
)
if sequence_level_is:
flat_is_ratio = vllm_importance_sampling_ratio.flatten()
else:
flat_is_ratio = vllm_importance_sampling_ratio[mask.bool()]

min_importance_sampling_ratio = (
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device)
)
mean_importance_sampling_ratio = (
torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device)
)
max_importance_sampling_ratio = (
torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device)
)
self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
)
self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
)
self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Significant IS ratio code duplication within same file

Medium Severity

The IS ratio computation (~15 lines) and metrics logging (~35 lines) are now fully duplicated between _generate_and_score_completions (liger path) and _compute_loss (non-liger path) within the same file. If someone later updates the capping logic, metric names, or IS ratio formula in one copy but not the other, the two code paths will silently diverge. The AGENTS.md rules explicitly warn that duplicated logic must stay aligned and that "not propagating a change is a bug." While duplication across trainers is accepted by design, this intra-file duplication doubles the maintenance surface for the same feature within a single trainer.

Additional Locations (2)
Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicated parts are IS ratio computation (~15 lines) and metrics logging (~35 lines). I kept them inline to follow the project's simplicity policy (avoid adding layers of indirection), but I can extract them into helper methods (_compute_vllm_is_ratio, _log_vllm_is_metrics) if preferred.


if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl

Expand Down
Loading