Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 10 additions & 57 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,7 @@ def _get_tool_suffix_ids(self, tool_messages):
dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
if self._is_vlm:
dummy_messages = prepare_multimodal_messages(dummy_messages)

prefix_ids = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
Expand All @@ -1433,66 +1434,18 @@ def _get_tool_suffix_ids(self, tool_messages):
return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
# VLM processors return batched output (list of lists), unbatch for single conversation
if self._is_vlm:
prefix_ids = prefix_ids[0]

# Check if tool messages contain images (multimodal tool responses)
tool_images = []
for msg in tool_messages:
if isinstance(msg.get("content"), list):
for part in msg["content"]:
if part["type"] == "image":
tool_images.append(part["image"])

# Normalize string content in tool messages for VLM processors before either path.
# Use copies to avoid mutating the original completions data.
if self._is_vlm:
tool_messages = [
{**msg, "content": [{"type": "text", "text": msg["content"]}]}
if isinstance(msg.get("content"), str)
else msg
for msg in tool_messages
]

if tool_images and self._is_vlm:
# For VLMs with images: use processor.__call__ to get correctly expanded image tokens.
# apply_chat_template only inserts a single <|image_pad|> placeholder per image,
# but the model needs N tokens per image (based on resolution). The processor's
# __call__ handles this expansion.
# Use the same tokenization method (processor.__call__) for both prefix and full to
# avoid mismatches from different tokenization paths.
prefix_text = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
tokenize=False,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
prefix_ids = self.processing_class(text=prefix_text, return_tensors="pt")["input_ids"][0].tolist()
full_text = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=False,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
# We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw
# are computed separately in the forward pass via image_processor to avoid mismatches.
full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")["input_ids"][
0
].tolist()
else:
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
if self._is_vlm:
full_ids = full_ids[0]
full_ids = full_ids[0]

# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
Expand Down
Loading