DPO transformers v0.29 fixes#3560
Conversation
📝 WalkthroughWalkthroughRemoving deprecated DPO/RL configuration parameters ( Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
src/axolotl/utils/data/utils.py (1)
354-363: In-place mutation and type assumption on dict values.The function mutates
examplein-place by reassigningexample[key]. This is fine if callers expect mutation, but could be surprising. Additionally, the loop assumes all values inexampleare sliceable (lists). Ifexamplecontains non-list metadata (e.g., a scalarlengthfield), this will raise aTypeError.Consider either:
- Documenting that mutation is intentional and all values must be lists, or
- Adding a safeguard for non-list values
💡 Optional safeguard for non-list values
def remove_double_bos_token(example: dict[str, list], bos_token_id: int | None): """Remove double bos tokens that may occur when retokenizing preprocessed data - for tokenizers and chat templates that have a bos_token - eg. DPO + Llama. + for tokenizers and chat templates that have a bos_token - eg. DPO + Llama. + + Note: Mutates `example` in-place. All values must be list-like. """ if bos_token_id is not None: input_ids = example["input_ids"] if len(input_ids) >= 2 and input_ids[0] == input_ids[1] == bos_token_id: for key in example: - example[key] = example[key][1:] + if isinstance(example[key], list): + example[key] = example[key][1:] return example🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/data/utils.py` around lines 354 - 363, The function remove_double_bos_token mutates example in-place and assumes every example[key] is a sliceable list, which can raise TypeError for scalar metadata; change it to either return a new dict or guard mutations: when bos_token_id is not None and a double-BOS is detected (in remove_double_bos_token), iterate keys and only slice/modify values that are instances of list (or collections.abc.Sequence) and leave non-list values unchanged (or copy them into the new dict if you choose to return a new object); ensure the function's docstring is updated to state whether mutation is intentional and that only list-like fields are affected, and reference the input_ids check and keys loop (example["input_ids"] and for key in example) when applying the guard.tests/utils/data/test_utils.py (1)
544-582: Consider adding edge case tests for short sequences.The tests cover the main scenarios well. Consider adding tests for edge cases:
- Empty
input_idslist (would fail onlen(input_ids) >= 2check)- Single-element
input_idslist- Exactly two elements where both are
bos_token_id(result would be single-element list)Also, these tests use
assertstatements while the rest of the file usesself.assertEqual- minor style inconsistency.💡 Suggested additional test case
def test_remove_bos_token_boundary_length_two(self): """Test when input_ids has exactly two elements both being bos_token_id.""" input_ids = [0, 0] labels = [1, 2] example = { "input_ids": input_ids, "labels": labels, } example = remove_double_bos_token(example, 0) self.assertEqual(example["input_ids"], [0]) self.assertEqual(example["labels"], [2]) def test_short_input_ids_no_error(self): """Test that short input_ids (len < 2) don't cause errors.""" example = {"input_ids": [0], "labels": [1]} result = remove_double_bos_token(example, 0) self.assertEqual(result["input_ids"], [0])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/data/test_utils.py` around lines 544 - 582, Add unit tests in TestRemoveDoubleBOSToken to cover short-sequence edge cases and fix style: add three new test methods that call remove_double_bos_token to verify behavior for (1) empty input_ids and labels (ensure it returns unchanged and does not error), (2) single-element input_ids (len==1) with bos_token_id and non-bos and assert it returns the same sequence, and (3) boundary case of exactly two elements both equal to bos_token_id to assert it collapses to a single-element result; use self.assertEqual instead of bare assert to match existing style and reference the existing TestRemoveDoubleBOSToken class and remove_double_bos_token function to locate where to add these tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/trainers/dpo/trainer.py`:
- Around line 58-74: The _tokenize method currently assumes processing_class has
bos_token_id which is only guaranteed on PreTrainedTokenizerBase; update
_tokenize (and rename the parameter input to inputs to avoid shadowing) to first
resolve a tokenizer that exposes bos_token_id—e.g., if
isinstance(processing_class, PreTrainedTokenizerBase) use processing_class, else
try getattr(processing_class, "tokenizer", None) or getattr(processing_class,
"tokenizer", "processor", None) and then check hasattr(tokenizer,
"bos_token_id"); only call remove_double_bos_token(result, bos_id) when bos_id
is present, otherwise return result unchanged; keep references to the existing
_tokenize method, ProcessorMixin, PreTrainedTokenizerBase,
remove_double_bos_token, and bos_token_id to locate the change.
---
Nitpick comments:
In `@src/axolotl/utils/data/utils.py`:
- Around line 354-363: The function remove_double_bos_token mutates example
in-place and assumes every example[key] is a sliceable list, which can raise
TypeError for scalar metadata; change it to either return a new dict or guard
mutations: when bos_token_id is not None and a double-BOS is detected (in
remove_double_bos_token), iterate keys and only slice/modify values that are
instances of list (or collections.abc.Sequence) and leave non-list values
unchanged (or copy them into the new dict if you choose to return a new object);
ensure the function's docstring is updated to state whether mutation is
intentional and that only list-like fields are affected, and reference the
input_ids check and keys loop (example["input_ids"] and for key in example) when
applying the guard.
In `@tests/utils/data/test_utils.py`:
- Around line 544-582: Add unit tests in TestRemoveDoubleBOSToken to cover
short-sequence edge cases and fix style: add three new test methods that call
remove_double_bos_token to verify behavior for (1) empty input_ids and labels
(ensure it returns unchanged and does not error), (2) single-element input_ids
(len==1) with bos_token_id and non-bos and assert it returns the same sequence,
and (3) boundary case of exactly two elements both equal to bos_token_id to
assert it collapses to a single-element result; use self.assertEqual instead of
bare assert to match existing style and reference the existing
TestRemoveDoubleBOSToken class and remove_double_bos_token function to locate
where to add these tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 07aa515d-4a0e-437c-94b8-b7cb6c06969d
📒 Files selected for processing (13)
src/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/base.pysrc/axolotl/core/trainers/dpo/__init__.pysrc/axolotl/core/trainers/dpo/args.pysrc/axolotl/core/trainers/dpo/trainer.pysrc/axolotl/prompt_strategies/bradley_terry/chat_template.pysrc/axolotl/prompt_strategies/orpo/chat_template.pysrc/axolotl/utils/data/utils.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/deprecated.pytests/e2e/test_dpo.pytests/test_prompt_tokenizers.pytests/utils/data/test_utils.py
💤 Files with no reviewable changes (2)
- src/axolotl/core/builders/rl.py
- tests/e2e/test_dpo.py
NanoCode012
left a comment
There was a problem hiding this comment.
Thanks for the cleanup, took a glance and noted the below.
src/axolotl/utils/schemas/config.py
Outdated
| dpo_norm_loss: bool | None = Field( | ||
| default=None, | ||
| deprecated="Deprecated in v0.15.1 due to breaking changes in TRL >=v0.29.0. Will be readded upon TRL support.", | ||
| ) |
There was a problem hiding this comment.
We should remove this as this class inherits Deprecatedparameters and would be a duplicate. Same for the other change below.
|
|
||
| @with_temp_dir | ||
| def test_dpo_nll_lora(self, temp_dir): | ||
| cfg = DictDefault( |
There was a problem hiding this comment.
I couldn't rmb if this test was specifically for rpo_alpha. If it is, we'd need to adjust to support it or remove this?
There was a problem hiding this comment.
Thanks for catching, the configs are identical without the parameter so you're probably right. Will remove
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
Summary of changes:
dpo_norm_loss. As outlined in DPOdpo_norm_lossno longer works ontrl==0.29.0#3548, the 0.29 refactors in TRL's DPOTrainer break Axolotl's implementation. The goal is to add this back in once TRL natively supports this, PR already here: Add length-normalized sigmoid loss type to DPO trainer huggingface/trl#5406chosen/rejected_input_idstochosen/rejected_idsto be consistent with TRL. Rename input keys inRewardTrainercollator fromchosen/rejected_input_idstochosen/rejected_idshuggingface/trl#5179rpo_alpha. RPO is now configured by passing listloss_type=["sigmoid", "sft"]https://github.com/huggingface/trl/blob/main/docs/source/paper_index.md#iterative-reasoning-preference-optimizationtokenize_rowoverride (deprecated) with_tokenizeto handle bos_tokens. Previously, this override handled bos token bugs. The only bug that still exists is the double bos token bug for tokenizers with bos_tokens such as llama. The new_tokenizemethod handles this.loss_typeto a list (was previously a string). In TRL 0.29, DPOTrainer'sloss_typenow takes a list of strings rather than a single string allowing multiple losses to be combined. Note: This needs to be supported, but I have made a separate issue here Support DPOloss_typeandloss_weights. #3565 .I recommend reviewing commit by commit.
Motivation and Context
Breaking changes were introduced in TRL v0.29.0 for DPO, parts of Axolotl need to be updated to interface with the new code. eg. #3548.
How has this been tested?
Unit tests
AI Usage Disclaimer
All fixes written completely by me. Claude helped find some but not all of the bugs.