Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
List of image objects to insert in the messages.

Returns:
`list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured
`list[dict[str, Any]]`: A new list of messages where every `"content"` value is a list of structured
content blocks, and all `"image"` placeholders are populated with the corresponding image objects. If the
assistant turns contains `"tool_calls"`, then the `"content"` might be empty.

Expand Down Expand Up @@ -72,36 +72,38 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
```
"""
images = images or []
messages = copy.deepcopy(messages) # avoid modifying the original messages

# First, convert all messages to the structured format if needed, and insert image placeholders if needed
# First, convert all messages to the structured format if needed, and insert image placeholders if needed.
# Build new message dicts only when transforming string content to avoid modifying the originals.
new_messages = []
images_included = False
for message in messages:
if message["role"] == "system":
if isinstance(message["content"], str): # if already prepared, the content will be a list
message["content"] = [{"type": "text", "text": message["content"]}]
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
elif message["role"] == "user":
if isinstance(message["content"], str) and not images_included:
image_entries = [{"type": "image"} for _ in range(len(images))]
message["content"] = [*image_entries, {"type": "text", "text": message["content"]}]
message = {**message, "content": [*image_entries, {"type": "text", "text": message["content"]}]}
images_included = True
elif isinstance(message["content"], str) and images_included:
message["content"] = [{"type": "text", "text": message["content"]}]
elif isinstance(message["content"], str):
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
elif message["role"] == "assistant":
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
elif message["role"] == "tool":
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
else:
raise ValueError(
f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'."
)
new_messages.append(message)

# Then, check that the number of image placeholders matches the number of images provided
num_placeholders = sum(
sum(1 for part in message["content"] if part["type"] == "image")
for message in messages
for message in new_messages
if message.get("content") and message["role"] != "tool"
)
if num_placeholders != len(images):
Expand All @@ -111,15 +113,19 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N

# Then, fill in the actual images in the placeholders
img_idx = 0
for message in messages:
for i, message in enumerate(new_messages):
if not message.get("content") or message["role"] == "tool":
continue
new_content = []
for part in message["content"]:
if part["type"] == "image":
part["image"] = images[img_idx]
new_content.append({**part, "image": images[img_idx]})
img_idx += 1
else:
new_content.append(part)
Comment on lines +124 to +125
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Deep-copy non-image blocks when rebuilding content

This branch reuses the original block dict object for every non-image part, so the returned structure aliases nested objects from the caller’s input when messages are already in structured format. Any downstream in-place edit of prepare_multimodal_messages(...) output (for example, adding keys to text/tool blocks) will mutate the original messages, which is a regression from the previous deep-copy behavior and breaks practical immutability expectations for this helper.

Useful? React with 👍 / 👎.

new_messages[i] = {**message, "content": new_content}

return messages
return new_messages


def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
Expand Down
Loading