Skip to content

Commit 537cc13

Browse files
authored
Merge branch 'main' into gold_vlm_support
2 parents 23784b0 + 4a3cae0 commit 537cc13

File tree

4 files changed

+166
-37
lines changed

4 files changed

+166
-37
lines changed

docs/source/grpo_trainer.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,26 @@ trainer.train()
716716

717717
`reset` can return either `None` or a string. In GRPO, when it returns a string, that string is appended to the last user message before generation.
718718

719+
### Multimodal Tool Responses
720+
721+
Tools can return images alongside text by returning a list of content blocks. This is useful for VLM agent training where the tool provides visual feedback (e.g., screenshots, plots, camera captures).
722+
723+
```python
724+
from PIL import Image
725+
726+
def take_screenshot() -> list:
727+
"""
728+
Takes a screenshot of the current screen.
729+
730+
Returns:
731+
The screenshot image with a description.
732+
"""
733+
img = Image.open("screenshot.png")
734+
return [{"type": "image", "image": img}, {"type": "text", "text": "Here is the screenshot."}]
735+
```
736+
737+
The returned images are automatically injected into the conversation and passed to the VLM for subsequent generation turns.
738+
719739
### Supported Models
720740

721741
Tested with:

tests/test_grpo_trainer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2446,6 +2446,90 @@ def fake_generate(input_ids, **kwargs):
24462446
new_param = trainer.model.get_parameter(n)
24472447
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
24482448

2449+
@pytest.mark.xfail(
2450+
condition=Version(transformers.__version__) < Version("5.2.0"),
2451+
reason="Qwen3.5 models were introduced in transformers-5.2.0",
2452+
strict=True,
2453+
)
2454+
@require_jmespath
2455+
@require_vision
2456+
def test_training_with_tools_multimodal_response(self):
2457+
# Test that tools returning images (multimodal responses) work correctly with a VLM.
2458+
# The tool returns a list of content blocks including an image.
2459+
from PIL import Image as PILImage
2460+
2461+
def screenshot_tool() -> list:
2462+
"""
2463+
Takes a screenshot and returns it.
2464+
2465+
Returns:
2466+
A list of content blocks with the screenshot image.
2467+
"""
2468+
img = PILImage.new("RGB", (64, 64), color="red")
2469+
return [{"type": "image", "image": img}, {"type": "text", "text": "Here is the screenshot"}]
2470+
2471+
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
2472+
2473+
training_args = GRPOConfig(
2474+
output_dir=self.tmp_dir,
2475+
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
2476+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
2477+
num_generations=3, # reduce the number of generations to reduce memory usage
2478+
max_completion_length=512,
2479+
report_to="none",
2480+
)
2481+
trainer = GRPOTrainer(
2482+
model="trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
2483+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
2484+
args=training_args,
2485+
train_dataset=dataset,
2486+
tools=[screenshot_tool],
2487+
)
2488+
2489+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
2490+
2491+
def fake_generate(input_ids, **kwargs):
2492+
if input_ids.shape[0] == 3: # first call
2493+
# fmt: off
2494+
completion_ids = torch.tensor(
2495+
[
2496+
# '<tool_call>\n<function=screenshot_tool>\n</function>\n</tool_call><|im_end|>'
2497+
[248058, 198, 27, 1628, 13744, 30091, 22076, 29, 198, 510, 1628, 29, 198, 248059, 248046],
2498+
# "I don't know any tool<|im_end|>" + padding
2499+
[40, 1459, 914, 1366, 866, 5224, 248046, 248044, 248044, 248044, 248044, 248044, 248044, 248044, 248044],
2500+
# '<tool_call>\n<function=screenshot_tool>\n</function>\n</tool_call><|im_end|>'
2501+
[248058, 198, 27, 1628, 13744, 30091, 22076, 29, 198, 510, 1628, 29, 198, 248059, 248046],
2502+
],
2503+
device=input_ids.device,
2504+
)
2505+
# fmt: on
2506+
else: # second call: 2 tool calls succeeded
2507+
completion_ids = torch.tensor(
2508+
[
2509+
# 'Done!<|im_end|>'
2510+
[16936, 0, 248046],
2511+
# 'Done!<|im_end|>'
2512+
[16936, 0, 248046],
2513+
],
2514+
device=input_ids.device,
2515+
)
2516+
return torch.cat([input_ids, completion_ids], dim=-1)
2517+
2518+
with patch.object(trainer.model, "generate", side_effect=fake_generate):
2519+
trainer.train()
2520+
2521+
assert trainer.state.log_history[-1]["train_loss"] is not None
2522+
assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
2523+
assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(0.0)
2524+
2525+
# Check that the params have changed (skip vision parts, see test_training_vlm)
2526+
params_to_skip = ("model.visual.",)
2527+
for n, param in previous_trainable_params.items():
2528+
if n.startswith(params_to_skip):
2529+
continue
2530+
new_param = trainer.model.get_parameter(n)
2531+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
2532+
24492533
@pytest.mark.xfail(
24502534
condition=Version(transformers.__version__) < Version("5.2.0"),
24512535
reason="Environment factory support is not available in transformers versions below 5.2.0",

tests/test_utils.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import transformers
2525
from packaging.version import Version
2626
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
27+
from transformers.testing_utils import torch_device
2728
from transformers.utils import is_peft_available
2829

2930
from trl import ModelConfig
@@ -48,7 +49,7 @@
4849
use_adapter,
4950
)
5051

51-
from .testing_utils import TrlTestCase, require_peft, require_rich
52+
from .testing_utils import TrlTestCase, require_peft, require_rich, require_torch_accelerator
5253

5354

5455
if is_peft_available():
@@ -960,6 +961,23 @@ def test_multi_images(self):
960961
assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))
961962
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
962963

964+
def test_split_by_image_position_ids(self):
965+
# Gemma-style: no image_grid_thw, split by num_images using image_position_ids
966+
batch = {
967+
"num_images": [1, 2],
968+
"pixel_values": torch.arange(3 * 4).reshape(3, 4),
969+
"image_position_ids": torch.tensor([[0, 1], [2, 3], [4, 5]]),
970+
}
971+
result = split_pixel_values_by_grid(batch)
972+
assert isinstance(result["pixel_values"], list)
973+
assert len(result["pixel_values"]) == 2
974+
assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:1])
975+
assert torch.equal(result["pixel_values"][1], batch["pixel_values"][1:])
976+
assert isinstance(result["image_position_ids"], list)
977+
assert len(result["image_position_ids"]) == 2
978+
assert torch.equal(result["image_position_ids"][0], batch["image_position_ids"][:1])
979+
assert torch.equal(result["image_position_ids"][1], batch["image_position_ids"][1:])
980+
963981

964982
class TestUnsplitPixelValuesByGrid(TrlTestCase):
965983
def test_unsplit_correctly(self):
@@ -975,13 +993,23 @@ def test_unsplit_correctly(self):
975993
assert torch.equal(result["image_grid_thw"], image_grid_thw_merged)
976994
assert "other_key" in result
977995

996+
def test_unsplit_image_position_ids(self):
997+
image_position_ids = [torch.tensor([[0, 1]]), torch.tensor([[2, 3], [4, 5]])]
998+
image_position_ids_merged = torch.cat(image_position_ids, dim=0)
999+
pixel_values = [torch.randn(1, 4), torch.randn(2, 4)]
1000+
batch = {"pixel_values": pixel_values, "image_position_ids": image_position_ids}
1001+
result = unsplit_pixel_values_by_grid(batch)
1002+
assert isinstance(result["image_position_ids"], torch.Tensor)
1003+
assert torch.equal(result["image_position_ids"], image_position_ids_merged)
1004+
9781005
def test_no_op_if_not_list(self):
9791006
original = torch.randn(5, 3)
9801007
batch = {"pixel_values": original}
9811008
result = unsplit_pixel_values_by_grid(batch)
9821009
assert torch.equal(result["pixel_values"], original)
9831010

