diff --git a/examples/scripts/gold_vlm.py b/examples/scripts/gold_vlm.py
new file mode 100644
index 00000000000..f50ad67a691
--- /dev/null
+++ b/examples/scripts/gold_vlm.py
@@ -0,0 +1,181 @@
+# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+GOLD VLM distillation on MMK12.
+
+# Example 1 — Same-family distillation (SmolVLM-500M → SmolVLM-256M)
+# Uses JSD loss. Same architecture and tokenizer, so standard distillation works directly.
+# vLLM enabled for faster on-policy generation.
+accelerate launch examples/scripts/gold_vlm.py \
+ --student_model_name HuggingFaceTB/SmolVLM-256M-Instruct \
+ --teacher_model_name HuggingFaceTB/SmolVLM-500M-Instruct \
+ --lmbda 0.5 \
+ --use_vllm \
+ --vllm_mode colocate
+
+# Example 2 — Cross-family distillation (Qwen2.5-VL-3B → SmolVLM-256M)
+# Different architectures have incompatible tokenizers and image token formats,
+# so ULD (Universal Logit Distillation) loss is required to align logits across vocabularies.
+accelerate launch examples/scripts/gold_vlm.py \
+ --student_model_name HuggingFaceTB/SmolVLM-256M-Instruct \
+ --teacher_model_name Qwen/Qwen2.5-VL-3B-Instruct \
+ --use_uld_loss \
+ --lmbda 0.0
+"""
+
+import argparse
+
+import torch
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import AutoModelForImageTextToText, AutoProcessor
+
+from trl.experimental.gold import GOLDConfig, GOLDTrainer
+
+
+SYSTEM_PROMPT = (
+ "You are a helpful AI Assistant that provides well-reasoned and detailed responses. "
+ "You first think about the reasoning process as an internal monologue and then provide the user with the answer. "
+ "Respond in the following format: \n...\n\n\n...\n"
+)
+
+
+def make_conversation(example):
+ """Convert MMK12 row into the chat format expected by TRL VLM trainers."""
+ return {
+ "prompt": [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": SYSTEM_PROMPT}],
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": example["question"]},
+ ],
+ },
+ ],
+ "completion": [
+ {
+ "role": "assistant",
+ "content": [{"type": "text", "text": str(example["answer"])}],
+ },
+ ],
+ "image": example["image"],
+ }
+
+
+def filter_big_images(example):
+ image = example["image"]
+ return image.size[0] < 512 and image.size[1] < 512
+
+
+def convert_to_rgb(example):
+ image = example["image"]
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+ example["image"] = image
+ return example
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--student_model_name", type=str, default="HuggingFaceTB/SmolVLM-256M-Instruct")
+ parser.add_argument("--teacher_model_name", type=str, default="HuggingFaceTB/SmolVLM-500M-Instruct")
+ parser.add_argument("--use_uld_loss", action="store_true")
+ parser.add_argument("--lmbda", type=float, default=0.5)
+ parser.add_argument("--use_vllm", action="store_true")
+ parser.add_argument("--vllm_mode", type=str, default="colocate")
+ cli_args = parser.parse_args()
+
+ # ──────────────────────────────────────────────
+ # Models
+ # ──────────────────────────────────────────────
+ student_model = AutoModelForImageTextToText.from_pretrained(cli_args.student_model_name, dtype=torch.bfloat16)
+ teacher_model = AutoModelForImageTextToText.from_pretrained(cli_args.teacher_model_name, dtype=torch.bfloat16)
+
+ # Freeze everything except the language model head
+ for name, param in student_model.named_parameters():
+ if "language_model" not in name:
+ param.requires_grad = False
+
+ processor = AutoProcessor.from_pretrained(cli_args.student_model_name, padding_side="left")
+
+ # toy example to fit small GPUs
+ peft_config = LoraConfig(
+ r=4,
+ lora_alpha=8,
+ lora_dropout=0.05,
+ target_modules=["q_proj"],
+ )
+
+ # ──────────────────────────────────────────────
+ # Dataset
+ # ──────────────────────────────────────────────
+ dataset = load_dataset("FanqingM/MMK12", split="train[:5%]")
+ dataset = dataset.filter(filter_big_images)
+ dataset = dataset.map(convert_to_rgb)
+ dataset = dataset.map(make_conversation)
+
+ # ──────────────────────────────────────────────
+ # Training config
+ # ──────────────────────────────────────────────
+ args = GOLDConfig(
+ output_dir="gold-vlm-distillation",
+ # GOLD-specific
+ lmbda=cli_args.lmbda,
+ beta=0.5,
+ temperature=0.9,
+ max_completion_length=256,
+ teacher_model_name_or_path=cli_args.teacher_model_name,
+ num_generations=1,
+ use_uld_loss=cli_args.use_uld_loss,
+ # vLLM
+ use_vllm=cli_args.use_vllm,
+ vllm_mode=cli_args.vllm_mode,
+ vllm_gpu_memory_utilization=0.5,
+ vllm_max_model_length=8192,
+ # VLM image tokens expand during processing, so the default max_length (1024) is often too small.
+ # Which will lead to shifted_student_logits become an empty Tensor.
+ max_length=2048,
+ # Training schedule
+ per_device_train_batch_size=2,
+ gradient_accumulation_steps=4,
+ max_steps=100,
+ learning_rate=2e-5,
+ warmup_steps=10,
+ # Precision
+ bf16=True,
+ # Logging
+ logging_steps=1,
+ log_completions=True,
+ report_to="none",
+ )
+
+ # ──────────────────────────────────────────────
+ # Trainer
+ # ──────────────────────────────────────────────
+ trainer = GOLDTrainer(
+ model=student_model,
+ teacher_model=teacher_model,
+ args=args,
+ train_dataset=dataset,
+ processing_class=processor,
+ peft_config=peft_config,
+ )
+
+ trainer.train()
+ trainer.save_model(args.output_dir)
diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py
index d7e32056323..76df35cc03c 100644
--- a/tests/experimental/test_gold_trainer.py
+++ b/tests/experimental/test_gold_trainer.py
@@ -16,12 +16,13 @@
import pytest
import torch
-from datasets import load_dataset
-from transformers import AutoTokenizer
+from datasets import Dataset, load_dataset
+from transformers import AutoProcessor, AutoTokenizer
from trl.experimental.gold import gold_trainer as gold_trainer_module
from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts
-from trl.experimental.utils import DataCollatorForChatML
+from trl.experimental.utils import DataCollatorForChatML, DataCollatorForVisionLanguageChatML
+from trl.trainer.utils import identity
@pytest.fixture(scope="module")
@@ -271,6 +272,35 @@ def smollm_tokenizer():
return tokenizer
+@pytest.fixture(scope="session")
+def smolvlm_processor():
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
+ if processor.tokenizer.pad_token is None:
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+ return processor
+
+
+@pytest.fixture(scope="session")
+def qwen3_vl_processor():
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct")
+ if processor.tokenizer.pad_token is None:
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+ return processor
+
+
+@pytest.fixture(scope="module")
+def vlm_examples():
+ try:
+ dataset = load_dataset(
+ "trl-internal-testing/zen-image",
+ "conversational_prompt_completion",
+ split="train[:3]",
+ )
+ except Exception as exc: # pragma: no cover - network/environment dependent
+ pytest.skip(f"zen-image dataset unavailable: {exc}")
+ return [dict(row) for row in dataset]
+
+
def encode_prompt_completion(tokenizer, prompt, completion):
prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
completion_ids = tokenizer(completion, add_special_tokens=False)["input_ids"]
@@ -302,6 +332,7 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati
trainer = GOLDTrainer.__new__(GOLDTrainer)
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"))
trainer.processing_class = RecordingTokenizer()
+ trainer.pad_token_id = RecordingTokenizer.pad_token_id # __new__ bypasses __init__, set manually
trainer.args = SimpleNamespace(max_length=None)
trainer._buffered_inputs = [None]
trainer._buffered_text_logs = [None]
@@ -449,6 +480,7 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati
trainer = GOLDTrainer.__new__(GOLDTrainer)
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
trainer.processing_class = RecordingTokenizer()
+ trainer.pad_token_id = RecordingTokenizer.pad_token_id # __new__ bypasses __init__, set manually
trainer.args = SimpleNamespace(max_length=None, report_to=[])
trainer.use_vllm = True
trainer.vllm_generation = RecordingVLLMGeneration()
@@ -498,6 +530,9 @@ def resize_token_embeddings(self, vocab_size):
self.resized_to = vocab_size
class DummyProcessingClass:
+ # GOLDTrainer.__init__ extracts tokenizer pad token (like GRPOTrainer),
+ # so the dummy must provide both pad_token and pad_token_id.
+ pad_token = ""
pad_token_id = 0
def fake_sft_init(
@@ -575,12 +610,17 @@ def __init__(self, **kwargs):
vllm_sync_frequency=1,
)
+ # A minimal dataset is required because GOLDTrainer inspects the first example at init
+ # to detect whether the dataset contains images for VLM, so None dosn't pass
+ dummy_dataset = Dataset.from_dict({"messages": [["dummy"]]})
+
teacher_model = DummyTeacherModel()
GOLDTrainer(
model=DummyStudentModel(),
teacher_model=teacher_model,
args=args,
data_collator=object(),
+ train_dataset=dummy_dataset,
processing_class=DummyProcessingClass(),
)
@@ -942,3 +982,764 @@ def test_uldloss_hybrid_config_beta_zero(llama_tokenizer, qwen_tokenizer):
expected = config.uld_hybrid_unmatched_weight * loss_fn.last_unmatched_loss
torch.testing.assert_close(loss, expected, atol=1e-6, rtol=1e-5)
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# VLM tests
+# ──────────────────────────────────────────────────────────────────────────────
+
+
+def test_vlm_alignment_groups_cover_all_tokens_smolvlm_qwen3vl(smolvlm_processor, qwen3_vl_processor, vlm_examples):
+ student_tokenizer = smolvlm_processor.tokenizer
+ teacher_tokenizer = qwen3_vl_processor.tokenizer
+
+ collator = DataCollatorForVisionLanguageChatML(processor=smolvlm_processor, max_length=2048)
+ batch = collator(vlm_examples)
+
+ config = build_config()
+ loss = ULDLoss(config, student_tokenizer=student_tokenizer, teacher_tokenizer=teacher_tokenizer)
+
+ teacher_input_ids, teacher_labels, _ = _teacher_inputs_from_collator(student_tokenizer, teacher_tokenizer, batch)
+
+ _assert_alignment_covers_completion(loss, batch, teacher_input_ids, teacher_labels)
+
+
+def test_gold_trainer_init_rejects_llm_with_vision_dataset(monkeypatch):
+ """GOLDTrainer should raise ValueError when a text-only model receives a vision dataset."""
+
+ class DummyStudentModel:
+ def __init__(self):
+ self.config = SimpleNamespace(_name_or_path="student", vocab_size=17)
+ self.generation_config = SimpleNamespace(eos_token_id=2)
+ self.name_or_path = "student"
+
+ class DummyTeacherModel:
+ def __init__(self):
+ self.resized_to = None
+
+ def resize_token_embeddings(self, vocab_size):
+ self.resized_to = vocab_size
+
+ def fake_sft_init(
+ self,
+ model,
+ args=None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ compute_metrics=None,
+ callbacks=None,
+ optimizers=None,
+ preprocess_logits_for_metrics=None,
+ peft_config=None,
+ ):
+ del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers
+ del preprocess_logits_for_metrics, peft_config
+ self.model = model
+ self.args = args
+ self.processing_class = processing_class
+ self.accelerator = SimpleNamespace(
+ device=torch.device("cpu"),
+ num_processes=1,
+ prepare_model=lambda module, evaluation_mode=True: module,
+ )
+ self.is_deepspeed_enabled = False
+ self.is_fsdp_enabled = False
+
+ monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
+
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Dataset with an "image" key triggers vision detection
+ vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]})
+
+ args = SimpleNamespace(
+ model_init_kwargs=None,
+ max_length=128,
+ use_liger_kernel=False,
+ teacher_model_init_kwargs=None,
+ use_uld_loss=False,
+ teacher_tokenizer_name_or_path=None,
+ teacher_model_revision=None,
+ disable_dropout=False,
+ lmbda=1.0,
+ beta=0.5,
+ temperature=1.0,
+ top_p=1.0,
+ seq_kd=False,
+ num_generations=1,
+ use_transformers_paged=False,
+ max_completion_length=16,
+ top_k=0,
+ log_completions=False,
+ log_completions_steps=100,
+ wandb_log_unique_prompts=True,
+ num_completions_to_print=None,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=1,
+ use_vllm=False,
+ )
+
+ with pytest.raises(ValueError, match="vision-related"):
+ GOLDTrainer(
+ model=DummyStudentModel(),
+ teacher_model=DummyTeacherModel(),
+ args=args,
+ train_dataset=vision_dataset,
+ processing_class=tokenizer,
+ )
+
+
+def _get_assistant_texts(examples):
+ """Extract assistant text content from examples, handling both plain string and multimodal format."""
+ texts = []
+ for example in examples:
+ content = example["completion"][-1]["content"]
+ if isinstance(content, list):
+ texts.append("".join(part["text"] for part in content if "text" in part))
+ else:
+ texts.append(content)
+ return texts
+
+
+def test_vlm_chatml_collator_preserves_completion_smolvlm(smolvlm_processor, qwen3_vl_processor, vlm_examples):
+ # 2048 to not truncate the completion tokens
+ collator = DataCollatorForVisionLanguageChatML(processor=smolvlm_processor, max_length=2048)
+ batch = collator(vlm_examples)
+
+ # Verify basic batch structure
+ assert "input_ids" in batch
+ assert "labels" in batch
+ assert "prompts" in batch
+ assert "prompt_attention_mask" in batch
+ assert "pixel_values" in batch
+ assert "original_prompt_text" in batch
+ assert "original_completion_text" in batch
+
+ # Verify completions are preserved in decoded output
+ assistant_texts = _get_assistant_texts(vlm_examples)
+ decoded_batch = smolvlm_processor.tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)
+ for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True):
+ assert assistant in decoded
+
+ # Verify ULD cross-tokenizer distillation with teacher inputs
+ student_tokenizer = smolvlm_processor.tokenizer
+ teacher_tokenizer = qwen3_vl_processor.tokenizer
+
+ teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator(
+ student_tokenizer, teacher_tokenizer, batch
+ )
+ for completion, assistant in zip(completion_texts, assistant_texts, strict=True):
+ assert assistant.strip() in completion
+ assert completion.strip()
+
+ config = build_config(
+ uld_use_hybrid_loss=True,
+ uld_hybrid_matched_weight=0.6,
+ uld_hybrid_unmatched_weight=0.4,
+ )
+ loss_fn = ULDLoss(config, student_tokenizer=student_tokenizer, teacher_tokenizer=teacher_tokenizer)
+
+ _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels)
+
+ torch.manual_seed(42)
+ student_vocab = len(student_tokenizer)
+ teacher_vocab = len(teacher_tokenizer)
+ batch_size, seq_len = batch["input_ids"].shape
+ student_logits = torch.randn(batch_size, seq_len, student_vocab)
+ teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab)
+
+ loss = loss_fn(
+ student_logits=student_logits,
+ teacher_logits=teacher_logits,
+ student_labels=batch["labels"],
+ teacher_labels=teacher_labels,
+ student_input_ids=batch["input_ids"],
+ teacher_input_ids=teacher_input_ids,
+ )
+
+ assert torch.isfinite(loss)
+
+
+@pytest.mark.slow
+def test_vlm_chatml_collator_preserves_completion_qwen3vl(smolvlm_processor, qwen3_vl_processor, vlm_examples):
+ collator = DataCollatorForVisionLanguageChatML(processor=qwen3_vl_processor, max_length=2048)
+ batch = collator(vlm_examples)
+
+ # Verify basic batch structure
+ assert "input_ids" in batch
+ assert "labels" in batch
+ assert "prompts" in batch
+ assert "pixel_values" in batch
+
+ # Verify completions are preserved in decoded output
+ assistant_texts = _get_assistant_texts(vlm_examples)
+ decoded_batch = qwen3_vl_processor.tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)
+ for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True):
+ assert assistant in decoded
+
+ # Verify ULD cross-tokenizer distillation with teacher inputs
+ student_tokenizer = qwen3_vl_processor.tokenizer
+ teacher_tokenizer = smolvlm_processor.tokenizer
+
+ teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator(
+ student_tokenizer, teacher_tokenizer, batch
+ )
+ for completion, assistant in zip(completion_texts, assistant_texts, strict=True):
+ assert assistant.strip() in completion
+ assert completion.strip()
+
+ config = build_config(
+ uld_use_hybrid_loss=True,
+ uld_hybrid_matched_weight=0.6,
+ uld_hybrid_unmatched_weight=0.4,
+ )
+ loss_fn = ULDLoss(config, student_tokenizer=student_tokenizer, teacher_tokenizer=teacher_tokenizer)
+
+ _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels)
+
+ torch.manual_seed(43)
+ student_vocab = len(student_tokenizer)
+ teacher_vocab = len(teacher_tokenizer)
+ batch_size, seq_len = batch["input_ids"].shape
+ student_logits = torch.randn(batch_size, seq_len, student_vocab)
+ teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab)
+
+ loss = loss_fn(
+ student_logits=student_logits,
+ teacher_logits=teacher_logits,
+ student_labels=batch["labels"],
+ teacher_labels=teacher_labels,
+ student_input_ids=batch["input_ids"],
+ teacher_input_ids=teacher_input_ids,
+ )
+
+ assert torch.isfinite(loss)
+
+
+def test_vlm_collator_label_masking(smolvlm_processor, vlm_examples):
+ """Verify that the VLM collator masks prompt and padding tokens in labels and leaves completion tokens unmasked."""
+ collator = DataCollatorForVisionLanguageChatML(processor=smolvlm_processor, max_length=2048)
+ batch = collator(vlm_examples)
+
+ input_ids = batch["input_ids"]
+ labels = batch["labels"]
+ attention_mask = batch["attention_mask"]
+
+ for i in range(input_ids.shape[0]):
+ # Padding tokens (attention_mask == 0) must be masked in labels
+ padding_positions = attention_mask[i] == 0
+ assert (labels[i][padding_positions] == -100).all(), "Padding tokens should be masked with -100"
+
+ # There must be at least one non-masked label (completion token)
+ completion_positions = labels[i] != -100
+ assert completion_positions.any(), "Each example must have at least one completion token in labels"
+
+ # Completion labels must match the corresponding input_ids
+ assert (labels[i][completion_positions] == input_ids[i][completion_positions]).all(), (
+ "Unmasked labels must match input_ids"
+ )
+
+ # Prompt tokens (attended but masked in labels) must exist — the prompt is never empty
+ prompt_positions = (attention_mask[i] == 1) & (labels[i] == -100)
+ assert prompt_positions.any(), "Each example must have masked prompt tokens"
+
+
+def test_gold_trainer_init_rejects_non_vlm_teacher(monkeypatch):
+ """GOLDTrainer should raise ValueError when the student is a VLM but the teacher is not."""
+
+ class DummyStudentModel:
+ def __init__(self):
+ self.config = SimpleNamespace(_name_or_path="student", vocab_size=17)
+ self.generation_config = SimpleNamespace(eos_token_id=2)
+ self.name_or_path = "student"
+
+ class DummyTeacherModel:
+ def __init__(self):
+ # No vision_config — looks like a text-only model
+ self.config = SimpleNamespace()
+ self.resized_to = None
+
+ def resize_token_embeddings(self, vocab_size):
+ self.resized_to = vocab_size
+
+ def fake_sft_init(
+ self,
+ model,
+ args=None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ compute_metrics=None,
+ callbacks=None,
+ optimizers=None,
+ preprocess_logits_for_metrics=None,
+ peft_config=None,
+ ):
+ del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers
+ del preprocess_logits_for_metrics, peft_config
+ self.model = model
+ self.args = args
+ self.processing_class = processing_class
+ self.accelerator = SimpleNamespace(
+ device=torch.device("cpu"),
+ num_processes=1,
+ prepare_model=lambda module, evaluation_mode=True: module,
+ )
+ self.is_deepspeed_enabled = False
+ self.is_fsdp_enabled = False
+
+ monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
+
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
+
+ vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]})
+
+ args = SimpleNamespace(
+ model_init_kwargs=None,
+ max_length=128,
+ use_liger_kernel=False,
+ teacher_model_init_kwargs=None,
+ use_uld_loss=False,
+ teacher_tokenizer_name_or_path=None,
+ teacher_model_revision=None,
+ disable_dropout=False,
+ lmbda=1.0,
+ beta=0.5,
+ temperature=1.0,
+ top_p=1.0,
+ seq_kd=False,
+ num_generations=1,
+ use_transformers_paged=False,
+ max_completion_length=16,
+ top_k=0,
+ log_completions=False,
+ log_completions_steps=100,
+ wandb_log_unique_prompts=True,
+ num_completions_to_print=None,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=1,
+ use_vllm=False,
+ )
+
+ with pytest.raises(ValueError, match="VLM distillation requires both student and teacher"):
+ GOLDTrainer(
+ model=DummyStudentModel(),
+ teacher_model=DummyTeacherModel(),
+ args=args,
+ train_dataset=vision_dataset,
+ processing_class=processor,
+ )
+
+
+def test_gold_trainer_vlm_vllm_init_uses_identity_collator(monkeypatch):
+ """When a VLM processor is used with lmbda > 0 and use_vllm=True, GOLDTrainer should use the identity collator
+ and store a _vlm_collator for on-the-fly collation. vLLM should be initialized with max_model_length from args."""
+ captured = {}
+
+ class DummyStudentModel:
+ def __init__(self):
+ self.config = SimpleNamespace(
+ _name_or_path="student", vocab_size=17, vision_config=True, model_type="dummy_vlm"
+ )
+ self.generation_config = SimpleNamespace(eos_token_id=2)
+ self.name_or_path = "student"
+
+ class DummyTeacherModel:
+ def __init__(self):
+ self.config = SimpleNamespace(vision_config=True, model_type="dummy_vlm")
+ self.resized_to = None
+
+ def resize_token_embeddings(self, vocab_size):
+ self.resized_to = vocab_size
+
+ def fake_sft_init(
+ self,
+ model,
+ args=None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ compute_metrics=None,
+ callbacks=None,
+ optimizers=None,
+ preprocess_logits_for_metrics=None,
+ peft_config=None,
+ ):
+ self.data_collator = data_collator
+ del train_dataset, eval_dataset, compute_metrics, callbacks, optimizers
+ del preprocess_logits_for_metrics, peft_config
+ self.model = model
+ self.args = args
+ self.processing_class = processing_class
+ self.accelerator = SimpleNamespace(
+ device=torch.device("cpu"),
+ num_processes=1,
+ prepare_model=lambda module, evaluation_mode=True: module,
+ )
+ self.is_deepspeed_enabled = False
+ self.is_fsdp_enabled = False
+
+ class CapturingVLLMGeneration:
+ def __init__(self, **kwargs):
+ captured.update(kwargs)
+
+ monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
+ monkeypatch.setattr(gold_trainer_module, "is_vllm_available", lambda: True)
+ monkeypatch.setattr(gold_trainer_module, "VLLMGeneration", CapturingVLLMGeneration)
+
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
+ if processor.tokenizer.pad_token is None:
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+
+ vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]})
+
+ args = SimpleNamespace(
+ model_init_kwargs=None,
+ max_length=128,
+ use_liger_kernel=False,
+ teacher_model_init_kwargs=None,
+ use_uld_loss=False,
+ teacher_tokenizer_name_or_path=None,
+ teacher_model_revision=None,
+ disable_dropout=False,
+ lmbda=1.0,
+ beta=0.5,
+ temperature=1.0,
+ top_p=1.0,
+ seq_kd=False,
+ num_generations=1,
+ use_transformers_paged=False,
+ max_completion_length=16,
+ top_k=0,
+ log_completions=False,
+ log_completions_steps=100,
+ wandb_log_unique_prompts=True,
+ num_completions_to_print=None,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=1,
+ use_vllm=True,
+ vllm_mode="colocate",
+ vllm_structured_outputs_regex=None,
+ vllm_server_base_url=None,
+ vllm_server_host="0.0.0.0",
+ vllm_server_port=8001,
+ vllm_group_port=51216,
+ vllm_server_timeout=240.0,
+ vllm_tensor_parallel_size=1,
+ vllm_gpu_memory_utilization=0.2,
+ vllm_max_model_length=None,
+ vllm_enable_sleep_mode=False,
+ vllm_model_impl="vllm",
+ vllm_sync_frequency=1,
+ )
+
+ teacher_model = DummyTeacherModel()
+ trainer = GOLDTrainer(
+ model=DummyStudentModel(),
+ teacher_model=teacher_model,
+ args=args,
+ train_dataset=vision_dataset,
+ processing_class=processor,
+ )
+
+ # Same assertions as text-only vLLM test
+ assert teacher_model.resized_to == 17
+ assert captured["max_model_length"] == 128
+
+ # VLM-specific: identity collator + _vlm_collator for on-the-fly use
+ assert trainer.data_collator is identity
+ assert trainer._vlm_collator is not None
+ assert isinstance(trainer._vlm_collator, DataCollatorForVisionLanguageChatML)
+
+
+def _make_dummy_vlm_models(student_model_type, teacher_model_type):
+ """Helper to create dummy student/teacher VLM models with specified model_type."""
+
+ class DummyStudentModel:
+ def __init__(self):
+ self.config = SimpleNamespace(
+ _name_or_path="student", vocab_size=17, vision_config=True, model_type=student_model_type
+ )
+ self.generation_config = SimpleNamespace(eos_token_id=2)
+ self.name_or_path = "student"
+
+ class DummyTeacherModel:
+ def __init__(self):
+ self.config = SimpleNamespace(_name_or_path="teacher", vision_config=True, model_type=teacher_model_type)
+ self.resized_to = None
+
+ def resize_token_embeddings(self, vocab_size):
+ self.resized_to = vocab_size
+
+ return DummyStudentModel(), DummyTeacherModel()
+
+
+def _make_vlm_trainer_args(use_vllm=False):
+ """Helper to create minimal GOLDTrainer args for VLM tests."""
+ return SimpleNamespace(
+ model_init_kwargs=None,
+ max_length=128,
+ use_liger_kernel=False,
+ teacher_model_init_kwargs=None,
+ use_uld_loss=False,
+ teacher_tokenizer_name_or_path=None,
+ teacher_model_revision=None,
+ disable_dropout=False,
+ lmbda=0.5,
+ beta=0.5,
+ temperature=1.0,
+ top_p=1.0,
+ seq_kd=False,
+ num_generations=1,
+ use_transformers_paged=False,
+ max_completion_length=16,
+ top_k=0,
+ log_completions=False,
+ log_completions_steps=100,
+ wandb_log_unique_prompts=True,
+ num_completions_to_print=None,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=1,
+ use_vllm=use_vllm,
+ vllm_mode="colocate",
+ vllm_structured_outputs_regex=None,
+ vllm_server_base_url=None,
+ vllm_server_host="0.0.0.0",
+ vllm_server_port=8001,
+ vllm_group_port=51216,
+ vllm_server_timeout=240.0,
+ vllm_tensor_parallel_size=1,
+ vllm_gpu_memory_utilization=0.2,
+ vllm_max_model_length=None,
+ vllm_enable_sleep_mode=False,
+ vllm_model_impl="vllm",
+ vllm_sync_frequency=1,
+ # ULD-specific defaults (needed when use_uld_loss=True)
+ uld_crossentropy_weight=0.5,
+ uld_distillation_weight=0.5,
+ uld_student_temperature=1.0,
+ uld_teacher_temperature=1.0,
+ uld_skip_student_eos=False,
+ uld_skip_teacher_eos=False,
+ use_extended_uld=False,
+ )
+
+
+def test_cross_architecture_vlm_without_uld_raises_error(monkeypatch):
+ """When student and teacher have different model_type and use_uld_loss=False, GOLDTrainer should raise
+ a ValueError telling the user to enable ULD loss."""
+
+ def fake_sft_init(
+ self,
+ model,
+ args=None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ compute_metrics=None,
+ callbacks=None,
+ optimizers=None,
+ preprocess_logits_for_metrics=None,
+ peft_config=None,
+ ):
+ self.data_collator = data_collator
+ self.model = model
+ self.args = args
+ self.processing_class = processing_class
+ self.accelerator = SimpleNamespace(
+ device=torch.device("cpu"),
+ num_processes=1,
+ prepare_model=lambda module, evaluation_mode=True: module,
+ )
+ self.is_deepspeed_enabled = False
+ self.is_fsdp_enabled = False
+
+ monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
+
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
+ if processor.tokenizer.pad_token is None:
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+
+ sentinel_processor = SimpleNamespace(_is_sentinel=True)
+ real_auto_processor_from_pretrained = AutoProcessor.from_pretrained
+
+ def patched_auto_processor(name, **kwargs):
+ if name == "teacher":
+ return sentinel_processor
+ return real_auto_processor_from_pretrained(name, **kwargs)
+
+ monkeypatch.setattr(gold_trainer_module.AutoProcessor, "from_pretrained", staticmethod(patched_auto_processor))
+
+ vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]})
+ student, teacher = _make_dummy_vlm_models("smolvlm", "qwen2_5_vl")
+ args = _make_vlm_trainer_args() # use_uld_loss=False by default
+
+ with pytest.raises(ValueError, match="Cross-architecture VLM distillation.*use_uld_loss=True"):
+ GOLDTrainer(
+ model=student,
+ teacher_model=teacher,
+ args=args,
+ train_dataset=vision_dataset,
+ processing_class=processor,
+ )
+
+
+def test_cross_architecture_vlm_with_uld_sets_teacher_processor(monkeypatch):
+ """When student and teacher have different model_type and use_uld_loss=True, GOLDTrainer should store
+ a separate _teacher_processor and emit a warning."""
+
+ def fake_sft_init(
+ self,
+ model,
+ args=None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ compute_metrics=None,
+ callbacks=None,
+ optimizers=None,
+ preprocess_logits_for_metrics=None,
+ peft_config=None,
+ ):
+ self.data_collator = data_collator
+ self.model = model
+ self.args = args
+ self.processing_class = processing_class
+ self.accelerator = SimpleNamespace(
+ device=torch.device("cpu"),
+ num_processes=1,
+ prepare_model=lambda module, evaluation_mode=True: module,
+ )
+ self.is_deepspeed_enabled = False
+ self.is_fsdp_enabled = False
+
+ monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
+
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
+ if processor.tokenizer.pad_token is None:
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+
+ sentinel_processor = SimpleNamespace(_is_sentinel=True)
+ real_auto_processor_from_pretrained = AutoProcessor.from_pretrained
+
+ def patched_auto_processor(name, **kwargs):
+ if name == "teacher":
+ return sentinel_processor
+ return real_auto_processor_from_pretrained(name, **kwargs)
+
+ monkeypatch.setattr(gold_trainer_module.AutoProcessor, "from_pretrained", staticmethod(patched_auto_processor))
+
+ # Monkeypatch AutoTokenizer.from_pretrained for ULD teacher tokenizer loading
+ sentinel_tokenizer = SimpleNamespace(pad_token="", eos_token="")
+ real_auto_tokenizer_from_pretrained = AutoTokenizer.from_pretrained
+
+ def patched_auto_tokenizer(name, **kwargs):
+ if name == "teacher":
+ return sentinel_tokenizer
+ return real_auto_tokenizer_from_pretrained(name, **kwargs)
+
+ monkeypatch.setattr(gold_trainer_module.AutoTokenizer, "from_pretrained", staticmethod(patched_auto_tokenizer))
+
+ vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]})
+ student, teacher = _make_dummy_vlm_models("smolvlm", "qwen2_5_vl")
+ args = _make_vlm_trainer_args()
+ args.use_uld_loss = True
+ args.teacher_tokenizer_name_or_path = "teacher"
+
+ import warnings
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ trainer = GOLDTrainer(
+ model=student,
+ teacher_model=teacher,
+ args=args,
+ train_dataset=vision_dataset,
+ processing_class=processor,
+ )
+
+ # _teacher_processor should be set for cross-architecture
+ assert trainer._teacher_processor is not None
+ assert trainer._teacher_processor is sentinel_processor
+
+ # A cross-architecture warning should have been emitted
+ cross_arch_warnings = [w for w in caught if "Cross-architecture VLM distillation" in str(w.message)]
+ assert len(cross_arch_warnings) == 1
+ assert "smolvlm" in str(cross_arch_warnings[0].message)
+ assert "qwen2_5_vl" in str(cross_arch_warnings[0].message)
+
+ # Identity collator and VLM collator should still be set
+ assert trainer.data_collator is identity
+ assert trainer._vlm_collator is not None
+
+
+def test_same_architecture_vlm_no_teacher_processor(monkeypatch):
+ """When student and teacher have the same model_type, GOLDTrainer should NOT store a _teacher_processor
+ (zero overhead -- both models share the same forward_kwargs)."""
+
+ def fake_sft_init(
+ self,
+ model,
+ args=None,
+ data_collator=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ compute_metrics=None,
+ callbacks=None,
+ optimizers=None,
+ preprocess_logits_for_metrics=None,
+ peft_config=None,
+ ):
+ self.data_collator = data_collator
+ self.model = model
+ self.args = args
+ self.processing_class = processing_class
+ self.accelerator = SimpleNamespace(
+ device=torch.device("cpu"),
+ num_processes=1,
+ prepare_model=lambda module, evaluation_mode=True: module,
+ )
+ self.is_deepspeed_enabled = False
+ self.is_fsdp_enabled = False
+
+ monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
+
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
+ if processor.tokenizer.pad_token is None:
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+
+ vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]})
+ student, teacher = _make_dummy_vlm_models("smolvlm", "smolvlm")
+ args = _make_vlm_trainer_args()
+
+ import warnings
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ trainer = GOLDTrainer(
+ model=student,
+ teacher_model=teacher,
+ args=args,
+ train_dataset=vision_dataset,
+ processing_class=processor,
+ )
+
+ # _teacher_processor should be None for same architecture (zero overhead)
+ assert trainer._teacher_processor is None
+
+ # No cross-architecture warning should have been emitted
+ cross_arch_warnings = [w for w in caught if "Cross-architecture VLM distillation" in str(w.message)]
+ assert len(cross_arch_warnings) == 0
+
+ # Identity collator and VLM collator should still be set
+ assert trainer.data_collator is identity
+ assert trainer._vlm_collator is not None
diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py
index 1af9eeae332..09f6689292d 100644
--- a/trl/experimental/gold/gold_config.py
+++ b/trl/experimental/gold/gold_config.py
@@ -119,6 +119,15 @@ class GOLDConfig(SFTConfig):
default=1e-7,
metadata={"help": "The initial learning rate for AdamW."},
)
+ # The default value remove_unused_columns is overwritten from the parent class, because in GOLD we usually rely on
+ # additional columns to compute the loss
+ remove_unused_columns: bool | None = field(
+ default=False,
+ metadata={
+ "help": "Whether to only keep the columns 'prompt' and 'completion' in the dataset. If you use a custom "
+ "dataset that requires additional columns, you should keep this to `False`."
+ },
+ )
# GOLD-specific parameters
temperature: float = field(
diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py
index caaeb9cdc9e..ac2e574409c 100644
--- a/trl/experimental/gold/gold_trainer.py
+++ b/trl/experimental/gold/gold_trainer.py
@@ -29,7 +29,7 @@
from accelerate.utils import DistributedType, broadcast_object_list, gather_object
from datasets import Dataset, IterableDataset
from torch.utils.data import DataLoader
-from transformers import AutoTokenizer, TrainerCallback
+from transformers import AutoConfig, AutoProcessor, AutoTokenizer, TrainerCallback
from transformers.data.data_collator import DataCollator
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.generation.configuration_utils import GenerationConfig
@@ -47,7 +47,12 @@
is_rich_available,
)
-from ...data_utils import is_conversational, maybe_convert_to_chatml, pack_dataset
+from ...data_utils import (
+ is_conversational,
+ maybe_convert_to_chatml,
+ pack_dataset,
+ prepare_multimodal_messages,
+)
from ...extras.profiling import profiling_decorator
from ...generation.vllm_generation import VLLMGeneration
from ...import_utils import is_vllm_available
@@ -58,10 +63,12 @@
RepeatSampler,
create_model_from_path,
disable_dropout_in_model,
+ get_config_model_id,
+ identity,
pad,
split_tensor_dict,
)
-from ..utils import DataCollatorForChatML, empty_cache, truncate_dataset
+from ..utils import DataCollatorForChatML, DataCollatorForVisionLanguageChatML, empty_cache, truncate_dataset
from .gold_config import GOLDConfig
@@ -219,7 +226,7 @@ def build_teacher_inputs_from_texts(
last_idx = valid.nonzero(as_tuple=True)[0][-1]
teacher_attention_mask[row, last_idx + 1 :] = False
- teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0
+ teacher_prompt_length = min(prompt_lengths) if prompt_lengths else 0
return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length
@@ -778,14 +785,102 @@ def __init__(
):
self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path
self.model_revision = (args.model_init_kwargs or {}).get("revision")
+ if train_dataset is None:
+ raise ValueError("`train_dataset` is required")
+ dataset_sample = next(iter(train_dataset))
+ if processing_class is None:
+ processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config))
+ # simplified logic from SFTTrainer
+ # Handle pad token for processors or tokenizers
+ if isinstance(processing_class, ProcessorMixin):
+ tokenizer = processing_class.tokenizer
+ self._is_vlm = True
+ else:
+ tokenizer = processing_class
+ self._is_vlm = False
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ self.pad_token_id = tokenizer.pad_token_id
+
+ # VLM distillation: only VLM-to-VLM is supported. Both student and teacher must be
+ # VLMs so that both receive images and multimodal inputs.
+ self._teacher_processor = None
+ if self._is_vlm and isinstance(teacher_model, str):
+ # Teacher not yet instantiated -- validate it's a VLM
+ teacher_proc = AutoProcessor.from_pretrained(teacher_model)
+ if not isinstance(teacher_proc, ProcessorMixin):
+ raise ValueError(
+ "VLM distillation requires both student and teacher to be vision-language models. "
+ "The student has a `ProcessorMixin` but the teacher does not."
+ )
+ # Check for cross-architecture VLM distillation
+ student_model_type = model.config.model_type if not isinstance(model, str) else None
+ teacher_model_type = AutoConfig.from_pretrained(teacher_model).model_type
+ if student_model_type and teacher_model_type != student_model_type:
+ warnings.warn(
+ f"Cross-architecture VLM distillation detected: student is '{student_model_type}', "
+ f"teacher is '{teacher_model_type}'. Images will be processed separately through each "
+ "model's processor, which may increase memory usage and computation time."
+ )
+ self._teacher_processor = teacher_proc
+ elif self._is_vlm and not isinstance(teacher_model, str):
+ # Teacher already instantiated — check if it looks like a VLM by checking for a vision config
+ if not hasattr(teacher_model, "config") or not hasattr(teacher_model.config, "vision_config"):
+ raise ValueError(
+ "VLM distillation requires both student and teacher to be vision-language models. "
+ "The student has a `ProcessorMixin` but the teacher model does not appear to be a VLM "
+ "(missing `vision_config`)."
+ )
+ # Check for cross-architecture VLM distillation
+ student_model_type = model.config.model_type if not isinstance(model, str) else None
+ teacher_model_type = teacher_model.config.model_type
+ if student_model_type and teacher_model_type != student_model_type:
+ warnings.warn(
+ f"Cross-architecture VLM distillation detected: student is '{student_model_type}', "
+ f"teacher is '{teacher_model_type}'. Images will be processed separately through each "
+ "model's processor, which may increase memory usage and computation time."
+ )
+ self._teacher_processor = AutoProcessor.from_pretrained(teacher_model.config._name_or_path)
+ if self._teacher_processor is not None and not args.use_uld_loss:
+ raise ValueError(
+ "Cross-architecture VLM distillation (student and teacher have different `model_type`) is not "
+ "supported with the standard JSD loss because the models require different image token formats "
+ "and tokenizers. Please set `use_uld_loss=True` in your GOLDConfig to enable cross-tokenizer "
+ "alignment via ULD loss."
+ )
+ self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
+ if self._is_vision_dataset and not self._is_vlm:
+ raise ValueError(
+ "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
+ "model does not seem to be a vision-language model. Please check your model and dataset."
+ )
- # Respect a user-provided data_collator; otherwise, provide a ChatML collator that
+ # Respect a user-provided data_collator; otherwise, pick the right collator based on modality.
+ # For VLMs, always use identity collator to preserve raw PIL images in the dataloader.
+ # Raw images are needed for: (1) vLLM generation, (2) cross-architecture teacher processing.
+ # A separate _vlm_collator is stored for on-the-fly collation inside _fill_buffer.
+ self._vlm_collator = None
if data_collator is None:
- data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
+ if self._is_vision_dataset:
+ self._vlm_collator = DataCollatorForVisionLanguageChatML(
+ processor=processing_class,
+ max_length=args.max_length,
+ )
+ data_collator = identity
+ else:
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
# Liger fused GKD loss (JSD)
self.use_liger_gkd_loss = False
if args.use_liger_kernel:
+ if self._is_vlm:
+ raise ValueError(
+ "Liger fused GKD loss is not supported with VLMs. The fused kernel operates on base decoder "
+ "hidden states, which is incompatible with VLM multimodal inputs (pixel_values, etc.). "
+ "Please set `use_liger_kernel=False`."
+ )
self.liger_jsd_loss = LigerFusedLinearJSDLoss(
beta=args.beta,
ignore_index=-100,
@@ -813,6 +908,8 @@ def __init__(
if args.use_uld_loss and args.teacher_tokenizer_name_or_path is None:
if isinstance(teacher_model, str):
args.teacher_tokenizer_name_or_path = teacher_model
+ elif hasattr(teacher_model, "config") and getattr(teacher_model.config, "_name_or_path", None):
+ args.teacher_tokenizer_name_or_path = teacher_model.config._name_or_path
else:
raise ValueError(
"`teacher_tokenizer_name_or_path` must be set when using ULD loss with a pre-instantiated teacher model."
@@ -889,7 +986,7 @@ def __init__(
if self.use_uld_loss:
self.uld_loss_fn = ULDLoss(
config=args,
- student_tokenizer=processing_class,
+ student_tokenizer=tokenizer,
teacher_tokenizer=self.teacher_tokenizer,
device=self.accelerator.device,
)
@@ -900,7 +997,7 @@ def __init__(
"top_p": args.top_p,
"do_sample": True,
"top_k": args.top_k,
- "pad_token_id": self.processing_class.pad_token_id,
+ "pad_token_id": self.pad_token_id,
}
self.generation_config = GenerationConfig(**generation_kwargs)
# Keep training-specific generation kwargs to overwrite model's original generation config
@@ -967,6 +1064,8 @@ def __init__(
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
required_columns = [
+ "prompt",
+ "completion",
"prompts",
"prompt_attention_mask",
"messages",
@@ -974,6 +1073,15 @@ def _set_signature_columns_if_needed(self):
"tools",
"original_prompt_text",
"original_completion_text",
+ "images",
+ "image",
+ "pixel_values",
+ "image_grid_thw",
+ "image_position_ids",
+ "pixel_attention_mask",
+ "image_sizes",
+ "token_type_ids",
+ "mm_token_type_ids",
]
if self._signature_columns is None:
self._signature_columns = required_columns
@@ -1058,8 +1166,8 @@ def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Ten
decoded_completion_tokens: list[list[int]] = []
for row in labels_cpu:
token_ids = row[row != -100].tolist()
- if self.processing_class.pad_token_id is not None:
- token_ids = [tok for tok in token_ids if tok != self.processing_class.pad_token_id]
+ if self.pad_token_id is not None:
+ token_ids = [tok for tok in token_ids if tok != self.pad_token_id]
decoded_completion_tokens.append(token_ids)
return self.processing_class.batch_decode(
@@ -1122,8 +1230,16 @@ def _build_sequence_batch(
return new_attention_mask, new_labels
@profiling_decorator
- def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int):
- slices = split_tensor_dict(generation_batch, buffer_steps)
+ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any] | list[dict], buffer_steps: int):
+ if self._vlm_collator is not None:
+ # Identity collator path: generation_batch is list[dict] with raw PIL images.
+ # Split into chunks via list slicing, then collate on-the-fly per slice.
+ chunk_size = len(generation_batch) // buffer_steps
+ raw_slices = [generation_batch[i * chunk_size : (i + 1) * chunk_size] for i in range(buffer_steps)]
+ slices = None # not used in this path
+ else:
+ raw_slices = None # not used in this path
+ slices = split_tensor_dict(generation_batch, buffer_steps)
if self.accelerator.is_main_process:
on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)]
@@ -1139,7 +1255,33 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s
for i, flag in enumerate(on_policy_flags):
if not flag:
- slice_inputs = slices[i]
+ if self._vlm_collator is not None:
+ # Extract raw images and prompts BEFORE collation, since the collator
+ # mutates examples in place (pops "image", overwrites "prompt").
+ raw_images = None
+ raw_prompts = None
+ if self._teacher_processor is not None:
+ raw_images = [
+ ex.get("images") or ([ex["image"]] if "image" in ex else None) for ex in raw_slices[i]
+ ]
+ raw_prompts = [
+ prepare_multimodal_messages(ex["prompt"], images=imgs)
+ if imgs is not None
+ else ex.get("prompt")
+ for ex, imgs in zip(raw_slices[i], raw_images, strict=True)
+ ]
+ # Collate raw examples on-the-fly for off-policy slices
+ slice_inputs = self._vlm_collator(raw_slices[i])
+ slice_inputs = {
+ k: v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v
+ for k, v in slice_inputs.items()
+ }
+ # Preserve raw PIL images and prompts for cross-architecture teacher processing
+ if self._teacher_processor is not None:
+ slice_inputs["_raw_images"] = raw_images
+ slice_inputs["_raw_prompts"] = raw_prompts
+ else:
+ slice_inputs = slices[i]
if self.use_uld_loss and self.teacher_tokenizer is not None:
slice_inputs = self._ensure_original_text_fields(slice_inputs)
@@ -1153,7 +1295,10 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s
self._buffered_inputs[i] = slice_inputs
if on_policy_indices:
- self._generate_on_policy_for_slices(slices, on_policy_indices)
+ if self._vlm_collator is not None:
+ self._generate_on_policy_vlm_raw(raw_slices, on_policy_indices)
+ else:
+ self._generate_on_policy_for_slices(slices, on_policy_indices)
@profiling_decorator
def _generate_on_policy_for_slices(
@@ -1190,6 +1335,8 @@ def _generate_on_policy_for_slices(
self.vllm_generation.sync_weights()
self._last_vllm_sync_step = self.state.global_step
+ # Text-only vLLM generation. VLM on-policy generation with raw images
+ # is handled by _generate_on_policy_vlm_raw (routed from _fill_buffer).
_, completion_ids, _, _ = self.vllm_generation.generate(
prompts=prompt_ids_list,
images=None,
@@ -1220,7 +1367,7 @@ def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | An
unwrapped_model,
slice_inputs,
self.generation_config,
- self.processing_class.pad_token_id,
+ self.pad_token_id,
)
new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result
@@ -1228,12 +1375,210 @@ def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | An
updated_slice["input_ids"] = new_input_ids
updated_slice["attention_mask"] = new_attention_mask
updated_slice["labels"] = new_labels
+ # Rebuild sequence-length-dependent keys to match new input_ids shape
+ new_seq_len = new_input_ids.shape[1]
+ prompt_seq_len = slice_inputs["prompts"].shape[1]
+ for k in ("token_type_ids", "mm_token_type_ids"):
+ if k in updated_slice:
+ prompt_part = updated_slice[k][:, :prompt_seq_len]
+ comp_part = torch.zeros(
+ new_input_ids.shape[0],
+ new_seq_len - prompt_seq_len,
+ dtype=updated_slice[k].dtype,
+ device=new_input_ids.device,
+ )
+ updated_slice[k] = torch.cat([prompt_part, comp_part], dim=1)
updated_slice["original_prompt_text"] = prompt_texts
updated_slice["original_completion_text"] = completion_texts
self._buffered_inputs[slice_idx] = updated_slice
self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts)
+ def _generate_on_policy_vlm_raw(self, raw_slices: list[list[dict]], on_policy_indices: list[int]):
+ """On-policy generation from raw VLM examples, preserving PIL images for vLLM."""
+ device = self.accelerator.device
+
+ # Phase 1: Collect prompts, images, and raw examples across all on-policy slices
+ all_prompt_ids = []
+ all_images = []
+ all_prompts = [] # prepared multimodal messages
+ all_raw_examples = []
+ local_slice_indices = []
+ slice_raw_data = {} # per-slice raw data for non-vLLM path
+
+ for slice_idx in on_policy_indices:
+ raw_examples = raw_slices[slice_idx]
+
+ # Extract raw PIL images from examples (like GRPOTrainer)
+ if "images" in raw_examples[0]:
+ images = [example.get("images") for example in raw_examples]
+ elif "image" in raw_examples[0]:
+ images = [
+ [example.get("image")] if example.get("image") is not None else None for example in raw_examples
+ ]
+ else:
+ images = None
+ if images is not None and all(img_list is None or img_list == [] for img_list in images):
+ images = None
+
+ # Extract prompts and prepare multimodal messages
+ prompts = [example["prompt"] for example in raw_examples]
+ if images is not None:
+ prompts = [
+ prepare_multimodal_messages(prompt, images=img_list)
+ for prompt, img_list in zip(prompts, images, strict=True)
+ ]
+
+ # Normalize string content to content blocks for VLM processors that don't handle plain strings
+ # copied from GRPOTrainer
+ prompts = [
+ [
+ {**msg, "content": [{"type": "text", "text": msg["content"]}]}
+ if isinstance(msg.get("content"), str)
+ else msg
+ for msg in prompt
+ ]
+ for prompt in prompts
+ ]
+
+ # Tokenize prompts to get prompt token IDs
+ # TODO: add self.tools support
+ tokenized = self.processing_class.apply_chat_template(
+ conversation=prompts,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding=True,
+ )
+ prompt_ids_list = [
+ [tok for tok, m in zip(ids, mask, strict=True) if m]
+ for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True)
+ ]
+
+ slice_raw_data[slice_idx] = (raw_examples, images, prompts, prompt_ids_list)
+
+ for i, example in enumerate(raw_examples):
+ all_prompt_ids.append(prompt_ids_list[i])
+ all_images.append(images[i] if images is not None else None)
+ all_prompts.append(prompts[i])
+ all_raw_examples.append(example)
+ local_slice_indices.append(slice_idx)
+
+ all_prompts_text = self.processing_class.batch_decode(all_prompt_ids, skip_special_tokens=True)
+ all_prompts_text_with_special = self.processing_class.batch_decode(all_prompt_ids, skip_special_tokens=False)
+
+ if not self.use_vllm:
+ # Non-vLLM path: generate per-slice using model.generate
+ for slice_idx in on_policy_indices:
+ raw_examples, images, prompts, _ = slice_raw_data[slice_idx]
+ collated = self._vlm_collator(raw_examples)
+ collated = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in collated.items()}
+ with unwrap_model_for_generation(
+ self.model, self.accelerator, generation_kwargs=self.generation_kwargs
+ ) as unwrapped_model:
+ result = self.generate_on_policy_outputs(
+ unwrapped_model, collated, self.generation_config, self.pad_token_id
+ )
+ new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result
+
+ updated_slice = dict(collated)
+ updated_slice["input_ids"] = new_input_ids
+ updated_slice["attention_mask"] = new_attention_mask
+ updated_slice["labels"] = new_labels
+ # Rebuild sequence-length-dependent keys to match new input_ids shape
+ new_seq_len = new_input_ids.shape[1]
+ prompt_seq_len = collated["prompts"].shape[1]
+ for k in ("token_type_ids", "mm_token_type_ids"):
+ if k in updated_slice:
+ prompt_part = updated_slice[k][:, :prompt_seq_len]
+ comp_part = torch.zeros(
+ new_input_ids.shape[0],
+ new_seq_len - prompt_seq_len,
+ dtype=updated_slice[k].dtype,
+ device=new_input_ids.device,
+ )
+ updated_slice[k] = torch.cat([prompt_part, comp_part], dim=1)
+ updated_slice["original_prompt_text"] = prompt_texts
+ updated_slice["original_completion_text"] = completion_texts
+ if self._teacher_processor is not None:
+ updated_slice["_raw_images"] = images
+ updated_slice["_raw_prompts"] = prompts
+
+ self._buffered_inputs[slice_idx] = updated_slice
+ self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts)
+ return
+
+ # vLLM path: one batched generate call across all slices
+ if (
+ self.state.global_step != self._last_vllm_sync_step
+ and self.state.global_step >= self._last_vllm_sync_step + self.vllm_sync_frequency
+ ):
+ self.vllm_generation.sync_weights()
+ self._last_vllm_sync_step = self.state.global_step
+
+ # Pass None for images if all entries are None
+ generate_images = all_images if any(img is not None for img in all_images) else None
+ _, completion_ids, _, _ = self.vllm_generation.generate(
+ prompts=all_prompt_ids,
+ images=generate_images,
+ num_generations=self.num_generations,
+ )
+
+ # Decode completions
+ max_completion_length = self.generation_config.max_new_tokens
+ all_completion_texts = []
+ for comp_ids in completion_ids:
+ if len(comp_ids) > max_completion_length:
+ comp_ids = comp_ids[:max_completion_length]
+ all_completion_texts.append(
+ self.processing_class.decode(comp_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ )
+
+ # Redistribute completions to slices. With num_generations > 1, each prompt produces
+ # multiple completions, so we repeat each raw example/prompt/image to match.
+ slice_completions = {idx: [] for idx in on_policy_indices}
+ slice_raw = {idx: [] for idx in on_policy_indices}
+ slice_images = {idx: [] for idx in on_policy_indices}
+ slice_prompts = {idx: [] for idx in on_policy_indices}
+ slice_prompts_text = {idx: [] for idx in on_policy_indices}
+ slice_prompts_text_special = {idx: [] for idx in on_policy_indices}
+
+ for i, slice_idx in enumerate(local_slice_indices):
+ for g in range(self.num_generations):
+ comp_idx = i * self.num_generations + g
+ slice_completions[slice_idx].append(all_completion_texts[comp_idx])
+ slice_raw[slice_idx].append(all_raw_examples[i])
+ slice_images[slice_idx].append(all_images[i])
+ slice_prompts[slice_idx].append(all_prompts[i])
+ slice_prompts_text[slice_idx].append(all_prompts_text[i])
+ slice_prompts_text_special[slice_idx].append(all_prompts_text_with_special[i])
+
+ for slice_idx in on_policy_indices:
+ completion_texts = slice_completions[slice_idx]
+ raw_for_slice = slice_raw[slice_idx]
+ images_for_slice = slice_images[slice_idx]
+ prompts_for_slice = slice_prompts[slice_idx]
+
+ # Build synthetic examples: original prompt + generated completion
+ synthetic_examples = []
+ for i, example in enumerate(raw_for_slice):
+ synthetic = dict(example)
+ synthetic["completion"] = [{"role": "assistant", "content": completion_texts[i]}]
+ synthetic_examples.append(synthetic)
+
+ # Collate synthetic examples to get pixel_values + properly tokenized input_ids/labels
+ collated = self._vlm_collator(synthetic_examples)
+ collated = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in collated.items()}
+ collated["original_prompt_text"] = slice_prompts_text_special[slice_idx]
+ collated["original_completion_text"] = completion_texts
+ if self._teacher_processor is not None:
+ has_images = any(img is not None for img in images_for_slice)
+ collated["_raw_images"] = images_for_slice if has_images else None
+ collated["_raw_prompts"] = prompts_for_slice
+
+ self._buffered_inputs[slice_idx] = collated
+ self._buffered_text_logs[slice_idx] = (slice_prompts_text[slice_idx], completion_texts)
+
def _process_completions_to_buffer(
self,
slices: list[dict[str, torch.Tensor | Any]],
@@ -1249,7 +1594,7 @@ def _process_completions_to_buffer(
Process vLLM completions and update buffered inputs for on-policy slices.
"""
device = self.accelerator.device
- pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0
+ pad_token_id = self.pad_token_id if self.pad_token_id is not None else 0
slice_completions = {idx: [] for idx in on_policy_indices}
slice_prompt_ids = {idx: [] for idx in on_policy_indices}
@@ -1367,6 +1712,11 @@ def _prepare_dataset(
dataset_name: str,
) -> Dataset | IterableDataset:
"""Preserve original text fields for ULD when needed."""
+ # For VLM datasets, skip dataset preparation entirely — the VLM collator handles tokenization
+ # and image processing on the fly, similar to how SFTTrainer skips prep for vision datasets.
+ if self._is_vision_dataset:
+ return dataset
+
column_names = list(next(iter(dataset)).keys())
is_processed = "input_ids" in column_names
@@ -1689,7 +2039,23 @@ def generalized_jsd_loss(
else:
return jsd
+ _MULTIMODAL_KEYS = (
+ "pixel_values",
+ "image_grid_thw",
+ "image_position_ids",
+ "pixel_attention_mask",
+ "image_sizes",
+ "token_type_ids",
+ "mm_token_type_ids",
+ )
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ # Extract multimodal fields for student forward passes
+ student_forward_kwargs = {k: inputs[k] for k in self._MULTIMODAL_KEYS if k in inputs}
+ # For same-architecture teacher reuses student vision tensors.
+ # For cross-architecture VLMs, this gets overridden in the ULD branch below.
+ teacher_forward_kwargs = student_forward_kwargs
+
if self.use_uld_loss and self.teacher_tokenizer is not None:
if "original_prompt_text" in inputs and "original_completion_text" in inputs:
prompt_texts = inputs["original_prompt_text"]
@@ -1708,16 +2074,60 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
full.replace(prompt, "", 1) for full, prompt in zip(full_texts, prompt_texts, strict=True)
]
- (
- teacher_input_ids,
- teacher_labels,
- teacher_attention_mask,
- teacher_prompt_length,
- ) = build_teacher_inputs_from_texts(
- self.teacher_tokenizer,
- prompt_texts,
- completion_texts,
- )
+ # For cross-architecture VLMs, build teacher inputs with image placeholders by processing
+ # prompts through the teacher's processor with raw images, then appending completions.
+ if self._teacher_processor is not None and "_raw_images" in inputs:
+ raw_images = inputs["_raw_images"]
+ raw_prompts = inputs["_raw_prompts"]
+ # Apply teacher's chat template to get prompt text with correct image placeholders
+ teacher_prompt_texts = self._teacher_processor.apply_chat_template(
+ raw_prompts, tokenize=False, add_generation_prompt=True
+ )
+ # Build full text (prompt + completion) and process in one call so all tensors
+ # (input_ids, attention_mask, mm_token_type_ids, pixel_values, ...) are aligned.
+ teacher_full_texts = [p + c for p, c in zip(teacher_prompt_texts, completion_texts, strict=True)]
+ teacher_full_processed = self._teacher_processor(
+ images=raw_images,
+ text=teacher_full_texts,
+ padding=True,
+ return_tensors="pt",
+ )
+ teacher_input_ids = teacher_full_processed["input_ids"]
+ teacher_attention_mask = teacher_full_processed["attention_mask"]
+ # Determine prompt lengths after image token expansion to build labels.
+ # Derive prompt lengths from total sequence length minus completion length.
+ # Completions are pure text (no images), so the tokenizer gives exact counts.
+ # This avoids a second image-processing pass through the teacher processor.
+ teacher_completion_token_lengths = [
+ len(self._teacher_processor.tokenizer(ct, add_special_tokens=False)["input_ids"])
+ for ct in completion_texts
+ ]
+ total_lengths = teacher_attention_mask.sum(dim=1)
+ teacher_prompt_token_lengths = [
+ int(total_lengths[i].item()) - cl for i, cl in enumerate(teacher_completion_token_lengths)
+ ]
+ teacher_labels = teacher_input_ids.clone()
+ teacher_labels[teacher_attention_mask == 0] = -100
+ for i, pl in enumerate(teacher_prompt_token_lengths):
+ teacher_labels[i, :pl] = -100
+ teacher_prompt_length = min(teacher_prompt_token_lengths)
+ # Override teacher_forward_kwargs with all multimodal keys from teacher processing
+ teacher_forward_kwargs = {
+ k: teacher_full_processed[k].to(self.accelerator.device)
+ for k in self._MULTIMODAL_KEYS
+ if k in teacher_full_processed
+ }
+ else:
+ (
+ teacher_input_ids,
+ teacher_labels,
+ teacher_attention_mask,
+ teacher_prompt_length,
+ ) = build_teacher_inputs_from_texts(
+ self.teacher_tokenizer,
+ prompt_texts,
+ completion_texts,
+ )
teacher_input_ids = teacher_input_ids.to(self.accelerator.device)
teacher_labels = teacher_labels.to(self.accelerator.device)
@@ -1727,6 +2137,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=False,
+ **student_forward_kwargs,
)
self.teacher_model.eval()
@@ -1734,10 +2145,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
outputs_teacher = self.teacher_model(
input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
+ **teacher_forward_kwargs,
)
# These are not used for ULD loss but are needed if JSD loss were to be used in this branch
- student_prompt_length = inputs["prompts"].shape[1]
+ # For VLMs, prompts are left-padded but input_ids are flushed left, so prompts.shape[1]
+ # would overcount. Derive prompt length from the flushed labels instead.
+ if self._is_vlm:
+ student_prompt_length = (inputs["labels"] != -100).long().argmax(dim=1).min().item()
+ else:
+ student_prompt_length = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, student_prompt_length - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, teacher_prompt_length - 1 : -1, :]
shifted_labels = inputs["labels"][:, student_prompt_length:]
@@ -1756,6 +2173,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=False,
+ **student_forward_kwargs,
)
self.teacher_model.eval()
@@ -1771,6 +2189,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=False,
+ **teacher_forward_kwargs,
)
student_hidden = student_outputs.last_hidden_state[:, :-1]
@@ -1805,6 +2224,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
outputs_student = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
+ **student_forward_kwargs,
)
self.teacher_model.eval()
@@ -1812,9 +2232,14 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
+ **teacher_forward_kwargs,
)
-
- prompt_lengths = inputs["prompts"].shape[1]
+ # Using the same prompt_lengths for teacher and student, since JSD can only be
+ # used with same-family VLMs (shared tokenizer).
+ if self._is_vlm:
+ prompt_lengths = (inputs["labels"] != -100).long().argmax(dim=1).min().item()
+ else:
+ prompt_lengths = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]
@@ -1833,8 +2258,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
teacher_input_ids_for_loss = teacher_input_ids if "teacher_input_ids" in locals() else inputs["input_ids"]
student_labels = inputs["labels"].clone()
- if hasattr(self.processing_class, "pad_token_id") and self.processing_class.pad_token_id is not None:
- student_labels[student_labels == self.processing_class.pad_token_id] = -100
+ if self.pad_token_id is not None:
+ student_labels[student_labels == self.pad_token_id] = -100
if (
hasattr(self, "teacher_tokenizer")
@@ -1898,11 +2323,19 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token
completion_ids = [output.generated_tokens for output in generated_outputs.values()]
generated_tokens = torch.stack([torch.tensor(ids, device=model.device) for ids in completion_ids])
else:
+ generate_kwargs = {k: inputs[k] for k in self._MULTIMODAL_KEYS if k in inputs}
+ # Slice sequence-length-dependent keys to prompt-only length (e.g. token_type_ids for Gemma,
+ # mm_token_type_ids for ERNIE-VL) since model.generate receives prompt-only input_ids
+ prompt_seq_len = inputs["prompts"].shape[1]
+ for k in ("token_type_ids", "mm_token_type_ids"):
+ if k in generate_kwargs:
+ generate_kwargs[k] = generate_kwargs[k][:, :prompt_seq_len]
generated_outputs = model.generate(
input_ids=inputs["prompts"],
attention_mask=inputs.get("prompt_attention_mask", None),
generation_config=generation_config,
return_dict_in_generate=True,
+ **generate_kwargs,
)
# Get the generated token IDs
generated_tokens = generated_outputs.sequences
@@ -1911,7 +2344,7 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token
device = generated_tokens.device
prompt_mask = inputs.get("prompt_attention_mask")
- pad_token_id = pad_token_id if pad_token_id is not None else self.processing_class.pad_token_id
+ pad_token_id = pad_token_id if pad_token_id is not None else self.pad_token_id
if self.use_transformers_paged:
# generate_batch() returns completion-only tokens, so the entire tensor is completion.
diff --git a/trl/experimental/utils.py b/trl/experimental/utils.py
index 5c2057d6f20..3c18412beb9 100644
--- a/trl/experimental/utils.py
+++ b/trl/experimental/utils.py
@@ -29,7 +29,9 @@
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments
+from transformers.data.data_collator import DataCollatorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+from transformers.processing_utils import ProcessorMixin
from transformers.utils import (
is_peft_available,
is_torch_mlu_available,
@@ -37,8 +39,14 @@
is_torch_xpu_available,
)
-from ..data_utils import DatasetType, _get_dataset_format
-from ..trainer.utils import pad
+from ..data_utils import (
+ DatasetType,
+ _get_dataset_format,
+ apply_chat_template,
+ is_conversational,
+ prepare_multimodal_messages,
+)
+from ..trainer.utils import flush_left, pad
if is_peft_available():
@@ -249,7 +257,11 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids]
prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask]
- prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
+ prompts_input_ids = pad(
+ prompts_input_ids,
+ padding_side="left",
+ padding_value=self.tokenizer.pad_token_id,
+ )
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0)
return {
@@ -261,6 +273,171 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
}
+@dataclass
+class DataCollatorForVisionLanguageChatML(DataCollatorMixin):
+ """
+ Data collator for GOLD VLM training.
+
+ Combines image processing from [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] with the
+ prompt-separation logic that GOLD needs for on-policy generation. Each input example should be a dictionary
+ containing at least:
+ - An `"images"` key holding a list of images, or an `"image"` key holding a single image.
+ - Keys `"prompt"` and `"completion"` for the prompt and completion (conversational or plain text).
+
+ The collator outputs a dictionary including:
+ - `"input_ids"`: Tensor of token IDs (prompt + completion, concatenated).
+ - `"attention_mask"`: Tensor indicating attention mask.
+ - `"labels"`: Tensor for training labels (prompt tokens masked with -100).
+ - `"prompts"`: Tensor of prompt-only token IDs (left-padded), used for on-policy generation.
+ - `"prompt_attention_mask"`: Attention mask for prompts.
+ - `"original_prompt_text"`: List of raw prompt text strings, used for ULD cross-tokenizer distillation.
+ - `"original_completion_text"`: List of raw completion text strings, used for ULD cross-tokenizer distillation.
+ - `"pixel_values"`: Tensor representing image pixel values.
+
+ Additional keys may be present depending on the processor, such as `"image_grid_thw"` or `"image_position_ids"`.
+
+ Args:
+ processor ([`~transformers.ProcessorMixin`]):
+ The processor used to tokenize text and process images.
+ max_length (`int` or `None`, *optional*):
+ Maximum sequence length for input tokens. If `None`, no truncation is applied.
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
+ The tensor type to return.
+ """
+
+ processor: ProcessorMixin
+ max_length: int | None = None
+ return_tensors: str = "pt"
+
+ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
+ if "prompt" not in examples[0] or "completion" not in examples[0]:
+ raise KeyError(
+ "DataCollatorForVisionLanguageChatML requires 'prompt' and 'completion' keys in examples. "
+ f"Got keys: {list(examples[0].keys())}."
+ )
+
+ # Normalize single image to list
+ if "image" in examples[0]:
+ for example in examples:
+ example["images"] = [example.pop("image")]
+ images = [example.get("images", []) for example in examples]
+ if all(img_list == [] for img_list in images):
+ images = None
+
+ # Apply chat template for conversational data
+ if is_conversational(examples[0]):
+ for example in examples:
+ example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"])
+ example["completion"] = prepare_multimodal_messages(example["completion"])
+ examples = [apply_chat_template(example, self.processor) for example in examples]
+
+ prompts = [example["prompt"] for example in examples]
+ completions = [example["completion"] for example in examples]
+
+ # Process prompts (with images) and completions (text only) separately
+ processed_prompts = self.processor(
+ images=images,
+ text=prompts,
+ padding=True,
+ padding_side="left",
+ return_tensors=self.return_tensors,
+ add_special_tokens=False,
+ )
+ processed_completions = self.processor(
+ text=completions,
+ padding=True,
+ padding_side="right",
+ return_tensors=self.return_tensors,
+ add_special_tokens=False,
+ )
+
+ # Concatenate prompts and completions
+ prompt_ids, prompt_mask = (
+ processed_prompts["input_ids"],
+ processed_prompts["attention_mask"],
+ )
+ completion_ids, completion_mask = (
+ processed_completions["input_ids"],
+ processed_completions["attention_mask"],
+ )
+ input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
+ attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
+ completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
+ if "token_type_ids" in processed_prompts:
+ prompt_token_type_ids = processed_prompts["token_type_ids"]
+ completion_token_type_ids = processed_completions["token_type_ids"]
+ token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
+ if "mm_token_type_ids" in processed_prompts:
+ prompt_mm_token_type_ids = processed_prompts["mm_token_type_ids"]
+ completion_mm_token_type_ids = processed_completions.get(
+ "mm_token_type_ids", torch.zeros_like(completion_ids)
+ )
+ mm_token_type_ids = torch.cat((prompt_mm_token_type_ids, completion_mm_token_type_ids), dim=1)
+
+ # Flush left to reduce padding
+ if "token_type_ids" in processed_prompts and "mm_token_type_ids" in processed_prompts:
+ (
+ attention_mask,
+ input_ids,
+ completion_mask,
+ token_type_ids,
+ mm_token_type_ids,
+ ) = flush_left(
+ attention_mask,
+ input_ids,
+ completion_mask,
+ token_type_ids,
+ mm_token_type_ids,
+ )
+ elif "token_type_ids" in processed_prompts:
+ attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
+ attention_mask, input_ids, completion_mask, token_type_ids
+ )
+ elif "mm_token_type_ids" in processed_prompts:
+ attention_mask, input_ids, completion_mask, mm_token_type_ids = flush_left(
+ attention_mask, input_ids, completion_mask, mm_token_type_ids
+ )
+ else:
+ attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
+
+ # Truncate if necessary
+ if self.max_length is not None:
+ input_ids = input_ids[:, : self.max_length]
+ attention_mask = attention_mask[:, : self.max_length]
+ completion_mask = completion_mask[:, : self.max_length]
+ if "token_type_ids" in processed_prompts:
+ token_type_ids = token_type_ids[:, : self.max_length]
+ if "mm_token_type_ids" in processed_prompts:
+ mm_token_type_ids = mm_token_type_ids[:, : self.max_length]
+
+ # Create labels: mask padding and prompt tokens
+ labels = input_ids.clone()
+ labels[attention_mask == 0] = -100
+ labels[completion_mask == 0] = -100
+
+ # Build output with vision keys from processed_prompts (pixel_values, image_grid_thw, etc.)
+ output = processed_prompts
+ output["input_ids"] = input_ids
+ output["attention_mask"] = attention_mask
+ output["labels"] = labels
+ if "token_type_ids" in processed_prompts:
+ output["token_type_ids"] = token_type_ids
+ if (
+ "mm_token_type_ids" in processed_prompts
+ ): # special case for ERNIE-VL from class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
+ output["mm_token_type_ids"] = mm_token_type_ids
+
+ # GOLD-specific: separate prompt tensors for on-policy generation
+ output["prompts"] = prompt_ids
+ output["prompt_attention_mask"] = prompt_mask
+
+ # GOLD-specific: raw text for ULD cross-tokenizer distillation
+ output["original_prompt_text"] = prompts
+ output["original_completion_text"] = completions
+
+ return output
+
+
def truncate_right(
input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -316,7 +493,9 @@ def add_bos_token_if_needed(
def add_eos_token_if_needed(
- eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]]
+ eos_token_id: int,
+ chosen_tokens: dict[str, list[int]],
+ rejected_tokens: dict[str, list[int]],
):
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
@@ -351,7 +530,10 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor:
def get_reward(
- model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
+ model: torch.nn.Module,
+ query_responses: torch.Tensor,
+ pad_token_id: int,
+ context_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the reward logits and the rewards for a given model and query responses.