Skip to content
3 changes: 0 additions & 3 deletions src/axolotl/core/builders/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ def _build_training_arguments(self, total_num_steps):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha

if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name

Expand Down
10 changes: 4 additions & 6 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,13 @@ def evaluate(self, *args, **kwargs):
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

max_length = max(
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
)
max_length = max(inputs["input_ids"].shape[1], inputs["rejected_ids"].shape[1])
# Concatenate positive and negative inputs
concatenated_batch["input_ids"] = pad_to_length(
inputs["input_ids"], max_length, pad_token
)
concatenated_batch["rejected_input_ids"] = pad_to_length(
inputs["rejected_input_ids"], max_length, pad_token
concatenated_batch["rejected_ids"] = pad_to_length(
inputs["rejected_ids"], max_length, pad_token
)
concatenated_batch["labels"] = pad_to_length(
inputs["labels"], max_length, label_pad_token
Expand All @@ -432,7 +430,7 @@ def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=No
).to(device=device)

input_ids = torch.cat(
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
[concatenated_batch["input_ids"], concatenated_batch["rejected_ids"]],
dim=0,
).to(device=device)
attention_mask = torch.cat(
Expand Down
4 changes: 1 addition & 3 deletions src/axolotl/core/trainers/dpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_training_args_class(cls):
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["loss_type"] = ["ipo"]
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
Expand All @@ -30,8 +30,6 @@ def set_training_args_kwargs(cls, cfg):
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs
6 changes: 1 addition & 5 deletions src/axolotl/core/trainers/dpo/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
Axolotl specific DPO args
"""

from dataclasses import dataclass, field
from typing import Optional
from dataclasses import dataclass

from trl import DPOConfig

Expand All @@ -15,6 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""

dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)
68 changes: 24 additions & 44 deletions src/axolotl/core/trainers/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import nn
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from trl import DPOTrainer

from axolotl.core.trainers.mixins import (
Expand All @@ -18,6 +19,7 @@
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.data.utils import remove_double_bos_token


class AxolotlDPOTrainer(
Expand Down Expand Up @@ -53,36 +55,31 @@ def push_to_hub(self, *args, **kwargs) -> str:

return super().push_to_hub(*args, **kwargs)

@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
def _tokenize(
self,
processing_class: PreTrainedTokenizerBase | ProcessorMixin,
input: str | list,
**kwargs,
) -> dict[str, list]:
"""
Override TRL's tokenization in DPO trainer to fix double bos_token bug (eg. llama).
"""
result = super()._tokenize(
processing_class=processing_class, input=input, **kwargs
)

# Handle multimodal models
tokenizer = (
getattr(processing_class, "tokenizer", None)
if isinstance(processing_class, ProcessorMixin)
else processing_class
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]

if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
bos_token_id = getattr(tokenizer, "bos_token_id", None) if tokenizer else None
if bos_token_id is not None:
result = remove_double_bos_token(result, bos_token_id)

return res
return result

def training_step(
self,
Expand All @@ -94,20 +91,3 @@ def training_step(
gc.collect()
torch.cuda.empty_cache()
return loss

def concatenated_forward(
self,
model: nn.Module,
batch: dict[str, Union[list, torch.LongTensor]],
is_ref_model: bool = False,
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: list[str] = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = ["ipo"]
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
4 changes: 2 additions & 2 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def _tokenize_single_prompt(self, prompt):
]

return {
"chosen_input_ids": chosen_tokenized["input_ids"],
"chosen_ids": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"rejected_input_ids": rejected_tokenized["input_ids"],
"rejected_ids": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/prompt_strategies/orpo/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_rejected(self, prompt) -> MessageList:

class ORPOTokenizingStrategy(PromptTokenizingStrategy):
"""
rejected_input_ids
rejected_ids
input_ids
rejected_attention_mask
attention_mask
Expand Down Expand Up @@ -169,7 +169,7 @@ def tokenize_prompt(self, prompt):
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
prompt_len = len(input_ids)
# remap the input_ids, attention_mask and labels
rejected_input_ids = input_ids
rejected_ids = input_ids
rejected_labels = labels
# pass the chosen prompt/row to the Prompter to get the formatted prompt
chosen_message_list: MessageList = (
Expand All @@ -191,7 +191,7 @@ def tokenize_prompt(self, prompt):
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)

return {
"rejected_input_ids": rejected_input_ids,
"rejected_ids": rejected_ids,
"rejected_labels": rejected_labels,
"rejected_attention_mask": [1] * len(rejected_labels),
"input_ids": input_ids,
Expand Down
11 changes: 11 additions & 0 deletions src/axolotl/utils/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,14 @@ def handle_long_seq_in_dataset(
)

return dataset


def remove_double_bos_token(example: dict[str, list], bos_token_id: int):
"""Remove double bos tokens that may occur when retokenizing preprocessed data
for tokenizers and chat templates that have a bos_token - eg. DPO + Llama.
"""
input_ids = example["input_ids"]
if len(input_ids) >= 2 and input_ids[0] == input_ids[1] == bos_token_id:
for key in example:
example[key] = example[key][1:]
return example
7 changes: 0 additions & 7 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ class AxolotlInputConfig(
},
)
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None

dpo_use_liger_kernel: bool | None = Field(
default=None,
Expand Down Expand Up @@ -1111,12 +1110,6 @@ class AxolotlInputConfig(
"description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping."
},
)
rpo_alpha: float | None = Field(
default=None,
json_schema_extra={
"description": "Weighting of NLL term in loss from RPO paper"
},
)
simpo_gamma: float | None = Field(
default=None,
json_schema_extra={"description": "Target reward margin for the SimPO loss"},
Expand Down
22 changes: 22 additions & 0 deletions src/axolotl/utils/schemas/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class DeprecatedParameters(BaseModel):
eval_max_new_tokens: int | None = None
dpo_use_logits_to_keep: bool | None = None
dpo_generate_during_eval: bool | None = None
dpo_norm_loss: bool | None = None
rpo_alpha: float | None = None

@field_validator("max_packed_sequence_len")
@classmethod
Expand Down Expand Up @@ -100,6 +102,26 @@ def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):
)
return dpo_generate_during_eval

@field_validator("dpo_norm_loss")
@classmethod
def validate_dpo_norm_loss(cls, dpo_norm_loss):
if dpo_norm_loss is not None:
raise DeprecationWarning(
"`dpo_norm_loss` is no longer supported, "
"due to breaking changes in TRL >= 0.29.0"
)
return dpo_norm_loss

@field_validator("rpo_alpha")
@classmethod
def validate_rpo_alpha(cls, rpo_alpha):
if rpo_alpha is not None:
raise DeprecationWarning(
"`rpo_alpha` has been deprecated in TRL >= 0.29.0, "
"and now requires passing multiple loss types, which is not yet supported by Axolotl."
) # TODO: change this warning once multiple dpo loss types are supported.
return rpo_alpha


class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""
Expand Down
49 changes: 0 additions & 49 deletions tests/e2e/test_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,55 +67,6 @@ def test_dpo_lora(self, temp_dir):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)

cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
cfg = DictDefault(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,18 @@ def test_orpo_integration(
DictDefault({"chat_template": "chatml"}),
)
res = strat.tokenize_prompt(ds[0])
assert "rejected_input_ids" in res
assert "rejected_ids" in res
assert "rejected_labels" in res
assert "input_ids" in res
assert "labels" in res
assert "prompt_attention_mask" in res

assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
assert len(res["rejected_ids"]) == len(res["rejected_labels"])
assert len(res["input_ids"]) == len(res["labels"])
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])

assert res["rejected_labels"][0] == -100
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
assert res["rejected_ids"][-1] == res["rejected_labels"][-1]

assert res["labels"][0] == -100
assert res["input_ids"][-1] == res["labels"][-1]
Expand Down
Loading
Loading