Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
elif message["role"] == "tool":
# NOTE: `tool` contains `name` (name of the tool used) and `content` (output of the tool call as a string)
# but there's no need to prepare it for multimodal specifically but rather leave it as-is
continue
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
else:
raise ValueError(
f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'."
Expand All @@ -103,7 +102,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
num_placeholders = sum(
sum(1 for part in message["content"] if part["type"] == "image")
for message in messages
if message.get("content") and message["role"] != "tool"
if message.get("content") and isinstance(message["content"], list)
)
if num_placeholders != len(images):
raise ValueError(
Expand All @@ -113,7 +112,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
# Then, fill in the actual images in the placeholders
img_idx = 0
for message in messages:
if not message.get("content") or message["role"] == "tool":
if not message.get("content") or not isinstance(message["content"], list):
continue
for part in message["content"]:
if part["type"] == "image":
Expand Down
68 changes: 11 additions & 57 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,8 @@ def _get_tool_suffix_ids(self, tool_messages):
dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}]
if self._is_vlm:
dummy_messages = prepare_multimodal_messages(dummy_messages)
tool_messages = prepare_multimodal_messages(tool_messages)

prefix_ids = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
Expand All @@ -1415,66 +1417,18 @@ def _get_tool_suffix_ids(self, tool_messages):
return_dict=False,
**self.chat_template_kwargs,
)
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
# VLM processors return batched output (list of lists), unbatch for single conversation
if self._is_vlm:
prefix_ids = prefix_ids[0]

# Check if tool messages contain images (multimodal tool responses)
tool_images = []
for msg in tool_messages:
if isinstance(msg.get("content"), list):
for part in msg["content"]:
if part["type"] == "image":
tool_images.append(part["image"])

# Normalize string content in tool messages for VLM processors before either path.
# Use copies to avoid mutating the original completions data.
if self._is_vlm:
tool_messages = [
{**msg, "content": [{"type": "text", "text": msg["content"]}]}
if isinstance(msg.get("content"), str)
else msg
for msg in tool_messages
]

if tool_images and self._is_vlm:
# For VLMs with images: use processor.__call__ to get correctly expanded image tokens.
# apply_chat_template only inserts a single <|image_pad|> placeholder per image,
# but the model needs N tokens per image (based on resolution). The processor's
# __call__ handles this expansion.
# Use the same tokenization method (processor.__call__) for both prefix and full to
# avoid mismatches from different tokenization paths.
prefix_text = self.processing_class.apply_chat_template(
dummy_messages,
add_generation_prompt=False,
tokenize=False,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
prefix_ids = self.processing_class(text=prefix_text, return_tensors="pt")["input_ids"][0].tolist()
full_text = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=False,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
# We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw
# are computed separately in the forward pass via image_processor to avoid mismatches.
full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")["input_ids"][
0
].tolist()
else:
full_ids = self.processing_class.apply_chat_template(
dummy_messages + tool_messages,
add_generation_prompt=True,
tokenize=True,
chat_template=self.chat_template,
return_dict=False,
**self.chat_template_kwargs,
)
if self._is_vlm:
full_ids = full_ids[0]
full_ids = full_ids[0]

# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
Expand Down
Loading