From ffde2d002ab617a870a983fde4bf8ead6f5d0f31 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 18 Mar 2026 15:32:53 +0100 Subject: [PATCH 01/37] Use `VLLMGeneration` in `GOLDTrainer` --- trl/experimental/gold/gold_config.py | 31 ++ trl/experimental/gold/gold_trainer.py | 441 +++----------------------- 2 files changed, 76 insertions(+), 396 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index d2c5fe72fb5..c32ef45a3e8 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -93,6 +93,17 @@ class GOLDConfig(SFTConfig): Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`). vllm_structured_outputs_regex (`str`, *optional*): Regex for vLLM structured outputs for the student model. + vllm_server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8001"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need + to change it. + vllm_max_model_length (`int`, *optional*): + Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the + model's maximum context length. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation backend to use in vLLM. Use `"vllm"` (default) or `"transformers"`. vllm_sync_frequency (`int`, *optional*, defaults to `1`): Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step. @@ -296,6 +307,12 @@ class GOLDConfig(SFTConfig): "help": 'Mode for vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).' }, ) + vllm_server_base_url: str | None = field( + default=None, + metadata={ + "help": 'Base URL for the vLLM server (e.g., "http://localhost:8001"). If provided, vllm_server_host and vllm_server_port are ignored.' + }, + ) vllm_server_host: str = field( default="0.0.0.0", metadata={"help": 'Host of the vLLM server when `vllm_mode="server"`.'}, @@ -308,6 +325,10 @@ class GOLDConfig(SFTConfig): default=240.0, metadata={"help": 'Timeout (in seconds) for connecting to the vLLM server when `vllm_mode="server"`.'}, ) + vllm_group_port: int = field( + default=51216, + metadata={"help": "Port for the vLLM weight-update group (NCCL communicator)."}, + ) vllm_gpu_memory_utilization: float = field( default=0.9, metadata={ @@ -318,6 +339,16 @@ class GOLDConfig(SFTConfig): default=1, metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'}, ) + vllm_max_model_length: int | None = field( + default=None, + metadata={ + "help": 'Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the model\'s maximum context length.' + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={"help": 'Model implementation backend to use in vLLM. Use "vllm" (default) or "transformers".'}, + ) vllm_structured_outputs_regex: str | None = field( default=None, metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."}, diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 26c0f72dddd..352a6bf5cd0 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random import textwrap import warnings @@ -29,9 +28,8 @@ from accelerate import PartialState from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model from datasets import Dataset, IterableDataset -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader -from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState, is_bitsandbytes_available +from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -51,7 +49,7 @@ from ...data_utils import is_conversational, maybe_convert_to_chatml, pack_dataset, truncate_dataset from ...extras.profiling import profiling_decorator -from ...generation.vllm_client import VLLMClient +from ...generation.vllm_generation import VLLMGeneration from ...import_utils import is_vllm_available from ...models import prepare_deepspeed from ...models.utils import unwrap_model_for_generation @@ -60,7 +58,6 @@ RepeatSampler, create_model_from_path, disable_dropout_in_model, - ensure_master_addr_port, pad, split_tensor_dict, ) @@ -74,9 +71,6 @@ if is_wandb_available(): import wandb -if is_vllm_available(): - from vllm import LLM, SamplingParams - from vllm.sampling_params import StructuredOutputsParams if is_liger_kernel_available(): from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss @@ -87,9 +81,6 @@ from rich.table import Table from rich.text import Text -if is_bitsandbytes_available(): - import bitsandbytes as bnb - def print_prompt_completions_sample_uld( prompts: list[str], @@ -751,25 +742,6 @@ def _get_start_and_size_answers(self, answer_tensors): return answers_index, answers_size -class GOLDVLLMSyncCallback(TrainerCallback): - """Sync the model weights to vLLM after training steps when it's safe to do so.""" - - def __init__(self, trainer): - self.trainer = trainer - - def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): - """Sync weights after training step when DeepSpeed is stable.""" - if ( - self.trainer.use_vllm - and state.global_step != self.trainer._last_vllm_sync_step - and state.global_step % self.trainer.vllm_sync_frequency == 0 - ): - # Check if this is a step where gradients are synchronized - # This happens at the end of gradient accumulation cycles - if hasattr(self.trainer.accelerator, "sync_gradients") and self.trainer.accelerator.sync_gradients: - self.trainer._move_model_to_vllm() - self.trainer._last_vllm_sync_step = state.global_step - class GOLDTrainer(SFTTrainer): _tag_names = ["trl", "gold"] @@ -964,86 +936,35 @@ def __init__( "vLLM is not available and use_vllm is set to True. Please install vLLM with " "`pip install vllm` to use it." ) - self.vllm_mode = args.vllm_mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization - self.vllm_enable_sleep_mode = args.vllm_enable_sleep_mode - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - self.vllm_client = VLLMClient( - host=args.vllm_server_host, - server_port=args.vllm_server_port, - connection_timeout=args.vllm_server_timeout, - ) - self.vllm_client.init_communicator() - elif self.vllm_mode == "colocate": - student_model_name_or_path = self.model_name_or_path - - # Make sure tensor_parallel_size divides world size evenly - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP - self.vllm_tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list( - range( - i * self.vllm_tensor_parallel_size, - (i + 1) * self.vllm_tensor_parallel_size, - ) - ) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - - # vLLM requires the environment variables to be set for distributed training. - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - ensure_master_addr_port() - - vllm_quantization = None - if is_bitsandbytes_available(): - for _, module in model.named_modules(): - if isinstance(module, bnb.nn.Linear4bit): - vllm_quantization = "bitsandbytes" - break - elif isinstance(module, bnb.nn.Linear8bitLt): - raise ValueError("vLLM does not support in-flight 8-bit quantization.") - - self.vllm_engine = LLM( - model=student_model_name_or_path, - revision=self.model_revision, - tensor_parallel_size=self.vllm_tensor_parallel_size, - gpu_memory_utilization=self.vllm_gpu_memory_utilization, - max_num_seqs=self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps, - max_model_len=args.max_length, - distributed_executor_backend="external_launcher", - # Feed identical seed for tp groups to ensure sampling results are the same across workers - seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - enable_sleep_mode=self.vllm_enable_sleep_mode, - quantization=vllm_quantization, - ) - - if self.vllm_enable_sleep_mode: - self.vllm_engine.sleep(level=2) - - # When using vLLM, the main process is responsible for loading the model weights. This can cause process - # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() - else: - raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") - self.vllm_structured_outputs_regex = args.vllm_structured_outputs_regex + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size * args.gradient_accumulation_steps, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + repetition_penalty=getattr(args, "repetition_penalty", 1.0), + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + min_p=getattr(args, "min_p", 0.0), + max_completion_length=args.max_completion_length, + logprobs=None, + ) self.vllm_sync_frequency = args.vllm_sync_frequency self._last_vllm_sync_step = -1 - self.add_callback(GOLDVLLMSyncCallback(self)) - def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() required_columns = [ @@ -1239,209 +1160,39 @@ def _generate_on_policy_for_slices( local_prompts.append(prompt) local_slice_indices.append(slice_idx) - prompts_text_for_vllm = self.processing_class.batch_decode( - torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), - skip_special_tokens=True, - ) - if self.processing_class.pad_token: - prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] - prompts_text_with_special = self.processing_class.batch_decode( torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), skip_special_tokens=False, ) - if self.use_vllm: - self._wake_vllm_if_needed() - - max_completion_length = self.generation_config.max_new_tokens - temperature = self.generation_config.temperature - top_k = ( - self.generation_config.top_k if self.generation_config.top_k and self.generation_config.top_k > 0 else -1 - ) - top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 - repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 - min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 - - if self.use_vllm and self.vllm_mode == "server": - completion_ids = self._generate_vllm_server_global( - prompts_text_for_vllm, - max_completion_length, - temperature, - top_k, - top_p, - repetition_penalty, - min_p, - n=self.num_generations, - ) - elif self.use_vllm and self.vllm_mode == "colocate": - completion_ids = self._generate_vllm_colocate( - prompts_text_for_vllm, - max_completion_length, - temperature, - top_k, - top_p, - repetition_penalty, - min_p, - n=self.num_generations, - ) - else: + if not self.use_vllm: self._generate_non_vllm_for_slices(slices, on_policy_indices) return + if ( + self.state.global_step != self._last_vllm_sync_step + and self.state.global_step % self.vllm_sync_frequency == 0 + ): + self.vllm_generation.sync_weights() + self._last_vllm_sync_step = self.state.global_step + + prompt_ids_list = [p.tolist() for p in local_prompts] + _, completion_ids, _, _ = self.vllm_generation.generate( + prompts=prompt_ids_list, + images=None, + num_generations=self.num_generations, + ) + self._process_completions_to_buffer( slices, on_policy_indices, local_slice_indices, completion_ids, - prompts_text_for_vllm, prompts_text_with_special, - max_completion_length, - ) - - @staticmethod - def _deduplicate_prompts( - prompts: list[str], num_generations: int - ) -> tuple[list[str], list[tuple[int, int]]] | None: - """Deduplicate prompts and build a completion remapping.""" - seen: dict[str, list[int]] = {} - unique_prompts: list[str] = [] - dedup_mapping: list[tuple[int, int]] = [] - - for prompt in prompts: - if prompt not in seen: - seen[prompt] = [len(unique_prompts), 0] - unique_prompts.append(prompt) - entry = seen[prompt] - if entry[1] >= num_generations: - return None - dedup_mapping.append((entry[0], entry[1])) - entry[1] += 1 - - return unique_prompts, dedup_mapping - - def _generate_vllm_server_global( - self, - prompts_text: list[str], - max_tokens: int, - temperature: float, - top_k: int, - top_p: float, - repetition_penalty: float, - min_p: float, - n: int = 1, - ) -> list: - all_prompts_text = gather_object(prompts_text) - local_count = len(prompts_text) - - if self.accelerator.is_main_process: - if all_prompts_text: - dedup_mapping = None - if n > 1: - dedup_result = self._deduplicate_prompts(all_prompts_text, n) - if dedup_result is not None: - gen_prompts, dedup_mapping = dedup_result - gen_n = n - else: - gen_prompts = all_prompts_text - gen_n = 1 - else: - gen_prompts = all_prompts_text - gen_n = 1 - - completion_ids = self.vllm_client.generate( - prompts=gen_prompts, - n=gen_n, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_tokens, - structured_outputs_regex=self.vllm_structured_outputs_regex, - )["completion_ids"] - - if dedup_mapping is not None: - completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] - else: - completion_ids = [] - else: - completion_ids = [None] * len(all_prompts_text) if all_prompts_text else [] - - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * local_count, - (self.accelerator.process_index + 1) * local_count, - ) - return completion_ids[process_slice] - - def _generate_vllm_colocate( - self, - prompts_text: list[str], - max_tokens: int, - temperature: float, - top_k: int, - top_p: float, - repetition_penalty: float, - min_p: float, - n: int = 1, - ) -> list: - if self.vllm_structured_outputs_regex: - structured_outputs = StructuredOutputsParams(backend="outlines", regex=self.vllm_structured_outputs_regex) - else: - structured_outputs = None - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.vllm_tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts_text = prompts_text - - dedup_mapping = None - if n > 1 and all_prompts_text: - dedup_result = self._deduplicate_prompts(all_prompts_text, n) - if dedup_result is not None: - gen_prompts, dedup_mapping = dedup_result - gen_n = n - else: - gen_prompts = all_prompts_text - gen_n = 1 - else: - gen_prompts = all_prompts_text - gen_n = 1 - - sampling_params = SamplingParams( - n=gen_n, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_tokens, - structured_outputs=structured_outputs, + prompts_text_with_special, + self.generation_config.max_new_tokens, ) - if gen_prompts: - all_outputs = self.vllm_engine.generate(gen_prompts, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - else: - completion_ids = [] - - if dedup_mapping is not None: - completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - - if self.vllm_enable_sleep_mode: - self.vllm_engine.sleep(level=2) - - return completion_ids - def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): """Fallback generation without vLLM (uses model.generate per slice).""" with unwrap_model_for_generation( @@ -2160,108 +1911,6 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token return new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts - def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM.""" - if visited is None: - visited = set() - - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - # recurse into the child - self._sync_fsdp_params_to_vllm(child_module, prefix=child_prefix, visited=visited) - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."): - full_name = full_name.replace(extra, "") - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(full_name, param.data)]) - - def _move_model_to_vllm(self): - """Synchronize student model weights to vLLM engine.""" - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode: - empty_cache() - self.vllm_engine.wake_up(tags=["weights"]) - # Work around for https://github.com/vllm-project/vllm/issues/29341 - self.vllm_engine.collective_rpc("reload_weights") - - if is_peft_model(self.model): - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - self._sync_fsdp_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = name.replace("modules_to_save.default.", "") - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - # use memory-efficient post-order traversal for FSDP - self._sync_fsdp_params_to_vllm(self.model) - else: - # For DeepSpeed ZeRO-3, gather each parameter individually like GRPO trainer - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.vllm_engine.reset_prefix_cache() - - def _wake_vllm_if_needed(self): - if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode: - empty_cache() - self.vllm_engine.wake_up(tags=["kv_cache"]) - def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): if not self.use_liger_gkd_loss: return nullcontext() From 1797fc14e364479f4ecb5c258002acf3c6e70139 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 19 Mar 2026 11:22:37 +0100 Subject: [PATCH 02/37] Update with precommit --- trl/experimental/gold/gold_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 352a6bf5cd0..0c7d3ae8bbe 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -26,10 +26,10 @@ import torch.nn as nn import torch.nn.functional as F from accelerate import PartialState -from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model +from accelerate.utils import DistributedType, broadcast_object_list, gather_object from datasets import Dataset, IterableDataset from torch.utils.data import DataLoader -from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState +from transformers import AutoTokenizer, TrainerCallback from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -742,7 +742,6 @@ def _get_start_and_size_answers(self, answer_tensors): return answers_index, answers_size - class GOLDTrainer(SFTTrainer): _tag_names = ["trl", "gold"] _name = "GOLD" From f723677162c9afb25b7ccd0119137fc3a7c56fca Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 14:48:16 +0100 Subject: [PATCH 03/37] Initial `DistillationTrainer` implementation --- trl/__init__.py | 4 + trl/generation/vllm_client.py | 47 ++++ trl/scripts/distillation.py | 182 +++++++++++++++ trl/scripts/vllm_serve.py | 82 +++++++ trl/trainer/__init__.py | 4 + trl/trainer/distillation_config.py | 352 +++++++++++++++++++++++++++++ 6 files changed, 671 insertions(+) create mode 100644 trl/scripts/distillation.py create mode 100644 trl/trainer/distillation_config.py diff --git a/trl/__init__.py b/trl/__init__.py index ade232ac2da..3a5ef029f83 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -46,6 +46,8 @@ "scripts": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"], "trainer": [ "BEMACallback", + "DistillationConfig", + "DistillationTrainer", "DPOConfig", "DPOTrainer", "GRPOConfig", @@ -90,6 +92,8 @@ from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose from .trainer import ( BEMACallback, + DistillationConfig, + DistillationTrainer, DPOConfig, DPOTrainer, GRPOConfig, diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index d993b283b7f..8f7680869bb 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -523,6 +523,53 @@ def update_model_params(self, model: nn.Module): # Update each parameter individually self.update_named_param(name, param.data) + def get_sequence_logprobs( + self, + sequences: list[list[int]], + prompt_lengths: list[int], + top_logprobs: int = 100, + ) -> dict[str, list]: + """ + Computes teacher logprobs for existing token sequences without generating new tokens. + + Sends full sequences (prompt + completion) to the vLLM server and retrieves per-token top-k logprobs for the + completion region only. This is used for knowledge distillation where the teacher model evaluates existing + sequences rather than generating new ones. + + Args: + sequences (`list[list[int]]`): + List of full token ID sequences (prompt + completion). + prompt_lengths (`list[int]`): + Number of prompt tokens in each sequence. Logprobs are returned starting from this position. + top_logprobs (`int`, *optional*, defaults to `100`): + Number of top logprobs to return per token position. + + Returns: + `dict` with keys: + - `logprobs` (`list[list[list[float]]]`): + Per-token logprobs of shape (batch, completion_len, top_logprobs), sorted by descending + probability. + - `logprob_token_ids` (`list[list[list[int]]]`): + Token IDs corresponding to each logprob, same shape as `logprobs`. + """ + url = f"{self.base_url}/get_sequence_logprobs/" + response = self.session.post( + url, + json={ + "sequences": sequences, + "prompt_lengths": prompt_lengths, + "top_logprobs": top_logprobs, + }, + ) + if response.status_code == 200: + json_response = response.json() + return { + "logprobs": json_response["logprobs"], + "logprob_token_ids": json_response["logprob_token_ids"], + } + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + def reset_prefix_cache(self): """ Resets the prefix cache for the model. diff --git a/trl/scripts/distillation.py b/trl/scripts/distillation.py new file mode 100644 index 00000000000..ca0ed4dde70 --- /dev/null +++ b/trl/scripts/distillation.py @@ -0,0 +1,182 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +# docstyle-ignore +""" +# Full training (off-policy only, lmbda=0): +``` +python trl/scripts/distillation.py \ + --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --lmbda 0.0 \ + --output_dir distilled-model \ + --num_train_epochs 1 +``` + +# Mixed on/off-policy (lmbda=0.5): +``` +python trl/scripts/distillation.py \ + --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --lmbda 0.5 \ + --beta 0.5 \ + --output_dir distilled-model \ + --num_train_epochs 1 +``` + +# LoRA: +``` +python trl/scripts/distillation.py \ + --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-4 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --lmbda 0.0 \ + --output_dir distilled-model \ + --num_train_epochs 1 \ + --use_peft \ + --lora_r 64 \ + --lora_alpha 16 +``` +""" + +import argparse +import os + + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args): + from accelerate import logging + from datasets import load_dataset + from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + + from trl import ( + DistillationTrainer, + LogCompletionsCallback, + ModelConfig, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + ) + + logger = logging.get_logger(__name__) + + ################ + # Model init kwargs + ################ + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.model_init_kwargs = model_kwargs + + teacher_model_kwargs = dict( + revision=training_args.teacher_model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.dtype, + use_cache=True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + if training_args.teacher_model_init_kwargs is not None: + teacher_model_kwargs.update(training_args.teacher_model_init_kwargs) + training_args.teacher_model_init_kwargs = teacher_model_kwargs + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + eval_dataset = None + if training_args.eval_strategy != "no": + if script_args.dataset_test_split in dataset: + eval_dataset = dataset[script_args.dataset_test_split] + elif "validation" in dataset: + eval_dataset = dataset["validation"] + elif "dev" in dataset: + eval_dataset = dataset["dev"] + + trainer = DistillationTrainer( + model=model_args.model_name_or_path, + teacher_model=training_args.teacher_model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=eval_dataset, + peft_config=get_peft_config(model_args), + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None): + from trl import DistillationConfig, ModelConfig, ScriptArguments, TrlParser + + dataclass_types = (ScriptArguments, DistillationConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser( + "distillation", help="Run the distillation training script", dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types, prog=prog) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config(fail_with_unknown_args=False) + main(script_args, training_args, model_args) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 370e3a1e5d2..99013a2da49 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -15,6 +15,7 @@ import argparse import base64 import logging +import math import os from collections.abc import Sequence from contextlib import asynccontextmanager @@ -641,6 +642,87 @@ async def generate(request: GenerateRequest): "logprob_token_ids": logprob_token_ids, } + class SequenceLogprobsRequest(BaseModel): + sequences: list[list[int]] + prompt_lengths: list[int] + top_logprobs: int = 100 + + class SequenceLogprobsResponse(BaseModel): + logprobs: list[list[list[float | None]]] + logprob_token_ids: list[list[list[int]]] + + @app.post("/get_sequence_logprobs/", response_model=SequenceLogprobsResponse) + async def get_sequence_logprobs(request: SequenceLogprobsRequest): + """ + Computes teacher logprobs for existing token sequences without generating new tokens. + + Sends the full sequence (prompt + completion) as the vLLM prompt with `max_tokens=1` and + `prompt_logprobs=top_logprobs`. Returns logprobs only for the completion region (positions + from `prompt_length` onwards) for each sequence. + + Args: + request (`SequenceLogprobsRequest`): + - `sequences` (list of list of `int`): Full token ID sequences (prompt + completion). + - `prompt_lengths` (list of `int`): Number of prompt tokens in each sequence. Logprobs + are returned starting from this position. + - `top_logprobs` (`int`, *optional*, defaults to `100`): Number of top logprobs per position. + + Returns: + `SequenceLogprobsResponse`: + - `logprobs` (list of list of list of `float`): Per-token logprobs of shape + (batch, completion_len, top_logprobs), sorted by descending probability. + - `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each + logprob, same shape as `logprobs`. + """ + if len(request.sequences) != len(request.prompt_lengths): + raise ValueError("sequences and prompt_lengths must have the same length.") + + prompts = [{"prompt_token_ids": seq} for seq in request.sequences] + sampling_params = SamplingParams( + max_tokens=1, + temperature=1.0, + prompt_logprobs=request.top_logprobs, + ) + + chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) + + for connection, chunk in zip(connections, chunked_prompts, strict=True): + if not chunk: + chunk = [{"prompt_token_ids": [0]}] + kwargs = {"prompts": chunk, "sampling_params": sampling_params} + connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) + + all_outputs = [connection.recv() for connection in connections] + all_outputs = [output for output, chunk in zip(all_outputs, chunked_prompts, strict=True) if chunk] + all_outputs = list(chain.from_iterable(all_outputs)) + + all_logprobs = [] + all_token_ids = [] + for output, prompt_length in zip(all_outputs, request.prompt_lengths, strict=True): + # prompt_logprobs is a list of dicts, one per prompt token (first token is None) + prompt_lps = output.prompt_logprobs + if prompt_lps is None: + raise ValueError("prompt_logprobs is None. Ensure the vLLM server supports prompt_logprobs.") + + seq_logprobs = [] + seq_token_ids = [] + # Extract logprobs only for the completion region + for pos in range(prompt_length, len(prompt_lps)): + lp = prompt_lps[pos] + if lp is None: + seq_logprobs.append([]) + seq_token_ids.append([]) + continue + sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) + seq_token_ids.append([token_id for token_id, _ in sorted_items]) + seq_logprobs.append( + [None if math.isnan(item.logprob) else item.logprob for _, item in sorted_items] + ) + all_logprobs.append(seq_logprobs) + all_token_ids.append(seq_token_ids) + + return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + class ChatRequest(BaseModel): messages: list[list[dict]] n: int = 1 diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index f24ea415072..c4f2482e56a 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -25,6 +25,8 @@ "SyncRefModelCallback", "WeaveCallback", ], + "distillation_config": ["DistillationConfig"], + "distillation_trainer": ["DistillationTrainer"], "dpo_config": ["DPOConfig"], "dpo_trainer": ["DPOTrainer"], "grpo_config": ["GRPOConfig"], @@ -55,6 +57,8 @@ SyncRefModelCallback, WeaveCallback, ) + from .distillation_config import DistillationConfig + from .distillation_trainer import DistillationTrainer from .dpo_config import DPOConfig from .dpo_trainer import DPOTrainer from .grpo_config import GRPOConfig diff --git a/trl/trainer/distillation_config.py b/trl/trainer/distillation_config.py new file mode 100644 index 00000000000..42ebd1755a8 --- /dev/null +++ b/trl/trainer/distillation_config.py @@ -0,0 +1,352 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Any + +from .base_config import _BaseConfig + + +@dataclass +class DistillationConfig(_BaseConfig): + # docstyle-ignore + r""" + Configuration class for the [`DistillationTrainer`]. + + Extends [`~transformers.TrainingArguments`] with parameters specific to knowledge distillation. This config is + independent of [`SFTConfig`] — all necessary fields are declared here. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the + trainer is provided as a string. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum total sequence length (prompt + completion) for tokenization and truncation. + + > Parameters that control the distillation + + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for sampling during generation and for computing the distillation loss. Higher values produce + softer probability distributions. + lmbda (`float`, *optional*, defaults to `0.5`): + Probability of using on-policy (student-generated) data for each gradient accumulation slice. A value of + `0.0` means fully off-policy (dataset completions only), `1.0` means fully on-policy. + beta (`float`, *optional*, defaults to `0.5`): + Interpolation coefficient for the Generalized Jensen-Shannon Divergence loss. When `0.0`, the loss is the + forward KL divergence. When `1.0`, the loss is the reverse KL divergence. When `0.5`, it is the standard + JSD. + max_completion_length (`int`, *optional*, defaults to `256`): + Maximum number of tokens to generate per completion during on-policy generation. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the student model during training. + + > Parameters that control the teacher model + + teacher_model_name_or_path (`str` or `None`, *optional*): + Model name or path for the teacher model. Used when the teacher is loaded locally. Mutually exclusive with + `teacher_model_server_url`. + teacher_model_revision (`str` or `None`, *optional*): + Model revision of the teacher model (e.g., branch name, tag, or commit hash). + teacher_model_init_kwargs (`dict[str, Any]` or `None`, *optional*): + Keyword arguments passed to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model + from a string. + teacher_model_server_url (`str` or `None`, *optional*): + Base URL of a vLLM server hosting the teacher model (e.g., `"http://localhost:8000"`). When set, teacher + logprobs are fetched from the server instead of running a local forward pass. Mutually exclusive with + passing a `teacher_model` object to the trainer. + teacher_server_top_logprobs (`int`, *optional*, defaults to `1`): + Number of top logprobs to request from the teacher server per token position. Only used when + `teacher_model_server_url` is set. Currently only `1` is supported — the server path uses a per-token + logprob approximation of the divergence. Full-vocabulary divergence is only available with a local teacher. + + > Parameters that control on-policy generation + + num_generations (`int`, *optional*, defaults to `1`): + Number of completions to generate per prompt during on-policy generation. + generation_batch_size (`int` or `None`, *optional*): + Number of unique prompts per worker per optimizer step. If `None`, computed from + `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`. + top_p (`float`, *optional*, defaults to `0.95`): + Top-p (nucleus) sampling parameter for on-policy generation. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter for on-policy generation. `0` disables top-k filtering. + + > Parameters that control vLLM for student generation + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating on-policy completions from the student model. + vllm_mode (`str`, *optional*, defaults to `"colocate"`): + Mode for student vLLM integration. Either `"server"` or `"colocate"`. + vllm_server_base_url (`str` or `None`, *optional*): + Base URL for the student vLLM server. If provided, `vllm_server_host` and `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the student vLLM server. + vllm_server_port (`int`, *optional*, defaults to `8001`): + Port of the student vLLM server. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Timeout for connecting to the student vLLM server. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the vLLM weight-update group (NCCL communicator). + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + GPU memory utilization for the colocated student vLLM engine. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Tensor parallel size for the colocated student vLLM engine. + vllm_max_model_length (`int` or `None`, *optional*): + Maximum model sequence length for the colocated vLLM engine. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation backend for vLLM. Use `"vllm"` or `"transformers"`. + vllm_structured_outputs_regex (`str` or `None`, *optional*): + Regex pattern for vLLM structured outputs. + vllm_sync_frequency (`int`, *optional*, defaults to `1`): + Frequency (in training steps) to synchronize student model weights to the vLLM engine. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Enable vLLM sleep mode to offload student weights during the optimizer step. + + > Parameters that control logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log a sample of (prompt, completion) pairs periodically. + log_completions_steps (`int`, *optional*, defaults to `100`): + Number of steps between logging completions. Only used if `log_completions` is `True`. + num_completions_to_print (`int` or `None`, *optional*): + Number of completions to print. If `None`, all completions are logged. + """ + + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs", "teacher_model_init_kwargs"] + + # Model + model_init_kwargs: dict[str, Any] | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument " + "of the trainer is provided as a string." + }, + ) + max_length: int | None = field( + default=1024, + metadata={ + "help": "Maximum total sequence length (prompt + completion) for tokenization and truncation." + }, + ) + + # Overridden defaults + learning_rate: float = field( + default=1e-7, + metadata={"help": "The initial learning rate for AdamW."}, + ) + + # Distillation core + temperature: float = field( + default=1.0, + metadata={ + "help": "Temperature for sampling and loss computation. Higher values produce softer distributions." + }, + ) + lmbda: float = field( + default=0.5, + metadata={ + "help": "Probability of using on-policy (student-generated) data per gradient accumulation slice. " + "0.0 = fully off-policy, 1.0 = fully on-policy." + }, + ) + beta: float = field( + default=0.5, + metadata={ + "help": "Interpolation coefficient for the Generalized JSD loss. " + "0.0 = forward KL, 0.5 = JSD, 1.0 = reverse KL." + }, + ) + max_completion_length: int = field( + default=256, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the student model during training."}, + ) + + # Teacher model (local) + teacher_model_name_or_path: str | None = field( + default=None, + metadata={"help": "Model name or path for the teacher model."}, + ) + teacher_model_revision: str | None = field( + default=None, + metadata={"help": "Model revision of the teacher model (e.g., branch name, tag, or commit hash)."}, + ) + teacher_model_init_kwargs: dict[str, Any] | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained` when instantiating the teacher." + }, + ) + + # Teacher model (external vLLM server) + teacher_model_server_url: str | None = field( + default=None, + metadata={ + "help": 'Base URL of a vLLM server hosting the teacher model (e.g., "http://localhost:8000"). ' + "When set, teacher logprobs are fetched from the server." + }, + ) + teacher_server_top_logprobs: int = field( + default=1, + metadata={"help": "Number of top logprobs to request from the teacher server per token position. " + "Currently only `1` is supported for the server path."}, + ) + + # On-policy generation + num_generations: int = field( + default=1, + metadata={"help": "Number of completions to generate per prompt during on-policy generation."}, + ) + generation_batch_size: int | None = field( + default=None, + metadata={ + "help": "Number of unique prompts per worker per optimizer step. " + "If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations." + }, + ) + top_p: float = field( + default=0.95, + metadata={"help": "Top-p (nucleus) sampling parameter for on-policy generation."}, + ) + top_k: int = field( + default=0, + metadata={"help": "Top-k sampling parameter for on-policy generation. 0 disables top-k filtering."}, + ) + + # vLLM for student generation + use_vllm: bool = field( + default=False, + metadata={"help": "Whether to use vLLM for generating on-policy completions from the student model."}, + ) + vllm_mode: str = field( + default="colocate", + metadata={"help": 'Mode for student vLLM integration. Either "server" or "colocate".'}, + ) + vllm_server_base_url: str | None = field( + default=None, + metadata={"help": "Base URL for the student vLLM server."}, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the student vLLM server."}, + ) + vllm_server_port: int = field( + default=8001, + metadata={"help": "Port of the student vLLM server."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={"help": "Timeout for connecting to the student vLLM server."}, + ) + vllm_group_port: int = field( + default=51216, + metadata={"help": "Port for the vLLM weight-update group (NCCL communicator)."}, + ) + vllm_gpu_memory_utilization: float = field( + default=0.9, + metadata={"help": "GPU memory utilization for the colocated student vLLM engine."}, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={"help": "Tensor parallel size for the colocated student vLLM engine."}, + ) + vllm_max_model_length: int | None = field( + default=None, + metadata={"help": "Maximum model sequence length for the colocated vLLM engine."}, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={"help": 'Model implementation backend for vLLM. Use "vllm" or "transformers".'}, + ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex pattern for vLLM structured outputs."}, + ) + vllm_sync_frequency: int = field( + default=1, + metadata={"help": "Frequency (in training steps) to synchronize student weights to the vLLM engine."}, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={"help": "Enable vLLM sleep mode to offload student weights during the optimizer step."}, + ) + + # Logging + log_completions: bool = field( + default=False, + metadata={"help": "Whether to log a sample of (prompt, completion) pairs periodically."}, + ) + log_completions_steps: int = field( + default=100, + metadata={"help": "Number of steps between logging completions."}, + ) + num_completions_to_print: int | None = field( + default=None, + metadata={"help": "Number of completions to print. If None, all completions are logged."}, + ) + + def __post_init__(self): + super().__post_init__() + + if self.lmbda < 0.0 or self.lmbda > 1.0: + raise ValueError(f"lmbda must be in [0.0, 1.0], got {self.lmbda}.") + if self.beta < 0.0 or self.beta > 1.0: + raise ValueError(f"beta must be in [0.0, 1.0], got {self.beta}.") + + if self.max_length is not None and self.max_completion_length >= self.max_length: + raise ValueError( + f"max_completion_length ({self.max_completion_length}) must be smaller than " + f"max_length ({self.max_length}) to leave room for the prompt." + ) + + if self.num_generations < 1: + raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") + + local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps + if self.generation_batch_size is None: + self.generation_batch_size = local_sequence_batch_size // self.num_generations + if self.generation_batch_size < 1: + raise ValueError(f"generation_batch_size must be at least 1, got {self.generation_batch_size}.") + if self.generation_batch_size * self.num_generations != local_sequence_batch_size: + raise ValueError( + "generation_batch_size * num_generations must equal per_device_train_batch_size * " + f"gradient_accumulation_steps. Got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." + ) + + if self.teacher_model_server_url is not None and self.teacher_server_top_logprobs != 1: + raise ValueError( + f"When using a teacher server (`teacher_model_server_url`), only `teacher_server_top_logprobs=1` is " + f"supported (got {self.teacher_server_top_logprobs}). The server computes a per-token logprob " + f"approximation of the divergence loss. Full-vocabulary divergence computation is only supported with " + f"a local teacher model." + ) + + if self.num_generations > 1 and self.lmbda < 1.0: + warnings.warn( + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include " + f"{self.num_generations} copies of each sample. Consider lmbda=1.0 when num_generations > 1.", + UserWarning, + stacklevel=2, + ) From b629987a0dbc24ffbacef7ca331c5d46ba626e36 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 14:55:53 +0100 Subject: [PATCH 04/37] Fix how we handle padding and special tokens --- trl/experimental/gold/gold_trainer.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 0c7d3ae8bbe..d41a146661c 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1159,11 +1159,20 @@ def _generate_on_policy_for_slices( local_prompts.append(prompt) local_slice_indices.append(slice_idx) + stacked_prompts = torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long) + prompts_text_with_special = self.processing_class.batch_decode( - torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + stacked_prompts, skip_special_tokens=False, ) + prompts_text = self.processing_class.batch_decode( + stacked_prompts, + skip_special_tokens=True, + ) + if self.processing_class.pad_token: + prompts_text = [p.replace(self.processing_class.pad_token, "") for p in prompts_text] + if not self.use_vllm: self._generate_non_vllm_for_slices(slices, on_policy_indices) return @@ -1175,7 +1184,10 @@ def _generate_on_policy_for_slices( self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step - prompt_ids_list = [p.tolist() for p in local_prompts] + pad_token_id = self.processing_class.pad_token_id + prompt_ids_list = [ + [tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts + ] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, @@ -1187,7 +1199,7 @@ def _generate_on_policy_for_slices( on_policy_indices, local_slice_indices, completion_ids, - prompts_text_with_special, + prompts_text, prompts_text_with_special, self.generation_config.max_new_tokens, ) From 6746237d2ceb44b3fb227b016c892a047175105f Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:13:41 +0100 Subject: [PATCH 05/37] Initial implementation of distillation trainer --- trl/trainer/distillation_trainer.py | 1049 +++++++++++++++++++++++++++ 1 file changed, 1049 insertions(+) create mode 100644 trl/trainer/distillation_trainer.py diff --git a/trl/trainer/distillation_trainer.py b/trl/trainer/distillation_trainer.py new file mode 100644 index 00000000000..ebe7e87dd25 --- /dev/null +++ b/trl/trainer/distillation_trainer.py @@ -0,0 +1,1049 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import textwrap +from collections import defaultdict, deque +from collections.abc import Callable +from contextlib import nullcontext +from functools import partial +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from accelerate.utils import DistributedType, broadcast_object_list, gather_object +from datasets import Dataset, IterableDataset +from torch.utils.data import DataLoader +from transformers import AutoProcessor, AutoTokenizer, TrainerCallback +from transformers.data.data_collator import DataCollator +from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.generation.configuration_utils import GenerationConfig +from transformers.image_processing_utils import BaseImageProcessor +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.utils import ( + is_datasets_available, + is_liger_kernel_available, + is_peft_available, + is_rich_available, +) + +from ..experimental.utils import DataCollatorForChatML +from ..extras.profiling import profiling_decorator +from ..generation.vllm_generation import VLLMGeneration +from ..import_utils import is_vllm_available +from ..models import prepare_deepspeed +from ..models.utils import unwrap_model_for_generation +from .base_trainer import _BaseTrainer +from .distillation_config import DistillationConfig +from .utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + pad, + split_tensor_dict, +) + + +if is_peft_available(): + from peft import PeftConfig, get_peft_model + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + +if is_rich_available(): + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + + +def _print_completions_sample(prompts: list[str], completions: list[str], step: int, num_samples: int = None) -> None: + """Print a sample of prompt-completion pairs using rich.""" + if not is_rich_available(): + return + + console = Console() + table = Table(show_header=True, header_style="bold white", expand=True) + table.add_column("Prompt", style="bright_yellow") + table.add_column("Completion", style="bright_green") + + if num_samples is not None: + if num_samples >= len(prompts): + num_samples = None + elif num_samples <= 0: + return + + if num_samples is not None: + indices = random.sample(range(len(prompts)), num_samples) + prompts = [prompts[i] for i in indices] + completions = [completions[i] for i in indices] + + for prompt, completion in zip(prompts, completions, strict=True): + table.add_row(Text(prompt), Text(completion)) + table.add_section() + + panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") + console.print(panel) + + +class DistillationTrainer(_BaseTrainer): + """ + Trainer for knowledge distillation from a teacher model to a student model. + + Supports: + - Generalized JSD loss (forward KL, reverse KL, or interpolated JSD via `beta`) + - On-policy / off-policy mixing via `lmbda` (buffered across gradient accumulation) + - Local teacher model or external teacher via vLLM server + - Student on-policy generation via vLLM or model.generate() + - Liger kernel for memory-efficient fused JSD loss + """ + + _tag_names = ["trl", "distillation"] + _name = "Distillation" + _paper = { + "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", + "id": "2306.13649", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{agarwal2024on-policy, + title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, + author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=3zKtaqxLhW}, + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + teacher_model: PreTrainedModel | nn.Module | str = None, + args: DistillationConfig | None = None, + data_collator: DataCollator | None = None, # type: ignore + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: Optional["PeftConfig"] = None, + ): + if args is None: + args = DistillationConfig(output_dir="tmp_distillation") + + # ── Student model loading ── + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model_init_kwargs, str): + import json + + model_init_kwargs = json.loads(model_init_kwargs) + if isinstance(model, str): + model_name_or_path = model + model = create_model_from_path(model, **model_init_kwargs) + else: + model_name_or_path = model.config._name_or_path if model is not None else None + + # ── Processing class (tokenizer) ── + if processing_class is None and model_name_or_path is not None: + processing_class = AutoTokenizer.from_pretrained(model_name_or_path) + if processing_class is not None: + if getattr(processing_class, "pad_token", None) is None: + processing_class.pad_token = processing_class.eos_token + + # ── PEFT ── + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # ── Data collator ── + if data_collator is None: + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) + + # ── Liger fused JSD loss ── + self.use_liger_loss = False + if args.use_liger_kernel: + self.liger_jsd_loss = LigerFusedLinearJSDLoss( + beta=args.beta, + ignore_index=-100, + temperature=args.temperature, + compiled=False, + weight_hard_loss=0.0, + weight_soft_loss=1.0, + ) + self.use_liger_loss = True + + # ── Teacher model setup ── + self.teacher_client = None + if args.teacher_model_server_url is not None: + from ..generation.vllm_client import VLLMClient + + self.teacher_client = VLLMClient(base_url=args.teacher_model_server_url, connection_timeout=60.0) + teacher_model = None + elif teacher_model is not None: + if args.teacher_model_init_kwargs is None: + teacher_model_init_kwargs = {} + elif not isinstance(teacher_model, str): + raise ValueError( + "You passed teacher_model_init_kwargs to the config, but your teacher_model is already " + "instantiated." + ) + else: + teacher_model_init_kwargs = args.teacher_model_init_kwargs + teacher_model_init_kwargs["torch_dtype"] = ( + teacher_model_init_kwargs["torch_dtype"] + if teacher_model_init_kwargs.get("torch_dtype") in ["auto", None] + else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) + ) + + if isinstance(teacher_model, str): + init_kwargs = dict(teacher_model_init_kwargs) + if args.teacher_model_revision is not None: + init_kwargs.setdefault("revision", args.teacher_model_revision) + if "torch_dtype" in init_kwargs and "dtype" not in init_kwargs: + init_kwargs["dtype"] = init_kwargs.pop("torch_dtype") + teacher_model = create_model_from_path(teacher_model, **init_kwargs) + + # Trainer does not need to remove unused columns — DataCollatorForChatML handles raw data + args.remove_unused_columns = False + + # ── Call _BaseTrainer.__init__ (which is transformers.Trainer.__init__) ── + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # ── Prepare teacher model (after super().__init__ so accelerator is ready) ── + if teacher_model is not None: + teacher_model.resize_token_embeddings(self.model.config.vocab_size) + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + else: + self.teacher_model = None + + if args.disable_dropout: + disable_dropout_in_model(self.model) + + # ── Store config values ── + self.lmbda = args.lmbda + self.beta = args.beta + self.temperature = args.temperature + self.top_p = args.top_p + self.num_generations = args.num_generations + self.teacher_server_top_logprobs = args.teacher_server_top_logprobs + + # ── Buffer state ── + self._buffered_inputs = None + self._buffered_on_policy_flags = None + self._buffered_text_logs = None + self._buffer_step = 0 + + # ── Loss tracking ── + self._on_policy_loss_total = 0.0 + self._off_policy_loss_total = 0.0 + self._on_policy_step_equiv = 0.0 + self._off_policy_step_equiv = 0.0 + + # ── Generation config ── + generation_kwargs = { + "max_new_tokens": args.max_completion_length, + "temperature": args.temperature, + "top_p": args.top_p, + "do_sample": True, + "top_k": args.top_k, + "pad_token_id": self.processing_class.pad_token_id, + } + self.generation_config = GenerationConfig(**generation_kwargs) + self.generation_kwargs = generation_kwargs + if ( + hasattr(self.model.generation_config, "eos_token_id") + and self.model.generation_config.eos_token_id is not None + ): + self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + + # ── Metrics & Logging ── + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.log_completions_steps = args.log_completions_steps + self.num_completions_to_print = args.num_completions_to_print + + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps + self._textual_logs = { + "prompt": deque(maxlen=maxlen), + "completion": deque(maxlen=maxlen), + } + + # ── vLLM for student generation ── + self.use_vllm = args.use_vllm + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and use_vllm is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size * args.gradient_accumulation_steps, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_completion_length=args.max_completion_length, + logprobs=None, + ) + self.vllm_sync_frequency = args.vllm_sync_frequency + self._last_vllm_sync_step = -1 + + # ────────────────────────────────────────────────────────────────────── + # Dataset / Dataloader + # ────────────────────────────────────────────────────────────────────── + + def _set_signature_columns_if_needed(self): + super()._set_signature_columns_if_needed() + extra_columns = ["prompts", "prompt_attention_mask", "messages", "chat_template_kwargs", "tools"] + if self._signature_columns is None: + self._signature_columns = extra_columns + else: + for col in extra_columns: + if col not in self._signature_columns: + self._signature_columns.append(col) + + def _get_train_sampler(self, dataset=None): + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size * self.accelerator.num_processes, + repeat_count=self.args.gradient_accumulation_steps, + shuffle=True, + seed=self.args.seed, + ) + + def get_train_dataloader(self): + """ + Override to load one generation batch per optimizer window. + + The dataloader yields batches of size `per_device_train_batch_size * gradient_accumulation_steps`. + RepeatSampler ensures each generation batch is repeated `gradient_accumulation_steps` times so the + Trainer's loop iterates the correct number of times. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + if self.args.dataloader_num_workers > 0: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + # ────────────────────────────────────────────────────────────────────── + # Buffering: on/off-policy mixing across gradient accumulation steps + # ────────────────────────────────────────────────────────────────────── + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + if not self.model.training: + return generation_batch + + buffer_steps = self.args.gradient_accumulation_steps + if self._buffer_step % buffer_steps == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, buffer_steps) + + slice_idx = self._buffer_step % buffer_steps + inputs = self._buffered_inputs[slice_idx] + self._buffer_step += 1 + return inputs + + @profiling_decorator + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): + """Split batch into slices and decide which are on-policy (student-generated) vs off-policy.""" + slices = split_tensor_dict(generation_batch, buffer_steps) + + # Decide on-policy flags (synchronized across processes) + if self.accelerator.is_main_process: + on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] + else: + on_policy_flags = [False] * buffer_steps + on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) + + self._buffered_inputs = [None] * buffer_steps + self._buffered_on_policy_flags = on_policy_flags + self._buffered_text_logs = [None] * buffer_steps + + # Store off-policy slices directly + on_policy_indices = [] + for i, is_on_policy in enumerate(on_policy_flags): + if is_on_policy: + on_policy_indices.append(i) + else: + self._buffered_inputs[i] = slices[i] + + # Generate student completions for on-policy slices + if on_policy_indices: + self._generate_student_completions(slices, on_policy_indices) + + @profiling_decorator + def _generate_student_completions( + self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] + ): + """Generate completions from the student model for on-policy training.""" + if not self.use_vllm: + self._generate_with_model(slices, on_policy_indices) + return + + # Collect all prompts across on-policy slices + local_prompts = [] + local_slice_indices = [] + for slice_idx in on_policy_indices: + for prompt in slices[slice_idx]["prompts"]: + local_prompts.append(prompt) + local_slice_indices.append(slice_idx) + + prompts_text = self.processing_class.batch_decode( + torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + skip_special_tokens=False, + ) + + # Sync student weights to vLLM if needed + if ( + self.state.global_step != self._last_vllm_sync_step + and self.state.global_step % self.vllm_sync_frequency == 0 + ): + self.vllm_generation.sync_weights() + self._last_vllm_sync_step = self.state.global_step + + # Generate completions + prompt_ids_list = [p.tolist() for p in local_prompts] + _, completion_ids, _, _ = self.vllm_generation.generate( + prompts=prompt_ids_list, images=None, num_generations=self.num_generations + ) + + # Process completions into the buffer + self._store_completions_in_buffer( + slices, on_policy_indices, local_slice_indices, completion_ids, prompts_text + ) + + def _generate_with_model(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): + """Fallback generation using model.generate() (no vLLM).""" + with unwrap_model_for_generation( + self.model, self.accelerator, generation_kwargs=self.generation_kwargs + ) as unwrapped_model: + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + generated_outputs = unwrapped_model.generate( + input_ids=slice_inputs["prompts"], + attention_mask=slice_inputs.get("prompt_attention_mask", None), + generation_config=self.generation_config, + return_dict_in_generate=True, + ) + generated_tokens = generated_outputs.sequences + batch_size = generated_tokens.size(0) + device = generated_tokens.device + pad_token_id = self.processing_class.pad_token_id + + prompt_lengths = torch.full( + (batch_size,), slice_inputs["prompts"].shape[1], dtype=torch.long, device=device + ) + new_attention_mask, new_labels = self._build_sequence_batch( + generated_tokens, prompt_lengths, pad_token_id + ) + + # Decode for logging + prompt_texts = [] + completion_texts = [] + prompt_mask = slice_inputs.get("prompt_attention_mask") + for idx in range(batch_size): + prompt_tokens = slice_inputs["prompts"][idx] + if prompt_mask is not None: + prompt_tokens = prompt_tokens[prompt_mask[idx].bool()] + elif pad_token_id is not None: + prompt_tokens = prompt_tokens[prompt_tokens != pad_token_id] + prompt_texts.append( + self.processing_class.decode(prompt_tokens.tolist(), skip_special_tokens=False) + ) + length = int(prompt_lengths[idx].item()) + completion_texts.append( + self.processing_class.decode(generated_tokens[idx, length:].tolist(), skip_special_tokens=False) + ) + + updated = dict(slice_inputs) + updated["input_ids"] = generated_tokens + updated["attention_mask"] = new_attention_mask + updated["labels"] = new_labels + + self._buffered_inputs[slice_idx] = updated + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + + def _store_completions_in_buffer( + self, + slices: list[dict[str, torch.Tensor | Any]], + on_policy_indices: list[int], + local_slice_indices: list[int], + completion_ids: list, + prompts_text: list[str], + ): + """Process vLLM completions and store them in the buffer.""" + device = self.accelerator.device + pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 + max_completion_length = self.generation_config.max_new_tokens + + # Group completions by slice + slice_completions = {idx: [] for idx in on_policy_indices} + slice_prompts = {idx: [] for idx in on_policy_indices} + for i, slice_idx in enumerate(local_slice_indices): + slice_completions[slice_idx].append(completion_ids[i]) + slice_prompts[slice_idx].append(prompts_text[i]) + + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + prompt_txts = slice_prompts[slice_idx] + + # Tokenize prompts + prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None + prompt_tokenized = self.processing_class( + prompt_txts, + return_tensors="pt", + padding="longest", + padding_side="left", + truncation=True if prompt_max_length else False, + max_length=prompt_max_length, + add_special_tokens=False, + ).to(device) + prompt_ids = prompt_tokenized.input_ids + + # Pad/truncate completions + completion_tensors = [] + completion_ids_for_text = [] + for comp_ids in slice_completions[slice_idx]: + t = torch.tensor(comp_ids, device=device) + if len(t) > max_completion_length: + t = t[:max_completion_length] + completion_ids_for_text.append(t.tolist()) + if len(t) < max_completion_length: + padding = torch.full( + (max_completion_length - len(t),), pad_token_id, device=device, dtype=t.dtype + ) + t = torch.cat([t, padding]) + completion_tensors.append(t) + + completion_ids_padded = torch.stack(completion_tensors) + new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) + prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) + new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) + + completion_texts = self.processing_class.batch_decode( + completion_ids_for_text, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + + updated = dict(slice_inputs) + updated["input_ids"] = new_input_ids + updated["attention_mask"] = new_attention_mask + updated["labels"] = new_labels + + self._buffered_inputs[slice_idx] = updated + self._buffered_text_logs[slice_idx] = (prompt_txts, completion_texts) + + @staticmethod + def _build_sequence_batch( + new_input_ids: torch.Tensor, prompt_lengths: torch.Tensor, pad_token_id: int | None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build attention mask and labels from full sequences and prompt lengths.""" + prompt_lengths = prompt_lengths.to(device=new_input_ids.device, dtype=torch.long) + positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) + completion_mask = positions >= prompt_lengths.unsqueeze(1) + + new_attention_mask = torch.ones_like(new_input_ids) + if pad_token_id is not None: + new_attention_mask[new_input_ids == pad_token_id] = 0 + + new_labels = torch.full_like(new_input_ids, -100) + new_labels[completion_mask] = new_input_ids[completion_mask] + if pad_token_id is not None: + new_labels[new_input_ids == pad_token_id] = -100 + + return new_attention_mask, new_labels + + # ────────────────────────────────────────────────────────────────────── + # Loss computation + # ────────────────────────────────────────────────────────────────────── + + @staticmethod + def generalized_jsd_loss( + student_logits, + teacher_logits, + labels=None, + beta=0.5, + temperature=1.0, + reduction="batchmean", + ): + """ + Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation. + + Args: + student_logits: Tensor of shape (batch_size, sequence_length, vocab_size). + teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size). + labels: Tensor of shape (batch_size, sequence_length) with -100 for positions to ignore. + beta: Interpolation coefficient. 0.0 = forward KL, 1.0 = reverse KL. + temperature: Softmax temperature. + reduction: 'batchmean', 'sum', 'mean', or 'none'. + + Returns: + Scalar loss tensor. + """ + student_logits = student_logits / temperature + teacher_logits = teacher_logits / temperature + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log1p(-beta_t), teacher_log_probs + torch.log(beta_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) + jsd = beta_t * kl_teacher + (1 - beta_t) * kl_student + + if labels is not None: + mask = labels != -100 + jsd = jsd[mask] + + if reduction == "batchmean": + return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0) + elif reduction == "sum": + return jsd.sum() + elif reduction == "mean": + return jsd.mean() + else: + return jsd + + @staticmethod + def token_level_divergence_loss( + student_logprobs, + teacher_logprobs, + labels=None, + beta=0.5, + ): + """ + Compute a per-token approximation of the generalized JSD loss using only the sampled token's logprobs. + + This is used when the teacher is an external server and only top-1 logprobs are available. + For each token position, we have log p_student(token) and log p_teacher(token) and compute: + - beta=0 (forward KL): exp(log_teacher) * (log_teacher - log_student) + - beta=1 (reverse KL): exp(log_student) * (log_student - log_teacher) + - 0 < beta < 1 (JSD): weighted combination of forward and reverse token-level KL terms + + Args: + student_logprobs: Tensor of shape (batch_size, completion_length) — student's log-prob per token. + teacher_logprobs: Tensor of shape (batch_size, completion_length) — teacher's log-prob per token. + labels: Tensor of shape (batch_size, completion_length) with -100 for positions to ignore. + beta: Interpolation coefficient. 0.0 = forward KL, 0.5 = JSD, 1.0 = reverse KL. + + Returns: + Scalar loss tensor. + """ + if beta == 0: + # Forward KL: p_teacher * (log_teacher - log_student) + loss = torch.exp(teacher_logprobs) * (teacher_logprobs - student_logprobs) + elif beta == 1: + # Reverse KL: p_student * (log_student - log_teacher) + loss = torch.exp(student_logprobs) * (student_logprobs - teacher_logprobs) + else: + # Token-level JSD approximation + forward_kl = torch.exp(teacher_logprobs) * (teacher_logprobs - student_logprobs) + reverse_kl = torch.exp(student_logprobs) * (student_logprobs - teacher_logprobs) + loss = beta * forward_kl + (1 - beta) * reverse_kl + + if labels is not None: + mask = labels != -100 + loss = loss[mask] + return loss.sum() / mask.sum() + + return loss.mean() + + def _get_prompt_length(self, inputs: dict[str, torch.Tensor | Any]) -> int: + """Compute the effective prompt length from labels and attention mask.""" + labels = inputs.get("labels") + attention_mask = inputs.get("attention_mask") + if labels is not None and attention_mask is not None: + total_valid = attention_mask.sum(dim=1) + completion_len = (labels != -100).sum(dim=1) + return int((total_valid - completion_len).min().item()) + return inputs["prompts"].shape[1] + + def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: + """Get teacher logits — dispatches between local model and external server.""" + if self.teacher_model is not None: + self.teacher_model.eval() + with torch.no_grad(): + return self.teacher_model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ).logits + elif self.teacher_client is not None: + return self._get_teacher_logits_from_server(inputs) + else: + raise ValueError("No teacher model or teacher server configured.") + + def _get_teacher_token_logprobs_from_server(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: + """Fetch per-token teacher logprobs from an external vLLM server. + + Returns a tensor of shape (batch_size, completion_length) containing the teacher's log-probability + for the token present at each completion position in the input sequence. + """ + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + batch_size = input_ids.shape[0] + prompt_length = self._get_prompt_length(inputs) + + # Extract unpadded sequences + sequences = [] + prompt_lengths = [] + for i in range(batch_size): + valid_mask = attention_mask[i].bool() + seq = input_ids[i][valid_mask].tolist() + sequences.append(seq) + prompt_lengths.append(prompt_length) + + # Request top-1 logprobs from the teacher server. The server returns the logprob of the token at each + # position (i.e., the token in the sequence we sent), which is exactly what we need for all divergences. + result = self.teacher_client.get_sequence_logprobs( + sequences=sequences, + prompt_lengths=prompt_lengths, + top_logprobs=1, + ) + + # Build a (batch_size, completion_length) tensor of teacher logprobs for the sequence tokens + completion_length = max(len(lps) for lps in result["logprobs"]) + device = input_ids.device + teacher_logprobs = torch.full( + (batch_size, completion_length), float("-inf"), dtype=torch.float32, device=device + ) + for i, seq_lps in enumerate(result["logprobs"]): + for pos, lps in enumerate(seq_lps): + if lps and lps[0] is not None: + teacher_logprobs[i, pos] = lps[0] + + return teacher_logprobs + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if self.use_liger_loss: + loss = self._compute_liger_loss(model, inputs) + return (loss, None) if return_outputs else loss + + # Student forward pass + student_outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + prompt_length = self._get_prompt_length(inputs) + labels = inputs["labels"][:, prompt_length:] + + if self.teacher_client is not None: + # Server path: token-level divergence using top-1 logprobs + teacher_token_logprobs = self._get_teacher_token_logprobs_from_server(inputs) + + # Extract student logprobs for the same tokens in the completion region + student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] + student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) + completion_tokens = inputs["input_ids"][:, prompt_length:] + # Trim to match completion length from server + comp_len = teacher_token_logprobs.shape[1] + completion_tokens = completion_tokens[:, :comp_len] + student_token_logprobs = student_log_probs[:, :comp_len, :].gather( + dim=-1, index=completion_tokens.unsqueeze(-1) + ).squeeze(-1) + + loss = self.token_level_divergence_loss( + student_logprobs=student_token_logprobs, + teacher_logprobs=teacher_token_logprobs, + labels=labels[:, :comp_len], + beta=self.beta, + ) + else: + # Local teacher: full-vocabulary generalized JSD + teacher_logits = self._get_teacher_logits(inputs) + student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] + teacher_logits = teacher_logits[:, prompt_length - 1 : -1, :] + + loss = self.generalized_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=labels, + beta=self.beta, + temperature=self.temperature, + ) + + return (loss, student_outputs) if return_outputs else loss + + def _compute_liger_loss(self, model, inputs): + """Memory-efficient JSD using Liger kernel (operates on hidden states, not full logits).""" + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None: + base_student = unwrapped_student.get_decoder() + else: + base_student = getattr( + unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student + ) + + student_outputs = base_student( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + self.teacher_model.eval() + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None: + base_teacher = unwrapped_teacher.get_decoder() + else: + base_teacher = getattr( + unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher + ) + with torch.no_grad(): + teacher_outputs = base_teacher( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + del student_outputs, teacher_outputs + + student_hidden = student_hidden.reshape(-1, student_hidden.shape[-1]) + teacher_hidden = teacher_hidden.reshape(-1, teacher_hidden.shape[-1]) + + labels_mask = inputs["labels"] != -100 + masked_input_ids = torch.where( + labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) + ) + true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) + + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, "bias", None), + teacher_bias=getattr(teacher_head, "bias", None), + ) + + del student_hidden, teacher_hidden, true_labels + return loss + + def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): + """Context manager for gathering lm_head parameters under Liger + ZeRO-3.""" + if not self.use_liger_loss: + return nullcontext() + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3: + return nullcontext() + + import deepspeed + + unwrapped_student = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if student_head.bias is not None: + params.append(student_head.bias) + if teacher_head.bias is not None: + params.append(teacher_head.bias) + return deepspeed.zero.GatheredParameters(params, modifier_rank=None) + + # ────────────────────────────────────────────────────────────────────── + # Training step & Logging + # ────────────────────────────────────────────────────────────────────── + + @profiling_decorator + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + """Training step with on/off-policy loss tracking and completion stats.""" + buffer_steps = self.args.gradient_accumulation_steps + + with self._get_liger_zero3_lm_head_gather_ctx(model): + loss = super().training_step(model, inputs, num_items_in_batch) + + slice_idx = (self._buffer_step - 1) % buffer_steps + + # Track on-policy text logs + is_on_policy = False + if self._buffered_on_policy_flags is not None and slice_idx < len(self._buffered_on_policy_flags): + is_on_policy = self._buffered_on_policy_flags[slice_idx] + + if is_on_policy and self._buffered_text_logs is not None and self._buffered_text_logs[slice_idx] is not None: + prompt_texts, completion_texts = self._buffered_text_logs[slice_idx] + self._textual_logs["prompt"].extend(gather_object(prompt_texts)) + self._textual_logs["completion"].extend(gather_object(completion_texts)) + + # Track completion length stats + labels = inputs.get("labels") + if labels is not None: + completion_lengths = (labels != -100).sum(dim=1).float() + gathered_lengths = self.accelerator.gather(completion_lengths) + mode = "train" + prefix = "on_policy" if is_on_policy else "off_policy" + self._metrics[mode][f"completions/{prefix}_mean_length"].append(gathered_lengths.mean().item()) + self._metrics[mode][f"completions/{prefix}_max_length"].append(gathered_lengths.max().item()) + self._metrics[mode][f"completions/{prefix}_min_length"].append(gathered_lengths.min().item()) + + # Track loss per policy type + loss_scalar = float(loss.detach()) + step_equiv = 1.0 / self.args.gradient_accumulation_steps + if is_on_policy: + self._on_policy_loss_total += loss_scalar + self._on_policy_step_equiv += step_equiv + else: + self._off_policy_loss_total += loss_scalar + self._off_policy_step_equiv += step_equiv + + return loss + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} + + if mode == "train": + # Aggregate on/off-policy losses across distributed processes + device = self.accelerator.device if hasattr(self.accelerator, "device") else torch.device("cpu") + vec = torch.tensor( + [ + self._on_policy_loss_total, + self._off_policy_loss_total, + self._on_policy_step_equiv, + self._off_policy_step_equiv, + ], + dtype=torch.float64, + device=device, + ) + + if ( + getattr(self.accelerator, "distributed_type", DistributedType.NO) != DistributedType.NO + and dist.is_available() + and dist.is_initialized() + ): + dist.all_reduce(vec, op=dist.ReduceOp.SUM) + + on_sum, off_sum, on_eq, off_eq = vec.tolist() + if on_eq > 0: + logs["on_policy_loss"] = round(on_sum / on_eq, 4) + if off_eq > 0: + logs["off_policy_loss"] = round(off_sum / off_eq, 4) + + self._on_policy_loss_total = self._off_policy_loss_total = 0.0 + self._on_policy_step_equiv = self._off_policy_step_equiv = 0.0 + + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + # Log completions to console and wandb + if ( + self.accelerator.is_main_process + and self.log_completions + and self.state.global_step % self.log_completions_steps == 0 + ): + prompts = list(self._textual_logs["prompt"]) + completions = list(self._textual_logs["completion"]) + + _print_completions_sample(prompts, completions, self.state.global_step, self.num_completions_to_print) + + # Log as a wandb Table + if self.args.report_to and "wandb" in self.args.report_to: + try: + import wandb + + if wandb.run is not None: + import pandas as pd + + table_data = { + "step": [str(self.state.global_step)] * len(prompts), + "prompt": prompts, + "completion": completions, + } + df = pd.DataFrame(table_data) + if self.num_completions_to_print and len(df) > self.num_completions_to_print: + df = df.sample(n=self.num_completions_to_print, random_state=42) + wandb.log({"completions": wandb.Table(dataframe=df)}) + except ImportError: + pass From cdc31965e71d8a1c3a3e4a1a980115cff0c6797a Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:19:35 +0100 Subject: [PATCH 06/37] Address concern about vllm weight sync --- trl/experimental/gold/gold_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index d41a146661c..719dec860e9 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -962,7 +962,7 @@ def __init__( logprobs=None, ) self.vllm_sync_frequency = args.vllm_sync_frequency - self._last_vllm_sync_step = -1 + self._last_vllm_sync_step = -self.vllm_sync_frequency def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() @@ -1179,7 +1179,7 @@ def _generate_on_policy_for_slices( if ( self.state.global_step != self._last_vllm_sync_step - and self.state.global_step % self.vllm_sync_frequency == 0 + and self.state.global_step >= self._last_vllm_sync_step + self.vllm_sync_frequency ): self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step From f4c193e36f940ea33b3e348492db41eda5a64c64 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:20:09 +0100 Subject: [PATCH 07/37] Run precommit --- trl/experimental/gold/gold_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 719dec860e9..d76056ffad4 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1185,9 +1185,7 @@ def _generate_on_policy_for_slices( self._last_vllm_sync_step = self.state.global_step pad_token_id = self.processing_class.pad_token_id - prompt_ids_list = [ - [tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts - ] + prompt_ids_list = [[tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, From 2b41f84603274c2e9619777cc7187f3c388e3229 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:32:35 +0100 Subject: [PATCH 08/37] Fix max len behavior for generation --- trl/experimental/gold/gold_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index d76056ffad4..49f9fdcb07c 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -949,7 +949,7 @@ def __init__( server_timeout=args.vllm_server_timeout, tensor_parallel_size=args.vllm_tensor_parallel_size, gpu_memory_utilization=args.vllm_gpu_memory_utilization, - max_model_length=args.vllm_max_model_length, + max_model_length=args.vllm_max_model_length or args.max_length, max_num_seqs=args.per_device_train_batch_size * args.gradient_accumulation_steps, enable_sleep_mode=args.vllm_enable_sleep_mode, model_impl=args.vllm_model_impl, From 91715cb5285274c025fc2a8e9096988215192451 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:33:17 +0100 Subject: [PATCH 09/37] Format docstring --- trl/experimental/gold/gold_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index c32ef45a3e8..1af9eeae332 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -97,8 +97,8 @@ class GOLDConfig(SFTConfig): Base URL for the vLLM server (e.g., `"http://localhost:8001"`). If provided, `vllm_server_host` and `vllm_server_port` are ignored. vllm_group_port (`int`, *optional*, defaults to `51216`): - Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need - to change it. + Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need to + change it. vllm_max_model_length (`int`, *optional*): Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the model's maximum context length. From b8754dbad1c4b0fcd7cd0a9780c42c2fd491e234 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 24 Mar 2026 14:36:02 +0000 Subject: [PATCH 10/37] Fix data collation issue --- trl/trainer/distillation_config.py | 11 ++ trl/trainer/distillation_trainer.py | 206 ++++++++++++++++++++++------ 2 files changed, 172 insertions(+), 45 deletions(-) diff --git a/trl/trainer/distillation_config.py b/trl/trainer/distillation_config.py index 42ebd1755a8..b2d70309830 100644 --- a/trl/trainer/distillation_config.py +++ b/trl/trainer/distillation_config.py @@ -178,6 +178,14 @@ class DistillationConfig(_BaseConfig): default=256, metadata={"help": "Maximum number of tokens to generate per completion."}, ) + max_prompt_length: int | None = field( + default=None, + metadata={ + "help": "Maximum number of tokens for the prompt. If None, auto-computed as " + "max_length - max_completion_length. Prompts are truncated from the left to preserve " + "the most recent context near the generation point." + }, + ) disable_dropout: bool = field( default=True, metadata={"help": "Whether to disable dropout in the student model during training."}, @@ -320,6 +328,9 @@ def __post_init__(self): f"max_length ({self.max_length}) to leave room for the prompt." ) + if self.max_prompt_length is None and self.max_length is not None: + self.max_prompt_length = self.max_length - self.max_completion_length + if self.num_generations < 1: raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") diff --git a/trl/trainer/distillation_trainer.py b/trl/trainer/distillation_trainer.py index ebe7e87dd25..2605155eb49 100644 --- a/trl/trainer/distillation_trainer.py +++ b/trl/trainer/distillation_trainer.py @@ -43,7 +43,6 @@ is_rich_available, ) -from ..experimental.utils import DataCollatorForChatML from ..extras.profiling import profiling_decorator from ..generation.vllm_generation import VLLMGeneration from ..import_utils import is_vllm_available @@ -102,6 +101,121 @@ def _print_completions_sample(prompts: list[str], completions: list[str], step: console.print(panel) +class _DistillationCollator: + """Data collator for the distillation trainer with independent prompt/completion budgets. + + Unlike ``DataCollatorForChatML``, this collator tokenizes prompts and completions separately so + that long completions can never truncate the prompt to empty. It also handles prompt-only data + (no assistant completions) for pure on-policy distillation (``lmbda=1``). + """ + + def __init__( + self, + tokenizer: "PreTrainedTokenizerBase", + max_length: int, + max_prompt_length: int, + messages_key: str = "messages", + ignore_index: int = -100, + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.max_prompt_length = max_prompt_length + self.messages_key = messages_key + self.ignore_index = ignore_index + + if tokenizer.pad_token_id is None: + raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") + + def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + all_input_ids: list[list[int]] = [] + all_labels: list[list[int]] = [] + all_prompt_ids: list[list[int]] = [] + + for example in examples: + messages = example[self.messages_key] + + # Split: prompt = everything before the last assistant turn, completion = last assistant turn + has_completion = len(messages) > 1 and messages[-1].get("role") == "assistant" + prompt_messages = messages[:-1] if has_completion else messages + + # Tokenize prompt with its own budget (truncate from the left to keep recent context) + formatted_prompt = self.tokenizer.apply_chat_template( + prompt_messages, tokenize=False, add_generation_prompt=True + ) + prompt_ids = self.tokenizer( + formatted_prompt, + truncation=True, + max_length=self.max_prompt_length, + padding=False, + add_special_tokens=False, + )["input_ids"] + + if has_completion: + # Tokenize the full message (prompt + completion) without truncation first + formatted_full = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + full_ids = self.tokenizer( + formatted_full, truncation=False, padding=False, add_special_tokens=False + )["input_ids"] + + # Identify completion tokens: everything after the prompt in the full sequence. + # Use the un-truncated prompt length as the split point. + formatted_prompt_ids = self.tokenizer( + formatted_prompt, truncation=False, padding=False, add_special_tokens=False + )["input_ids"] + completion_ids = full_ids[len(formatted_prompt_ids):] + + # Trim completion so prompt + completion <= max_length + max_comp = self.max_length - len(prompt_ids) + if max_comp > 0 and len(completion_ids) > max_comp: + completion_ids = completion_ids[:max_comp] + elif max_comp <= 0: + completion_ids = [] + + input_ids = prompt_ids + completion_ids + labels = [self.ignore_index] * len(prompt_ids) + list(completion_ids) + else: + # Prompt-only: no completion to train on (on-policy will generate one) + input_ids = list(prompt_ids) + labels = [self.ignore_index] * len(prompt_ids) + + all_input_ids.append(input_ids) + all_labels.append(labels) + all_prompt_ids.append(list(prompt_ids)) + + # Convert to tensors and left-pad + pad_id = self.tokenizer.pad_token_id + input_ids_t = pad( + [torch.tensor(ids, dtype=torch.long) for ids in all_input_ids], + padding_side="left", padding_value=pad_id, + ) + attention_mask_t = pad( + [torch.ones(len(ids), dtype=torch.long) for ids in all_input_ids], + padding_side="left", padding_value=0, + ) + labels_t = pad( + [torch.tensor(lab, dtype=torch.long) for lab in all_labels], + padding_side="left", padding_value=self.ignore_index, + ) + prompts_t = pad( + [torch.tensor(ids, dtype=torch.long) for ids in all_prompt_ids], + padding_side="left", padding_value=pad_id, + ) + prompt_mask_t = pad( + [torch.ones(len(ids), dtype=torch.long) for ids in all_prompt_ids], + padding_side="left", padding_value=0, + ) + + return { + "input_ids": input_ids_t, + "attention_mask": attention_mask_t, + "labels": labels_t, + "prompts": prompts_t, + "prompt_attention_mask": prompt_mask_t, + } + + class DistillationTrainer(_BaseTrainer): """ Trainer for knowledge distillation from a teacher model to a student model. @@ -178,7 +292,11 @@ def __init__( # ── Data collator ── if data_collator is None: - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) + data_collator = _DistillationCollator( + tokenizer=processing_class, + max_length=args.max_length, + max_prompt_length=args.max_prompt_length, + ) # ── Liger fused JSD loss ── self.use_liger_loss = False @@ -224,7 +342,7 @@ def __init__( init_kwargs["dtype"] = init_kwargs.pop("torch_dtype") teacher_model = create_model_from_path(teacher_model, **init_kwargs) - # Trainer does not need to remove unused columns — DataCollatorForChatML handles raw data + # Trainer does not need to remove unused columns — the collator handles raw data args.remove_unused_columns = False # ── Call _BaseTrainer.__init__ (which is transformers.Trainer.__init__) ── @@ -455,19 +573,23 @@ def _generate_student_completions( self._generate_with_model(slices, on_policy_indices) return - # Collect all prompts across on-policy slices + # Collect all prompts across on-policy slices, stripping left-padding so vLLM + # receives only real prompt token IDs (like GRPO trainer). local_prompts = [] local_slice_indices = [] + pad_token_id = self.processing_class.pad_token_id for slice_idx in on_policy_indices: - for prompt in slices[slice_idx]["prompts"]: + prompt_mask = slices[slice_idx].get("prompt_attention_mask") + for i, prompt in enumerate(slices[slice_idx]["prompts"]): + if prompt_mask is not None: + prompt = prompt[prompt_mask[i].bool()] + elif pad_token_id is not None: + first_non_pad = (prompt != pad_token_id).nonzero(as_tuple=True)[0] + if len(first_non_pad) > 0: + prompt = prompt[first_non_pad[0] :] local_prompts.append(prompt) local_slice_indices.append(slice_idx) - prompts_text = self.processing_class.batch_decode( - torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), - skip_special_tokens=False, - ) - # Sync student weights to vLLM if needed if ( self.state.global_step != self._last_vllm_sync_step @@ -476,7 +598,7 @@ def _generate_student_completions( self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step - # Generate completions + # Generate completions — pass token IDs directly, no text decoding prompt_ids_list = [p.tolist() for p in local_prompts] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, num_generations=self.num_generations @@ -484,7 +606,7 @@ def _generate_student_completions( # Process completions into the buffer self._store_completions_in_buffer( - slices, on_policy_indices, local_slice_indices, completion_ids, prompts_text + slices, on_policy_indices, local_slice_indices, local_prompts, completion_ids ) def _generate_with_model(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): @@ -543,39 +665,37 @@ def _store_completions_in_buffer( slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int], local_slice_indices: list[int], + local_prompts: list[torch.Tensor], completion_ids: list, - prompts_text: list[str], ): - """Process vLLM completions and store them in the buffer.""" + """Process vLLM completions and store them in the buffer. + + Uses original prompt token IDs directly (no decode/re-encode roundtrip), following the same + approach as GRPOTrainer. + """ device = self.accelerator.device pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 max_completion_length = self.generation_config.max_new_tokens - # Group completions by slice + # Group completions and prompt token IDs by slice slice_completions = {idx: [] for idx in on_policy_indices} - slice_prompts = {idx: [] for idx in on_policy_indices} + slice_prompt_ids = {idx: [] for idx in on_policy_indices} for i, slice_idx in enumerate(local_slice_indices): slice_completions[slice_idx].append(completion_ids[i]) - slice_prompts[slice_idx].append(prompts_text[i]) + slice_prompt_ids[slice_idx].append(local_prompts[i]) for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] - prompt_txts = slice_prompts[slice_idx] - - # Tokenize prompts - prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None - prompt_tokenized = self.processing_class( - prompt_txts, - return_tensors="pt", - padding="longest", - padding_side="left", - truncation=True if prompt_max_length else False, - max_length=prompt_max_length, - add_special_tokens=False, - ).to(device) - prompt_ids = prompt_tokenized.input_ids + prompt_id_tensors = slice_prompt_ids[slice_idx] + + # Left-pad prompt token IDs to the longest prompt in this slice + max_prompt_len = max(len(p) for p in prompt_id_tensors) + prompt_ids = torch.stack([ + F.pad(p, (max_prompt_len - len(p), 0), value=pad_token_id) + for p in prompt_id_tensors + ]).to(device) - # Pad/truncate completions + # Pad/truncate completions (right-pad to max_completion_length) completion_tensors = [] completion_ids_for_text = [] for comp_ids in slice_completions[slice_idx]: @@ -595,6 +715,10 @@ def _store_completions_in_buffer( prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) + # Decode for logging only + prompt_texts = self.processing_class.batch_decode( + prompt_id_tensors, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) completion_texts = self.processing_class.batch_decode( completion_ids_for_text, skip_special_tokens=False, clean_up_tokenization_spaces=False ) @@ -603,9 +727,11 @@ def _store_completions_in_buffer( updated["input_ids"] = new_input_ids updated["attention_mask"] = new_attention_mask updated["labels"] = new_labels + # Update prompts to match the new padding width so prompt_length is consistent + updated["prompts"] = prompt_ids self._buffered_inputs[slice_idx] = updated - self._buffered_text_logs[slice_idx] = (prompt_txts, completion_texts) + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) @staticmethod def _build_sequence_batch( @@ -730,16 +856,6 @@ def token_level_divergence_loss( return loss.mean() - def _get_prompt_length(self, inputs: dict[str, torch.Tensor | Any]) -> int: - """Compute the effective prompt length from labels and attention mask.""" - labels = inputs.get("labels") - attention_mask = inputs.get("attention_mask") - if labels is not None and attention_mask is not None: - total_valid = attention_mask.sum(dim=1) - completion_len = (labels != -100).sum(dim=1) - return int((total_valid - completion_len).min().item()) - return inputs["prompts"].shape[1] - def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: """Get teacher logits — dispatches between local model and external server.""" if self.teacher_model is not None: @@ -763,7 +879,7 @@ def _get_teacher_token_logprobs_from_server(self, inputs: dict[str, torch.Tensor input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] batch_size = input_ids.shape[0] - prompt_length = self._get_prompt_length(inputs) + prompt_length = inputs["prompts"].shape[1] # Extract unpadded sequences sequences = [] @@ -805,7 +921,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) - prompt_length = self._get_prompt_length(inputs) + prompt_length = inputs["prompts"].shape[1] labels = inputs["labels"][:, prompt_length:] if self.teacher_client is not None: From b94fc1fd36f350119ea915507c3eb51102c3a3c7 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 24 Mar 2026 16:08:17 +0100 Subject: [PATCH 11/37] Remove decode -> re-tokenization roundtrip --- tests/experimental/test_gold_trainer.py | 319 ++++++++++++++++++++---- trl/experimental/gold/gold_trainer.py | 100 +++++--- 2 files changed, 333 insertions(+), 86 deletions(-) diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py index 50800c0a136..d7e32056323 100644 --- a/tests/experimental/test_gold_trainer.py +++ b/tests/experimental/test_gold_trainer.py @@ -19,6 +19,7 @@ from datasets import load_dataset from transformers import AutoTokenizer +from trl.experimental.gold import gold_trainer as gold_trainer_module from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts from trl.experimental.utils import DataCollatorForChatML @@ -289,58 +290,11 @@ def pad_labels(labels, target_length): return labels + [-100] * (target_length - len(labels)) -def test_process_completions_to_buffer_left_pads_prompt_retokenization(): - class DummyBatch: - def __init__(self, input_ids): - self.input_ids = input_ids - - def to(self, device): - self.input_ids = self.input_ids.to(device) - return self - +def test_process_completions_to_buffer_left_pads_prompt_ids(): class RecordingTokenizer: pad_token_id = 0 pad_token = "" - def __init__(self): - self.padding_side = "right" - self.calls = [] - self._prompt_ids = { - "short": [11], - "longer": [21, 22], - } - - def __call__( - self, - texts, - return_tensors, - padding, - truncation, - max_length, - add_special_tokens, - padding_side=None, - ): - assert return_tensors == "pt" - assert padding == "longest" - assert not truncation - assert max_length is None - assert not add_special_tokens - self.calls.append(padding_side) - - side = padding_side or self.padding_side - encoded = [torch.tensor(self._prompt_ids[text], dtype=torch.long) for text in texts] - max_len = max(len(ids) for ids in encoded) - - padded = [] - for ids in encoded: - pad_width = max_len - len(ids) - if pad_width: - pad = torch.full((pad_width,), self.pad_token_id, dtype=torch.long) - ids = torch.cat([pad, ids]) if side == "left" else torch.cat([ids, pad]) - padded.append(ids) - - return DummyBatch(torch.stack(padded)) - def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): del skip_special_tokens, clean_up_tokenization_spaces return [" ".join(str(token) for token in sequence) for sequence in sequences] @@ -358,19 +312,282 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati on_policy_indices=[0], local_slice_indices=[0, 0], completion_ids=[[31], [41]], - prompts_text=["short", "longer"], prompts_text_with_special=["short", "longer"], + prompt_ids_list=[[11], [21, 22]], + prompts_text=["short", "longer"], max_completion_length=1, ) buffered_inputs = trainer._buffered_inputs[0] - assert trainer.processing_class.calls == ["left"] - assert trainer.processing_class.padding_side == "right" assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[0, 11, 31], [21, 22, 41]], dtype=torch.long)) assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long)) assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, 31], [-100, -100, 41]])) +def test_generate_on_policy_for_slices_uses_prompt_attention_mask_for_vllm_prompts(): + class RecordingVLLMGeneration: + def __init__(self): + self.prompts = None + self.sync_calls = 0 + + def sync_weights(self): + self.sync_calls += 1 + + def generate(self, prompts, images, num_generations): + self.prompts = prompts + assert images is None + assert num_generations == 1 + return None, [[42]], None, None + + class RecordingTokenizer: + pad_token_id = 9 + pad_token = "" + + def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): + del clean_up_tokenization_spaces + decoded = [] + token_map = {5: "A", 6: "B", 9: ""} + for sequence in sequences: + tokens = [] + for token in sequence: + token = int(token) + if skip_special_tokens and token == 9: + continue + tokens.append(token_map[token]) + decoded.append(" ".join(tokens)) + return decoded + + captured = {} + + def capture_process_completions( + slices, + on_policy_indices, + local_slice_indices, + completion_ids, + prompt_ids_list, + prompts_text_with_special, + prompts_text, + max_completion_length, + ): + captured["slices"] = slices + captured["on_policy_indices"] = on_policy_indices + captured["local_slice_indices"] = local_slice_indices + captured["completion_ids"] = completion_ids + captured["prompt_ids_list"] = prompt_ids_list + captured["prompts_text"] = prompts_text + captured["prompts_text_with_special"] = prompts_text_with_special + captured["max_completion_length"] = max_completion_length + + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.accelerator = SimpleNamespace(is_main_process=True) + trainer.args = SimpleNamespace(report_to=[]) + trainer.processing_class = RecordingTokenizer() + trainer.use_vllm = True + trainer.vllm_generation = RecordingVLLMGeneration() + trainer.vllm_sync_frequency = 1 + trainer._last_vllm_sync_step = -1 + trainer.state = SimpleNamespace(global_step=0) + trainer.num_generations = 1 + trainer.generation_config = SimpleNamespace(max_new_tokens=1) + trainer._process_completions_to_buffer = capture_process_completions + + slices = [ + { + "prompts": torch.tensor([[9, 9, 5, 9, 6]], dtype=torch.long), + "prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long), + } + ] + + GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0]) + + assert trainer.vllm_generation.prompts == [[5, 9, 6]] + assert trainer.vllm_generation.sync_calls == 1 + assert captured["completion_ids"] == [[42]] + assert captured["prompt_ids_list"] == [[5, 9, 6]] + assert captured["prompts_text"] == ["A B"] + assert captured["prompts_text_with_special"] == ["A B"] + + +def test_generate_on_policy_for_slices_reconstructs_prompt_with_special_tokens(): + class RecordingVLLMGeneration: + def __init__(self): + self.prompts = None + self.sync_calls = 0 + + def sync_weights(self): + self.sync_calls += 1 + + def generate(self, prompts, images, num_generations): + self.prompts = prompts + assert images is None + assert num_generations == 1 + return None, [[42]], None, None + + class RecordingTokenizer: + pad_token_id = 0 + pad_token = "" + + def __init__(self): + self.truncation_side = "right" + + def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): + del clean_up_tokenization_spaces + token_map = {0: "", 5: "A", 6: "B", 13: "", 42: "C"} + decoded = [] + for sequence in sequences: + tokens = [] + for token in sequence: + token = int(token) + if skip_special_tokens and token == 13: + continue + if token == 0: + continue + tokens.append(token_map[token]) + decoded.append(" ".join(tokens)) + return decoded + + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True) + trainer.processing_class = RecordingTokenizer() + trainer.args = SimpleNamespace(max_length=None, report_to=[]) + trainer.use_vllm = True + trainer.vllm_generation = RecordingVLLMGeneration() + trainer.vllm_sync_frequency = 1 + trainer._last_vllm_sync_step = -1 + trainer.state = SimpleNamespace(global_step=0) + trainer.num_generations = 1 + trainer.generation_config = SimpleNamespace(max_new_tokens=1) + trainer._buffered_inputs = [None] + trainer._buffered_text_logs = [None] + + slices = [ + { + "slice": "original", + "prompts": torch.tensor([[0, 0, 5, 13, 6]], dtype=torch.long), + "prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long), + } + ] + + GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0]) + + buffered_inputs = trainer._buffered_inputs[0] + assert trainer.vllm_generation.prompts == [[5, 13, 6]] + assert trainer.vllm_generation.sync_calls == 1 + assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[5, 13, 6, 42]], dtype=torch.long)) + assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[1, 1, 1, 1]], dtype=torch.long)) + assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, -100, 42]], dtype=torch.long)) + assert buffered_inputs["original_prompt_text"] == ["A B"] + assert buffered_inputs["original_completion_text"] == ["C"] + assert trainer._buffered_text_logs[0] == (["A B"], ["C"]) + + +def test_gold_trainer_init_defaults_vllm_max_model_length_to_max_length(monkeypatch): + captured = {} + + class DummyStudentModel: + def __init__(self): + self.config = SimpleNamespace(_name_or_path="student", vocab_size=17) + self.generation_config = SimpleNamespace(eos_token_id=2) + self.name_or_path = "student" + + class DummyTeacherModel: + def __init__(self): + self.resized_to = None + + def resize_token_embeddings(self, vocab_size): + self.resized_to = vocab_size + + class DummyProcessingClass: + pad_token_id = 0 + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers + del preprocess_logits_for_metrics, peft_config + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + class CapturingVLLMGeneration: + def __init__(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + monkeypatch.setattr(gold_trainer_module, "is_vllm_available", lambda: True) + monkeypatch.setattr(gold_trainer_module, "VLLMGeneration", CapturingVLLMGeneration) + + args = SimpleNamespace( + model_init_kwargs=None, + max_length=128, + use_liger_kernel=False, + teacher_model_init_kwargs=None, + use_uld_loss=False, + teacher_tokenizer_name_or_path=None, + teacher_model_revision=None, + disable_dropout=False, + lmbda=1.0, + beta=0.5, + temperature=1.0, + top_p=1.0, + seq_kd=False, + num_generations=1, + use_transformers_paged=False, + max_completion_length=16, + top_k=0, + log_completions=False, + log_completions_steps=100, + wandb_log_unique_prompts=True, + num_completions_to_print=None, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + use_vllm=True, + vllm_mode="colocate", + vllm_structured_outputs_regex=None, + vllm_server_base_url=None, + vllm_server_host="0.0.0.0", + vllm_server_port=8001, + vllm_group_port=51216, + vllm_server_timeout=240.0, + vllm_tensor_parallel_size=1, + vllm_gpu_memory_utilization=0.2, + vllm_max_model_length=None, + vllm_enable_sleep_mode=False, + vllm_model_impl="vllm", + vllm_sync_frequency=1, + ) + + teacher_model = DummyTeacherModel() + GOLDTrainer( + model=DummyStudentModel(), + teacher_model=teacher_model, + args=args, + data_collator=object(), + processing_class=DummyProcessingClass(), + ) + + assert teacher_model.resized_to == 17 + assert captured["max_model_length"] == 128 + + def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer): config = build_config() loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 49f9fdcb07c..39462fbaa1e 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1095,20 +1095,26 @@ def _ensure_original_text_fields( @staticmethod def _build_sequence_batch( - new_input_ids: torch.Tensor, prompt_lengths: torch.Tensor, pad_token_id: int | None + new_input_ids: torch.Tensor, + prompt_lengths: torch.Tensor, + pad_token_id: int | None, + attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build attention mask and labels from full sequences and prompt lengths.""" prompt_lengths = prompt_lengths.to(device=new_input_ids.device, dtype=torch.long) positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) completion_mask = positions >= prompt_lengths.unsqueeze(1) - new_attention_mask = torch.ones_like(new_input_ids) - if pad_token_id is not None: - new_attention_mask[new_input_ids == pad_token_id] = 0 + if attention_mask is not None: + new_attention_mask = attention_mask.to(device=new_input_ids.device, dtype=new_input_ids.dtype) + else: + new_attention_mask = torch.ones_like(new_input_ids) + if pad_token_id is not None: + new_attention_mask[new_input_ids == pad_token_id] = 0 new_labels = torch.full_like(new_input_ids, -100) - new_labels[completion_mask] = new_input_ids[completion_mask] - if pad_token_id is not None: + new_labels[completion_mask & new_attention_mask.bool()] = new_input_ids[completion_mask & new_attention_mask.bool()] + if attention_mask is None and pad_token_id is not None: new_labels[new_input_ids == pad_token_id] = -100 return new_attention_mask, new_labels @@ -1151,27 +1157,25 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s def _generate_on_policy_for_slices( self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] ): - local_prompts = [] + prompt_ids_list = [] local_slice_indices = [] for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] - for prompt in slice_inputs["prompts"]: - local_prompts.append(prompt) + prompt_attention_mask = slice_inputs.get("prompt_attention_mask") + for prompt_idx, prompt in enumerate(slice_inputs["prompts"]): + if prompt_attention_mask is not None: + prompt = prompt[prompt_attention_mask[prompt_idx].bool()] + prompt_ids_list.append(prompt.tolist()) local_slice_indices.append(slice_idx) - stacked_prompts = torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long) - - prompts_text_with_special = self.processing_class.batch_decode( - stacked_prompts, - skip_special_tokens=False, - ) - prompts_text = self.processing_class.batch_decode( - stacked_prompts, + prompt_ids_list, skip_special_tokens=True, ) - if self.processing_class.pad_token: - prompts_text = [p.replace(self.processing_class.pad_token, "") for p in prompts_text] + prompts_text_with_special = self.processing_class.batch_decode( + prompt_ids_list, + skip_special_tokens=False, + ) if not self.use_vllm: self._generate_non_vllm_for_slices(slices, on_policy_indices) @@ -1184,8 +1188,6 @@ def _generate_on_policy_for_slices( self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step - pad_token_id = self.processing_class.pad_token_id - prompt_ids_list = [[tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, @@ -1197,8 +1199,9 @@ def _generate_on_policy_for_slices( on_policy_indices, local_slice_indices, completion_ids, - prompts_text, + prompt_ids_list, prompts_text_with_special, + prompts_text, self.generation_config.max_new_tokens, ) @@ -1235,8 +1238,9 @@ def _process_completions_to_buffer( on_policy_indices: list[int], local_slice_indices: list[int], completion_ids: list, - prompts_text: list[str], + prompt_ids_list: list[list[int]], prompts_text_with_special: list[str], + prompts_text: list[str], max_completion_length: int, ): """ @@ -1246,40 +1250,50 @@ def _process_completions_to_buffer( pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 slice_completions = {idx: [] for idx in on_policy_indices} + slice_prompt_ids = {idx: [] for idx in on_policy_indices} slice_prompts = {idx: [] for idx in on_policy_indices} slice_prompts_special = {idx: [] for idx in on_policy_indices} for i, slice_idx in enumerate(local_slice_indices): slice_completions[slice_idx].append(completion_ids[i]) - slice_prompts[slice_idx].append(prompts_text[i]) + slice_prompt_ids[slice_idx].append(prompt_ids_list[i]) slice_prompts_special[slice_idx].append(prompts_text_with_special[i]) + slice_prompts[slice_idx].append(prompts_text[i]) for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] completion_ids_for_slice = slice_completions[slice_idx] + prompt_ids_for_slice = slice_prompt_ids[slice_idx] prompt_txts = slice_prompts[slice_idx] prompt_txts_with_special = slice_prompts_special[slice_idx] prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None - prompt_tokenized = self.processing_class( - prompt_txts, - return_tensors="pt", - padding="longest", - padding_side="left", - truncation=True if prompt_max_length else False, - max_length=prompt_max_length, - add_special_tokens=False, - ).to(device) - prompt_ids = prompt_tokenized.input_ids + truncated_prompt_ids = [] + prompt_attention_masks = [] + truncation_side = getattr(self.processing_class, "truncation_side", "right") + for prompt_ids in prompt_ids_for_slice: + if prompt_max_length and len(prompt_ids) > prompt_max_length: + if truncation_side == "left": + prompt_ids = prompt_ids[-prompt_max_length:] + else: + prompt_ids = prompt_ids[:prompt_max_length] + prompt_tensor = torch.tensor(prompt_ids, device=device, dtype=torch.long) + truncated_prompt_ids.append(prompt_tensor) + prompt_attention_masks.append(torch.ones(len(prompt_ids), device=device, dtype=torch.long)) + + prompt_ids = pad(truncated_prompt_ids, padding_side="left", padding_value=pad_token_id) + prompt_attention_mask = pad(prompt_attention_masks, padding_side="left", padding_value=0) completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids_for_slice] completion_ids_for_text: list[list[int]] = [] padded_completion_ids_list = [] + completion_attention_masks = [] for completion_tensor in completion_ids_tensors: if len(completion_tensor) > max_completion_length: truncated_completion_tensor = completion_tensor[:max_completion_length] padded_completion_ids_list.append(truncated_completion_tensor) completion_ids_for_text.append(truncated_completion_tensor.tolist()) + completion_attention_masks.append(torch.ones(len(truncated_completion_tensor), device=device, dtype=torch.long)) elif len(completion_tensor) < max_completion_length: padding_needed = max_completion_length - len(completion_tensor) padded_tensor = torch.cat( @@ -1295,15 +1309,31 @@ def _process_completions_to_buffer( ) padded_completion_ids_list.append(padded_tensor) completion_ids_for_text.append(completion_tensor.tolist()) + completion_attention_masks.append( + torch.cat( + [ + torch.ones(len(completion_tensor), device=device, dtype=torch.long), + torch.zeros(padding_needed, device=device, dtype=torch.long), + ] + ) + ) else: padded_completion_ids_list.append(completion_tensor) completion_ids_for_text.append(completion_tensor.tolist()) + completion_attention_masks.append(torch.ones(len(completion_tensor), device=device, dtype=torch.long)) completion_ids_padded = torch.stack(padded_completion_ids_list) + completion_attention_mask = torch.stack(completion_attention_masks) new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) + new_attention_mask = torch.cat([prompt_attention_mask, completion_attention_mask], dim=1) prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) - new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) + new_attention_mask, new_labels = self._build_sequence_batch( + new_input_ids, + prompt_lengths, + pad_token_id, + attention_mask=new_attention_mask, + ) completion_texts = self.processing_class.batch_decode( completion_ids_for_text, From dcfce594fc5fea7bf3faee793cd1fc7296455e22 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 24 Mar 2026 16:10:11 +0100 Subject: [PATCH 12/37] Run precommit --- trl/experimental/gold/gold_trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 39462fbaa1e..bfabbd7e1bb 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1113,7 +1113,9 @@ def _build_sequence_batch( new_attention_mask[new_input_ids == pad_token_id] = 0 new_labels = torch.full_like(new_input_ids, -100) - new_labels[completion_mask & new_attention_mask.bool()] = new_input_ids[completion_mask & new_attention_mask.bool()] + new_labels[completion_mask & new_attention_mask.bool()] = new_input_ids[ + completion_mask & new_attention_mask.bool() + ] if attention_mask is None and pad_token_id is not None: new_labels[new_input_ids == pad_token_id] = -100 @@ -1293,7 +1295,9 @@ def _process_completions_to_buffer( truncated_completion_tensor = completion_tensor[:max_completion_length] padded_completion_ids_list.append(truncated_completion_tensor) completion_ids_for_text.append(truncated_completion_tensor.tolist()) - completion_attention_masks.append(torch.ones(len(truncated_completion_tensor), device=device, dtype=torch.long)) + completion_attention_masks.append( + torch.ones(len(truncated_completion_tensor), device=device, dtype=torch.long) + ) elif len(completion_tensor) < max_completion_length: padding_needed = max_completion_length - len(completion_tensor) padded_tensor = torch.cat( @@ -1320,7 +1324,9 @@ def _process_completions_to_buffer( else: padded_completion_ids_list.append(completion_tensor) completion_ids_for_text.append(completion_tensor.tolist()) - completion_attention_masks.append(torch.ones(len(completion_tensor), device=device, dtype=torch.long)) + completion_attention_masks.append( + torch.ones(len(completion_tensor), device=device, dtype=torch.long) + ) completion_ids_padded = torch.stack(padded_completion_ids_list) completion_attention_mask = torch.stack(completion_attention_masks) From ff81a898c250a1e4e5ee66452e96ddb5c869b9d5 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 24 Mar 2026 16:12:12 +0100 Subject: [PATCH 13/37] Add check for tokenizers and prompt length --- trl/scripts/distillation.py | 12 +- trl/trainer/distillation_trainer.py | 172 +++++++++++++++++++++++----- 2 files changed, 147 insertions(+), 37 deletions(-) diff --git a/trl/scripts/distillation.py b/trl/scripts/distillation.py index ca0ed4dde70..69321da6291 100644 --- a/trl/scripts/distillation.py +++ b/trl/scripts/distillation.py @@ -26,8 +26,8 @@ # Full training (off-policy only, lmbda=0): ``` python trl/scripts/distillation.py \ - --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ - --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 2e-5 \ --per_device_train_batch_size 4 \ @@ -40,8 +40,8 @@ # Mixed on/off-policy (lmbda=0.5): ``` python trl/scripts/distillation.py \ - --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ - --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 2e-5 \ --per_device_train_batch_size 4 \ @@ -55,8 +55,8 @@ # LoRA: ``` python trl/scripts/distillation.py \ - --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ - --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 2e-4 \ --per_device_train_batch_size 4 \ diff --git a/trl/trainer/distillation_trainer.py b/trl/trainer/distillation_trainer.py index 2605155eb49..bd0b7e8ae78 100644 --- a/trl/trainer/distillation_trainer.py +++ b/trl/trainer/distillation_trainer.py @@ -14,6 +14,7 @@ import random import textwrap +import warnings from collections import defaultdict, deque from collections.abc import Callable from contextlib import nullcontext @@ -273,6 +274,11 @@ def __init__( import json model_init_kwargs = json.loads(model_init_kwargs) + teacher_model_init_kwargs = args.teacher_model_init_kwargs or {} + if isinstance(teacher_model_init_kwargs, str): + import json + + teacher_model_init_kwargs = json.loads(teacher_model_init_kwargs) if isinstance(model, str): model_name_or_path = model model = create_model_from_path(model, **model_init_kwargs) @@ -313,25 +319,58 @@ def __init__( # ── Teacher model setup ── self.teacher_client = None + self._local_teacher_tokenizer_matches_student = True if args.teacher_model_server_url is not None: from ..generation.vllm_client import VLLMClient self.teacher_client = VLLMClient(base_url=args.teacher_model_server_url, connection_timeout=60.0) teacher_model = None elif teacher_model is not None: - if args.teacher_model_init_kwargs is None: - teacher_model_init_kwargs = {} - elif not isinstance(teacher_model, str): + if args.teacher_model_init_kwargs is not None and not isinstance(teacher_model, str): raise ValueError( "You passed teacher_model_init_kwargs to the config, but your teacher_model is already " "instantiated." ) - else: - teacher_model_init_kwargs = args.teacher_model_init_kwargs + + teacher_model_name_or_path = ( + teacher_model + if isinstance(teacher_model, str) + else getattr(getattr(teacher_model, "config", None), "_name_or_path", None) + ) + if teacher_model_name_or_path is None: + raise ValueError( + "DistillationTrainer requires a local teacher model with `config._name_or_path` set so its " + "tokenizer can be validated against the student tokenizer." + ) + + teacher_tokenizer_kwargs = {} + teacher_revision = teacher_model_init_kwargs.get("revision", args.teacher_model_revision) + if teacher_revision is not None: + teacher_tokenizer_kwargs["revision"] = teacher_revision + if teacher_model_init_kwargs.get("trust_remote_code") is not None: + teacher_tokenizer_kwargs["trust_remote_code"] = teacher_model_init_kwargs["trust_remote_code"] + teacher_processing_class = AutoTokenizer.from_pretrained( + teacher_model_name_or_path, **teacher_tokenizer_kwargs + ) + if getattr(teacher_processing_class, "pad_token", None) is None: + teacher_processing_class.pad_token = teacher_processing_class.eos_token + self._local_teacher_tokenizer_matches_student = self._local_teacher_tokenizers_match( + processing_class, teacher_processing_class + ) + if not self._local_teacher_tokenizer_matches_student: + warnings.warn( + "DistillationTrainer's built-in local-teacher loss assumes the student and teacher share the " + "same tokenizer. The loaded local teacher tokenizer does not match the student tokenizer, so " + "the teacher weights will be left unchanged for subclass overrides. Direct use of the base " + "DistillationTrainer with this local teacher remains unsupported.", + UserWarning, + stacklevel=2, + ) + + if isinstance(teacher_model, str): + torch_dtype = teacher_model_init_kwargs.get("torch_dtype") teacher_model_init_kwargs["torch_dtype"] = ( - teacher_model_init_kwargs["torch_dtype"] - if teacher_model_init_kwargs.get("torch_dtype") in ["auto", None] - else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) + torch_dtype if torch_dtype in ["auto", None] else getattr(torch, torch_dtype) ) if isinstance(teacher_model, str): @@ -361,7 +400,8 @@ def __init__( # ── Prepare teacher model (after super().__init__ so accelerator is ready) ── if teacher_model is not None: - teacher_model.resize_token_embeddings(self.model.config.vocab_size) + if self._local_teacher_tokenizer_matches_student: + teacher_model.resize_token_embeddings(self.model.config.vocab_size) if self.is_deepspeed_enabled: self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) else: @@ -457,6 +497,54 @@ def __init__( self.vllm_sync_frequency = args.vllm_sync_frequency self._last_vllm_sync_step = -1 + @staticmethod + def _local_teacher_tokenizers_match( + student_processing_class: PreTrainedTokenizerBase, + teacher_processing_class: PreTrainedTokenizerBase, + ) -> bool: + """Check whether the student and local teacher tokenizers share the same vocabulary.""" + return student_processing_class.get_vocab() == teacher_processing_class.get_vocab() + + def _raise_if_local_teacher_tokenizer_mismatch(self) -> None: + """Guard the base local-teacher JSD path, while still allowing subclass overrides.""" + if self.teacher_model is not None and not self._local_teacher_tokenizer_matches_student: + raise ValueError( + "DistillationTrainer's built-in local-teacher loss only supports student/teacher pairs that use " + "the same tokenizer. Use a same-tokenizer local teacher, use `teacher_model_server_url`, or " + "override the local teacher loss path in a subclass." + ) + + def _get_completion_lengths(self, generated_tokens: torch.Tensor, prompt_width: int) -> torch.Tensor: + """Infer per-sample completion lengths from generated tokens.""" + completion_tokens = generated_tokens[:, prompt_width:] + pad_token_id = self.processing_class.pad_token_id + eos_token_id = self.generation_config.eos_token_id + if eos_token_id is None: + eos_token_ids = set() + elif isinstance(eos_token_id, int): + eos_token_ids = {eos_token_id} + else: + eos_token_ids = set(eos_token_id) + pad_equals_eos = pad_token_id is not None and pad_token_id in eos_token_ids + + completion_lengths = [] + for row in completion_tokens.tolist(): + if pad_equals_eos and eos_token_ids: + completion_length = len(row) + for idx, token_id in enumerate(row): + if token_id in eos_token_ids: + completion_length = idx + 1 + break + elif pad_token_id is not None: + completion_length = len(row) + while completion_length > 0 and row[completion_length - 1] == pad_token_id: + completion_length -= 1 + else: + completion_length = len(row) + completion_lengths.append(completion_length) + + return torch.tensor(completion_lengths, device=generated_tokens.device, dtype=torch.long) + # ────────────────────────────────────────────────────────────────────── # Dataset / Dataloader # ────────────────────────────────────────────────────────────────────── @@ -626,18 +714,20 @@ def _generate_with_model(self, slices: list[dict[str, torch.Tensor | Any]], on_p batch_size = generated_tokens.size(0) device = generated_tokens.device pad_token_id = self.processing_class.pad_token_id - - prompt_lengths = torch.full( - (batch_size,), slice_inputs["prompts"].shape[1], dtype=torch.long, device=device - ) + prompt_width = slice_inputs["prompts"].shape[1] + prompt_mask = slice_inputs.get("prompt_attention_mask") + if prompt_mask is not None: + prompt_token_lengths = prompt_mask.sum(dim=1) + else: + prompt_token_lengths = torch.full((batch_size,), prompt_width, dtype=torch.long, device=device) + completion_lengths = self._get_completion_lengths(generated_tokens, prompt_width) new_attention_mask, new_labels = self._build_sequence_batch( - generated_tokens, prompt_lengths, pad_token_id + generated_tokens, prompt_width, prompt_token_lengths, completion_lengths ) # Decode for logging prompt_texts = [] completion_texts = [] - prompt_mask = slice_inputs.get("prompt_attention_mask") for idx in range(batch_size): prompt_tokens = slice_inputs["prompts"][idx] if prompt_mask is not None: @@ -647,9 +737,13 @@ def _generate_with_model(self, slices: list[dict[str, torch.Tensor | Any]], on_p prompt_texts.append( self.processing_class.decode(prompt_tokens.tolist(), skip_special_tokens=False) ) - length = int(prompt_lengths[idx].item()) + length = prompt_width + completion_length = int(completion_lengths[idx].item()) completion_texts.append( - self.processing_class.decode(generated_tokens[idx, length:].tolist(), skip_special_tokens=False) + self.processing_class.decode( + generated_tokens[idx, length : length + completion_length].tolist(), + skip_special_tokens=False, + ) ) updated = dict(slice_inputs) @@ -687,22 +781,29 @@ def _store_completions_in_buffer( for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] prompt_id_tensors = slice_prompt_ids[slice_idx] + prompt_width = max(len(p) for p in prompt_id_tensors) + prompt_token_lengths = torch.tensor([len(p) for p in prompt_id_tensors], device=device, dtype=torch.long) + prompt_attention_mask = ( + torch.arange(prompt_width, device=device).unsqueeze(0) + >= (prompt_width - prompt_token_lengths).unsqueeze(1) + ).long() # Left-pad prompt token IDs to the longest prompt in this slice - max_prompt_len = max(len(p) for p in prompt_id_tensors) prompt_ids = torch.stack([ - F.pad(p, (max_prompt_len - len(p), 0), value=pad_token_id) + F.pad(p, (prompt_width - len(p), 0), value=pad_token_id) for p in prompt_id_tensors ]).to(device) # Pad/truncate completions (right-pad to max_completion_length) completion_tensors = [] completion_ids_for_text = [] + completion_lengths = [] for comp_ids in slice_completions[slice_idx]: t = torch.tensor(comp_ids, device=device) if len(t) > max_completion_length: t = t[:max_completion_length] completion_ids_for_text.append(t.tolist()) + completion_lengths.append(len(t)) if len(t) < max_completion_length: padding = torch.full( (max_completion_length - len(t),), pad_token_id, device=device, dtype=t.dtype @@ -712,8 +813,10 @@ def _store_completions_in_buffer( completion_ids_padded = torch.stack(completion_tensors) new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) - prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) - new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) + completion_lengths = torch.tensor(completion_lengths, device=device, dtype=torch.long) + new_attention_mask, new_labels = self._build_sequence_batch( + new_input_ids, prompt_width, prompt_token_lengths, completion_lengths + ) # Decode for logging only prompt_texts = self.processing_class.batch_decode( @@ -729,27 +832,32 @@ def _store_completions_in_buffer( updated["labels"] = new_labels # Update prompts to match the new padding width so prompt_length is consistent updated["prompts"] = prompt_ids + updated["prompt_attention_mask"] = prompt_attention_mask self._buffered_inputs[slice_idx] = updated self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) @staticmethod def _build_sequence_batch( - new_input_ids: torch.Tensor, prompt_lengths: torch.Tensor, pad_token_id: int | None + new_input_ids: torch.Tensor, + prompt_width: int, + prompt_token_lengths: torch.Tensor, + completion_lengths: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Build attention mask and labels from full sequences and prompt lengths.""" - prompt_lengths = prompt_lengths.to(device=new_input_ids.device, dtype=torch.long) + """Build attention mask and labels from prompt/completion lengths.""" + prompt_token_lengths = prompt_token_lengths.to(device=new_input_ids.device, dtype=torch.long) + completion_lengths = completion_lengths.to(device=new_input_ids.device, dtype=torch.long) positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) - completion_mask = positions >= prompt_lengths.unsqueeze(1) - - new_attention_mask = torch.ones_like(new_input_ids) - if pad_token_id is not None: - new_attention_mask[new_input_ids == pad_token_id] = 0 + prompt_mask = (positions < prompt_width) & ( + positions >= (prompt_width - prompt_token_lengths).unsqueeze(1) + ) + completion_mask = (positions >= prompt_width) & ( + positions < (prompt_width + completion_lengths).unsqueeze(1) + ) + new_attention_mask = (prompt_mask | completion_mask).long() new_labels = torch.full_like(new_input_ids, -100) new_labels[completion_mask] = new_input_ids[completion_mask] - if pad_token_id is not None: - new_labels[new_input_ids == pad_token_id] = -100 return new_attention_mask, new_labels @@ -912,6 +1020,8 @@ def _get_teacher_token_logprobs_from_server(self, inputs: dict[str, torch.Tensor return teacher_logprobs def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + self._raise_if_local_teacher_tokenizer_mismatch() + if self.use_liger_loss: loss = self._compute_liger_loss(model, inputs) return (loss, None) if return_outputs else loss From f5fc947eb98df4457925a81fe3dd112a79bfaa44 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Fri, 27 Mar 2026 15:47:32 +0000 Subject: [PATCH 14/37] Implement efficient logprob generation --- trl/generation/vllm_client.py | 176 +++++++++++++-- trl/scripts/vllm_serve.py | 322 ++++++++++++++++++++++------ trl/trainer/distillation_trainer.py | 258 ++++++++++++++++------ 3 files changed, 610 insertions(+), 146 deletions(-) diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 8f7680869bb..9f2a6d3b696 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -528,6 +528,9 @@ def get_sequence_logprobs( sequences: list[list[int]], prompt_lengths: list[int], top_logprobs: int = 100, + use_binary: bool = True, + chunk_size: int = 0, + max_concurrent_requests: int = 4, ) -> dict[str, list]: """ Computes teacher logprobs for existing token sequences without generating new tokens. @@ -536,6 +539,11 @@ def get_sequence_logprobs( completion region only. This is used for knowledge distillation where the teacher model evaluates existing sequences rather than generating new ones. + When `chunk_size > 0`, splits the batch into chunks and dispatches them concurrently via a thread pool, keeping + the server's data-parallel workers busy. + + When `use_binary=True`, uses base64-encoded numpy arrays for fast serialization instead of nested JSON lists. + Args: sequences (`list[list[int]]`): List of full token ID sequences (prompt + completion). @@ -543,6 +551,13 @@ def get_sequence_logprobs( Number of prompt tokens in each sequence. Logprobs are returned starting from this position. top_logprobs (`int`, *optional*, defaults to `100`): Number of top logprobs to return per token position. + use_binary (`bool`, *optional*, defaults to `True`): + Use binary (base64 numpy) response format for faster serialization. + chunk_size (`int`, *optional*, defaults to `0`): + If > 0, split batch into chunks of this size and dispatch concurrently. + If 0, send the entire batch in a single request. + max_concurrent_requests (`int`, *optional*, defaults to `4`): + Maximum number of concurrent requests when using chunked dispatch. Returns: `dict` with keys: @@ -552,23 +567,154 @@ def get_sequence_logprobs( - `logprob_token_ids` (`list[list[list[int]]]`): Token IDs corresponding to each logprob, same shape as `logprobs`. """ + import base64 + from concurrent.futures import ThreadPoolExecutor, as_completed + + import numpy as np + url = f"{self.base_url}/get_sequence_logprobs/" - response = self.session.post( - url, - json={ - "sequences": sequences, - "prompt_lengths": prompt_lengths, - "top_logprobs": top_logprobs, - }, - ) - if response.status_code == 200: - json_response = response.json() - return { - "logprobs": json_response["logprobs"], - "logprob_token_ids": json_response["logprob_token_ids"], - } + response_format = "binary" if use_binary else "json" + + if chunk_size > 0 and len(sequences) > chunk_size: + # Chunked concurrent dispatch + n = len(sequences) + chunks = [] + for i in range(0, n, chunk_size): + chunks.append((sequences[i : i + chunk_size], prompt_lengths[i : i + chunk_size])) + + responses = [None] * len(chunks) + + def _send_chunk(idx, seqs, plens): + resp = self.session.post( + url, + json={ + "sequences": seqs, + "prompt_lengths": plens, + "top_logprobs": top_logprobs, + "response_format": response_format, + }, + ) + if resp.status_code != 200: + raise Exception(f"Request failed: {resp.status_code}, {resp.text}") + return idx, resp.json() + + with ThreadPoolExecutor(max_workers=min(max_concurrent_requests, len(chunks))) as executor: + futures = { + executor.submit(_send_chunk, idx, seqs, plens): idx + for idx, (seqs, plens) in enumerate(chunks) + } + for future in as_completed(futures): + idx, result = future.result() + responses[idx] = result + + # Merge results + if use_binary: + return self._merge_binary_responses(responses, top_logprobs) + else: + all_logprobs = [] + all_token_ids = [] + for resp in responses: + all_logprobs.extend(resp["logprobs"]) + all_token_ids.extend(resp["logprob_token_ids"]) + return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} else: - raise Exception(f"Request failed: {response.status_code}, {response.text}") + # Single request + response = self.session.post( + url, + json={ + "sequences": sequences, + "prompt_lengths": prompt_lengths, + "top_logprobs": top_logprobs, + "response_format": response_format, + }, + ) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + json_response = response.json() + if use_binary: + return self._decode_binary_logprobs(json_response) + else: + return { + "logprobs": json_response["logprobs"], + "logprob_token_ids": json_response["logprob_token_ids"], + } + + @staticmethod + def _decode_binary_logprobs(response: dict) -> dict[str, list]: + """Decode base64-encoded numpy arrays back to nested lists. + + Returns a dict with: + ``logprobs`` / ``logprob_token_ids`` — teacher's sorted top-k logprobs and + token IDs (shape per sequence: ``(comp_len, top_k)``). Used for the + forward KL term. + ``actual_logprobs`` / ``actual_token_ids`` — teacher logprob for the actual + token at each position (shape per sequence: ``(comp_len, 1)``). Used + for the reverse KL term. + """ + import base64 + + import numpy as np + + shape = response["shape"] # [batch, max_comp_len, top_k] + comp_lengths = response["completion_lengths"] + + logprobs_arr = np.frombuffer(base64.b64decode(response["logprobs_b64"]), dtype=np.float32).reshape(shape) + token_ids_arr = np.frombuffer(base64.b64decode(response["token_ids_b64"]), dtype=np.int32).reshape(shape) + + # Convert back to nested lists, trimming padding + all_logprobs = [] + all_token_ids = [] + for i, comp_len in enumerate(comp_lengths): + all_logprobs.append(logprobs_arr[i, :comp_len, :].tolist()) + all_token_ids.append(token_ids_arr[i, :comp_len, :].tolist()) + + result = {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + + # Decode actual-token logprobs (for reverse KL) + if "actual_logprobs_b64" in response: + actual_shape = [shape[0], shape[1], 1] + actual_lp = np.frombuffer( + base64.b64decode(response["actual_logprobs_b64"]), dtype=np.float32 + ).reshape(actual_shape) + actual_ids = np.frombuffer( + base64.b64decode(response["actual_token_ids_b64"]), dtype=np.int32 + ).reshape(actual_shape) + all_actual_lps = [] + all_actual_ids = [] + for i, comp_len in enumerate(comp_lengths): + all_actual_lps.append(actual_lp[i, :comp_len, :].tolist()) + all_actual_ids.append(actual_ids[i, :comp_len, :].tolist()) + result["actual_logprobs"] = all_actual_lps + result["actual_token_ids"] = all_actual_ids + + return result + + @staticmethod + def _merge_binary_responses(responses: list[dict], top_logprobs: int) -> dict[str, list]: + """Merge binary responses from multiple chunks into a single result.""" + import base64 + + import numpy as np + + all_logprobs = [] + all_token_ids = [] + all_actual_lps = [] + all_actual_ids = [] + has_actual = False + for resp in responses: + decoded = VLLMClient._decode_binary_logprobs(resp) + all_logprobs.extend(decoded["logprobs"]) + all_token_ids.extend(decoded["logprob_token_ids"]) + if "actual_logprobs" in decoded: + has_actual = True + all_actual_lps.extend(decoded["actual_logprobs"]) + all_actual_ids.extend(decoded["actual_token_ids"]) + result = {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + if has_actual: + result["actual_logprobs"] = all_actual_lps + result["actual_token_ids"] = all_actual_ids + return result def reset_prefix_cache(self): """ diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 99013a2da49..9c819e4b5a9 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -441,6 +441,8 @@ def main(script_args: ScriptArguments): @asynccontextmanager async def lifespan(app: FastAPI): + import asyncio as _asyncio + # Wait for all workers to send "ready" ready_connections = set() while len(ready_connections) < script_args.data_parallel_size: @@ -449,8 +451,13 @@ async def lifespan(app: FastAPI): if isinstance(msg, dict) and msg.get("status") == "ready": ready_connections.add(connection) + # Start the logprob request batcher background task + batcher_task = _asyncio.create_task(_logprob_batcher()) + yield + batcher_task.cancel() + # Wait for processes to terminate for process in processes: process.join(timeout=10) # Wait for 10 seconds for the process to terminate @@ -646,82 +653,273 @@ class SequenceLogprobsRequest(BaseModel): sequences: list[list[int]] prompt_lengths: list[int] top_logprobs: int = 100 + response_format: str = "json" # "json" (legacy) or "binary" (base64 numpy arrays) class SequenceLogprobsResponse(BaseModel): - logprobs: list[list[list[float | None]]] - logprob_token_ids: list[list[list[int]]] + logprobs: list[list[list[float | None]]] | None = None + logprob_token_ids: list[list[list[int]]] | None = None + # Binary format fields (base64-encoded numpy arrays) + logprobs_b64: str | None = None + token_ids_b64: str | None = None + shape: list[int] | None = None # [batch_size, max_completion_len, top_logprobs] + completion_lengths: list[int] | None = None # actual completion length per sample + + def _run_prompt_logprobs(prompts, sampling_params): + """Send prompts to DP workers and collect outputs.""" + chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) + for connection, chunk in zip(connections, chunked_prompts, strict=True): + if not chunk: + chunk = [{"prompt_token_ids": [0]}] + kwargs = {"prompts": chunk, "sampling_params": sampling_params} + connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) + all_outputs = [connection.recv() for connection in connections] + all_outputs = [output for output, chunk in zip(all_outputs, chunked_prompts, strict=True) if chunk] + return list(chain.from_iterable(all_outputs)) + + # ── Request batching for get_sequence_logprobs ── + # Collects concurrent requests into batches and dispatches them together so that + # all DP workers stay busy. Without this, async endpoint handlers block the event + # loop during pipe I/O, serializing requests and leaving DP workers idle. + import asyncio + import threading + + _logprob_queue: asyncio.Queue = asyncio.Queue() + + # Maximum time (seconds) to wait for more requests before dispatching a batch. + _BATCH_WAIT_S = 0.005 # 5ms - short enough to not add much latency when lightly loaded + # Maximum number of sequences per batch (set to DP size so each worker gets sequences) + _MAX_BATCH_SEQS = max(script_args.data_parallel_size * 4, 16) + # Maximum total tokens per batch. prompt_logprobs materializes full-vocab logits + # during the forward pass, so each worker can safely handle ~1 max-length sequence. + # Budget = max_model_len * dp_size gives ~1 sequence per worker at max length. + _max_model_len = script_args.max_model_len or 8192 + _MAX_BATCH_TOKENS = _max_model_len * script_args.data_parallel_size + + async def _logprob_batcher(): + """Background task that continuously drains the queue, batches requests, and dispatches.""" + loop = asyncio.get_event_loop() + + while True: + # Wait for the first request + batch = [] + batch_tokens = 0 + item = await _logprob_queue.get() + batch.append(item) + # Count tokens in this item's sequences + for prompt in item[0]: + batch_tokens += len(prompt.get("prompt_token_ids", [])) + + # Collect more requests up to batch limit, timeout, or token budget + deadline = loop.time() + _BATCH_WAIT_S + while len(batch) < _MAX_BATCH_SEQS and batch_tokens < _MAX_BATCH_TOKENS: + remaining = deadline - loop.time() + if remaining <= 0: + break + try: + item = await asyncio.wait_for(_logprob_queue.get(), timeout=remaining) + # Check if adding this item would exceed the token budget + item_tokens = sum(len(p.get("prompt_token_ids", [])) for p in item[0]) + if batch_tokens + item_tokens > _MAX_BATCH_TOKENS and len(batch) > 0: + # Put it back and dispatch current batch + await _logprob_queue.put(item) + break + batch.append(item) + batch_tokens += item_tokens + except asyncio.TimeoutError: + break + + # batch is a list of (prompts, prompt_lengths, top_logprobs, response_format, future) + # All items in a batch must share the same top_logprobs (enforced at dispatch time) + # Group by top_logprobs to handle mixed requests + groups = {} + for prompts, prompt_lengths, top_logprobs, response_format, future in batch: + key = top_logprobs + if key not in groups: + groups[key] = [] + groups[key].append((prompts, prompt_lengths, response_format, future)) + + for top_logprobs, items in groups.items(): + # Merge all sequences into a single batch + all_prompts = [] + all_prompt_lengths = [] + offsets = [] # (start_idx, count) per original request + for prompts, prompt_lengths, response_format, future in items: + start = len(all_prompts) + all_prompts.extend(prompts) + all_prompt_lengths.extend(prompt_lengths) + offsets.append((start, len(prompts))) + + sampling_params = SamplingParams( + max_tokens=1, + temperature=1.0, + prompt_logprobs=top_logprobs, + ) + + # Dispatch to workers in a thread to avoid blocking the event loop + try: + all_outputs = await loop.run_in_executor( + None, _run_prompt_logprobs, all_prompts, sampling_params + ) + except Exception as e: + # Signal error to all waiting requests + for _, _, _, future in items: + if not future.done(): + future.set_exception(e) + continue + + # Split results back to individual requests + for (start, count), (_, prompt_lengths, response_format, future) in zip(offsets, items): + outputs_slice = all_outputs[start : start + count] + if not future.done(): + future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) + + def _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format): + """Format vLLM outputs into the response dict (runs in any thread).""" + import numpy as np + + batch_size = len(all_outputs) + use_binary = response_format == "binary" + + if use_binary: + from starlette.responses import Response + + comp_lengths = [] + for output, prompt_length in zip(all_outputs, prompt_lengths, strict=True): + prompt_lps = output.prompt_logprobs + if prompt_lps is None: + raise ValueError("prompt_logprobs is None.") + comp_lengths.append(len(prompt_lps) - prompt_length) + + max_comp_len = max(comp_lengths) if comp_lengths else 0 + + # logprobs_arr / token_ids_arr: teacher's sorted top-k logprobs + token ids (for forward KL). + # actual_logprobs_arr / actual_token_ids_arr: actual token's teacher logprob (for reverse KL). + logprobs_arr = np.full((batch_size, max_comp_len, top_k), float("-inf"), dtype=np.float32) + token_ids_arr = np.zeros((batch_size, max_comp_len, top_k), dtype=np.int32) + actual_logprobs_arr = np.full((batch_size, max_comp_len, 1), float("-inf"), dtype=np.float32) + actual_token_ids_arr = np.zeros((batch_size, max_comp_len, 1), dtype=np.int32) + + for i, (output, prompt_length) in enumerate( + zip(all_outputs, prompt_lengths, strict=True) + ): + prompt_lps = output.prompt_logprobs + seq_tokens = output.prompt_token_ids + if comp_lengths[i] == 0: + continue + + for pos in range(prompt_length, len(prompt_lps)): + lp = prompt_lps[pos] + if lp is None: + continue + t = pos - prompt_length + actual_token = seq_tokens[pos] + + # Actual token's logprob (for reverse KL) + if actual_token in lp: + val = lp[actual_token].logprob + if not math.isnan(val): + actual_logprobs_arr[i, t, 0] = val + actual_token_ids_arr[i, t, 0] = actual_token + + # Teacher's top-k logprobs (for forward KL) + if top_k == 1: + # Fast path: find rank-1 directly instead of sorting + for token_id, logprob_obj in lp.items(): + if logprob_obj.rank == 1: + val = logprob_obj.logprob + if not math.isnan(val): + logprobs_arr[i, t, 0] = val + token_ids_arr[i, t, 0] = token_id + break + else: + sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) + for k_idx, (token_id, logprob_obj) in enumerate(sorted_items[:top_k]): + val = logprob_obj.logprob + if not math.isnan(val): + logprobs_arr[i, t, k_idx] = val + token_ids_arr[i, t, k_idx] = token_id + + payload = { + "logprobs_b64": base64.b64encode(logprobs_arr.tobytes()).decode("ascii"), + "token_ids_b64": base64.b64encode(token_ids_arr.tobytes()).decode("ascii"), + "actual_logprobs_b64": base64.b64encode(actual_logprobs_arr.tobytes()).decode("ascii"), + "actual_token_ids_b64": base64.b64encode(actual_token_ids_arr.tobytes()).decode("ascii"), + "shape": [batch_size, max_comp_len, top_k], + "completion_lengths": comp_lengths, + } + + try: + import orjson - @app.post("/get_sequence_logprobs/", response_model=SequenceLogprobsResponse) + return Response(content=orjson.dumps(payload), media_type="application/json") + except ImportError: + return payload + else: + all_logprobs = [] + all_token_ids = [] + for output, prompt_length in zip(all_outputs, prompt_lengths, strict=True): + prompt_lps = output.prompt_logprobs + if prompt_lps is None: + raise ValueError("prompt_logprobs is None.") + seq_logprobs = [] + seq_token_ids = [] + for pos in range(prompt_length, len(prompt_lps)): + lp = prompt_lps[pos] + if lp is None: + seq_logprobs.append([]) + seq_token_ids.append([]) + continue + sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) + seq_token_ids.append([token_id for token_id, _ in sorted_items]) + seq_logprobs.append( + [None if math.isnan(item.logprob) else item.logprob for _, item in sorted_items] + ) + all_logprobs.append(seq_logprobs) + all_token_ids.append(seq_token_ids) + return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + + @app.post("/get_sequence_logprobs/") async def get_sequence_logprobs(request: SequenceLogprobsRequest): """ Computes teacher logprobs for existing token sequences without generating new tokens. - Sends the full sequence (prompt + completion) as the vLLM prompt with `max_tokens=1` and - `prompt_logprobs=top_logprobs`. Returns logprobs only for the completion region (positions - from `prompt_length` onwards) for each sequence. - - Args: - request (`SequenceLogprobsRequest`): - - `sequences` (list of list of `int`): Full token ID sequences (prompt + completion). - - `prompt_lengths` (list of `int`): Number of prompt tokens in each sequence. Logprobs - are returned starting from this position. - - `top_logprobs` (`int`, *optional*, defaults to `100`): Number of top logprobs per position. + Concurrent requests are automatically batched and dispatched together to maximize + GPU utilization across DP workers. This avoids the event-loop-blocking problem where + synchronous pipe I/O serializes requests despite having multiple DP workers. - Returns: - `SequenceLogprobsResponse`: - - `logprobs` (list of list of list of `float`): Per-token logprobs of shape - (batch, completion_len, top_logprobs), sorted by descending probability. - - `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each - logprob, same shape as `logprobs`. + Supports two response formats: + - `"json"` (default): Nested lists, backward-compatible with existing clients. + - `"binary"`: Base64-encoded numpy arrays for fast serialization/deserialization. """ if len(request.sequences) != len(request.prompt_lengths): raise ValueError("sequences and prompt_lengths must have the same length.") - prompts = [{"prompt_token_ids": seq} for seq in request.sequences] - sampling_params = SamplingParams( - max_tokens=1, - temperature=1.0, - prompt_logprobs=request.top_logprobs, - ) - - chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) + # Validate sequence lengths against max_model_len to prevent worker OOM crashes + if _max_model_len: + for i, seq in enumerate(request.sequences): + if len(seq) > _max_model_len: + raise ValueError( + f"Sequence {i} has length {len(seq)} which exceeds max_model_len={_max_model_len}. " + f"Truncate sequences or increase --max-model-len." + ) - for connection, chunk in zip(connections, chunked_prompts, strict=True): - if not chunk: - chunk = [{"prompt_token_ids": [0]}] - kwargs = {"prompts": chunk, "sampling_params": sampling_params} - connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) - - all_outputs = [connection.recv() for connection in connections] - all_outputs = [output for output, chunk in zip(all_outputs, chunked_prompts, strict=True) if chunk] - all_outputs = list(chain.from_iterable(all_outputs)) - - all_logprobs = [] - all_token_ids = [] - for output, prompt_length in zip(all_outputs, request.prompt_lengths, strict=True): - # prompt_logprobs is a list of dicts, one per prompt token (first token is None) - prompt_lps = output.prompt_logprobs - if prompt_lps is None: - raise ValueError("prompt_logprobs is None. Ensure the vLLM server supports prompt_logprobs.") - - seq_logprobs = [] - seq_token_ids = [] - # Extract logprobs only for the completion region - for pos in range(prompt_length, len(prompt_lps)): - lp = prompt_lps[pos] - if lp is None: - seq_logprobs.append([]) - seq_token_ids.append([]) - continue - sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) - seq_token_ids.append([token_id for token_id, _ in sorted_items]) - seq_logprobs.append( - [None if math.isnan(item.logprob) else item.logprob for _, item in sorted_items] - ) - all_logprobs.append(seq_logprobs) - all_token_ids.append(seq_token_ids) + prompts = [{"prompt_token_ids": seq} for seq in request.sequences] - return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} + # Submit to the batching queue and await result + loop = asyncio.get_event_loop() + future = loop.create_future() + await _logprob_queue.put(( + prompts, + list(request.prompt_lengths), + request.top_logprobs, + request.response_format, + future, + )) + + # Wait for the batcher to process our request + all_outputs, prompt_lengths, top_k, response_format = await future + + return _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format) class ChatRequest(BaseModel): messages: list[list[dict]] diff --git a/trl/trainer/distillation_trainer.py b/trl/trainer/distillation_trainer.py index bd0b7e8ae78..a579f5ef1bf 100644 --- a/trl/trainer/distillation_trainer.py +++ b/trl/trainer/distillation_trainer.py @@ -15,7 +15,7 @@ import random import textwrap import warnings -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Callable from contextlib import nullcontext from functools import partial @@ -217,6 +217,35 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: } +class _RepeatBatchDataLoader: + """Repeats each collated batch ``repeat_count`` times without re-collation. + + ``RepeatSampler`` with ``repeat_count > 1`` causes the DataLoader to re-collate + (re-tokenize) the same examples on every repeat, which is wasteful. This wrapper + instead keeps ``repeat_count=1`` in the sampler and repeats the already-collated + tensor dict, avoiding redundant tokenization. + """ + + def __init__(self, dataloader, repeat_count: int): + self.dataloader = dataloader + self.repeat_count = repeat_count + + def __iter__(self): + for batch in self.dataloader: + for _ in range(self.repeat_count): + yield batch + + def __len__(self): + return len(self.dataloader) * self.repeat_count + + def set_epoch(self, epoch: int): + if hasattr(self.dataloader, "set_epoch"): + self.dataloader.set_epoch(epoch) + + def __getattr__(self, attr): + return getattr(self.dataloader, attr) + + class DistillationTrainer(_BaseTrainer): """ Trainer for knowledge distillation from a teacher model to a student model. @@ -456,10 +485,9 @@ def __init__( self.log_completions_steps = args.log_completions_steps self.num_completions_to_print = args.num_completions_to_print - maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps self._textual_logs = { - "prompt": deque(maxlen=maxlen), - "completion": deque(maxlen=maxlen), + "prompt": [], + "completion": [], } # ── vLLM for student generation ── @@ -566,7 +594,7 @@ def _get_train_sampler(self, dataset=None): data_source=dataset, mini_repeat_count=self.num_generations, batch_size=self.args.generation_batch_size * self.accelerator.num_processes, - repeat_count=self.args.gradient_accumulation_steps, + repeat_count=1, shuffle=True, seed=self.args.seed, ) @@ -604,7 +632,8 @@ def get_train_dataloader(self): if self.args.dataloader_num_workers > 0: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return _RepeatBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps) # ────────────────────────────────────────────────────────────────────── # Buffering: on/off-policy mixing across gradient accumulation steps @@ -652,6 +681,18 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s if on_policy_indices: self._generate_student_completions(slices, on_policy_indices) + # Gather on-policy text logs once per optimizer step (all processes must participate) + if self.log_completions: + on_policy_prompts = [] + on_policy_completions = [] + for i in on_policy_indices: + if self._buffered_text_logs[i] is not None: + prompts, completions = self._buffered_text_logs[i] + on_policy_prompts.extend(prompts) + on_policy_completions.extend(completions) + self._textual_logs["prompt"].extend(gather_object(on_policy_prompts)) + self._textual_logs["completion"].extend(gather_object(on_policy_completions)) + @profiling_decorator def _generate_student_completions( self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] @@ -978,12 +1019,20 @@ def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Te else: raise ValueError("No teacher model or teacher server configured.") - def _get_teacher_token_logprobs_from_server(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: + def _get_teacher_token_logprobs_from_server( + self, inputs: dict[str, torch.Tensor | Any] + ) -> dict[str, torch.Tensor]: """Fetch per-token teacher logprobs from an external vLLM server. - Returns a tensor of shape (batch_size, completion_length) containing the teacher's log-probability - for the token present at each completion position in the input sequence. + Returns a dict with: + ``actual_logprobs`` – (batch, completion_length) teacher log-prob for the actual + token at each position (for reverse KL). + ``topk_logprobs`` – (batch, completion_length, K) teacher top-k sorted logprobs + (for forward KL). + ``topk_token_ids`` – (batch, completion_length, K) corresponding token IDs. """ + import numpy as np + input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] batch_size = input_ids.shape[0] @@ -998,26 +1047,90 @@ def _get_teacher_token_logprobs_from_server(self, inputs: dict[str, torch.Tensor sequences.append(seq) prompt_lengths.append(prompt_length) - # Request top-1 logprobs from the teacher server. The server returns the logprob of the token at each - # position (i.e., the token in the sequence we sent), which is exactly what we need for all divergences. + K = self.teacher_server_top_logprobs result = self.teacher_client.get_sequence_logprobs( sequences=sequences, prompt_lengths=prompt_lengths, - top_logprobs=1, + top_logprobs=K, ) - # Build a (batch_size, completion_length) tensor of teacher logprobs for the sequence tokens - completion_length = max(len(lps) for lps in result["logprobs"]) device = input_ids.device - teacher_logprobs = torch.full( - (batch_size, completion_length), float("-inf"), dtype=torch.float32, device=device - ) - for i, seq_lps in enumerate(result["logprobs"]): - for pos, lps in enumerate(seq_lps): - if lps and lps[0] is not None: - teacher_logprobs[i, pos] = lps[0] + completion_length = max(len(lps) for lps in result["logprobs"]) + + # actual_logprobs: (B, T) — teacher logprob for the actual token + def _actual_to_tensor(key): + arr = np.full((batch_size, completion_length), float("-inf"), dtype=np.float32) + for i, seq_lps in enumerate(result[key]): + if seq_lps: + vals = np.array(seq_lps, dtype=np.float32) # (comp_len_i, 1) + arr[i, : vals.shape[0]] = vals[:, 0] + return torch.from_numpy(arr).to(device) + + # topk: (B, T, K) + def _topk_to_tensor(key, k, np_dtype, fill): + arr = np.full((batch_size, completion_length, k), fill, dtype=np_dtype) + for i, seq_vals in enumerate(result[key]): + if seq_vals: + vals = np.array(seq_vals, dtype=np_dtype) # (comp_len_i, k) + arr[i, : vals.shape[0], :] = vals + return torch.from_numpy(arr).to(device) + + return { + "actual_logprobs": _actual_to_tensor("actual_logprobs"), + "topk_logprobs": _topk_to_tensor("logprobs", K, np.float32, float("-inf")), + "topk_token_ids": _topk_to_tensor("logprob_token_ids", K, np.int64, 0), + } - return teacher_logprobs + def _compute_server_divergence_loss( + self, + teacher_result: dict[str, torch.Tensor], + student_log_probs: torch.Tensor, + completion_tokens: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute forward/reverse KL or JSD using teacher logprobs from the server. + + The forward KL term sums over the teacher's top-k tokens (better approximation + as k increases). The reverse KL term always uses the actual (student-sampled) + token only, because the teacher's top-k may not cover the student's high-probability + tokens. + + Args: + teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), + ``topk_token_ids`` (B, T, K). + student_log_probs: (B, T, V) student log-softmax over vocabulary. + completion_tokens: (B, T) actual token IDs in the completion. + labels: (B, T) with -100 for positions to ignore. + """ + topk_teacher_lps = teacher_result["topk_logprobs"] # (B, T, K) + topk_token_ids = teacher_result["topk_token_ids"] # (B, T, K) + actual_teacher_lps = teacher_result["actual_logprobs"] # (B, T) + + # ── Forward KL term: sum over teacher's top-k tokens ── + # Gather student logprobs for each of the teacher's top-k tokens. + student_topk_lps = student_log_probs.gather(dim=-1, index=topk_token_ids) # (B, T, K) + + # Mask out -inf padding (positions where top-k slot was not filled). + valid = topk_teacher_lps > float("-inf") + fwd_per_k = torch.exp(topk_teacher_lps) * (topk_teacher_lps - student_topk_lps) # (B, T, K) + fwd_per_token = (fwd_per_k * valid).sum(dim=-1) # (B, T) + + # ── Reverse KL term: actual token only ── + student_actual_lps = student_log_probs.gather( + dim=-1, index=completion_tokens.unsqueeze(-1) + ).squeeze(-1) # (B, T) + rev_per_token = torch.exp(student_actual_lps) * (student_actual_lps - actual_teacher_lps) # (B, T) + + # ── Combine according to beta ── + if self.beta == 0: + loss_per_token = fwd_per_token + elif self.beta == 1: + loss_per_token = rev_per_token + else: + loss_per_token = self.beta * fwd_per_token + (1 - self.beta) * rev_per_token + + mask = labels != -100 + return loss_per_token[mask].sum() / mask.sum() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): self._raise_if_local_teacher_tokenizer_mismatch() @@ -1035,25 +1148,26 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N labels = inputs["labels"][:, prompt_length:] if self.teacher_client is not None: - # Server path: token-level divergence using top-1 logprobs - teacher_token_logprobs = self._get_teacher_token_logprobs_from_server(inputs) + # Server path: token-level divergence using teacher logprobs. + # The server returns: + # actual_logprobs – (B, T) teacher log p(x_actual) (for reverse KL) + # topk_logprobs – (B, T, K) teacher top-k sorted logprobs (for forward KL) + # topk_token_ids – (B, T, K) corresponding token IDs + teacher_result = self._get_teacher_token_logprobs_from_server(inputs) - # Extract student logprobs for the same tokens in the completion region student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) completion_tokens = inputs["input_ids"][:, prompt_length:] - # Trim to match completion length from server - comp_len = teacher_token_logprobs.shape[1] + + comp_len = teacher_result["actual_logprobs"].shape[1] completion_tokens = completion_tokens[:, :comp_len] - student_token_logprobs = student_log_probs[:, :comp_len, :].gather( - dim=-1, index=completion_tokens.unsqueeze(-1) - ).squeeze(-1) - - loss = self.token_level_divergence_loss( - student_logprobs=student_token_logprobs, - teacher_logprobs=teacher_token_logprobs, - labels=labels[:, :comp_len], - beta=self.beta, + trimmed_labels = labels[:, :comp_len] + + loss = self._compute_server_divergence_loss( + teacher_result=teacher_result, + student_log_probs=student_log_probs[:, :comp_len, :], + completion_tokens=completion_tokens, + labels=trimmed_labels, ) else: # Local teacher: full-vocabulary generalized JSD @@ -1169,18 +1283,14 @@ def training_step( slice_idx = (self._buffer_step - 1) % buffer_steps - # Track on-policy text logs + # Determine if this slice is on-policy is_on_policy = False if self._buffered_on_policy_flags is not None and slice_idx < len(self._buffered_on_policy_flags): is_on_policy = self._buffered_on_policy_flags[slice_idx] - if is_on_policy and self._buffered_text_logs is not None and self._buffered_text_logs[slice_idx] is not None: - prompt_texts, completion_texts = self._buffered_text_logs[slice_idx] - self._textual_logs["prompt"].extend(gather_object(prompt_texts)) - self._textual_logs["completion"].extend(gather_object(completion_texts)) - - # Track completion length stats - labels = inputs.get("labels") + # Track completion length stats — read from buffered inputs (which reflect on-policy generation) + actual_inputs = self._buffered_inputs[slice_idx] if self._buffered_inputs is not None else inputs + labels = actual_inputs.get("labels") if labels is not None: completion_lengths = (labels != -100).sum(dim=1).float() gathered_lengths = self.accelerator.gather(completion_lengths) @@ -1244,32 +1354,42 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() # Log completions to console and wandb - if ( - self.accelerator.is_main_process - and self.log_completions + should_log_completions = ( + self.log_completions + and self.state.global_step > 0 and self.state.global_step % self.log_completions_steps == 0 - ): + ) + + if should_log_completions and self.accelerator.is_main_process: prompts = list(self._textual_logs["prompt"]) completions = list(self._textual_logs["completion"]) - _print_completions_sample(prompts, completions, self.state.global_step, self.num_completions_to_print) - - # Log as a wandb Table - if self.args.report_to and "wandb" in self.args.report_to: - try: - import wandb - - if wandb.run is not None: - import pandas as pd - - table_data = { - "step": [str(self.state.global_step)] * len(prompts), - "prompt": prompts, - "completion": completions, - } - df = pd.DataFrame(table_data) - if self.num_completions_to_print and len(df) > self.num_completions_to_print: - df = df.sample(n=self.num_completions_to_print, random_state=42) - wandb.log({"completions": wandb.Table(dataframe=df)}) - except ImportError: - pass + if prompts: + _print_completions_sample( + prompts, completions, self.state.global_step, self.num_completions_to_print + ) + + # Log as a wandb Table + if self.args.report_to and "wandb" in self.args.report_to: + try: + import wandb + + if wandb.run is not None: + import pandas as pd + + table_data = { + "step": [str(self.state.global_step)] * len(prompts), + "prompt": prompts, + "completion": completions, + } + df = pd.DataFrame(table_data) + if self.num_completions_to_print and len(df) > self.num_completions_to_print: + df = df.sample(n=self.num_completions_to_print, random_state=42) + wandb.log({"completions": wandb.Table(dataframe=df)}) + except ImportError: + pass + + # Clear text logs on all processes after the logging interval + if should_log_completions: + self._textual_logs["prompt"].clear() + self._textual_logs["completion"].clear() From 42f3d4d40c1d7060f199f481bfcfd1c1d87fc40a Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 09:34:29 +0000 Subject: [PATCH 15/37] Fix top-k implementation --- trl/scripts/distillation.py | 10 +++ trl/trainer/distillation_config.py | 56 +++++++++--- trl/trainer/distillation_trainer.py | 135 ++++++++++++++++++++++++---- 3 files changed, 171 insertions(+), 30 deletions(-) diff --git a/trl/scripts/distillation.py b/trl/scripts/distillation.py index 69321da6291..2ba82e318f1 100644 --- a/trl/scripts/distillation.py +++ b/trl/scripts/distillation.py @@ -94,6 +94,16 @@ def main(script_args, training_args, model_args): logger = logging.get_logger(__name__) + ################ + # W&B env vars + ################ + if training_args.wandb_project is not None: + os.environ["WANDB_PROJECT"] = training_args.wandb_project + if training_args.wandb_entity is not None: + os.environ["WANDB_ENTITY"] = training_args.wandb_entity + if training_args.wandb_run_group is not None: + os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group + ################ # Model init kwargs ################ diff --git a/trl/trainer/distillation_config.py b/trl/trainer/distillation_config.py index b2d70309830..70184e4dda8 100644 --- a/trl/trainer/distillation_config.py +++ b/trl/trainer/distillation_config.py @@ -72,10 +72,10 @@ class DistillationConfig(_BaseConfig): Base URL of a vLLM server hosting the teacher model (e.g., `"http://localhost:8000"`). When set, teacher logprobs are fetched from the server instead of running a local forward pass. Mutually exclusive with passing a `teacher_model` object to the trainer. - teacher_server_top_logprobs (`int`, *optional*, defaults to `1`): - Number of top logprobs to request from the teacher server per token position. Only used when - `teacher_model_server_url` is set. Currently only `1` is supported — the server path uses a per-token - logprob approximation of the divergence. Full-vocabulary divergence is only available with a local teacher. + loss_top_k (`int`, *optional*, defaults to `0`): + Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are + restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary + is used. When using `teacher_model_server_url` with `beta > 0`, only `loss_top_k=1` is supported. > Parameters that control on-policy generation @@ -120,6 +120,15 @@ class DistillationConfig(_BaseConfig): vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): Enable vLLM sleep mode to offload student weights during the optimizer step. + > Parameters that control W&B logging + + wandb_entity (`str` or `None`, *optional*): + The W&B entity to store runs under. + wandb_project (`str` or `None`, *optional*): + The W&B project to store runs under. + wandb_run_group (`str` or `None`, *optional*): + The W&B group to store runs under. + > Parameters that control logging log_completions (`bool`, *optional*, defaults to `False`): @@ -215,10 +224,16 @@ class DistillationConfig(_BaseConfig): "When set, teacher logprobs are fetched from the server." }, ) - teacher_server_top_logprobs: int = field( - default=1, - metadata={"help": "Number of top logprobs to request from the teacher server per token position. " - "Currently only `1` is supported for the server path."}, + loss_top_k: int = field( + default=0, + metadata={ + "help": "Number of top tokens to use when computing the JSD/KL loss. " + "Both student and teacher distributions are restricted to these K tokens " + "(selected based on beta: teacher's top-k for forward KL, student's top-k for reverse KL, " + "union of both for JSD) and re-normalized before computing divergence. " + "If 0, the full vocabulary is used (slower but exact). " + "Only supported with a local teacher — not with teacher_model_server_url." + }, ) # On-policy generation @@ -300,6 +315,20 @@ class DistillationConfig(_BaseConfig): metadata={"help": "Enable vLLM sleep mode to offload student weights during the optimizer step."}, ) + # W&B + wandb_entity: str | None = field( + default=None, + metadata={"help": "The W&B entity to store runs under."}, + ) + wandb_project: str | None = field( + default=None, + metadata={"help": "The W&B project to store runs under."}, + ) + wandb_run_group: str | None = field( + default=None, + metadata={"help": "The W&B group to store runs under."}, + ) + # Logging log_completions: bool = field( default=False, @@ -346,12 +375,13 @@ def __post_init__(self): f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) - if self.teacher_model_server_url is not None and self.teacher_server_top_logprobs != 1: + if self.teacher_model_server_url is not None and self.beta > 0 and self.loss_top_k != 1: raise ValueError( - f"When using a teacher server (`teacher_model_server_url`), only `teacher_server_top_logprobs=1` is " - f"supported (got {self.teacher_server_top_logprobs}). The server computes a per-token logprob " - f"approximation of the divergence loss. Full-vocabulary divergence computation is only supported with " - f"a local teacher model." + f"loss_top_k != 1 with beta > 0 is not supported with teacher_model_server_url " + f"(got loss_top_k={self.loss_top_k}, beta={self.beta}). The server path always uses top-1 " + f"logprobs: any reverse KL component (beta > 0) requires the teacher's logprobs at the " + f"student's top-k tokens, which the server cannot provide for k > 1 or full vocabulary (k=0). " + f"Use a local teacher, set loss_top_k=1, or set beta=0 (pure forward KL)." ) if self.num_generations > 1 and self.lmbda < 1.0: diff --git a/trl/trainer/distillation_trainer.py b/trl/trainer/distillation_trainer.py index a579f5ef1bf..a05b4f1b7e9 100644 --- a/trl/trainer/distillation_trainer.py +++ b/trl/trainer/distillation_trainer.py @@ -102,6 +102,63 @@ def _print_completions_sample(prompts: list[str], completions: list[str], step: console.print(panel) +def _add_tail_bucket(log_probs, valid_mask): + """Append a (K+1)-th tail element: log(1 - sum(exp(top_k_logps))). + + This creates a proper probability distribution over K+1 elements, preventing + trivial zero loss when top_k is small (especially top_k=1). + """ + log_sum = torch.logsumexp(log_probs, dim=-1, keepdim=True) + log_sum = torch.clamp(log_sum, max=-1e-7) # ensure sum < 1 + tail = torch.log(-torch.expm1(log_sum)) # log(1 - exp(log_sum)) + tail_mask = torch.ones_like(valid_mask[..., :1], dtype=torch.bool) + return torch.cat([log_probs, tail], dim=-1), torch.cat([valid_mask, tail_mask], dim=-1) + + +def _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask=None): + """Compute JSD (or forward/reverse KL) from log-probability tensors. + + When *support_mask* is not None, uses manual computation with masked + positions zeroed. When None, uses ``F.kl_div``. + """ + if support_mask is not None: + safe_student = torch.where(support_mask, student_log_probs, torch.zeros_like(student_log_probs)) + safe_teacher = torch.where(support_mask, teacher_log_probs, torch.zeros_like(teacher_log_probs)) + student_probs = torch.where(support_mask, student_log_probs.exp(), torch.zeros_like(student_log_probs)) + teacher_probs = torch.where(support_mask, teacher_log_probs.exp(), torch.zeros_like(teacher_log_probs)) + + if beta == 0: + return teacher_probs * (safe_teacher - safe_student) + elif beta == 1: + return student_probs * (safe_student - safe_teacher) + else: + beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) + tiny = torch.finfo(student_probs.dtype).tiny + mixture_probs = (1 - beta_t) * student_probs + beta_t * teacher_probs + safe_mixture = torch.where( + support_mask, + torch.log(mixture_probs.clamp_min(tiny)), + torch.zeros_like(student_log_probs), + ) + kl_teacher = teacher_probs * (safe_teacher - safe_mixture) + kl_student = student_probs * (safe_student - safe_mixture) + return beta_t * kl_teacher + (1 - beta_t) * kl_student + else: + if beta == 0: + return F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + return F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log1p(-beta_t), teacher_log_probs + torch.log(beta_t)]), + dim=0, + ) + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) + return beta_t * kl_teacher + (1 - beta_t) * kl_student + + class _DistillationCollator: """Data collator for the distillation trainer with independent prompt/completion budgets. @@ -447,7 +504,7 @@ def __init__( self.temperature = args.temperature self.top_p = args.top_p self.num_generations = args.num_generations - self.teacher_server_top_logprobs = args.teacher_server_top_logprobs + self.loss_top_k = args.loss_top_k # ── Buffer state ── self._buffered_inputs = None @@ -914,6 +971,7 @@ def generalized_jsd_loss( beta=0.5, temperature=1.0, reduction="batchmean", + top_k=0, ): """ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation. @@ -925,28 +983,62 @@ def generalized_jsd_loss( beta: Interpolation coefficient. 0.0 = forward KL, 1.0 = reverse KL. temperature: Softmax temperature. reduction: 'batchmean', 'sum', 'mean', or 'none'. + top_k: Number of top tokens to restrict the loss to. The support set depends on beta: + beta=0 (forward KL) uses teacher's top-k, beta=1 (reverse KL) uses student's top-k, + 0 0 and student_logits.size(-1) > top_k: + neg_inf = torch.full((), float("-inf"), dtype=student_logits.dtype, device=student_logits.device) + student_log_probs_full = F.log_softmax(student_logits, dim=-1) + teacher_log_probs_full = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + # Forward KL: teacher-selected support + _, support = teacher_logits.topk(top_k, dim=-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + elif beta == 1: + # Reverse KL: student-selected support + _, support = student_logits.topk(top_k, dim=-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + else: + # JSD: union of both supports (concatenate + deduplicate) + _, student_top = student_logits.topk(top_k, dim=-1) + _, teacher_top = teacher_logits.topk(top_k, dim=-1) + support = torch.cat([teacher_top, student_top], dim=-1) + support_mask = torch.ones(support.shape, dtype=torch.bool, device=support.device) + for i in range(1, support.shape[-1]): + prev_matches = support[..., i:i + 1] == support[..., :i] + prev_valid = support_mask[..., :i] + support_mask[..., i] &= ~(prev_matches & prev_valid).any(dim=-1) + support = torch.where(support_mask, support, torch.zeros_like(support)) + + student_support_logps = student_log_probs_full.gather(-1, support) + teacher_support_logps = teacher_log_probs_full.gather(-1, support) + + # Mask invalid (duplicate) positions with -inf + student_topk_logps = torch.where(support_mask, student_support_logps, neg_inf) + teacher_topk_logps = torch.where(support_mask, teacher_support_logps, neg_inf) + + # Add tail bucket: append log(1 - sum(exp(top_k_logps))) to preserve + # the remaining probability mass outside the top-k. This prevents trivial + # zero loss when top_k is small (especially top_k=1). + base_support_mask = support_mask + student_log_probs, support_mask = _add_tail_bucket(student_topk_logps, base_support_mask) + teacher_log_probs, _ = _add_tail_bucket(teacher_topk_logps, base_support_mask) else: - beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) - mixture_log_probs = torch.logsumexp( - torch.stack([student_log_probs + torch.log1p(-beta_t), teacher_log_probs + torch.log(beta_t)]), - dim=0, - ) - kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) - jsd = beta_t * kl_teacher + (1 - beta_t) * kl_student + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + jsd = _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask) if labels is not None: mask = labels != -100 @@ -1047,12 +1139,14 @@ def _get_teacher_token_logprobs_from_server( sequences.append(seq) prompt_lengths.append(prompt_length) - K = self.teacher_server_top_logprobs + # Server path always uses top-1 logprobs. Top-k loss requires a local teacher + # since reverse KL needs the teacher's logprobs at the student's top-k tokens. result = self.teacher_client.get_sequence_logprobs( sequences=sequences, prompt_lengths=prompt_lengths, - top_logprobs=K, + top_logprobs=1, ) + K = 1 device = input_ids.device completion_length = max(len(lps) for lps in result["logprobs"]) @@ -1181,6 +1275,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N labels=labels, beta=self.beta, temperature=self.temperature, + top_k=self.loss_top_k, ) return (loss, student_outputs) if return_outputs else loss @@ -1300,6 +1395,12 @@ def training_step( self._metrics[mode][f"completions/{prefix}_max_length"].append(gathered_lengths.max().item()) self._metrics[mode][f"completions/{prefix}_min_length"].append(gathered_lengths.min().item()) + # Log fraction of completions that hit max_completion_length (truncated) + max_comp_len = getattr(self.generation_config, "max_new_tokens", None) + if is_on_policy and max_comp_len is not None: + truncated_frac = (gathered_lengths >= max_comp_len).float().mean().item() + self._metrics[mode]["completions/truncated_fraction"].append(truncated_frac) + # Track loss per policy type loss_scalar = float(loss.detach()) step_equiv = 1.0 / self.args.gradient_accumulation_steps From d13bc4a48cd02dae92c0c1bd2ecc6399025b61fa Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 14:49:44 +0000 Subject: [PATCH 16/37] Migrate trainer to experimental --- trl/__init__.py | 4 ---- trl/experimental/distillation/__init__.py | 19 +++++++++++++++++++ .../distillation}/distillation.py | 11 ++++++----- .../distillation}/distillation_config.py | 2 +- .../distillation}/distillation_trainer.py | 16 ++++++++-------- trl/trainer/__init__.py | 4 ---- 6 files changed, 34 insertions(+), 22 deletions(-) create mode 100644 trl/experimental/distillation/__init__.py rename trl/{scripts => experimental/distillation}/distillation.py (94%) rename trl/{trainer => experimental/distillation}/distillation_config.py (99%) rename trl/{trainer => experimental/distillation}/distillation_trainer.py (99%) diff --git a/trl/__init__.py b/trl/__init__.py index ea20438dc89..897b4fef917 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -45,8 +45,6 @@ "scripts": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"], "trainer": [ "BEMACallback", - "DistillationConfig", - "DistillationTrainer", "DPOConfig", "DPOTrainer", "GRPOConfig", @@ -90,8 +88,6 @@ from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose from .trainer import ( BEMACallback, - DistillationConfig, - DistillationTrainer, DPOConfig, DPOTrainer, GRPOConfig, diff --git a/trl/experimental/distillation/__init__.py b/trl/experimental/distillation/__init__.py new file mode 100644 index 00000000000..894333c1d01 --- /dev/null +++ b/trl/experimental/distillation/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .distillation_config import DistillationConfig +from .distillation_trainer import DistillationTrainer + + +__all__ = ["DistillationConfig", "DistillationTrainer"] diff --git a/trl/scripts/distillation.py b/trl/experimental/distillation/distillation.py similarity index 94% rename from trl/scripts/distillation.py rename to trl/experimental/distillation/distillation.py index 2ba82e318f1..6e61fed660c 100644 --- a/trl/scripts/distillation.py +++ b/trl/experimental/distillation/distillation.py @@ -25,7 +25,7 @@ """ # Full training (off-policy only, lmbda=0): ``` -python trl/scripts/distillation.py \ +python trl/experimental/distillation/distillation.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ @@ -39,7 +39,7 @@ # Mixed on/off-policy (lmbda=0.5): ``` -python trl/scripts/distillation.py \ +python trl/experimental/distillation/distillation.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ @@ -54,7 +54,7 @@ # LoRA: ``` -python trl/scripts/distillation.py \ +python trl/experimental/distillation/distillation.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ @@ -84,13 +84,13 @@ def main(script_args, training_args, model_args): from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig from trl import ( - DistillationTrainer, LogCompletionsCallback, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config, ) + from trl.experimental.distillation import DistillationTrainer logger = logging.get_logger(__name__) @@ -174,7 +174,8 @@ def main(script_args, training_args, model_args): def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None): - from trl import DistillationConfig, ModelConfig, ScriptArguments, TrlParser + from trl import ModelConfig, ScriptArguments, TrlParser + from trl.experimental.distillation import DistillationConfig dataclass_types = (ScriptArguments, DistillationConfig, ModelConfig) if subparsers is not None: diff --git a/trl/trainer/distillation_config.py b/trl/experimental/distillation/distillation_config.py similarity index 99% rename from trl/trainer/distillation_config.py rename to trl/experimental/distillation/distillation_config.py index 70184e4dda8..56fe685e49d 100644 --- a/trl/trainer/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field from typing import Any -from .base_config import _BaseConfig +from ...trainer.base_config import _BaseConfig @dataclass diff --git a/trl/trainer/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py similarity index 99% rename from trl/trainer/distillation_trainer.py rename to trl/experimental/distillation/distillation_trainer.py index a05b4f1b7e9..23c73dde3bc 100644 --- a/trl/trainer/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -44,20 +44,20 @@ is_rich_available, ) -from ..extras.profiling import profiling_decorator -from ..generation.vllm_generation import VLLMGeneration -from ..import_utils import is_vllm_available -from ..models import prepare_deepspeed -from ..models.utils import unwrap_model_for_generation -from .base_trainer import _BaseTrainer -from .distillation_config import DistillationConfig -from .utils import ( +from ...extras.profiling import profiling_decorator +from ...generation.vllm_generation import VLLMGeneration +from ...import_utils import is_vllm_available +from ...models import prepare_deepspeed +from ...models.utils import unwrap_model_for_generation +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import ( RepeatSampler, create_model_from_path, disable_dropout_in_model, pad, split_tensor_dict, ) +from .distillation_config import DistillationConfig if is_peft_available(): diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index c4f2482e56a..f24ea415072 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -25,8 +25,6 @@ "SyncRefModelCallback", "WeaveCallback", ], - "distillation_config": ["DistillationConfig"], - "distillation_trainer": ["DistillationTrainer"], "dpo_config": ["DPOConfig"], "dpo_trainer": ["DPOTrainer"], "grpo_config": ["GRPOConfig"], @@ -57,8 +55,6 @@ SyncRefModelCallback, WeaveCallback, ) - from .distillation_config import DistillationConfig - from .distillation_trainer import DistillationTrainer from .dpo_config import DPOConfig from .dpo_trainer import DPOTrainer from .grpo_config import GRPOConfig From 64fdafd80310620344a1bd43bb4000c444869770 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 16:45:25 +0000 Subject: [PATCH 17/37] Fix reverse KL calculation for top-1 --- .../distillation/distillation_config.py | 11 +- .../distillation/distillation_trainer.py | 161 ++++++++++++++---- trl/scripts/vllm_serve.py | 13 ++ 3 files changed, 149 insertions(+), 36 deletions(-) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index 56fe685e49d..0d32edd6f8d 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -76,6 +76,9 @@ class DistillationConfig(_BaseConfig): Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary is used. When using `teacher_model_server_url` with `beta > 0`, only `loss_top_k=1` is supported. + loss_add_tail (`bool`, *optional*, defaults to `True`): + Whether to append a tail bucket that represents the remaining probability mass outside the selected top-k + support when computing the loss. > Parameters that control on-policy generation @@ -232,7 +235,13 @@ class DistillationConfig(_BaseConfig): "(selected based on beta: teacher's top-k for forward KL, student's top-k for reverse KL, " "union of both for JSD) and re-normalized before computing divergence. " "If 0, the full vocabulary is used (slower but exact). " - "Only supported with a local teacher — not with teacher_model_server_url." + "When using teacher_model_server_url with beta > 0, only loss_top_k=1 is supported." + }, + ) + loss_add_tail: bool = field( + default=True, + metadata={ + "help": "Whether to append a tail bucket representing the remaining probability mass outside the selected top-k support." }, ) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 23c73dde3bc..d57c033b21d 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -159,6 +159,51 @@ def _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask=Non return beta_t * kl_teacher + (1 - beta_t) * kl_student +def build_teacher_request_inputs( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + prompt_attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, +) -> tuple[list[list[int]], list[int], list[int]]: + """Trim padded batch tensors into per-sample sequences for teacher-server requests.""" + + if input_ids.shape != attention_mask.shape: + raise ValueError( + f"input_ids and attention_mask must have the same shape, got {input_ids.shape} and {attention_mask.shape}." + ) + + input_ids_cpu = input_ids.detach().cpu() + attention_mask_cpu = attention_mask.detach().cpu().bool() + + if prompt_attention_mask is not None: + prompt_lengths = prompt_attention_mask.detach().cpu().sum(dim=1).to(torch.long) + else: + if labels is None: + raise ValueError("labels are required when prompt_attention_mask is not provided.") + if labels.shape != input_ids.shape: + raise ValueError(f"labels must match input_ids shape, got {labels.shape} and {input_ids.shape}.") + full_lengths = attention_mask_cpu.sum(dim=1).to(torch.long) + completion_lengths = (labels.detach().cpu() != -100).sum(dim=1).to(torch.long) + prompt_lengths = full_lengths - completion_lengths + + trimmed_input_ids: list[list[int]] = [] + prompt_lengths_list: list[int] = [] + completion_lengths_list: list[int] = [] + + for row, mask, prompt_length in zip(input_ids_cpu, attention_mask_cpu, prompt_lengths, strict=True): + trimmed_row = row[mask] + prompt_len = int(prompt_length.item()) + if prompt_len < 0 or prompt_len > trimmed_row.numel(): + raise ValueError( + f"Invalid prompt length {prompt_len} for trimmed sequence of length {trimmed_row.numel()}." + ) + trimmed_input_ids.append(trimmed_row.tolist()) + prompt_lengths_list.append(prompt_len) + completion_lengths_list.append(int(trimmed_row.numel()) - prompt_len) + + return trimmed_input_ids, prompt_lengths_list, completion_lengths_list + + class _DistillationCollator: """Data collator for the distillation trainer with independent prompt/completion budgets. @@ -407,7 +452,7 @@ def __init__( self.teacher_client = None self._local_teacher_tokenizer_matches_student = True if args.teacher_model_server_url is not None: - from ..generation.vllm_client import VLLMClient + from ...generation.vllm_client import VLLMClient self.teacher_client = VLLMClient(base_url=args.teacher_model_server_url, connection_timeout=60.0) teacher_model = None @@ -505,6 +550,7 @@ def __init__( self.top_p = args.top_p self.num_generations = args.num_generations self.loss_top_k = args.loss_top_k + self.loss_add_tail = args.loss_add_tail # ── Buffer state ── self._buffered_inputs = None @@ -599,6 +645,16 @@ def _raise_if_local_teacher_tokenizer_mismatch(self) -> None: "override the local teacher loss path in a subclass." ) + def _compute_prompt_length(self, inputs: dict[str, torch.Tensor | Any]) -> int: + """Compute the earliest prompt boundary that still includes every completion token in the batch.""" + if inputs.get("labels") is not None: + attention_mask = inputs["attention_mask"] + labels = inputs["labels"] + full_lengths = attention_mask.sum(dim=1) + completion_lengths = (labels != -100).sum(dim=1) + return int((full_lengths - completion_lengths).min().item()) + return inputs["prompts"].shape[1] + def _get_completion_lengths(self, generated_tokens: torch.Tensor, prompt_width: int) -> torch.Tensor: """Infer per-sample completion lengths from generated tokens.""" completion_tokens = generated_tokens[:, prompt_width:] @@ -972,6 +1028,7 @@ def generalized_jsd_loss( temperature=1.0, reduction="batchmean", top_k=0, + add_tail=True, ): """ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation. @@ -987,6 +1044,8 @@ def generalized_jsd_loss( beta=0 (forward KL) uses teacher's top-k, beta=1 (reverse KL) uses student's top-k, 0 torch.Te raise ValueError("No teacher model or teacher server configured.") def _get_teacher_token_logprobs_from_server( - self, inputs: dict[str, torch.Tensor | Any] + self, + inputs: dict[str, torch.Tensor | Any], + aligned_prompt_length: int, ) -> dict[str, torch.Tensor]: """Fetch per-token teacher logprobs from an external vLLM server. @@ -1126,18 +1191,13 @@ def _get_teacher_token_logprobs_from_server( import numpy as np input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] batch_size = input_ids.shape[0] - prompt_length = inputs["prompts"].shape[1] - - # Extract unpadded sequences - sequences = [] - prompt_lengths = [] - for i in range(batch_size): - valid_mask = attention_mask[i].bool() - seq = input_ids[i][valid_mask].tolist() - sequences.append(seq) - prompt_lengths.append(prompt_length) + sequences, prompt_lengths, _ = build_teacher_request_inputs( + input_ids, + inputs["attention_mask"], + prompt_attention_mask=inputs.get("prompt_attention_mask"), + labels=inputs.get("labels"), + ) # Server path always uses top-1 logprobs. Top-k loss requires a local teacher # since reverse KL needs the teacher's logprobs at the student's top-k tokens. @@ -1149,24 +1209,28 @@ def _get_teacher_token_logprobs_from_server( K = 1 device = input_ids.device - completion_length = max(len(lps) for lps in result["logprobs"]) + completion_offsets = [prompt_length - aligned_prompt_length for prompt_length in prompt_lengths] + completion_length = max( + (offset + len(lps) for offset, lps in zip(completion_offsets, result["logprobs"], strict=True)), + default=0, + ) # actual_logprobs: (B, T) — teacher logprob for the actual token def _actual_to_tensor(key): arr = np.full((batch_size, completion_length), float("-inf"), dtype=np.float32) - for i, seq_lps in enumerate(result[key]): + for i, (offset, seq_lps) in enumerate(zip(completion_offsets, result[key], strict=True)): if seq_lps: vals = np.array(seq_lps, dtype=np.float32) # (comp_len_i, 1) - arr[i, : vals.shape[0]] = vals[:, 0] + arr[i, offset : offset + vals.shape[0]] = vals[:, 0] return torch.from_numpy(arr).to(device) # topk: (B, T, K) def _topk_to_tensor(key, k, np_dtype, fill): arr = np.full((batch_size, completion_length, k), fill, dtype=np_dtype) - for i, seq_vals in enumerate(result[key]): + for i, (offset, seq_vals) in enumerate(zip(completion_offsets, result[key], strict=True)): if seq_vals: vals = np.array(seq_vals, dtype=np_dtype) # (comp_len_i, k) - arr[i, : vals.shape[0], :] = vals + arr[i, offset : offset + vals.shape[0], :] = vals return torch.from_numpy(arr).to(device) return { @@ -1184,10 +1248,11 @@ def _compute_server_divergence_loss( ) -> torch.Tensor: """Compute forward/reverse KL or JSD using teacher logprobs from the server. - The forward KL term sums over the teacher's top-k tokens (better approximation - as k increases). The reverse KL term always uses the actual (student-sampled) - token only, because the teacher's top-k may not cover the student's high-probability - tokens. + The forward KL term uses the teacher's top-k support, optionally collapsed with + a tail bucket. The reverse KL term uses the actual completion token only, because + the server cannot provide teacher logprobs at arbitrary student-selected token IDs. + When the completion is student-sampled, this reverse term is a Monte Carlo + approximation of reverse KL on the available token support. Args: teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), @@ -1196,8 +1261,8 @@ def _compute_server_divergence_loss( completion_tokens: (B, T) actual token IDs in the completion. labels: (B, T) with -100 for positions to ignore. """ - topk_teacher_lps = teacher_result["topk_logprobs"] # (B, T, K) - topk_token_ids = teacher_result["topk_token_ids"] # (B, T, K) + topk_teacher_lps = teacher_result["topk_logprobs"] # (B, T, K) + topk_token_ids = teacher_result["topk_token_ids"] # (B, T, K) actual_teacher_lps = teacher_result["actual_logprobs"] # (B, T) # ── Forward KL term: sum over teacher's top-k tokens ── @@ -1206,14 +1271,39 @@ def _compute_server_divergence_loss( # Mask out -inf padding (positions where top-k slot was not filled). valid = topk_teacher_lps > float("-inf") - fwd_per_k = torch.exp(topk_teacher_lps) * (topk_teacher_lps - student_topk_lps) # (B, T, K) - fwd_per_token = (fwd_per_k * valid).sum(dim=-1) # (B, T) + if self.loss_add_tail: + neg_inf = torch.full((), float("-inf"), dtype=student_log_probs.dtype, device=student_log_probs.device) + student_topk_lps = torch.where(valid, student_topk_lps, neg_inf) + teacher_topk_lps = torch.where(valid, topk_teacher_lps, neg_inf) + forward_student_log_probs, forward_support_mask = _add_tail_bucket(student_topk_lps, valid) + forward_teacher_log_probs, _ = _add_tail_bucket(teacher_topk_lps, valid) + fwd_per_token = _jsd_divergence( + forward_student_log_probs, + forward_teacher_log_probs, + beta=0.0, + support_mask=forward_support_mask, + ).sum(dim=-1) + else: + fwd_per_k = torch.exp(topk_teacher_lps) * (topk_teacher_lps - student_topk_lps) # (B, T, K) + fwd_per_token = (fwd_per_k * valid).sum(dim=-1) # (B, T) # ── Reverse KL term: actual token only ── student_actual_lps = student_log_probs.gather( dim=-1, index=completion_tokens.unsqueeze(-1) ).squeeze(-1) # (B, T) - rev_per_token = torch.exp(student_actual_lps) * (student_actual_lps - actual_teacher_lps) # (B, T) + required = labels != -100 + missing_actual = required & ~torch.isfinite(actual_teacher_lps) + if missing_actual.any(): + missing_count = int(missing_actual.sum().item()) + total_required = int(required.sum().item()) + raise ValueError( + "Teacher server is missing actual-token logprobs for required reverse-KL positions: " + f"{missing_count}/{total_required}." + ) + # Use the sampled token's logprob difference directly. Multiplying by p_student + # again would overweight high-probability tokens because the token itself is + # already drawn from the student distribution in the on-policy path. + rev_per_token = student_actual_lps - actual_teacher_lps # (B, T) # ── Combine according to beta ── if self.beta == 0: @@ -1238,7 +1328,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) - prompt_length = inputs["prompts"].shape[1] + prompt_length = self._compute_prompt_length(inputs) labels = inputs["labels"][:, prompt_length:] if self.teacher_client is not None: @@ -1247,7 +1337,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # actual_logprobs – (B, T) teacher log p(x_actual) (for reverse KL) # topk_logprobs – (B, T, K) teacher top-k sorted logprobs (for forward KL) # topk_token_ids – (B, T, K) corresponding token IDs - teacher_result = self._get_teacher_token_logprobs_from_server(inputs) + teacher_result = self._get_teacher_token_logprobs_from_server(inputs, prompt_length) student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) @@ -1276,6 +1366,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N beta=self.beta, temperature=self.temperature, top_k=self.loss_top_k, + add_tail=self.loss_add_tail, ) return (loss, student_outputs) if return_outputs else loss diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 9c819e4b5a9..6f04a8da742 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -210,6 +210,10 @@ class ScriptArguments: log_level (`str`, *optional*, defaults to `"info"`): Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, `"trace"`. + distributed_executor_backend (`str` or `None`, *optional*): + Distributed executor backend for vLLM. Set to `"ray"` to distribute tensor parallel workers across + multiple nodes via a Ray cluster. Required when `tensor_parallel_size` exceeds the number of local GPUs. + If not set, vLLM defaults to the multiproc backend (single-node only). """ model: str = field( @@ -306,6 +310,14 @@ class ScriptArguments: "model implementation." }, ) + distributed_executor_backend: str | None = field( + default=None, + metadata={ + "help": "Distributed executor backend for vLLM. When set to 'ray', vLLM uses Ray to distribute tensor " + "parallel workers across multiple nodes. Required when tensor_parallel_size exceeds the number of local " + "GPUs. If not set, vLLM defaults to the multiproc backend (single-node only)." + }, + ) def llm_worker( @@ -335,6 +347,7 @@ def llm_worker( worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", trust_remote_code=script_args.trust_remote_code, model_impl=script_args.vllm_model_impl, + distributed_executor_backend=script_args.distributed_executor_backend, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", ) From 0d4bd05008daf8a39205286fe1abc01a44c59464 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 19:01:38 +0200 Subject: [PATCH 18/37] Run precommit --- trl/experimental/distillation/distillation.py | 6 +- .../distillation/distillation_config.py | 4 +- .../distillation/distillation_trainer.py | 125 ++++++++---------- trl/generation/vllm_client.py | 31 ++--- 4 files changed, 70 insertions(+), 96 deletions(-) diff --git a/trl/experimental/distillation/distillation.py b/trl/experimental/distillation/distillation.py index 6e61fed660c..da27587e1af 100644 --- a/trl/experimental/distillation/distillation.py +++ b/trl/experimental/distillation/distillation.py @@ -79,21 +79,17 @@ def main(script_args, training_args, model_args): - from accelerate import logging from datasets import load_dataset - from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + from transformers import GenerationConfig from trl import ( LogCompletionsCallback, - ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config, ) from trl.experimental.distillation import DistillationTrainer - logger = logging.get_logger(__name__) - ################ # W&B env vars ################ diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index 0d32edd6f8d..470da24da89 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -154,9 +154,7 @@ class DistillationConfig(_BaseConfig): ) max_length: int | None = field( default=1024, - metadata={ - "help": "Maximum total sequence length (prompt + completion) for tokenization and truncation." - }, + metadata={"help": "Maximum total sequence length (prompt + completion) for tokenization and truncation."}, ) # Overridden defaults diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index d57c033b21d..56b30c36cb0 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -26,9 +26,9 @@ import torch.nn as nn import torch.nn.functional as F from accelerate.utils import DistributedType, broadcast_object_list, gather_object -from datasets import Dataset, IterableDataset +from datasets import Dataset from torch.utils.data import DataLoader -from transformers import AutoProcessor, AutoTokenizer, TrainerCallback +from transformers import AutoTokenizer, TrainerCallback from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -38,7 +38,6 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_utils import EvalPrediction, seed_worker from transformers.utils import ( - is_datasets_available, is_liger_kernel_available, is_peft_available, is_rich_available, @@ -105,8 +104,8 @@ def _print_completions_sample(prompts: list[str], completions: list[str], step: def _add_tail_bucket(log_probs, valid_mask): """Append a (K+1)-th tail element: log(1 - sum(exp(top_k_logps))). - This creates a proper probability distribution over K+1 elements, preventing - trivial zero loss when top_k is small (especially top_k=1). + This creates a proper probability distribution over K+1 elements, preventing trivial zero loss when top_k is small + (especially top_k=1). """ log_sum = torch.logsumexp(log_probs, dim=-1, keepdim=True) log_sum = torch.clamp(log_sum, max=-1e-7) # ensure sum < 1 @@ -118,8 +117,8 @@ def _add_tail_bucket(log_probs, valid_mask): def _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask=None): """Compute JSD (or forward/reverse KL) from log-probability tensors. - When *support_mask* is not None, uses manual computation with masked - positions zeroed. When None, uses ``F.kl_div``. + When *support_mask* is not None, uses manual computation with masked positions zeroed. When None, uses + ``F.kl_div``. """ if support_mask is not None: safe_student = torch.where(support_mask, student_log_probs, torch.zeros_like(student_log_probs)) @@ -207,9 +206,9 @@ def build_teacher_request_inputs( class _DistillationCollator: """Data collator for the distillation trainer with independent prompt/completion budgets. - Unlike ``DataCollatorForChatML``, this collator tokenizes prompts and completions separately so - that long completions can never truncate the prompt to empty. It also handles prompt-only data - (no assistant completions) for pure on-policy distillation (``lmbda=1``). + Unlike ``DataCollatorForChatML``, this collator tokenizes prompts and completions separately so that long + completions can never truncate the prompt to empty. It also handles prompt-only data (no assistant completions) for + pure on-policy distillation (``lmbda=1``). """ def __init__( @@ -258,16 +257,16 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: formatted_full = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) - full_ids = self.tokenizer( - formatted_full, truncation=False, padding=False, add_special_tokens=False - )["input_ids"] + full_ids = self.tokenizer(formatted_full, truncation=False, padding=False, add_special_tokens=False)[ + "input_ids" + ] # Identify completion tokens: everything after the prompt in the full sequence. # Use the un-truncated prompt length as the split point. formatted_prompt_ids = self.tokenizer( formatted_prompt, truncation=False, padding=False, add_special_tokens=False )["input_ids"] - completion_ids = full_ids[len(formatted_prompt_ids):] + completion_ids = full_ids[len(formatted_prompt_ids) :] # Trim completion so prompt + completion <= max_length max_comp = self.max_length - len(prompt_ids) @@ -291,23 +290,28 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: pad_id = self.tokenizer.pad_token_id input_ids_t = pad( [torch.tensor(ids, dtype=torch.long) for ids in all_input_ids], - padding_side="left", padding_value=pad_id, + padding_side="left", + padding_value=pad_id, ) attention_mask_t = pad( [torch.ones(len(ids), dtype=torch.long) for ids in all_input_ids], - padding_side="left", padding_value=0, + padding_side="left", + padding_value=0, ) labels_t = pad( [torch.tensor(lab, dtype=torch.long) for lab in all_labels], - padding_side="left", padding_value=self.ignore_index, + padding_side="left", + padding_value=self.ignore_index, ) prompts_t = pad( [torch.tensor(ids, dtype=torch.long) for ids in all_prompt_ids], - padding_side="left", padding_value=pad_id, + padding_side="left", + padding_value=pad_id, ) prompt_mask_t = pad( [torch.ones(len(ids), dtype=torch.long) for ids in all_prompt_ids], - padding_side="left", padding_value=0, + padding_side="left", + padding_value=0, ) return { @@ -322,10 +326,9 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: class _RepeatBatchDataLoader: """Repeats each collated batch ``repeat_count`` times without re-collation. - ``RepeatSampler`` with ``repeat_count > 1`` causes the DataLoader to re-collate - (re-tokenize) the same examples on every repeat, which is wasteful. This wrapper - instead keeps ``repeat_count=1`` in the sampler and repeats the already-collated - tensor dict, avoiding redundant tokenization. + ``RepeatSampler`` with ``repeat_count > 1`` causes the DataLoader to re-collate (re-tokenize) the same examples on + every repeat, which is wasteful. This wrapper instead keeps ``repeat_count=1`` in the sampler and repeats the + already-collated tensor dict, avoiding redundant tokenization. """ def __init__(self, dataloader, repeat_count: int): @@ -717,8 +720,8 @@ def get_train_dataloader(self): Override to load one generation batch per optimizer window. The dataloader yields batches of size `per_device_train_batch_size * gradient_accumulation_steps`. - RepeatSampler ensures each generation batch is repeated `gradient_accumulation_steps` times so the - Trainer's loop iterates the correct number of times. + RepeatSampler ensures each generation batch is repeated `gradient_accumulation_steps` times so the Trainer's + loop iterates the correct number of times. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -807,9 +810,7 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s self._textual_logs["completion"].extend(gather_object(on_policy_completions)) @profiling_decorator - def _generate_student_completions( - self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] - ): + def _generate_student_completions(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): """Generate completions from the student model for on-policy training.""" if not self.use_vllm: self._generate_with_model(slices, on_policy_indices) @@ -918,8 +919,8 @@ def _store_completions_in_buffer( ): """Process vLLM completions and store them in the buffer. - Uses original prompt token IDs directly (no decode/re-encode roundtrip), following the same - approach as GRPOTrainer. + Uses original prompt token IDs directly (no decode/re-encode roundtrip), following the same approach as + GRPOTrainer. """ device = self.accelerator.device pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 @@ -943,10 +944,9 @@ def _store_completions_in_buffer( ).long() # Left-pad prompt token IDs to the longest prompt in this slice - prompt_ids = torch.stack([ - F.pad(p, (prompt_width - len(p), 0), value=pad_token_id) - for p in prompt_id_tensors - ]).to(device) + prompt_ids = torch.stack( + [F.pad(p, (prompt_width - len(p), 0), value=pad_token_id) for p in prompt_id_tensors] + ).to(device) # Pad/truncate completions (right-pad to max_completion_length) completion_tensors = [] @@ -959,9 +959,7 @@ def _store_completions_in_buffer( completion_ids_for_text.append(t.tolist()) completion_lengths.append(len(t)) if len(t) < max_completion_length: - padding = torch.full( - (max_completion_length - len(t),), pad_token_id, device=device, dtype=t.dtype - ) + padding = torch.full((max_completion_length - len(t),), pad_token_id, device=device, dtype=t.dtype) t = torch.cat([t, padding]) completion_tensors.append(t) @@ -1002,12 +1000,8 @@ def _build_sequence_batch( prompt_token_lengths = prompt_token_lengths.to(device=new_input_ids.device, dtype=torch.long) completion_lengths = completion_lengths.to(device=new_input_ids.device, dtype=torch.long) positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) - prompt_mask = (positions < prompt_width) & ( - positions >= (prompt_width - prompt_token_lengths).unsqueeze(1) - ) - completion_mask = (positions >= prompt_width) & ( - positions < (prompt_width + completion_lengths).unsqueeze(1) - ) + prompt_mask = (positions < prompt_width) & (positions >= (prompt_width - prompt_token_lengths).unsqueeze(1)) + completion_mask = (positions >= prompt_width) & (positions < (prompt_width + completion_lengths).unsqueeze(1)) new_attention_mask = (prompt_mask | completion_mask).long() new_labels = torch.full_like(new_input_ids, -100) @@ -1041,9 +1035,9 @@ def generalized_jsd_loss( temperature: Softmax temperature. reduction: 'batchmean', 'sum', 'mean', or 'none'. top_k: Number of top tokens to restrict the loss to. The support set depends on beta: - beta=0 (forward KL) uses teacher's top-k, beta=1 (reverse KL) uses student's top-k, - 0 torch.Tensor: """Compute forward/reverse KL or JSD using teacher logprobs from the server. - The forward KL term uses the teacher's top-k support, optionally collapsed with - a tail bucket. The reverse KL term uses the actual completion token only, because - the server cannot provide teacher logprobs at arbitrary student-selected token IDs. - When the completion is student-sampled, this reverse term is a Monte Carlo + The forward KL term uses the teacher's top-k support, optionally collapsed with a tail bucket. The reverse KL + term uses the actual completion token only, because the server cannot provide teacher logprobs at arbitrary + student-selected token IDs. When the completion is student-sampled, this reverse term is a Monte Carlo approximation of reverse KL on the available token support. Args: @@ -1288,9 +1281,9 @@ def _compute_server_divergence_loss( fwd_per_token = (fwd_per_k * valid).sum(dim=-1) # (B, T) # ── Reverse KL term: actual token only ── - student_actual_lps = student_log_probs.gather( - dim=-1, index=completion_tokens.unsqueeze(-1) - ).squeeze(-1) # (B, T) + student_actual_lps = student_log_probs.gather(dim=-1, index=completion_tokens.unsqueeze(-1)).squeeze( + -1 + ) # (B, T) required = labels != -100 missing_actual = required & ~torch.isfinite(actual_teacher_lps) if missing_actual.any(): @@ -1410,9 +1403,7 @@ def _compute_liger_loss(self, model, inputs): teacher_hidden = teacher_hidden.reshape(-1, teacher_hidden.shape[-1]) labels_mask = inputs["labels"] != -100 - masked_input_ids = torch.where( - labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) - ) + masked_input_ids = torch.where(labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)) true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) student_head = unwrapped_student.get_output_embeddings() @@ -1557,9 +1548,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: completions = list(self._textual_logs["completion"]) if prompts: - _print_completions_sample( - prompts, completions, self.state.global_step, self.num_completions_to_print - ) + _print_completions_sample(prompts, completions, self.state.global_step, self.num_completions_to_print) # Log as a wandb Table if self.args.report_to and "wandb" in self.args.report_to: diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 9f2a6d3b696..c7126f6b09d 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -554,8 +554,8 @@ def get_sequence_logprobs( use_binary (`bool`, *optional*, defaults to `True`): Use binary (base64 numpy) response format for faster serialization. chunk_size (`int`, *optional*, defaults to `0`): - If > 0, split batch into chunks of this size and dispatch concurrently. - If 0, send the entire batch in a single request. + If > 0, split batch into chunks of this size and dispatch concurrently. If 0, send the entire batch in + a single request. max_concurrent_requests (`int`, *optional*, defaults to `4`): Maximum number of concurrent requests when using chunked dispatch. @@ -567,11 +567,8 @@ def get_sequence_logprobs( - `logprob_token_ids` (`list[list[list[int]]]`): Token IDs corresponding to each logprob, same shape as `logprobs`. """ - import base64 from concurrent.futures import ThreadPoolExecutor, as_completed - import numpy as np - url = f"{self.base_url}/get_sequence_logprobs/" response_format = "binary" if use_binary else "json" @@ -600,8 +597,7 @@ def _send_chunk(idx, seqs, plens): with ThreadPoolExecutor(max_workers=min(max_concurrent_requests, len(chunks))) as executor: futures = { - executor.submit(_send_chunk, idx, seqs, plens): idx - for idx, (seqs, plens) in enumerate(chunks) + executor.submit(_send_chunk, idx, seqs, plens): idx for idx, (seqs, plens) in enumerate(chunks) } for future in as_completed(futures): idx, result = future.result() @@ -646,11 +642,9 @@ def _decode_binary_logprobs(response: dict) -> dict[str, list]: Returns a dict with: ``logprobs`` / ``logprob_token_ids`` — teacher's sorted top-k logprobs and - token IDs (shape per sequence: ``(comp_len, top_k)``). Used for the - forward KL term. + token IDs (shape per sequence: ``(comp_len, top_k)``). Used for the forward KL term. ``actual_logprobs`` / ``actual_token_ids`` — teacher logprob for the actual - token at each position (shape per sequence: ``(comp_len, 1)``). Used - for the reverse KL term. + token at each position (shape per sequence: ``(comp_len, 1)``). Used for the reverse KL term. """ import base64 @@ -674,12 +668,12 @@ def _decode_binary_logprobs(response: dict) -> dict[str, list]: # Decode actual-token logprobs (for reverse KL) if "actual_logprobs_b64" in response: actual_shape = [shape[0], shape[1], 1] - actual_lp = np.frombuffer( - base64.b64decode(response["actual_logprobs_b64"]), dtype=np.float32 - ).reshape(actual_shape) - actual_ids = np.frombuffer( - base64.b64decode(response["actual_token_ids_b64"]), dtype=np.int32 - ).reshape(actual_shape) + actual_lp = np.frombuffer(base64.b64decode(response["actual_logprobs_b64"]), dtype=np.float32).reshape( + actual_shape + ) + actual_ids = np.frombuffer(base64.b64decode(response["actual_token_ids_b64"]), dtype=np.int32).reshape( + actual_shape + ) all_actual_lps = [] all_actual_ids = [] for i, comp_len in enumerate(comp_lengths): @@ -693,9 +687,6 @@ def _decode_binary_logprobs(response: dict) -> dict[str, list]: @staticmethod def _merge_binary_responses(responses: list[dict], top_logprobs: int) -> dict[str, list]: """Merge binary responses from multiple chunks into a single result.""" - import base64 - - import numpy as np all_logprobs = [] all_token_ids = [] From 59aa0075345b07de4f78dd66b32ea211eb138f1d Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 19:10:08 +0200 Subject: [PATCH 19/37] Address cursor comments --- docs/source/paper_index.md | 31 ++++++++++++++++++---------- trl/scripts/vllm_serve.py | 41 ++++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 9fe2857b085..c368fd3286a 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1563,13 +1563,13 @@ Papers relating to training a student model with the help of a teacher model. **📜 Paper**: https://huggingface.co/papers/2306.13649 -Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration: +Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.distillation.DistillationTrainer`] and [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration: ```python -from trl.experimental.gkd import GKDConfig +from trl.experimental.distillation import DistillationConfig # XSum summarization task (Table A.1 of the paper) -training_args = GKDConfig( +training_args = DistillationConfig( lmbda=0.5, # λ student data fraction (Section 3 of the paper) beta=0.5, # β Generalized JSD interpolation, 0=KL, 1=reverse KL (Section 3 of the paper) temperature=1.0, # student training temperature (Appendix A of the paper) @@ -1577,7 +1577,7 @@ training_args = GKDConfig( learning_rate=3e-4, # learning rate (Table A.1 of the paper) per_device_train_batch_size=32, # batch size (Table A.1 of the paper) warmup_steps=2000, # warm-up steps (Table A.1 of the paper) - max_new_tokens=64, # max output tokens (Table A.1 of the paper) + max_completion_length=64, # max output tokens (Table A.1 of the paper) ) ``` @@ -1597,20 +1597,31 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data. -To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]: +To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.distillation.DistillationTrainer`] and [`experimental.distillation.DistillationConfig`]: + +```python +from trl.experimental.distillation import DistillationConfig + +training_args = DistillationConfig( + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model +) +``` + +Alternatively, you can use the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]: ```python from trl.experimental.gkd import GKDConfig training_args = GKDConfig( - lmbda=1.0, # student produces rollouts for all batches - beta=1.0, # to ensure reverse-kl as the loss function - teacher_model_name_or_path="teacher-model", # specify the teacher model - + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model ) ``` -Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: +You can also use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: ```python from trl.experimental import GOLDConfig diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 6f04a8da742..f1adafaaefa 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -211,9 +211,9 @@ class ScriptArguments: Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, `"trace"`. distributed_executor_backend (`str` or `None`, *optional*): - Distributed executor backend for vLLM. Set to `"ray"` to distribute tensor parallel workers across - multiple nodes via a Ray cluster. Required when `tensor_parallel_size` exceeds the number of local GPUs. - If not set, vLLM defaults to the multiproc backend (single-node only). + Distributed executor backend for vLLM. Set to `"ray"` to distribute tensor parallel workers across multiple + nodes via a Ray cluster. Required when `tensor_parallel_size` exceeds the number of local GPUs. If not set, + vLLM defaults to the multiproc backend (single-node only). """ model: str = field( @@ -694,7 +694,6 @@ def _run_prompt_logprobs(prompts, sampling_params): # all DP workers stay busy. Without this, async endpoint handlers block the event # loop during pipe I/O, serializing requests and leaving DP workers idle. import asyncio - import threading _logprob_queue: asyncio.Queue = asyncio.Queue() @@ -756,7 +755,7 @@ async def _logprob_batcher(): all_prompts = [] all_prompt_lengths = [] offsets = [] # (start_idx, count) per original request - for prompts, prompt_lengths, response_format, future in items: + for prompts, prompt_lengths, _response_format, _future in items: start = len(all_prompts) all_prompts.extend(prompts) all_prompt_lengths.extend(prompt_lengths) @@ -770,9 +769,7 @@ async def _logprob_batcher(): # Dispatch to workers in a thread to avoid blocking the event loop try: - all_outputs = await loop.run_in_executor( - None, _run_prompt_logprobs, all_prompts, sampling_params - ) + all_outputs = await loop.run_in_executor(None, _run_prompt_logprobs, all_prompts, sampling_params) except Exception as e: # Signal error to all waiting requests for _, _, _, future in items: @@ -781,7 +778,7 @@ async def _logprob_batcher(): continue # Split results back to individual requests - for (start, count), (_, prompt_lengths, response_format, future) in zip(offsets, items): + for (start, count), (_, prompt_lengths, response_format, future) in zip(offsets, items, strict=False): outputs_slice = all_outputs[start : start + count] if not future.done(): future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) @@ -812,9 +809,7 @@ def _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format actual_logprobs_arr = np.full((batch_size, max_comp_len, 1), float("-inf"), dtype=np.float32) actual_token_ids_arr = np.zeros((batch_size, max_comp_len, 1), dtype=np.int32) - for i, (output, prompt_length) in enumerate( - zip(all_outputs, prompt_lengths, strict=True) - ): + for i, (output, prompt_length) in enumerate(zip(all_outputs, prompt_lengths, strict=True)): prompt_lps = output.prompt_logprobs seq_tokens = output.prompt_token_ids if comp_lengths[i] == 0: @@ -896,9 +891,9 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): """ Computes teacher logprobs for existing token sequences without generating new tokens. - Concurrent requests are automatically batched and dispatched together to maximize - GPU utilization across DP workers. This avoids the event-loop-blocking problem where - synchronous pipe I/O serializes requests despite having multiple DP workers. + Concurrent requests are automatically batched and dispatched together to maximize GPU utilization across DP + workers. This avoids the event-loop-blocking problem where synchronous pipe I/O serializes requests despite + having multiple DP workers. Supports two response formats: - `"json"` (default): Nested lists, backward-compatible with existing clients. @@ -921,13 +916,15 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): # Submit to the batching queue and await result loop = asyncio.get_event_loop() future = loop.create_future() - await _logprob_queue.put(( - prompts, - list(request.prompt_lengths), - request.top_logprobs, - request.response_format, - future, - )) + await _logprob_queue.put( + ( + prompts, + list(request.prompt_lengths), + request.top_logprobs, + request.response_format, + future, + ) + ) # Wait for the batcher to process our request all_outputs, prompt_lengths, top_k, response_format = await future From 91bffa49b8d1f0810fb71c97c56d45a27a3a02a1 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 18:59:15 +0000 Subject: [PATCH 20/37] Fix reverse KL computation --- .../distillation/distillation_trainer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 56b30c36cb0..c42efd1599c 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1244,8 +1244,8 @@ def _compute_server_divergence_loss( The forward KL term uses the teacher's top-k support, optionally collapsed with a tail bucket. The reverse KL term uses the actual completion token only, because the server cannot provide teacher logprobs at arbitrary - student-selected token IDs. When the completion is student-sampled, this reverse term is a Monte Carlo - approximation of reverse KL on the available token support. + student-selected token IDs. This makes the reverse term a sparse token-level surrogate rather than the exact + reverse KL used by the local-teacher path. Args: teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), @@ -1293,10 +1293,8 @@ def _compute_server_divergence_loss( "Teacher server is missing actual-token logprobs for required reverse-KL positions: " f"{missing_count}/{total_required}." ) - # Use the sampled token's logprob difference directly. Multiplying by p_student - # again would overweight high-probability tokens because the token itself is - # already drawn from the student distribution in the on-policy path. - rev_per_token = student_actual_lps - actual_teacher_lps # (B, T) + student_actual_ps = torch.exp(student_actual_lps) + rev_per_token = student_actual_ps * (student_actual_lps - actual_teacher_lps) # (B, T) # ── Combine according to beta ── if self.beta == 0: @@ -1304,7 +1302,7 @@ def _compute_server_divergence_loss( elif self.beta == 1: loss_per_token = rev_per_token else: - loss_per_token = self.beta * fwd_per_token + (1 - self.beta) * rev_per_token + loss_per_token = (1 - self.beta) * fwd_per_token + self.beta * rev_per_token mask = labels != -100 return loss_per_token[mask].sum() / mask.sum() From 6d0fe13f2c7c15eb77c42cd35c49ec49d8a27c7e Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 19:00:29 +0000 Subject: [PATCH 21/37] Add `DistillationTrainer` to table of contents --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 624ba22b91b..ec3b927858c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -95,6 +95,8 @@ title: BCO - local: cpo_trainer title: CPO + - local: distillation_trainer + title: Distillation - local: gfpo title: GFPO - local: gkd_trainer From b55996eb0e204d6f65e45aa8211031d48deb75e8 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 21:09:31 +0200 Subject: [PATCH 22/37] Remove `DistillationTrainer` from toc --- docs/source/_toctree.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index ec3b927858c..624ba22b91b 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -95,8 +95,6 @@ title: BCO - local: cpo_trainer title: CPO - - local: distillation_trainer - title: Distillation - local: gfpo title: GFPO - local: gkd_trainer From 9727ec7125e9cf879ba1019f06209844967c7efd Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 19:45:38 +0000 Subject: [PATCH 23/37] Tighten logic for different top-k scenarios --- .../distillation/distillation_config.py | 18 +++++++------ .../distillation/distillation_trainer.py | 25 +++++++++++-------- trl/generation/vllm_client.py | 8 ++++++ trl/scripts/vllm_serve.py | 16 ++++++------ 4 files changed, 42 insertions(+), 25 deletions(-) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index 470da24da89..587a4b6c77b 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -75,7 +75,9 @@ class DistillationConfig(_BaseConfig): loss_top_k (`int`, *optional*, defaults to `0`): Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary - is used. When using `teacher_model_server_url` with `beta > 0`, only `loss_top_k=1` is supported. + is used. When using `teacher_model_server_url`, the pure forward path (`beta=0`) requires this to be + positive and uses the teacher's top-k logprobs for the forward term. Any reverse component (`beta>0`) + uses the realized token's logprob for the reverse term and top-1 teacher support for the forward term. loss_add_tail (`bool`, *optional*, defaults to `True`): Whether to append a tail bucket that represents the remaining probability mass outside the selected top-k support when computing the loss. @@ -233,7 +235,9 @@ class DistillationConfig(_BaseConfig): "(selected based on beta: teacher's top-k for forward KL, student's top-k for reverse KL, " "union of both for JSD) and re-normalized before computing divergence. " "If 0, the full vocabulary is used (slower but exact). " - "When using teacher_model_server_url with beta > 0, only loss_top_k=1 is supported." + "When using teacher_model_server_url, beta=0 requires loss_top_k > 0 and uses the teacher's top-k " + "logprobs for the forward term. Any reverse component uses the realized token's logprob for the reverse " + "term and top-1 teacher support for the forward term." }, ) loss_add_tail: bool = field( @@ -382,13 +386,11 @@ def __post_init__(self): f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) - if self.teacher_model_server_url is not None and self.beta > 0 and self.loss_top_k != 1: + if self.teacher_model_server_url is not None and self.beta == 0 and self.loss_top_k < 1: raise ValueError( - f"loss_top_k != 1 with beta > 0 is not supported with teacher_model_server_url " - f"(got loss_top_k={self.loss_top_k}, beta={self.beta}). The server path always uses top-1 " - f"logprobs: any reverse KL component (beta > 0) requires the teacher's logprobs at the " - f"student's top-k tokens, which the server cannot provide for k > 1 or full vocabulary (k=0). " - f"Use a local teacher, set loss_top_k=1, or set beta=0 (pure forward KL)." + f"loss_top_k must be positive when using teacher_model_server_url with beta=0 " + f"(got loss_top_k={self.loss_top_k}). The pure forward server path only has access to the " + f"teacher's top-k logprobs, so it cannot compute the exact full-vocabulary loss when loss_top_k=0." ) if self.num_generations > 1 and self.lmbda < 1.0: diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index c42efd1599c..122513310d1 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1120,8 +1120,8 @@ def token_level_divergence_loss( """ Compute a per-token approximation of the generalized JSD loss using only the sampled token's logprobs. - This is used when the teacher is an external server and only top-1 logprobs are available. For each token - position, we have log p_student(token) and log p_teacher(token) and compute: + This is used when the teacher is an external server and the reverse component only has access to the realized + token's logprob. For each token position, we have log p_student(token) and log p_teacher(token) and compute: - beta=0 (forward KL): exp(log_teacher) * (log_teacher - log_student) - beta=1 (reverse KL): exp(log_student) * (log_student - log_teacher) - 0 < beta < 1 (JSD): weighted combination of forward and reverse token-level KL terms @@ -1193,14 +1193,18 @@ def _get_teacher_token_logprobs_from_server( labels=inputs.get("labels"), ) - # Server path always uses top-1 logprobs. Top-k loss requires a local teacher - # since reverse KL needs the teacher's logprobs at the student's top-k tokens. + # The external-teacher path can use the teacher's configured top-k support for + # pure forward KL. Once a reverse component is present, the reverse term only + # has access to the realized token's logprob, so we also limit the forward term + # to top-1 teacher support to keep the mixed objective balanced. + requested_top_k = self.loss_top_k if self.beta == 0 else 1 result = self.teacher_client.get_sequence_logprobs( sequences=sequences, prompt_lengths=prompt_lengths, - top_logprobs=1, + top_logprobs=requested_top_k, + temperature=self.temperature, ) - K = 1 + K = requested_top_k device = input_ids.device completion_offsets = [prompt_length - aligned_prompt_length for prompt_length in prompt_lengths] @@ -1242,10 +1246,11 @@ def _compute_server_divergence_loss( ) -> torch.Tensor: """Compute forward/reverse KL or JSD using teacher logprobs from the server. - The forward KL term uses the teacher's top-k support, optionally collapsed with a tail bucket. The reverse KL - term uses the actual completion token only, because the server cannot provide teacher logprobs at arbitrary - student-selected token IDs. This makes the reverse term a sparse token-level surrogate rather than the exact - reverse KL used by the local-teacher path. + The forward KL term uses the teacher's top-k support for pure forward KL and top-1 teacher support whenever a + reverse component is present, optionally collapsed with a tail bucket. The reverse KL term uses the actual + completion token only, because the server cannot provide teacher logprobs at arbitrary student-selected token + IDs. This makes the reverse term a sparse token-level surrogate rather than the exact reverse KL used by the + local-teacher path. Args: teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index c7126f6b09d..92edb892a34 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -528,6 +528,7 @@ def get_sequence_logprobs( sequences: list[list[int]], prompt_lengths: list[int], top_logprobs: int = 100, + temperature: float = 1.0, use_binary: bool = True, chunk_size: int = 0, max_concurrent_requests: int = 4, @@ -551,6 +552,8 @@ def get_sequence_logprobs( Number of prompt tokens in each sequence. Logprobs are returned starting from this position. top_logprobs (`int`, *optional*, defaults to `100`): Number of top logprobs to return per token position. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature used when scoring the teacher distribution. use_binary (`bool`, *optional*, defaults to `True`): Use binary (base64 numpy) response format for faster serialization. chunk_size (`int`, *optional*, defaults to `0`): @@ -569,6 +572,9 @@ def get_sequence_logprobs( """ from concurrent.futures import ThreadPoolExecutor, as_completed + if temperature <= 0: + raise ValueError(f"temperature must be positive, got {temperature}") + url = f"{self.base_url}/get_sequence_logprobs/" response_format = "binary" if use_binary else "json" @@ -588,6 +594,7 @@ def _send_chunk(idx, seqs, plens): "sequences": seqs, "prompt_lengths": plens, "top_logprobs": top_logprobs, + "temperature": temperature, "response_format": response_format, }, ) @@ -621,6 +628,7 @@ def _send_chunk(idx, seqs, plens): "sequences": sequences, "prompt_lengths": prompt_lengths, "top_logprobs": top_logprobs, + "temperature": temperature, "response_format": response_format, }, ) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index f1adafaaefa..1f1d858f0bf 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -666,6 +666,7 @@ class SequenceLogprobsRequest(BaseModel): sequences: list[list[int]] prompt_lengths: list[int] top_logprobs: int = 100 + temperature: float = 1.0 response_format: str = "json" # "json" (legacy) or "binary" (base64 numpy arrays) class SequenceLogprobsResponse(BaseModel): @@ -740,17 +741,17 @@ async def _logprob_batcher(): except asyncio.TimeoutError: break - # batch is a list of (prompts, prompt_lengths, top_logprobs, response_format, future) - # All items in a batch must share the same top_logprobs (enforced at dispatch time) - # Group by top_logprobs to handle mixed requests + # batch is a list of (prompts, prompt_lengths, top_logprobs, temperature, response_format, future) + # All items in a batch must share the same (top_logprobs, temperature) pair. + # Group by those execution parameters to handle mixed requests. groups = {} - for prompts, prompt_lengths, top_logprobs, response_format, future in batch: - key = top_logprobs + for prompts, prompt_lengths, top_logprobs, temperature, response_format, future in batch: + key = (top_logprobs, temperature) if key not in groups: groups[key] = [] groups[key].append((prompts, prompt_lengths, response_format, future)) - for top_logprobs, items in groups.items(): + for (top_logprobs, temperature), items in groups.items(): # Merge all sequences into a single batch all_prompts = [] all_prompt_lengths = [] @@ -763,7 +764,7 @@ async def _logprob_batcher(): sampling_params = SamplingParams( max_tokens=1, - temperature=1.0, + temperature=temperature, prompt_logprobs=top_logprobs, ) @@ -921,6 +922,7 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): prompts, list(request.prompt_lengths), request.top_logprobs, + request.temperature, request.response_format, future, ) From 939b53b759b8c29c74193d10a43b6479088eda5f Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 30 Mar 2026 21:54:09 +0000 Subject: [PATCH 24/37] Add tail bucket to reverse KL + server case --- .../distillation/distillation_trainer.py | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 122513310d1..58a84ed5e6f 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1118,19 +1118,22 @@ def token_level_divergence_loss( beta=0.5, ): """ - Compute a per-token approximation of the generalized JSD loss using only the sampled token's logprobs. + Compute a per-token weighted forward/reverse KL surrogate using only the sampled token's logprobs. - This is used when the teacher is an external server and the reverse component only has access to the realized - token's logprob. For each token position, we have log p_student(token) and log p_teacher(token) and compute: + This helper keeps the legacy sparse approximation where only the realized token's logprob is available. For + each token position, we have log p_student(token) and log p_teacher(token) and compute: - beta=0 (forward KL): exp(log_teacher) * (log_teacher - log_student) - beta=1 (reverse KL): exp(log_student) * (log_student - log_teacher) - - 0 < beta < 1 (JSD): weighted combination of forward and reverse token-level KL terms + - 0 < beta < 1: weighted combination of forward and reverse token-level KL terms + + This is not the exact generalized JSD used by `generalized_jsd_loss` when `0 < beta < 1`. Args: student_logprobs: Tensor of shape (batch_size, completion_length) — student's log-prob per token. teacher_logprobs: Tensor of shape (batch_size, completion_length) — teacher's log-prob per token. labels: Tensor of shape (batch_size, completion_length) with -100 for positions to ignore. - beta: Interpolation coefficient. 0.0 = forward KL, 0.5 = JSD, 1.0 = reverse KL. + beta: Interpolation coefficient. 0.0 = forward KL surrogate, 1.0 = reverse KL surrogate, and intermediate + values interpolate between them. Returns: Scalar loss tensor. @@ -1244,13 +1247,16 @@ def _compute_server_divergence_loss( completion_tokens: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: - """Compute forward/reverse KL or JSD using teacher logprobs from the server. + """Compute sparse server-side forward/reverse KL surrogates using teacher logprobs from the server. The forward KL term uses the teacher's top-k support for pure forward KL and top-1 teacher support whenever a reverse component is present, optionally collapsed with a tail bucket. The reverse KL term uses the actual completion token only, because the server cannot provide teacher logprobs at arbitrary student-selected token - IDs. This makes the reverse term a sparse token-level surrogate rather than the exact reverse KL used by the - local-teacher path. + IDs. When `self.loss_add_tail` is enabled, the reverse term becomes a two-bucket KL over + `{actual_token, residual_tail}`. This matches `generalized_jsd_loss(..., beta=1, top_k=1, add_tail=True)` + when the sampled token is the same token used in the top-1 support and the logprob values agree. For + `0 < beta < 1`, this method still returns a weighted combination of the forward and reverse surrogates rather + than the exact generalized JSD used by the local-teacher path. Args: teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), @@ -1298,8 +1304,21 @@ def _compute_server_divergence_loss( "Teacher server is missing actual-token logprobs for required reverse-KL positions: " f"{missing_count}/{total_required}." ) - student_actual_ps = torch.exp(student_actual_lps) - rev_per_token = student_actual_ps * (student_actual_lps - actual_teacher_lps) # (B, T) + if self.loss_add_tail: + rev_student_lps = student_actual_lps.unsqueeze(-1) # (B, T, 1) + rev_teacher_lps = actual_teacher_lps.unsqueeze(-1) # (B, T, 1) + rev_valid = torch.ones_like(rev_student_lps, dtype=torch.bool) + rev_student_log_probs, rev_support_mask = _add_tail_bucket(rev_student_lps, rev_valid) + rev_teacher_log_probs, _ = _add_tail_bucket(rev_teacher_lps, rev_valid) + rev_per_token = _jsd_divergence( + rev_student_log_probs, + rev_teacher_log_probs, + beta=1.0, + support_mask=rev_support_mask, + ).sum(dim=-1) + else: + student_actual_ps = torch.exp(student_actual_lps) + rev_per_token = student_actual_ps * (student_actual_lps - actual_teacher_lps) # (B, T) # ── Combine according to beta ── if self.beta == 0: From bcf21d22b1f6937f3dce8877a4e067a485ff1c93 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 11:20:05 +0200 Subject: [PATCH 25/37] Fix dead code when using full vocab for external teacher --- trl/experimental/distillation/distillation_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 58a84ed5e6f..b60bfff481b 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1167,7 +1167,11 @@ def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Te attention_mask=inputs["attention_mask"], ).logits elif self.teacher_client is not None: - return self._get_teacher_logits_from_server(inputs) + raise NotImplementedError( + "Fetching full teacher logits from `teacher_model_server_url` is not supported. " + "Server-backed distillation only supports per-token logprobs via " + "`_get_teacher_token_logprobs_from_server`." + ) else: raise ValueError("No teacher model or teacher server configured.") From aaf47ad6a3b7004408de824dec80f4b79b8baa79 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 11:39:07 +0200 Subject: [PATCH 26/37] Remove unused function --- .../distillation/distillation_trainer.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index b60bfff481b..45cd1a675a7 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1110,53 +1110,6 @@ def generalized_jsd_loss( else: return jsd - @staticmethod - def token_level_divergence_loss( - student_logprobs, - teacher_logprobs, - labels=None, - beta=0.5, - ): - """ - Compute a per-token weighted forward/reverse KL surrogate using only the sampled token's logprobs. - - This helper keeps the legacy sparse approximation where only the realized token's logprob is available. For - each token position, we have log p_student(token) and log p_teacher(token) and compute: - - beta=0 (forward KL): exp(log_teacher) * (log_teacher - log_student) - - beta=1 (reverse KL): exp(log_student) * (log_student - log_teacher) - - 0 < beta < 1: weighted combination of forward and reverse token-level KL terms - - This is not the exact generalized JSD used by `generalized_jsd_loss` when `0 < beta < 1`. - - Args: - student_logprobs: Tensor of shape (batch_size, completion_length) — student's log-prob per token. - teacher_logprobs: Tensor of shape (batch_size, completion_length) — teacher's log-prob per token. - labels: Tensor of shape (batch_size, completion_length) with -100 for positions to ignore. - beta: Interpolation coefficient. 0.0 = forward KL surrogate, 1.0 = reverse KL surrogate, and intermediate - values interpolate between them. - - Returns: - Scalar loss tensor. - """ - if beta == 0: - # Forward KL: p_teacher * (log_teacher - log_student) - loss = torch.exp(teacher_logprobs) * (teacher_logprobs - student_logprobs) - elif beta == 1: - # Reverse KL: p_student * (log_student - log_teacher) - loss = torch.exp(student_logprobs) * (student_logprobs - teacher_logprobs) - else: - # Token-level JSD approximation - forward_kl = torch.exp(teacher_logprobs) * (teacher_logprobs - student_logprobs) - reverse_kl = torch.exp(student_logprobs) * (student_logprobs - teacher_logprobs) - loss = beta * forward_kl + (1 - beta) * reverse_kl - - if labels is not None: - mask = labels != -100 - loss = loss[mask] - return loss.sum() / mask.sum() - - return loss.mean() - def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: """Get teacher logits — dispatches between local model and external server.""" if self.teacher_model is not None: From 89a7b883236f58b15e806275e66ca84d79b84f63 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 11:49:04 +0200 Subject: [PATCH 27/37] Add guard in config for liger + external teacher --- trl/experimental/distillation/distillation_config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index 587a4b6c77b..e409c0850c8 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -386,6 +386,12 @@ def __post_init__(self): f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) + if self.teacher_model_server_url is not None and self.use_liger_kernel: + raise ValueError( + "use_liger_kernel=True is not supported with teacher_model_server_url because the Liger loss path " + "requires a local teacher model." + ) + if self.teacher_model_server_url is not None and self.beta == 0 and self.loss_top_k < 1: raise ValueError( f"loss_top_k must be positive when using teacher_model_server_url with beta=0 " From 16a538058828f521e8f30108979a378a8d09d784 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 12:27:20 +0200 Subject: [PATCH 28/37] Tighten alignment logic --- .../distillation/distillation_trainer.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 45cd1a675a7..f65dc84b3d1 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1146,7 +1146,7 @@ def _get_teacher_token_logprobs_from_server( input_ids = inputs["input_ids"] batch_size = input_ids.shape[0] - sequences, prompt_lengths, _ = build_teacher_request_inputs( + sequences, prompt_lengths, completion_lengths = build_teacher_request_inputs( input_ids, inputs["attention_mask"], prompt_attention_mask=inputs.get("prompt_attention_mask"), @@ -1167,11 +1167,23 @@ def _get_teacher_token_logprobs_from_server( K = requested_top_k device = input_ids.device - completion_offsets = [prompt_length - aligned_prompt_length for prompt_length in prompt_lengths] - completion_length = max( - (offset + len(lps) for offset, lps in zip(completion_offsets, result["logprobs"], strict=True)), - default=0, - ) + completion_length = input_ids.shape[1] - aligned_prompt_length + labels = inputs.get("labels") + if labels is None: + raise ValueError("labels are required to align teacher-server logprobs with the student loss tensors.") + + # The student loss slices tensors in padded-sequence coordinates starting at `aligned_prompt_length`. + # Place each teacher completion into that same coordinate system by locating the first non-masked completion + # token in `labels`. This works for both left-padded off-policy batches and on-policy batches where + # completions are right-padded after a fixed-width prompt block. + completion_offsets = [] + label_mask = labels != -100 + for sample_mask, comp_len in zip(label_mask, completion_lengths, strict=True): + if comp_len == 0: + completion_offsets.append(0) + continue + completion_start = int(torch.nonzero(sample_mask, as_tuple=False)[0].item()) + completion_offsets.append(completion_start - aligned_prompt_length) # actual_logprobs: (B, T) — teacher logprob for the actual token def _actual_to_tensor(key): From 1c7b9de703a79bcce737b4a543b8e291d897aa62 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 12:27:50 +0200 Subject: [PATCH 29/37] Run precommit --- trl/experimental/distillation/distillation_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index f65dc84b3d1..5fdbb9b756b 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1221,11 +1221,11 @@ def _compute_server_divergence_loss( The forward KL term uses the teacher's top-k support for pure forward KL and top-1 teacher support whenever a reverse component is present, optionally collapsed with a tail bucket. The reverse KL term uses the actual completion token only, because the server cannot provide teacher logprobs at arbitrary student-selected token - IDs. When `self.loss_add_tail` is enabled, the reverse term becomes a two-bucket KL over - `{actual_token, residual_tail}`. This matches `generalized_jsd_loss(..., beta=1, top_k=1, add_tail=True)` - when the sampled token is the same token used in the top-1 support and the logprob values agree. For - `0 < beta < 1`, this method still returns a weighted combination of the forward and reverse surrogates rather - than the exact generalized JSD used by the local-teacher path. + IDs. When `self.loss_add_tail` is enabled, the reverse term becomes a two-bucket KL over `{actual_token, + residual_tail}`. This matches `generalized_jsd_loss(..., beta=1, top_k=1, add_tail=True)` when the sampled + token is the same token used in the top-1 support and the logprob values agree. For `0 < beta < 1`, this method + still returns a weighted combination of the forward and reverse surrogates rather than the exact generalized + JSD used by the local-teacher path. Args: teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), From 3cb4bc3827315669aeb5e60f31a40e990ed935d6 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 13:16:19 +0000 Subject: [PATCH 30/37] Correct completion logic --- .../distillation/distillation_trainer.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 5fdbb9b756b..888c6477c8f 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -127,9 +127,9 @@ def _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask=Non teacher_probs = torch.where(support_mask, teacher_log_probs.exp(), torch.zeros_like(teacher_log_probs)) if beta == 0: - return teacher_probs * (safe_teacher - safe_student) + return torch.nan_to_num(teacher_probs * (safe_teacher - safe_student), nan=0.0) elif beta == 1: - return student_probs * (safe_student - safe_teacher) + return torch.nan_to_num(student_probs * (safe_student - safe_teacher), nan=0.0) else: beta_t = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) tiny = torch.finfo(student_probs.dtype).tiny @@ -139,8 +139,8 @@ def _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask=Non torch.log(mixture_probs.clamp_min(tiny)), torch.zeros_like(student_log_probs), ) - kl_teacher = teacher_probs * (safe_teacher - safe_mixture) - kl_student = student_probs * (safe_student - safe_mixture) + kl_teacher = torch.nan_to_num(teacher_probs * (safe_teacher - safe_mixture), nan=0.0) + kl_student = torch.nan_to_num(student_probs * (safe_student - safe_mixture), nan=0.0) return beta_t * kl_teacher + (1 - beta_t) * kl_student else: if beta == 0: @@ -1167,7 +1167,6 @@ def _get_teacher_token_logprobs_from_server( K = requested_top_k device = input_ids.device - completion_length = input_ids.shape[1] - aligned_prompt_length labels = inputs.get("labels") if labels is None: raise ValueError("labels are required to align teacher-server logprobs with the student loss tensors.") @@ -1185,6 +1184,14 @@ def _get_teacher_token_logprobs_from_server( completion_start = int(torch.nonzero(sample_mask, as_tuple=False)[0].item()) completion_offsets.append(completion_start - aligned_prompt_length) + # Size the output tensors to tightly fit the teacher logprobs. Using the full padded + # sequence length would include padding positions with -inf teacher logprobs, producing + # inf in the forward pass and NaN gradients in the backward pass (0 * inf = NaN). + completion_length = max( + (offset + len(lps) for offset, lps in zip(completion_offsets, result["logprobs"], strict=True)), + default=0, + ) + # actual_logprobs: (B, T) — teacher logprob for the actual token def _actual_to_tensor(key): arr = np.full((batch_size, completion_length), float("-inf"), dtype=np.float32) @@ -1298,7 +1305,10 @@ def _compute_server_divergence_loss( loss_per_token = (1 - self.beta) * fwd_per_token + self.beta * rev_per_token mask = labels != -100 - return loss_per_token[mask].sum() / mask.sum() + num_tokens = mask.sum() + if num_tokens == 0: + return loss_per_token.sum() * 0.0 # no completion tokens — return zero-grad scalar + return loss_per_token[mask].sum() / num_tokens def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): self._raise_if_local_teacher_tokenizer_mismatch() From d1b97a8fd2cc9701e10cc94a0e41234d8bcc74c8 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 31 Mar 2026 14:20:38 +0000 Subject: [PATCH 31/37] Remove wandb config params --- trl/experimental/distillation/distillation.py | 10 ---------- trl/experimental/distillation/distillation_config.py | 9 --------- 2 files changed, 19 deletions(-) diff --git a/trl/experimental/distillation/distillation.py b/trl/experimental/distillation/distillation.py index da27587e1af..368cf4ed03e 100644 --- a/trl/experimental/distillation/distillation.py +++ b/trl/experimental/distillation/distillation.py @@ -90,16 +90,6 @@ def main(script_args, training_args, model_args): ) from trl.experimental.distillation import DistillationTrainer - ################ - # W&B env vars - ################ - if training_args.wandb_project is not None: - os.environ["WANDB_PROJECT"] = training_args.wandb_project - if training_args.wandb_entity is not None: - os.environ["WANDB_ENTITY"] = training_args.wandb_entity - if training_args.wandb_run_group is not None: - os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group - ################ # Model init kwargs ################ diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index e409c0850c8..6ac2feae8f5 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -125,15 +125,6 @@ class DistillationConfig(_BaseConfig): vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): Enable vLLM sleep mode to offload student weights during the optimizer step. - > Parameters that control W&B logging - - wandb_entity (`str` or `None`, *optional*): - The W&B entity to store runs under. - wandb_project (`str` or `None`, *optional*): - The W&B project to store runs under. - wandb_run_group (`str` or `None`, *optional*): - The W&B group to store runs under. - > Parameters that control logging log_completions (`bool`, *optional*, defaults to `False`): From 691f46cb71ccba2f3e467ee25afff4d18332bacc Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 1 Apr 2026 12:06:12 +0000 Subject: [PATCH 32/37] Address Albert's comments --- trl/generation/vllm_client.py | 18 ++++++++++++------ trl/scripts/vllm_serve.py | 16 +++++++--------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 92edb892a34..4a5cf8d7de0 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -654,8 +654,6 @@ def _decode_binary_logprobs(response: dict) -> dict[str, list]: ``actual_logprobs`` / ``actual_token_ids`` — teacher logprob for the actual token at each position (shape per sequence: ``(comp_len, 1)``). Used for the reverse KL term. """ - import base64 - import numpy as np shape = response["shape"] # [batch, max_comp_len, top_k] @@ -700,17 +698,25 @@ def _merge_binary_responses(responses: list[dict], top_logprobs: int) -> dict[st all_token_ids = [] all_actual_lps = [] all_actual_ids = [] - has_actual = False + has_actual_flags = [] for resp in responses: decoded = VLLMClient._decode_binary_logprobs(resp) all_logprobs.extend(decoded["logprobs"]) all_token_ids.extend(decoded["logprob_token_ids"]) - if "actual_logprobs" in decoded: - has_actual = True + has_actual = "actual_logprobs" in decoded + has_actual_flags.append(has_actual) + if has_actual: all_actual_lps.extend(decoded["actual_logprobs"]) all_actual_ids.extend(decoded["actual_token_ids"]) + + if any(has_actual_flags) and not all(has_actual_flags): + raise ValueError( + "Inconsistent responses: some chunks contain 'actual_logprobs' while others do not. " + "All responses in a batch must either all include or all exclude actual token logprobs." + ) + result = {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} - if has_actual: + if all(has_actual_flags) and has_actual_flags: result["actual_logprobs"] = all_actual_lps result["actual_token_ids"] = all_actual_ids return result diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 1f1d858f0bf..08a36411d90 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -396,6 +396,8 @@ def chunk_list(lst: list, n: int) -> list[list]: def main(script_args: ScriptArguments): + import asyncio + from packaging.version import Version from transformers import is_vision_available @@ -454,8 +456,6 @@ def main(script_args: ScriptArguments): @asynccontextmanager async def lifespan(app: FastAPI): - import asyncio as _asyncio - # Wait for all workers to send "ready" ready_connections = set() while len(ready_connections) < script_args.data_parallel_size: @@ -465,7 +465,7 @@ async def lifespan(app: FastAPI): ready_connections.add(connection) # Start the logprob request batcher background task - batcher_task = _asyncio.create_task(_logprob_batcher()) + batcher_task = asyncio.create_task(_logprob_batcher()) yield @@ -694,8 +694,6 @@ def _run_prompt_logprobs(prompts, sampling_params): # Collects concurrent requests into batches and dispatches them together so that # all DP workers stay busy. Without this, async endpoint handlers block the event # loop during pipe I/O, serializing requests and leaving DP workers idle. - import asyncio - _logprob_queue: asyncio.Queue = asyncio.Queue() # Maximum time (seconds) to wait for more requests before dispatching a batch. @@ -710,7 +708,7 @@ def _run_prompt_logprobs(prompts, sampling_params): async def _logprob_batcher(): """Background task that continuously drains the queue, batches requests, and dispatches.""" - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() while True: # Wait for the first request @@ -779,7 +777,7 @@ async def _logprob_batcher(): continue # Split results back to individual requests - for (start, count), (_, prompt_lengths, response_format, future) in zip(offsets, items, strict=False): + for (start, count), (_, prompt_lengths, response_format, future) in zip(offsets, items, strict=True): outputs_slice = all_outputs[start : start + count] if not future.done(): future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) @@ -915,7 +913,7 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): prompts = [{"prompt_token_ids": seq} for seq in request.sequences] # Submit to the batching queue and await result - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() future = loop.create_future() await _logprob_queue.put( ( @@ -931,7 +929,7 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): # Wait for the batcher to process our request all_outputs, prompt_lengths, top_k, response_format = await future - return _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format) + return await loop.run_in_executor(None, _format_logprob_response, all_outputs, prompt_lengths, top_k, response_format) class ChatRequest(BaseModel): messages: list[list[dict]] From 9ef9192a25e972a0bad3e96009863fe2906f3248 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 1 Apr 2026 16:18:09 +0200 Subject: [PATCH 33/37] Run precommit --- trl/scripts/vllm_serve.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 08a36411d90..6800dbe32f1 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -929,7 +929,9 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): # Wait for the batcher to process our request all_outputs, prompt_lengths, top_k, response_format = await future - return await loop.run_in_executor(None, _format_logprob_response, all_outputs, prompt_lengths, top_k, response_format) + return await loop.run_in_executor( + None, _format_logprob_response, all_outputs, prompt_lengths, top_k, response_format + ) class ChatRequest(BaseModel): messages: list[list[dict]] From bb2eee6f74198591471aa974740092e2fde2b6aa Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 2 Apr 2026 12:37:45 +0200 Subject: [PATCH 34/37] Address PR comments --- .../distillation/distillation_trainer.py | 7 +- trl/generation/vllm_client.py | 18 +- trl/scripts/vllm_serve.py | 171 ++++++++++-------- 3 files changed, 113 insertions(+), 83 deletions(-) diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 888c6477c8f..a6997527aaf 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1102,7 +1102,12 @@ def generalized_jsd_loss( jsd = jsd[mask] if reduction == "batchmean": - return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0) + if labels is not None: + num_tokens = mask.sum() + if num_tokens == 0: + return jsd.sum() * 0.0 # no completion tokens — return zero-grad scalar + return jsd.sum() / num_tokens + return jsd.sum() / jsd.size(0) elif reduction == "sum": return jsd.sum() elif reduction == "mean": diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 4a5cf8d7de0..cb05ae5a551 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -698,25 +698,21 @@ def _merge_binary_responses(responses: list[dict], top_logprobs: int) -> dict[st all_token_ids = [] all_actual_lps = [] all_actual_ids = [] - has_actual_flags = [] for resp in responses: decoded = VLLMClient._decode_binary_logprobs(resp) all_logprobs.extend(decoded["logprobs"]) all_token_ids.extend(decoded["logprob_token_ids"]) - has_actual = "actual_logprobs" in decoded - has_actual_flags.append(has_actual) - if has_actual: + if "actual_logprobs" in decoded: all_actual_lps.extend(decoded["actual_logprobs"]) all_actual_ids.extend(decoded["actual_token_ids"]) - if any(has_actual_flags) and not all(has_actual_flags): - raise ValueError( - "Inconsistent responses: some chunks contain 'actual_logprobs' while others do not. " - "All responses in a batch must either all include or all exclude actual token logprobs." - ) - result = {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} - if all(has_actual_flags) and has_actual_flags: + if all_actual_lps: + if len(all_actual_lps) != len(all_logprobs): + raise ValueError( + f"Inconsistent chunks: {len(all_actual_lps)} actual_logprobs entries " + f"but {len(all_logprobs)} logprobs entries." + ) result["actual_logprobs"] = all_actual_lps result["actual_token_ids"] = all_actual_ids return result diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 6800dbe32f1..4341d22c129 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -675,6 +675,8 @@ class SequenceLogprobsResponse(BaseModel): # Binary format fields (base64-encoded numpy arrays) logprobs_b64: str | None = None token_ids_b64: str | None = None + actual_logprobs_b64: str | None = None + actual_token_ids_b64: str | None = None shape: list[int] | None = None # [batch_size, max_completion_len, top_logprobs] completion_lengths: list[int] | None = None # actual completion length per sample @@ -698,8 +700,8 @@ def _run_prompt_logprobs(prompts, sampling_params): # Maximum time (seconds) to wait for more requests before dispatching a batch. _BATCH_WAIT_S = 0.005 # 5ms - short enough to not add much latency when lightly loaded - # Maximum number of sequences per batch (set to DP size so each worker gets sequences) - _MAX_BATCH_SEQS = max(script_args.data_parallel_size * 4, 16) + # Maximum number of HTTP requests to collect per batcher cycle + _MAX_BATCH_REQUESTS = max(script_args.data_parallel_size * 4, 16) # Maximum total tokens per batch. prompt_logprobs materializes full-vocab logits # during the forward pass, so each worker can safely handle ~1 max-length sequence. # Budget = max_model_len * dp_size gives ~1 sequence per worker at max length. @@ -711,76 +713,85 @@ async def _logprob_batcher(): loop = asyncio.get_running_loop() while True: - # Wait for the first request batch = [] - batch_tokens = 0 - item = await _logprob_queue.get() - batch.append(item) - # Count tokens in this item's sequences - for prompt in item[0]: - batch_tokens += len(prompt.get("prompt_token_ids", [])) - - # Collect more requests up to batch limit, timeout, or token budget - deadline = loop.time() + _BATCH_WAIT_S - while len(batch) < _MAX_BATCH_SEQS and batch_tokens < _MAX_BATCH_TOKENS: - remaining = deadline - loop.time() - if remaining <= 0: - break - try: - item = await asyncio.wait_for(_logprob_queue.get(), timeout=remaining) - # Check if adding this item would exceed the token budget - item_tokens = sum(len(p.get("prompt_token_ids", [])) for p in item[0]) - if batch_tokens + item_tokens > _MAX_BATCH_TOKENS and len(batch) > 0: - # Put it back and dispatch current batch - await _logprob_queue.put(item) + try: + # Wait for the first request + batch_tokens = 0 + item = await _logprob_queue.get() + batch.append(item) + # Count tokens in this item's sequences + for prompt in item[0]: + batch_tokens += len(prompt.get("prompt_token_ids", [])) + + # Collect more requests up to batch limit, timeout, or token budget + deadline = loop.time() + _BATCH_WAIT_S + while len(batch) < _MAX_BATCH_REQUESTS and batch_tokens < _MAX_BATCH_TOKENS: + remaining = deadline - loop.time() + if remaining <= 0: + break + try: + item = await asyncio.wait_for(_logprob_queue.get(), timeout=remaining) + # Check if adding this item would exceed the token budget + item_tokens = sum(len(p.get("prompt_token_ids", [])) for p in item[0]) + if batch_tokens + item_tokens > _MAX_BATCH_TOKENS and len(batch) > 0: + # Put it back and dispatch current batch + await _logprob_queue.put(item) + break + batch.append(item) + batch_tokens += item_tokens + except asyncio.TimeoutError: break - batch.append(item) - batch_tokens += item_tokens - except asyncio.TimeoutError: - break - - # batch is a list of (prompts, prompt_lengths, top_logprobs, temperature, response_format, future) - # All items in a batch must share the same (top_logprobs, temperature) pair. - # Group by those execution parameters to handle mixed requests. - groups = {} - for prompts, prompt_lengths, top_logprobs, temperature, response_format, future in batch: - key = (top_logprobs, temperature) - if key not in groups: - groups[key] = [] - groups[key].append((prompts, prompt_lengths, response_format, future)) - - for (top_logprobs, temperature), items in groups.items(): - # Merge all sequences into a single batch - all_prompts = [] - all_prompt_lengths = [] - offsets = [] # (start_idx, count) per original request - for prompts, prompt_lengths, _response_format, _future in items: - start = len(all_prompts) - all_prompts.extend(prompts) - all_prompt_lengths.extend(prompt_lengths) - offsets.append((start, len(prompts))) - - sampling_params = SamplingParams( - max_tokens=1, - temperature=temperature, - prompt_logprobs=top_logprobs, - ) - # Dispatch to workers in a thread to avoid blocking the event loop - try: - all_outputs = await loop.run_in_executor(None, _run_prompt_logprobs, all_prompts, sampling_params) - except Exception as e: - # Signal error to all waiting requests - for _, _, _, future in items: - if not future.done(): - future.set_exception(e) - continue + # batch is a list of (prompts, prompt_lengths, top_logprobs, temperature, response_format, future) + # All items in a batch must share the same (top_logprobs, temperature) pair. + # Group by those execution parameters to handle mixed requests. + groups = {} + for prompts, prompt_lengths, top_logprobs, temperature, response_format, future in batch: + key = (top_logprobs, temperature) + if key not in groups: + groups[key] = [] + groups[key].append((prompts, prompt_lengths, response_format, future)) + + for (top_logprobs, temperature), items in groups.items(): + # Merge all sequences into a single batch + all_prompts = [] + all_prompt_lengths = [] + offsets = [] # (start_idx, count) per original request + for prompts, prompt_lengths, _response_format, _future in items: + start = len(all_prompts) + all_prompts.extend(prompts) + all_prompt_lengths.extend(prompt_lengths) + offsets.append((start, len(prompts))) + + sampling_params = SamplingParams( + max_tokens=1, + temperature=temperature, + prompt_logprobs=top_logprobs, + ) - # Split results back to individual requests - for (start, count), (_, prompt_lengths, response_format, future) in zip(offsets, items, strict=True): - outputs_slice = all_outputs[start : start + count] + # Dispatch to workers in a thread to avoid blocking the event loop + try: + all_outputs = await loop.run_in_executor( + None, _run_prompt_logprobs, all_prompts, sampling_params + ) + + # Split results back to individual requests + for (start, count), (_, prompt_lengths, response_format, future) in zip( + offsets, items, strict=True + ): + outputs_slice = all_outputs[start : start + count] + if not future.done(): + future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) + except Exception as e: + # Signal error to all waiting requests in this execution-parameter group + for _, _, _, future in items: + if not future.done(): + future.set_exception(e) + except Exception as e: + # Prevent killing the batcher task — signal error to all unfulfilled futures + for *_, future in batch: if not future.done(): - future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) + future.set_exception(e) def _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format): """Format vLLM outputs into the response dict (runs in any thread).""" @@ -885,7 +896,7 @@ def _format_logprob_response(all_outputs, prompt_lengths, top_k, response_format all_token_ids.append(seq_token_ids) return {"logprobs": all_logprobs, "logprob_token_ids": all_token_ids} - @app.post("/get_sequence_logprobs/") + @app.post("/get_sequence_logprobs/", response_model=SequenceLogprobsResponse) async def get_sequence_logprobs(request: SequenceLogprobsRequest): """ Computes teacher logprobs for existing token sequences without generating new tokens. @@ -894,9 +905,27 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): workers. This avoids the event-loop-blocking problem where synchronous pipe I/O serializes requests despite having multiple DP workers. - Supports two response formats: - - `"json"` (default): Nested lists, backward-compatible with existing clients. - - `"binary"`: Base64-encoded numpy arrays for fast serialization/deserialization. + Args: + request (`SequenceLogprobsRequest`): + - `sequences` (list of list of `int`): Full token sequences (prompt + completion) per sample. + - `prompt_lengths` (list of `int`): Number of prompt tokens per sequence; completion logprobs start + after each prompt. + - `top_logprobs` (`int`, *optional*, defaults to `100`): Number of top teacher logprobs to return per + completion position (sorted by vLLM rank). + - `temperature` (`float`, *optional*, defaults to `1.0`): Sampling temperature passed to vLLM for + logprob computation. + - `response_format` (`str`, *optional*, defaults to `"json"`): Either `"json"` (nested lists, + backward-compatible) or `"binary"` (base64-encoded numpy arrays for fast serialization). + + Returns: + `SequenceLogprobsResponse` or Starlette `Response`: + When `response_format` is `"json"`, a JSON object with: + - `logprobs` (list of list of list of `float` or `None`): Top-k teacher logprobs per completion token. + - `logprob_token_ids` (list of list of list of `int`): Token IDs aligned with `logprobs`. + When `response_format` is `"binary"`, a JSON response (Starlette `Response` if `orjson` is installed) + whose body is a JSON object with base64-encoded float32/int32 arrays: `logprobs_b64`, `token_ids_b64`, + `actual_logprobs_b64`, `actual_token_ids_b64`, plus `shape` (`list[int]`, `[batch_size, + max_completion_len, top_k]`) and `completion_lengths` (`list[int]`). """ if len(request.sequences) != len(request.prompt_lengths): raise ValueError("sequences and prompt_lengths must have the same length.") From 088db425f1a8a46f52f0001a2735dafd26c74bff Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 7 Apr 2026 16:53:09 +0200 Subject: [PATCH 35/37] Match behavior between local and external teacher when top-1 and beta > 0 --- .../distillation/distillation_config.py | 78 ++++- .../distillation/distillation_trainer.py | 323 ++++++++++++------ 2 files changed, 276 insertions(+), 125 deletions(-) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index 6ac2feae8f5..bb379a350d9 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -53,6 +53,11 @@ class DistillationConfig(_BaseConfig): Interpolation coefficient for the Generalized Jensen-Shannon Divergence loss. When `0.0`, the loss is the forward KL divergence. When `1.0`, the loss is the reverse KL divergence. When `0.5`, it is the standard JSD. + reverse_kl_top_1_mode (`str`, *optional*, defaults to `"sampled"`): + Selection rule for the reverse-KL top-1 token when `beta > 0` and `loss_top_k == 1`. `"sampled"` uses the + actual completion token in the batch. `"argmax"` uses the student's highest-probability token. This + setting does not affect the forward-KL support, which always uses the teacher's top-1 token. Ignored when + `beta == 0` or `loss_top_k != 1`. max_completion_length (`int`, *optional*, defaults to `256`): Maximum number of tokens to generate per completion during on-policy generation. disable_dropout (`bool`, *optional*, defaults to `True`): @@ -61,23 +66,26 @@ class DistillationConfig(_BaseConfig): > Parameters that control the teacher model teacher_model_name_or_path (`str` or `None`, *optional*): - Model name or path for the teacher model. Used when the teacher is loaded locally. Mutually exclusive with - `teacher_model_server_url`. + Model name or path for the teacher model. Used when the teacher is loaded locally. teacher_model_revision (`str` or `None`, *optional*): Model revision of the teacher model (e.g., branch name, tag, or commit hash). teacher_model_init_kwargs (`dict[str, Any]` or `None`, *optional*): Keyword arguments passed to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string. + use_teacher_server (`bool`, *optional*, defaults to `False`): + Whether to use an external vLLM teacher server instead of a local teacher model. teacher_model_server_url (`str` or `None`, *optional*): Base URL of a vLLM server hosting the teacher model (e.g., `"http://localhost:8000"`). When set, teacher - logprobs are fetched from the server instead of running a local forward pass. Mutually exclusive with - passing a `teacher_model` object to the trainer. + logprobs are fetched from the server instead of running a local forward pass when `use_teacher_server=True`. loss_top_k (`int`, *optional*, defaults to `0`): Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary - is used. When using `teacher_model_server_url`, the pure forward path (`beta=0`) requires this to be - positive and uses the teacher's top-k logprobs for the forward term. Any reverse component (`beta>0`) - uses the realized token's logprob for the reverse term and top-1 teacher support for the forward term. + is used. For local teachers, the general support rule is teacher top-k for forward KL, student top-k for + reverse KL, and the union for mixed JSD. When `beta > 0` and `loss_top_k == 1`, the forward support still + uses the teacher's top-1 token, while the reverse top-1 token is controlled by `reverse_kl_top_1_mode`. + When `use_teacher_server=True`, the pure forward path (`beta=0`) requires this to be positive and uses the + teacher's top-k logprobs for the forward term. When `beta > 0`, server-backed distillation requires + `loss_top_k == 1` and only supports `"sampled"` reverse top-1 tokens. loss_add_tail (`bool`, *optional*, defaults to `True`): Whether to append a tail bucket that represents the remaining probability mass outside the selected top-k support when computing the loss. @@ -177,6 +185,14 @@ class DistillationConfig(_BaseConfig): "0.0 = forward KL, 0.5 = JSD, 1.0 = reverse KL." }, ) + reverse_kl_top_1_mode: str = field( + default="sampled", + metadata={ + "help": "Reverse-KL top-1 token selection when beta > 0 and loss_top_k == 1. " + "Use 'sampled' for the actual completion token or 'argmax' for the student's top-1 token. " + "The forward-KL support always uses the teacher's top-1 token. Ignored when beta == 0 or loss_top_k != 1." + }, + ) max_completion_length: int = field( default=256, metadata={"help": "Maximum number of tokens to generate per completion."}, @@ -211,11 +227,17 @@ class DistillationConfig(_BaseConfig): ) # Teacher model (external vLLM server) + use_teacher_server: bool = field( + default=False, + metadata={ + "help": "Whether to use an external vLLM teacher server instead of a local teacher model." + }, + ) teacher_model_server_url: str | None = field( default=None, metadata={ "help": 'Base URL of a vLLM server hosting the teacher model (e.g., "http://localhost:8000"). ' - "When set, teacher logprobs are fetched from the server." + "Required when use_teacher_server=True." }, ) loss_top_k: int = field( @@ -226,9 +248,11 @@ class DistillationConfig(_BaseConfig): "(selected based on beta: teacher's top-k for forward KL, student's top-k for reverse KL, " "union of both for JSD) and re-normalized before computing divergence. " "If 0, the full vocabulary is used (slower but exact). " - "When using teacher_model_server_url, beta=0 requires loss_top_k > 0 and uses the teacher's top-k " - "logprobs for the forward term. Any reverse component uses the realized token's logprob for the reverse " - "term and top-1 teacher support for the forward term." + "When beta > 0 and loss_top_k == 1, the forward support still uses the teacher's top-1 token, " + "while the reverse top-1 token is controlled by reverse_kl_top_1_mode. " + "When use_teacher_server=True, beta=0 requires loss_top_k > 0 and uses the teacher's top-k " + "logprobs for the forward term. When beta > 0, server-backed distillation requires loss_top_k == 1 " + "and only supports 'sampled' reverse top-1 tokens." }, ) loss_add_tail: bool = field( @@ -352,6 +376,8 @@ def __post_init__(self): raise ValueError(f"lmbda must be in [0.0, 1.0], got {self.lmbda}.") if self.beta < 0.0 or self.beta > 1.0: raise ValueError(f"beta must be in [0.0, 1.0], got {self.beta}.") + if self.reverse_kl_top_1_mode not in {"sampled", "argmax"}: + raise ValueError("reverse_kl_top_1_mode must be one of: 'sampled', 'argmax'") if self.max_length is not None and self.max_completion_length >= self.max_length: raise ValueError( @@ -377,18 +403,40 @@ def __post_init__(self): f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) - if self.teacher_model_server_url is not None and self.use_liger_kernel: + if self.use_teacher_server and self.use_liger_kernel: raise ValueError( - "use_liger_kernel=True is not supported with teacher_model_server_url because the Liger loss path " + "use_liger_kernel=True is not supported with use_teacher_server=True because the Liger loss path " "requires a local teacher model." ) + if self.use_teacher_server and ( + self.teacher_model_server_url is None or not self.teacher_model_server_url.strip() + ): + raise ValueError("teacher_model_server_url must be set when use_teacher_server=True.") - if self.teacher_model_server_url is not None and self.beta == 0 and self.loss_top_k < 1: + if self.use_teacher_server and self.beta == 0 and self.loss_top_k < 1: raise ValueError( - f"loss_top_k must be positive when using teacher_model_server_url with beta=0 " + f"loss_top_k must be positive when using use_teacher_server=True with beta=0 " f"(got loss_top_k={self.loss_top_k}). The pure forward server path only has access to the " f"teacher's top-k logprobs, so it cannot compute the exact full-vocabulary loss when loss_top_k=0." ) + if self.use_teacher_server and self.reverse_kl_top_1_mode == "argmax": + raise ValueError( + "reverse_kl_top_1_mode='argmax' is not supported with use_teacher_server=True because the server " + "cannot provide teacher logprobs for arbitrary student-selected tokens." + ) + if self.use_teacher_server and self.beta > 0 and self.loss_top_k != 1: + raise ValueError( + f"loss_top_k must be 1 when using use_teacher_server=True with beta>0 " + f"(got loss_top_k={self.loss_top_k}). Mixed forward/reverse distillation with an external teacher " + "is only implemented for top-1 support." + ) + if self.reverse_kl_top_1_mode != "sampled" and (self.beta == 0 or self.loss_top_k != 1): + warnings.warn( + f"reverse_kl_top_1_mode='{self.reverse_kl_top_1_mode}' has no effect when beta={self.beta} " + f"and loss_top_k={self.loss_top_k}. It only applies when beta > 0 and loss_top_k == 1.", + UserWarning, + stacklevel=2, + ) if self.num_generations > 1 and self.lmbda < 1.0: warnings.warn( diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index a6997527aaf..5564f42a2fa 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -453,11 +453,13 @@ def __init__( # ── Teacher model setup ── self.teacher_client = None + self.use_teacher_server = args.use_teacher_server + self.teacher_model_server_url = args.teacher_model_server_url self._local_teacher_tokenizer_matches_student = True - if args.teacher_model_server_url is not None: + if self.use_teacher_server: from ...generation.vllm_client import VLLMClient - self.teacher_client = VLLMClient(base_url=args.teacher_model_server_url, connection_timeout=60.0) + self.teacher_client = VLLMClient(base_url=self.teacher_model_server_url, connection_timeout=60.0) teacher_model = None elif teacher_model is not None: if args.teacher_model_init_kwargs is not None and not isinstance(teacher_model, str): @@ -552,6 +554,7 @@ def __init__( self.temperature = args.temperature self.top_p = args.top_p self.num_generations = args.num_generations + self.reverse_kl_top_1_mode = args.reverse_kl_top_1_mode self.loss_top_k = args.loss_top_k self.loss_add_tail = args.loss_add_tail @@ -644,7 +647,7 @@ def _raise_if_local_teacher_tokenizer_mismatch(self) -> None: if self.teacher_model is not None and not self._local_teacher_tokenizer_matches_student: raise ValueError( "DistillationTrainer's built-in local-teacher loss only supports student/teacher pairs that use " - "the same tokenizer. Use a same-tokenizer local teacher, use `teacher_model_server_url`, or " + "the same tokenizer. Use a same-tokenizer local teacher, set `use_teacher_server=True`, or " "override the local teacher loss path in a subclass." ) @@ -1013,6 +1016,28 @@ def _build_sequence_batch( # Loss computation # ────────────────────────────────────────────────────────────────────── + @staticmethod + def _reduce_divergence_loss(jsd, labels=None, reduction="batchmean"): + """Reduce a per-token divergence tensor using the trainer's label mask semantics.""" + mask = None + if labels is not None: + mask = labels != -100 + jsd = jsd[mask] + + if reduction == "batchmean": + if labels is not None: + num_tokens = mask.sum() + if num_tokens == 0: + return jsd.sum() * 0.0 # no completion tokens — return zero-grad scalar + return jsd.sum() / num_tokens + return jsd.sum() / jsd.size(0) + elif reduction == "sum": + return jsd.sum() + elif reduction == "mean": + return jsd.mean() + else: + return jsd + @staticmethod def generalized_jsd_loss( student_logits, @@ -1096,24 +1121,90 @@ def generalized_jsd_loss( teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) jsd = _jsd_divergence(student_log_probs, teacher_log_probs, beta, support_mask) + return DistillationTrainer._reduce_divergence_loss(jsd, labels=labels, reduction=reduction) - if labels is not None: - mask = labels != -100 - jsd = jsd[mask] + def _get_reverse_kl_top_1_tokens( + self, student_scores: torch.Tensor, completion_tokens: torch.Tensor + ) -> torch.Tensor: + """Return the reverse-KL top-1 token IDs for the mixed top-1 loss path. - if reduction == "batchmean": - if labels is not None: - num_tokens = mask.sum() - if num_tokens == 0: - return jsd.sum() * 0.0 # no completion tokens — return zero-grad scalar - return jsd.sum() / num_tokens - return jsd.sum() / jsd.size(0) - elif reduction == "sum": - return jsd.sum() - elif reduction == "mean": - return jsd.mean() + Args: + student_scores: Any (B, T, V) tensor whose argmax selects the student's top token + (logits or log-probs — both are order-preserving). + completion_tokens: (B, T) actual token IDs in the completion. + """ + if self.reverse_kl_top_1_mode == "argmax": + return student_scores.argmax(dim=-1) + return completion_tokens + + def _compute_sparse_top_1_divergence_loss( + self, + student_log_probs: torch.Tensor, + teacher_top1_token_ids: torch.Tensor, + teacher_top1_logprobs: torch.Tensor, + reverse_token_ids: torch.Tensor, + reverse_teacher_logprobs: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute exact generalized JSD/KL on top-1 support for the mixed beta>0 path.""" + neg_inf = torch.full((), float("-inf"), dtype=student_log_probs.dtype, device=student_log_probs.device) + + if self.beta == 1: + support = reverse_token_ids.unsqueeze(-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + teacher_support_logprobs = reverse_teacher_logprobs.unsqueeze(-1) else: - return jsd + teacher_support = teacher_top1_token_ids.unsqueeze(-1) + reverse_support = reverse_token_ids.unsqueeze(-1) + support = torch.cat([teacher_support, reverse_support], dim=-1) + support_mask = torch.ones_like(support, dtype=torch.bool) + support_mask[..., 1] = support[..., 1] != support[..., 0] + teacher_support_logprobs = torch.stack([teacher_top1_logprobs, reverse_teacher_logprobs], dim=-1) + support = torch.where(support_mask, support, torch.zeros_like(support)) + + student_support_logprobs = student_log_probs.gather(-1, support) + student_support_logprobs = torch.where(support_mask, student_support_logprobs, neg_inf) + teacher_support_logprobs = torch.where(support_mask, teacher_support_logprobs, neg_inf) + + if self.loss_add_tail: + base_support_mask = support_mask + student_sparse_log_probs, support_mask = _add_tail_bucket(student_support_logprobs, base_support_mask) + teacher_sparse_log_probs, _ = _add_tail_bucket(teacher_support_logprobs, base_support_mask) + else: + student_sparse_log_probs = student_support_logprobs - torch.logsumexp( + student_support_logprobs, dim=-1, keepdim=True + ) + teacher_sparse_log_probs = teacher_support_logprobs - torch.logsumexp( + teacher_support_logprobs, dim=-1, keepdim=True + ) + + jsd = _jsd_divergence(student_sparse_log_probs, teacher_sparse_log_probs, self.beta, support_mask) + return self._reduce_divergence_loss(jsd, labels=labels, reduction="batchmean") + + def _compute_local_sparse_top_1_divergence_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + completion_tokens: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute the mixed top-1 loss for a local teacher using gathered full-logit probabilities.""" + student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) + + teacher_top1_token_ids = teacher_logits.argmax(dim=-1) + teacher_top1_logprobs = teacher_log_probs.gather(dim=-1, index=teacher_top1_token_ids.unsqueeze(-1)).squeeze(-1) + reverse_token_ids = self._get_reverse_kl_top_1_tokens(student_logits, completion_tokens) + reverse_teacher_logprobs = teacher_log_probs.gather(dim=-1, index=reverse_token_ids.unsqueeze(-1)).squeeze(-1) + + return self._compute_sparse_top_1_divergence_loss( + student_log_probs=student_log_probs, + teacher_top1_token_ids=teacher_top1_token_ids, + teacher_top1_logprobs=teacher_top1_logprobs, + reverse_token_ids=reverse_token_ids, + reverse_teacher_logprobs=reverse_teacher_logprobs, + labels=labels, + ) def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Tensor: """Get teacher logits — dispatches between local model and external server.""" @@ -1124,9 +1215,9 @@ def _get_teacher_logits(self, inputs: dict[str, torch.Tensor | Any]) -> torch.Te input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ).logits - elif self.teacher_client is not None: + elif self.use_teacher_server: raise NotImplementedError( - "Fetching full teacher logits from `teacher_model_server_url` is not supported. " + "Fetching full teacher logits with use_teacher_server=True is not supported. " "Server-backed distillation only supports per-token logprobs via " "`_get_teacher_token_logprobs_from_server`." ) @@ -1158,11 +1249,9 @@ def _get_teacher_token_logprobs_from_server( labels=inputs.get("labels"), ) - # The external-teacher path can use the teacher's configured top-k support for - # pure forward KL. Once a reverse component is present, the reverse term only - # has access to the realized token's logprob, so we also limit the forward term - # to top-1 teacher support to keep the mixed objective balanced. - requested_top_k = self.loss_top_k if self.beta == 0 else 1 + # The pure forward server path can use the requested teacher top-k support. + # When beta > 0, config validation restricts the server-backed path to top-1. + requested_top_k = self.loss_top_k result = self.teacher_client.get_sequence_logprobs( sequences=sequences, prompt_lengths=prompt_lengths, @@ -1221,23 +1310,14 @@ def _topk_to_tensor(key, k, np_dtype, fill): "topk_token_ids": _topk_to_tensor("logprob_token_ids", K, np.int64, 0), } - def _compute_server_divergence_loss( + def _compute_server_sparse_top_1_divergence_loss( self, teacher_result: dict[str, torch.Tensor], student_log_probs: torch.Tensor, completion_tokens: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: - """Compute sparse server-side forward/reverse KL surrogates using teacher logprobs from the server. - - The forward KL term uses the teacher's top-k support for pure forward KL and top-1 teacher support whenever a - reverse component is present, optionally collapsed with a tail bucket. The reverse KL term uses the actual - completion token only, because the server cannot provide teacher logprobs at arbitrary student-selected token - IDs. When `self.loss_add_tail` is enabled, the reverse term becomes a two-bucket KL over `{actual_token, - residual_tail}`. This matches `generalized_jsd_loss(..., beta=1, top_k=1, add_tail=True)` when the sampled - token is the same token used in the top-1 support and the logprob values agree. For `0 < beta < 1`, this method - still returns a weighted combination of the forward and reverse surrogates rather than the exact generalized - JSD used by the local-teacher path. + """Compute exact sparse top-1 generalized JSD/KL from server-provided teacher logprobs. Args: teacher_result: dict with ``actual_logprobs`` (B, T), ``topk_logprobs`` (B, T, K), @@ -1246,37 +1326,11 @@ def _compute_server_divergence_loss( completion_tokens: (B, T) actual token IDs in the completion. labels: (B, T) with -100 for positions to ignore. """ - topk_teacher_lps = teacher_result["topk_logprobs"] # (B, T, K) - topk_token_ids = teacher_result["topk_token_ids"] # (B, T, K) + topk_teacher_lps = teacher_result["topk_logprobs"] # (B, T, 1) + topk_token_ids = teacher_result["topk_token_ids"] # (B, T, 1) actual_teacher_lps = teacher_result["actual_logprobs"] # (B, T) - - # ── Forward KL term: sum over teacher's top-k tokens ── - # Gather student logprobs for each of the teacher's top-k tokens. - student_topk_lps = student_log_probs.gather(dim=-1, index=topk_token_ids) # (B, T, K) - - # Mask out -inf padding (positions where top-k slot was not filled). - valid = topk_teacher_lps > float("-inf") - if self.loss_add_tail: - neg_inf = torch.full((), float("-inf"), dtype=student_log_probs.dtype, device=student_log_probs.device) - student_topk_lps = torch.where(valid, student_topk_lps, neg_inf) - teacher_topk_lps = torch.where(valid, topk_teacher_lps, neg_inf) - forward_student_log_probs, forward_support_mask = _add_tail_bucket(student_topk_lps, valid) - forward_teacher_log_probs, _ = _add_tail_bucket(teacher_topk_lps, valid) - fwd_per_token = _jsd_divergence( - forward_student_log_probs, - forward_teacher_log_probs, - beta=0.0, - support_mask=forward_support_mask, - ).sum(dim=-1) - else: - fwd_per_k = torch.exp(topk_teacher_lps) * (topk_teacher_lps - student_topk_lps) # (B, T, K) - fwd_per_token = (fwd_per_k * valid).sum(dim=-1) # (B, T) - - # ── Reverse KL term: actual token only ── - student_actual_lps = student_log_probs.gather(dim=-1, index=completion_tokens.unsqueeze(-1)).squeeze( - -1 - ) # (B, T) required = labels != -100 + missing_actual = required & ~torch.isfinite(actual_teacher_lps) if missing_actual.any(): missing_count = int(missing_actual.sum().item()) @@ -1285,35 +1339,70 @@ def _compute_server_divergence_loss( "Teacher server is missing actual-token logprobs for required reverse-KL positions: " f"{missing_count}/{total_required}." ) + if self.beta < 1: + teacher_top1_logprobs = topk_teacher_lps.squeeze(-1) + missing_top1 = required & ~torch.isfinite(teacher_top1_logprobs) + if missing_top1.any(): + missing_count = int(missing_top1.sum().item()) + total_required = int(required.sum().item()) + raise ValueError( + "Teacher server is missing top-1 logprobs for required forward-KL positions: " + f"{missing_count}/{total_required}." + ) + + # Server path only supports "sampled" mode — config validation enforces this, but we guard + # explicitly so future relaxations of the config check don't silently change behaviour. + reverse_token_ids = self._get_reverse_kl_top_1_tokens(student_log_probs, completion_tokens) + return self._compute_sparse_top_1_divergence_loss( + student_log_probs=student_log_probs, + teacher_top1_token_ids=topk_token_ids.squeeze(-1), + teacher_top1_logprobs=topk_teacher_lps.squeeze(-1), + reverse_token_ids=reverse_token_ids, + reverse_teacher_logprobs=actual_teacher_lps, + labels=labels, + ) + + def _compute_server_forward_kl_loss( + self, + teacher_result: dict[str, torch.Tensor], + student_log_probs: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute sparse forward KL from server-provided teacher top-k logprobs (beta==0 path). + + Args: + teacher_result: dict with ``topk_logprobs`` (B, T, K) and ``topk_token_ids`` (B, T, K). + student_log_probs: (B, T, V) student log-softmax over vocabulary. + labels: (B, T) with -100 for positions to ignore. + """ + teacher_topk_logprobs = teacher_result["topk_logprobs"] + teacher_topk_token_ids = teacher_result["topk_token_ids"] + valid = teacher_topk_logprobs > float("-inf") + neg_inf = torch.full((), float("-inf"), dtype=student_log_probs.dtype, device=student_log_probs.device) + student_topk_logprobs = student_log_probs.gather(dim=-1, index=teacher_topk_token_ids) + student_topk_logprobs = torch.where(valid, student_topk_logprobs, neg_inf) + teacher_topk_logprobs = torch.where(valid, teacher_topk_logprobs, neg_inf) + if self.loss_add_tail: - rev_student_lps = student_actual_lps.unsqueeze(-1) # (B, T, 1) - rev_teacher_lps = actual_teacher_lps.unsqueeze(-1) # (B, T, 1) - rev_valid = torch.ones_like(rev_student_lps, dtype=torch.bool) - rev_student_log_probs, rev_support_mask = _add_tail_bucket(rev_student_lps, rev_valid) - rev_teacher_log_probs, _ = _add_tail_bucket(rev_teacher_lps, rev_valid) - rev_per_token = _jsd_divergence( - rev_student_log_probs, - rev_teacher_log_probs, - beta=1.0, - support_mask=rev_support_mask, - ).sum(dim=-1) - else: - student_actual_ps = torch.exp(student_actual_lps) - rev_per_token = student_actual_ps * (student_actual_lps - actual_teacher_lps) # (B, T) - - # ── Combine according to beta ── - if self.beta == 0: - loss_per_token = fwd_per_token - elif self.beta == 1: - loss_per_token = rev_per_token + base_support_mask = valid + student_sparse_log_probs, support_mask = _add_tail_bucket(student_topk_logprobs, base_support_mask) + teacher_sparse_log_probs, _ = _add_tail_bucket(teacher_topk_logprobs, base_support_mask) else: - loss_per_token = (1 - self.beta) * fwd_per_token + self.beta * rev_per_token + support_mask = valid + student_sparse_log_probs = student_topk_logprobs - torch.logsumexp( + student_topk_logprobs, dim=-1, keepdim=True + ) + teacher_sparse_log_probs = teacher_topk_logprobs - torch.logsumexp( + teacher_topk_logprobs, dim=-1, keepdim=True + ) - mask = labels != -100 - num_tokens = mask.sum() - if num_tokens == 0: - return loss_per_token.sum() * 0.0 # no completion tokens — return zero-grad scalar - return loss_per_token[mask].sum() / num_tokens + jsd = _jsd_divergence( + student_sparse_log_probs, + teacher_sparse_log_probs, + beta=0.0, + support_mask=support_mask, + ) + return self._reduce_divergence_loss(jsd, labels=labels, reduction="batchmean") def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): self._raise_if_local_teacher_tokenizer_mismatch() @@ -1329,8 +1418,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) prompt_length = self._compute_prompt_length(inputs) labels = inputs["labels"][:, prompt_length:] + completion_tokens = inputs["input_ids"][:, prompt_length:] - if self.teacher_client is not None: + if self.use_teacher_server: # Server path: token-level divergence using teacher logprobs. # The server returns: # actual_logprobs – (B, T) teacher log p(x_actual) (for reverse KL) @@ -1340,33 +1430,46 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) - completion_tokens = inputs["input_ids"][:, prompt_length:] comp_len = teacher_result["actual_logprobs"].shape[1] completion_tokens = completion_tokens[:, :comp_len] trimmed_labels = labels[:, :comp_len] - loss = self._compute_server_divergence_loss( - teacher_result=teacher_result, - student_log_probs=student_log_probs[:, :comp_len, :], - completion_tokens=completion_tokens, - labels=trimmed_labels, - ) + if self.beta > 0: + loss = self._compute_server_sparse_top_1_divergence_loss( + teacher_result=teacher_result, + student_log_probs=student_log_probs[:, :comp_len, :], + completion_tokens=completion_tokens, + labels=trimmed_labels, + ) + else: + loss = self._compute_server_forward_kl_loss( + teacher_result=teacher_result, + student_log_probs=student_log_probs[:, :comp_len, :], + labels=trimmed_labels, + ) else: - # Local teacher: full-vocabulary generalized JSD + # Local teacher: exact full-vocabulary loss except for the shared mixed top-1 path. teacher_logits = self._get_teacher_logits(inputs) student_logits = student_outputs.logits[:, prompt_length - 1 : -1, :] teacher_logits = teacher_logits[:, prompt_length - 1 : -1, :] - - loss = self.generalized_jsd_loss( - student_logits=student_logits, - teacher_logits=teacher_logits, - labels=labels, - beta=self.beta, - temperature=self.temperature, - top_k=self.loss_top_k, - add_tail=self.loss_add_tail, - ) + if self.beta > 0 and self.loss_top_k == 1: + loss = self._compute_local_sparse_top_1_divergence_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + completion_tokens=completion_tokens, + labels=labels, + ) + else: + loss = self.generalized_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=labels, + beta=self.beta, + temperature=self.temperature, + top_k=self.loss_top_k, + add_tail=self.loss_add_tail, + ) return (loss, student_outputs) if return_outputs else loss From 90c6e810b50359cfa48656023ff3b176cd76788b Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 7 Apr 2026 17:26:33 +0200 Subject: [PATCH 36/37] Run precommit --- trl/experimental/distillation/distillation_config.py | 4 +--- trl/experimental/distillation/distillation_trainer.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index bb379a350d9..d050665f7aa 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -229,9 +229,7 @@ class DistillationConfig(_BaseConfig): # Teacher model (external vLLM server) use_teacher_server: bool = field( default=False, - metadata={ - "help": "Whether to use an external vLLM teacher server instead of a local teacher model." - }, + metadata={"help": "Whether to use an external vLLM teacher server instead of a local teacher model."}, ) teacher_model_server_url: str | None = field( default=None, diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 5564f42a2fa..df7837ff433 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -1193,7 +1193,9 @@ def _compute_local_sparse_top_1_divergence_loss( teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) teacher_top1_token_ids = teacher_logits.argmax(dim=-1) - teacher_top1_logprobs = teacher_log_probs.gather(dim=-1, index=teacher_top1_token_ids.unsqueeze(-1)).squeeze(-1) + teacher_top1_logprobs = teacher_log_probs.gather(dim=-1, index=teacher_top1_token_ids.unsqueeze(-1)).squeeze( + -1 + ) reverse_token_ids = self._get_reverse_kl_top_1_tokens(student_logits, completion_tokens) reverse_teacher_logprobs = teacher_log_probs.gather(dim=-1, index=reverse_token_ids.unsqueeze(-1)).squeeze(-1) From 864ba1116b4bc89bb6969cb24a1341075b7b6012 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 9 Apr 2026 14:49:08 +0200 Subject: [PATCH 37/37] Address latest comments --- trl/experimental/distillation/distillation_config.py | 4 ++-- trl/experimental/distillation/distillation_trainer.py | 2 +- trl/scripts/vllm_serve.py | 9 ++++++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/trl/experimental/distillation/distillation_config.py b/trl/experimental/distillation/distillation_config.py index d050665f7aa..8e1386f1134 100644 --- a/trl/experimental/distillation/distillation_config.py +++ b/trl/experimental/distillation/distillation_config.py @@ -201,8 +201,8 @@ class DistillationConfig(_BaseConfig): default=None, metadata={ "help": "Maximum number of tokens for the prompt. If None, auto-computed as " - "max_length - max_completion_length. Prompts are truncated from the left to preserve " - "the most recent context near the generation point." + "max_length - max_completion_length. Prompts are truncated according to the " + "tokenizer's truncation_side setting." }, ) disable_dropout: bool = field( diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index df7837ff433..479e03e8b71 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -240,7 +240,7 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: has_completion = len(messages) > 1 and messages[-1].get("role") == "assistant" prompt_messages = messages[:-1] if has_completion else messages - # Tokenize prompt with its own budget (truncate from the left to keep recent context) + # Tokenize prompt with its own budget using the tokenizer's truncation side formatted_prompt = self.tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generation_prompt=True ) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 4341d22c129..573a8ee60cd 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -784,7 +784,7 @@ async def _logprob_batcher(): future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) except Exception as e: # Signal error to all waiting requests in this execution-parameter group - for _, _, _, future in items: + for *_, future in items: if not future.done(): future.set_exception(e) except Exception as e: @@ -930,6 +930,13 @@ async def get_sequence_logprobs(request: SequenceLogprobsRequest): if len(request.sequences) != len(request.prompt_lengths): raise ValueError("sequences and prompt_lengths must have the same length.") + for i, (seq, pl) in enumerate(zip(request.sequences, request.prompt_lengths, strict=True)): + if pl < 0 or pl > len(seq): + raise ValueError( + f"Sequence {i} has prompt_length={pl} which is out of range [0, {len(seq)}]. " + f"prompt_length must be between 0 and the sequence length inclusive." + ) + # Validate sequence lengths against max_model_len to prevent worker OOM crashes if _max_model_len: for i, seq in enumerate(request.sequences):