diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index b0b85c541e..f05662a720 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1283,21 +1283,3 @@ def test_peft_with_quantization(self): assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") - - @require_vision - def test_train_vlm_keep_end_raises(self): - # Regression test for #5285: keep_end with a VLM must raise at init time, not silently corrupt training. - # Image tokens live at the start of the sequence (in the prompt); keep_end would drop them. - dataset = load_dataset("trl-internal-testing/zen-image", "conversational_preference", split="train") - training_args = DPOConfig( - output_dir=self.tmp_dir, - max_length=32, - truncation_mode="keep_end", - report_to="none", - ) - with pytest.raises(ValueError, match="truncation_mode='keep_end' is not supported for vision-language models"): - DPOTrainer( - model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - args=training_args, - train_dataset=dataset, - ) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 0363d574a1..b77ad4ce45 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -49,11 +49,8 @@ class DPOConfig(_BaseConfig): Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, it falls back to `processing_class.eos_token`. max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the left or - right depending on the `truncation_mode`. If `None`, no truncation is applied. - truncation_mode (`str`, *optional*, defaults to `"keep_start"`): - Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the end. + If `None`, no truncation is applied. padding_free (`bool`, *optional*, defaults to `False`): Whether to perform forward passes without padding by flattening all sequences in the batch into a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only @@ -164,15 +161,7 @@ class DPOConfig(_BaseConfig): default=1024, metadata={ "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from " - "the left or right depending on the `truncation_mode`. If `None`, no truncation is applied." - }, - ) - truncation_mode: str = field( - default="keep_start", - metadata={ - "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " - "and `'keep_start'`.", - "choices": ["keep_end", "keep_start"], + "the end. If `None`, no truncation is applied." }, ) padding_free: bool = field( diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index d688a87367..89e9451bf1 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -110,9 +110,6 @@ class DataCollatorForPreference(DataCollatorMixin): max_length (`int`, *optional*): Maximum length of the sequences after concatenation. Sequences longer than `max_length` are truncated before padding, which avoids allocating oversized tensors for batches containing very long sequences. - truncation_mode (`str`, *optional*, defaults to `"keep_start"`): - Truncation mode when a concatenated sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. pad_to_multiple_of (`int`, *optional*): If set, the sequences will be padded to a multiple of this value. return_tensors (`str`, *optional*, defaults to `"pt"`): @@ -145,7 +142,6 @@ class DataCollatorForPreference(DataCollatorMixin): pad_token_id: int max_length: int | None = None - truncation_mode: str = "keep_start" pad_to_multiple_of: int | None = None return_tensors: str = "pt" @@ -156,14 +152,7 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples] if self.max_length is not None: - if self.truncation_mode == "keep_start": - sl = slice(None, self.max_length) - elif self.truncation_mode == "keep_end": - sl = slice(-self.max_length, None) - else: - raise ValueError( - f"Unsupported truncation mode: {self.truncation_mode}, expected 'keep_start' or 'keep_end'" - ) + sl = slice(None, self.max_length) prompt_chosen_ids = [ids[sl] for ids in prompt_chosen_ids] prompt_rejected_ids = [ids[sl] for ids in prompt_rejected_ids] chosen_mask = [m[sl] for m in chosen_mask] @@ -623,13 +612,6 @@ def __init__( "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " "model does not seem to be a vision-language model. Please check your model and dataset." ) - if self._is_vision_dataset and args.max_length is not None and args.truncation_mode == "keep_end": - raise ValueError( - "truncation_mode='keep_end' is not supported for vision-language models. Image tokens reside " - "inside the prompt portion of the sequence; depending on the example, keep_end may silently " - "drop them, causing pixel_values to be forwarded to the model with no corresponding visual " - "tokens in input_ids. Use truncation_mode='keep_start' (the default) or set max_length=None." - ) if data_collator is None and not self._is_vision_dataset: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. @@ -644,7 +626,6 @@ def __init__( data_collator = DataCollatorForPreference( pad_token_id=pad_token_id, max_length=args.max_length, - truncation_mode=args.truncation_mode, pad_to_multiple_of=args.pad_to_multiple_of, ) elif data_collator is None and self._is_vision_dataset: