Skip to content

GOLDTrainer VLM support#5461

Open
Strongich wants to merge 10 commits intohuggingface:mainfrom
Strongich:gold_vlm_support
Open

GOLDTrainer VLM support#5461
Strongich wants to merge 10 commits intohuggingface:mainfrom
Strongich:gold_vlm_support

Conversation

@Strongich
Copy link
Copy Markdown

@Strongich Strongich commented Apr 6, 2026

What does this PR do?

Adds VLM support to GOLDTrainer:

  • Same-family VLM distillation: same <image_pad> tokens, same vision encoder's family -> JSD loss
  • Cross-architecture VLM distillation: images are processed separately through each model's processor to handle different image token formats
  • vLLM support for both

Motivation

The GOLD algorithm has no theoretical constraints against VLM-to-VLM distillation -- the barriers were purely engineering (incompatible image token formats, different tokenizers, raw image handling through the dataloader).

Key changes

  • GOLDTrainer detects VLM datasets and uses an identity collator to preserve raw PIL images through the dataloader
  • For cross-architecture pairs, a _teacher_processor is stored and used in compute_loss to build teacher-compatible vision tensors from raw images
  • Auto-resolves teacher_tokenizer_name_or_path
  • Added examples/scripts/gold_vlm.py with two documented usage examples (same-family JSD + vLLM, cross-family ULD)
  • Added tests for VLM collator (label masking, completion preservation), cross-architecture detection (rejects JSD, stores teacher processor for different archs, skips it for same arch), VLM + vLLM init (copied from the LLM example), rejects LLM teacher with vision dataset
  • VLM handling (identity collator, raw image storage, vLLM multimodal path) is borrowed (where it was possible) from SFTTrainer and GRPOTrainer

Note

Looking for feedback:

  • I'm not fully confident the current approach for storing and passing raw images through the pipeline is optimal (especially in _fill_buffer), and the overall design choice with two different collators, as well as two separate generation flows (_generate_on_policy_vlm_raw vs _generate_on_policy_for_slices). Would appreciate feedback from anyone with more experience in this area.
  • I didn't add VLM usage examples to docs/source/gold_trainer.md -- will add if that's desirable, just let me know.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

High Risk
Substantial changes to GOLDTrainer batching/generation and compute_loss to support multimodal inputs (raw images, processor-based tokenization, cross-architecture teacher processing) introduce risk of subtle shape/masking and memory/perf regressions during training, especially with vLLM and ULD paths.

Overview
Enables VLM-to-VLM distillation in GOLDTrainer by detecting vision datasets, enforcing VLM student/teacher compatibility, and introducing a VLM data path that preserves raw images through the dataloader via an identity collator plus an internal DataCollatorForVisionLanguageChatML.

Adds cross-architecture VLM handling: when student/teacher model_type differs, the trainer loads/stores a separate _teacher_processor, requires use_uld_loss=True, and builds teacher inputs (including vision tensors) from raw prompts/images during compute_loss.

Extends on-policy generation to support multimodal vLLM/non-vLLM flows (_generate_on_policy_vlm_raw), updates prompt-length/label masking to work with flushed-left VLM batches, defaults GOLDConfig.remove_unused_columns to False, and adds a runnable examples/scripts/gold_vlm.py plus extensive new VLM-focused tests (collator correctness, init validation, cross-arch behavior, and vLLM init).

Reviewed by Cursor Bugbot for commit 7c96055. Bugbot is set up for automated code reviews on this repo. Configure here.

attention_mask=inputs.get("prompt_attention_mask", None),
generation_config=generation_config,
return_dict_in_generate=True,
**generate_kwargs,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence-length multimodal keys mismatch during VLM generation

Medium Severity

In generate_on_policy_outputs, all _MULTIMODAL_KEYS are extracted from the collated batch and passed to model.generate() with prompt-only input_ids. The collator concatenates token_type_ids and mm_token_type_ids to full prompt+completion length, but model.generate() receives only inputs["prompts"] (prompt-only). This dimension mismatch causes errors for VLM architectures that produce these keys (e.g., ERNIE-VL).

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9a1f345. Configure here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a closer look at it a bit later, this part I've copied from SFTTrainer

# Models
# ──────────────────────────────────────────────
student_model = AutoModelForImageTextToText.from_pretrained(cli_args.student_model_name, dtype=torch.bfloat16)
teacher_model = AutoModelForImageTextToText.from_pretrained(cli_args.teacher_model_name, dtype=torch.bfloat16)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example script uses wrong dtype parameter name

Low Severity

AutoModelForImageTextToText.from_pretrained is called with dtype=torch.bfloat16 instead of the correct torch_dtype=torch.bfloat16. The dtype kwarg is not a recognized parameter for from_pretrained, so the models will silently load in their default precision (float32) instead of bfloat16, increasing memory usage and potentially causing dtype mismatches during training.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9a1f345. Configure here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bug, torch_dtype is deprecated (everybody knows this warning)
Maybe I should add version checking, like here

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 3 total unresolved issues (including 2 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 7c96055. Configure here.

updated_slice["_raw_prompts"] = prompts

self._buffered_inputs[slice_idx] = updated_slice
self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale sequence-length-dependent tensors after VLM on-policy generation

Medium Severity

In the non-vLLM VLM on-policy path, the original collated output from _vlm_collator is shallow-copied into updated_slice, but only input_ids, attention_mask, and labels are replaced with the generated sequences. Sequence-length-dependent multimodal tensors like token_type_ids and mm_token_type_ids retain their original shape from the collation (prompt + original completion length). When compute_loss later extracts these via _MULTIMODAL_KEYS and passes them to the model, they mismatch the new input_ids shape. The same mismatch also occurs earlier in generate_on_policy_outputs, where generate_kwargs includes full-sequence token_type_ids but model.generate receives prompt-only input_ids. This causes a runtime shape error for VLM architectures whose processor produces token_type_ids or mm_token_type_ids (e.g., ERNIE-VL).

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 7c96055. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant