Conversation
537cc13 to
9a1f345
Compare
| attention_mask=inputs.get("prompt_attention_mask", None), | ||
| generation_config=generation_config, | ||
| return_dict_in_generate=True, | ||
| **generate_kwargs, |
There was a problem hiding this comment.
Sequence-length multimodal keys mismatch during VLM generation
Medium Severity
In generate_on_policy_outputs, all _MULTIMODAL_KEYS are extracted from the collated batch and passed to model.generate() with prompt-only input_ids. The collator concatenates token_type_ids and mm_token_type_ids to full prompt+completion length, but model.generate() receives only inputs["prompts"] (prompt-only). This dimension mismatch causes errors for VLM architectures that produce these keys (e.g., ERNIE-VL).
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 9a1f345. Configure here.
There was a problem hiding this comment.
Will take a closer look at it a bit later, this part I've copied from SFTTrainer
| # Models | ||
| # ────────────────────────────────────────────── | ||
| student_model = AutoModelForImageTextToText.from_pretrained(cli_args.student_model_name, dtype=torch.bfloat16) | ||
| teacher_model = AutoModelForImageTextToText.from_pretrained(cli_args.teacher_model_name, dtype=torch.bfloat16) |
There was a problem hiding this comment.
Example script uses wrong dtype parameter name
Low Severity
AutoModelForImageTextToText.from_pretrained is called with dtype=torch.bfloat16 instead of the correct torch_dtype=torch.bfloat16. The dtype kwarg is not a recognized parameter for from_pretrained, so the models will silently load in their default precision (float32) instead of bfloat16, increasing memory usage and potentially causing dtype mismatches during training.
Reviewed by Cursor Bugbot for commit 9a1f345. Configure here.
There was a problem hiding this comment.
Not a bug, torch_dtype is deprecated (everybody knows this warning)
Maybe I should add version checking, like here
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 3 total unresolved issues (including 2 from previous reviews).
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 7c96055. Configure here.
| updated_slice["_raw_prompts"] = prompts | ||
|
|
||
| self._buffered_inputs[slice_idx] = updated_slice | ||
| self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) |
There was a problem hiding this comment.
Stale sequence-length-dependent tensors after VLM on-policy generation
Medium Severity
In the non-vLLM VLM on-policy path, the original collated output from _vlm_collator is shallow-copied into updated_slice, but only input_ids, attention_mask, and labels are replaced with the generated sequences. Sequence-length-dependent multimodal tensors like token_type_ids and mm_token_type_ids retain their original shape from the collation (prompt + original completion length). When compute_loss later extracts these via _MULTIMODAL_KEYS and passes them to the model, they mismatch the new input_ids shape. The same mismatch also occurs earlier in generate_on_policy_outputs, where generate_kwargs includes full-sequence token_type_ids but model.generate receives prompt-only input_ids. This causes a runtime shape error for VLM architectures whose processor produces token_type_ids or mm_token_type_ids (e.g., ERNIE-VL).
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 7c96055. Configure here.


What does this PR do?
Adds VLM support to GOLDTrainer:
Motivation
The GOLD algorithm has no theoretical constraints against VLM-to-VLM distillation -- the barriers were purely engineering (incompatible image token formats, different tokenizers, raw image handling through the dataloader).
Key changes
_teacher_processoris stored and used incompute_lossto build teacher-compatible vision tensors from raw imagesteacher_tokenizer_name_or_pathexamples/scripts/gold_vlm.pywith two documented usage examples (same-family JSD + vLLM, cross-family ULD)Note
Looking for feedback:
_fill_buffer), and the overall design choice with two different collators, as well as two separate generation flows (_generate_on_policy_vlm_rawvs_generate_on_policy_for_slices). Would appreciate feedback from anyone with more experience in this area.docs/source/gold_trainer.md-- will add if that's desirable, just let me know.Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
High Risk
Substantial changes to
GOLDTrainerbatching/generation andcompute_lossto support multimodal inputs (raw images, processor-based tokenization, cross-architecture teacher processing) introduce risk of subtle shape/masking and memory/perf regressions during training, especially with vLLM and ULD paths.Overview
Enables VLM-to-VLM distillation in
GOLDTrainerby detecting vision datasets, enforcing VLM student/teacher compatibility, and introducing a VLM data path that preserves raw images through the dataloader via anidentitycollator plus an internalDataCollatorForVisionLanguageChatML.Adds cross-architecture VLM handling: when student/teacher
model_typediffers, the trainer loads/stores a separate_teacher_processor, requiresuse_uld_loss=True, and builds teacher inputs (including vision tensors) from raw prompts/images duringcompute_loss.Extends on-policy generation to support multimodal vLLM/non-vLLM flows (
_generate_on_policy_vlm_raw), updates prompt-length/label masking to work with flushed-left VLM batches, defaultsGOLDConfig.remove_unused_columnstoFalse, and adds a runnableexamples/scripts/gold_vlm.pyplus extensive new VLM-focused tests (collator correctness, init validation, cross-arch behavior, and vLLM init).Reviewed by Cursor Bugbot for commit 7c96055. Bugbot is set up for automated code reviews on this repo. Configure here.