Skip to content

Fix DAPO token-level loss to use prompt-level aggregation#5381

Open
matdou wants to merge 1 commit intohuggingface:mainfrom
matdou:fix/dapo-token-loss-aggregation
Open

Fix DAPO token-level loss to use prompt-level aggregation#5381
matdou wants to merge 1 commit intohuggingface:mainfrom
matdou:fix/dapo-token-loss-aggregation

Conversation

@matdou
Copy link
Copy Markdown

@matdou matdou commented Mar 26, 2026

What does this PR do?

Fix the DAPO token-level loss aggregation to use prompt-level averaging instead
of batch-level averaging, as specified in the DAPO paper.

The dapo loss branch was normalizing over all tokens in the batch globally:

normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer

This up-weights longer sequences since a prompt with 6 tokens contributes twice
as much as a prompt with 3 tokens. The DAPO paper specifies prompt-level
averaging: average token losses within each prompt first, then average across
prompts. The fix follows the same pattern already used by the grpo loss type
in the same file:

per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
loss = per_prompt_loss.mean()
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
loss = loss / normalizer

Fixes #5375

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case. → DAPO token level loss is incorrectly implemented #5375
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? → Added unit test with unequal prompt
    lengths verifying prompt-level vs batch-level aggregation produce different values.

Who can review?

Anyone in the community is free to review the PR once the tests have passed.


Note

Medium Risk
Changes the dapo loss computation in multiple trainers, which will alter optimization dynamics and reported losses, especially with variable-length completions. Risk is moderate because it affects core training math but is a localized, well-defined normalization change with a targeted unit test.

Overview
Fixes dapo loss aggregation to match the paper: average token losses per prompt, then average across prompts, instead of globally summing tokens and normalizing by num_items_in_batch (which over-weights longer completions).

Applies this prompt-level averaging in trl/trainer/grpo_trainer.py and the experimental trainers gspo_token/grpo_trainer.py, papo_trainer.py, and dppo_trainer.py, and adds a unit test (tests/experimental/test_gspo_token_trainer.py) that asserts prompt-level vs batch-level averaging diverge for unequal completion lengths.

Written by Cursor Bugbot for commit c763811. This will update automatically on new commits. Configure here.

@matdou matdou force-pushed the fix/dapo-token-loss-aggregation branch from 9ecde08 to 1755a4e Compare March 26, 2026 17:39
@matdou matdou force-pushed the fix/dapo-token-loss-aggregation branch from 1755a4e to b28e1ce Compare March 26, 2026 17:43
@matdou matdou force-pushed the fix/dapo-token-loss-aggregation branch from b28e1ce to 2035933 Compare March 26, 2026 17:51
@matdou matdou force-pushed the fix/dapo-token-loss-aggregation branch 2 times, most recently from fb406ac to 30f26bc Compare March 26, 2026 18:03
@matdou matdou force-pushed the fix/dapo-token-loss-aggregation branch from 30f26bc to c763811 Compare March 26, 2026 18:09
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

loss = (per_token_loss * mask).sum() / normalizer
# Average tokens within each prompt first, then average across prompts
per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
loss = per_prompt_loss.mean()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing gradient accumulation normalizer in dppo trainer

High Severity

The dppo trainer's loss computation now diverges from the main grpo_trainer.py dapo branch, which was previously identical. The main trainer divides by current_gradient_accumulation_steps during training, but the dppo change omits this normalizer entirely. The old code's num_items_in_batch / num_processes normalizer implicitly handled gradient accumulation since num_items_in_batch spanned all micro-batches; .mean() only averages within a single micro-batch, so accumulated gradients will be too large by a factor of gradient_accumulation_steps.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DAPO token level loss is incorrectly implemented

1 participant