Skip to content

Commit c763811

Browse files
committed
Fix DAPO token-level loss to use prompt-level aggregation
1 parent b27d9de commit c763811

File tree

5 files changed

+81
-7
lines changed

5 files changed

+81
-7
lines changed

tests/experimental/test_gspo_token_trainer.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import unittest
17+
1618
import torch
1719
from datasets import load_dataset
1820
from transformers.utils import is_peft_available
@@ -27,6 +29,69 @@
2729
pass
2830

2931

32+
class TestDAPOLossAggregation(unittest.TestCase):
33+
def test_dapo_prompt_level_vs_batch_level_differ_for_unequal_lengths(self):
34+
"""
35+
Verifies that prompt-level averaging (DAPO paper) produces a different loss
36+
than batch-level averaging when prompts have unequal completion lengths.
37+
38+
Prompt-level averaging: average tokens within each prompt, then average across prompts.
39+
Batch-level averaging (old, incorrect): sum all tokens and divide by total token count.
40+
41+
With equal-length completions both methods agree; with unequal lengths they diverge
42+
because batch-level averaging implicitly up-weights longer sequences.
43+
"""
44+
torch.manual_seed(0)
45+
batch_size = 3
46+
max_len = 6
47+
48+
# Per-token losses: arbitrary positive values
49+
per_token_loss = torch.tensor(
50+
[
51+
[1.0, 2.0, 3.0, 0.0, 0.0, 0.0], # prompt 0: 3 tokens
52+
[4.0, 5.0, 0.0, 0.0, 0.0, 0.0], # prompt 1: 2 tokens
53+
[6.0, 7.0, 8.0, 9.0, 10.0, 11.0], # prompt 2: 6 tokens
54+
]
55+
)
56+
completion_mask = torch.tensor(
57+
[
58+
[1, 1, 1, 0, 0, 0],
59+
[1, 1, 0, 0, 0, 0],
60+
[1, 1, 1, 1, 1, 1],
61+
],
62+
dtype=torch.float,
63+
)
64+
65+
num_items_in_batch = batch_size # single process
66+
67+
# --- New (correct): prompt-level averaging via .mean() ---
68+
per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
69+
loss_prompt_level = per_prompt_loss.mean()
70+
71+
# --- Old (incorrect): batch-level averaging ---
72+
loss_batch_level = (per_token_loss * completion_mask).sum() / num_items_in_batch
73+
74+
# The two values must differ because sequence lengths are unequal
75+
self.assertNotAlmostEqual(
76+
loss_prompt_level.item(),
77+
loss_batch_level.item(),
78+
places=5,
79+
msg="Expected prompt-level and batch-level losses to differ for unequal-length completions",
80+
)
81+
82+
# Sanity-check the prompt-level value manually:
83+
# prompt 0: (1+2+3)/3 = 2.0
84+
# prompt 1: (4+5)/2 = 4.5
85+
# prompt 2: (6+7+8+9+10+11)/6 = 8.5
86+
# mean of [2.0, 4.5, 8.5] / 3 == (2.0 + 4.5 + 8.5) / 3
87+
expected_prompt_level = (2.0 + 4.5 + 8.5) / 3
88+
self.assertAlmostEqual(loss_prompt_level.item(), expected_prompt_level, places=5)
89+
90+
# Batch-level: (1+2+3+4+5+6+7+8+9+10+11) / 3 = 66/3 = 22.0
91+
expected_batch_level = (1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11) / 3
92+
self.assertAlmostEqual(loss_batch_level.item(), expected_batch_level, places=5)
93+
94+
3095
class TestGSPOTokenTrainer(TrlTestCase):
3196
def test_training(self):
3297
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

trl/experimental/dppo/dppo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,8 +1338,9 @@ def _compute_loss(self, model, inputs):
13381338
per_token_loss = per_token_loss + self.beta * per_token_kl
13391339

13401340
mode = "train" if self.model.training else "eval"
1341-
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
1342-
loss = (per_token_loss * mask).sum() / normalizer
1341+
# Average tokens within each prompt first, then average across prompts
1342+
per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
1343+
loss = per_prompt_loss.mean()
13431344

13441345
# Log metrics
13451346
completion_token_count = mask.sum().clamp(min=1.0)

trl/experimental/gspo_token/grpo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ def _compute_loss(self, model, inputs):
116116
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
117117
loss = loss / normalizer
118118
elif self.loss_type == "dapo":
119-
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
120-
loss = (per_token_loss * completion_mask).sum() / normalizer
119+
# Average tokens within each prompt first, then average across prompts
120+
per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
121+
loss = per_prompt_loss.mean()
121122
else:
122123
raise ValueError(f"Unknown loss type: {self.loss_type}")
123124

trl/experimental/papo/papo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ def _compute_loss(self, model, inputs):
287287
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
288288
loss = loss / self.current_gradient_accumulation_steps
289289
elif self.loss_type == "dapo":
290-
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
291-
loss = (per_token_loss * completion_mask).sum() / normalizer
290+
# Average tokens within each prompt first, then average across prompts
291+
per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
292+
loss = per_prompt_loss.mean() / self.current_gradient_accumulation_steps
292293
else:
293294
raise ValueError(f"Unknown loss type: {self.loss_type}")
294295
# >>> 2. Implicit Perception Loss

trl/trainer/grpo_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2397,7 +2397,13 @@ def _compute_loss(self, model, inputs):
23972397
loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
23982398
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
23992399
loss = loss / normalizer
2400-
elif self.loss_type in ["cispo", "dapo", "vespo"]:
2400+
elif self.loss_type == "dapo":
2401+
# Average tokens within each prompt first, then average across prompts
2402+
per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
2403+
loss = per_prompt_loss.mean()
2404+
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
2405+
loss = loss / normalizer
2406+
elif self.loss_type in ["cispo", "vespo"]:
24012407
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
24022408
loss = (per_token_loss * mask).sum() / normalizer
24032409
elif self.loss_type == "luspo":

0 commit comments

Comments
 (0)