diff --git a/examples/scripts/gold_vlm.py b/examples/scripts/gold_vlm.py new file mode 100644 index 00000000000..f50ad67a691 --- /dev/null +++ b/examples/scripts/gold_vlm.py @@ -0,0 +1,181 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GOLD VLM distillation on MMK12. + +# Example 1 — Same-family distillation (SmolVLM-500M → SmolVLM-256M) +# Uses JSD loss. Same architecture and tokenizer, so standard distillation works directly. +# vLLM enabled for faster on-policy generation. +accelerate launch examples/scripts/gold_vlm.py \ + --student_model_name HuggingFaceTB/SmolVLM-256M-Instruct \ + --teacher_model_name HuggingFaceTB/SmolVLM-500M-Instruct \ + --lmbda 0.5 \ + --use_vllm \ + --vllm_mode colocate + +# Example 2 — Cross-family distillation (Qwen2.5-VL-3B → SmolVLM-256M) +# Different architectures have incompatible tokenizers and image token formats, +# so ULD (Universal Logit Distillation) loss is required to align logits across vocabularies. +accelerate launch examples/scripts/gold_vlm.py \ + --student_model_name HuggingFaceTB/SmolVLM-256M-Instruct \ + --teacher_model_name Qwen/Qwen2.5-VL-3B-Instruct \ + --use_uld_loss \ + --lmbda 0.0 +""" + +import argparse + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoModelForImageTextToText, AutoProcessor + +from trl.experimental.gold import GOLDConfig, GOLDTrainer + + +SYSTEM_PROMPT = ( + "You are a helpful AI Assistant that provides well-reasoned and detailed responses. " + "You first think about the reasoning process as an internal monologue and then provide the user with the answer. " + "Respond in the following format: \n...\n\n\n...\n" +) + + +def make_conversation(example): + """Convert MMK12 row into the chat format expected by TRL VLM trainers.""" + return { + "prompt": [ + { + "role": "system", + "content": [{"type": "text", "text": SYSTEM_PROMPT}], + }, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": example["question"]}, + ], + }, + ], + "completion": [ + { + "role": "assistant", + "content": [{"type": "text", "text": str(example["answer"])}], + }, + ], + "image": example["image"], + } + + +def filter_big_images(example): + image = example["image"] + return image.size[0] < 512 and image.size[1] < 512 + + +def convert_to_rgb(example): + image = example["image"] + if image.mode != "RGB": + image = image.convert("RGB") + example["image"] = image + return example + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--student_model_name", type=str, default="HuggingFaceTB/SmolVLM-256M-Instruct") + parser.add_argument("--teacher_model_name", type=str, default="HuggingFaceTB/SmolVLM-500M-Instruct") + parser.add_argument("--use_uld_loss", action="store_true") + parser.add_argument("--lmbda", type=float, default=0.5) + parser.add_argument("--use_vllm", action="store_true") + parser.add_argument("--vllm_mode", type=str, default="colocate") + cli_args = parser.parse_args() + + # ────────────────────────────────────────────── + # Models + # ────────────────────────────────────────────── + student_model = AutoModelForImageTextToText.from_pretrained(cli_args.student_model_name, dtype=torch.bfloat16) + teacher_model = AutoModelForImageTextToText.from_pretrained(cli_args.teacher_model_name, dtype=torch.bfloat16) + + # Freeze everything except the language model head + for name, param in student_model.named_parameters(): + if "language_model" not in name: + param.requires_grad = False + + processor = AutoProcessor.from_pretrained(cli_args.student_model_name, padding_side="left") + + # toy example to fit small GPUs + peft_config = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.05, + target_modules=["q_proj"], + ) + + # ────────────────────────────────────────────── + # Dataset + # ────────────────────────────────────────────── + dataset = load_dataset("FanqingM/MMK12", split="train[:5%]") + dataset = dataset.filter(filter_big_images) + dataset = dataset.map(convert_to_rgb) + dataset = dataset.map(make_conversation) + + # ────────────────────────────────────────────── + # Training config + # ────────────────────────────────────────────── + args = GOLDConfig( + output_dir="gold-vlm-distillation", + # GOLD-specific + lmbda=cli_args.lmbda, + beta=0.5, + temperature=0.9, + max_completion_length=256, + teacher_model_name_or_path=cli_args.teacher_model_name, + num_generations=1, + use_uld_loss=cli_args.use_uld_loss, + # vLLM + use_vllm=cli_args.use_vllm, + vllm_mode=cli_args.vllm_mode, + vllm_gpu_memory_utilization=0.5, + vllm_max_model_length=8192, + # VLM image tokens expand during processing, so the default max_length (1024) is often too small. + # Which will lead to shifted_student_logits become an empty Tensor. + max_length=2048, + # Training schedule + per_device_train_batch_size=2, + gradient_accumulation_steps=4, + max_steps=100, + learning_rate=2e-5, + warmup_steps=10, + # Precision + bf16=True, + # Logging + logging_steps=1, + log_completions=True, + report_to="none", + ) + + # ────────────────────────────────────────────── + # Trainer + # ────────────────────────────────────────────── + trainer = GOLDTrainer( + model=student_model, + teacher_model=teacher_model, + args=args, + train_dataset=dataset, + processing_class=processor, + peft_config=peft_config, + ) + + trainer.train() + trainer.save_model(args.output_dir) diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py index d7e32056323..76df35cc03c 100644 --- a/tests/experimental/test_gold_trainer.py +++ b/tests/experimental/test_gold_trainer.py @@ -16,12 +16,13 @@ import pytest import torch -from datasets import load_dataset -from transformers import AutoTokenizer +from datasets import Dataset, load_dataset +from transformers import AutoProcessor, AutoTokenizer from trl.experimental.gold import gold_trainer as gold_trainer_module from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts -from trl.experimental.utils import DataCollatorForChatML +from trl.experimental.utils import DataCollatorForChatML, DataCollatorForVisionLanguageChatML +from trl.trainer.utils import identity @pytest.fixture(scope="module") @@ -271,6 +272,35 @@ def smollm_tokenizer(): return tokenizer +@pytest.fixture(scope="session") +def smolvlm_processor(): + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + return processor + + +@pytest.fixture(scope="session") +def qwen3_vl_processor(): + processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct") + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + return processor + + +@pytest.fixture(scope="module") +def vlm_examples(): + try: + dataset = load_dataset( + "trl-internal-testing/zen-image", + "conversational_prompt_completion", + split="train[:3]", + ) + except Exception as exc: # pragma: no cover - network/environment dependent + pytest.skip(f"zen-image dataset unavailable: {exc}") + return [dict(row) for row in dataset] + + def encode_prompt_completion(tokenizer, prompt, completion): prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] completion_ids = tokenizer(completion, add_special_tokens=False)["input_ids"] @@ -302,6 +332,7 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati trainer = GOLDTrainer.__new__(GOLDTrainer) trainer.accelerator = SimpleNamespace(device=torch.device("cpu")) trainer.processing_class = RecordingTokenizer() + trainer.pad_token_id = RecordingTokenizer.pad_token_id # __new__ bypasses __init__, set manually trainer.args = SimpleNamespace(max_length=None) trainer._buffered_inputs = [None] trainer._buffered_text_logs = [None] @@ -449,6 +480,7 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati trainer = GOLDTrainer.__new__(GOLDTrainer) trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True) trainer.processing_class = RecordingTokenizer() + trainer.pad_token_id = RecordingTokenizer.pad_token_id # __new__ bypasses __init__, set manually trainer.args = SimpleNamespace(max_length=None, report_to=[]) trainer.use_vllm = True trainer.vllm_generation = RecordingVLLMGeneration() @@ -498,6 +530,9 @@ def resize_token_embeddings(self, vocab_size): self.resized_to = vocab_size class DummyProcessingClass: + # GOLDTrainer.__init__ extracts tokenizer pad token (like GRPOTrainer), + # so the dummy must provide both pad_token and pad_token_id. + pad_token = "" pad_token_id = 0 def fake_sft_init( @@ -575,12 +610,17 @@ def __init__(self, **kwargs): vllm_sync_frequency=1, ) + # A minimal dataset is required because GOLDTrainer inspects the first example at init + # to detect whether the dataset contains images for VLM, so None dosn't pass + dummy_dataset = Dataset.from_dict({"messages": [["dummy"]]}) + teacher_model = DummyTeacherModel() GOLDTrainer( model=DummyStudentModel(), teacher_model=teacher_model, args=args, data_collator=object(), + train_dataset=dummy_dataset, processing_class=DummyProcessingClass(), ) @@ -942,3 +982,764 @@ def test_uldloss_hybrid_config_beta_zero(llama_tokenizer, qwen_tokenizer): expected = config.uld_hybrid_unmatched_weight * loss_fn.last_unmatched_loss torch.testing.assert_close(loss, expected, atol=1e-6, rtol=1e-5) + + +# ────────────────────────────────────────────────────────────────────────────── +# VLM tests +# ────────────────────────────────────────────────────────────────────────────── + + +def test_vlm_alignment_groups_cover_all_tokens_smolvlm_qwen3vl(smolvlm_processor, qwen3_vl_processor, vlm_examples): + student_tokenizer = smolvlm_processor.tokenizer + teacher_tokenizer = qwen3_vl_processor.tokenizer + + collator = DataCollatorForVisionLanguageChatML(processor=smolvlm_processor, max_length=2048) + batch = collator(vlm_examples) + + config = build_config() + loss = ULDLoss(config, student_tokenizer=student_tokenizer, teacher_tokenizer=teacher_tokenizer) + + teacher_input_ids, teacher_labels, _ = _teacher_inputs_from_collator(student_tokenizer, teacher_tokenizer, batch) + + _assert_alignment_covers_completion(loss, batch, teacher_input_ids, teacher_labels) + + +def test_gold_trainer_init_rejects_llm_with_vision_dataset(monkeypatch): + """GOLDTrainer should raise ValueError when a text-only model receives a vision dataset.""" + + class DummyStudentModel: + def __init__(self): + self.config = SimpleNamespace(_name_or_path="student", vocab_size=17) + self.generation_config = SimpleNamespace(eos_token_id=2) + self.name_or_path = "student" + + class DummyTeacherModel: + def __init__(self): + self.resized_to = None + + def resize_token_embeddings(self, vocab_size): + self.resized_to = vocab_size + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers + del preprocess_logits_for_metrics, peft_config + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Dataset with an "image" key triggers vision detection + vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]}) + + args = SimpleNamespace( + model_init_kwargs=None, + max_length=128, + use_liger_kernel=False, + teacher_model_init_kwargs=None, + use_uld_loss=False, + teacher_tokenizer_name_or_path=None, + teacher_model_revision=None, + disable_dropout=False, + lmbda=1.0, + beta=0.5, + temperature=1.0, + top_p=1.0, + seq_kd=False, + num_generations=1, + use_transformers_paged=False, + max_completion_length=16, + top_k=0, + log_completions=False, + log_completions_steps=100, + wandb_log_unique_prompts=True, + num_completions_to_print=None, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + use_vllm=False, + ) + + with pytest.raises(ValueError, match="vision-related"): + GOLDTrainer( + model=DummyStudentModel(), + teacher_model=DummyTeacherModel(), + args=args, + train_dataset=vision_dataset, + processing_class=tokenizer, + ) + + +def _get_assistant_texts(examples): + """Extract assistant text content from examples, handling both plain string and multimodal format.""" + texts = [] + for example in examples: + content = example["completion"][-1]["content"] + if isinstance(content, list): + texts.append("".join(part["text"] for part in content if "text" in part)) + else: + texts.append(content) + return texts + + +def test_vlm_chatml_collator_preserves_completion_smolvlm(smolvlm_processor, qwen3_vl_processor, vlm_examples): + # 2048 to not truncate the completion tokens + collator = DataCollatorForVisionLanguageChatML(processor=smolvlm_processor, max_length=2048) + batch = collator(vlm_examples) + + # Verify basic batch structure + assert "input_ids" in batch + assert "labels" in batch + assert "prompts" in batch + assert "prompt_attention_mask" in batch + assert "pixel_values" in batch + assert "original_prompt_text" in batch + assert "original_completion_text" in batch + + # Verify completions are preserved in decoded output + assistant_texts = _get_assistant_texts(vlm_examples) + decoded_batch = smolvlm_processor.tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False) + for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True): + assert assistant in decoded + + # Verify ULD cross-tokenizer distillation with teacher inputs + student_tokenizer = smolvlm_processor.tokenizer + teacher_tokenizer = qwen3_vl_processor.tokenizer + + teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator( + student_tokenizer, teacher_tokenizer, batch + ) + for completion, assistant in zip(completion_texts, assistant_texts, strict=True): + assert assistant.strip() in completion + assert completion.strip() + + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.6, + uld_hybrid_unmatched_weight=0.4, + ) + loss_fn = ULDLoss(config, student_tokenizer=student_tokenizer, teacher_tokenizer=teacher_tokenizer) + + _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels) + + torch.manual_seed(42) + student_vocab = len(student_tokenizer) + teacher_vocab = len(teacher_tokenizer) + batch_size, seq_len = batch["input_ids"].shape + student_logits = torch.randn(batch_size, seq_len, student_vocab) + teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=batch["labels"], + teacher_labels=teacher_labels, + student_input_ids=batch["input_ids"], + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + + +@pytest.mark.slow +def test_vlm_chatml_collator_preserves_completion_qwen3vl(smolvlm_processor, qwen3_vl_processor, vlm_examples): + collator = DataCollatorForVisionLanguageChatML(processor=qwen3_vl_processor, max_length=2048) + batch = collator(vlm_examples) + + # Verify basic batch structure + assert "input_ids" in batch + assert "labels" in batch + assert "prompts" in batch + assert "pixel_values" in batch + + # Verify completions are preserved in decoded output + assistant_texts = _get_assistant_texts(vlm_examples) + decoded_batch = qwen3_vl_processor.tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False) + for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True): + assert assistant in decoded + + # Verify ULD cross-tokenizer distillation with teacher inputs + student_tokenizer = qwen3_vl_processor.tokenizer + teacher_tokenizer = smolvlm_processor.tokenizer + + teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator( + student_tokenizer, teacher_tokenizer, batch + ) + for completion, assistant in zip(completion_texts, assistant_texts, strict=True): + assert assistant.strip() in completion + assert completion.strip() + + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.6, + uld_hybrid_unmatched_weight=0.4, + ) + loss_fn = ULDLoss(config, student_tokenizer=student_tokenizer, teacher_tokenizer=teacher_tokenizer) + + _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels) + + torch.manual_seed(43) + student_vocab = len(student_tokenizer) + teacher_vocab = len(teacher_tokenizer) + batch_size, seq_len = batch["input_ids"].shape + student_logits = torch.randn(batch_size, seq_len, student_vocab) + teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=batch["labels"], + teacher_labels=teacher_labels, + student_input_ids=batch["input_ids"], + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + + +def test_vlm_collator_label_masking(smolvlm_processor, vlm_examples): + """Verify that the VLM collator masks prompt and padding tokens in labels and leaves completion tokens unmasked.""" + collator = DataCollatorForVisionLanguageChatML(processor=smolvlm_processor, max_length=2048) + batch = collator(vlm_examples) + + input_ids = batch["input_ids"] + labels = batch["labels"] + attention_mask = batch["attention_mask"] + + for i in range(input_ids.shape[0]): + # Padding tokens (attention_mask == 0) must be masked in labels + padding_positions = attention_mask[i] == 0 + assert (labels[i][padding_positions] == -100).all(), "Padding tokens should be masked with -100" + + # There must be at least one non-masked label (completion token) + completion_positions = labels[i] != -100 + assert completion_positions.any(), "Each example must have at least one completion token in labels" + + # Completion labels must match the corresponding input_ids + assert (labels[i][completion_positions] == input_ids[i][completion_positions]).all(), ( + "Unmasked labels must match input_ids" + ) + + # Prompt tokens (attended but masked in labels) must exist — the prompt is never empty + prompt_positions = (attention_mask[i] == 1) & (labels[i] == -100) + assert prompt_positions.any(), "Each example must have masked prompt tokens" + + +def test_gold_trainer_init_rejects_non_vlm_teacher(monkeypatch): + """GOLDTrainer should raise ValueError when the student is a VLM but the teacher is not.""" + + class DummyStudentModel: + def __init__(self): + self.config = SimpleNamespace(_name_or_path="student", vocab_size=17) + self.generation_config = SimpleNamespace(eos_token_id=2) + self.name_or_path = "student" + + class DummyTeacherModel: + def __init__(self): + # No vision_config — looks like a text-only model + self.config = SimpleNamespace() + self.resized_to = None + + def resize_token_embeddings(self, vocab_size): + self.resized_to = vocab_size + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers + del preprocess_logits_for_metrics, peft_config + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + + vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]}) + + args = SimpleNamespace( + model_init_kwargs=None, + max_length=128, + use_liger_kernel=False, + teacher_model_init_kwargs=None, + use_uld_loss=False, + teacher_tokenizer_name_or_path=None, + teacher_model_revision=None, + disable_dropout=False, + lmbda=1.0, + beta=0.5, + temperature=1.0, + top_p=1.0, + seq_kd=False, + num_generations=1, + use_transformers_paged=False, + max_completion_length=16, + top_k=0, + log_completions=False, + log_completions_steps=100, + wandb_log_unique_prompts=True, + num_completions_to_print=None, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + use_vllm=False, + ) + + with pytest.raises(ValueError, match="VLM distillation requires both student and teacher"): + GOLDTrainer( + model=DummyStudentModel(), + teacher_model=DummyTeacherModel(), + args=args, + train_dataset=vision_dataset, + processing_class=processor, + ) + + +def test_gold_trainer_vlm_vllm_init_uses_identity_collator(monkeypatch): + """When a VLM processor is used with lmbda > 0 and use_vllm=True, GOLDTrainer should use the identity collator + and store a _vlm_collator for on-the-fly collation. vLLM should be initialized with max_model_length from args.""" + captured = {} + + class DummyStudentModel: + def __init__(self): + self.config = SimpleNamespace( + _name_or_path="student", vocab_size=17, vision_config=True, model_type="dummy_vlm" + ) + self.generation_config = SimpleNamespace(eos_token_id=2) + self.name_or_path = "student" + + class DummyTeacherModel: + def __init__(self): + self.config = SimpleNamespace(vision_config=True, model_type="dummy_vlm") + self.resized_to = None + + def resize_token_embeddings(self, vocab_size): + self.resized_to = vocab_size + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + self.data_collator = data_collator + del train_dataset, eval_dataset, compute_metrics, callbacks, optimizers + del preprocess_logits_for_metrics, peft_config + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + class CapturingVLLMGeneration: + def __init__(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + monkeypatch.setattr(gold_trainer_module, "is_vllm_available", lambda: True) + monkeypatch.setattr(gold_trainer_module, "VLLMGeneration", CapturingVLLMGeneration) + + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + + vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]}) + + args = SimpleNamespace( + model_init_kwargs=None, + max_length=128, + use_liger_kernel=False, + teacher_model_init_kwargs=None, + use_uld_loss=False, + teacher_tokenizer_name_or_path=None, + teacher_model_revision=None, + disable_dropout=False, + lmbda=1.0, + beta=0.5, + temperature=1.0, + top_p=1.0, + seq_kd=False, + num_generations=1, + use_transformers_paged=False, + max_completion_length=16, + top_k=0, + log_completions=False, + log_completions_steps=100, + wandb_log_unique_prompts=True, + num_completions_to_print=None, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + use_vllm=True, + vllm_mode="colocate", + vllm_structured_outputs_regex=None, + vllm_server_base_url=None, + vllm_server_host="0.0.0.0", + vllm_server_port=8001, + vllm_group_port=51216, + vllm_server_timeout=240.0, + vllm_tensor_parallel_size=1, + vllm_gpu_memory_utilization=0.2, + vllm_max_model_length=None, + vllm_enable_sleep_mode=False, + vllm_model_impl="vllm", + vllm_sync_frequency=1, + ) + + teacher_model = DummyTeacherModel() + trainer = GOLDTrainer( + model=DummyStudentModel(), + teacher_model=teacher_model, + args=args, + train_dataset=vision_dataset, + processing_class=processor, + ) + + # Same assertions as text-only vLLM test + assert teacher_model.resized_to == 17 + assert captured["max_model_length"] == 128 + + # VLM-specific: identity collator + _vlm_collator for on-the-fly use + assert trainer.data_collator is identity + assert trainer._vlm_collator is not None + assert isinstance(trainer._vlm_collator, DataCollatorForVisionLanguageChatML) + + +def _make_dummy_vlm_models(student_model_type, teacher_model_type): + """Helper to create dummy student/teacher VLM models with specified model_type.""" + + class DummyStudentModel: + def __init__(self): + self.config = SimpleNamespace( + _name_or_path="student", vocab_size=17, vision_config=True, model_type=student_model_type + ) + self.generation_config = SimpleNamespace(eos_token_id=2) + self.name_or_path = "student" + + class DummyTeacherModel: + def __init__(self): + self.config = SimpleNamespace(_name_or_path="teacher", vision_config=True, model_type=teacher_model_type) + self.resized_to = None + + def resize_token_embeddings(self, vocab_size): + self.resized_to = vocab_size + + return DummyStudentModel(), DummyTeacherModel() + + +def _make_vlm_trainer_args(use_vllm=False): + """Helper to create minimal GOLDTrainer args for VLM tests.""" + return SimpleNamespace( + model_init_kwargs=None, + max_length=128, + use_liger_kernel=False, + teacher_model_init_kwargs=None, + use_uld_loss=False, + teacher_tokenizer_name_or_path=None, + teacher_model_revision=None, + disable_dropout=False, + lmbda=0.5, + beta=0.5, + temperature=1.0, + top_p=1.0, + seq_kd=False, + num_generations=1, + use_transformers_paged=False, + max_completion_length=16, + top_k=0, + log_completions=False, + log_completions_steps=100, + wandb_log_unique_prompts=True, + num_completions_to_print=None, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + use_vllm=use_vllm, + vllm_mode="colocate", + vllm_structured_outputs_regex=None, + vllm_server_base_url=None, + vllm_server_host="0.0.0.0", + vllm_server_port=8001, + vllm_group_port=51216, + vllm_server_timeout=240.0, + vllm_tensor_parallel_size=1, + vllm_gpu_memory_utilization=0.2, + vllm_max_model_length=None, + vllm_enable_sleep_mode=False, + vllm_model_impl="vllm", + vllm_sync_frequency=1, + # ULD-specific defaults (needed when use_uld_loss=True) + uld_crossentropy_weight=0.5, + uld_distillation_weight=0.5, + uld_student_temperature=1.0, + uld_teacher_temperature=1.0, + uld_skip_student_eos=False, + uld_skip_teacher_eos=False, + use_extended_uld=False, + ) + + +def test_cross_architecture_vlm_without_uld_raises_error(monkeypatch): + """When student and teacher have different model_type and use_uld_loss=False, GOLDTrainer should raise + a ValueError telling the user to enable ULD loss.""" + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + self.data_collator = data_collator + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + + sentinel_processor = SimpleNamespace(_is_sentinel=True) + real_auto_processor_from_pretrained = AutoProcessor.from_pretrained + + def patched_auto_processor(name, **kwargs): + if name == "teacher": + return sentinel_processor + return real_auto_processor_from_pretrained(name, **kwargs) + + monkeypatch.setattr(gold_trainer_module.AutoProcessor, "from_pretrained", staticmethod(patched_auto_processor)) + + vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]}) + student, teacher = _make_dummy_vlm_models("smolvlm", "qwen2_5_vl") + args = _make_vlm_trainer_args() # use_uld_loss=False by default + + with pytest.raises(ValueError, match="Cross-architecture VLM distillation.*use_uld_loss=True"): + GOLDTrainer( + model=student, + teacher_model=teacher, + args=args, + train_dataset=vision_dataset, + processing_class=processor, + ) + + +def test_cross_architecture_vlm_with_uld_sets_teacher_processor(monkeypatch): + """When student and teacher have different model_type and use_uld_loss=True, GOLDTrainer should store + a separate _teacher_processor and emit a warning.""" + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + self.data_collator = data_collator + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + + sentinel_processor = SimpleNamespace(_is_sentinel=True) + real_auto_processor_from_pretrained = AutoProcessor.from_pretrained + + def patched_auto_processor(name, **kwargs): + if name == "teacher": + return sentinel_processor + return real_auto_processor_from_pretrained(name, **kwargs) + + monkeypatch.setattr(gold_trainer_module.AutoProcessor, "from_pretrained", staticmethod(patched_auto_processor)) + + # Monkeypatch AutoTokenizer.from_pretrained for ULD teacher tokenizer loading + sentinel_tokenizer = SimpleNamespace(pad_token="", eos_token="") + real_auto_tokenizer_from_pretrained = AutoTokenizer.from_pretrained + + def patched_auto_tokenizer(name, **kwargs): + if name == "teacher": + return sentinel_tokenizer + return real_auto_tokenizer_from_pretrained(name, **kwargs) + + monkeypatch.setattr(gold_trainer_module.AutoTokenizer, "from_pretrained", staticmethod(patched_auto_tokenizer)) + + vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]}) + student, teacher = _make_dummy_vlm_models("smolvlm", "qwen2_5_vl") + args = _make_vlm_trainer_args() + args.use_uld_loss = True + args.teacher_tokenizer_name_or_path = "teacher" + + import warnings + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + trainer = GOLDTrainer( + model=student, + teacher_model=teacher, + args=args, + train_dataset=vision_dataset, + processing_class=processor, + ) + + # _teacher_processor should be set for cross-architecture + assert trainer._teacher_processor is not None + assert trainer._teacher_processor is sentinel_processor + + # A cross-architecture warning should have been emitted + cross_arch_warnings = [w for w in caught if "Cross-architecture VLM distillation" in str(w.message)] + assert len(cross_arch_warnings) == 1 + assert "smolvlm" in str(cross_arch_warnings[0].message) + assert "qwen2_5_vl" in str(cross_arch_warnings[0].message) + + # Identity collator and VLM collator should still be set + assert trainer.data_collator is identity + assert trainer._vlm_collator is not None + + +def test_same_architecture_vlm_no_teacher_processor(monkeypatch): + """When student and teacher have the same model_type, GOLDTrainer should NOT store a _teacher_processor + (zero overhead -- both models share the same forward_kwargs).""" + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + self.data_collator = data_collator + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + if processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + + vision_dataset = Dataset.from_dict({"messages": [["dummy"]], "image": ["fake_image"]}) + student, teacher = _make_dummy_vlm_models("smolvlm", "smolvlm") + args = _make_vlm_trainer_args() + + import warnings + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + trainer = GOLDTrainer( + model=student, + teacher_model=teacher, + args=args, + train_dataset=vision_dataset, + processing_class=processor, + ) + + # _teacher_processor should be None for same architecture (zero overhead) + assert trainer._teacher_processor is None + + # No cross-architecture warning should have been emitted + cross_arch_warnings = [w for w in caught if "Cross-architecture VLM distillation" in str(w.message)] + assert len(cross_arch_warnings) == 0 + + # Identity collator and VLM collator should still be set + assert trainer.data_collator is identity + assert trainer._vlm_collator is not None diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 1af9eeae332..09f6689292d 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -119,6 +119,15 @@ class GOLDConfig(SFTConfig): default=1e-7, metadata={"help": "The initial learning rate for AdamW."}, ) + # The default value remove_unused_columns is overwritten from the parent class, because in GOLD we usually rely on + # additional columns to compute the loss + remove_unused_columns: bool | None = field( + default=False, + metadata={ + "help": "Whether to only keep the columns 'prompt' and 'completion' in the dataset. If you use a custom " + "dataset that requires additional columns, you should keep this to `False`." + }, + ) # GOLD-specific parameters temperature: float = field( diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index caaeb9cdc9e..ac2e574409c 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -29,7 +29,7 @@ from accelerate.utils import DistributedType, broadcast_object_list, gather_object from datasets import Dataset, IterableDataset from torch.utils.data import DataLoader -from transformers import AutoTokenizer, TrainerCallback +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, TrainerCallback from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -47,7 +47,12 @@ is_rich_available, ) -from ...data_utils import is_conversational, maybe_convert_to_chatml, pack_dataset +from ...data_utils import ( + is_conversational, + maybe_convert_to_chatml, + pack_dataset, + prepare_multimodal_messages, +) from ...extras.profiling import profiling_decorator from ...generation.vllm_generation import VLLMGeneration from ...import_utils import is_vllm_available @@ -58,10 +63,12 @@ RepeatSampler, create_model_from_path, disable_dropout_in_model, + get_config_model_id, + identity, pad, split_tensor_dict, ) -from ..utils import DataCollatorForChatML, empty_cache, truncate_dataset +from ..utils import DataCollatorForChatML, DataCollatorForVisionLanguageChatML, empty_cache, truncate_dataset from .gold_config import GOLDConfig @@ -219,7 +226,7 @@ def build_teacher_inputs_from_texts( last_idx = valid.nonzero(as_tuple=True)[0][-1] teacher_attention_mask[row, last_idx + 1 :] = False - teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 + teacher_prompt_length = min(prompt_lengths) if prompt_lengths else 0 return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length @@ -778,14 +785,102 @@ def __init__( ): self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path self.model_revision = (args.model_init_kwargs or {}).get("revision") + if train_dataset is None: + raise ValueError("`train_dataset` is required") + dataset_sample = next(iter(train_dataset)) + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config)) + # simplified logic from SFTTrainer + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + else: + tokenizer = processing_class + self._is_vlm = False + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token_id = tokenizer.pad_token_id + + # VLM distillation: only VLM-to-VLM is supported. Both student and teacher must be + # VLMs so that both receive images and multimodal inputs. + self._teacher_processor = None + if self._is_vlm and isinstance(teacher_model, str): + # Teacher not yet instantiated -- validate it's a VLM + teacher_proc = AutoProcessor.from_pretrained(teacher_model) + if not isinstance(teacher_proc, ProcessorMixin): + raise ValueError( + "VLM distillation requires both student and teacher to be vision-language models. " + "The student has a `ProcessorMixin` but the teacher does not." + ) + # Check for cross-architecture VLM distillation + student_model_type = model.config.model_type if not isinstance(model, str) else None + teacher_model_type = AutoConfig.from_pretrained(teacher_model).model_type + if student_model_type and teacher_model_type != student_model_type: + warnings.warn( + f"Cross-architecture VLM distillation detected: student is '{student_model_type}', " + f"teacher is '{teacher_model_type}'. Images will be processed separately through each " + "model's processor, which may increase memory usage and computation time." + ) + self._teacher_processor = teacher_proc + elif self._is_vlm and not isinstance(teacher_model, str): + # Teacher already instantiated — check if it looks like a VLM by checking for a vision config + if not hasattr(teacher_model, "config") or not hasattr(teacher_model.config, "vision_config"): + raise ValueError( + "VLM distillation requires both student and teacher to be vision-language models. " + "The student has a `ProcessorMixin` but the teacher model does not appear to be a VLM " + "(missing `vision_config`)." + ) + # Check for cross-architecture VLM distillation + student_model_type = model.config.model_type if not isinstance(model, str) else None + teacher_model_type = teacher_model.config.model_type + if student_model_type and teacher_model_type != student_model_type: + warnings.warn( + f"Cross-architecture VLM distillation detected: student is '{student_model_type}', " + f"teacher is '{teacher_model_type}'. Images will be processed separately through each " + "model's processor, which may increase memory usage and computation time." + ) + self._teacher_processor = AutoProcessor.from_pretrained(teacher_model.config._name_or_path) + if self._teacher_processor is not None and not args.use_uld_loss: + raise ValueError( + "Cross-architecture VLM distillation (student and teacher have different `model_type`) is not " + "supported with the standard JSD loss because the models require different image token formats " + "and tokenizers. Please set `use_uld_loss=True` in your GOLDConfig to enable cross-tokenizer " + "alignment via ULD loss." + ) + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + if self._is_vision_dataset and not self._is_vlm: + raise ValueError( + "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " + "model does not seem to be a vision-language model. Please check your model and dataset." + ) - # Respect a user-provided data_collator; otherwise, provide a ChatML collator that + # Respect a user-provided data_collator; otherwise, pick the right collator based on modality. + # For VLMs, always use identity collator to preserve raw PIL images in the dataloader. + # Raw images are needed for: (1) vLLM generation, (2) cross-architecture teacher processing. + # A separate _vlm_collator is stored for on-the-fly collation inside _fill_buffer. + self._vlm_collator = None if data_collator is None: - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) + if self._is_vision_dataset: + self._vlm_collator = DataCollatorForVisionLanguageChatML( + processor=processing_class, + max_length=args.max_length, + ) + data_collator = identity + else: + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) # Liger fused GKD loss (JSD) self.use_liger_gkd_loss = False if args.use_liger_kernel: + if self._is_vlm: + raise ValueError( + "Liger fused GKD loss is not supported with VLMs. The fused kernel operates on base decoder " + "hidden states, which is incompatible with VLM multimodal inputs (pixel_values, etc.). " + "Please set `use_liger_kernel=False`." + ) self.liger_jsd_loss = LigerFusedLinearJSDLoss( beta=args.beta, ignore_index=-100, @@ -813,6 +908,8 @@ def __init__( if args.use_uld_loss and args.teacher_tokenizer_name_or_path is None: if isinstance(teacher_model, str): args.teacher_tokenizer_name_or_path = teacher_model + elif hasattr(teacher_model, "config") and getattr(teacher_model.config, "_name_or_path", None): + args.teacher_tokenizer_name_or_path = teacher_model.config._name_or_path else: raise ValueError( "`teacher_tokenizer_name_or_path` must be set when using ULD loss with a pre-instantiated teacher model." @@ -889,7 +986,7 @@ def __init__( if self.use_uld_loss: self.uld_loss_fn = ULDLoss( config=args, - student_tokenizer=processing_class, + student_tokenizer=tokenizer, teacher_tokenizer=self.teacher_tokenizer, device=self.accelerator.device, ) @@ -900,7 +997,7 @@ def __init__( "top_p": args.top_p, "do_sample": True, "top_k": args.top_k, - "pad_token_id": self.processing_class.pad_token_id, + "pad_token_id": self.pad_token_id, } self.generation_config = GenerationConfig(**generation_kwargs) # Keep training-specific generation kwargs to overwrite model's original generation config @@ -967,6 +1064,8 @@ def __init__( def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() required_columns = [ + "prompt", + "completion", "prompts", "prompt_attention_mask", "messages", @@ -974,6 +1073,15 @@ def _set_signature_columns_if_needed(self): "tools", "original_prompt_text", "original_completion_text", + "images", + "image", + "pixel_values", + "image_grid_thw", + "image_position_ids", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", ] if self._signature_columns is None: self._signature_columns = required_columns @@ -1058,8 +1166,8 @@ def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Ten decoded_completion_tokens: list[list[int]] = [] for row in labels_cpu: token_ids = row[row != -100].tolist() - if self.processing_class.pad_token_id is not None: - token_ids = [tok for tok in token_ids if tok != self.processing_class.pad_token_id] + if self.pad_token_id is not None: + token_ids = [tok for tok in token_ids if tok != self.pad_token_id] decoded_completion_tokens.append(token_ids) return self.processing_class.batch_decode( @@ -1122,8 +1230,16 @@ def _build_sequence_batch( return new_attention_mask, new_labels @profiling_decorator - def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): - slices = split_tensor_dict(generation_batch, buffer_steps) + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any] | list[dict], buffer_steps: int): + if self._vlm_collator is not None: + # Identity collator path: generation_batch is list[dict] with raw PIL images. + # Split into chunks via list slicing, then collate on-the-fly per slice. + chunk_size = len(generation_batch) // buffer_steps + raw_slices = [generation_batch[i * chunk_size : (i + 1) * chunk_size] for i in range(buffer_steps)] + slices = None # not used in this path + else: + raw_slices = None # not used in this path + slices = split_tensor_dict(generation_batch, buffer_steps) if self.accelerator.is_main_process: on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] @@ -1139,7 +1255,33 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s for i, flag in enumerate(on_policy_flags): if not flag: - slice_inputs = slices[i] + if self._vlm_collator is not None: + # Extract raw images and prompts BEFORE collation, since the collator + # mutates examples in place (pops "image", overwrites "prompt"). + raw_images = None + raw_prompts = None + if self._teacher_processor is not None: + raw_images = [ + ex.get("images") or ([ex["image"]] if "image" in ex else None) for ex in raw_slices[i] + ] + raw_prompts = [ + prepare_multimodal_messages(ex["prompt"], images=imgs) + if imgs is not None + else ex.get("prompt") + for ex, imgs in zip(raw_slices[i], raw_images, strict=True) + ] + # Collate raw examples on-the-fly for off-policy slices + slice_inputs = self._vlm_collator(raw_slices[i]) + slice_inputs = { + k: v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v + for k, v in slice_inputs.items() + } + # Preserve raw PIL images and prompts for cross-architecture teacher processing + if self._teacher_processor is not None: + slice_inputs["_raw_images"] = raw_images + slice_inputs["_raw_prompts"] = raw_prompts + else: + slice_inputs = slices[i] if self.use_uld_loss and self.teacher_tokenizer is not None: slice_inputs = self._ensure_original_text_fields(slice_inputs) @@ -1153,7 +1295,10 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s self._buffered_inputs[i] = slice_inputs if on_policy_indices: - self._generate_on_policy_for_slices(slices, on_policy_indices) + if self._vlm_collator is not None: + self._generate_on_policy_vlm_raw(raw_slices, on_policy_indices) + else: + self._generate_on_policy_for_slices(slices, on_policy_indices) @profiling_decorator def _generate_on_policy_for_slices( @@ -1190,6 +1335,8 @@ def _generate_on_policy_for_slices( self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step + # Text-only vLLM generation. VLM on-policy generation with raw images + # is handled by _generate_on_policy_vlm_raw (routed from _fill_buffer). _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, @@ -1220,7 +1367,7 @@ def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | An unwrapped_model, slice_inputs, self.generation_config, - self.processing_class.pad_token_id, + self.pad_token_id, ) new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result @@ -1228,12 +1375,210 @@ def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | An updated_slice["input_ids"] = new_input_ids updated_slice["attention_mask"] = new_attention_mask updated_slice["labels"] = new_labels + # Rebuild sequence-length-dependent keys to match new input_ids shape + new_seq_len = new_input_ids.shape[1] + prompt_seq_len = slice_inputs["prompts"].shape[1] + for k in ("token_type_ids", "mm_token_type_ids"): + if k in updated_slice: + prompt_part = updated_slice[k][:, :prompt_seq_len] + comp_part = torch.zeros( + new_input_ids.shape[0], + new_seq_len - prompt_seq_len, + dtype=updated_slice[k].dtype, + device=new_input_ids.device, + ) + updated_slice[k] = torch.cat([prompt_part, comp_part], dim=1) updated_slice["original_prompt_text"] = prompt_texts updated_slice["original_completion_text"] = completion_texts self._buffered_inputs[slice_idx] = updated_slice self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + def _generate_on_policy_vlm_raw(self, raw_slices: list[list[dict]], on_policy_indices: list[int]): + """On-policy generation from raw VLM examples, preserving PIL images for vLLM.""" + device = self.accelerator.device + + # Phase 1: Collect prompts, images, and raw examples across all on-policy slices + all_prompt_ids = [] + all_images = [] + all_prompts = [] # prepared multimodal messages + all_raw_examples = [] + local_slice_indices = [] + slice_raw_data = {} # per-slice raw data for non-vLLM path + + for slice_idx in on_policy_indices: + raw_examples = raw_slices[slice_idx] + + # Extract raw PIL images from examples (like GRPOTrainer) + if "images" in raw_examples[0]: + images = [example.get("images") for example in raw_examples] + elif "image" in raw_examples[0]: + images = [ + [example.get("image")] if example.get("image") is not None else None for example in raw_examples + ] + else: + images = None + if images is not None and all(img_list is None or img_list == [] for img_list in images): + images = None + + # Extract prompts and prepare multimodal messages + prompts = [example["prompt"] for example in raw_examples] + if images is not None: + prompts = [ + prepare_multimodal_messages(prompt, images=img_list) + for prompt, img_list in zip(prompts, images, strict=True) + ] + + # Normalize string content to content blocks for VLM processors that don't handle plain strings + # copied from GRPOTrainer + prompts = [ + [ + {**msg, "content": [{"type": "text", "text": msg["content"]}]} + if isinstance(msg.get("content"), str) + else msg + for msg in prompt + ] + for prompt in prompts + ] + + # Tokenize prompts to get prompt token IDs + # TODO: add self.tools support + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=True, + ) + prompt_ids_list = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + + slice_raw_data[slice_idx] = (raw_examples, images, prompts, prompt_ids_list) + + for i, example in enumerate(raw_examples): + all_prompt_ids.append(prompt_ids_list[i]) + all_images.append(images[i] if images is not None else None) + all_prompts.append(prompts[i]) + all_raw_examples.append(example) + local_slice_indices.append(slice_idx) + + all_prompts_text = self.processing_class.batch_decode(all_prompt_ids, skip_special_tokens=True) + all_prompts_text_with_special = self.processing_class.batch_decode(all_prompt_ids, skip_special_tokens=False) + + if not self.use_vllm: + # Non-vLLM path: generate per-slice using model.generate + for slice_idx in on_policy_indices: + raw_examples, images, prompts, _ = slice_raw_data[slice_idx] + collated = self._vlm_collator(raw_examples) + collated = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in collated.items()} + with unwrap_model_for_generation( + self.model, self.accelerator, generation_kwargs=self.generation_kwargs + ) as unwrapped_model: + result = self.generate_on_policy_outputs( + unwrapped_model, collated, self.generation_config, self.pad_token_id + ) + new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + + updated_slice = dict(collated) + updated_slice["input_ids"] = new_input_ids + updated_slice["attention_mask"] = new_attention_mask + updated_slice["labels"] = new_labels + # Rebuild sequence-length-dependent keys to match new input_ids shape + new_seq_len = new_input_ids.shape[1] + prompt_seq_len = collated["prompts"].shape[1] + for k in ("token_type_ids", "mm_token_type_ids"): + if k in updated_slice: + prompt_part = updated_slice[k][:, :prompt_seq_len] + comp_part = torch.zeros( + new_input_ids.shape[0], + new_seq_len - prompt_seq_len, + dtype=updated_slice[k].dtype, + device=new_input_ids.device, + ) + updated_slice[k] = torch.cat([prompt_part, comp_part], dim=1) + updated_slice["original_prompt_text"] = prompt_texts + updated_slice["original_completion_text"] = completion_texts + if self._teacher_processor is not None: + updated_slice["_raw_images"] = images + updated_slice["_raw_prompts"] = prompts + + self._buffered_inputs[slice_idx] = updated_slice + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + return + + # vLLM path: one batched generate call across all slices + if ( + self.state.global_step != self._last_vllm_sync_step + and self.state.global_step >= self._last_vllm_sync_step + self.vllm_sync_frequency + ): + self.vllm_generation.sync_weights() + self._last_vllm_sync_step = self.state.global_step + + # Pass None for images if all entries are None + generate_images = all_images if any(img is not None for img in all_images) else None + _, completion_ids, _, _ = self.vllm_generation.generate( + prompts=all_prompt_ids, + images=generate_images, + num_generations=self.num_generations, + ) + + # Decode completions + max_completion_length = self.generation_config.max_new_tokens + all_completion_texts = [] + for comp_ids in completion_ids: + if len(comp_ids) > max_completion_length: + comp_ids = comp_ids[:max_completion_length] + all_completion_texts.append( + self.processing_class.decode(comp_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + ) + + # Redistribute completions to slices. With num_generations > 1, each prompt produces + # multiple completions, so we repeat each raw example/prompt/image to match. + slice_completions = {idx: [] for idx in on_policy_indices} + slice_raw = {idx: [] for idx in on_policy_indices} + slice_images = {idx: [] for idx in on_policy_indices} + slice_prompts = {idx: [] for idx in on_policy_indices} + slice_prompts_text = {idx: [] for idx in on_policy_indices} + slice_prompts_text_special = {idx: [] for idx in on_policy_indices} + + for i, slice_idx in enumerate(local_slice_indices): + for g in range(self.num_generations): + comp_idx = i * self.num_generations + g + slice_completions[slice_idx].append(all_completion_texts[comp_idx]) + slice_raw[slice_idx].append(all_raw_examples[i]) + slice_images[slice_idx].append(all_images[i]) + slice_prompts[slice_idx].append(all_prompts[i]) + slice_prompts_text[slice_idx].append(all_prompts_text[i]) + slice_prompts_text_special[slice_idx].append(all_prompts_text_with_special[i]) + + for slice_idx in on_policy_indices: + completion_texts = slice_completions[slice_idx] + raw_for_slice = slice_raw[slice_idx] + images_for_slice = slice_images[slice_idx] + prompts_for_slice = slice_prompts[slice_idx] + + # Build synthetic examples: original prompt + generated completion + synthetic_examples = [] + for i, example in enumerate(raw_for_slice): + synthetic = dict(example) + synthetic["completion"] = [{"role": "assistant", "content": completion_texts[i]}] + synthetic_examples.append(synthetic) + + # Collate synthetic examples to get pixel_values + properly tokenized input_ids/labels + collated = self._vlm_collator(synthetic_examples) + collated = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in collated.items()} + collated["original_prompt_text"] = slice_prompts_text_special[slice_idx] + collated["original_completion_text"] = completion_texts + if self._teacher_processor is not None: + has_images = any(img is not None for img in images_for_slice) + collated["_raw_images"] = images_for_slice if has_images else None + collated["_raw_prompts"] = prompts_for_slice + + self._buffered_inputs[slice_idx] = collated + self._buffered_text_logs[slice_idx] = (slice_prompts_text[slice_idx], completion_texts) + def _process_completions_to_buffer( self, slices: list[dict[str, torch.Tensor | Any]], @@ -1249,7 +1594,7 @@ def _process_completions_to_buffer( Process vLLM completions and update buffered inputs for on-policy slices. """ device = self.accelerator.device - pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 + pad_token_id = self.pad_token_id if self.pad_token_id is not None else 0 slice_completions = {idx: [] for idx in on_policy_indices} slice_prompt_ids = {idx: [] for idx in on_policy_indices} @@ -1367,6 +1712,11 @@ def _prepare_dataset( dataset_name: str, ) -> Dataset | IterableDataset: """Preserve original text fields for ULD when needed.""" + # For VLM datasets, skip dataset preparation entirely — the VLM collator handles tokenization + # and image processing on the fly, similar to how SFTTrainer skips prep for vision datasets. + if self._is_vision_dataset: + return dataset + column_names = list(next(iter(dataset)).keys()) is_processed = "input_ids" in column_names @@ -1689,7 +2039,23 @@ def generalized_jsd_loss( else: return jsd + _MULTIMODAL_KEYS = ( + "pixel_values", + "image_grid_thw", + "image_position_ids", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ) + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Extract multimodal fields for student forward passes + student_forward_kwargs = {k: inputs[k] for k in self._MULTIMODAL_KEYS if k in inputs} + # For same-architecture teacher reuses student vision tensors. + # For cross-architecture VLMs, this gets overridden in the ULD branch below. + teacher_forward_kwargs = student_forward_kwargs + if self.use_uld_loss and self.teacher_tokenizer is not None: if "original_prompt_text" in inputs and "original_completion_text" in inputs: prompt_texts = inputs["original_prompt_text"] @@ -1708,16 +2074,60 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N full.replace(prompt, "", 1) for full, prompt in zip(full_texts, prompt_texts, strict=True) ] - ( - teacher_input_ids, - teacher_labels, - teacher_attention_mask, - teacher_prompt_length, - ) = build_teacher_inputs_from_texts( - self.teacher_tokenizer, - prompt_texts, - completion_texts, - ) + # For cross-architecture VLMs, build teacher inputs with image placeholders by processing + # prompts through the teacher's processor with raw images, then appending completions. + if self._teacher_processor is not None and "_raw_images" in inputs: + raw_images = inputs["_raw_images"] + raw_prompts = inputs["_raw_prompts"] + # Apply teacher's chat template to get prompt text with correct image placeholders + teacher_prompt_texts = self._teacher_processor.apply_chat_template( + raw_prompts, tokenize=False, add_generation_prompt=True + ) + # Build full text (prompt + completion) and process in one call so all tensors + # (input_ids, attention_mask, mm_token_type_ids, pixel_values, ...) are aligned. + teacher_full_texts = [p + c for p, c in zip(teacher_prompt_texts, completion_texts, strict=True)] + teacher_full_processed = self._teacher_processor( + images=raw_images, + text=teacher_full_texts, + padding=True, + return_tensors="pt", + ) + teacher_input_ids = teacher_full_processed["input_ids"] + teacher_attention_mask = teacher_full_processed["attention_mask"] + # Determine prompt lengths after image token expansion to build labels. + # Derive prompt lengths from total sequence length minus completion length. + # Completions are pure text (no images), so the tokenizer gives exact counts. + # This avoids a second image-processing pass through the teacher processor. + teacher_completion_token_lengths = [ + len(self._teacher_processor.tokenizer(ct, add_special_tokens=False)["input_ids"]) + for ct in completion_texts + ] + total_lengths = teacher_attention_mask.sum(dim=1) + teacher_prompt_token_lengths = [ + int(total_lengths[i].item()) - cl for i, cl in enumerate(teacher_completion_token_lengths) + ] + teacher_labels = teacher_input_ids.clone() + teacher_labels[teacher_attention_mask == 0] = -100 + for i, pl in enumerate(teacher_prompt_token_lengths): + teacher_labels[i, :pl] = -100 + teacher_prompt_length = min(teacher_prompt_token_lengths) + # Override teacher_forward_kwargs with all multimodal keys from teacher processing + teacher_forward_kwargs = { + k: teacher_full_processed[k].to(self.accelerator.device) + for k in self._MULTIMODAL_KEYS + if k in teacher_full_processed + } + else: + ( + teacher_input_ids, + teacher_labels, + teacher_attention_mask, + teacher_prompt_length, + ) = build_teacher_inputs_from_texts( + self.teacher_tokenizer, + prompt_texts, + completion_texts, + ) teacher_input_ids = teacher_input_ids.to(self.accelerator.device) teacher_labels = teacher_labels.to(self.accelerator.device) @@ -1727,6 +2137,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], use_cache=False, + **student_forward_kwargs, ) self.teacher_model.eval() @@ -1734,10 +2145,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N outputs_teacher = self.teacher_model( input_ids=teacher_input_ids, attention_mask=teacher_attention_mask, + **teacher_forward_kwargs, ) # These are not used for ULD loss but are needed if JSD loss were to be used in this branch - student_prompt_length = inputs["prompts"].shape[1] + # For VLMs, prompts are left-padded but input_ids are flushed left, so prompts.shape[1] + # would overcount. Derive prompt length from the flushed labels instead. + if self._is_vlm: + student_prompt_length = (inputs["labels"] != -100).long().argmax(dim=1).min().item() + else: + student_prompt_length = inputs["prompts"].shape[1] shifted_student_logits = outputs_student.logits[:, student_prompt_length - 1 : -1, :] shifted_teacher_logits = outputs_teacher.logits[:, teacher_prompt_length - 1 : -1, :] shifted_labels = inputs["labels"][:, student_prompt_length:] @@ -1756,6 +2173,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], use_cache=False, + **student_forward_kwargs, ) self.teacher_model.eval() @@ -1771,6 +2189,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], use_cache=False, + **teacher_forward_kwargs, ) student_hidden = student_outputs.last_hidden_state[:, :-1] @@ -1805,6 +2224,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N outputs_student = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + **student_forward_kwargs, ) self.teacher_model.eval() @@ -1812,9 +2232,14 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N outputs_teacher = self.teacher_model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + **teacher_forward_kwargs, ) - - prompt_lengths = inputs["prompts"].shape[1] + # Using the same prompt_lengths for teacher and student, since JSD can only be + # used with same-family VLMs (shared tokenizer). + if self._is_vlm: + prompt_lengths = (inputs["labels"] != -100).long().argmax(dim=1).min().item() + else: + prompt_lengths = inputs["prompts"].shape[1] shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] shifted_labels = inputs["labels"][:, prompt_lengths:] @@ -1833,8 +2258,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_input_ids_for_loss = teacher_input_ids if "teacher_input_ids" in locals() else inputs["input_ids"] student_labels = inputs["labels"].clone() - if hasattr(self.processing_class, "pad_token_id") and self.processing_class.pad_token_id is not None: - student_labels[student_labels == self.processing_class.pad_token_id] = -100 + if self.pad_token_id is not None: + student_labels[student_labels == self.pad_token_id] = -100 if ( hasattr(self, "teacher_tokenizer") @@ -1898,11 +2323,19 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token completion_ids = [output.generated_tokens for output in generated_outputs.values()] generated_tokens = torch.stack([torch.tensor(ids, device=model.device) for ids in completion_ids]) else: + generate_kwargs = {k: inputs[k] for k in self._MULTIMODAL_KEYS if k in inputs} + # Slice sequence-length-dependent keys to prompt-only length (e.g. token_type_ids for Gemma, + # mm_token_type_ids for ERNIE-VL) since model.generate receives prompt-only input_ids + prompt_seq_len = inputs["prompts"].shape[1] + for k in ("token_type_ids", "mm_token_type_ids"): + if k in generate_kwargs: + generate_kwargs[k] = generate_kwargs[k][:, :prompt_seq_len] generated_outputs = model.generate( input_ids=inputs["prompts"], attention_mask=inputs.get("prompt_attention_mask", None), generation_config=generation_config, return_dict_in_generate=True, + **generate_kwargs, ) # Get the generated token IDs generated_tokens = generated_outputs.sequences @@ -1911,7 +2344,7 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token device = generated_tokens.device prompt_mask = inputs.get("prompt_attention_mask") - pad_token_id = pad_token_id if pad_token_id is not None else self.processing_class.pad_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.pad_token_id if self.use_transformers_paged: # generate_batch() returns completion-only tokens, so the entire tensor is completion. diff --git a/trl/experimental/utils.py b/trl/experimental/utils.py index 5c2057d6f20..3c18412beb9 100644 --- a/trl/experimental/utils.py +++ b/trl/experimental/utils.py @@ -29,7 +29,9 @@ from torch import nn from torch.nn.utils.rnn import pad_sequence from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments +from transformers.data.data_collator import DataCollatorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.processing_utils import ProcessorMixin from transformers.utils import ( is_peft_available, is_torch_mlu_available, @@ -37,8 +39,14 @@ is_torch_xpu_available, ) -from ..data_utils import DatasetType, _get_dataset_format -from ..trainer.utils import pad +from ..data_utils import ( + DatasetType, + _get_dataset_format, + apply_chat_template, + is_conversational, + prepare_multimodal_messages, +) +from ..trainer.utils import flush_left, pad if is_peft_available(): @@ -249,7 +257,11 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] - prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) + prompts_input_ids = pad( + prompts_input_ids, + padding_side="left", + padding_value=self.tokenizer.pad_token_id, + ) prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) return { @@ -261,6 +273,171 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: } +@dataclass +class DataCollatorForVisionLanguageChatML(DataCollatorMixin): + """ + Data collator for GOLD VLM training. + + Combines image processing from [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] with the + prompt-separation logic that GOLD needs for on-policy generation. Each input example should be a dictionary + containing at least: + - An `"images"` key holding a list of images, or an `"image"` key holding a single image. + - Keys `"prompt"` and `"completion"` for the prompt and completion (conversational or plain text). + + The collator outputs a dictionary including: + - `"input_ids"`: Tensor of token IDs (prompt + completion, concatenated). + - `"attention_mask"`: Tensor indicating attention mask. + - `"labels"`: Tensor for training labels (prompt tokens masked with -100). + - `"prompts"`: Tensor of prompt-only token IDs (left-padded), used for on-policy generation. + - `"prompt_attention_mask"`: Attention mask for prompts. + - `"original_prompt_text"`: List of raw prompt text strings, used for ULD cross-tokenizer distillation. + - `"original_completion_text"`: List of raw completion text strings, used for ULD cross-tokenizer distillation. + - `"pixel_values"`: Tensor representing image pixel values. + + Additional keys may be present depending on the processor, such as `"image_grid_thw"` or `"image_position_ids"`. + + Args: + processor ([`~transformers.ProcessorMixin`]): + The processor used to tokenize text and process images. + max_length (`int` or `None`, *optional*): + Maximum sequence length for input tokens. If `None`, no truncation is applied. + return_tensors (`str`, *optional*, defaults to `"pt"`): + The tensor type to return. + """ + + processor: ProcessorMixin + max_length: int | None = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + if "prompt" not in examples[0] or "completion" not in examples[0]: + raise KeyError( + "DataCollatorForVisionLanguageChatML requires 'prompt' and 'completion' keys in examples. " + f"Got keys: {list(examples[0].keys())}." + ) + + # Normalize single image to list + if "image" in examples[0]: + for example in examples: + example["images"] = [example.pop("image")] + images = [example.get("images", []) for example in examples] + if all(img_list == [] for img_list in images): + images = None + + # Apply chat template for conversational data + if is_conversational(examples[0]): + for example in examples: + example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"]) + example["completion"] = prepare_multimodal_messages(example["completion"]) + examples = [apply_chat_template(example, self.processor) for example in examples] + + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + # Process prompts (with images) and completions (text only) separately + processed_prompts = self.processor( + images=images, + text=prompts, + padding=True, + padding_side="left", + return_tensors=self.return_tensors, + add_special_tokens=False, + ) + processed_completions = self.processor( + text=completions, + padding=True, + padding_side="right", + return_tensors=self.return_tensors, + add_special_tokens=False, + ) + + # Concatenate prompts and completions + prompt_ids, prompt_mask = ( + processed_prompts["input_ids"], + processed_prompts["attention_mask"], + ) + completion_ids, completion_mask = ( + processed_completions["input_ids"], + processed_completions["attention_mask"], + ) + input_ids = torch.cat((prompt_ids, completion_ids), dim=1) + attention_mask = torch.cat((prompt_mask, completion_mask), dim=1) + completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1) + if "token_type_ids" in processed_prompts: + prompt_token_type_ids = processed_prompts["token_type_ids"] + completion_token_type_ids = processed_completions["token_type_ids"] + token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1) + if "mm_token_type_ids" in processed_prompts: + prompt_mm_token_type_ids = processed_prompts["mm_token_type_ids"] + completion_mm_token_type_ids = processed_completions.get( + "mm_token_type_ids", torch.zeros_like(completion_ids) + ) + mm_token_type_ids = torch.cat((prompt_mm_token_type_ids, completion_mm_token_type_ids), dim=1) + + # Flush left to reduce padding + if "token_type_ids" in processed_prompts and "mm_token_type_ids" in processed_prompts: + ( + attention_mask, + input_ids, + completion_mask, + token_type_ids, + mm_token_type_ids, + ) = flush_left( + attention_mask, + input_ids, + completion_mask, + token_type_ids, + mm_token_type_ids, + ) + elif "token_type_ids" in processed_prompts: + attention_mask, input_ids, completion_mask, token_type_ids = flush_left( + attention_mask, input_ids, completion_mask, token_type_ids + ) + elif "mm_token_type_ids" in processed_prompts: + attention_mask, input_ids, completion_mask, mm_token_type_ids = flush_left( + attention_mask, input_ids, completion_mask, mm_token_type_ids + ) + else: + attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + + # Truncate if necessary + if self.max_length is not None: + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + completion_mask = completion_mask[:, : self.max_length] + if "token_type_ids" in processed_prompts: + token_type_ids = token_type_ids[:, : self.max_length] + if "mm_token_type_ids" in processed_prompts: + mm_token_type_ids = mm_token_type_ids[:, : self.max_length] + + # Create labels: mask padding and prompt tokens + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + labels[completion_mask == 0] = -100 + + # Build output with vision keys from processed_prompts (pixel_values, image_grid_thw, etc.) + output = processed_prompts + output["input_ids"] = input_ids + output["attention_mask"] = attention_mask + output["labels"] = labels + if "token_type_ids" in processed_prompts: + output["token_type_ids"] = token_type_ids + if ( + "mm_token_type_ids" in processed_prompts + ): # special case for ERNIE-VL from class DataCollatorForVisionLanguageModeling(DataCollatorMixin): + output["mm_token_type_ids"] = mm_token_type_ids + + # GOLD-specific: separate prompt tensors for on-policy generation + output["prompts"] = prompt_ids + output["prompt_attention_mask"] = prompt_mask + + # GOLD-specific: raw text for ULD cross-tokenizer distillation + output["original_prompt_text"] = prompts + output["original_completion_text"] = completions + + return output + + def truncate_right( input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int ) -> tuple[torch.Tensor, torch.Tensor]: @@ -316,7 +493,9 @@ def add_bos_token_if_needed( def add_eos_token_if_needed( - eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]] + eos_token_id: int, + chosen_tokens: dict[str, list[int]], + rejected_tokens: dict[str, list[int]], ): if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: chosen_tokens["input_ids"].append(eos_token_id) @@ -351,7 +530,10 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor: def get_reward( - model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int + model: torch.nn.Module, + query_responses: torch.Tensor, + pad_token_id: int, + context_length: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the reward logits and the rewards for a given model and query responses.