|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
| 16 | +import unittest |
| 17 | + |
16 | 18 | import torch |
17 | 19 | from datasets import load_dataset |
18 | 20 | from transformers.utils import is_peft_available |
|
27 | 29 | pass |
28 | 30 |
|
29 | 31 |
|
| 32 | +class TestDAPOLossAggregation(unittest.TestCase): |
| 33 | + def test_dapo_prompt_level_vs_batch_level_differ_for_unequal_lengths(self): |
| 34 | + """ |
| 35 | + Verifies that prompt-level averaging (DAPO paper) produces a different loss |
| 36 | + than batch-level averaging when prompts have unequal completion lengths. |
| 37 | +
|
| 38 | + Prompt-level averaging: average tokens within each prompt, then average across prompts. |
| 39 | + Batch-level averaging (old, incorrect): sum all tokens and divide by total token count. |
| 40 | +
|
| 41 | + With equal-length completions both methods agree; with unequal lengths they diverge |
| 42 | + because batch-level averaging implicitly up-weights longer sequences. |
| 43 | + """ |
| 44 | + torch.manual_seed(0) |
| 45 | + batch_size = 3 |
| 46 | + max_len = 6 |
| 47 | + |
| 48 | + # Per-token losses: arbitrary positive values |
| 49 | + per_token_loss = torch.tensor( |
| 50 | + [ |
| 51 | + [1.0, 2.0, 3.0, 0.0, 0.0, 0.0], # prompt 0: 3 tokens |
| 52 | + [4.0, 5.0, 0.0, 0.0, 0.0, 0.0], # prompt 1: 2 tokens |
| 53 | + [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], # prompt 2: 6 tokens |
| 54 | + ] |
| 55 | + ) |
| 56 | + completion_mask = torch.tensor( |
| 57 | + [ |
| 58 | + [1, 1, 1, 0, 0, 0], |
| 59 | + [1, 1, 0, 0, 0, 0], |
| 60 | + [1, 1, 1, 1, 1, 1], |
| 61 | + ], |
| 62 | + dtype=torch.float, |
| 63 | + ) |
| 64 | + |
| 65 | + num_items_in_batch = batch_size # single process |
| 66 | + |
| 67 | + # --- New (correct): prompt-level averaging via .mean() --- |
| 68 | + per_prompt_loss = (per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) |
| 69 | + loss_prompt_level = per_prompt_loss.mean() |
| 70 | + |
| 71 | + # --- Old (incorrect): batch-level averaging --- |
| 72 | + loss_batch_level = (per_token_loss * completion_mask).sum() / num_items_in_batch |
| 73 | + |
| 74 | + # The two values must differ because sequence lengths are unequal |
| 75 | + self.assertNotAlmostEqual( |
| 76 | + loss_prompt_level.item(), |
| 77 | + loss_batch_level.item(), |
| 78 | + places=5, |
| 79 | + msg="Expected prompt-level and batch-level losses to differ for unequal-length completions", |
| 80 | + ) |
| 81 | + |
| 82 | + # Sanity-check the prompt-level value manually: |
| 83 | + # prompt 0: (1+2+3)/3 = 2.0 |
| 84 | + # prompt 1: (4+5)/2 = 4.5 |
| 85 | + # prompt 2: (6+7+8+9+10+11)/6 = 8.5 |
| 86 | + # mean of [2.0, 4.5, 8.5] / 3 == (2.0 + 4.5 + 8.5) / 3 |
| 87 | + expected_prompt_level = (2.0 + 4.5 + 8.5) / 3 |
| 88 | + self.assertAlmostEqual(loss_prompt_level.item(), expected_prompt_level, places=5) |
| 89 | + |
| 90 | + # Batch-level: (1+2+3+4+5+6+7+8+9+10+11) / 3 = 66/3 = 22.0 |
| 91 | + expected_batch_level = (1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11) / 3 |
| 92 | + self.assertAlmostEqual(loss_batch_level.item(), expected_batch_level, places=5) |
| 93 | + |
| 94 | + |
30 | 95 | class TestGSPOTokenTrainer(TrlTestCase): |
31 | 96 | def test_training(self): |
32 | 97 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
0 commit comments