diff --git a/trl/data_utils.py b/trl/data_utils.py index f72c9e5693..7b49132f05 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -44,7 +44,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N List of image objects to insert in the messages. Returns: - `list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured + `list[dict[str, Any]]`: A new list of messages where every `"content"` value is a list of structured content blocks, and all `"image"` placeholders are populated with the corresponding image objects. If the assistant turns contains `"tool_calls"`, then the `"content"` might be empty. @@ -72,36 +72,38 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N ``` """ images = images or [] - messages = copy.deepcopy(messages) # avoid modifying the original messages - # First, convert all messages to the structured format if needed, and insert image placeholders if needed + # First, convert all messages to the structured format if needed, and insert image placeholders if needed. + # Build new message dicts only when transforming string content to avoid modifying the originals. + new_messages = [] images_included = False for message in messages: if message["role"] == "system": if isinstance(message["content"], str): # if already prepared, the content will be a list - message["content"] = [{"type": "text", "text": message["content"]}] + message = {**message, "content": [{"type": "text", "text": message["content"]}]} elif message["role"] == "user": if isinstance(message["content"], str) and not images_included: image_entries = [{"type": "image"} for _ in range(len(images))] - message["content"] = [*image_entries, {"type": "text", "text": message["content"]}] + message = {**message, "content": [*image_entries, {"type": "text", "text": message["content"]}]} images_included = True - elif isinstance(message["content"], str) and images_included: - message["content"] = [{"type": "text", "text": message["content"]}] + elif isinstance(message["content"], str): + message = {**message, "content": [{"type": "text", "text": message["content"]}]} elif message["role"] == "assistant": if message.get("content") and isinstance(message["content"], str): - message["content"] = [{"type": "text", "text": message["content"]}] + message = {**message, "content": [{"type": "text", "text": message["content"]}]} elif message["role"] == "tool": if message.get("content") and isinstance(message["content"], str): - message["content"] = [{"type": "text", "text": message["content"]}] + message = {**message, "content": [{"type": "text", "text": message["content"]}]} else: raise ValueError( f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'." ) + new_messages.append(message) # Then, check that the number of image placeholders matches the number of images provided num_placeholders = sum( sum(1 for part in message["content"] if part["type"] == "image") - for message in messages + for message in new_messages if message.get("content") and message["role"] != "tool" ) if num_placeholders != len(images): @@ -111,15 +113,19 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N # Then, fill in the actual images in the placeholders img_idx = 0 - for message in messages: + for i, message in enumerate(new_messages): if not message.get("content") or message["role"] == "tool": continue + new_content = [] for part in message["content"]: if part["type"] == "image": - part["image"] = images[img_idx] + new_content.append({**part, "image": images[img_idx]}) img_idx += 1 + else: + new_content.append(part) + new_messages[i] = {**message, "content": new_content} - return messages + return new_messages def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: