diff --git a/tests/experimental/test_gspo_token_trainer.py b/tests/experimental/test_gspo_token_trainer.py index 19391d55a18..959aded87ba 100644 --- a/tests/experimental/test_gspo_token_trainer.py +++ b/tests/experimental/test_gspo_token_trainer.py @@ -13,6 +13,8 @@ # limitations under the License. +import unittest + import torch from datasets import load_dataset from transformers.utils import is_peft_available @@ -27,6 +29,69 @@ pass +class TestDAPOLossAggregation(unittest.TestCase): + def test_dapo_prompt_level_vs_batch_level_differ_for_unequal_lengths(self): + """ + Verifies that prompt-level averaging (DAPO paper) produces a different loss + than batch-level averaging when prompts have unequal completion lengths. + + Prompt-level averaging: average tokens within each prompt, then average across prompts. + Batch-level averaging (old, incorrect): sum all tokens and divide by total token count. + + With equal-length completions both methods agree; with unequal lengths they diverge + because batch-level averaging implicitly up-weights longer sequences. + """ + torch.manual_seed(0) + batch_size = 3 + max_len = 6 + + # Per-token losses: arbitrary positive values + per_token_loss = torch.tensor( + [ + [1.0, 2.0, 3.0, 0.0, 0.0, 0.0], # prompt 0: 3 tokens + [4.0, 5.0, 0.0, 0.0, 0.0, 0.0], # prompt 1: 2 tokens + [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], # prompt 2: 6 tokens + ] + ) + completion_mask = torch.tensor( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + ], + dtype=torch.float, + ) + + num_items_in_batch = batch_size # single process + + # --- New (correct): prompt-level averaging via .mean() --- + per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + loss_prompt_level = per_prompt_loss.mean() + + # --- Old (incorrect): batch-level averaging --- + loss_batch_level = (per_token_loss * completion_mask).sum() / num_items_in_batch + + # The two values must differ because sequence lengths are unequal + self.assertNotAlmostEqual( + loss_prompt_level.item(), + loss_batch_level.item(), + places=5, + msg="Expected prompt-level and batch-level losses to differ for unequal-length completions", + ) + + # Sanity-check the prompt-level value manually: + # prompt 0: (1+2+3)/3 = 2.0 + # prompt 1: (4+5)/2 = 4.5 + # prompt 2: (6+7+8+9+10+11)/6 = 8.5 + # mean of [2.0, 4.5, 8.5] / 3 == (2.0 + 4.5 + 8.5) / 3 + expected_prompt_level = (2.0 + 4.5 + 8.5) / 3 + self.assertAlmostEqual(loss_prompt_level.item(), expected_prompt_level, places=5) + + # Batch-level: (1+2+3+4+5+6+7+8+9+10+11) / 3 = 66/3 = 22.0 + expected_batch_level = (1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11) / 3 + self.assertAlmostEqual(loss_batch_level.item(), expected_batch_level, places=5) + + class TestGSPOTokenTrainer(TrlTestCase): def test_training(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index 30bbbdd945e..a783a098ffc 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -1338,8 +1338,9 @@ def _compute_loss(self, model, inputs): per_token_loss = per_token_loss + self.beta * per_token_kl mode = "train" if self.model.training else "eval" - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes - loss = (per_token_loss * mask).sum() / normalizer + # Average tokens within each prompt first, then average across prompts + per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + loss = per_prompt_loss.mean() # Log metrics completion_token_count = mask.sum().clamp(min=1.0) diff --git a/trl/experimental/gspo_token/grpo_trainer.py b/trl/experimental/gspo_token/grpo_trainer.py index 0522058bc6e..5abd0dab911 100644 --- a/trl/experimental/gspo_token/grpo_trainer.py +++ b/trl/experimental/gspo_token/grpo_trainer.py @@ -116,8 +116,9 @@ def _compute_loss(self, model, inputs): normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval loss = loss / normalizer elif self.loss_type == "dapo": - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes - loss = (per_token_loss * completion_mask).sum() / normalizer + # Average tokens within each prompt first, then average across prompts + per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + loss = per_prompt_loss.mean() else: raise ValueError(f"Unknown loss type: {self.loss_type}") diff --git a/trl/experimental/papo/papo_trainer.py b/trl/experimental/papo/papo_trainer.py index 931fee3efe9..5af4956f823 100644 --- a/trl/experimental/papo/papo_trainer.py +++ b/trl/experimental/papo/papo_trainer.py @@ -287,8 +287,9 @@ def _compute_loss(self, model, inputs): loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dapo": - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes - loss = (per_token_loss * completion_mask).sum() / normalizer + # Average tokens within each prompt first, then average across prompts + per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + loss = per_prompt_loss.mean() / self.current_gradient_accumulation_steps else: raise ValueError(f"Unknown loss type: {self.loss_type}") # >>> 2. Implicit Perception Loss diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a631de07918..1173fee67d1 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2397,7 +2397,13 @@ def _compute_loss(self, model, inputs): loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval loss = loss / normalizer - elif self.loss_type in ["cispo", "dapo", "vespo"]: + elif self.loss_type == "dapo": + # Average tokens within each prompt first, then average across prompts + per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + loss = per_prompt_loss.mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type in ["cispo", "vespo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * mask).sum() / normalizer elif self.loss_type == "luspo":