diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index c5cdd37920..b4d8b4d47b 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -127,9 +127,6 @@ def _build_training_arguments(self, total_num_steps): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha - if self.cfg.rpo_alpha is not None: - training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha - if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8dc1a0239f..5bc44a1ddf 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -405,15 +405,13 @@ def evaluate(self, *args, **kwargs): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} - max_length = max( - inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] - ) + max_length = max(inputs["input_ids"].shape[1], inputs["rejected_ids"].shape[1]) # Concatenate positive and negative inputs concatenated_batch["input_ids"] = pad_to_length( inputs["input_ids"], max_length, pad_token ) - concatenated_batch["rejected_input_ids"] = pad_to_length( - inputs["rejected_input_ids"], max_length, pad_token + concatenated_batch["rejected_ids"] = pad_to_length( + inputs["rejected_ids"], max_length, pad_token ) concatenated_batch["labels"] = pad_to_length( inputs["labels"], max_length, label_pad_token @@ -432,7 +430,7 @@ def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=No ).to(device=device) input_ids = torch.cat( - [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], + [concatenated_batch["input_ids"], concatenated_batch["rejected_ids"]], dim=0, ).to(device=device) attention_mask = torch.cat( diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 93634f64bd..6d5251de13 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -21,7 +21,7 @@ def get_training_args_class(cls): def set_training_args_kwargs(cls, cfg): training_args_kwargs = {} if cfg.rl is RLType.IPO: - training_args_kwargs["loss_type"] = "ipo" + training_args_kwargs["loss_type"] = ["ipo"] # Label smoothing is not compatible with IPO if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing: training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing @@ -30,8 +30,6 @@ def set_training_args_kwargs(cls, cfg): training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting if cfg.dpo_padding_free is not None: training_args_kwargs["padding_free"] = cfg.dpo_padding_free - if cfg.dpo_norm_loss is not None: - training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss if cfg.dpo_use_liger_kernel is not None: training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel return training_args_kwargs diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py index a0af69c4c1..de1758ed09 100644 --- a/src/axolotl/core/trainers/dpo/args.py +++ b/src/axolotl/core/trainers/dpo/args.py @@ -2,8 +2,7 @@ Axolotl specific DPO args """ -from dataclasses import dataclass, field -from typing import Optional +from dataclasses import dataclass from trl import DPOConfig @@ -15,6 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): """ DPO config for DPO training """ - - dpo_norm_loss: bool | None = False - rpo_alpha: Optional[float] = field(default=None) diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 3c0bca3d42..2021ebeb13 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -6,6 +6,7 @@ import torch from torch import nn +from transformers import PreTrainedTokenizerBase, ProcessorMixin from trl import DPOTrainer from axolotl.core.trainers.mixins import ( @@ -18,6 +19,7 @@ sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils.data.utils import remove_double_bos_token class AxolotlDPOTrainer( @@ -53,36 +55,31 @@ def push_to_hub(self, *args, **kwargs) -> str: return super().push_to_hub(*args, **kwargs) - @staticmethod - def tokenize_row( - features, - processing_class, - max_prompt_length: int | None = None, - max_completion_length: int | None = None, - add_special_tokens: bool = True, - is_chat: bool = False, - ) -> Dict: - res = DPOTrainer.tokenize_row( - features, - processing_class, - max_prompt_length=max_prompt_length, - max_completion_length=max_completion_length, - add_special_tokens=add_special_tokens, - is_chat=is_chat, + def _tokenize( + self, + processing_class: PreTrainedTokenizerBase | ProcessorMixin, + input: str | list, + **kwargs, + ) -> dict[str, list]: + """ + Override TRL's tokenization in DPO trainer to fix double bos_token bug (eg. llama). + """ + result = super()._tokenize( + processing_class=processing_class, input=input, **kwargs + ) + + # Handle multimodal models + tokenizer = ( + getattr(processing_class, "tokenizer", None) + if isinstance(processing_class, ProcessorMixin) + else processing_class ) - # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen - if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: - for key in res.keys(): - res[key] = res[key][1:] - if processing_class.bos_token and processing_class.bos_token_id is not None: - # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs - if res["chosen_input_ids"][0] == processing_class.bos_token_id: - res["chosen_input_ids"] = res["chosen_input_ids"][1:] - if res["rejected_input_ids"][0] == processing_class.bos_token_id: - res["rejected_input_ids"] = res["rejected_input_ids"][1:] + bos_token_id = getattr(tokenizer, "bos_token_id", None) if tokenizer else None + if bos_token_id is not None: + result = remove_double_bos_token(result, bos_token_id) - return res + return result def training_step( self, @@ -94,20 +91,3 @@ def training_step( gc.collect() torch.cuda.empty_cache() return loss - - def concatenated_forward( - self, - model: nn.Module, - batch: dict[str, Union[list, torch.LongTensor]], - is_ref_model: bool = False, - ) -> dict[str, torch.Tensor]: - if self.args.dpo_norm_loss: - # fmt: off - loss_type: list[str] = self.loss_type # type: ignore[has-type] - # fmt: on - # concatenated_forward handles avg token logprob for ipo case already - self.loss_type = ["ipo"] - res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model) - self.loss_type = loss_type - return res - return super().concatenated_forward(model, batch, is_ref_model=is_ref_model) diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 03336b3ef9..e3f34ad8f6 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -71,10 +71,10 @@ def _tokenize_single_prompt(self, prompt): ] return { - "chosen_input_ids": chosen_tokenized["input_ids"], + "chosen_ids": chosen_tokenized["input_ids"], "attention_mask_chosen": chosen_tokenized["attention_mask"], "labels_chosen": 1.0, - "rejected_input_ids": rejected_tokenized["input_ids"], + "rejected_ids": rejected_tokenized["input_ids"], "attention_mask_rejected": rejected_tokenized["attention_mask"], "labels_rejected": 0.0, } diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index b655bc9703..259f1037dc 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -130,7 +130,7 @@ def get_rejected(self, prompt) -> MessageList: class ORPOTokenizingStrategy(PromptTokenizingStrategy): """ - rejected_input_ids + rejected_ids input_ids rejected_attention_mask attention_mask @@ -169,7 +169,7 @@ def tokenize_prompt(self, prompt): labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) prompt_len = len(input_ids) # remap the input_ids, attention_mask and labels - rejected_input_ids = input_ids + rejected_ids = input_ids rejected_labels = labels # pass the chosen prompt/row to the Prompter to get the formatted prompt chosen_message_list: MessageList = ( @@ -191,7 +191,7 @@ def tokenize_prompt(self, prompt): labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) return { - "rejected_input_ids": rejected_input_ids, + "rejected_ids": rejected_ids, "rejected_labels": rejected_labels, "rejected_attention_mask": [1] * len(rejected_labels), "input_ids": input_ids, diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index f2cdcac388..e141713e79 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -349,3 +349,14 @@ def handle_long_seq_in_dataset( ) return dataset + + +def remove_double_bos_token(example: dict[str, list], bos_token_id: int): + """Remove double bos tokens that may occur when retokenizing preprocessed data + for tokenizers and chat templates that have a bos_token - eg. DPO + Llama. + """ + input_ids = example["input_ids"] + if len(input_ids) >= 2 and input_ids[0] == input_ids[1] == bos_token_id: + for key in example: + example[key] = example[key][1:] + return example diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 982e3e4199..2a85d6c735 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -294,7 +294,6 @@ class AxolotlInputConfig( }, ) dpo_label_smoothing: float | None = None - dpo_norm_loss: bool | None = None dpo_use_liger_kernel: bool | None = Field( default=None, @@ -1111,12 +1110,6 @@ class AxolotlInputConfig( "description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping." }, ) - rpo_alpha: float | None = Field( - default=None, - json_schema_extra={ - "description": "Weighting of NLL term in loss from RPO paper" - }, - ) simpo_gamma: float | None = Field( default=None, json_schema_extra={"description": "Target reward margin for the SimPO loss"}, diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py index 62b26949e2..d87eb9d964 100644 --- a/src/axolotl/utils/schemas/deprecated.py +++ b/src/axolotl/utils/schemas/deprecated.py @@ -21,6 +21,8 @@ class DeprecatedParameters(BaseModel): eval_max_new_tokens: int | None = None dpo_use_logits_to_keep: bool | None = None dpo_generate_during_eval: bool | None = None + dpo_norm_loss: bool | None = None + rpo_alpha: float | None = None @field_validator("max_packed_sequence_len") @classmethod @@ -100,6 +102,26 @@ def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval): ) return dpo_generate_during_eval + @field_validator("dpo_norm_loss") + @classmethod + def validate_dpo_norm_loss(cls, dpo_norm_loss): + if dpo_norm_loss is not None: + raise DeprecationWarning( + "`dpo_norm_loss` is no longer supported, " + "due to breaking changes in TRL >= 0.29.0" + ) + return dpo_norm_loss + + @field_validator("rpo_alpha") + @classmethod + def validate_rpo_alpha(cls, rpo_alpha): + if rpo_alpha is not None: + raise DeprecationWarning( + "`rpo_alpha` has been deprecated in TRL >= 0.29.0, " + "and now requires passing multiple loss types, which is not yet supported by Axolotl." + ) # TODO: change this warning once multiple dpo loss types are supported. + return rpo_alpha + class RemappedParameters(BaseModel): """Parameters that have been remapped to other names""" diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index fc6fb73675..0aca1807cd 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -67,55 +67,6 @@ def test_dpo_lora(self, temp_dir): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) - @with_temp_dir - def test_dpo_nll_lora(self, temp_dir): - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "tokenizer_type": "AutoTokenizer", - "sequence_len": 1024, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 64, - "lora_alpha": 32, - "lora_dropout": 0.1, - "lora_target_linear": True, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "rl": "dpo", - "rpo_alpha": 0.5, - "datasets": [ - { - "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", - "type": "chatml.ultra", - "split": "train", - }, - ], - "num_epochs": 1, - "micro_batch_size": 4, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "paged_adamw_8bit", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "warmup_steps": 5, - "gradient_checkpointing": True, - "gradient_checkpointing_kwargs": {"use_reentrant": True}, - "save_first_step": False, - } - ) - - cfg = validate_config(cfg) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, dataset_meta=dataset_meta) - check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) - @with_temp_dir def test_dpo_use_weighting(self, temp_dir): cfg = DictDefault( diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 672643a923..181fddbc08 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -223,18 +223,18 @@ def test_orpo_integration( DictDefault({"chat_template": "chatml"}), ) res = strat.tokenize_prompt(ds[0]) - assert "rejected_input_ids" in res + assert "rejected_ids" in res assert "rejected_labels" in res assert "input_ids" in res assert "labels" in res assert "prompt_attention_mask" in res - assert len(res["rejected_input_ids"]) == len(res["rejected_labels"]) + assert len(res["rejected_ids"]) == len(res["rejected_labels"]) assert len(res["input_ids"]) == len(res["labels"]) assert len(res["input_ids"]) == len(res["prompt_attention_mask"]) assert res["rejected_labels"][0] == -100 - assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1] + assert res["rejected_ids"][-1] == res["rejected_labels"][-1] assert res["labels"][0] == -100 assert res["input_ids"][-1] == res["labels"][-1] diff --git a/tests/utils/data/test_utils.py b/tests/utils/data/test_utils.py index 357447b477..4b6afec3db 100644 --- a/tests/utils/data/test_utils.py +++ b/tests/utils/data/test_utils.py @@ -7,7 +7,7 @@ from datasets import Dataset -from axolotl.utils.data.utils import handle_long_seq_in_dataset +from axolotl.utils.data.utils import handle_long_seq_in_dataset, remove_double_bos_token from axolotl.utils.dict import DictDefault @@ -541,5 +541,33 @@ def test_invalid_strategy_falls_through_to_drop(self): self.assertEqual(len(result[0]["input_ids"]), 3) +class TestRemoveDoubleBOSToken(unittest.TestCase): + def test_no_remove_bos_token(self): + input_ids = [0, 1, 2] + labels = [1, 2, 3] + + example = { + "input_ids": input_ids, + "labels": labels, + } + + example = remove_double_bos_token(example, 0) + assert example["input_ids"] == input_ids + assert example["labels"] == labels + + def test_remove_bos_token(self): + input_ids = [0, 0, 1] + labels = [0, 1, 2] + + example = { + "input_ids": input_ids, + "labels": labels, + } + + example = remove_double_bos_token(example, 0) + assert example["input_ids"] == [0, 1] + assert example["labels"] == [1, 2] + + if __name__ == "__main__": unittest.main()