2424import transformers
2525from packaging .version import Version
2626from transformers import AutoModelForCausalLM , AutoModelForImageTextToText
27+ from transformers .testing_utils import torch_device
2728from transformers .utils import is_peft_available
2829
2930from trl import ModelConfig
4849 use_adapter ,
4950)
5051
51- from .testing_utils import TrlTestCase , require_peft , require_rich
52+ from .testing_utils import TrlTestCase , require_peft , require_rich , require_torch_accelerator
5253
5354
5455if is_peft_available ():
@@ -960,6 +961,23 @@ def test_multi_images(self):
960961 assert torch .equal (result ["image_grid_thw" ][0 ], torch .tensor ([[1 , 1 , 2 ]]))
961962 assert torch .equal (result ["image_grid_thw" ][1 ], torch .tensor ([[1 , 2 , 2 ], [1 , 2 , 1 ]]))
962963
964+ def test_split_by_image_position_ids (self ):
965+ # Gemma-style: no image_grid_thw, split by num_images using image_position_ids
966+ batch = {
967+ "num_images" : [1 , 2 ],
968+ "pixel_values" : torch .arange (3 * 4 ).reshape (3 , 4 ),
969+ "image_position_ids" : torch .tensor ([[0 , 1 ], [2 , 3 ], [4 , 5 ]]),
970+ }
971+ result = split_pixel_values_by_grid (batch )
972+ assert isinstance (result ["pixel_values" ], list )
973+ assert len (result ["pixel_values" ]) == 2
974+ assert torch .equal (result ["pixel_values" ][0 ], batch ["pixel_values" ][:1 ])
975+ assert torch .equal (result ["pixel_values" ][1 ], batch ["pixel_values" ][1 :])
976+ assert isinstance (result ["image_position_ids" ], list )
977+ assert len (result ["image_position_ids" ]) == 2
978+ assert torch .equal (result ["image_position_ids" ][0 ], batch ["image_position_ids" ][:1 ])
979+ assert torch .equal (result ["image_position_ids" ][1 ], batch ["image_position_ids" ][1 :])
980+
963981
964982class TestUnsplitPixelValuesByGrid (TrlTestCase ):
965983 def test_unsplit_correctly (self ):
@@ -975,13 +993,23 @@ def test_unsplit_correctly(self):
975993 assert torch .equal (result ["image_grid_thw" ], image_grid_thw_merged )
976994 assert "other_key" in result
977995
996+ def test_unsplit_image_position_ids (self ):
997+ image_position_ids = [torch .tensor ([[0 , 1 ]]), torch .tensor ([[2 , 3 ], [4 , 5 ]])]
998+ image_position_ids_merged = torch .cat (image_position_ids , dim = 0 )
999+ pixel_values = [torch .randn (1 , 4 ), torch .randn (2 , 4 )]
1000+ batch = {"pixel_values" : pixel_values , "image_position_ids" : image_position_ids }
1001+ result = unsplit_pixel_values_by_grid (batch )
1002+ assert isinstance (result ["image_position_ids" ], torch .Tensor )
1003+ assert torch .equal (result ["image_position_ids" ], image_position_ids_merged )
1004+
9781005 def test_no_op_if_not_list (self ):
9791006 original = torch .randn (5 , 3 )
9801007 batch = {"pixel_values" : original }
9811008 result = unsplit_pixel_values_by_grid (batch )
9821009 assert torch .equal (result ["pixel_values" ], original )
9831010
9841011
1012+ @require_torch_accelerator
9851013class TestForwardMaskedLogits :
9861014 @pytest .mark .parametrize (
9871015 "model_id" ,
@@ -1005,12 +1033,11 @@ class TestForwardMaskedLogits:
10051033 ],
10061034 )
10071035 def test_llm (self , model_id ):
1008- device = torch .device ("cuda" )
1009- model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , device_map = device )
1010- input_ids = torch .randint (0 , model .config .vocab_size , (2 , 8 ), device = device )
1036+ model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , device_map = torch_device )
1037+ input_ids = torch .randint (0 , model .config .vocab_size , (2 , 8 ), device = torch_device )
10111038 logits_mask = torch .tensor (
10121039 [[1 , 1 , 0 , 0 , 1 , 0 , 1 , 0 ], [0 , 1 , 1 , 0 , 0 , 1 , 0 , 1 ]],
1013- device = device ,
1040+ device = torch_device ,
10141041 )
10151042
10161043 full_outputs = model (input_ids = input_ids )
@@ -1051,12 +1078,11 @@ def test_llm(self, model_id):
10511078 ],
10521079 )
10531080 def test_vlm (self , model_id ):
1054- device = torch .device ("cuda" )
1055- model = AutoModelForImageTextToText .from_pretrained (model_id , dtype = "auto" , device_map = device )
1056- input_ids = torch .randint (0 , model .config .text_config .vocab_size , (2 , 8 ), device = device )
1081+ model = AutoModelForImageTextToText .from_pretrained (model_id , dtype = "auto" , device_map = torch_device )
1082+ input_ids = torch .randint (0 , model .config .text_config .vocab_size , (2 , 8 ), device = torch_device )
10571083 logits_mask = torch .tensor (
10581084 [[1 , 1 , 0 , 0 , 1 , 0 , 1 , 0 ], [0 , 1 , 1 , 0 , 0 , 1 , 0 , 1 ]],
1059- device = device ,
1085+ device = torch_device ,
10601086 )
10611087
10621088 full_outputs = model (input_ids = input_ids )
@@ -1203,6 +1229,7 @@ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
12031229]
12041230
12051231
1232+ @require_torch_accelerator
12061233class TestPatchChunkedLMHead :
12071234 B , S = 4 , 16 # batch size, sequence length (including prompt + completion)
12081235 H , V = 32 , 128
@@ -1285,15 +1312,14 @@ def test_dummy_model_chunked_forward_completion_mask_backward(self, temperature)
12851312 @pytest .mark .parametrize ("model_id" , _CHUNKED_LM_HEAD_MODEL_IDS )
12861313 @pytest .mark .parametrize ("temperature" , [1.0 , 0.7 ])
12871314 def test_forward (self , model_id , temperature ):
1288- device = torch .device ("cuda" )
1289- model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = torch .bfloat16 ).to (device )
1315+ model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = torch .bfloat16 ).to (torch_device )
12901316 if getattr (model .config , "final_logit_softcapping" , None ) is not None :
12911317 pytest .skip ("model uses final_logit_softcapping, not supported by chunked LM head" )
12921318 model .eval ()
12931319
12941320 B , S , chunk_size = 2 , 8 , 32
12951321 torch .manual_seed (42 )
1296- input_ids = torch .randint (0 , model .config .vocab_size , (B , S ), device = device )
1322+ input_ids = torch .randint (0 , model .config .vocab_size , (B , S ), device = torch_device )
12971323 labels = input_ids .clone ()
12981324
12991325 # Reference: standard forward → shifted logits → logprobs & entropy
@@ -1316,15 +1342,14 @@ def test_forward(self, model_id, temperature):
13161342 @pytest .mark .parametrize ("model_id" , _CHUNKED_LM_HEAD_MODEL_IDS )
13171343 @pytest .mark .parametrize ("temperature" , [1.0 , 0.7 ])
13181344 def test_backward (self , model_id , temperature ):
1319- device = torch .device ("cuda" )
1320- model_ref = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = torch .bfloat16 ).to (device )
1345+ model_ref = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = torch .bfloat16 ).to (torch_device )
13211346 if getattr (model_ref .config , "final_logit_softcapping" , None ) is not None :
13221347 pytest .skip ("model uses final_logit_softcapping, not supported by chunked LM head" )
13231348 model_chunked = copy .deepcopy (model_ref )
13241349
13251350 B , S , chunk_size = 2 , 8 , 32
13261351 torch .manual_seed (42 )
1327- input_ids = torch .randint (0 , model_ref .config .vocab_size , (B , S ), device = device )
1352+ input_ids = torch .randint (0 , model_ref .config .vocab_size , (B , S ), device = torch_device )
13281353 labels = input_ids .clone ()
13291354 shifted_labels = labels [:, 1 :]
13301355
0 commit comments