Skip to content

Commit f484ccf

Browse files
Support loss_type/loss_weights DPO
1 parent b477929 commit f484ccf

File tree

5 files changed

+87
-0
lines changed

5 files changed

+87
-0
lines changed

src/axolotl/core/trainers/dpo/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def set_training_args_kwargs(cls, cfg):
2222
training_args_kwargs = {}
2323
if cfg.rl is RLType.IPO:
2424
training_args_kwargs["loss_type"] = ["ipo"]
25+
else:
26+
if cfg.dpo_loss_type is not None:
27+
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
28+
29+
if cfg.dpo_loss_weights is not None:
30+
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
31+
2532
# Label smoothing is not compatible with IPO
2633
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
2734
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing

src/axolotl/utils/schemas/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,16 @@ class AxolotlInputConfig(
302302

303303
dpo_padding_free: bool | None = None
304304

305+
dpo_loss_type: list[str] | None = Field(
306+
default=None,
307+
json_schema_extra={"description": "List of DPO losses to use."},
308+
)
309+
310+
dpo_loss_weights: list[float] | None = Field(
311+
default=None,
312+
json_schema_extra={"description": "Weights for each DPO loss."},
313+
)
314+
305315
datasets: (
306316
Annotated[
307317
list[

src/axolotl/utils/schemas/validation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,20 @@ def check_gdpo(cls, data):
770770
)
771771
return data
772772

773+
@model_validator(mode="before")
774+
@classmethod
775+
def check_dpo(cls, data):
776+
if data.get("rl") == "dpo":
777+
loss_types = data.get("dpo_loss_type")
778+
loss_weights = data.get("dpo_loss_weights")
779+
780+
if loss_types and loss_weights and len(loss_types) != len(loss_weights):
781+
raise ValueError(
782+
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
783+
f"but got {len(loss_types)} losses and {len(loss_weights)} weights"
784+
)
785+
return data
786+
773787

774788
class OptimizationValidationMixin:
775789
"""Validation methods related to optimization and performance."""

tests/core/test_builders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
9696
"dpo_use_weighting": True,
9797
"dpo_label_smoothing": 0.1,
9898
"beta": 0.1, # DPO beta
99+
"dpo_loss_type": ["sigmoid", "sft"],
100+
"dpo_loss_weights": [1.0, 0.5],
99101
}
100102
)
101103
return cfg
@@ -298,6 +300,8 @@ def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
298300
assert hasattr(training_arguments, "use_weighting")
299301
assert training_arguments.use_weighting is True
300302
assert training_arguments.label_smoothing == 0.1
303+
assert training_arguments.loss_type == ["sigmoid", "sft"]
304+
assert training_arguments.loss_weights == [1.0, 0.5]
301305

302306
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
303307
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)

tests/e2e/test_dpo.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,58 @@ def test_dpo_use_weighting(self, temp_dir):
116116
train(cfg=cfg, dataset_meta=dataset_meta)
117117
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
118118

119+
@with_temp_dir
120+
def test_rpo(self, temp_dir):
121+
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
122+
# replaces loss_type="rpo", rpo_alpha=alpha.
123+
cfg = DictDefault(
124+
{
125+
"base_model": "HuggingFaceTB/SmolLM2-135M",
126+
"tokenizer_type": "AutoTokenizer",
127+
"sequence_len": 1024,
128+
"load_in_8bit": True,
129+
"adapter": "lora",
130+
"lora_r": 64,
131+
"lora_alpha": 32,
132+
"lora_dropout": 0.1,
133+
"lora_target_linear": True,
134+
"special_tokens": {
135+
"pad_token": "<|endoftext|>",
136+
},
137+
"rl": "dpo",
138+
"dpo_loss_type": ["sigmoid", "sft"],
139+
"dpo_loss_weights": [1.0, 1.0],
140+
"datasets": [
141+
{
142+
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
143+
"type": "chatml.ultra",
144+
"split": "train",
145+
},
146+
],
147+
"num_epochs": 1,
148+
"micro_batch_size": 4,
149+
"gradient_accumulation_steps": 1,
150+
"output_dir": temp_dir,
151+
"learning_rate": 0.00001,
152+
"optimizer": "paged_adamw_8bit",
153+
"lr_scheduler": "cosine",
154+
"max_steps": 20,
155+
"save_steps": 10,
156+
"warmup_steps": 5,
157+
"gradient_checkpointing": True,
158+
"gradient_checkpointing_kwargs": {"use_reentrant": True},
159+
"save_first_step": False,
160+
}
161+
)
162+
163+
cfg = validate_config(cfg)
164+
normalize_config(cfg)
165+
cli_args = TrainerCliArgs()
166+
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
167+
168+
train(cfg=cfg, dataset_meta=dataset_meta)
169+
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
170+
119171
@pytest.mark.skip("kto_pair no longer supported in trl")
120172
@with_temp_dir
121173
def test_kto_pair_lora(self, temp_dir):

0 commit comments

Comments
 (0)