diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index b0b85c541e..192afeccb2 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1289,12 +1289,13 @@ def test_train_vlm_keep_end_raises(self): # Regression test for #5285: keep_end with a VLM must raise at init time, not silently corrupt training. # Image tokens live at the start of the sequence (in the prompt); keep_end would drop them. dataset = load_dataset("trl-internal-testing/zen-image", "conversational_preference", split="train") - training_args = DPOConfig( - output_dir=self.tmp_dir, - max_length=32, - truncation_mode="keep_end", - report_to="none", - ) + with pytest.warns(FutureWarning, match="keep_end.*deprecated"): + training_args = DPOConfig( + output_dir=self.tmp_dir, + max_length=32, + truncation_mode="keep_end", + report_to="none", + ) with pytest.raises(ValueError, match="truncation_mode='keep_end' is not supported for vision-language models"): DPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 679d237788..5363262d4e 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -987,13 +987,14 @@ def tokenize_example(example): def test_skip_prepare_dataset_passes_truncation_to_text_collator(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]") - training_args = SFTConfig( - output_dir=self.tmp_dir, - max_length=16, - truncation_mode="keep_end", - dataset_kwargs={"skip_prepare_dataset": True}, - report_to="none", - ) + with pytest.warns(FutureWarning, match="keep_end.*deprecated"): + training_args = SFTConfig( + output_dir=self.tmp_dir, + max_length=16, + truncation_mode="keep_end", + dataset_kwargs={"skip_prepare_dataset": True}, + report_to="none", + ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 0363d574a1..eb1343b3f4 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Any @@ -52,8 +53,8 @@ class DPOConfig(_BaseConfig): Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the left or right depending on the `truncation_mode`. If `None`, no truncation is applied. truncation_mode (`str`, *optional*, defaults to `"keep_start"`): - Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. + Truncation mode to use when the sequence exceeds `max_length`. The only supported value is + `"keep_start"`. The `"keep_end"` value is deprecated and will be removed in v2.0.0. padding_free (`bool`, *optional*, defaults to `False`): Whether to perform forward passes without padding by flattening all sequences in the batch into a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only @@ -170,8 +171,8 @@ class DPOConfig(_BaseConfig): truncation_mode: str = field( default="keep_start", metadata={ - "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " - "and `'keep_start'`.", + "help": "Truncation mode to use when the sequence exceeds `max_length`. The only supported value is " + "`'keep_start'`. The `'keep_end'` value is deprecated and will be removed in v2.0.0.", "choices": ["keep_end", "keep_start"], }, ) @@ -316,4 +317,12 @@ def __post_init__(self): f"Got {len(self.loss_weights)} weights for {len(self.loss_type)} loss types." ) + if self.truncation_mode == "keep_end": + warnings.warn( + "The `'keep_end'` truncation mode is deprecated and will be removed in v2.0.0. " + "Use `truncation_mode='keep_start'` (the default) instead.", + FutureWarning, + stacklevel=3, + ) + super().__post_init__() diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index f279846d43..5e925c25b6 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -68,8 +68,8 @@ class SFTConfig(_BaseConfig): or right depending on `truncation_mode`. If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. truncation_mode (`str`, *optional*, defaults to `"keep_start"`): - Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. + Truncation mode to use when the sequence exceeds `max_length`. The only supported value is + `"keep_start"`. The `"keep_end"` value is deprecated and will be removed in v2.0.0. shuffle_dataset (`bool`, *optional*, defaults to `False`): Whether to shuffle the dataset. packing (`bool`, *optional*, defaults to `False`): @@ -185,8 +185,8 @@ class SFTConfig(_BaseConfig): truncation_mode: str = field( default="keep_start", metadata={ - "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " - "and `'keep_start'`.", + "help": "Truncation mode to use when the sequence exceeds `max_length`. The only supported value is " + "`'keep_start'`. The `'keep_end'` value is deprecated and will be removed in v2.0.0.", "choices": ["keep_end", "keep_start"], }, ) @@ -269,6 +269,14 @@ class SFTConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.truncation_mode == "keep_end": + warnings.warn( + "The `'keep_end'` truncation mode is deprecated and will be removed in v2.0.0. " + "Use `truncation_mode='keep_start'` (the default) instead.", + FutureWarning, + stacklevel=3, + ) + if self.packing_strategy == "bfd-requeue": warnings.warn( "The `bfd-requeue` packing strategy has been renamed to `bfd_split`. Please update your configuration accordingly. "