9841011

1012+
@require_torch_accelerator
9851013
class TestForwardMaskedLogits:
9861014
@pytest.mark.parametrize(
9871015
"model_id",
@@ -1005,12 +1033,11 @@ class TestForwardMaskedLogits:
10051033
],
10061034
)
10071035
def test_llm(self, model_id):
1008-
device = torch.device("cuda")
1009-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", device_map=device)
1010-
input_ids = torch.randint(0, model.config.vocab_size, (2, 8), device=device)
1036+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", device_map=torch_device)
1037+
input_ids = torch.randint(0, model.config.vocab_size, (2, 8), device=torch_device)
10111038
logits_mask = torch.tensor(
10121039
[[1, 1, 0, 0, 1, 0, 1, 0], [0, 1, 1, 0, 0, 1, 0, 1]],
1013-
device=device,
1040+
device=torch_device,
10141041
)
10151042

10161043
full_outputs = model(input_ids=input_ids)
@@ -1051,12 +1078,11 @@ def test_llm(self, model_id):
10511078
],
10521079
)
10531080
def test_vlm(self, model_id):
1054-
device = torch.device("cuda")
1055-
model = AutoModelForImageTextToText.from_pretrained(model_id, dtype="auto", device_map=device)
1056-
input_ids = torch.randint(0, model.config.text_config.vocab_size, (2, 8), device=device)
1081+
model = AutoModelForImageTextToText.from_pretrained(model_id, dtype="auto", device_map=torch_device)
1082+
input_ids = torch.randint(0, model.config.text_config.vocab_size, (2, 8), device=torch_device)
10571083
logits_mask = torch.tensor(
10581084
[[1, 1, 0, 0, 1, 0, 1, 0], [0, 1, 1, 0, 0, 1, 0, 1]],
1059-
device=device,
1085+
device=torch_device,
10601086
)
10611087

10621088
full_outputs = model(input_ids=input_ids)
@@ -1203,6 +1229,7 @@ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
12031229
]
12041230

12051231

1232+
@require_torch_accelerator
12061233
class TestPatchChunkedLMHead:
12071234
B, S = 4, 16 # batch size, sequence length (including prompt + completion)
12081235
H, V = 32, 128
@@ -1285,15 +1312,14 @@ def test_dummy_model_chunked_forward_completion_mask_backward(self, temperature)
12851312
@pytest.mark.parametrize("model_id", _CHUNKED_LM_HEAD_MODEL_IDS)
12861313
@pytest.mark.parametrize("temperature", [1.0, 0.7])
12871314
def test_forward(self, model_id, temperature):
1288-
device = torch.device("cuda")
1289-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
1315+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(torch_device)
12901316
if getattr(model.config, "final_logit_softcapping", None) is not None:
12911317
pytest.skip("model uses final_logit_softcapping, not supported by chunked LM head")
12921318
model.eval()
12931319

12941320
B, S, chunk_size = 2, 8, 32
12951321
torch.manual_seed(42)
1296-
input_ids = torch.randint(0, model.config.vocab_size, (B, S), device=device)
1322+
input_ids = torch.randint(0, model.config.vocab_size, (B, S), device=torch_device)
12971323
labels = input_ids.clone()
12981324

12991325
# Reference: standard forward → shifted logits → logprobs & entropy
@@ -1316,15 +1342,14 @@ def test_forward(self, model_id, temperature):
13161342
@pytest.mark.parametrize("model_id", _CHUNKED_LM_HEAD_MODEL_IDS)
13171343
@pytest.mark.parametrize("temperature", [1.0, 0.7])
13181344
def test_backward(self, model_id, temperature):
1319-
device = torch.device("cuda")
1320-
model_ref = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
1345+
model_ref = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(torch_device)
13211346
if getattr(model_ref.config, "final_logit_softcapping", None) is not None:
13221347
pytest.skip("model uses final_logit_softcapping, not supported by chunked LM head")
13231348
model_chunked = copy.deepcopy(model_ref)
13241349

