Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/experimental/test_gspo_token_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


import unittest

import torch
from datasets import load_dataset
from transformers.utils import is_peft_available
Expand All @@ -27,6 +29,69 @@
pass


class TestDAPOLossAggregation(unittest.TestCase):
def test_dapo_prompt_level_vs_batch_level_differ_for_unequal_lengths(self):
"""
Verifies that prompt-level averaging (DAPO paper) produces a different loss
than batch-level averaging when prompts have unequal completion lengths.

Prompt-level averaging: average tokens within each prompt, then average across prompts.
Batch-level averaging (old, incorrect): sum all tokens and divide by total token count.

With equal-length completions both methods agree; with unequal lengths they diverge
because batch-level averaging implicitly up-weights longer sequences.
"""
torch.manual_seed(0)
batch_size = 3
max_len = 6

# Per-token losses: arbitrary positive values
per_token_loss = torch.tensor(
[
[1.0, 2.0, 3.0, 0.0, 0.0, 0.0], # prompt 0: 3 tokens
[4.0, 5.0, 0.0, 0.0, 0.0, 0.0], # prompt 1: 2 tokens
[6.0, 7.0, 8.0, 9.0, 10.0, 11.0], # prompt 2: 6 tokens
]
)
completion_mask = torch.tensor(
[
[1, 1, 1, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
],
dtype=torch.float,
)

num_items_in_batch = batch_size # single process

# --- New (correct): prompt-level averaging via .mean() ---
per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
loss_prompt_level = per_prompt_loss.mean()

# --- Old (incorrect): batch-level averaging ---
loss_batch_level = (per_token_loss * completion_mask).sum() / num_items_in_batch

# The two values must differ because sequence lengths are unequal
self.assertNotAlmostEqual(
loss_prompt_level.item(),
loss_batch_level.item(),
places=5,
msg="Expected prompt-level and batch-level losses to differ for unequal-length completions",
)

# Sanity-check the prompt-level value manually:
# prompt 0: (1+2+3)/3 = 2.0
# prompt 1: (4+5)/2 = 4.5
# prompt 2: (6+7+8+9+10+11)/6 = 8.5
# mean of [2.0, 4.5, 8.5] / 3 == (2.0 + 4.5 + 8.5) / 3
expected_prompt_level = (2.0 + 4.5 + 8.5) / 3
self.assertAlmostEqual(loss_prompt_level.item(), expected_prompt_level, places=5)

# Batch-level: (1+2+3+4+5+6+7+8+9+10+11) / 3 = 66/3 = 22.0
expected_batch_level = (1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11) / 3
self.assertAlmostEqual(loss_batch_level.item(), expected_batch_level, places=5)


class TestGSPOTokenTrainer(TrlTestCase):
def test_training(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
Expand Down
5 changes: 3 additions & 2 deletions trl/experimental/dppo/dppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,8 +1338,9 @@ def _compute_loss(self, model, inputs):
per_token_loss = per_token_loss + self.beta * per_token_kl

mode = "train" if self.model.training else "eval"
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * mask).sum() / normalizer
# Average tokens within each prompt first, then average across prompts
per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
loss = per_prompt_loss.mean()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing gradient accumulation normalizer in dppo trainer

High Severity

The dppo trainer's loss computation now diverges from the main grpo_trainer.py dapo branch, which was previously identical. The main trainer divides by current_gradient_accumulation_steps during training, but the dppo change omits this normalizer entirely. The old code's num_items_in_batch / num_processes normalizer implicitly handled gradient accumulation since num_items_in_batch spanned all micro-batches; .mean() only averages within a single micro-batch, so accumulated gradients will be too large by a factor of gradient_accumulation_steps.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md


# Log metrics
completion_token_count = mask.sum().clamp(min=1.0)
Expand Down
5 changes: 3 additions & 2 deletions trl/experimental/gspo_token/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def _compute_loss(self, model, inputs):
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type == "dapo":
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
# Average tokens within each prompt first, then average across prompts
per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
loss = per_prompt_loss.mean()
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

Expand Down
5 changes: 3 additions & 2 deletions trl/experimental/papo/papo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ def _compute_loss(self, model, inputs):
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
loss = loss / self.current_gradient_accumulation_steps
elif self.loss_type == "dapo":
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
# Average tokens within each prompt first, then average across prompts
per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
loss = per_prompt_loss.mean() / self.current_gradient_accumulation_steps
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
# >>> 2. Implicit Perception Loss
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,7 +2397,13 @@ def _compute_loss(self, model, inputs):
loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type in ["cispo", "dapo", "vespo"]:
elif self.loss_type == "dapo":
# Average tokens within each prompt first, then average across prompts
per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
loss = per_prompt_loss.mean()
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type in ["cispo", "vespo"]:
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * mask).sum() / normalizer
elif self.loss_type == "luspo":
Expand Down