diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c5eed09419..3836173395 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1863,12 +1863,13 @@ def _generate_and_score_completions( # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set # old_per_token_logps to None. - # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the - # distribution mismatch between vLLM and the training model can be large and harm the training. + # When using liger kernel with vLLM, we need to pre-compute old_per_token_logps for IS correction + # since liger computes logprobs internally. For non-liger path, IS correction is computed in _compute_loss. generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency - if self.args.gradient_accumulation_steps % generate_every != 0 or ( - self.use_vllm and self.vllm_importance_sampling_correction - ): + needs_old_logps_for_liger = ( + self.use_liger_kernel and self.use_vllm and self.vllm_importance_sampling_correction + ) + if self.args.gradient_accumulation_steps % generate_every != 0 or needs_old_logps_for_liger: old_per_token_logps, _ = self._get_per_token_logps_and_entropies( self.model, prompt_completion_ids, @@ -1881,8 +1882,9 @@ def _generate_and_score_completions( else: old_per_token_logps = None - # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch - if self.use_vllm and self.vllm_importance_sampling_correction: + # Pre-compute IS ratio for liger kernel path only; non-liger path computes it in _compute_loss + vllm_importance_sampling_ratio = None + if needs_old_logps_for_liger: mask = completion_mask if tool_mask is None else completion_mask * tool_mask per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask @@ -2052,7 +2054,7 @@ def _generate_and_score_completions( if images is not None: self._logs["images"].extend(gather_object(images)) - if self.use_vllm and self.vllm_importance_sampling_correction: + if needs_old_logps_for_liger: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() delta = delta[mask] @@ -2292,13 +2294,50 @@ def _compute_loss(self, model, inputs): # provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor. if advantages.dim() == 1: advantages = advantages.unsqueeze(1) - # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, - # old_per_token_logps == per_token_logps. In this case we can skip its computation - # (see _generate_and_score_completions) and instead use per_token_logps.detach(). - # The exception is when using vLLM, where we always compute old_per_token_logps - # for importance sampling + # When gradient_accumulation_steps % generate_every == 0 (on-policy), + # old_per_token_logps == per_token_logps on the first iteration. In this case we can skip + # its computation (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # When using vLLM without liger, IS correction is computed inline below using old_per_token_logps. + if inputs.get("old_per_token_logps") is None: + inputs["old_per_token_logps"] = per_token_logps.detach() old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + # Compute inline IS ratio for non-liger vLLM path. This must happen before the loss type switch + # because vespo needs it in get_gamma_weights. For the liger path, IS ratio is pre-computed in + # _generate_and_score_completions and passed via inputs["importance_sampling_ratio"]. + vllm_importance_sampling_ratio = None + if self.use_vllm and self.vllm_importance_sampling_correction: + sampling_per_token_logps = inputs.get("sampling_per_token_logps") + if sampling_per_token_logps is not None and inputs.get("importance_sampling_ratio") is None: + per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask + + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + logps_diff = per_sequence_logps_diff + else: + logps_diff = per_token_logps_diff + + vllm_importance_sampling_ratio = torch.exp(logps_diff) + + # vllm_importance_sampling_ratio.shape: + # token_* modes: (B, T) (per-token ratio) + # sequence_* modes: (B, 1) (per-sequence ratio) + + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + vllm_importance_sampling_ratio = torch.clamp( + vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 + ) + else: + raise ValueError( + f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + ) + + inputs["importance_sampling_ratio"] = vllm_importance_sampling_ratio if self.off_policy_mask_threshold is not None: # OPSM should use inference-time logprobs to detect both sources of off-policyness: @@ -2381,6 +2420,43 @@ def _compute_loss(self, model, inputs): if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + # Log IS correction metrics for inline-computed IS ratio (non-liger path) + if self.use_vllm and self.vllm_importance_sampling_correction and vllm_importance_sampling_ratio is not None: + mode = "train" if self.model.training else "eval" + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + delta = delta[mask.bool()] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + if sequence_level_is: + flat_is_ratio = vllm_importance_sampling_ratio.flatten() + else: + flat_is_ratio = vllm_importance_sampling_ratio[mask.bool()] + + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=per_token_logps.device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) + if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl