DPO support loss types#3566
Conversation
e813169 to
f484ccf
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR consolidates DPO and IPO training modes by introducing Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@NanoCode012 when you have some time, would love to get your thoughts on the current state of the changes :) In particular, are we happy to support all loss_types and loss_weights that TRL does? these should reasonably well-tested upstream, so hopefully there isn't too much risk in supporting them. |
f484ccf to
729a2e6
Compare
|
@BrownianNotion , what other types do you have in mind? or do you mean a passthrough like you've done? |
I mean passing through the |
NanoCode012
left a comment
There was a problem hiding this comment.
Sorry this took a while, just a bit of cleanup would do. I think this config naming is ok with me since it follows existing convention.
| @model_validator(mode="before") | ||
| @classmethod | ||
| def check_dpo(cls, data): | ||
| if data.get("rl") == "dpo": |
There was a problem hiding this comment.
Can we add validation, if dpo_loss_type or the other is set, and not rl: dpo, we should also raise ValueError given it's dpo specific.
| loss_types = data.get("dpo_loss_type") | ||
| loss_weights = data.get("dpo_loss_weights") | ||
|
|
||
| if loss_types and loss_weights and len(loss_types) != len(loss_weights): |
There was a problem hiding this comment.
Re: loss type validation mentioned by the other user on the Issue, does trl not validate that config upstream and error appropriately?
I'm a bit hesitant on hardcoding any constant/lists for loss types.
There was a problem hiding this comment.
Your call - I figured we would want to error early rather than have the user wait for the runtime trl error. But I understand that this would have added maintenance costs given the hard coding (there isn't an enum we can just import from trl if I remember correctly)
We can recommend people to use the new method instead on the ipo docs, while backwards compat re-routing to passing loss types if rl: ipo is passed. I think this is a good transition. |
729a2e6 to
ebdbb2e
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/utils/schemas/config.py`:
- Around line 312-320: The schema currently allows empty lists for dpo_loss_type
and dpo_loss_weights which lets invalid configs pass; update both Field
declarations (dpo_loss_type and dpo_loss_weights) to require at least one
element (e.g., add min_items=1 to Field) so a non-empty list is enforced when
provided, and add a schema-level validator (e.g., a root_validator or a method
named validate_dpo_loss_consistency) to ensure that when both lists are present
their lengths match and neither is empty.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 788-797: The validation for rl=="dpo" incorrectly uses truthiness
so empty lists bypass the length check and the error message has a typo; update
the condition to explicitly check for None (e.g., "if dpo_loss_type is not None
and dpo_loss_weights is not None and len(dpo_loss_type) !=
len(dpo_loss_weights):") so empty lists are validated, and fix the error text to
reference the correct field name `dpo_loss_weights` instead of
`dpo_dpo_loss_weights`, keeping references to rl, dpo_loss_type,
dpo_loss_weights and the existing ValueError raise.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 30afffad-dc94-4c61-8169-f54007ab6879
📒 Files selected for processing (8)
AGENTS.mddocs/agents/preference_tuning.mddocs/rlhf.qmdsrc/axolotl/core/trainers/dpo/__init__.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/validation.pytests/core/test_builders.pytests/e2e/test_dpo.py
src/axolotl/utils/schemas/config.py
Outdated
| dpo_loss_type: list[str] | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "List of DPO losses to use."}, | ||
| ) | ||
|
|
||
| dpo_loss_weights: list[float] | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "Weights for each DPO loss."}, | ||
| ) |
There was a problem hiding this comment.
Prevent empty DPO loss lists at schema level.
Line 312 and Line 317 currently accept empty lists, which permits invalid no-op/misaligned configs to pass initial validation.
Suggested schema hardening
- dpo_loss_type: list[str] | None = Field(
+ dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)
- dpo_loss_weights: list[float] | None = Field(
+ dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| dpo_loss_type: list[str] | None = Field( | |
| default=None, | |
| json_schema_extra={"description": "List of DPO losses to use."}, | |
| ) | |
| dpo_loss_weights: list[float] | None = Field( | |
| default=None, | |
| json_schema_extra={"description": "Weights for each DPO loss."}, | |
| ) | |
| dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field( | |
| default=None, | |
| json_schema_extra={"description": "List of DPO losses to use."}, | |
| ) | |
| dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field( | |
| default=None, | |
| json_schema_extra={"description": "Weights for each DPO loss."}, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/utils/schemas/config.py` around lines 312 - 320, The schema
currently allows empty lists for dpo_loss_type and dpo_loss_weights which lets
invalid configs pass; update both Field declarations (dpo_loss_type and
dpo_loss_weights) to require at least one element (e.g., add min_items=1 to
Field) so a non-empty list is enforced when provided, and add a schema-level
validator (e.g., a root_validator or a method named
validate_dpo_loss_consistency) to ensure that when both lists are present their
lengths match and neither is empty.
| if rl == "dpo": | ||
| if ( | ||
| dpo_loss_type | ||
| and dpo_loss_weights | ||
| and len(dpo_loss_type) != len(dpo_loss_weights) | ||
| ): | ||
| raise ValueError( | ||
| f"`dpo_loss_type` and `dpo_dpo_loss_weights` must be the same length, " | ||
| f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights" | ||
| ) |
There was a problem hiding this comment.
Fix falsy-list validation gap and typo in the mismatch error.
Line 790 uses truthiness, so empty-list inputs can bypass the length check. Also, Line 795 has a typo in the field name (dpo_dpo_loss_weights).
Suggested fix
- if rl == "dpo":
- if (
- dpo_loss_type
- and dpo_loss_weights
- and len(dpo_loss_type) != len(dpo_loss_weights)
- ):
+ if rl == "dpo":
+ if (
+ dpo_loss_type is not None
+ and dpo_loss_weights is not None
+ and len(dpo_loss_type) != len(dpo_loss_weights)
+ ):
raise ValueError(
- f"`dpo_loss_type` and `dpo_dpo_loss_weights` must be the same length, "
+ f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/utils/schemas/validation.py` around lines 788 - 797, The
validation for rl=="dpo" incorrectly uses truthiness so empty lists bypass the
length check and the error message has a typo; update the condition to
explicitly check for None (e.g., "if dpo_loss_type is not None and
dpo_loss_weights is not None and len(dpo_loss_type) != len(dpo_loss_weights):")
so empty lists are validated, and fix the error text to reference the correct
field name `dpo_loss_weights` instead of `dpo_dpo_loss_weights`, keeping
references to rl, dpo_loss_type, dpo_loss_weights and the existing ValueError
raise.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Description
Support multiple loss types and weights as part of TRL >= v0.29.0 updates. #3565
Key changes:
dpo_loss_typeanddpo_loss_weightswhich are passed to Axolotl's DPO Trainer viaDPOStrategy. Note thatdpo_loss_typeexpects a list of strings.RLType.IPO/rl: ipoin favour of passingrl: dpo, dpo_loss_type: ["ipo"]. Warns users of upcoming deprecation.Motivation and Context
This PR exposes the full list of loss functions available in TRL and the ability to combine multiple dpo losses with custom weightings.
This will also restore losses previously available in axolotl such as RPO (and hopefully length-normalised DPO once TRL supports this), but are now broken on TRL >= 0.29 due to refactors/parameter deprecation [see #3548 #3560]. This PR aims to update Axolotl to work with the newer TRL version and restore this functionality.
This PR also begins cleaning up
RLType.IPO, which is no longer needed as a separate RL type.How has this been tested?
Unit tests.
AI Usage Disclaimer
Claude for light review, updates/fixes written by me.
Types of changes
Config additions, new losses for DPO, partial deprecation of
RLType.IPO.Summary by CodeRabbit
New Features
dpo_loss_typeanddpo_loss_weightsparameters for flexible training customization.Deprecation
rl: ipoconfiguration is now deprecated; userl: dpowithdpo_loss_type: ["ipo"]instead.Documentation