Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,8 @@ class DPOConfig(_BaseConfig):
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
it falls back to `processing_class.eos_token`.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the left or
right depending on the `truncation_mode`. If `None`, no truncation is applied.
truncation_mode (`str`, *optional*, defaults to `"keep_start"`):
Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
`"keep_start"`.
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the end.
If `None`, no truncation is applied.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent truncation_mode removal across trainers violates consistency rule

Medium Severity

The truncation_mode parameter and its associated logic (config field, collator field, truncation branching, VLM guard) are duplicated across the DPO and SFT trainers. This PR removes truncation_mode from DPOConfig, DataCollatorForPreference, and DPOTrainer, but the identical pattern remains in SFTConfig, DataCollatorForLanguageModeling, and SFTTrainer. Per the project's AGENTS.md consistency rules, changes to duplicated logic across trainers must be propagated to all copies.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Reviewed by Cursor Bugbot for commit 4bf9351. Configure here.

padding_free (`bool`, *optional*, defaults to `False`):
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
Expand Down Expand Up @@ -164,15 +161,7 @@ class DPOConfig(_BaseConfig):
default=1024,
metadata={
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from "
"the left or right depending on the `truncation_mode`. If `None`, no truncation is applied."
},
)
truncation_mode: str = field(
default="keep_start",
metadata={
"help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` "
"and `'keep_start'`.",
"choices": ["keep_end", "keep_start"],
"the end. If `None`, no truncation is applied."
},
)
padding_free: bool = field(
Expand Down
21 changes: 1 addition & 20 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ class DataCollatorForPreference(DataCollatorMixin):
max_length (`int`, *optional*):
Maximum length of the sequences after concatenation. Sequences longer than `max_length` are truncated
before padding, which avoids allocating oversized tensors for batches containing very long sequences.
truncation_mode (`str`, *optional*, defaults to `"keep_start"`):
Truncation mode when a concatenated sequence exceeds `max_length`. Possible values are `"keep_end"` and
`"keep_start"`.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Expand Down Expand Up @@ -146,7 +143,6 @@ class DataCollatorForPreference(DataCollatorMixin):

pad_token_id: int
max_length: int | None = None
truncation_mode: str = "keep_start"
pad_to_multiple_of: int | None = None
return_tensors: str = "pt"

Expand All @@ -157,14 +153,7 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples]

if self.max_length is not None:
if self.truncation_mode == "keep_start":
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncation mode removal not propagated to SFT trainer

Low Severity

The truncation_mode removal from DataCollatorForPreference and DPOConfig was not propagated to the SFT trainer, which has the identical pattern in DataCollatorForLanguageModeling (including the same keep_start/keep_end if/elif/else block) and SFTConfig. The VLM keep_end guard is also duplicated verbatim in SFTTrainer.__init__. This creates an inconsistency between the two main-code trainers.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: BUGBOT.md

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As commented in the description, if there is agreement, I'd be happy to follow up and apply the same change across other trainers for consistency.

sl = slice(None, self.max_length)
elif self.truncation_mode == "keep_end":
sl = slice(-self.max_length, None)
else:
raise ValueError(
f"Unsupported truncation mode: {self.truncation_mode}, expected 'keep_start' or 'keep_end'"
)
sl = slice(None, self.max_length)
prompt_chosen_ids = [ids[sl] for ids in prompt_chosen_ids]
prompt_rejected_ids = [ids[sl] for ids in prompt_rejected_ids]
chosen_mask = [m[sl] for m in chosen_mask]
Expand Down Expand Up @@ -624,13 +613,6 @@ def __init__(
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
"model does not seem to be a vision-language model. Please check your model and dataset."
)
if self._is_vision_dataset and args.max_length is not None and args.truncation_mode == "keep_end":
raise ValueError(
"truncation_mode='keep_end' is not supported for vision-language models. Image tokens reside "
"inside the prompt portion of the sequence; depending on the example, keep_end may silently "
"drop them, causing pixel_values to be forwarded to the model with no corresponding visual "
"tokens in input_ids. Use truncation_mode='keep_start' (the default) or set max_length=None."
)
if data_collator is None and not self._is_vision_dataset:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
Expand All @@ -645,7 +627,6 @@ def __init__(
data_collator = DataCollatorForPreference(
pad_token_id=pad_token_id,
max_length=args.max_length,
truncation_mode=args.truncation_mode,
pad_to_multiple_of=args.pad_to_multiple_of,
)
elif data_collator is None and self._is_vision_dataset:
Expand Down
Loading