@@ -875,6 +875,12 @@ def __init__(
875875 # Liger fused GKD loss (JSD)
876876 self .use_liger_gkd_loss = False
877877 if args .use_liger_kernel :
878+ if self ._is_vlm :
879+ raise ValueError (
880+ "Liger fused GKD loss is not supported with VLMs. The fused kernel operates on base decoder "
881+ "hidden states, which is incompatible with VLM multimodal inputs (pixel_values, etc.). "
882+ "Please set `use_liger_kernel=False`."
883+ )
878884 self .liger_jsd_loss = LigerFusedLinearJSDLoss (
879885 beta = args .beta ,
880886 ignore_index = - 100 ,
@@ -1248,6 +1254,15 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any] | list[di
12481254 for i , flag in enumerate (on_policy_flags ):
12491255 if not flag :
12501256 if self ._vlm_collator is not None :
1257+ # Extract raw images and prompts BEFORE collation, since the collator
1258+ # mutates examples in place (pops "image", overwrites "prompt").
1259+ raw_images = None
1260+ raw_prompts = None
1261+ if self ._teacher_processor is not None :
1262+ raw_images = [
1263+ ex .get ("images" ) or ([ex ["image" ]] if "image" in ex else None ) for ex in raw_slices [i ]
1264+ ]
1265+ raw_prompts = [ex .get ("prompt" ) for ex in raw_slices [i ]]
12511266 # Collate raw examples on-the-fly for off-policy slices
12521267 slice_inputs = self ._vlm_collator (raw_slices [i ])
12531268 slice_inputs = {
@@ -1256,10 +1271,8 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any] | list[di
12561271 }
12571272 # Preserve raw PIL images and prompts for cross-architecture teacher processing
12581273 if self ._teacher_processor is not None :
1259- slice_inputs ["_raw_images" ] = [
1260- ex .get ("images" ) or ([ex ["image" ]] if "image" in ex else None ) for ex in raw_slices [i ]
1261- ]
1262- slice_inputs ["_raw_prompts" ] = [ex .get ("prompt" ) for ex in raw_slices [i ]]
1274+ slice_inputs ["_raw_images" ] = raw_images
1275+ slice_inputs ["_raw_prompts" ] = raw_prompts
12631276 else :
12641277 slice_inputs = slices [i ]
12651278
0 commit comments