Skip to content

Commit 9a1f345

Browse files
committed
fix collator mutation bug and reject Liger kernel for VLMs
1 parent 23784b0 commit 9a1f345

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

trl/experimental/gold/gold_trainer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,12 @@ def __init__(
875875
# Liger fused GKD loss (JSD)
876876
self.use_liger_gkd_loss = False
877877
if args.use_liger_kernel:
878+
if self._is_vlm:
879+
raise ValueError(
880+
"Liger fused GKD loss is not supported with VLMs. The fused kernel operates on base decoder "
881+
"hidden states, which is incompatible with VLM multimodal inputs (pixel_values, etc.). "
882+
"Please set `use_liger_kernel=False`."
883+
)
878884
self.liger_jsd_loss = LigerFusedLinearJSDLoss(
879885
beta=args.beta,
880886
ignore_index=-100,
@@ -1248,6 +1254,15 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any] | list[di
12481254
for i, flag in enumerate(on_policy_flags):
12491255
if not flag:
12501256
if self._vlm_collator is not None:
1257+
# Extract raw images and prompts BEFORE collation, since the collator
1258+
# mutates examples in place (pops "image", overwrites "prompt").
1259+
raw_images = None
1260+
raw_prompts = None
1261+
if self._teacher_processor is not None:
1262+
raw_images = [
1263+
ex.get("images") or ([ex["image"]] if "image" in ex else None) for ex in raw_slices[i]
1264+
]
1265+
raw_prompts = [ex.get("prompt") for ex in raw_slices[i]]
12511266
# Collate raw examples on-the-fly for off-policy slices
12521267
slice_inputs = self._vlm_collator(raw_slices[i])
12531268
slice_inputs = {
@@ -1256,10 +1271,8 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any] | list[di
12561271
}
12571272
# Preserve raw PIL images and prompts for cross-architecture teacher processing
12581273
if self._teacher_processor is not None:
1259-
slice_inputs["_raw_images"] = [
1260-
ex.get("images") or ([ex["image"]] if "image" in ex else None) for ex in raw_slices[i]
1261-
]
1262-
slice_inputs["_raw_prompts"] = [ex.get("prompt") for ex in raw_slices[i]]
1274+
slice_inputs["_raw_images"] = raw_images
1275+
slice_inputs["_raw_prompts"] = raw_prompts
12631276
else:
12641277
slice_inputs = slices[i]
12651278

0 commit comments

Comments
 (0)