-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Remove truncation_mode from DPO #5372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a444d39
a9706c6
01cfab1
4bf9351
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,9 +110,6 @@ class DataCollatorForPreference(DataCollatorMixin): | |
| max_length (`int`, *optional*): | ||
| Maximum length of the sequences after concatenation. Sequences longer than `max_length` are truncated | ||
| before padding, which avoids allocating oversized tensors for batches containing very long sequences. | ||
| truncation_mode (`str`, *optional*, defaults to `"keep_start"`): | ||
| Truncation mode when a concatenated sequence exceeds `max_length`. Possible values are `"keep_end"` and | ||
| `"keep_start"`. | ||
| pad_to_multiple_of (`int`, *optional*): | ||
| If set, the sequences will be padded to a multiple of this value. | ||
| return_tensors (`str`, *optional*, defaults to `"pt"`): | ||
|
|
@@ -145,7 +142,6 @@ class DataCollatorForPreference(DataCollatorMixin): | |
|
|
||
| pad_token_id: int | ||
| max_length: int | None = None | ||
| truncation_mode: str = "keep_start" | ||
| pad_to_multiple_of: int | None = None | ||
| return_tensors: str = "pt" | ||
|
|
||
|
|
@@ -156,14 +152,7 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: | |
| rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples] | ||
|
|
||
| if self.max_length is not None: | ||
| if self.truncation_mode == "keep_start": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truncation mode removal not propagated to SFT trainerLow Severity The Additional Locations (1)Triggered by project rule: BUGBOT.md
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As commented in the description, if there is agreement, I'd be happy to follow up and apply the same change across other trainers for consistency. |
||
| sl = slice(None, self.max_length) | ||
| elif self.truncation_mode == "keep_end": | ||
| sl = slice(-self.max_length, None) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported truncation mode: {self.truncation_mode}, expected 'keep_start' or 'keep_end'" | ||
| ) | ||
| sl = slice(None, self.max_length) | ||
| prompt_chosen_ids = [ids[sl] for ids in prompt_chosen_ids] | ||
| prompt_rejected_ids = [ids[sl] for ids in prompt_rejected_ids] | ||
| chosen_mask = [m[sl] for m in chosen_mask] | ||
|
|
@@ -623,13 +612,6 @@ def __init__( | |
| "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " | ||
| "model does not seem to be a vision-language model. Please check your model and dataset." | ||
| ) | ||
| if self._is_vision_dataset and args.max_length is not None and args.truncation_mode == "keep_end": | ||
| raise ValueError( | ||
| "truncation_mode='keep_end' is not supported for vision-language models. Image tokens reside " | ||
| "inside the prompt portion of the sequence; depending on the example, keep_end may silently " | ||
| "drop them, causing pixel_values to be forwarded to the model with no corresponding visual " | ||
| "tokens in input_ids. Use truncation_mode='keep_start' (the default) or set max_length=None." | ||
| ) | ||
| if data_collator is None and not self._is_vision_dataset: | ||
| # Get the pad token: if not provided, use the one from the processing class or the eos token | ||
| # if the processing class does not have a pad token. | ||
|
|
@@ -644,7 +626,6 @@ def __init__( | |
| data_collator = DataCollatorForPreference( | ||
| pad_token_id=pad_token_id, | ||
| max_length=args.max_length, | ||
| truncation_mode=args.truncation_mode, | ||
| pad_to_multiple_of=args.pad_to_multiple_of, | ||
| ) | ||
| elif data_collator is None and self._is_vision_dataset: | ||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent truncation_mode removal across trainers violates consistency rule
Medium Severity
The
truncation_modeparameter and its associated logic (config field, collator field, truncation branching, VLM guard) are duplicated across the DPO and SFT trainers. This PR removestruncation_modefromDPOConfig,DataCollatorForPreference, andDPOTrainer, but the identical pattern remains inSFTConfig,DataCollatorForLanguageModeling, andSFTTrainer. Per the project'sAGENTS.mdconsistency rules, changes to duplicated logic across trainers must be propagated to all copies.Additional Locations (1)
trl/trainer/dpo_trainer.py#L154-L155Triggered by project rule: ../.ai/AGENTS.md
Reviewed by Cursor Bugbot for commit 4bf9351. Configure here.