diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 9fe2857b08..c368fd3286 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1563,13 +1563,13 @@ Papers relating to training a student model with the help of a teacher model. **📜 Paper**: https://huggingface.co/papers/2306.13649 -Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration: +Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.distillation.DistillationTrainer`] and [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration: ```python -from trl.experimental.gkd import GKDConfig +from trl.experimental.distillation import DistillationConfig # XSum summarization task (Table A.1 of the paper) -training_args = GKDConfig( +training_args = DistillationConfig( lmbda=0.5, # λ student data fraction (Section 3 of the paper) beta=0.5, # β Generalized JSD interpolation, 0=KL, 1=reverse KL (Section 3 of the paper) temperature=1.0, # student training temperature (Appendix A of the paper) @@ -1577,7 +1577,7 @@ training_args = GKDConfig( learning_rate=3e-4, # learning rate (Table A.1 of the paper) per_device_train_batch_size=32, # batch size (Table A.1 of the paper) warmup_steps=2000, # warm-up steps (Table A.1 of the paper) - max_new_tokens=64, # max output tokens (Table A.1 of the paper) + max_completion_length=64, # max output tokens (Table A.1 of the paper) ) ``` @@ -1597,20 +1597,31 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data. -To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]: +To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.distillation.DistillationTrainer`] and [`experimental.distillation.DistillationConfig`]: + +```python +from trl.experimental.distillation import DistillationConfig + +training_args = DistillationConfig( + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model +) +``` + +Alternatively, you can use the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]: ```python from trl.experimental.gkd import GKDConfig training_args = GKDConfig( - lmbda=1.0, # student produces rollouts for all batches - beta=1.0, # to ensure reverse-kl as the loss function - teacher_model_name_or_path="teacher-model", # specify the teacher model - + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model ) ``` -Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: +You can also use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: ```python from trl.experimental import GOLDConfig diff --git a/trl/experimental/distillation/__init__.py b/trl/experimental/distillation/__init__.py new file mode 100644 index 0000000000..894333c1d0 --- /dev/null +++ b/trl/experimental/distillation/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .distillation_config import DistillationConfig +from .distillation_trainer import DistillationTrainer + + +__all__ = ["DistillationConfig", "DistillationTrainer"] diff --git a/trl/experimental/distillation/distillation.py b/trl/experimental/distillation/distillation.py new file mode 100644 index 0000000000..368cf4ed03 --- /dev/null +++ b/trl/experimental/distillation/distillation.py @@ -0,0 +1,179 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +# docstyle-ignore +""" +# Full training (off-policy only, lmbda=0): +``` +python trl/experimental/distillation/distillation.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --lmbda 0.0 \ + --output_dir distilled-model \ + --num_train_epochs 1 +``` + +# Mixed on/off-policy (lmbda=0.5): +``` +python trl/experimental/distillation/distillation.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --lmbda 0.5 \ + --beta 0.5 \ + --output_dir distilled-model \ + --num_train_epochs 1 +``` + +# LoRA: +``` +python trl/experimental/distillation/distillation.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-4 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --lmbda 0.0 \ + --output_dir distilled-model \ + --num_train_epochs 1 \ + --use_peft \ + --lora_r 64 \ + --lora_alpha 16 +``` +""" + +import argparse +import os + + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args): + from datasets import load_dataset + from transformers import GenerationConfig + + from trl import ( + LogCompletionsCallback, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + ) + from trl.experimental.distillation import DistillationTrainer + + ################ + # Model init kwargs + ################ + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.model_init_kwargs = model_kwargs + + teacher_model_kwargs = dict( + revision=training_args.teacher_model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.dtype, + use_cache=True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + if training_args.teacher_model_init_kwargs is not None: + teacher_model_kwargs.update(training_args.teacher_model_init_kwargs) + training_args.teacher_model_init_kwargs = teacher_model_kwargs + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + eval_dataset = None + if training_args.eval_strategy != "no": + if script_args.dataset_test_split in dataset: + eval_dataset = dataset[script_args.dataset_test_split] + elif "validation" in dataset: + eval_dataset = dataset["validation"] + elif "dev" in dataset: + eval_dataset = dataset["dev"] + + trainer = DistillationTrainer( + model=model_args.model_name_or_path, + teacher_model=training_args.teacher_model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=eval_dataset, + peft_config=get_peft_config(model_args), + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None): + from trl import ModelConfig, ScriptArguments, TrlParser + from trl.experimental.distillation import DistillationConfig + + dataclass_types = (ScriptArguments, DistillationConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser( + "distillation", help="Run the distillation training script", dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types, prog=prog) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config(fail_with_unknown_args=False) + main(script_args, training_args, model_args) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py new file mode 100644 index 0000000000..8e1386f113 --- /dev/null +++ b/trl/experimental/distillation/distillation_config.py @@ -0,0 +1,445 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Any + +from ...trainer.base_config import _BaseConfig + + +@dataclass +class DistillationConfig(_BaseConfig): + # docstyle-ignore + r""" + Configuration class for the [`DistillationTrainer`]. + + Extends [`~transformers.TrainingArguments`] with parameters specific to knowledge distillation. This config is + independent of [`SFTConfig`] — all necessary fields are declared here. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the + trainer is provided as a string. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum total sequence length (prompt + completion) for tokenization and truncation. + + > Parameters that control the distillation + + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for sampling during generation and for computing the distillation loss. Higher values produce + softer probability distributions. + lmbda (`float`, *optional*, defaults to `0.5`): + Probability of using on-policy (student-generated) data for each gradient accumulation slice. A value of + `0.0` means fully off-policy (dataset completions only), `1.0` means fully on-policy. + beta (`float`, *optional*, defaults to `0.5`): + Interpolation coefficient for the Generalized Jensen-Shannon Divergence loss. When `0.0`, the loss is the + forward KL divergence. When `1.0`, the loss is the reverse KL divergence. When `0.5`, it is the standard + JSD. + reverse_kl_top_1_mode (`str`, *optional*, defaults to `"sampled"`): + Selection rule for the reverse-KL top-1 token when `beta > 0` and `loss_top_k == 1`. `"sampled"` uses the + actual completion token in the batch. `"argmax"` uses the student's highest-probability token. This + setting does not affect the forward-KL support, which always uses the teacher's top-1 token. Ignored when + `beta == 0` or `loss_top_k != 1`. + max_completion_length (`int`, *optional*, defaults to `256`): + Maximum number of tokens to generate per completion during on-policy generation. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the student model during training. + + > Parameters that control the teacher model + + teacher_model_name_or_path (`str` or `None`, *optional*): + Model name or path for the teacher model. Used when the teacher is loaded locally. + teacher_model_revision (`str` or `None`, *optional*): + Model revision of the teacher model (e.g., branch name, tag, or commit hash). + teacher_model_init_kwargs (`dict[str, Any]` or `None`, *optional*): + Keyword arguments passed to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model + from a string. + use_teacher_server (`bool`, *optional*, defaults to `False`): + Whether to use an external vLLM teacher server instead of a local teacher model. + teacher_model_server_url (`str` or `None`, *optional*): + Base URL of a vLLM server hosting the teacher model (e.g., `"http://localhost:8000"`). When set, teacher + logprobs are fetched from the server instead of running a local forward pass when `use_teacher_server=True`. + loss_top_k (`int`, *optional*, defaults to `0`): + Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are + restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary + is used. For local teachers, the general support rule is teacher top-k for forward KL, student top-k for + reverse KL, and the union for mixed JSD. When `beta > 0` and `loss_top_k == 1`, the forward support still + uses the teacher's top-1 token, while the reverse top-1 token is controlled by `reverse_kl_top_1_mode`. + When `use_teacher_server=True`, the pure forward path (`beta=0`) requires this to be positive and uses the + teacher's top-k logprobs for the forward term. When `beta > 0`, server-backed distillation requires + `loss_top_k == 1` and only supports `"sampled"` reverse top-1 tokens. + loss_add_tail (`bool`, *optional*, defaults to `True`): + Whether to append a tail bucket that represents the remaining probability mass outside the selected top-k + support when computing the loss. + + > Parameters that control on-policy generation + + num_generations (`int`, *optional*, defaults to `1`): + Number of completions to generate per prompt during on-policy generation. + generation_batch_size (`int` or `None`, *optional*): + Number of unique prompts per worker per optimizer step. If `None`, computed from + `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`. + top_p (`float`, *optional*, defaults to `0.95`): + Top-p (nucleus) sampling parameter for on-policy generation. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter for on-policy generation. `0` disables top-k filtering. + + > Parameters that control vLLM for student generation + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating on-policy completions from the student model. + vllm_mode (`str`, *optional*, defaults to `"colocate"`): + Mode for student vLLM integration. Either `"server"` or `"colocate"`. + vllm_server_base_url (`str` or `None`, *optional*): + Base URL for the student vLLM server. If provided, `vllm_server_host` and `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the student vLLM server. + vllm_server_port (`int`, *optional*, defaults to `8001`): + Port of the student vLLM server. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Timeout for connecting to the student vLLM server. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the vLLM weight-update group (NCCL communicator). + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + GPU memory utilization for the colocated student vLLM engine. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Tensor parallel size for the colocated student vLLM engine. + vllm_max_model_length (`int` or `None`, *optional*): + Maximum model sequence length for the colocated vLLM engine. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation backend for vLLM. Use `"vllm"` or `"transformers"`. + vllm_structured_outputs_regex (`str` or `None`, *optional*): + Regex pattern for vLLM structured outputs. + vllm_sync_frequency (`int`, *optional*, defaults to `1`): + Frequency (in training steps) to synchronize student model weights to the vLLM engine. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Enable vLLM sleep mode to offload student weights during the optimizer step. + + > Parameters that control logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log a sample of (prompt, completion) pairs periodically. + log_completions_steps (`int`, *optional*, defaults to `100`): + Number of steps between logging completions. Only used if `log_completions` is `True`. + num_completions_to_print (`int` or `None`, *optional*): + Number of completions to print. If `None`, all completions are logged. + """ + + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs", "teacher_model_init_kwargs"] + + # Model + model_init_kwargs: dict[str, Any] | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument " + "of the trainer is provided as a string." + }, + ) + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum total sequence length (prompt + completion) for tokenization and truncation."}, + ) + + # Overridden defaults + learning_rate: float = field( + default=1e-7, + metadata={"help": "The initial learning rate for AdamW."}, + ) + + # Distillation core + temperature: float = field( + default=1.0, + metadata={ + "help": "Temperature for sampling and loss computation. Higher values produce softer distributions." + }, + ) + lmbda: float = field( + default=0.5, + metadata={ + "help": "Probability of using on-policy (student-generated) data per gradient accumulation slice. " + "0.0 = fully off-policy, 1.0 = fully on-policy." + }, + ) + beta: float = field( + default=0.5, + metadata={ + "help": "Interpolation coefficient for the Generalized JSD loss. " + "0.0 = forward KL, 0.5 = JSD, 1.0 = reverse KL." + }, + ) + reverse_kl_top_1_mode: str = field( + default="sampled", + metadata={ + "help": "Reverse-KL top-1 token selection when beta > 0 and loss_top_k == 1. " + "Use 'sampled' for the actual completion token or 'argmax' for the student's top-1 token. " + "The forward-KL support always uses the teacher's top-1 token. Ignored when beta == 0 or loss_top_k != 1." + }, + ) + max_completion_length: int = field( + default=256, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + max_prompt_length: int | None = field( + default=None, + metadata={ + "help": "Maximum number of tokens for the prompt. If None, auto-computed as " + "max_length - max_completion_length. Prompts are truncated according to the " + "tokenizer's truncation_side setting." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the student model during training."}, + ) + + # Teacher model (local) + teacher_model_name_or_path: str | None = field( + default=None, + metadata={"help": "Model name or path for the teacher model."}, + ) + teacher_model_revision: str | None = field( + default=None, + metadata={"help": "Model revision of the teacher model (e.g., branch name, tag, or commit hash)."}, + ) + teacher_model_init_kwargs: dict[str, Any] | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained` when instantiating the teacher." + }, + ) + + # Teacher model (external vLLM server) + use_teacher_server: bool = field( + default=False, + metadata={"help": "Whether to use an external vLLM teacher server instead of a local teacher model."}, + ) + teacher_model_server_url: str | None = field( + default=None, + metadata={ + "help": 'Base URL of a vLLM server hosting the teacher model (e.g., "http://localhost:8000"). ' + "Required when use_teacher_server=True." + }, + ) + loss_top_k: int = field( + default=0, + metadata={ + "help": "Number of top tokens to use when computing the JSD/KL loss. " + "Both student and teacher distributions are restricted to these K tokens " + "(selected based on beta: teacher's top-k for forward KL, student's top-k for reverse KL, " + "union of both for JSD) and re-normalized before computing divergence. " + "If 0, the full vocabulary is used (slower but exact). " + "When beta > 0 and loss_top_k == 1, the forward support still uses the teacher's top-1 token, " + "while the reverse top-1 token is controlled by reverse_kl_top_1_mode. " + "When use_teacher_server=True, beta=0 requires loss_top_k > 0 and uses the teacher's top-k " + "logprobs for the forward term. When beta > 0, server-backed distillation requires loss_top_k == 1 " + "and only supports 'sampled' reverse top-1 tokens." + }, + ) + loss_add_tail: bool = field( + default=True, + metadata={ + "help": "Whether to append a tail bucket representing the remaining probability mass outside the selected top-k support." + }, + ) + + # On-policy generation + num_generations: int = field( + default=1, + metadata={"help": "Number of completions to generate per prompt during on-policy generation."}, + ) + generation_batch_size: int | None = field( + default=None, + metadata={ + "help": "Number of unique prompts per worker per optimizer step. " + "If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations." + }, + ) + top_p: float = field( + default=0.95, + metadata={"help": "Top-p (nucleus) sampling parameter for on-policy generation."}, + ) + top_k: int = field( + default=0, + metadata={"help": "Top-k sampling parameter for on-policy generation. 0 disables top-k filtering."}, + ) + + # vLLM for student generation + use_vllm: bool = field( + default=False, + metadata={"help": "Whether to use vLLM for generating on-policy completions from the student model."}, + ) + vllm_mode: str = field( + default="colocate", + metadata={"help": 'Mode for student vLLM integration. Either "server" or "colocate".'}, + ) + vllm_server_base_url: str | None = field( + default=None, + metadata={"help": "Base URL for the student vLLM server."}, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the student vLLM server."}, + ) + vllm_server_port: int = field( + default=8001, + metadata={"help": "Port of the student vLLM server."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={"help": "Timeout for connecting to the student vLLM server."}, + ) + vllm_group_port: int = field( + default=51216, + metadata={"help": "Port for the vLLM weight-update group (NCCL communicator)."}, + ) + vllm_gpu_memory_utilization: float = field( + default=0.9, + metadata={"help": "GPU memory utilization for the colocated student vLLM engine."}, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={"help": "Tensor parallel size for the colocated student vLLM engine."}, + ) + vllm_max_model_length: int | None = field( + default=None, + metadata={"help": "Maximum model sequence length for the colocated vLLM engine."}, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={"help": 'Model implementation backend for vLLM. Use "vllm" or "transformers".'}, + ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex pattern for vLLM structured outputs."}, + ) + vllm_sync_frequency: int = field( + default=1, + metadata={"help": "Frequency (in training steps) to synchronize student weights to the vLLM engine."}, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={"help": "Enable vLLM sleep mode to offload student weights during the optimizer step."}, + ) + + # W&B + wandb_entity: str | None = field( + default=None, + metadata={"help": "The W&B entity to store runs under."}, + ) + wandb_project: str | None = field( + default=None, + metadata={"help": "The W&B project to store runs under."}, + ) + wandb_run_group: str | None = field( + default=None, + metadata={"help": "The W&B group to store runs under."}, + ) + + # Logging + log_completions: bool = field( + default=False, + metadata={"help": "Whether to log a sample of (prompt, completion) pairs periodically."}, + ) + log_completions_steps: int = field( + default=100, + metadata={"help": "Number of steps between logging completions."}, + ) + num_completions_to_print: int | None = field( + default=None, + metadata={"help": "Number of completions to print. If None, all completions are logged."}, + ) + + def __post_init__(self): + super().__post_init__() + + if self.lmbda < 0.0 or self.lmbda > 1.0: + raise ValueError(f"lmbda must be in [0.0, 1.0], got {self.lmbda}.") + if self.beta < 0.0 or self.beta > 1.0: + raise ValueError(f"beta must be in [0.0, 1.0], got {self.beta}.") + if self.reverse_kl_top_1_mode not in {"sampled", "argmax"}: + raise ValueError("reverse_kl_top_1_mode must be one of: 'sampled', 'argmax'") + + if self.max_length is not None and self.max_completion_length >= self.max_length: + raise ValueError( + f"max_completion_length ({self.max_completion_length}) must be smaller than " + f"max_length ({self.max_length}) to leave room for the prompt." + ) + + if self.max_prompt_length is None and self.max_length is not None: + self.max_prompt_length = self.max_length - self.max_completion_length + + if self.num_generations < 1: + raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") + + local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps + if self.generation_batch_size is None: + self.generation_batch_size = local_sequence_batch_size // self.num_generations + if self.generation_batch_size < 1: + raise ValueError(f"generation_batch_size must be at least 1, got {self.generation_batch_size}.") + if self.generation_batch_size * self.num_generations != local_sequence_batch_size: + raise ValueError( + "generation_batch_size * num_generations must equal per_device_train_batch_size * " + f"gradient_accumulation_steps. Got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." + ) + + if self.use_teacher_server and self.use_liger_kernel: + raise ValueError( + "use_liger_kernel=True is not supported with use_teacher_server=True because the Liger loss path " + "requires a local teacher model." + ) + if self.use_teacher_server and ( + self.teacher_model_server_url is None or not self.teacher_model_server_url.strip() + ): + raise ValueError("teacher_model_server_url must be set when use_teacher_server=True.") + + if self.use_teacher_server and self.beta == 0 and self.loss_top_k < 1: + raise ValueError( + f"loss_top_k must be positive when using use_teacher_server=True with beta=0 " + f"(got loss_top_k={self.loss_top_k}). The pure forward server path only has access to the " + f"teacher's top-k logprobs, so it cannot compute the exact full-vocabulary loss when loss_top_k=0." + ) + if self.use_teacher_server and self.reverse_kl_top_1_mode == "argmax": + raise ValueError( + "reverse_kl_top_1_mode='argmax' is not supported with use_teacher_server=True because the server " + "cannot provide teacher logprobs for arbitrary student-selected tokens." + ) + if self.use_teacher_server and self.beta > 0 and self.loss_top_k != 1: + raise ValueError( + f"loss_top_k must be 1 when using use_teacher_server=True with beta>0 " + f"(got loss_top_k={self.loss_top_k}). Mixed forward/reverse distillation with an external teacher " + "is only implemented for top-1 support." + ) + if self.reverse_kl_top_1_mode != "sampled" and (self.beta == 0 or self.loss_top_k != 1): + warnings.warn( + f"reverse_kl_top_1_mode='{self.reverse_kl_top_1_mode}' has no effect when beta={self.beta} " + f"and loss_top_k={self.loss_top_k}. It only applies when beta > 0 and loss_top_k == 1.", + UserWarning, + stacklevel=2, + ) + + if self.num_generations > 1 and self.lmbda < 1.0: + warnings.warn( + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include " + f"{self.num_generations} copies of each sample. Consider lmbda=1.0 when num_generations > 1.", + UserWarning, + stacklevel=2, + ) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py new file mode 100644 index 0000000000..479e03e8b7 --- /dev/null +++ b/trl/experimental/distillation/distillation_trainer.py @@ -0,0 +1,1687 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import textwrap +import warnings +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from functools import partial +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from accelerate.utils import DistributedType, broadcast_object_list, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, TrainerCallback +from transformers.data.data_collator import DataCollator +from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.generation.configuration_utils import GenerationConfig +from transformers.image_processing_utils import BaseImageProcessor +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.utils import ( + is_liger_kernel_available, + is_peft_available, + is_rich_available, +) + +from ...extras.profiling import profiling_decorator +from ...generation.vllm_generation import VLLMGeneration +from ...import_utils import is_vllm_available +from ...models import prepare_deepspeed +from ...models.utils import unwrap_model_for_generation +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + pad, + split_tensor_dict, +) +from .distillation_config import DistillationConfig + + +if is_peft_available(): + from peft import PeftConfig, get_peft_model + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + +if is_rich_available(): + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + + +def _print_completions_sample(prompts: list[str], completions: list[str], step: int, num_samples: int = None) -> None: + """Print a sample of prompt-completion pairs using rich.""" + if not is_rich_available(): + return + + console = Console() + table = Table(show_header=True, header_style="bold white", expand=True) + table.add_column("Prompt", style="bright_yellow") + table.add_column("Completion", style="bright_green") + + if num_samples is not None: + if num_samples >= len(prompts): + num_samples = None + elif num_samples <= 0: + return + + if num_samples is not None: + indices = random.sample(range(len(prompts)), num_samples) + prompts = [prompts[i] for i in indices] + completions = [completions[i] for i in indices] + + for prompt, completion in zip(prompts, completions, strict=True): + table.add_row(Text(prompt), Text(completion)) + table.add_section() + + panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") + console.print(panel) + + +def _add_tail_bucket(log_probs, valid_mask): + """Append a (K+1)-th tail element: log(1 - sum(exp(top_k_logps))). + + This creates a proper probability distribution over K+1 elements, preventing trivial zero loss when top_k is small + (especially top_k=1). + """ + log_sum = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_sum = torch.clamp(log_sum, max=-1e-7) # ensure sum < 1 + tail = torch.log(-torch.expm1(log_sum)) # log(1 - exp(log_sum)) + tail_mask = torch.ones_like(valid_mask[..., :1], dtype=torch.bool) + return torch.cat([log_probs, tail], dim=-1), torch.cat([valid_mask, tail_mask], dim=-1) + + +def _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask=None): + """Compute JSD (or forward/reverse KL) from log-probability tensors. + + When *support_mask* is not None, uses manual computation with masked positions zeroed. When None, uses + ``F.kl_div``. + """ + if support_mask is not None: + safe_student = torch.where(support_mask, student_log_probs, torch.zeros_like(student_log_probs)) + safe_teacher = torch.where(support_mask, teacher_log_probs, torch.zeros_like(teacher_log_probs)) + student_probs = torch.where(support_mask, student_log_probs.exp(), torch.zeros_like(student_log_probs)) + teacher_probs = torch.where(support_mask, teacher_log_probs.exp(), torch.zeros_like(teacher_log_probs)) + + if beta == 0: + return torch.nan_to_num(teacher_probs * (safe_teacher - safe_student), nan=0.0) + elif beta == 1: + return torch.nan_to_num(student_probs * (safe_student - safe_teacher), nan=0.0) + else: + beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) + tiny = torch.finfo(student_probs.dtype).tiny + mixture_probs = (1 - beta_t) * student_probs + beta_t * teacher_probs + safe_mixture = torch.where( + support_mask, + torch.log(mixture_probs.clamp_min(tiny)), + torch.zeros_like(student_log_probs), + ) + kl_teacher = torch.nan_to_num(teacher_probs * (safe_teacher - safe_mixture), nan=0.0) + kl_student = torch.nan_to_num(student_probs * (safe_student - safe_mixture), nan=0.0) + return beta_t * kl_teacher + (1 - beta_t) * kl_student + else: + if beta == 0: + return F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + return F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log1p(-beta_t), teacher_log_probs + torch.log(beta_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) + return beta_t * kl_teacher + (1 - beta_t) * kl_student + + +def build_teacher_request_inputs( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + prompt_attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, +) -> tuple[list[list[int]], list[int], list[int]]: + """Trim padded batch tensors into per-sample sequences for teacher-server requests.""" + + if input_ids.shape != attention_mask.shape: + raise ValueError( + f"input_ids and attention_mask must have the same shape, got {input_ids.shape} and {attention_mask.shape}." + ) + + input_ids_cpu = input_ids.detach().cpu() + attention_mask_cpu = attention_mask.detach().cpu().bool() + + if prompt_attention_mask is not None: + prompt_lengths = prompt_attention_mask.detach().cpu().sum(dim=1).to(torch.long) + else: + if labels is None: + raise ValueError("labels are required when prompt_attention_mask is not provided.") + if labels.shape != input_ids.shape: + raise ValueError(f"labels must match input_ids shape, got {labels.shape} and {input_ids.shape}.") + full_lengths = attention_mask_cpu.sum(dim=1).to(torch.long) + completion_lengths = (labels.detach().cpu() != -100).sum(dim=1).to(torch.long) + prompt_lengths = full_lengths - completion_lengths + + trimmed_input_ids: list[list[int]] = [] + prompt_lengths_list: list[int] = [] + completion_lengths_list: list[int] = [] + + for row, mask, prompt_length in zip(input_ids_cpu, attention_mask_cpu, prompt_lengths, strict=True): + trimmed_row = row[mask] + prompt_len = int(prompt_length.item()) + if prompt_len < 0 or prompt_len > trimmed_row.numel(): + raise ValueError( + f"Invalid prompt length {prompt_len} for trimmed sequence of length {trimmed_row.numel()}." + ) + trimmed_input_ids.append(trimmed_row.tolist()) + prompt_lengths_list.append(prompt_len) + completion_lengths_list.append(int(trimmed_row.numel()) - prompt_len) + + return trimmed_input_ids, prompt_lengths_list, completion_lengths_list + + +class _DistillationCollator: + """Data collator for the distillation trainer with independent prompt/completion budgets. + + Unlike ``DataCollatorForChatML``, this collator tokenizes prompts and completions separately so that long + completions can never truncate the prompt to empty. It also handles prompt-only data (no assistant completions) for + pure on-policy distillation (``lmbda=1``). + """ + + def __init__( + self, + tokenizer: "PreTrainedTokenizerBase", + max_length: int, + max_prompt_length: int, + messages_key: str = "messages", + ignore_index: int = -100, + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.max_prompt_length = max_prompt_length + self.messages_key = messages_key + self.ignore_index = ignore_index + + if tokenizer.pad_token_id is None: + raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") + + def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + all_input_ids: list[list[int]] = [] + all_labels: list[list[int]] = [] + all_prompt_ids: list[list[int]] = [] + + for example in examples: + messages = example[self.messages_key] + + # Split: prompt = everything before the last assistant turn, completion = last assistant turn + has_completion = len(messages) > 1 and messages[-1].get("role") == "assistant" + prompt_messages = messages[:-1] if has_completion else messages + + # Tokenize prompt with its own budget using the tokenizer's truncation side + formatted_prompt = self.tokenizer.apply_chat_template( + prompt_messages, tokenize=False, add_generation_prompt=True + ) + prompt_ids = self.tokenizer( + formatted_prompt, + truncation=True, + max_length=self.max_prompt_length, + padding=False, + add_special_tokens=False, + )["input_ids"] + + if has_completion: + # Tokenize the full message (prompt + completion) without truncation first + formatted_full = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + full_ids = self.tokenizer(formatted_full, truncation=False, padding=False, add_special_tokens=False)[ + "input_ids" + ] + + # Identify completion tokens: everything after the prompt in the full sequence. + # Use the un-truncated prompt length as the split point. + formatted_prompt_ids = self.tokenizer( + formatted_prompt, truncation=False, padding=False, add_special_tokens=False + )["input_ids"] + completion_ids = full_ids[len(formatted_prompt_ids) :] + + # Trim completion so prompt + completion <= max_length + max_comp = self.max_length - len(prompt_ids) + if max_comp > 0 and len(completion_ids) > max_comp: + completion_ids = completion_ids[:max_comp] + elif max_comp <= 0: + completion_ids = [] + + input_ids = prompt_ids + completion_ids + labels = [self.ignore_index] * len(prompt_ids) + list(completion_ids) + else: + # Prompt-only: no completion to train on (on-policy will generate one) + input_ids = list(prompt_ids) + labels = [self.ignore_index] * len(prompt_ids) + + all_input_ids.append(input_ids) + all_labels.append(labels) + all_prompt_ids.append(list(prompt_ids)) + + # Convert to tensors and left-pad + pad_id = self.tokenizer.pad_token_id + input_ids_t = pad( + [torch.tensor(ids, dtype=torch.long) for ids in all_input_ids], + padding_side="left", + padding_value=pad_id, + ) + attention_mask_t = pad( + [torch.ones(len(ids), dtype=torch.long) for ids in all_input_ids], + padding_side="left", + padding_value=0, + ) + labels_t = pad( + [torch.tensor(lab, dtype=torch.long) for lab in all_labels], + padding_side="left", + padding_value=self.ignore_index, + ) + prompts_t = pad( + [torch.tensor(ids, dtype=torch.long) for ids in all_prompt_ids], + padding_side="left", + padding_value=pad_id, + ) + prompt_mask_t = pad( + [torch.ones(len(ids), dtype=torch.long) for ids in all_prompt_ids], + padding_side="left", + padding_value=0, + ) + + return { + "input_ids": input_ids_t, + "attention_mask": attention_mask_t, + "labels": labels_t, + "prompts": prompts_t, + "prompt_attention_mask": prompt_mask_t, + } + + +class _RepeatBatchDataLoader: + """Repeats each collated batch ``repeat_count`` times without re-collation. + + ``RepeatSampler`` with ``repeat_count > 1`` causes the DataLoader to re-collate (re-tokenize) the same examples on + every repeat, which is wasteful. This wrapper instead keeps ``repeat_count=1`` in the sampler and repeats the + already-collated tensor dict, avoiding redundant tokenization. + """ + + def __init__(self, dataloader, repeat_count: int): + self.dataloader = dataloader + self.repeat_count = repeat_count + + def __iter__(self): + for batch in self.dataloader: + for _ in range(self.repeat_count): + yield batch + + def __len__(self): + return len(self.dataloader) * self.repeat_count + + def set_epoch(self, epoch: int): + if hasattr(self.dataloader, "set_epoch"): + self.dataloader.set_epoch(epoch) + + def __getattr__(self, attr): + return getattr(self.dataloader, attr) + + +class DistillationTrainer(_BaseTrainer): + """ + Trainer for knowledge distillation from a teacher model to a student model. + + Supports: + - Generalized JSD loss (forward KL, reverse KL, or interpolated JSD via `beta`) + - On-policy / off-policy mixing via `lmbda` (buffered across gradient accumulation) + - Local teacher model or external teacher via vLLM server + - Student on-policy generation via vLLM or model.generate() + - Liger kernel for memory-efficient fused JSD loss + """ + + _tag_names = ["trl", "distillation"] + _name = "Distillation" + _paper = { + "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", + "id": "2306.13649", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{agarwal2024on-policy, + title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, + author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=3zKtaqxLhW}, + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + teacher_model: PreTrainedModel | nn.Module | str = None, + args: DistillationConfig | None = None, + data_collator: DataCollator | None = None, # type: ignore + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: Optional["PeftConfig"] = None, + ): + if args is None: + args = DistillationConfig(output_dir="tmp_distillation") + + # ── Student model loading ── + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model_init_kwargs, str): + import json + + model_init_kwargs = json.loads(model_init_kwargs) + teacher_model_init_kwargs = args.teacher_model_init_kwargs or {} + if isinstance(teacher_model_init_kwargs, str): + import json + + teacher_model_init_kwargs = json.loads(teacher_model_init_kwargs) + if isinstance(model, str): + model_name_or_path = model + model = create_model_from_path(model, **model_init_kwargs) + else: + model_name_or_path = model.config._name_or_path if model is not None else None + + # ── Processing class (tokenizer) ── + if processing_class is None and model_name_or_path is not None: + processing_class = AutoTokenizer.from_pretrained(model_name_or_path) + if processing_class is not None: + if getattr(processing_class, "pad_token", None) is None: + processing_class.pad_token = processing_class.eos_token + + # ── PEFT ── + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # ── Data collator ── + if data_collator is None: + data_collator = _DistillationCollator( + tokenizer=processing_class, + max_length=args.max_length, + max_prompt_length=args.max_prompt_length, + ) + + # ── Liger fused JSD loss ── + self.use_liger_loss = False + if args.use_liger_kernel: + self.liger_jsd_loss = LigerFusedLinearJSDLoss( + beta=args.beta, + ignore_index=-100, + temperature=args.temperature, + compiled=False, + weight_hard_loss=0.0, + weight_soft_loss=1.0, + ) + self.use_liger_loss = True + + # ── Teacher model setup ── + self.teacher_client = None + self.use_teacher_server = args.use_teacher_server + self.teacher_model_server_url = args.teacher_model_server_url + self._local_teacher_tokenizer_matches_student = True + if self.use_teacher_server: + from ...generation.vllm_client import VLLMClient + + self.teacher_client = VLLMClient(base_url=self.teacher_model_server_url, connection_timeout=60.0) + teacher_model = None + elif teacher_model is not None: + if args.teacher_model_init_kwargs is not None and not isinstance(teacher_model, str): + raise ValueError( + "You passed teacher_model_init_kwargs to the config, but your teacher_model is already " + "instantiated." + ) + + teacher_model_name_or_path = ( + teacher_model + if isinstance(teacher_model, str) + else getattr(getattr(teacher_model, "config", None), "_name_or_path", None) + ) + if teacher_model_name_or_path is None: + raise ValueError( + "DistillationTrainer requires a local teacher model with `config._name_or_path` set so its " + "tokenizer can be validated against the student tokenizer." + ) + + teacher_tokenizer_kwargs = {} + teacher_revision = teacher_model_init_kwargs.get("revision", args.teacher_model_revision) + if teacher_revision is not None: + teacher_tokenizer_kwargs["revision"] = teacher_revision + if teacher_model_init_kwargs.get("trust_remote_code") is not None: + teacher_tokenizer_kwargs["trust_remote_code"] = teacher_model_init_kwargs["trust_remote_code"] + teacher_processing_class = AutoTokenizer.from_pretrained( + teacher_model_name_or_path, **teacher_tokenizer_kwargs + ) + if getattr(teacher_processing_class, "pad_token", None) is None: + teacher_processing_class.pad_token = teacher_processing_class.eos_token + self._local_teacher_tokenizer_matches_student = self._local_teacher_tokenizers_match( + processing_class, teacher_processing_class + ) + if not self._local_teacher_tokenizer_matches_student: + warnings.warn( + "DistillationTrainer's built-in local-teacher loss assumes the student and teacher share the " + "same tokenizer. The loaded local teacher tokenizer does not match the student tokenizer, so " + "the teacher weights will be left unchanged for subclass overrides. Direct use of the base " + "DistillationTrainer with this local teacher remains unsupported.", + UserWarning, + stacklevel=2, + ) + + if isinstance(teacher_model, str): + torch_dtype = teacher_model_init_kwargs.get("torch_dtype") + teacher_model_init_kwargs["torch_dtype"] = ( + torch_dtype if torch_dtype in ["auto", None] else getattr(torch, torch_dtype) + ) + + if isinstance(teacher_model, str): + init_kwargs = dict(teacher_model_init_kwargs) + if args.teacher_model_revision is not None: + init_kwargs.setdefault("revision", args.teacher_model_revision) + if "torch_dtype" in init_kwargs and "dtype" not in init_kwargs: + init_kwargs["dtype"] = init_kwargs.pop("torch_dtype") + teacher_model = create_model_from_path(teacher_model, **init_kwargs) + + # Trainer does not need to remove unused columns — the collator handles raw data + args.remove_unused_columns = False + + # ── Call _BaseTrainer.__init__ (which is transformers.Trainer.__init__) ── + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # ── Prepare teacher model (after super().__init__ so accelerator is ready) ── + if teacher_model is not None: + if self._local_teacher_tokenizer_matches_student: + teacher_model.resize_token_embeddings(self.model.config.vocab_size) + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + else: + self.teacher_model = None + + if args.disable_dropout: + disable_dropout_in_model(self.model) + + # ── Store config values ── + self.lmbda = args.lmbda + self.beta = args.beta + self.temperature = args.temperature + self.top_p = args.top_p + self.num_generations = args.num_generations + self.reverse_kl_top_1_mode = args.reverse_kl_top_1_mode + self.loss_top_k = args.loss_top_k + self.loss_add_tail = args.loss_add_tail + + # ── Buffer state ── + self._buffered_inputs = None + self._buffered_on_policy_flags = None + self._buffered_text_logs = None + self._buffer_step = 0 + + # ── Loss tracking ── + self._on_policy_loss_total = 0.0 + self._off_policy_loss_total = 0.0 + self._on_policy_step_equiv = 0.0 + self._off_policy_step_equiv = 0.0 + + # ── Generation config ── + generation_kwargs = { + "max_new_tokens": args.max_completion_length, + "temperature": args.temperature, + "top_p": args.top_p, + "do_sample": True, + "top_k": args.top_k, + "pad_token_id": self.processing_class.pad_token_id, + } + self.generation_config = GenerationConfig(**generation_kwargs) + self.generation_kwargs = generation_kwargs + if ( + hasattr(self.model.generation_config, "eos_token_id") + and self.model.generation_config.eos_token_id is not None + ): + self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + + # ── Metrics & Logging ── + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.log_completions_steps = args.log_completions_steps + self.num_completions_to_print = args.num_completions_to_print + + self._textual_logs = { + "prompt": [], + "completion": [], + } + + # ── vLLM for student generation ── + self.use_vllm = args.use_vllm + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and use_vllm is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size * args.gradient_accumulation_steps, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_completion_length=args.max_completion_length, + logprobs=None, + ) + self.vllm_sync_frequency = args.vllm_sync_frequency + self._last_vllm_sync_step = -1 + + @staticmethod + def _local_teacher_tokenizers_match( + student_processing_class: PreTrainedTokenizerBase, + teacher_processing_class: PreTrainedTokenizerBase, + ) -> bool: + """Check whether the student and local teacher tokenizers share the same vocabulary.""" + return student_processing_class.get_vocab() == teacher_processing_class.get_vocab() + + def _raise_if_local_teacher_tokenizer_mismatch(self) -> None: + """Guard the base local-teacher JSD path, while still allowing subclass overrides.""" + if self.teacher_model is not None and not self._local_teacher_tokenizer_matches_student: + raise ValueError( + "DistillationTrainer's built-in local-teacher loss only supports student/teacher pairs that use " + "the same tokenizer. Use a same-tokenizer local teacher, set `use_teacher_server=True`, or " + "override the local teacher loss path in a subclass." + ) + + def _compute_prompt_length(self, inputs: dict[str, torch.Tensor | Any]) -> int: + """Compute the earliest prompt boundary that still includes every completion token in the batch.""" + if inputs.get("labels") is not None: + attention_mask = inputs["attention_mask"] + labels = inputs["labels"] + full_lengths = attention_mask.sum(dim=1) + completion_lengths = (labels != -100).sum(dim=1) + return int((full_lengths - completion_lengths).min().item()) + return inputs["prompts"].shape[1] + + def _get_completion_lengths(self, generated_tokens: torch.Tensor, prompt_width: int) -> torch.Tensor: + """Infer per-sample completion lengths from generated tokens.""" + completion_tokens = generated_tokens[:, prompt_width:] + pad_token_id = self.processing_class.pad_token_id + eos_token_id = self.generation_config.eos_token_id + if eos_token_id is None: + eos_token_ids = set() + elif isinstance(eos_token_id, int): + eos_token_ids = {eos_token_id} + else: + eos_token_ids = set(eos_token_id) + pad_equals_eos = pad_token_id is not None and pad_token_id in eos_token_ids + + completion_lengths = [] + for row in completion_tokens.tolist(): + if pad_equals_eos and eos_token_ids: + completion_length = len(row) + for idx, token_id in enumerate(row): + if token_id in eos_token_ids: + completion_length = idx + 1 + break + elif pad_token_id is not None: + completion_length = len(row) + while completion_length > 0 and row[completion_length - 1] == pad_token_id: + completion_length -= 1 + else: + completion_length = len(row) + completion_lengths.append(completion_length) + + return torch.tensor(completion_lengths, device=generated_tokens.device, dtype=torch.long) + + # ────────────────────────────────────────────────────────────────────── + # Dataset / Dataloader + # ────────────────────────────────────────────────────────────────────── + + def _set_signature_columns_if_needed(self): + super()._set_signature_columns_if_needed() + extra_columns = ["prompts", "prompt_attention_mask", "messages", "chat_template_kwargs", "tools"] + if self._signature_columns is None: + self._signature_columns = extra_columns + else: + for col in extra_columns: + if col not in self._signature_columns: + self._signature_columns.append(col) + + def _get_train_sampler(self, dataset=None): + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size * self.accelerator.num_processes, + repeat_count=1, + shuffle=True, + seed=self.args.seed, + ) + + def get_train_dataloader(self): + """ + Override to load one generation batch per optimizer window. + + The dataloader yields batches of size `per_device_train_batch_size * gradient_accumulation_steps`. + RepeatSampler ensures each generation batch is repeated `gradient_accumulation_steps` times so the Trainer's + loop iterates the correct number of times. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + if self.args.dataloader_num_workers > 0: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return _RepeatBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps) + + # ────────────────────────────────────────────────────────────────────── + # Buffering: on/off-policy mixing across gradient accumulation steps + # ────────────────────────────────────────────────────────────────────── + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + if not self.model.training: + return generation_batch + + buffer_steps = self.args.gradient_accumulation_steps + if self._buffer_step % buffer_steps == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, buffer_steps) + + slice_idx = self._buffer_step % buffer_steps + inputs = self._buffered_inputs[slice_idx] + self._buffer_step += 1 + return inputs + + @profiling_decorator + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): + """Split batch into slices and decide which are on-policy (student-generated) vs off-policy.""" + slices = split_tensor_dict(generation_batch, buffer_steps) + + # Decide on-policy flags (synchronized across processes) + if self.accelerator.is_main_process: + on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] + else: + on_policy_flags = [False] * buffer_steps + on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) + + self._buffered_inputs = [None] * buffer_steps + self._buffered_on_policy_flags = on_policy_flags + self._buffered_text_logs = [None] * buffer_steps + + # Store off-policy slices directly + on_policy_indices = [] + for i, is_on_policy in enumerate(on_policy_flags): + if is_on_policy: + on_policy_indices.append(i) + else: + self._buffered_inputs[i] = slices[i] + + # Generate student completions for on-policy slices + if on_policy_indices: + self._generate_student_completions(slices, on_policy_indices) + + # Gather on-policy text logs once per optimizer step (all processes must participate) + if self.log_completions: + on_policy_prompts = [] + on_policy_completions = [] + for i in on_policy_indices: + if self._buffered_text_logs[i] is not None: + prompts, completions = self._buffered_text_logs[i] + on_policy_prompts.extend(prompts) + on_policy_completions.extend(completions) + self._textual_logs["prompt"].extend(gather_object(on_policy_prompts)) + self._textual_logs["completion"].extend(gather_object(on_policy_completions)) + + @profiling_decorator + def _generate_student_completions(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): + """Generate completions from the student model for on-policy training.""" + if not self.use_vllm: + self._generate_with_model(slices, on_policy_indices) + return + + # Collect all prompts across on-policy slices, stripping left-padding so vLLM + # receives only real prompt token IDs (like GRPO trainer). + local_prompts = [] + local_slice_indices = [] + pad_token_id = self.processing_class.pad_token_id + for slice_idx in on_policy_indices: + prompt_mask = slices[slice_idx].get("prompt_attention_mask") + for i, prompt in enumerate(slices[slice_idx]["prompts"]): + if prompt_mask is not None: + prompt = prompt[prompt_mask[i].bool()] + elif pad_token_id is not None: + first_non_pad = (prompt != pad_token_id).nonzero(as_tuple=True)[0] + if len(first_non_pad) > 0: + prompt = prompt[first_non_pad[0] :] + local_prompts.append(prompt) + local_slice_indices.append(slice_idx) + + # Sync student weights to vLLM if needed + if ( + self.state.global_step != self._last_vllm_sync_step + and self.state.global_step % self.vllm_sync_frequency == 0 + ): + self.vllm_generation.sync_weights() + self._last_vllm_sync_step = self.state.global_step + + # Generate completions — pass token IDs directly, no text decoding + prompt_ids_list = [p.tolist() for p in local_prompts] + _, completion_ids, _, _ = self.vllm_generation.generate( + prompts=prompt_ids_list, images=None, num_generations=self.num_generations + ) + + # Process completions into the buffer + self._store_completions_in_buffer( + slices, on_policy_indices, local_slice_indices, local_prompts, completion_ids + ) + + def _generate_with_model(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): + """Fallback generation using model.generate() (no vLLM).""" + with unwrap_model_for_generation( + self.model, self.accelerator, generation_kwargs=self.generation_kwargs + ) as unwrapped_model: + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + generated_outputs = unwrapped_model.generate( + input_ids=slice_inputs["prompts"], + attention_mask=slice_inputs.get("prompt_attention_mask", None), + generation_config=self.generation_config, + return_dict_in_generate=True, + ) + generated_tokens = generated_outputs.sequences + batch_size = generated_tokens.size(0) + device = generated_tokens.device + pad_token_id = self.processing_class.pad_token_id + prompt_width = slice_inputs["prompts"].shape[1] + prompt_mask = slice_inputs.get("prompt_attention_mask") + if prompt_mask is not None: + prompt_token_lengths = prompt_mask.sum(dim=1) + else: + prompt_token_lengths = torch.full((batch_size,), prompt_width, dtype=torch.long, device=device) + completion_lengths = self._get_completion_lengths(generated_tokens, prompt_width) + new_attention_mask, new_labels = self._build_sequence_batch( + generated_tokens, prompt_width, prompt_token_lengths, completion_lengths + ) + + # Decode for logging + prompt_texts = [] + completion_texts = [] + for idx in range(batch_size): + prompt_tokens = slice_inputs["prompts"][idx] + if prompt_mask is not None: + prompt_tokens = prompt_tokens[prompt_mask[idx].bool()] + elif pad_token_id is not None: + prompt_tokens = prompt_tokens[prompt_tokens != pad_token_id] + prompt_texts.append( + self.processing_class.decode(prompt_tokens.tolist(), skip_special_tokens=False) + ) + length = prompt_width + completion_length = int(completion_lengths[idx].item()) + completion_texts.append( + self.processing_class.decode( + generated_tokens[idx, length : length + completion_length].tolist(), + skip_special_tokens=False, + ) + ) + + updated = dict(slice_inputs) + updated["input_ids"] = generated_tokens + updated["attention_mask"] = new_attention_mask + updated["labels"] = new_labels + + self._buffered_inputs[slice_idx] = updated + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + + def _store_completions_in_buffer( + self, + slices: list[dict[str, torch.Tensor | Any]], + on_policy_indices: list[int], + local_slice_indices: list[int], + local_prompts: list[torch.Tensor], + completion_ids: list, + ): + """Process vLLM completions and store them in the buffer. + + Uses original prompt token IDs directly (no decode/re-encode roundtrip), following the same approach as + GRPOTrainer. + """ + device = self.accelerator.device + pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 + max_completion_length = self.generation_config.max_new_tokens + + # Group completions and prompt token IDs by slice + slice_completions = {idx: [] for idx in on_policy_indices} + slice_prompt_ids = {idx: [] for idx in on_policy_indices} + for i, slice_idx in enumerate(local_slice_indices): + slice_completions[slice_idx].append(completion_ids[i]) + slice_prompt_ids[slice_idx].append(local_prompts[i]) + + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + prompt_id_tensors = slice_prompt_ids[slice_idx] + prompt_width = max(len(p) for p in prompt_id_tensors) + prompt_token_lengths = torch.tensor([len(p) for p in prompt_id_tensors], device=device, dtype=torch.long) + prompt_attention_mask = ( + torch.arange(prompt_width, device=device).unsqueeze(0) + >= (prompt_width - prompt_token_lengths).unsqueeze(1) + ).long() + + # Left-pad prompt token IDs to the longest prompt in this slice + prompt_ids = torch.stack( + [F.pad(p, (prompt_width - len(p), 0), value=pad_token_id) for p in prompt_id_tensors] + ).to(device) + + # Pad/truncate completions (right-pad to max_completion_length) + completion_tensors = [] + completion_ids_for_text = [] + completion_lengths = [] + for comp_ids in slice_completions[slice_idx]: + t = torch.tensor(comp_ids, device=device) + if len(t) > max_completion_length: + t = t[:max_completion_length] + completion_ids_for_text.append(t.tolist()) + completion_lengths.append(len(t)) + if len(t) < max_completion_length: + padding = torch.full((max_completion_length - len(t),), pad_token_id, device=device, dtype=t.dtype) + t = torch.cat([t, padding]) + completion_tensors.append(t) + + completion_ids_padded = torch.stack(completion_tensors) + new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) + completion_lengths = torch.tensor(completion_lengths, device=device, dtype=torch.long) + new_attention_mask, new_labels = self._build_sequence_batch( + new_input_ids, prompt_width, prompt_token_lengths, completion_lengths + ) + + # Decode for logging only + prompt_texts = self.processing_class.batch_decode( + prompt_id_tensors, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + completion_texts = self.processing_class.batch_decode( + completion_ids_for_text, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + + updated = dict(slice_inputs) + updated["input_ids"] = new_input_ids + updated["attention_mask"] = new_attention_mask + updated["labels"] = new_labels + # Update prompts to match the new padding width so prompt_length is consistent + updated["prompts"] = prompt_ids + updated["prompt_attention_mask"] = prompt_attention_mask + + self._buffered_inputs[slice_idx] = updated + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + + @staticmethod + def _build_sequence_batch( + new_input_ids: torch.Tensor, + prompt_width: int, + prompt_token_lengths: torch.Tensor, + completion_lengths: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build attention mask and labels from prompt/completion lengths.""" + prompt_token_lengths = prompt_token_lengths.to(device=new_input_ids.device, dtype=torch.long) + completion_lengths = completion_lengths.to(device=new_input_ids.device, dtype=torch.long) + positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) + prompt_mask = (positions < prompt_width) & (positions >= (prompt_width - prompt_token_lengths).unsqueeze(1)) + completion_mask = (positions >= prompt_width) & (positions < (prompt_width + completion_lengths).unsqueeze(1)) + new_attention_mask = (prompt_mask | completion_mask).long() + + new_labels = torch.full_like(new_input_ids, -100) + new_labels[completion_mask] = new_input_ids[completion_mask] + + return new_attention_mask, new_labels + + # ────────────────────────────────────────────────────────────────────── + # Loss computation + # ────────────────────────────────────────────────────────────────────── + + @staticmethod + def _reduce_divergence_loss(jsd, labels=None, reduction="batchmean"): + """Reduce a per-token divergence tensor using the trainer's label mask semantics.""" + mask = None + if labels is not None: + mask = labels != -100 + jsd = jsd[mask] + + if reduction == "batchmean": + if labels is not None: + num_tokens = mask.sum() + if num_tokens == 0: + return jsd.sum() * 0.0 # no completion tokens — return zero-grad scalar + return jsd.sum() / num_tokens + return jsd.sum() / jsd.size(0) + elif reduction == "sum": + return jsd.sum() + elif reduction == "mean": + return jsd.mean() + else: + return jsd + + @staticmethod + def generalized_jsd_loss( + student_logits, + teacher_logits, + labels=None, + beta=0.5, + temperature=1.0, + reduction="batchmean", + top_k=0, + add_tail=True, + ): + """ + Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation. + + Args: + student_logits: Tensor of shape (batch_size, sequence_length, vocab_size). + teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size). + labels: Tensor of shape (batch_size, sequence_length) with -100 for positions to ignore. + beta: Interpolation coefficient. 0.0 = forward KL, 1.0 = reverse KL. + temperature: Softmax temperature. + reduction: 'batchmean', 'sum', 'mean', or 'none'. + top_k: Number of top tokens to restrict the loss to. The support set depends on beta: + beta=0 (forward KL) uses teacher's top-k, beta=1 (reverse KL) uses student's top-k, 0 0 and student_logits.size(-1) > top_k: + neg_inf = torch.full((), float("-inf"), dtype=student_logits.dtype, device=student_logits.device) + student_log_probs_full = F.log_softmax(student_logits, dim=-1) + teacher_log_probs_full = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + # Forward KL: teacher-selected support + _, support = teacher_logits.topk(top_k, dim=-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + elif beta == 1: + # Reverse KL: student-selected support + _, support = student_logits.topk(top_k, dim=-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + else: + # JSD: union of both supports (concatenate + deduplicate) + _, student_top = student_logits.topk(top_k, dim=-1) + _, teacher_top = teacher_logits.topk(top_k, dim=-1) + support = torch.cat([teacher_top, student_top], dim=-1) + support_mask = torch.ones(support.shape, dtype=torch.bool, device=support.device) + for i in range(1, support.shape[-1]): + prev_matches = support[..., i : i + 1] == support[..., :i] + prev_valid = support_mask[..., :i] + support_mask[..., i] &= ~(prev_matches & prev_valid).any(dim=-1) + support = torch.where(support_mask, support, torch.zeros_like(support)) + + student_support_logps = student_log_probs_full.gather(-1, support) + teacher_support_logps = teacher_log_probs_full.gather(-1, support) + + # Mask invalid (duplicate) positions with -inf + student_topk_logps = torch.where(support_mask, student_support_logps, neg_inf) + teacher_topk_logps = torch.where(support_mask, teacher_support_logps, neg_inf) + + if add_tail: + # Add tail bucket: append log(1 - sum(exp(top_k_logps))) to preserve + # the remaining probability mass outside the top-k. This prevents trivial + # zero loss when top_k is small (especially top_k=1). + base_support_mask = support_mask + student_log_probs, support_mask = _add_tail_bucket(student_topk_logps, base_support_mask) + teacher_log_probs, _ = _add_tail_bucket(teacher_topk_logps, base_support_mask) + else: + student_log_probs = student_topk_logps - torch.logsumexp(student_topk_logps, dim=-1, keepdim=True) + teacher_log_probs = teacher_topk_logps - torch.logsumexp(teacher_topk_logps, dim=-1, keepdim=True) + else: + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + jsd = _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask) + return DistillationTrainer._reduce_divergence_loss(jsd, labels=labels, reduction=reduction) + + def _get_reverse_kl_top_1_tokens( + self, student_scores: torch.Tensor, completion_tokens: torch.Tensor + ) -> torch.Tensor: + """Return the reverse-KL top-1 token IDs for the mixed top-1 loss path. + + Args: + student_scores: Any (B, T, V) tensor whose argmax selects the student's top token + (logits or log-probs — both are order-preserving). + completion_tokens: (B, T) actual token IDs in the completion. + """ + if self.reverse_kl_top_1_mode == "argmax": + return student_scores.argmax(dim=-1) + return completion_tokens + + def _compute_sparse_top_1_divergence_loss( + self, + student_log_probs: torch.Tensor, + teacher_top1_token_ids: torch.Tensor, + teacher_top1_logprobs: torch.Tensor, + reverse_token_ids: torch.Tensor, + reverse_teacher_logprobs: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute exact generalized JSD/KL on top-1 support for the mixed beta>0 path.""" + neg_inf = torch.full((), float("-inf"), dtype=student_log_probs.dtype, device=student_log_probs.device) + + if self.beta == 1: + support = reverse_token_ids.unsqueeze(-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + teacher_support_logprobs = reverse_teacher_logprobs.unsqueeze(-1) + else: + teacher_support = teacher_top1_token_ids.unsqueeze(-1) + reverse_support = reverse_token_ids.unsqueeze(-1) + support = torch.cat([teacher_support, reverse_support], dim=-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + support_mask[..., 1] = support[..., 1] != support[..., 0] + teacher_support_logprobs = torch.stack([teacher_top1_logprobs, reverse_teacher_logprobs], dim=-1) + support = torch.where(support_mask, support, torch.zeros_like(support)) + + student_support_logprobs = student_log_probs.gather(-1, support) + student_support_logprobs = torch.where(support_mask, student_support_logprobs, neg_inf) + teacher_support_logprobs = torch.where(support_mask, teacher_support_logprobs, neg_inf) + + if self.loss_add_tail: + base_support_mask = support_mask + student_sparse_log_probs, support_mask = _add_tail_bucket(student_support_logprobs, base_support_mask) + teacher_sparse_log_probs, _ = _add_tail_bucket(teacher_support_logprobs, base_support_mask) + else: + student_sparse_log_probs = student_support_logprobs - torch.logsumexp( + student_support_logprobs, dim=-1, keepdim=True + ) + teacher_sparse_log_probs = teacher_support_logprobs - torch.logsumexp( + teacher_support_logprobs, dim=-1, keepdim=True + ) + + jsd = _jsd_divergence(student_sparse_log_probs, teacher_sparse_log_probs, self.beta, support_mask) + return self._reduce_divergence_loss(jsd, labels=labels, reduction="batchmean") + + def _compute_local_sparse_top_1_divergence_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_tokens: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute the mixed top-1 loss for a local teacher using gathered full-logit probabilities.""" + student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) + + teacher_top1_token_ids = teacher_logits.argmax(dim=-1) + teacher_top1_logprobs = teacher_log_probs.gather(dim=-1, index=teacher_top1_token_ids.unsqueeze(-1)).squeeze( + -1 + ) + reverse_token_ids = self._get_reverse_kl_top_1_tokens(student_logits, completion_tokens) + reverse_teacher_logprobs = teacher_log_probs.gather(dim=-1, index=reverse_token_ids.unsqueeze(-1)).squeeze(-1) + + return self._compute_sparse_top_1_divergence_loss( + student_log_probs=student_log_probs, + teacher_top1_token_ids=teacher_top1_token_ids, + teacher_top1_logprobs=teacher_top1_logprobs, + reverse_token_ids=reverse_token_ids, + reverse_teacher_logprobs=reverse_teacher_logprobs, + labels=labels, + ) + + def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: + """Get teacher logits — dispatches between local model and external server.""" + if self.teacher_model is not None: + self.teacher_model.eval() + with torch.no_grad(): + return self.teacher_model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ).logits + elif self.use_teacher_server: + raise NotImplementedError( + "Fetching full teacher logits with use_teacher_server=True is not supported. " + "Server-backed distillation only supports per-token logprobs via " + "`_get_teacher_token_logprobs_from_server`." + ) + else: + raise ValueError("No teacher model or teacher server configured.") + + def _get_teacher_token_logprobs_from_server( + self, + inputs: dict[str, torch.Tensor | Any], + aligned_prompt_length: int, + ) -> dict[str, torch.Tensor]: + """Fetch per-token teacher logprobs from an external vLLM server. + + Returns a dict with: + ``actual_logprobs`` – (batch, completion_length) teacher log-prob for the actual + token at each position (for reverse KL). + ``topk_logprobs`` – (batch, completion_length, K) teacher top-k sorted logprobs + (for forward KL). + ``topk_token_ids`` – (batch, completion_length, K) corresponding token IDs. + """ + import numpy as np + + input_ids = inputs["input_ids"] + batch_size = input_ids.shape[0] + sequences, prompt_lengths, completion_lengths = build_teacher_request_inputs( + input_ids, + inputs["attention_mask"], + prompt_attention_mask=inputs.get("prompt_attention_mask"), + labels=inputs.get("labels"), + ) + + # The pure forward server path can use the requested teacher top-k support. + # When beta > 0, config validation restricts the server-backed path to top-1. + requested_top_k = self.loss_top_k + result = self.teacher_client.get_sequence_logprobs( + sequences=sequences, + prompt_lengths=prompt_lengths, + top_logprobs=requested_top_k, + temperature=self.temperature, + ) + K = requested_top_k + + device = input_ids.device + labels = inputs.get("labels") + if labels is None: + raise ValueError("labels are required to align teacher-server logprobs with the student loss tensors.") + + # The student loss slices tensors in padded-sequence coordinates starting at `aligned_prompt_length`. + # Place each teacher completion into that same coordinate system by locating the first non-masked completion + # token in `labels`. This works for both left-padded off-policy batches and on-policy batches where + # completions are right-padded after a fixed-width prompt block. + completion_offsets = [] + label_mask = labels != -100 + for sample_mask, comp_len in zip(label_mask, completion_lengths, strict=True): + if comp_len == 0: + completion_offsets.append(0) + continue + completion_start = int(torch.nonzero(sample_mask, as_tuple=False)[0].item()) + completion_offsets.append(completion_start - aligned_prompt_length) + + # Size the output tensors to tightly fit the teacher logprobs. Using the full padded + # sequence length would include padding positions with -inf teacher logprobs, producing + # inf in the forward pass and NaN gradients in the backward pass (0 * inf = NaN). + completion_length = max( + (offset + len(lps) for offset, lps in zip(completion_offsets, result["logprobs"], strict=True)), + default=0, + ) + + # actual_logprobs: (B, T) — teacher logprob for the actual token + def _actual_to_tensor(key): + arr = np.full((batch_size, completion_length), float("-inf"), dtype=np.float32) + for i, (offset, seq_lps) in enumerate(zip(completion_offsets, result[key], strict=True)): + if seq_lps: + vals = np.array(seq_lps, dtype=np.float32) # (comp_len_i, 1) + arr[i, offset : offset + vals.shape[0]] = vals[:, 0] + return torch.from_numpy(arr).to(device) + + # topk: (B, T, K) + def _topk_to_tensor(key, k, np_dtype, fill): + arr = np.full((batch_size, completion_length, k), fill, dtype=np_dtype) + for i, (offset, seq_vals) in enumerate(zip(completion_offsets, result[key], strict=True)): + if seq_vals: + vals = np.array(seq_vals, dtype=np_dtype) # (comp_len_i, k) + arr[i, offset : offset + vals.shape[0], :] = vals + return torch.from_numpy(arr).to(device) + + return { + "actual_logprobs": _actual_to_tensor("actual_logprobs"), + "topk_logprobs": _topk_to_tensor("logprobs", K, np.float32, float("-inf")), + "topk_token_ids": _topk_to_tensor("logprob_token_ids", K, np.int64, 0), + } + + def _compute_server_sparse_top_1_divergence_loss( + self, + teacher_result: dict[str, torch.Tensor], + student_log_probs: torch.Tensor, + completion_tokens: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute exact sparse top-1 generalized JSD/KL from server-provided teacher logprobs. + + Args: + teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), + ``topk_token_ids`` (B, T, K). + student_log_probs: (B, T, V) student log-softmax over vocabulary. + completion_tokens: (B, T) actual token IDs in the completion. + labels: (B, T) with -100 for positions to ignore. + """ + topk_teacher_lps = teacher_result["topk_logprobs"] # (B, T, 1) + topk_token_ids = teacher_result["topk_token_ids"] # (B, T, 1) + actual_teacher_lps = teacher_result["actual_logprobs"] # (B, T) + required = labels != -100 + + missing_actual = required & ~torch.isfinite(actual_teacher_lps) + if missing_actual.any(): + missing_count = int(missing_actual.sum().item()) + total_required = int(required.sum().item()) + raise ValueError( + "Teacher server is missing actual-token logprobs for required reverse-KL positions: " + f"{missing_count}/{total_required}." + ) + if self.beta < 1: + teacher_top1_logprobs = topk_teacher_lps.squeeze(-1) + missing_top1 = required & ~torch.isfinite(teacher_top1_logprobs) + if missing_top1.any(): + missing_count = int(missing_top1.sum().item()) + total_required = int(required.sum().item()) + raise ValueError( + "Teacher server is missing top-1 logprobs for required forward-KL positions: " + f"{missing_count}/{total_required}." + ) + + # Server path only supports "sampled" mode — config validation enforces this, but we guard + # explicitly so future relaxations of the config check don't silently change behaviour. + reverse_token_ids = self._get_reverse_kl_top_1_tokens(student_log_probs, completion_tokens) + return self._compute_sparse_top_1_divergence_loss( + student_log_probs=student_log_probs, + teacher_top1_token_ids=topk_token_ids.squeeze(-1), + teacher_top1_logprobs=topk_teacher_lps.squeeze(-1), + reverse_token_ids=reverse_token_ids, + reverse_teacher_logprobs=actual_teacher_lps, + labels=labels, + ) + + def _compute_server_forward_kl_loss( + self, + teacher_result: dict[str, torch.Tensor], + student_log_probs: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute sparse forward KL from server-provided teacher top-k logprobs (beta==0 path). + + Args: + teacher_result: dict with ``topk_logprobs`` (B, T, K) and ``topk_token_ids`` (B, T, K). + student_log_probs: (B, T, V) student log-softmax over vocabulary. + labels: (B, T) with -100 for positions to ignore. + """ + teacher_topk_logprobs = teacher_result["topk_logprobs"] + teacher_topk_token_ids = teacher_result["topk_token_ids"] + valid = teacher_topk_logprobs > float("-inf") + neg_inf = torch.full((), float("-inf"), dtype=student_log_probs.dtype, device=student_log_probs.device) + student_topk_logprobs = student_log_probs.gather(dim=-1, index=teacher_topk_token_ids) + student_topk_logprobs = torch.where(valid, student_topk_logprobs, neg_inf) + teacher_topk_logprobs = torch.where(valid, teacher_topk_logprobs, neg_inf) + + if self.loss_add_tail: + base_support_mask = valid + student_sparse_log_probs, support_mask = _add_tail_bucket(student_topk_logprobs, base_support_mask) + teacher_sparse_log_probs, _ = _add_tail_bucket(teacher_topk_logprobs, base_support_mask) + else: + support_mask = valid + student_sparse_log_probs = student_topk_logprobs - torch.logsumexp( + student_topk_logprobs, dim=-1, keepdim=True + ) + teacher_sparse_log_probs = teacher_topk_logprobs - torch.logsumexp( + teacher_topk_logprobs, dim=-1, keepdim=True + ) + + jsd = _jsd_divergence( + student_sparse_log_probs, + teacher_sparse_log_probs, + beta=0.0, + support_mask=support_mask, + ) + return self._reduce_divergence_loss(jsd, labels=labels, reduction="batchmean") + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + self._raise_if_local_teacher_tokenizer_mismatch() + + if self.use_liger_loss: + loss = self._compute_liger_loss(model, inputs) + return (loss, None) if return_outputs else loss + + # Student forward pass + student_outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + prompt_length = self._compute_prompt_length(inputs) + labels = inputs["labels"][:, prompt_length:] + completion_tokens = inputs["input_ids"][:, prompt_length:] + + if self.use_teacher_server: + # Server path: token-level divergence using teacher logprobs. + # The server returns: + # actual_logprobs – (B, T) teacher log p(x_actual) (for reverse KL) + # topk_logprobs – (B, T, K) teacher top-k sorted logprobs (for forward KL) + # topk_token_ids – (B, T, K) corresponding token IDs + teacher_result = self._get_teacher_token_logprobs_from_server(inputs, prompt_length) + + student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] + student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) + + comp_len = teacher_result["actual_logprobs"].shape[1] + completion_tokens = completion_tokens[:, :comp_len] + trimmed_labels = labels[:, :comp_len] + + if self.beta > 0: + loss = self._compute_server_sparse_top_1_divergence_loss( + teacher_result=teacher_result, + student_log_probs=student_log_probs[:, :comp_len, :], + completion_tokens=completion_tokens, + labels=trimmed_labels, + ) + else: + loss = self._compute_server_forward_kl_loss( + teacher_result=teacher_result, + student_log_probs=student_log_probs[:, :comp_len, :], + labels=trimmed_labels, + ) + else: + # Local teacher: exact full-vocabulary loss except for the shared mixed top-1 path. + teacher_logits = self._get_teacher_logits(inputs) + student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] + teacher_logits = teacher_logits[:, prompt_length - 1 : -1, :] + if self.beta > 0 and self.loss_top_k == 1: + loss = self._compute_local_sparse_top_1_divergence_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + completion_tokens=completion_tokens, + labels=labels, + ) + else: + loss = self.generalized_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=labels, + beta=self.beta, + temperature=self.temperature, + top_k=self.loss_top_k, + add_tail=self.loss_add_tail, + ) + + return (loss, student_outputs) if return_outputs else loss + + def _compute_liger_loss(self, model, inputs): + """Memory-efficient JSD using Liger kernel (operates on hidden states, not full logits).""" + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None: + base_student = unwrapped_student.get_decoder() + else: + base_student = getattr( + unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student + ) + + student_outputs = base_student( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + self.teacher_model.eval() + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None: + base_teacher = unwrapped_teacher.get_decoder() + else: + base_teacher = getattr( + unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher + ) + with torch.no_grad(): + teacher_outputs = base_teacher( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + del student_outputs, teacher_outputs + + student_hidden = student_hidden.reshape(-1, student_hidden.shape[-1]) + teacher_hidden = teacher_hidden.reshape(-1, teacher_hidden.shape[-1]) + + labels_mask = inputs["labels"] != -100 + masked_input_ids = torch.where(labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)) + true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) + + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, "bias", None), + teacher_bias=getattr(teacher_head, "bias", None), + ) + + del student_hidden, teacher_hidden, true_labels + return loss + + def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): + """Context manager for gathering lm_head parameters under Liger + ZeRO-3.""" + if not self.use_liger_loss: + return nullcontext() + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3: + return nullcontext() + + import deepspeed + + unwrapped_student = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if student_head.bias is not None: + params.append(student_head.bias) + if teacher_head.bias is not None: + params.append(teacher_head.bias) + return deepspeed.zero.GatheredParameters(params, modifier_rank=None) + + # ────────────────────────────────────────────────────────────────────── + # Training step & Logging + # ────────────────────────────────────────────────────────────────────── + + @profiling_decorator + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + """Training step with on/off-policy loss tracking and completion stats.""" + buffer_steps = self.args.gradient_accumulation_steps + + with self._get_liger_zero3_lm_head_gather_ctx(model): + loss = super().training_step(model, inputs, num_items_in_batch) + + slice_idx = (self._buffer_step - 1) % buffer_steps + + # Determine if this slice is on-policy + is_on_policy = False + if self._buffered_on_policy_flags is not None and slice_idx < len(self._buffered_on_policy_flags): + is_on_policy = self._buffered_on_policy_flags[slice_idx] + + # Track completion length stats — read from buffered inputs (which reflect on-policy generation) + actual_inputs = self._buffered_inputs[slice_idx] if self._buffered_inputs is not None else inputs + labels = actual_inputs.get("labels") + if labels is not None: + completion_lengths = (labels != -100).sum(dim=1).float() + gathered_lengths = self.accelerator.gather(completion_lengths) + mode = "train" + prefix = "on_policy" if is_on_policy else "off_policy" + self._metrics[mode][f"completions/{prefix}_mean_length"].append(gathered_lengths.mean().item()) + self._metrics[mode][f"completions/{prefix}_max_length"].append(gathered_lengths.max().item()) + self._metrics[mode][f"completions/{prefix}_min_length"].append(gathered_lengths.min().item()) + + # Log fraction of completions that hit max_completion_length (truncated) + max_comp_len = getattr(self.generation_config, "max_new_tokens", None) + if is_on_policy and max_comp_len is not None: + truncated_frac = (gathered_lengths >= max_comp_len).float().mean().item() + self._metrics[mode]["completions/truncated_fraction"].append(truncated_frac) + + # Track loss per policy type + loss_scalar = float(loss.detach()) + step_equiv = 1.0 / self.args.gradient_accumulation_steps + if is_on_policy: + self._on_policy_loss_total += loss_scalar + self._on_policy_step_equiv += step_equiv + else: + self._off_policy_loss_total += loss_scalar + self._off_policy_step_equiv += step_equiv + + return loss + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} + + if mode == "train": + # Aggregate on/off-policy losses across distributed processes + device = self.accelerator.device if hasattr(self.accelerator, "device") else torch.device("cpu") + vec = torch.tensor( + [ + self._on_policy_loss_total, + self._off_policy_loss_total, + self._on_policy_step_equiv, + self._off_policy_step_equiv, + ], + dtype=torch.float64, + device=device, + ) + + if ( + getattr(self.accelerator, "distributed_type", DistributedType.NO) != DistributedType.NO + and dist.is_available() + and dist.is_initialized() + ): + dist.all_reduce(vec, op=dist.ReduceOp.SUM) + + on_sum, off_sum, on_eq, off_eq = vec.tolist() + if on_eq > 0: + logs["on_policy_loss"] = round(on_sum / on_eq, 4) + if off_eq > 0: + logs["off_policy_loss"] = round(off_sum / off_eq, 4) + + self._on_policy_loss_total = self._off_policy_loss_total = 0.0 + self._on_policy_step_equiv = self._off_policy_step_equiv = 0.0 + + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + # Log completions to console and wandb + should_log_completions = ( + self.log_completions + and self.state.global_step > 0 + and self.state.global_step % self.log_completions_steps == 0 + ) + + if should_log_completions and self.accelerator.is_main_process: + prompts = list(self._textual_logs["prompt"]) + completions = list(self._textual_logs["completion"]) + + if prompts: + _print_completions_sample(prompts, completions, self.state.global_step, self.num_completions_to_print) + + # Log as a wandb Table + if self.args.report_to and "wandb" in self.args.report_to: + try: + import wandb + + if wandb.run is not None: + import pandas as pd + + table_data = { + "step": [str(self.state.global_step)] * len(prompts), + "prompt": prompts, + "completion": completions, + } + df = pd.DataFrame(table_data) + if self.num_completions_to_print and len(df) > self.num_completions_to_print: + df = df.sample(n=self.num_completions_to_print, random_state=42) + wandb.log({"completions": wandb.Table(dataframe=df)}) + except ImportError: + pass + + # Clear text logs on all processes after the logging interval + if should_log_completions: + self._textual_logs["prompt"].clear() + self._textual_logs["completion"].clear() diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index d993b283b7..cb05ae5a55 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -523,6 +523,200 @@ def update_model_params(self, model: nn.Module): # Update each parameter individually self.update_named_param(name, param.data) + def get_sequence_logprobs( + self, + sequences: list[list[int]], + prompt_lengths: list[int], + top_logprobs: int = 100, + temperature: float = 1.0, + use_binary: bool = True, + chunk_size: int = 0, + max_concurrent_requests: int = 4, + ) -> dict[str, list]: + """ + Computes teacher logprobs for existing token sequences without generating new tokens. + + Sends full sequences (prompt + completion) to the vLLM server and retrieves per-token top-k logprobs for the + completion region only. This is used for knowledge distillation where the teacher model evaluates existing + sequences rather than generating new ones. + + When `chunk_size > 0`, splits the batch into chunks and dispatches them concurrently via a thread pool, keeping + the server's data-parallel workers busy. + + When `use_binary=True`, uses base64-encoded numpy arrays for fast serialization instead of nested JSON lists. + + Args: + sequences (`list[list[int]]`): + List of full token ID sequences (prompt + completion). + prompt_lengths (`list[int]`): + Number of prompt tokens in each sequence. Logprobs are returned starting from this position. + top_logprobs (`int`, *optional*, defaults to `100`): + Number of top logprobs to return per token position. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature used when scoring the teacher distribution. + use_binary (`bool`, *optional*, defaults to `True`): + Use binary (base64 numpy) response format for faster serialization. + chunk_size (`int`, *optional*, defaults to `0`): + If > 0, split batch into chunks of this size and dispatch concurrently. If 0, send the entire batch in + a single request. + max_concurrent_requests (`int`, *optional*, defaults to `4`): + Maximum number of concurrent requests when using chunked dispatch. + + Returns: + `dict` with keys: + - `logprobs` (`list[list[list[float]]]`): + Per-token logprobs of shape (batch, completion_len, top_logprobs), sorted by descending + probability. + - `logprob_token_ids` (`list[list[list[int]]]`): + Token IDs corresponding to each logprob, same shape as `logprobs`. + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + + if temperature <= 0: + raise ValueError(f"temperature must be positive, got {temperature}") + + url = f"{self.base_url}/get_sequence_logprobs/" + response_format = "binary" if use_binary else "json" + + if chunk_size > 0 and len(sequences) > chunk_size: + # Chunked concurrent dispatch + n = len(sequences) + chunks = [] + for i in range(0, n, chunk_size): + chunks.append((sequences[i : i + chunk_size], prompt_lengths[i : i + chunk_size])) + + responses = [None] * len(chunks) + + def _send_chunk(idx, seqs, plens): + resp = self.session.post( + url, + json={ + "sequences": seqs, + "prompt_lengths": plens, + "top_logprobs": top_logprobs, + "temperature": temperature, + "response_format": response_format, + }, + ) + if resp.status_code != 200: + raise Exception(f"Request failed: {resp.status_code}, {resp.text}") + return idx, resp.json() + + with ThreadPoolExecutor(max_workers=min(max_concurrent_requests, len(chunks))) as executor: + futures = { + executor.submit(_send_chunk, idx, seqs, plens): idx for idx, (seqs, plens) in enumerate(chunks) + } + for future in as_completed(futures): + idx, result = future.result() + responses[idx] = result + + # Merge results + if use_binary: + return self._merge_binary_responses(responses, top_logprobs) + else: + all_logprobs = [] + all_token_ids = [] + for resp in responses: + all_logprobs.extend(resp["logprobs"]) + all_token_ids.extend(resp["logprob_token_ids"]) + return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + else: + # Single request + response = self.session.post( + url, + json={ + "sequences": sequences, + "prompt_lengths": prompt_lengths, + "top_logprobs": top_logprobs, + "temperature": temperature, + "response_format": response_format, + }, + ) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + json_response = response.json() + if use_binary: + return self._decode_binary_logprobs(json_response) + else: + return { + "logprobs": json_response["logprobs"], + "logprob_token_ids": json_response["logprob_token_ids"], + } + + @staticmethod + def _decode_binary_logprobs(response: dict) -> dict[str, list]: + """Decode base64-encoded numpy arrays back to nested lists. + + Returns a dict with: + ``logprobs`` / ``logprob_token_ids`` — teacher's sorted top-k logprobs and + token IDs (shape per sequence: ``(comp_len, top_k)``). Used for the forward KL term. + ``actual_logprobs`` / ``actual_token_ids`` — teacher logprob for the actual + token at each position (shape per sequence: ``(comp_len, 1)``). Used for the reverse KL term. + """ + import numpy as np + + shape = response["shape"] # [batch, max_comp_len, top_k] + comp_lengths = response["completion_lengths"] + + logprobs_arr = np.frombuffer(base64.b64decode(response["logprobs_b64"]), dtype=np.float32).reshape(shape) + token_ids_arr = np.frombuffer(base64.b64decode(response["token_ids_b64"]), dtype=np.int32).reshape(shape) + + # Convert back to nested lists, trimming padding + all_logprobs = [] + all_token_ids = [] + for i, comp_len in enumerate(comp_lengths): + all_logprobs.append(logprobs_arr[i, :comp_len, :].tolist()) + all_token_ids.append(token_ids_arr[i, :comp_len, :].tolist()) + + result = {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + + # Decode actual-token logprobs (for reverse KL) + if "actual_logprobs_b64" in response: + actual_shape = [shape[0], shape[1], 1] + actual_lp = np.frombuffer(base64.b64decode(response["actual_logprobs_b64"]), dtype=np.float32).reshape( + actual_shape + ) + actual_ids = np.frombuffer(base64.b64decode(response["actual_token_ids_b64"]), dtype=np.int32).reshape( + actual_shape + ) + all_actual_lps = [] + all_actual_ids = [] + for i, comp_len in enumerate(comp_lengths): + all_actual_lps.append(actual_lp[i, :comp_len, :].tolist()) + all_actual_ids.append(actual_ids[i, :comp_len, :].tolist()) + result["actual_logprobs"] = all_actual_lps + result["actual_token_ids"] = all_actual_ids + + return result + + @staticmethod + def _merge_binary_responses(responses: list[dict], top_logprobs: int) -> dict[str, list]: + """Merge binary responses from multiple chunks into a single result.""" + + all_logprobs = [] + all_token_ids = [] + all_actual_lps = [] + all_actual_ids = [] + for resp in responses: + decoded = VLLMClient._decode_binary_logprobs(resp) + all_logprobs.extend(decoded["logprobs"]) + all_token_ids.extend(decoded["logprob_token_ids"]) + if "actual_logprobs" in decoded: + all_actual_lps.extend(decoded["actual_logprobs"]) + all_actual_ids.extend(decoded["actual_token_ids"]) + + result = {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + if all_actual_lps: + if len(all_actual_lps) != len(all_logprobs): + raise ValueError( + f"Inconsistent chunks: {len(all_actual_lps)} actual_logprobs entries " + f"but {len(all_logprobs)} logprobs entries." + ) + result["actual_logprobs"] = all_actual_lps + result["actual_token_ids"] = all_actual_ids + return result + def reset_prefix_cache(self): """ Resets the prefix cache for the model. diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 2134937c25..faf9bf865b 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -15,6 +15,7 @@ import argparse import base64 import logging +import math import os from collections.abc import Sequence from contextlib import asynccontextmanager @@ -209,6 +210,10 @@ class ScriptArguments: log_level (`str`, *optional*, defaults to `"info"`): Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, `"trace"`. + distributed_executor_backend (`str` or `None`, *optional*): + Distributed executor backend for vLLM. Set to `"ray"` to distribute tensor parallel workers across multiple + nodes via a Ray cluster. Required when `tensor_parallel_size` exceeds the number of local GPUs. If not set, + vLLM defaults to the multiproc backend (single-node only). """ model: str = field( @@ -305,6 +310,14 @@ class ScriptArguments: "model implementation." }, ) + distributed_executor_backend: str | None = field( + default=None, + metadata={ + "help": "Distributed executor backend for vLLM. When set to 'ray', vLLM uses Ray to distribute tensor " + "parallel workers across multiple nodes. Required when tensor_parallel_size exceeds the number of local " + "GPUs. If not set, vLLM defaults to the multiproc backend (single-node only)." + }, + ) def llm_worker( @@ -334,6 +347,7 @@ def llm_worker( worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", trust_remote_code=script_args.trust_remote_code, model_impl=script_args.vllm_model_impl, + distributed_executor_backend=script_args.distributed_executor_backend, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", ) @@ -382,6 +396,8 @@ def chunk_list(lst: list, n: int) -> list[list]: def main(script_args: ScriptArguments): + import asyncio + from packaging.version import Version from transformers import is_vision_available @@ -449,8 +465,13 @@ async def lifespan(app: FastAPI): if isinstance(msg, dict) and msg.get("status") == "ready": ready_connections.add(connection) + # Start the logprob request batcher background task + batcher_task = asyncio.create_task(_logprob_batcher()) + yield + batcher_task.cancel() + # Wait for processes to terminate for process in processes: process.join(timeout=10) # Wait for 10 seconds for the process to terminate @@ -634,6 +655,313 @@ async def generate(request: GenerateRequest): "logprob_token_ids": logprob_token_ids, } + class SequenceLogprobsRequest(BaseModel): + sequences: list[list[int]] + prompt_lengths: list[int] + top_logprobs: int = 100 + temperature: float = 1.0 + response_format: str = "json" # "json" (legacy) or "binary" (base64 numpy arrays) + + class SequenceLogprobsResponse(BaseModel): + logprobs: list[list[list[float | None]]] | None = None + logprob_token_ids: list[list[list[int]]] | None = None + # Binary format fields (base64-encoded numpy arrays) + logprobs_b64: str | None = None + token_ids_b64: str | None = None + actual_logprobs_b64: str | None = None + actual_token_ids_b64: str | None = None + shape: list[int] | None = None # [batch_size, max_completion_len, top_logprobs] + completion_lengths: list[int] | None = None # actual completion length per sample + + def _run_prompt_logprobs(prompts, sampling_params): + """Send prompts to DP workers and collect outputs.""" + chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) + for connection, chunk in zip(connections, chunked_prompts, strict=True): + if not chunk: + chunk = [{"prompt_token_ids": [0]}] + kwargs = {"prompts": chunk, "sampling_params": sampling_params} + connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) + all_outputs = [connection.recv() for connection in connections] + all_outputs = [output for output, chunk in zip(all_outputs, chunked_prompts, strict=True) if chunk] + return list(chain.from_iterable(all_outputs)) + + # ── Request batching for get_sequence_logprobs ── + # Collects concurrent requests into batches and dispatches them together so that + # all DP workers stay busy. Without this, async endpoint handlers block the event + # loop during pipe I/O, serializing requests and leaving DP workers idle. + _logprob_queue: asyncio.Queue = asyncio.Queue() + + # Maximum time (seconds) to wait for more requests before dispatching a batch. + _BATCH_WAIT_S = 0.005 # 5ms - short enough to not add much latency when lightly loaded + # Maximum number of HTTP requests to collect per batcher cycle + _MAX_BATCH_REQUESTS = max(script_args.data_parallel_size * 4, 16) + # Maximum total tokens per batch. prompt_logprobs materializes full-vocab logits + # during the forward pass, so each worker can safely handle ~1 max-length sequence. + # Budget = max_model_len * dp_size gives ~1 sequence per worker at max length. + _max_model_len = script_args.max_model_len or 8192 + _MAX_BATCH_TOKENS = _max_model_len * script_args.data_parallel_size + + async def _logprob_batcher(): + """Background task that continuously drains the queue, batches requests, and dispatches.""" + loop = asyncio.get_running_loop() + + while True: + batch = [] + try: + # Wait for the first request + batch_tokens = 0 + item = await _logprob_queue.get() + batch.append(item) + # Count tokens in this item's sequences + for prompt in item[0]: + batch_tokens += len(prompt.get("prompt_token_ids", [])) + + # Collect more requests up to batch limit, timeout, or token budget + deadline = loop.time() + _BATCH_WAIT_S + while len(batch) < _MAX_BATCH_REQUESTS and batch_tokens < _MAX_BATCH_TOKENS: + remaining = deadline - loop.time() + if remaining <= 0: + break + try: + item = await asyncio.wait_for(_logprob_queue.get(), timeout=remaining) + # Check if adding this item would exceed the token budget + item_tokens = sum(len(p.get("prompt_token_ids", [])) for p in item[0]) + if batch_tokens + item_tokens > _MAX_BATCH_TOKENS and len(batch) > 0: + # Put it back and dispatch current batch + await _logprob_queue.put(item) + break + batch.append(item) + batch_tokens += item_tokens + except asyncio.TimeoutError: + break + + # batch is a list of (prompts, prompt_lengths, top_logprobs, temperature, response_format, future) + # All items in a batch must share the same (top_logprobs, temperature) pair. + # Group by those execution parameters to handle mixed requests. + groups = {} + for prompts, prompt_lengths, top_logprobs, temperature, response_format, future in batch: + key = (top_logprobs, temperature) + if key not in groups: + groups[key] = [] + groups[key].append((prompts, prompt_lengths, response_format, future)) + + for (top_logprobs, temperature), items in groups.items(): + # Merge all sequences into a single batch + all_prompts = [] + all_prompt_lengths = [] + offsets = [] # (start_idx, count) per original request + for prompts, prompt_lengths, _response_format, _future in items: + start = len(all_prompts) + all_prompts.extend(prompts) + all_prompt_lengths.extend(prompt_lengths) + offsets.append((start, len(prompts))) + + sampling_params = SamplingParams( + max_tokens=1, + temperature=temperature, + prompt_logprobs=top_logprobs, + ) + + # Dispatch to workers in a thread to avoid blocking the event loop + try: + all_outputs = await loop.run_in_executor( + None, _run_prompt_logprobs, all_prompts, sampling_params + ) + + # Split results back to individual requests + for (start, count), (_, prompt_lengths, response_format, future) in zip( + offsets, items, strict=True + ): + outputs_slice = all_outputs[start : start + count] + if not future.done(): + future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) + except Exception as e: + # Signal error to all waiting requests in this execution-parameter group + for *_, future in items: + if not future.done(): + future.set_exception(e) + except Exception as e: + # Prevent killing the batcher task — signal error to all unfulfilled futures + for *_, future in batch: + if not future.done(): + future.set_exception(e) + + def _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format): + """Format vLLM outputs into the response dict (runs in any thread).""" + import numpy as np + + batch_size = len(all_outputs) + use_binary = response_format == "binary" + + if use_binary: + from starlette.responses import Response + + comp_lengths = [] + for output, prompt_length in zip(all_outputs, prompt_lengths, strict=True): + prompt_lps = output.prompt_logprobs + if prompt_lps is None: + raise ValueError("prompt_logprobs is None.") + comp_lengths.append(len(prompt_lps) - prompt_length) + + max_comp_len = max(comp_lengths) if comp_lengths else 0 + + # logprobs_arr / token_ids_arr: teacher's sorted top-k logprobs + token ids (for forward KL). + # actual_logprobs_arr / actual_token_ids_arr: actual token's teacher logprob (for reverse KL). + logprobs_arr = np.full((batch_size, max_comp_len, top_k), float("-inf"), dtype=np.float32) + token_ids_arr = np.zeros((batch_size, max_comp_len, top_k), dtype=np.int32) + actual_logprobs_arr = np.full((batch_size, max_comp_len, 1), float("-inf"), dtype=np.float32) + actual_token_ids_arr = np.zeros((batch_size, max_comp_len, 1), dtype=np.int32) + + for i, (output, prompt_length) in enumerate(zip(all_outputs, prompt_lengths, strict=True)): + prompt_lps = output.prompt_logprobs + seq_tokens = output.prompt_token_ids + if comp_lengths[i] == 0: + continue + + for pos in range(prompt_length, len(prompt_lps)): + lp = prompt_lps[pos] + if lp is None: + continue + t = pos - prompt_length + actual_token = seq_tokens[pos] + + # Actual token's logprob (for reverse KL) + if actual_token in lp: + val = lp[actual_token].logprob + if not math.isnan(val): + actual_logprobs_arr[i, t, 0] = val + actual_token_ids_arr[i, t, 0] = actual_token + + # Teacher's top-k logprobs (for forward KL) + if top_k == 1: + # Fast path: find rank-1 directly instead of sorting + for token_id, logprob_obj in lp.items(): + if logprob_obj.rank == 1: + val = logprob_obj.logprob + if not math.isnan(val): + logprobs_arr[i, t, 0] = val + token_ids_arr[i, t, 0] = token_id + break + else: + sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) + for k_idx, (token_id, logprob_obj) in enumerate(sorted_items[:top_k]): + val = logprob_obj.logprob + if not math.isnan(val): + logprobs_arr[i, t, k_idx] = val + token_ids_arr[i, t, k_idx] = token_id + + payload = { + "logprobs_b64": base64.b64encode(logprobs_arr.tobytes()).decode("ascii"), + "token_ids_b64": base64.b64encode(token_ids_arr.tobytes()).decode("ascii"), + "actual_logprobs_b64": base64.b64encode(actual_logprobs_arr.tobytes()).decode("ascii"), + "actual_token_ids_b64": base64.b64encode(actual_token_ids_arr.tobytes()).decode("ascii"), + "shape": [batch_size, max_comp_len, top_k], + "completion_lengths": comp_lengths, + } + + try: + import orjson + + return Response(content=orjson.dumps(payload), media_type="application/json") + except ImportError: + return payload + else: + all_logprobs = [] + all_token_ids = [] + for output, prompt_length in zip(all_outputs, prompt_lengths, strict=True): + prompt_lps = output.prompt_logprobs + if prompt_lps is None: + raise ValueError("prompt_logprobs is None.") + seq_logprobs = [] + seq_token_ids = [] + for pos in range(prompt_length, len(prompt_lps)): + lp = prompt_lps[pos] + if lp is None: + seq_logprobs.append([]) + seq_token_ids.append([]) + continue + sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) + seq_token_ids.append([token_id for token_id, _ in sorted_items]) + seq_logprobs.append( + [None if math.isnan(item.logprob) else item.logprob for _, item in sorted_items] + ) + all_logprobs.append(seq_logprobs) + all_token_ids.append(seq_token_ids) + return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + + @app.post("/get_sequence_logprobs/", response_model=SequenceLogprobsResponse) + async def get_sequence_logprobs(request: SequenceLogprobsRequest): + """ + Computes teacher logprobs for existing token sequences without generating new tokens. + + Concurrent requests are automatically batched and dispatched together to maximize GPU utilization across DP + workers. This avoids the event-loop-blocking problem where synchronous pipe I/O serializes requests despite + having multiple DP workers. + + Args: + request (`SequenceLogprobsRequest`): + - `sequences` (list of list of `int`): Full token sequences (prompt + completion) per sample. + - `prompt_lengths` (list of `int`): Number of prompt tokens per sequence; completion logprobs start + after each prompt. + - `top_logprobs` (`int`, *optional*, defaults to `100`): Number of top teacher logprobs to return per + completion position (sorted by vLLM rank). + - `temperature` (`float`, *optional*, defaults to `1.0`): Sampling temperature passed to vLLM for + logprob computation. + - `response_format` (`str`, *optional*, defaults to `"json"`): Either `"json"` (nested lists, + backward-compatible) or `"binary"` (base64-encoded numpy arrays for fast serialization). + + Returns: + `SequenceLogprobsResponse` or Starlette `Response`: + When `response_format` is `"json"`, a JSON object with: + - `logprobs` (list of list of list of `float` or `None`): Top-k teacher logprobs per completion token. + - `logprob_token_ids` (list of list of list of `int`): Token IDs aligned with `logprobs`. + When `response_format` is `"binary"`, a JSON response (Starlette `Response` if `orjson` is installed) + whose body is a JSON object with base64-encoded float32/int32 arrays: `logprobs_b64`, `token_ids_b64`, + `actual_logprobs_b64`, `actual_token_ids_b64`, plus `shape` (`list[int]`, `[batch_size, + max_completion_len, top_k]`) and `completion_lengths` (`list[int]`). + """ + if len(request.sequences) != len(request.prompt_lengths): + raise ValueError("sequences and prompt_lengths must have the same length.") + + for i, (seq, pl) in enumerate(zip(request.sequences, request.prompt_lengths, strict=True)): + if pl < 0 or pl > len(seq): + raise ValueError( + f"Sequence {i} has prompt_length={pl} which is out of range [0, {len(seq)}]. " + f"prompt_length must be between 0 and the sequence length inclusive." + ) + + # Validate sequence lengths against max_model_len to prevent worker OOM crashes + if _max_model_len: + for i, seq in enumerate(request.sequences): + if len(seq) > _max_model_len: + raise ValueError( + f"Sequence {i} has length {len(seq)} which exceeds max_model_len={_max_model_len}. " + f"Truncate sequences or increase --max-model-len." + ) + + prompts = [{"prompt_token_ids": seq} for seq in request.sequences] + + # Submit to the batching queue and await result + loop = asyncio.get_running_loop() + future = loop.create_future() + await _logprob_queue.put( + ( + prompts, + list(request.prompt_lengths), + request.top_logprobs, + request.temperature, + request.response_format, + future, + ) + ) + + # Wait for the batcher to process our request + all_outputs, prompt_lengths, top_k, response_format = await future + + return await loop.run_in_executor( + None, _format_logprob_response, all_outputs, prompt_lengths, top_k, response_format + ) + class ChatRequest(BaseModel): messages: list[list[dict]] n: int = 1