Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 83 additions & 11 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,12 +1863,13 @@ def _generate_and_score_completions(
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
# old_per_token_logps to None.
# When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
# distribution mismatch between vLLM and the training model can be large and harm the training.
# When using liger kernel with vLLM, we need to pre-compute old_per_token_logps for IS correction
# since liger computes logprobs internally. For non-liger path, IS correction is computed in _compute_loss.
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
if self.args.gradient_accumulation_steps % generate_every != 0 or (
self.use_vllm and self.vllm_importance_sampling_correction
):
needs_old_logps_for_liger = (
self.use_liger_kernel and self.use_vllm and self.vllm_importance_sampling_correction
)
if self.args.gradient_accumulation_steps % generate_every != 0 or needs_old_logps_for_liger:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PPO clipping disabled when num_iterations exceeds one

Medium Severity

When num_iterations > 1 with on-policy + non-liger + vLLM IS, skipping the old_per_token_logps pre-computation causes _compute_loss to always set old_per_token_logps = per_token_logps.detach(). This makes coef_1 = exp(per_token_logps - per_token_logps.detach()) = 1 on every iteration, so PPO clipping never activates across iterations. Previously, the stored generation-time old_per_token_logps allowed coef_1 to reflect model drift and trigger clipping. The unclipped gradient product is mathematically equivalent, but the clipping safety mechanism is lost.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _compute_loss method stores old_per_token_logps back into inputs on the first iteration (inputs["old_per_token_logps"] = per_token_logps.detach()), so subsequent iterations reuse the same generation-time logprobs instead of recomputing per_token_logps.detach(). This means coef_1 = exp(per_token_logps - old_per_token_logps) correctly reflects model drift across iterations, and PPO clipping remains active.

old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
Expand All @@ -1881,8 +1882,9 @@ def _generate_and_score_completions(
else:
old_per_token_logps = None

# Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
if self.use_vllm and self.vllm_importance_sampling_correction:
# Pre-compute IS ratio for liger kernel path only; non-liger path computes it in _compute_loss
vllm_importance_sampling_ratio = None
if needs_old_logps_for_liger:
mask = completion_mask if tool_mask is None else completion_mask * tool_mask
per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask

Expand Down Expand Up @@ -2052,7 +2054,7 @@ def _generate_and_score_completions(
if images is not None:
self._logs["images"].extend(gather_object(images))

if self.use_vllm and self.vllm_importance_sampling_correction:
if needs_old_logps_for_liger:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool()
delta = delta[mask]
Expand Down Expand Up @@ -2295,8 +2297,7 @@ def _compute_loss(self, model, inputs):
# When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
# old_per_token_logps == per_token_logps. In this case we can skip its computation
# (see _generate_and_score_completions) and instead use per_token_logps.detach().
# The exception is when using vLLM, where we always compute old_per_token_logps
# for importance sampling
# When using vLLM without liger, IS correction is computed inline below using old_per_token_logps.
old_per_token_logps = inputs.get("old_per_token_logps")
old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps

Expand Down Expand Up @@ -2378,8 +2379,79 @@ def _compute_loss(self, model, inputs):
if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask

# Compute IS correction inline for non-liger path.
# For liger path, IS ratio is pre-computed in _generate_and_score_completions.
if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo":
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
if sampling_per_token_logps is not None:
per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask

sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]
if sequence_level_is:
logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True)
else:
logps_diff = per_token_logps_diff

vllm_is_ratio = torch.exp(logps_diff)

if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
vllm_is_ratio = torch.clamp(vllm_is_ratio, max=self.vllm_importance_sampling_cap)
elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
vllm_is_ratio = vllm_is_ratio.masked_fill(
vllm_is_ratio > self.vllm_importance_sampling_cap, value=0.0
)

per_token_loss = per_token_loss * vllm_is_ratio

# Log IS correction metrics
_mode = "train" if self.model.training else "eval"
delta = torch.abs(per_token_logps_diff)
bool_mask = mask.bool()
delta_masked = delta[bool_mask]
mean_delta = (
torch.mean(delta_masked)
if delta_masked.numel() > 0
else torch.tensor(0.0, device=per_token_logps.device)
)
max_delta = (
torch.max(delta_masked)
if delta_masked.numel() > 0
else torch.tensor(0.0, device=per_token_logps.device)
)
self._metrics[_mode]["sampling/sampling_logp_difference/mean"].append(
self.accelerator.gather(mean_delta).mean().item()
)
self._metrics[_mode]["sampling/sampling_logp_difference/max"].append(
self.accelerator.gather(max_delta).max().item()
)
if sequence_level_is:
flat_is_ratio = vllm_is_ratio.flatten()
else:
flat_is_ratio = vllm_is_ratio[bool_mask]
min_is = (
torch.min(flat_is_ratio)
if flat_is_ratio.numel() > 0
else torch.tensor(0.0, device=per_token_logps.device)
)
mean_is = (
torch.mean(flat_is_ratio)
if flat_is_ratio.numel() > 0
else torch.tensor(0.0, device=per_token_logps.device)
)
max_is = (
torch.max(flat_is_ratio)
if flat_is_ratio.numel() > 0
else torch.tensor(0.0, device=per_token_logps.device)
)
self._metrics[_mode]["sampling/importance_sampling_ratio/min"].append(
nanmin(self.accelerator.gather(min_is)).item()
)
self._metrics[_mode]["sampling/importance_sampling_ratio/mean"].append(
self.accelerator.gather(mean_is).nanmean().item()
)
self._metrics[_mode]["sampling/importance_sampling_ratio/max"].append(
nanmax(self.accelerator.gather(max_is)).item()
)

if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
Expand Down
Loading