Skip to content

Commit 1e667d8

Browse files
Avoid image deepcopy in prepare_multimodal_messages (#5475)
1 parent dd071d7 commit 1e667d8

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

trl/data_utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
4444
List of image objects to insert in the messages.
4545
4646
Returns:
47-
`list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured
47+
`list[dict[str, Any]]`: A new list of messages where every `"content"` value is a list of structured
4848
content blocks, and all `"image"` placeholders are populated with the corresponding image objects. If the
4949
assistant turns contains `"tool_calls"`, then the `"content"` might be empty.
5050
@@ -70,36 +70,38 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
7070
```
7171
"""
7272
images = images or []
73-
messages = copy.deepcopy(messages) # avoid modifying the original messages
7473

75-
# First, convert all messages to the structured format if needed, and insert image placeholders if needed
74+
# First, convert all messages to the structured format if needed, and insert image placeholders if needed.
75+
# Build new message dicts only when transforming string content to avoid modifying the originals.
76+
new_messages = []
7677
images_included = False
7778
for message in messages:
7879
if message["role"] == "system":
7980
if isinstance(message["content"], str): # if already prepared, the content will be a list
80-
message["content"] = [{"type": "text", "text": message["content"]}]
81+
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
8182
elif message["role"] == "user":
8283
if isinstance(message["content"], str) and not images_included:
8384
image_entries = [{"type": "image"} for _ in range(len(images))]
84-
message["content"] = [*image_entries, {"type": "text", "text": message["content"]}]
85+
message = {**message, "content": [*image_entries, {"type": "text", "text": message["content"]}]}
8586
images_included = True
86-
elif isinstance(message["content"], str) and images_included:
87-
message["content"] = [{"type": "text", "text": message["content"]}]
87+
elif isinstance(message["content"], str):
88+
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
8889
elif message["role"] == "assistant":
8990
if message.get("content") and isinstance(message["content"], str):
90-
message["content"] = [{"type": "text", "text": message["content"]}]
91+
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
9192
elif message["role"] == "tool":
9293
if message.get("content") and isinstance(message["content"], str):
93-
message["content"] = [{"type": "text", "text": message["content"]}]
94+
message = {**message, "content": [{"type": "text", "text": message["content"]}]}
9495
else:
9596
raise ValueError(
9697
f"Invalid role in message: {message['role']}. Expected 'system', 'user', 'assistant', or 'tool'."
9798
)
99+
new_messages.append(message)
98100

99101
# Then, check that the number of image placeholders matches the number of images provided
100102
num_placeholders = sum(
101103
sum(1 for part in message["content"] if part["type"] == "image")
102-
for message in messages
104+
for message in new_messages
103105
if message.get("content") and message["role"] != "tool"
104106
)
105107
if num_placeholders != len(images):
@@ -109,15 +111,19 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
109111

110112
# Then, fill in the actual images in the placeholders
111113
img_idx = 0
112-
for message in messages:
114+
for i, message in enumerate(new_messages):
113115
if not message.get("content") or message["role"] == "tool":
114116
continue
117+
new_content = []
115118
for part in message["content"]:
116119
if part["type"] == "image":
117-
part["image"] = images[img_idx]
120+
new_content.append({**part, "image": images[img_idx]})
118121
img_idx += 1
122+
else:
123+
new_content.append(part)
124+
new_messages[i] = {**message, "content": new_content}
119125

120-
return messages
126+
return new_messages
121127

122128

123129
def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:

0 commit comments

Comments
 (0)