Skip to content

DPO support loss types#3566

Open
BrownianNotion wants to merge 5 commits intoaxolotl-ai-cloud:mainfrom
BrownianNotion:dpo-support-loss-types
Open

DPO support loss types#3566
BrownianNotion wants to merge 5 commits intoaxolotl-ai-cloud:mainfrom
BrownianNotion:dpo-support-loss-types

Conversation

@BrownianNotion
Copy link
Copy Markdown
Contributor

@BrownianNotion BrownianNotion commented Mar 31, 2026

Description

Support multiple loss types and weights as part of TRL >= v0.29.0 updates. #3565

Key changes:

  1. Adds two new config parameters dpo_loss_type and dpo_loss_weights which are passed to Axolotl's DPO Trainer via DPOStrategy. Note that dpo_loss_type expects a list of strings.
  2. Partially deprecates RLType.IPO/rl: ipo in favour of passing rl: dpo, dpo_loss_type: ["ipo"]. Warns users of upcoming deprecation.

Motivation and Context

This PR exposes the full list of loss functions available in TRL and the ability to combine multiple dpo losses with custom weightings.

This will also restore losses previously available in axolotl such as RPO (and hopefully length-normalised DPO once TRL supports this), but are now broken on TRL >= 0.29 due to refactors/parameter deprecation [see #3548 #3560]. This PR aims to update Axolotl to work with the newer TRL version and restore this functionality.

This PR also begins cleaning up RLType.IPO, which is no longer needed as a separate RL type.

How has this been tested?

Unit tests.

AI Usage Disclaimer

Claude for light review, updates/fixes written by me.

Types of changes

Config additions, new losses for DPO, partial deprecation of RLType.IPO.

Summary by CodeRabbit

  • New Features

    • Added configurable DPO loss types and weights via dpo_loss_type and dpo_loss_weights parameters for flexible training customization.
  • Deprecation

    • The rl: ipo configuration is now deprecated; use rl: dpo with dpo_loss_type: ["ipo"] instead.
  • Documentation

    • Updated training method guidance to reflect the new DPO configuration approach and IPO deprecation notice.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 31, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 340801af-9e97-4f61-be0b-e9799bb1f619

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR consolidates DPO and IPO training modes by introducing dpo_loss_type and dpo_loss_weights configuration fields that enable DPO mode to handle both standard DPO and IPO loss functions via a single rl: dpo setting, while deprecating the separate rl: ipo option.

Changes

Cohort / File(s) Summary
Documentation Updates
AGENTS.md, docs/agents/preference_tuning.md, docs/rlhf.qmd
Updated method selection guidance and IPO configuration documentation to use rl: dpo, dpo_loss_type: ["ipo"] instead of rl: ipo. Includes deprecation notice for the old rl: ipo setting.
Configuration Schema
src/axolotl/utils/schemas/config.py
Added two new optional fields to AxolotlInputConfig: `dpo_loss_type: list[str]
Validation Logic
src/axolotl/utils/schemas/validation.py
Added check_dpo validator to RLValidationMixin that emits a deprecation warning for rl: ipo, enforces matching lengths between dpo_loss_type and dpo_loss_weights when both provided, and restricts these fields to DPO mode only.
Trainer Implementation
src/axolotl/core/trainers/dpo/__init__.py
Updated DPOStrategy.set_training_args_kwargs to conditionally inject loss_type and loss_weights from the new config fields when rl is RLType.DPO.
Tests
tests/core/test_builders.py, tests/e2e/test_dpo.py
Updated DPO test fixtures to include dpo_loss_type and dpo_loss_weights. Modified test_ipo_lora to use rl: dpo with dpo_loss_type: ["ipo"]. Added new test_rpo E2E test for composite loss configurations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related issues

Possibly related PRs

Suggested labels

documentation, ready to merge

Suggested reviewers

  • winglian
  • djsaunde
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'DPO support loss types' directly and clearly describes the main objective of the PR: adding support for multiple DPO loss types and weighted combinations.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@BrownianNotion
Copy link
Copy Markdown
Contributor Author

@NanoCode012 when you have some time, would love to get your thoughts on the current state of the changes :) In particular, are we happy to support all loss_types and loss_weights that TRL does? these should reasonably well-tested upstream, so hopefully there isn't too much risk in supporting them.

@BrownianNotion BrownianNotion force-pushed the dpo-support-loss-types branch from f484ccf to 729a2e6 Compare April 1, 2026 18:57
@NanoCode012
Copy link
Copy Markdown
Collaborator

NanoCode012 commented Apr 2, 2026

@BrownianNotion , what other types do you have in mind? or do you mean a passthrough like you've done?

@BrownianNotion
Copy link
Copy Markdown
Contributor Author

@BrownianNotion , what other types do you have in mind? or do you mean a passthrough like you've done?

I mean passing through the loss_type from config to Axolotl's DPOTrainer. Based on my understanding, this wasn't previously exposed as a config parameter, or at least, it wasn't plumbed through (correct me if I'm wrong). So all loss types except sigmoid (default) and ipo (RLType.IPO) weren't available. https://huggingface.co/docs/trl/dpo_trainer#loss-types

Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this took a while, just a bit of cleanup would do. I think this config naming is ok with me since it follows existing convention.

@model_validator(mode="before")
@classmethod
def check_dpo(cls, data):
if data.get("rl") == "dpo":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add validation, if dpo_loss_type or the other is set, and not rl: dpo, we should also raise ValueError given it's dpo specific.

loss_types = data.get("dpo_loss_type")
loss_weights = data.get("dpo_loss_weights")

if loss_types and loss_weights and len(loss_types) != len(loss_weights):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: loss type validation mentioned by the other user on the Issue, does trl not validate that config upstream and error appropriately?

I'm a bit hesitant on hardcoding any constant/lists for loss types.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your call - I figured we would want to error early rather than have the user wait for the runtime trl error. But I understand that this would have added maintenance costs given the hard coding (there isn't an enum we can just import from trl if I remember correctly)

@NanoCode012
Copy link
Copy Markdown
Collaborator

Remove IPO RLType and use DPO loss_type instead.

We can recommend people to use the new method instead on the ipo docs, while backwards compat re-routing to passing loss types if rl: ipo is passed. I think this is a good transition.

@BrownianNotion BrownianNotion force-pushed the dpo-support-loss-types branch from 729a2e6 to ebdbb2e Compare April 9, 2026 10:37
@BrownianNotion BrownianNotion marked this pull request as ready for review April 9, 2026 16:47
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/utils/schemas/config.py`:
- Around line 312-320: The schema currently allows empty lists for dpo_loss_type
and dpo_loss_weights which lets invalid configs pass; update both Field
declarations (dpo_loss_type and dpo_loss_weights) to require at least one
element (e.g., add min_items=1 to Field) so a non-empty list is enforced when
provided, and add a schema-level validator (e.g., a root_validator or a method
named validate_dpo_loss_consistency) to ensure that when both lists are present
their lengths match and neither is empty.

In `@src/axolotl/utils/schemas/validation.py`:
- Around line 788-797: The validation for rl=="dpo" incorrectly uses truthiness
so empty lists bypass the length check and the error message has a typo; update
the condition to explicitly check for None (e.g., "if dpo_loss_type is not None
and dpo_loss_weights is not None and len(dpo_loss_type) !=
len(dpo_loss_weights):") so empty lists are validated, and fix the error text to
reference the correct field name `dpo_loss_weights` instead of
`dpo_dpo_loss_weights`, keeping references to rl, dpo_loss_type,
dpo_loss_weights and the existing ValueError raise.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 30afffad-dc94-4c61-8169-f54007ab6879

📥 Commits

Reviewing files that changed from the base of the PR and between 7daf7d9 and 07a99ce.

📒 Files selected for processing (8)
  • AGENTS.md
  • docs/agents/preference_tuning.md
  • docs/rlhf.qmd
  • src/axolotl/core/trainers/dpo/__init__.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/validation.py
  • tests/core/test_builders.py
  • tests/e2e/test_dpo.py

Comment on lines +312 to +320
dpo_loss_type: list[str] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)

dpo_loss_weights: list[float] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Prevent empty DPO loss lists at schema level.

Line 312 and Line 317 currently accept empty lists, which permits invalid no-op/misaligned configs to pass initial validation.

Suggested schema hardening
-    dpo_loss_type: list[str] | None = Field(
+    dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
         default=None,
         json_schema_extra={"description": "List of DPO losses to use."},
     )

-    dpo_loss_weights: list[float] | None = Field(
+    dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
         default=None,
         json_schema_extra={"description": "Weights for each DPO loss."},
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dpo_loss_type: list[str] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)
dpo_loss_weights: list[float] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/config.py` around lines 312 - 320, The schema
currently allows empty lists for dpo_loss_type and dpo_loss_weights which lets
invalid configs pass; update both Field declarations (dpo_loss_type and
dpo_loss_weights) to require at least one element (e.g., add min_items=1 to
Field) so a non-empty list is enforced when provided, and add a schema-level
validator (e.g., a root_validator or a method named
validate_dpo_loss_consistency) to ensure that when both lists are present their
lengths match and neither is empty.

Comment on lines +788 to +797
if rl == "dpo":
if (
dpo_loss_type
and dpo_loss_weights
and len(dpo_loss_type) != len(dpo_loss_weights)
):
raise ValueError(
f"`dpo_loss_type` and `dpo_dpo_loss_weights` must be the same length, "
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix falsy-list validation gap and typo in the mismatch error.

Line 790 uses truthiness, so empty-list inputs can bypass the length check. Also, Line 795 has a typo in the field name (dpo_dpo_loss_weights).

Suggested fix
-        if rl == "dpo":
-            if (
-                dpo_loss_type
-                and dpo_loss_weights
-                and len(dpo_loss_type) != len(dpo_loss_weights)
-            ):
+        if rl == "dpo":
+            if (
+                dpo_loss_type is not None
+                and dpo_loss_weights is not None
+                and len(dpo_loss_type) != len(dpo_loss_weights)
+            ):
                 raise ValueError(
-                    f"`dpo_loss_type` and `dpo_dpo_loss_weights` must be the same length, "
+                    f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
                     f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/validation.py` around lines 788 - 797, The
validation for rl=="dpo" incorrectly uses truthiness so empty lists bypass the
length check and the error message has a typo; update the condition to
explicitly check for None (e.g., "if dpo_loss_type is not None and
dpo_loss_weights is not None and len(dpo_loss_type) != len(dpo_loss_weights):")
so empty lists are validated, and fix the error text to reference the correct
field name `dpo_loss_weights` instead of `dpo_dpo_loss_weights`, keeping
references to rl, dpo_loss_type, dpo_loss_weights and the existing ValueError
raise.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 9, 2026

Codecov Report

❌ Patch coverage is 87.50000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/schemas/validation.py 82.35% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants