Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
List of image objects to insert in the messages.

Returns:
`list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured
`list[dict[str, Any]]`: A new list of messages where every `"content"` value is a list of structured
content blocks, and all `"image"` placeholders are populated with the corresponding image objects. If the
assistant turns contains `"tool_calls"`, then the `"content"` might be empty.

Expand Down Expand Up @@ -72,36 +72,38 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
```
"""
images = images or []
messages = copy.deepcopy(messages) # avoid modifying the original messages

# First, convert all messages to the structured format if needed, and insert image placeholders if needed
# First, convert all messages to the structured format if needed, and insert image placeholders if needed.
# Build new message dicts only when transforming string content to avoid modifying the originals.
new_messages = []
images_included = False
for message in messages:
if message["role"] == "system":
if isinstance(message["content"], str): # if already prepared, the content will be a list
message["content"] = [{"type": "text", "text": message["content"]}]
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
elif message["role"] == "user":
if isinstance(message["content"], str) and not images_included:
image_entries = [{"type": "image"} for _ in range(len(images))]
message["content"] = [*image_entries, {"type": "text", "text": message["content"]}]
message = {**message, "content": [*image_entries, {"type": "text", "text": message["content"]}]}
images_included = True
elif isinstance(message["content"], str) and images_included:
message["content"] = [{"type": "text", "text": message["content"]}]
elif isinstance(message["content"], str):
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
elif message["role"] == "assistant":
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
elif message["role"] == "tool":
if message.get("content") and isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
else:
raise ValueError(
f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'."
)
new_messages.append(message)

# Then, check that the number of image placeholders matches the number of images provided
num_placeholders = sum(
sum(1 for part in message["content"] if part["type"] == "image")
for message in messages
for message in new_messages
if message.get("content") and message["role"] != "tool"
)
if num_placeholders != len(images):
Expand All @@ -111,15 +113,15 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N

# Then, fill in the actual images in the placeholders
img_idx = 0
for message in messages:
for message in new_messages:
if not message.get("content") or message["role"] == "tool":
continue
for part in message["content"]:
if part["type"] == "image":
part["image"] = images[img_idx]
img_idx += 1

return messages
return new_messages


def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
Expand Down
Loading