Fix DAPO token-level loss to use prompt-level aggregation#5381
Fix DAPO token-level loss to use prompt-level aggregation#5381matdou wants to merge 1 commit intohuggingface:mainfrom
Conversation
9ecde08 to
1755a4e
Compare
1755a4e to
b28e1ce
Compare
b28e1ce to
2035933
Compare
fb406ac to
30f26bc
Compare
30f26bc to
c763811
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| loss = (per_token_loss * mask).sum() / normalizer | ||
| # Average tokens within each prompt first, then average across prompts | ||
| per_prompt_loss = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) | ||
| loss = per_prompt_loss.mean() |
There was a problem hiding this comment.
Missing gradient accumulation normalizer in dppo trainer
High Severity
The dppo trainer's loss computation now diverges from the main grpo_trainer.py dapo branch, which was previously identical. The main trainer divides by current_gradient_accumulation_steps during training, but the dppo change omits this normalizer entirely. The old code's num_items_in_batch / num_processes normalizer implicitly handled gradient accumulation since num_items_in_batch spanned all micro-batches; .mean() only averages within a single micro-batch, so accumulated gradients will be too large by a factor of gradient_accumulation_steps.
Additional Locations (1)
Triggered by project rule: ../.ai/AGENTS.md


What does this PR do?
Fix the DAPO token-level loss aggregation to use prompt-level averaging instead
of batch-level averaging, as specified in the DAPO paper.
The
dapoloss branch was normalizing over all tokens in the batch globally:This up-weights longer sequences since a prompt with 6 tokens contributes twice
as much as a prompt with 3 tokens. The DAPO paper specifies prompt-level
averaging: average token losses within each prompt first, then average across
prompts. The fix follows the same pattern already used by the
grpoloss typein the same file:
Fixes #5375
Before submitting
to it if that's the case. → DAPO token level loss is incorrectly implemented #5375
lengths verifying prompt-level vs batch-level aggregation produce different values.
Who can review?
Anyone in the community is free to review the PR once the tests have passed.
Note
Medium Risk
Changes the
dapoloss computation in multiple trainers, which will alter optimization dynamics and reported losses, especially with variable-length completions. Risk is moderate because it affects core training math but is a localized, well-defined normalization change with a targeted unit test.Overview
Fixes
dapoloss aggregation to match the paper: average token losses per prompt, then average across prompts, instead of globally summing tokens and normalizing bynum_items_in_batch(which over-weights longer completions).Applies this prompt-level averaging in
trl/trainer/grpo_trainer.pyand the experimental trainersgspo_token/grpo_trainer.py,papo_trainer.py, anddppo_trainer.py, and adds a unit test (tests/experimental/test_gspo_token_trainer.py) that asserts prompt-level vs batch-level averaging diverge for unequal completion lengths.Written by Cursor Bugbot for commit c763811. This will update automatically on new commits. Configure here.