Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
| Method | Config Key | When to Use |
|--------|-----------|-------------|
| SFT | *(default)* | Input-output pairs, instruction tuning |
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
| KTO | `rl: kto` | Unpaired binary preference labels |
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
Expand Down
2 changes: 1 addition & 1 deletion docs/agents/preference_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da

1. Paired preference data (chosen + rejected)?
- Default → `rl: dpo`
- Overfitting → `rl: ipo`
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
- VRAM-limited → `rl: orpo` (no ref model)
- Length-sensitive → `rl: simpo` (no ref model)
2. Only binary labels (good/bad)? → `rl: kto`
Expand Down
4 changes: 3 additions & 1 deletion docs/rlhf.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.

```yaml
rl: ipo
rl: dpo
dpo_loss_type: ["ipo"]
```
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.

### ORPO

Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/core/trainers/dpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@ def get_training_args_class(cls):
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl is RLType.DPO:
if cfg.dpo_loss_type is not None:
training_args_kwargs["loss_type"] = cfg.dpo_loss_type

if cfg.dpo_loss_weights is not None:
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights

if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = ["ipo"]

# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ class AxolotlInputConfig(

dpo_padding_free: bool | None = None

dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)

dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)

datasets: (
Annotated[
list[
Expand Down
37 changes: 37 additions & 0 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import sys
import tempfile
import warnings
from pathlib import Path

from pydantic import (
Expand Down Expand Up @@ -770,6 +771,42 @@ def check_gdpo(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_dpo(cls, data):
dpo_loss_type = data.get("dpo_loss_type")
dpo_loss_weights = data.get("dpo_loss_weights")
rl = data.get("rl")

if rl == "ipo":
warnings.warn(
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead.",
DeprecationWarning,
stacklevel=2,
)

if rl == "dpo":
if dpo_loss_weights is not None and dpo_loss_type is None:
raise ValueError(
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
)
if (
dpo_loss_type is not None
and dpo_loss_weights is not None
and len(dpo_loss_type) != len(dpo_loss_weights)
):
raise ValueError(
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
)
Comment on lines +788 to +801
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix falsy-list validation gap and typo in the mismatch error.

Line 790 uses truthiness, so empty-list inputs can bypass the length check. Also, Line 795 has a typo in the field name (dpo_dpo_loss_weights).

Suggested fix
-        if rl == "dpo":
-            if (
-                dpo_loss_type
-                and dpo_loss_weights
-                and len(dpo_loss_type) != len(dpo_loss_weights)
-            ):
+        if rl == "dpo":
+            if (
+                dpo_loss_type is not None
+                and dpo_loss_weights is not None
+                and len(dpo_loss_type) != len(dpo_loss_weights)
+            ):
                 raise ValueError(
-                    f"`dpo_loss_type` and `dpo_dpo_loss_weights` must be the same length, "
+                    f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
                     f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/validation.py` around lines 788 - 797, The
validation for rl=="dpo" incorrectly uses truthiness so empty lists bypass the
length check and the error message has a typo; update the condition to
explicitly check for None (e.g., "if dpo_loss_type is not None and
dpo_loss_weights is not None and len(dpo_loss_type) != len(dpo_loss_weights):")
so empty lists are validated, and fix the error text to reference the correct
field name `dpo_loss_weights` instead of `dpo_dpo_loss_weights`, keeping
references to rl, dpo_loss_type, dpo_loss_weights and the existing ValueError
raise.

elif dpo_loss_type is not None or dpo_loss_weights is not None:
raise ValueError(
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
)

return data


class OptimizationValidationMixin:
"""Validation methods related to optimization and performance."""
Expand Down
7 changes: 6 additions & 1 deletion tests/core/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
"dpo_use_weighting": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
"dpo_loss_type": ["sigmoid", "sft"],
"dpo_loss_weights": [1.0, 0.5],
}
)
return cfg
Expand Down Expand Up @@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.IPO,
"rl": RLType.DPO,
"dpo_loss_type": ["ipo"],
"dpo_label_smoothing": 0,
"beta": 0.1,
}
Expand Down Expand Up @@ -300,6 +303,8 @@ def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
assert training_arguments.use_weighting is True
assert training_arguments.label_smoothing == 0.1
assert training_arguments.precompute_ref_log_probs is True
assert training_arguments.loss_type == ["sigmoid", "sft"]
assert training_arguments.loss_weights == [1.0, 0.5]

def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
Expand Down
55 changes: 54 additions & 1 deletion tests/e2e/test_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,58 @@ def test_dpo_use_weighting(self, temp_dir):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_rpo(self, temp_dir):
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
# replaces loss_type="rpo", rpo_alpha=alpha.
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"dpo_loss_type": ["sigmoid", "sft"],
"dpo_loss_weights": [1.0, 1.0],
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)

cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
Expand Down Expand Up @@ -181,7 +233,8 @@ def test_ipo_lora(self, temp_dir):
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "ipo",
"rl": "dpo",
"dpo_loss_type": ["ipo"],
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
Expand Down
Loading