Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_message_with_tool_calling_turns(self):
}
],
},
{"role": "tool", "name": "get_current_weather", "content": "22.0"},
{"role": "tool", "name": "get_current_weather", "content": [{"type": "text", "text": "22.0"}]},
{"role": "assistant", "content": "The current weather in New York is 22.0 degrees Celsius."},
]

Expand Down
5 changes: 2 additions & 3 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
elif message["role"] == "tool":
# NOTE: `tool` contains `name` (name of the tool used) and `content` (output of the tool call as a string)
# but there's no need to prepare it for multimodal specifically but rather leave it as-is
continue
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
else:
raise ValueError(
f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'."
Expand Down
68 changes: 11 additions & 57 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,8 @@ def _get_tool_suffix_ids(self, tool_messages):
]
if self._is_vlm:
dummy_messages = prepare_multimodal_messages(dummy_messages)
tool_messages = prepare_multimodal_messages(tool_messages)

prefix_ids = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
Expand All @@ -1425,66 +1427,18 @@ def _get_tool_suffix_ids(self, tool_messages):
return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
# VLM processors return batched output (list of lists), unbatch for single conversation
if self._is_vlm:
prefix_ids = prefix_ids[0]

# Check if tool messages contain images (multimodal tool responses)
tool_images = []
for msg in tool_messages:
if isinstance(msg.get("content"), list):
for part in msg["content"]:
if part["type"] == "image":
tool_images.append(part["image"])

# Normalize string content in tool messages for VLM processors before either path.
# Use copies to avoid mutating the original completions data.
if self._is_vlm:
tool_messages = [
{**msg, "content": [{"type": "text", "text": msg["content"]}]}
if isinstance(msg.get("content"), str)
else msg
for msg in tool_messages
]

if tool_images and self._is_vlm:
# For VLMs with images: use processor.__call__ to get correctly expanded image tokens.
# apply_chat_template only inserts a single <|image_pad|> placeholder per image,
# but the model needs N tokens per image (based on resolution). The processor's
# __call__ handles this expansion.
# Use the same tokenization method (processor.__call__) for both prefix and full to
# avoid mismatches from different tokenization paths.
prefix_text = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
tokenize=False,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
prefix_ids = self.processing_class(text=prefix_text, return_tensors="pt")["input_ids"][0].tolist()
full_text = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=False,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
# We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw
# are computed separately in the forward pass via image_processor to avoid mismatches.
full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")["input_ids"][
0
].tolist()
else:
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
if self._is_vlm:
full_ids = full_ids[0]
full_ids = full_ids[0]

# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
Expand Down
Loading