diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 72532b1f64..a423273973 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -235,7 +235,7 @@ def test_message_with_tool_calling_turns(self): } ], }, - {"role": "tool", "name": "get_current_weather", "content": "22.0"}, + {"role": "tool", "name": "get_current_weather", "content": [{"type": "text", "text": "22.0"}]}, { "role": "assistant", "content": [{"type": "text", "text": "The current weather in New York is 22.0 degrees Celsius."}], diff --git a/trl/data_utils.py b/trl/data_utils.py index 44a524ab01..f72c9e5693 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -91,9 +91,8 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N if message.get("content") and isinstance(message["content"], str): message["content"] = [{"type": "text", "text": message["content"]}] elif message["role"] == "tool": - # NOTE: `tool` contains `name` (name of the tool used) and `content` (output of the tool call as a string) - # but there's no need to prepare it for multimodal specifically but rather leave it as-is - continue + if message.get("content") and isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] else: raise ValueError( f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'." diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e949181e20..d775baefd3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1417,6 +1417,8 @@ def _get_tool_suffix_ids(self, tool_messages): ] if self._is_vlm: dummy_messages = prepare_multimodal_messages(dummy_messages) + tool_messages = prepare_multimodal_messages(tool_messages) + prefix_ids = self.processing_class.apply_chat_template( dummy_messages, add_generation_prompt=False, @@ -1425,66 +1427,18 @@ def _get_tool_suffix_ids(self, tool_messages): return_dict=False, **self.chat_template_kwargs, ) + full_ids = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + tokenize=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) # VLM processors return batched output (list of lists), unbatch for single conversation if self._is_vlm: prefix_ids = prefix_ids[0] - - # Check if tool messages contain images (multimodal tool responses) - tool_images = [] - for msg in tool_messages: - if isinstance(msg.get("content"), list): - for part in msg["content"]: - if part["type"] == "image": - tool_images.append(part["image"]) - - # Normalize string content in tool messages for VLM processors before either path. - # Use copies to avoid mutating the original completions data. - if self._is_vlm: - tool_messages = [ - {**msg, "content": [{"type": "text", "text": msg["content"]}]} - if isinstance(msg.get("content"), str) - else msg - for msg in tool_messages - ] - - if tool_images and self._is_vlm: - # For VLMs with images: use processor.__call__ to get correctly expanded image tokens. - # apply_chat_template only inserts a single <|image_pad|> placeholder per image, - # but the model needs N tokens per image (based on resolution). The processor's - # __call__ handles this expansion. - # Use the same tokenization method (processor.__call__) for both prefix and full to - # avoid mismatches from different tokenization paths. - prefix_text = self.processing_class.apply_chat_template( - dummy_messages, - add_generation_prompt=False, - tokenize=False, - chat_template=self.chat_template, - **self.chat_template_kwargs, - ) - prefix_ids = self.processing_class(text=prefix_text, return_tensors="pt")["input_ids"][0].tolist() - full_text = self.processing_class.apply_chat_template( - dummy_messages + tool_messages, - add_generation_prompt=True, - tokenize=False, - chat_template=self.chat_template, - **self.chat_template_kwargs, - ) - # We only need input_ids (for suffix token extraction). pixel_values and image_grid_thw - # are computed separately in the forward pass via image_processor to avoid mismatches. - full_ids = self.processing_class(text=full_text, images=tool_images, return_tensors="pt")["input_ids"][ - 0 - ].tolist() - else: - full_ids = self.processing_class.apply_chat_template( - dummy_messages + tool_messages, - add_generation_prompt=True, - tokenize=True, - chat_template=self.chat_template, - return_dict=False, - **self.chat_template_kwargs, - ) - if self._is_vlm: - full_ids = full_ids[0] + full_ids = full_ids[0] # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to