13251350
B, S, chunk_size = 2, 8, 32
13261351
torch.manual_seed(42)
1327-
input_ids = torch.randint(0, model_ref.config.vocab_size, (B, S), device=device)
1352+
input_ids = torch.randint(0, model_ref.config.vocab_size, (B, S), device=torch_device)
13281353
labels = input_ids.clone()
13291354
shifted_labels = labels[:, 1:]
13301355

trl/trainer/utils.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -887,39 +887,39 @@ def identity(x):
887887

888888
def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor | list[torch.Tensor]]:
889889
"""
890-
Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`
891-
and batch["num_images"] while keeping other entries unchanged. For models without `image_grid_thw` (e.g. Gemma),
892-
splits by `num_images` directly.
890+
Splits `batch["pixel_values"]` into a list of tensors, one per sample, based on `batch["num_images"]`.
891+
892+
For models with `image_grid_thw` (e.g. Qwen), the grid dimensions determine how many rows of `pixel_values` belong
893+
to each image. For models with `image_position_ids` instead (e.g. Gemma), `pixel_values` is indexed directly by
894+
image count.
893895
"""
894896
if "pixel_values" not in batch or "num_images" not in batch:
895897
return batch
896898

899+
num_images = batch["num_images"]
900+
pixel_values = batch["pixel_values"] # [total, feature_dim]
901+
897902
if "image_grid_thw" in batch:
898903
lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images]
899-
pixel_values = batch["pixel_values"] # [total, feature_dim]
900-
901904
if sum(lengths) != pixel_values.size(0):
902905
raise ValueError(
903906
f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}"
904907
)
905908

906-
boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12]
907-
sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))]
908-
split_values = list(torch.split(batch["pixel_values"], sections, dim=0))
909-
image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0))
910-
return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw}
911-
else:
912-
# Models without image_grid_thw (e.g. Gemma): split pixel_values by num_images per sample
913-
num_images = batch["num_images"]
914-
num_images_list = num_images.tolist() if isinstance(num_images, torch.Tensor) else list(num_images)
915-
split_values = list(torch.split(batch["pixel_values"], [int(n) for n in num_images_list], dim=0))
916-
result = {**batch, "pixel_values": split_values}
917-
# Also split image_position_ids if present (indexed by image, same as pixel_values)
918-
if "image_position_ids" in batch:
919-
result["image_position_ids"] = list(
920-
torch.split(batch["image_position_ids"], [int(n) for n in num_images_list], dim=0)
921-
)
922-
return result
909+
boundaries = [0, *accumulate(num_images)]
910+
image_grid_thw = batch["image_grid_thw"] # [total, 3]
911+
sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(num_images))]
912+
split_pixel_values = list(torch.split(pixel_values, sections, dim=0))
913+
split_image_grid_thw = list(torch.split(image_grid_thw, num_images, dim=0))
914+
return {**batch, "pixel_values": split_pixel_values, "image_grid_thw": split_image_grid_thw}
915+
916+
if "image_position_ids" in batch:
917+
image_position_ids = batch["image_position_ids"] # [total]
918+
split_pixel_values = list(torch.split(pixel_values, num_images, dim=0))
919+
split_image_position_ids = list(torch.split(image_position_ids, num_images, dim=0))
920+
return {**batch, "pixel_values": split_pixel_values, "image_position_ids": split_image_position_ids}
921+
922+
return batch
923923

924924

925925
def unsplit_pixel_values_by_grid(batch: dict[str, torch.Tensor | list[torch.Tensor]]) -> dict[str, torch.Tensor]:

0 commit comments

Comments
 (0)