diff --git a/arctic_training/__init__.py b/arctic_training/__init__.py index a425f61d..1b5dec8a 100644 --- a/arctic_training/__init__.py +++ b/arctic_training/__init__.py @@ -25,6 +25,7 @@ from arctic_training.config.data import DataConfig from arctic_training.config.logger import LoggerConfig from arctic_training.config.model import ModelConfig +from arctic_training.config.on_policy_distillation import OnPolicyDistillationTrainerConfig from arctic_training.config.optimizer import OptimizerConfig from arctic_training.config.scheduler import SchedulerConfig from arctic_training.config.tokenizer import TokenizerConfig @@ -36,6 +37,7 @@ from arctic_training.data.hf_causal_source import HFDataSourceCausal from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.hf_source import HFDataSource +from arctic_training.data.on_policy_distillation_factory import OnPolicyDistillationDataFactory from arctic_training.data.sft_factory import SFTDataFactory from arctic_training.data.snowflake_source import SnowflakeDataSource from arctic_training.data.source import DataSource @@ -53,5 +55,6 @@ from arctic_training.trainer.causal_trainer import CausalTrainer from arctic_training.trainer.dpo_trainer import DPOTrainer from arctic_training.trainer.dpo_trainer import DPOTrainerConfig +from arctic_training.trainer.on_policy_distillation_trainer import OnPolicyDistillationTrainer from arctic_training.trainer.sft_trainer import SFTTrainer from arctic_training.trainer.trainer import Trainer diff --git a/arctic_training/config/on_policy_distillation.py b/arctic_training/config/on_policy_distillation.py new file mode 100644 index 00000000..d1882848 --- /dev/null +++ b/arctic_training/config/on_policy_distillation.py @@ -0,0 +1,151 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for On-Policy Distillation Trainer. + +On-Policy Distillation trains a student model by having it generate its own +trajectories, then using a teacher model to provide per-token supervision via +reverse KL divergence. This contrasts with traditional (off-policy) distillation +where the teacher generates trajectories for the student to imitate. +""" + +from typing import Dict +from typing import Union +from typing import cast + +from pydantic import Field +from pydantic import ValidationInfo +from pydantic import field_validator +from pydantic import model_validator +from typing_extensions import Self + +from arctic_training.config.model import ModelConfig +from arctic_training.config.trainer import TrainerConfig +from arctic_training.config.utils import HumanInt +from arctic_training.registry import get_registered_model_factory + + +class OnPolicyDistillationTrainerConfig(TrainerConfig): + """Configuration for On-Policy Distillation Trainer. + + On-policy distillation trains the student on its own generated trajectories, + with the teacher providing dense per-token feedback via reverse KL divergence. + """ + + teacher_model: ModelConfig + """ + Configuration for the teacher model used in on-policy distillation. + The teacher model provides per-token log probabilities for computing + the reverse KL divergence loss against student-generated trajectories. + """ + + teacher_deepspeed: Dict = {} + """ + DeepSpeed configuration for the teacher model. This is automatically + computed based on the main model's DeepSpeed config and should not + be provided by the user. + """ + + disable_teacher_dropout: bool = True + """ + Whether to disable dropout in the teacher model during training. + Recommended to keep True for stable distillation signal. + """ + + num_rollouts_per_prompt: int = Field(default=4, ge=1) + """ + Number of trajectory samples to generate from the student per prompt. + Higher values provide more diverse on-policy samples but increase compute. + """ + + max_new_tokens: HumanInt = Field(default=2048, ge=1) + """ + Maximum number of new tokens to generate for each student trajectory. + Should be set based on expected response length for the task. + """ + + generation_temperature: float = Field(default=1.0, gt=0.0) + """ + Temperature for student trajectory generation. + Higher values produce more diverse samples but may reduce quality. + """ + + generation_top_p: float = Field(default=1.0, gt=0.0, le=1.0) + """ + Top-p (nucleus) sampling parameter for student generation. + """ + + generation_top_k: int = Field(default=0, ge=0) + """ + Top-k sampling parameter for student generation. 0 means no top-k filtering. + """ + + beta: float = Field(default=1.0, gt=0.0) + """ + Coefficient for the reverse KL divergence loss. + Controls the strength of the distillation signal. + """ + + generation_batch_size: int = Field(default=0, ge=0) + """ + Batch size for trajectory generation. If 0, uses micro_batch_size. + May need to be smaller than micro_batch_size due to memory constraints + during generation. + """ + + @field_validator("teacher_model", mode="before") + @classmethod + def init_teacher_model_config(cls, v: Union[Dict, ModelConfig], info: ValidationInfo) -> ModelConfig: + """Initialize teacher model config from dict or ModelConfig.""" + subconfig = cls._get_subconfig_object( + v=v, + info=info, + get_class_fn=get_registered_model_factory, + attr_name="teacher_model_factory", + ) + return cast(ModelConfig, subconfig) + + @model_validator(mode="after") + def build_teacher_deepspeed_config(self) -> Self: + """Build DeepSpeed config for teacher model.""" + if len(self.teacher_deepspeed) != 0: + raise ValueError( + "Teacher model DeepSpeed config is computed based on the main model " + "DeepSpeed config and should not be passed by the user." + ) + + teacher_deepspeed = dict( + train_batch_size=self.deepspeed["train_batch_size"], + train_micro_batch_size_per_gpu=self.deepspeed["train_micro_batch_size_per_gpu"], + steps_per_print=self.deepspeed["steps_per_print"], + zero_optimization=dict( + stage=3 if self.deepspeed["zero_optimization"]["stage"] == 3 else 0, + stage3_param_persistence_threshold=1e4, + memory_efficient_linear=False, + ), + bfloat16=dict(enabled=True), + gradient_clipping=1.0, + prescale_gradients=False, + wall_clock_breakdown=False, + ) + self.teacher_deepspeed = teacher_deepspeed + return self + + @model_validator(mode="after") + def set_generation_batch_size(self) -> Self: + """Set generation batch size to micro_batch_size if not specified.""" + if self.generation_batch_size == 0: + self.generation_batch_size = self.micro_batch_size + return self diff --git a/arctic_training/data/on_policy_distillation_factory.py b/arctic_training/data/on_policy_distillation_factory.py new file mode 100644 index 00000000..5e67f7f4 --- /dev/null +++ b/arctic_training/data/on_policy_distillation_factory.py @@ -0,0 +1,319 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data factory for On-Policy Distillation. + +This module provides a data factory that loads prompts for on-policy distillation +training. Unlike SFT which loads full conversations, this factory loads only the +prompts that the student model will use to generate trajectories. +""" + +from typing import Dict +from typing import List + +import torch +from pydantic import Field +from torch.utils.data import DataLoader +from transformers import BatchEncoding +from transformers import PreTrainedTokenizerBase + +from arctic_training.config.data import DataConfig +from arctic_training.config.utils import HumanInt +from arctic_training.data.factory import DataFactory +from arctic_training.data.hf_instruct_source import HFDataSourceInstruct +from arctic_training.data.utils import DatasetType + + +class OnPolicyDistillationDataConfig(DataConfig): + """Configuration for On-Policy Distillation data loading.""" + + div_length: HumanInt = 256 + """The number that the length of the prompt sequence should be divisible by.""" + + max_prompt_length: HumanInt = Field(default=0, ge=0) + """ + Maximum length of the prompt. If 0, uses max_length. + Prompts longer than this will be truncated. + """ + + filter_long_prompts: bool = True + """Whether to filter out prompts longer than max_prompt_length.""" + + include_system_prompt: bool = True + """Whether to include system prompts in the tokenized prompt.""" + + +def pad_prompts( + tensors: List[torch.Tensor], + padding_value: int = 0, + padding_side: str = "left", + divisible_by: int = 256, +) -> torch.Tensor: + """Pad a list of 1D tensors to the same length. + + For generation, we typically want left-padding so that the generated + tokens are appended to the right. + + Args: + tensors: List of 1D tensors to pad + padding_value: Value to use for padding + padding_side: Side to pad on ('left' or 'right') + divisible_by: Pad to length divisible by this value + + Returns: + Padded tensor of shape (batch_size, padded_length) + """ + import math + + max_len = max(t.size(0) for t in tensors) + padded_len = math.ceil(max_len / divisible_by) * divisible_by + + output = torch.full( + (len(tensors), padded_len), + padding_value, + dtype=tensors[0].dtype, + ) + + for i, t in enumerate(tensors): + if padding_side == "left": + output[i, padded_len - t.size(0) :] = t + else: + output[i, : t.size(0)] = t + + return output + + +class DataCollatorForOnPolicyDistillation: + """Data collator for on-policy distillation training. + + This collator pads prompts with left-padding (for generation) and + tracks the original prompt lengths for later use. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, config: OnPolicyDistillationDataConfig): + self.tokenizer = tokenizer + self.config = config + + def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: + input_ids = [torch.tensor(example["input_ids"]) for example in instances] + prompt_lengths = [len(example["input_ids"]) for example in instances] + + # Left-pad for generation + padded_input_ids = pad_prompts( + input_ids, + padding_value=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, + padding_side="left", + divisible_by=self.config.div_length, + ) + + # Create attention mask (1 for real tokens, 0 for padding) + attention_mask = (padded_input_ids != (self.tokenizer.pad_token_id or self.tokenizer.eos_token_id)).long() + + return { + "input_ids": padded_input_ids, + "attention_mask": attention_mask, + "prompt_lengths": torch.tensor(prompt_lengths), + } + + +def filter_by_prompt_length(self, dataset: DatasetType) -> DatasetType: + """Filter dataset to remove prompts longer than max_prompt_length.""" + if not self.config.filter_long_prompts: + return dataset + + max_len = self.config.max_prompt_length or self.config.max_length + + dataset = dataset.filter( + lambda x: len(x["input_ids"]) <= max_len, + num_proc=self.config.num_proc, + desc="Filtering prompts by max length", + ) + + if len(dataset) < 1: + raise ValueError( + f"No data left after filtering by max prompt length {max_len}. " + "Consider increasing max_prompt_length or max_length." + ) + + return dataset + + +class OnPolicyDistillationDataFactory(DataFactory): + """Data factory for On-Policy Distillation training. + + This factory loads prompts (without completions) for on-policy distillation. + The prompts are tokenized and left-padded for efficient batch generation + by the student model. + + Expected input format: + - Dataset with 'messages' column containing conversation turns + - Only user/system messages are used; assistant messages are ignored + """ + + name = "on_policy_distillation" + config: OnPolicyDistillationDataConfig + default_source_cls = HFDataSourceInstruct + + callbacks = [ + ("post-load", filter_by_prompt_length), + ] + + def process(self, dataset: DatasetType) -> DatasetType: + """Process the dataset by tokenizing prompts. + + Extracts the prompt (system + user messages) from each conversation + and tokenizes it for generation. + """ + if "messages" not in dataset.column_names: + raise ValueError("Dataset must have 'messages' column for OnPolicyDistillationDataFactory.") + + dataset = dataset.select_columns(["messages"]) + + return dataset.map( + lambda ex: { + **self.tokenize_prompt( + ex["messages"], + self.tokenizer, + include_system=self.config.include_system_prompt, + ) + }, + remove_columns=dataset.column_names, + num_proc=self.config.num_proc, + desc="Tokenizing prompts", + ) + + @classmethod + def tokenize_prompt( + cls, + messages: List[Dict[str, str]], + tokenizer: PreTrainedTokenizerBase, + include_system: bool = True, + ) -> BatchEncoding: + """Tokenize the prompt portion of a conversation. + + Extracts system and user messages (excluding assistant responses) + and tokenizes them using the chat template. + + Args: + messages: List of message dicts with 'role' and 'content' keys + tokenizer: Tokenizer to use + include_system: Whether to include system messages + + Returns: + BatchEncoding with input_ids for the prompt + """ + # Extract prompt messages (everything before the first assistant response) + prompt_messages = [] + for msg in messages: + role = msg.get("role", "") + if role == "assistant": + break # Stop at first assistant message + if role == "system" and not include_system: + continue + prompt_messages.append(msg) + + # If no prompt messages found, use the first user message + if not prompt_messages: + for msg in messages: + if msg.get("role") == "user": + prompt_messages = [msg] + break + + if not prompt_messages: + raise ValueError("No prompt messages found in conversation") + + # Apply chat template to get the prompt text + # add_generation_prompt=True adds the assistant prefix for generation + prompt_text = tokenizer.apply_chat_template( + conversation=prompt_messages, + tokenize=False, + add_generation_prompt=True, + ) + + # Tokenize the prompt + encoded = tokenizer( + prompt_text, + add_special_tokens=False, + return_attention_mask=False, + ) + + return encoded + + @classmethod + def extract_prompt_and_completion( + cls, + messages: List[Dict[str, str]], + tokenizer: PreTrainedTokenizerBase, + include_system: bool = True, + ) -> Dict[str, List[int]]: + """Extract and tokenize both prompt and completion from messages. + + This is useful for evaluation or when you have ground-truth completions. + + Args: + messages: Full conversation messages + tokenizer: Tokenizer to use + include_system: Whether to include system messages + + Returns: + Dict with 'prompt_ids' and 'completion_ids' + """ + prompt_messages = [] + completion_messages = [] + seen_assistant = False + + for msg in messages: + role = msg.get("role", "") + if role == "assistant": + seen_assistant = True + completion_messages.append(msg) + elif not seen_assistant: + if role == "system" and not include_system: + continue + prompt_messages.append(msg) + + # Tokenize prompt + prompt_text = tokenizer.apply_chat_template( + conversation=prompt_messages, + tokenize=False, + add_generation_prompt=True, + ) + prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] + + # Tokenize completion (if any) + completion_ids = [] + if completion_messages: + # Get full conversation and subtract prompt + full_text = tokenizer.apply_chat_template( + conversation=prompt_messages + completion_messages, + tokenize=False, + add_generation_prompt=False, + ) + full_ids = tokenizer(full_text, add_special_tokens=False)["input_ids"] + completion_ids = full_ids[len(prompt_ids) :] + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + } + + def create_dataloader(self, dataset: DatasetType) -> DataLoader: + """Create a DataLoader with the appropriate collator.""" + dataloader = super().create_dataloader(dataset) + dataloader.collate_fn = DataCollatorForOnPolicyDistillation( + tokenizer=self.tokenizer, + config=self.config, + ) + return dataloader diff --git a/arctic_training/metrics.py b/arctic_training/metrics.py index 4e2885e8..8dddb37c 100644 --- a/arctic_training/metrics.py +++ b/arctic_training/metrics.py @@ -152,6 +152,38 @@ def print_summary(self, prefix: str = "train") -> None: losses = gather_object(self.values["loss/eval"], self.trainer.world_size) self.summary_dict["loss/eval"] = sum(losses) / len(losses) + # On-policy distillation metrics (average across ranks) + # Named with / prefix for W&B panel grouping + distillation_metrics = [ + # logprob group + "logprob/student", + "logprob/teacher", + "logprob/gap", + # perplexity group + "perplexity/student", + "perplexity/teacher", + # distill group + "distill/reverse_kl", + "distill/prob_ratio", + "distill/advantage", + # generation group + "generation/avg_length", + # Eval versions (eval/ prefix for W&B grouping) + "eval/logprob_student", + "eval/logprob_teacher", + "eval/logprob_gap", + "eval/perplexity_student", + "eval/perplexity_teacher", + "eval/reverse_kl", + "eval/prob_ratio", + "eval/advantage", + "eval/avg_completion_length", + ] + for metric_name in distillation_metrics: + if metric_name in self.values: + values = gather_object(self.values[metric_name], self.trainer.world_size) + self.summary_dict[metric_name] = sum(values) / len(values) + if "iter_time" in self.values: iter_time_total = sum(gather_object(self.values["iter_time"], self.trainer.world_size)) self.summary_dict["iter_time"] = iter_time_total / self.trainer.world_size diff --git a/arctic_training/trainer/on_policy_distillation_trainer.py b/arctic_training/trainer/on_policy_distillation_trainer.py new file mode 100644 index 00000000..d1e6630e --- /dev/null +++ b/arctic_training/trainer/on_policy_distillation_trainer.py @@ -0,0 +1,520 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""On-Policy Distillation Trainer. + +This module implements on-policy distillation training where: +1. The student model generates its own trajectories +2. The teacher model provides per-token supervision via reverse KL divergence +3. The student is updated to minimize the reverse KL against the teacher + +This approach combines the on-policy relevance of RL with the dense reward +signal of distillation, enabling efficient training of smaller models. + +Reference: https://thinkingmachines.ai/blog/on-policy-distillation/ +""" + +from typing import Dict +from typing import Tuple +from typing import Union + +import deepspeed +import torch +import torch.nn.functional as F + +from arctic_training.checkpoint.ds_engine import DSCheckpointEngine +from arctic_training.checkpoint.hf_engine import HFCheckpointEngine +from arctic_training.config.on_policy_distillation import OnPolicyDistillationTrainerConfig +from arctic_training.data.on_policy_distillation_factory import OnPolicyDistillationDataFactory +from arctic_training.logging import logger +from arctic_training.model.hf_factory import HFModelFactory +from arctic_training.model.liger_factory import LigerModelFactory +from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory +from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory +from arctic_training.scheduler.hf_factory import HFSchedulerFactory +from arctic_training.tokenizer.hf_factory import HFTokenizerFactory +from arctic_training.trainer.trainer import Trainer +from arctic_training.trainer.utils import disable_dropout_in_model +from arctic_training.trainer.utils import to_device + + +def init_teacher_model(self: "OnPolicyDistillationTrainer") -> None: + """Initialize the in-memory teacher model for logprob computation. + + This callback is called post-init to load the teacher model and + wrap it with DeepSpeed for efficient inference. + """ + config = self.config + + # Create teacher model using the same factory pattern as DPO's ref_model + teacher_model_factory = config.teacher_model.factory(trainer=self, model_config=config.teacher_model) + self.teacher_model = teacher_model_factory() + + # Wrap with DeepSpeed for efficient inference + self.teacher_model, *_ = deepspeed.initialize( + model=self.teacher_model, + config=config.teacher_deepspeed, + ) + + # Disable dropout for stable distillation signal + if config.disable_teacher_dropout: + disable_dropout_in_model(self.teacher_model) + + logger.info("Teacher model initialized for on-policy distillation") + + +class OnPolicyDistillationTrainer(Trainer): + """Trainer for On-Policy Distillation. + + On-policy distillation trains a student model by having it generate + trajectories and using a teacher model to provide per-token feedback + via reverse KL divergence. + + The loss function is: + L = E[log π_student(x) - log π_teacher(x)] + + where x are tokens sampled from π_student. + + This is "mode-seeking" - the student learns to approximate the teacher's + behavior specifically in the states the student visits, making it robust + to compounding errors in generation. + + Attributes: + teacher_model: In-memory teacher model for computing logprobs + """ + + name = "on_policy_distillation" + config: OnPolicyDistillationTrainerConfig + data_factory: OnPolicyDistillationDataFactory + model_factory: Union[HFModelFactory, LigerModelFactory] + teacher_model_factory: Union[HFModelFactory, LigerModelFactory] + checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine] + optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory] + scheduler_factory: HFSchedulerFactory + tokenizer_factory: HFTokenizerFactory + + teacher_model: torch.nn.Module + + callbacks = [ + ("post-init", init_teacher_model), + ] + + def generate_trajectories( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + prompt_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate trajectories from the student model. + + Generates `num_rollouts_per_prompt` completions for each prompt to increase + sample diversity and GPU utilization. + + Args: + input_ids: Prompt input IDs (batch_size, prompt_len) + attention_mask: Attention mask for prompts + prompt_lengths: Original prompt lengths before padding + + Returns: + Tuple of: + - generated_ids: Full sequences (batch_size * num_rollouts, seq_len) + - labels: Token IDs with -100 for prompt/padding positions + - attention_mask: Attention mask for generated sequence + """ + num_rollouts = self.config.num_rollouts_per_prompt + + # Repeat each prompt num_rollouts times for multiple completions per prompt + # [p1, p2, p3, p4] with num_rollouts=2 -> [p1, p1, p2, p2, p3, p3, p4, p4] + if num_rollouts > 1: + input_ids = input_ids.repeat_interleave(num_rollouts, dim=0) + attention_mask = attention_mask.repeat_interleave(num_rollouts, dim=0) + prompt_lengths = prompt_lengths.repeat_interleave(num_rollouts, dim=0) + + # Put model in eval mode for generation (no dropout) + self.model.eval() + + with torch.no_grad(): + # Generate using the model + generated_ids = self.model_unwrapped.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=self.config.max_new_tokens, + do_sample=True, + temperature=self.config.generation_temperature, + top_p=self.config.generation_top_p, + top_k=self.config.generation_top_k if self.config.generation_top_k > 0 else None, + pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + use_cache=True, + ) + + # Put model back in train mode + self.model.train() + + # Create labels tensor with -100 for prompt positions (like TRL) + # This allows using cross_entropy's ignore_index for masking + padded_prompt_len = input_ids.size(1) + gen_seq_len = generated_ids.size(1) + + # Start with all -100 (ignore all) + labels = torch.full_like(generated_ids, -100) + + # Create attention mask for generated sequence + pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id + gen_attention_mask = (generated_ids != pad_token_id).long() + + # Fill in completion tokens (after prompt) using vectorized operations + # Create position indices [1, gen_seq_len] broadcast to [batch, gen_seq_len] + positions = torch.arange(gen_seq_len, device=generated_ids.device).unsqueeze(0) + + # Mask for completion region (after prompt) + in_completion = positions >= padded_prompt_len + + # Find pad and EOS positions in completion region + is_pad = (generated_ids == pad_token_id) & in_completion + is_eos = (generated_ids == self.tokenizer.eos_token_id) & in_completion + + # Find first occurrence of each (use gen_seq_len as "not found" sentinel) + pad_positions = torch.where(is_pad, positions.expand_as(generated_ids), gen_seq_len) + first_pad = pad_positions.min(dim=1).values # [batch_size] + + eos_positions = torch.where(is_eos, positions.expand_as(generated_ids), gen_seq_len) + first_eos = eos_positions.min(dim=1).values # [batch_size] + + # comp_end = min(first_pad, first_eos + 1) - EOS is included, pad is not + comp_end = torch.minimum(first_pad, first_eos + 1) # [batch_size] + + # Create completion mask: True where padded_prompt_len <= position < comp_end + completion_mask = in_completion & (positions < comp_end.unsqueeze(1)) + + # Set labels using the mask (vectorized assignment) + labels = torch.where(completion_mask, generated_ids, labels) + + return generated_ids, labels, gen_attention_mask + + def loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Compute the on-policy distillation loss. + + This is the main loss function that: + 1. Generates trajectories from the student + 2. Computes student and teacher log probabilities + 3. Computes reverse KL divergence + + Uses label masking (-100) for clean handling of variable-length completions. + + Args: + batch: Batch containing input_ids, attention_mask, prompt_lengths + + Returns: + Scalar loss value + """ + batch = to_device(batch, self.device) + + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + prompt_lengths = batch["prompt_lengths"] + + # Step 1: Generate trajectories from student (returns labels with -100 for prompts) + generated_ids, labels, gen_attention_mask = self.generate_trajectories( + input_ids=input_ids, + attention_mask=attention_mask, + prompt_lengths=prompt_lengths, + ) + + # Check if any completions were generated + mask = labels != -100 + num_completion_tokens = mask.sum() + if num_completion_tokens == 0: + logger.warning("No completions generated, returning zero loss") + return torch.tensor(0.0, device=self.device, requires_grad=True) + + # Step 2: Forward pass through student model + student_outputs = self.model( + input_ids=generated_ids, + attention_mask=gen_attention_mask, + use_cache=False, + ) + student_logits = student_outputs.logits + + # Step 3: Forward pass through teacher model (no grad) + self.teacher_model.eval() + with torch.no_grad(): + teacher_outputs = self.teacher_model( + input_ids=generated_ids, + attention_mask=gen_attention_mask, + use_cache=False, + ) + teacher_logits = teacher_outputs.logits + + # Step 4: Compute logprobs using fused cross_entropy (shifted for next-token prediction) + # Shift: logits[t] predicts token at position t+1, so we compare logits[:-1] with labels[1:] + shift_student_logits = student_logits[:, :-1, :].contiguous() + shift_teacher_logits = teacher_logits[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + + # Flatten for cross_entropy + vocab_size = shift_student_logits.size(-1) + flat_student_logits = shift_student_logits.view(-1, vocab_size) + flat_teacher_logits = shift_teacher_logits.view(-1, vocab_size) + flat_labels = shift_labels.view(-1) + + # Compute per-token logprobs using fused cross_entropy + # cross_entropy returns -log_prob, so negate; ignore_index=-100 handles masking + student_logprobs = -F.cross_entropy(flat_student_logits, flat_labels, ignore_index=-100, reduction="none") + with torch.no_grad(): + teacher_logprobs = -F.cross_entropy(flat_teacher_logits, flat_labels, ignore_index=-100, reduction="none") + + # Step 5: Compute on-policy distillation loss using policy gradient + # Reference: https://thinkingmachines.ai/blog/on-policy-distillation/ + # + # The key insight: use -reverse_kl as ADVANTAGE in policy gradient, not as direct loss. + # + # reverse_kl = student_logprob - teacher_logprob + # advantage = -reverse_kl = teacher_logprob - student_logprob + # + # Policy gradient loss: loss = -advantage * log_prob + # = -(teacher_logprob - student_logprob) * student_logprob + # + # This gives correct gradients: + # - When teacher > student (advantage > 0): increase student_logprob + # - When teacher < student (advantage < 0): decrease student_logprob + + shift_mask = flat_labels != -100 + num_tokens = shift_mask.sum() + + if num_tokens > 0: + valid_student_logprobs = student_logprobs[shift_mask] + valid_teacher_logprobs = teacher_logprobs[shift_mask] + + # Advantage = teacher's preference - student's confidence (detached for stable training) + # Positive advantage: teacher likes this token more than student expects + # Negative advantage: student is overconfident relative to teacher + advantage = valid_teacher_logprobs - valid_student_logprobs.detach() + + # Policy gradient loss: maximize student logprob weighted by advantage + # loss = -E[advantage * log p_student] + loss = -(advantage * valid_student_logprobs).mean() * self.config.beta + + # Track reverse KL for monitoring + per_token_kl = valid_student_logprobs - valid_teacher_logprobs + else: + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + per_token_kl = torch.zeros(1, device=self.device) + + # Log detailed metrics for monitoring training progress + with torch.no_grad(): + batch_size = generated_ids.size(0) + avg_completion_len = num_completion_tokens.float() / batch_size + + if num_tokens > 0: + # Log probability metrics (key for understanding training dynamics) + student_logprob_mean = valid_student_logprobs.mean() + teacher_logprob_mean = valid_teacher_logprobs.mean() + logprob_gap = student_logprob_mean - teacher_logprob_mean + + # Group metrics by concept using / prefix for W&B panel grouping + # logprob/ group - raw log probabilities + self.metrics.record("logprob/student", student_logprob_mean.item()) + self.metrics.record("logprob/teacher", teacher_logprob_mean.item()) + self.metrics.record("logprob/gap", logprob_gap.item()) + + # perplexity/ group - easier to interpret + student_ppl = torch.exp(-student_logprob_mean) + teacher_ppl = torch.exp(-teacher_logprob_mean) + self.metrics.record("perplexity/student", student_ppl.item()) + self.metrics.record("perplexity/teacher", teacher_ppl.item()) + + # distill/ group - distillation-specific metrics + mean_kl = per_token_kl.mean() + prob_ratio = torch.exp(logprob_gap) + mean_advantage = (valid_teacher_logprobs - valid_student_logprobs).mean() + self.metrics.record("distill/reverse_kl", mean_kl.item()) + self.metrics.record("distill/prob_ratio", prob_ratio.item()) + self.metrics.record("distill/advantage", mean_advantage.item()) + + # generation/ group + self.metrics.record("generation/avg_length", avg_completion_len.item()) + else: + self.metrics.record("generation/avg_length", 0.0) + + return loss + + def step(self, batch: Dict[str, torch.Tensor]) -> None: + """Execute a single training step. + + Overrides the base step to handle the unique requirements of + on-policy distillation (generation + training). + """ + self.model.train() + + loss = self.loss(batch) + + self.backward(loss) + + def maybe_item(v): + return v.item() if torch.is_tensor(v) else v + + self.metrics.record("loss", maybe_item(loss)) + + self.model.step() + + self.checkpoint() + + # Update step counters + self.global_step = self.model.global_steps + self.global_step_this_run = self.global_step - self.global_step_at_start_this_run + + def evaluate(self) -> None: + """Evaluation loop with detailed metrics for on-policy distillation. + + Note: We intentionally don't use @callback_wrapper here to avoid + the base class's evaluate() being called instead of this override. + + Computes loss and various metrics on the validation set to track: + - Whether student is learning to match teacher + - Generation quality metrics + - KL divergence trends + """ + if self.eval_dataloader is None: + logger.warning("No evaluation dataloader, skipping evaluation") + return + + self.model.eval() + self.teacher_model.eval() + + # Accumulators for metrics + total_loss = 0.0 + total_kl = 0.0 + total_student_logprob = 0.0 + total_teacher_logprob = 0.0 + total_advantage = 0.0 + total_completion_len = 0.0 + total_tokens = 0 + num_batches = 0 + + with torch.no_grad(): + for eval_batch in self.eval_batches: + eval_batch = to_device(eval_batch, self.device) + + input_ids = eval_batch["input_ids"] + attention_mask = eval_batch["attention_mask"] + prompt_lengths = eval_batch["prompt_lengths"] + + # Generate trajectories + generated_ids, labels, gen_attention_mask = self.generate_trajectories( + input_ids=input_ids, + attention_mask=attention_mask, + prompt_lengths=prompt_lengths, + ) + + mask = labels != -100 + num_completion_tokens = mask.sum() + if num_completion_tokens == 0: + continue + + # Forward passes + student_logits = self.model( + input_ids=generated_ids, + attention_mask=gen_attention_mask, + use_cache=False, + ).logits + + teacher_logits = self.teacher_model( + input_ids=generated_ids, + attention_mask=gen_attention_mask, + use_cache=False, + ).logits + + # Compute logprobs + shift_student_logits = student_logits[:, :-1, :].contiguous() + shift_teacher_logits = teacher_logits[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + + vocab_size = shift_student_logits.size(-1) + flat_student_logits = shift_student_logits.view(-1, vocab_size) + flat_teacher_logits = shift_teacher_logits.view(-1, vocab_size) + flat_labels = shift_labels.view(-1) + + student_logprobs = -F.cross_entropy( + flat_student_logits, flat_labels, ignore_index=-100, reduction="none" + ) + teacher_logprobs = -F.cross_entropy( + flat_teacher_logits, flat_labels, ignore_index=-100, reduction="none" + ) + + shift_mask = flat_labels != -100 + num_tokens = shift_mask.sum().item() + + if num_tokens > 0: + valid_student_logprobs = student_logprobs[shift_mask] + valid_teacher_logprobs = teacher_logprobs[shift_mask] + + # Compute policy gradient loss (same as training) + advantage = valid_teacher_logprobs - valid_student_logprobs + batch_loss = -(advantage * valid_student_logprobs).mean() * self.config.beta + total_loss += batch_loss.item() + + # Track reverse KL for monitoring + per_token_kl = valid_student_logprobs - valid_teacher_logprobs + + total_kl += per_token_kl.sum().item() + total_student_logprob += valid_student_logprobs.sum().item() + total_teacher_logprob += valid_teacher_logprobs.sum().item() + total_advantage += advantage.sum().item() + total_tokens += num_tokens + total_completion_len += num_completion_tokens.item() + num_batches += 1 + + # Clean up large tensors to prevent OOM during eval + del student_logits, teacher_logits + del shift_student_logits, shift_teacher_logits + del flat_student_logits, flat_teacher_logits + del student_logprobs, teacher_logprobs + del generated_ids, labels, gen_attention_mask + torch.cuda.empty_cache() + + # Log aggregated metrics + if num_batches > 0 and total_tokens > 0: + avg_loss = total_loss / num_batches + avg_kl = total_kl / total_tokens + avg_student_logprob = total_student_logprob / total_tokens + avg_teacher_logprob = total_teacher_logprob / total_tokens + avg_advantage = total_advantage / total_tokens + avg_completion_len = total_completion_len / ( + num_batches * self.config.micro_batch_size * self.config.num_rollouts_per_prompt + ) + + # Perplexity + student_ppl = torch.exp(torch.tensor(-avg_student_logprob)).item() + teacher_ppl = torch.exp(torch.tensor(-avg_teacher_logprob)).item() + prob_ratio = torch.exp(torch.tensor(avg_student_logprob - avg_teacher_logprob)).item() + + # Record all eval metrics with eval/ prefix (W&B groups by first segment) + self.metrics.record("loss/eval", avg_loss) + self.metrics.record("eval/logprob_student", avg_student_logprob) + self.metrics.record("eval/logprob_teacher", avg_teacher_logprob) + self.metrics.record("eval/logprob_gap", avg_student_logprob - avg_teacher_logprob) + self.metrics.record("eval/perplexity_student", student_ppl) + self.metrics.record("eval/perplexity_teacher", teacher_ppl) + self.metrics.record("eval/reverse_kl", avg_kl) + self.metrics.record("eval/prob_ratio", prob_ratio) + self.metrics.record("eval/advantage", avg_advantage) + self.metrics.record("eval/avg_completion_length", avg_completion_len) + + logger.info( + f"Eval | loss: {avg_loss:.4f} | kl: {avg_kl:.4f} | adv: {avg_advantage:.4f} | " + f"student_ppl: {student_ppl:.2f} | teacher_ppl: {teacher_ppl:.2f} | " + f"prob_ratio: {prob_ratio:.4f} | comp_len: {avg_completion_len:.1f}" + ) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 7035a4b3..0fccbdcd 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -513,7 +513,12 @@ def epoch(self) -> None: self.metrics.print_summary(prefix="eval") if self.wandb_experiment is not None: - metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]} + # Log all eval metrics (those starting with eval/ or ending with /eval) + metrics = { + k: v + for k, v in self.metrics.summary_dict.items() + if k.startswith("eval/") or k.endswith("/eval") + } self.wandb_experiment.log(metrics, step=self.global_step) self.metrics.stop_timer("iter") diff --git a/arctic_training/trainer/utils.py b/arctic_training/trainer/utils.py index 66f479c7..ee9f5fbc 100644 --- a/arctic_training/trainer/utils.py +++ b/arctic_training/trainer/utils.py @@ -16,6 +16,7 @@ from typing import Dict import torch +import torch.nn as nn def to_device(batch: Dict, device: str) -> Dict: @@ -24,3 +25,17 @@ def to_device(batch: Dict, device: str) -> Dict: if isinstance(v, torch.Tensor): output[k] = v.to(device) return output + + +def disable_dropout_in_model(model: nn.Module) -> None: + """Disable dropout in a model by setting dropout probability to 0. + + This is useful for teacher/reference models during distillation or DPO + training where we want deterministic outputs. + + Args: + model: The PyTorch model to modify in-place. + """ + for module in model.modules(): + if isinstance(module, nn.Dropout): + module.p = 0.0 diff --git a/projects/on_policy_distillation/README.md b/projects/on_policy_distillation/README.md new file mode 100644 index 00000000..5160d003 --- /dev/null +++ b/projects/on_policy_distillation/README.md @@ -0,0 +1,58 @@ +# On-Policy Distillation + +Train a student model using on-policy distillation where: +- Student generates its own trajectories +- Teacher provides per-token supervision via reverse KL divergence + +This differs from traditional distillation where the teacher generates trajectories. + +## Usage + +The teacher model is loaded in-memory alongside the student, using DeepSpeed for efficient memory management. Both models must fit in GPU memory (with ZeRO-2/3). + +```bash +arctic_training projects/on_policy_distillation/distill-qwen3.yaml +``` + +## Config Options + +| Parameter | Description | +|-----------|-------------| +| `teacher_model` | Teacher model config (same format as `model`) | +| `disable_teacher_dropout` | Disable dropout in teacher (default: true) | +| `num_rollouts_per_prompt` | Number of student samples per prompt | +| `max_new_tokens` | Maximum generation length | +| `generation_temperature` | Student sampling temperature | +| `beta` | Reverse KL coefficient (higher = stronger teacher signal) | + +## Example Config + +```yaml +type: on_policy_distillation + +# Student model +model: + name_or_path: Qwen/Qwen3-1.7B + dtype: bf16 + +# Teacher model (loaded in-memory) +teacher_model: + name_or_path: Qwen/Qwen3-8B + dtype: bf16 + +# Distillation settings +num_rollouts_per_prompt: 4 +max_new_tokens: 1024 +beta: 1.0 +``` + +## Memory Considerations + +Both student and teacher models are loaded in GPU memory. Options to reduce memory: +- Use DeepSpeed ZeRO-3 for both models +- Use lower precision for teacher (e.g., `dtype: fp16`) +- Use a smaller teacher model + +## Reference + +- [On-Policy Distillation of Language Models](https://arxiv.org/abs/2306.13649) diff --git a/projects/on_policy_distillation/distill-qwen3.yaml b/projects/on_policy_distillation/distill-qwen3.yaml new file mode 100644 index 00000000..d3db0b5b --- /dev/null +++ b/projects/on_policy_distillation/distill-qwen3.yaml @@ -0,0 +1,99 @@ +# On-Policy Distillation Configuration +# +# This config demonstrates training a student model (Qwen3-1.7B) using +# on-policy distillation from a teacher model (Qwen3-8B) loaded in-memory. +# +# The key difference from traditional (off-policy) distillation: +# - Off-policy: Teacher generates trajectories, student imitates +# - On-policy: Student generates trajectories, teacher corrects via reverse KL +# +# On-policy distillation is more compute-efficient and avoids compounding errors +# from distribution mismatch. +# +# Memory requirement: Both student and teacher must fit in GPU memory. +# With DeepSpeed ZeRO-2/3, this is typically feasible for most model combinations. + +type: on_policy_distillation + +# Student model configuration +model: + type: huggingface + name_or_path: Qwen/Qwen3-1.7B + dtype: bf16 + attn_implementation: flash_attention_2 + +# Teacher model configuration (loaded in-memory for fast logprob computation) +teacher_model: + type: huggingface + name_or_path: Qwen/Qwen3-8B + dtype: bf16 + attn_implementation: flash_attention_2 + +# Disable dropout in teacher for stable distillation signal +disable_teacher_dropout: true + +# On-policy distillation hyperparameters +num_rollouts_per_prompt: 8 # Number of completions per prompt (increases GPU utilization) +max_new_tokens: 128 # Math solutions are short +generation_temperature: 0.7 # Lower temp for more focused math responses +beta: 1.0 # Reverse KL coefficient + +# Training configuration - OPTIMIZED FOR H200 (143GB per GPU) +world_size: 4 +micro_batch_size: 16 # 16 prompts * 8 rollouts = 128 completions per step +gradient_accumulation_steps: 1 # Single accumulation for max throughput +epochs: 10 + +# Data configuration - GSM8K math reasoning (small dataset, short responses) +data: + sources: + - type: huggingface_instruct + name_or_path: openai/gsm8k + split: train + kwargs: + name: main + role_mapping: + user: question + assistant: answer + max_length: 512 # Math solutions are short + max_prompt_length: 256 + filter_long_prompts: true + train_eval_split: [0.9, 0.1] # 90% train, 10% validation + cache_dir: ./data-cache + num_proc: 16 + +# Evaluation +eval_interval: 10 # Evaluate every 10 steps + +# DeepSpeed configuration +deepspeed: + zero_optimization: + stage: 2 + +# Logging +wandb: + enable: true + project: on-policy-distillation-gsm8k + name: distill-qwen3-1.7b + +logger: + level: INFO + output_dir: ./logs + +# Optimizer +scheduler: + name: cosine + warmup_ratio: 0.1 + +optimizer: + type: fused_adam + betas: [0.9, 0.999] + weight_decay: 0.1 + lr: 1e-5 + +# Checkpointing +checkpoint: + - type: huggingface + save_every_n_steps: 100 + output_dir: ./checkpoints/distill-qwen3-8b + save_end_of_training: true diff --git a/tests/trainer/test_on_policy_distillation_trainer.py b/tests/trainer/test_on_policy_distillation_trainer.py new file mode 100644 index 00000000..80d44e04 --- /dev/null +++ b/tests/trainer/test_on_policy_distillation_trainer.py @@ -0,0 +1,505 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for On-Policy Distillation Trainer components.""" + +import pytest +import torch +import torch.nn as nn + +# Import the modules to register them +from arctic_training.config.on_policy_distillation import OnPolicyDistillationTrainerConfig +from arctic_training.data.on_policy_distillation_factory import DataCollatorForOnPolicyDistillation +from arctic_training.data.on_policy_distillation_factory import OnPolicyDistillationDataFactory +from arctic_training.data.on_policy_distillation_factory import pad_prompts +from arctic_training.trainer.on_policy_distillation_trainer import OnPolicyDistillationTrainer # noqa: F401 +from arctic_training.trainer.utils import disable_dropout_in_model + + +@pytest.mark.skip(reason="Config tests require full distributed setup - tested via integration tests") +class TestOnPolicyDistillationConfig: + """Tests for OnPolicyDistillationTrainerConfig. + + Note: These tests are skipped because they require the full distributed + setup to properly resolve type hints. The config is tested via the + integration tests at the end of this file. + """ + + def test_config_requires_teacher(self, model_name): + """Test that config validation requires teacher_model.""" + with pytest.raises(ValueError): + OnPolicyDistillationTrainerConfig( + model={"name_or_path": model_name}, + data={"sources": ["test-data"], "max_length": 512}, + skip_validation=True, + # No teacher_model + ) + + def test_config_with_teacher_model(self, model_name): + """Test config with in-memory teacher model.""" + config = OnPolicyDistillationTrainerConfig( + model={"name_or_path": model_name}, + data={"sources": ["test-data"], "max_length": 512}, + teacher_model={"name_or_path": model_name}, + skip_validation=True, + ) + assert config.teacher_model.name_or_path == model_name + assert config.disable_teacher_dropout is True + + def test_config_defaults(self, model_name): + """Test default config values.""" + config = OnPolicyDistillationTrainerConfig( + model={"name_or_path": model_name}, + data={"sources": ["test-data"], "max_length": 512}, + teacher_model={"name_or_path": model_name}, + skip_validation=True, + ) + assert config.num_rollouts_per_prompt == 4 + assert config.max_new_tokens == 2048 + assert config.generation_temperature == 1.0 + assert config.beta == 1.0 + assert config.disable_teacher_dropout is True + + def test_generation_batch_size_default(self, model_name): + """Test that generation_batch_size defaults to micro_batch_size.""" + config = OnPolicyDistillationTrainerConfig( + model={"name_or_path": model_name}, + data={"sources": ["test-data"], "max_length": 512}, + teacher_model={"name_or_path": model_name}, + micro_batch_size=4, + skip_validation=True, + ) + assert config.generation_batch_size == 4 + + def test_teacher_deepspeed_config_auto_generated(self, model_name): + """Test that teacher DeepSpeed config is auto-generated.""" + config = OnPolicyDistillationTrainerConfig( + model={"name_or_path": model_name}, + data={"sources": ["test-data"], "max_length": 512}, + teacher_model={"name_or_path": model_name}, + skip_validation=True, + ) + # Teacher deepspeed should be auto-generated + assert len(config.teacher_deepspeed) > 0 + assert "zero_optimization" in config.teacher_deepspeed + + +class TestPadPrompts: + """Tests for the pad_prompts utility function.""" + + def test_left_padding(self): + """Test left padding of prompts.""" + tensors = [ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5]), + torch.tensor([6, 7, 8, 9]), + ] + padded = pad_prompts(tensors, padding_value=0, padding_side="left", divisible_by=4) + + assert padded.shape == (3, 4) # Padded to divisible by 4 + assert padded[0].tolist() == [0, 1, 2, 3] # Left-padded + assert padded[1].tolist() == [0, 0, 4, 5] + assert padded[2].tolist() == [6, 7, 8, 9] + + def test_right_padding(self): + """Test right padding of prompts.""" + tensors = [ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5]), + ] + padded = pad_prompts(tensors, padding_value=0, padding_side="right", divisible_by=4) + + assert padded.shape == (2, 4) + assert padded[0].tolist() == [1, 2, 3, 0] # Right-padded + assert padded[1].tolist() == [4, 5, 0, 0] + + def test_padding_divisibility(self): + """Test that output length is divisible by specified value.""" + tensors = [torch.tensor([1, 2, 3, 4, 5])] # Length 5 + padded = pad_prompts(tensors, padding_value=0, divisible_by=8) + + assert padded.shape == (1, 8) # Padded to 8 (divisible by 8) + + +class TestDataCollator: + """Tests for DataCollatorForOnPolicyDistillation.""" + + def test_collator_output_structure(self): + """Test that collator produces expected output structure.""" + from unittest.mock import MagicMock + + # Create mock tokenizer + tokenizer = MagicMock() + tokenizer.pad_token_id = 0 + tokenizer.eos_token_id = 1 + + # Create mock config + config = MagicMock() + config.div_length = 4 + + collator = DataCollatorForOnPolicyDistillation(tokenizer, config) + + instances = [ + {"input_ids": [1, 2, 3]}, + {"input_ids": [4, 5]}, + ] + + result = collator(instances) + + assert "input_ids" in result + assert "attention_mask" in result + assert "prompt_lengths" in result + assert result["prompt_lengths"].tolist() == [3, 2] + + def test_collator_left_padding(self): + """Test that collator applies left padding.""" + from unittest.mock import MagicMock + + tokenizer = MagicMock() + tokenizer.pad_token_id = 0 + tokenizer.eos_token_id = 1 + + config = MagicMock() + config.div_length = 4 + + collator = DataCollatorForOnPolicyDistillation(tokenizer, config) + + instances = [ + {"input_ids": [1, 2, 3]}, + {"input_ids": [4, 5]}, + ] + + result = collator(instances) + + # Check left padding (real tokens at the end) + assert result["input_ids"][0, -1].item() == 3 + assert result["input_ids"][1, -1].item() == 5 + + +class TestPolicyGradientLoss: + """Tests for the policy gradient loss computation. + + The on-policy distillation loss uses policy gradient with advantage: + advantage = teacher_logprob - student_logprob + loss = -(advantage * student_logprob).mean() + + This ensures: + - When teacher > student: positive advantage -> increase student_logprob + - When teacher < student: negative advantage -> decrease student_logprob + """ + + def test_advantage_computation(self): + """Test that advantage is computed correctly.""" + student_logprobs = torch.tensor([-1.0, -2.0, -3.0]) + teacher_logprobs = torch.tensor([-0.5, -2.0, -4.0]) + + advantage = teacher_logprobs - student_logprobs + + # Position 0: teacher (-0.5) > student (-1.0) -> positive advantage + assert advantage[0] > 0 + # Position 1: teacher == student -> zero advantage + assert advantage[1] == 0 + # Position 2: teacher (-4.0) < student (-3.0) -> negative advantage + assert advantage[2] < 0 + + def test_policy_gradient_loss_direction(self): + """Test that policy gradient loss has correct gradient direction.""" + # Student logprobs (requires grad) + student_logprobs = torch.tensor([-1.0, -2.0, -3.0], requires_grad=True) + teacher_logprobs = torch.tensor([-0.5, -2.0, -4.0]) + + # Compute advantage (detached as in actual implementation) + advantage = teacher_logprobs - student_logprobs.detach() + + # Policy gradient loss + loss = -(advantage * student_logprobs).mean() + loss.backward() + + # Check gradient directions + # Position 0: advantage > 0, so gradient should be negative (to increase logprob) + # Position 2: advantage < 0, so gradient should be positive (to decrease logprob) + assert student_logprobs.grad[0] < 0 # Will increase student_logprob + assert student_logprobs.grad[2] > 0 # Will decrease student_logprob + + def test_loss_zero_when_distributions_match(self): + """Test that loss is zero when student matches teacher.""" + student_logprobs = torch.tensor([-1.0, -2.0, -3.0]) + teacher_logprobs = torch.tensor([-1.0, -2.0, -3.0]) + + advantage = teacher_logprobs - student_logprobs + loss = -(advantage * student_logprobs).mean() + + assert torch.isclose(loss, torch.tensor(0.0)) + + def test_loss_with_masking(self): + """Test loss computation with masked positions.""" + student_logprobs = torch.tensor([-1.0, -2.0, -999.0]) # Last is padding + teacher_logprobs = torch.tensor([-0.5, -1.0, -999.0]) + mask = torch.tensor([True, True, False]) + + # Apply mask + valid_student = student_logprobs[mask] + valid_teacher = teacher_logprobs[mask] + + advantage = valid_teacher - valid_student + loss = -(advantage * valid_student).mean() + + # Manual calculation for first two positions only + adv_0 = -0.5 - (-1.0) # = 0.5 + adv_1 = -1.0 - (-2.0) # = 1.0 + expected_loss = -((adv_0 * -1.0) + (adv_1 * -2.0)) / 2 + + assert torch.isclose(loss, torch.tensor(expected_loss)) + + +class TestReverseKLMetric: + """Tests for reverse KL metric computation (for monitoring, not loss).""" + + def test_reverse_kl_identical_distributions(self): + """Test that KL is zero when distributions are identical.""" + student_logprobs = torch.tensor([-1.0, -2.0, -3.0]) + teacher_logprobs = torch.tensor([-1.0, -2.0, -3.0]) + + # Reverse KL: student - teacher + reverse_kl = (student_logprobs - teacher_logprobs).mean() + + assert torch.isclose(reverse_kl, torch.tensor(0.0)) + + def test_reverse_kl_student_worse(self): + """Test KL when student is worse (lower logprobs) than teacher.""" + student_logprobs = torch.tensor([-2.0, -3.0, -4.0]) # Lower (worse) + teacher_logprobs = torch.tensor([-1.0, -2.0, -3.0]) # Higher (better) + + reverse_kl = (student_logprobs - teacher_logprobs).mean() + + # Student logprobs are lower, so student - teacher is negative + assert reverse_kl < 0 + + def test_reverse_kl_student_overconfident(self): + """Test KL when student is overconfident (higher logprobs than teacher).""" + student_logprobs = torch.tensor([-0.5, -1.0, -1.5]) # Higher (overconfident) + teacher_logprobs = torch.tensor([-1.0, -2.0, -3.0]) # Lower + + reverse_kl = (student_logprobs - teacher_logprobs).mean() + + # Student logprobs are higher, so student - teacher is positive + assert reverse_kl > 0 + + +class TestDataFactoryTokenization: + """Tests for OnPolicyDistillationDataFactory tokenization.""" + + def test_tokenize_prompt_extracts_user_message(self): + """Test that tokenize_prompt extracts prompt correctly.""" + from unittest.mock import MagicMock + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "User: Hello" + tokenizer.return_value = {"input_ids": [1, 2, 3, 4]} + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + + OnPolicyDistillationDataFactory.tokenize_prompt(messages, tokenizer, include_system=True) + + # Should only include user message in template call + call_args = tokenizer.apply_chat_template.call_args + assert len(call_args[1]["conversation"]) == 1 + assert call_args[1]["conversation"][0]["role"] == "user" + + def test_tokenize_prompt_includes_system(self): + """Test that system messages are included when specified.""" + from unittest.mock import MagicMock + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "System: Be helpful\nUser: Hello" + tokenizer.return_value = {"input_ids": [1, 2, 3, 4, 5]} + + messages = [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + + OnPolicyDistillationDataFactory.tokenize_prompt(messages, tokenizer, include_system=True) + + call_args = tokenizer.apply_chat_template.call_args + assert len(call_args[1]["conversation"]) == 2 + + def test_tokenize_prompt_excludes_system(self): + """Test that system messages can be excluded.""" + from unittest.mock import MagicMock + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "User: Hello" + tokenizer.return_value = {"input_ids": [1, 2, 3]} + + messages = [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + + OnPolicyDistillationDataFactory.tokenize_prompt(messages, tokenizer, include_system=False) + + call_args = tokenizer.apply_chat_template.call_args + assert len(call_args[1]["conversation"]) == 1 + assert call_args[1]["conversation"][0]["role"] == "user" + + +class TestDisableDropout: + """Tests for disable_dropout_in_model utility.""" + + def test_disable_dropout(self): + """Test that dropout is disabled in model.""" + # Create a simple model with dropout + model = nn.Sequential( + nn.Linear(10, 10), + nn.Dropout(p=0.5), + nn.Linear(10, 10), + nn.Dropout(p=0.3), + ) + + # Verify dropout is enabled + dropouts = [m for m in model.modules() if isinstance(m, nn.Dropout)] + assert len(dropouts) == 2 + assert dropouts[0].p == 0.5 + assert dropouts[1].p == 0.3 + + # Disable dropout + disable_dropout_in_model(model) + + # Verify dropout is disabled + for dropout in dropouts: + assert dropout.p == 0.0 + + def test_disable_dropout_nested(self): + """Test that dropout is disabled in nested modules.""" + model = nn.Sequential( + nn.Sequential( + nn.Linear(10, 10), + nn.Dropout(p=0.5), + ), + nn.Dropout(p=0.3), + ) + + disable_dropout_in_model(model) + + dropouts = [m for m in model.modules() if isinstance(m, nn.Dropout)] + for dropout in dropouts: + assert dropout.p == 0.0 + + +class TestNumRolloutsPerPrompt: + """Tests for num_rollouts_per_prompt functionality.""" + + def test_repeat_interleave_prompts(self): + """Test that prompts are correctly repeated for multiple rollouts.""" + # Simulate input tensors + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 2 prompts + prompt_lengths = torch.tensor([3, 3]) + num_rollouts = 3 + + # Repeat interleave + expanded_input_ids = input_ids.repeat_interleave(num_rollouts, dim=0) + expanded_lengths = prompt_lengths.repeat_interleave(num_rollouts, dim=0) + + # Should have 6 rows (2 prompts * 3 rollouts) + assert expanded_input_ids.shape[0] == 6 + assert expanded_lengths.shape[0] == 6 + + # Check pattern: [p1, p1, p1, p2, p2, p2] + assert torch.equal(expanded_input_ids[0], input_ids[0]) + assert torch.equal(expanded_input_ids[1], input_ids[0]) + assert torch.equal(expanded_input_ids[2], input_ids[0]) + assert torch.equal(expanded_input_ids[3], input_ids[1]) + + +# Integration tests +@pytest.mark.parametrize( + "run_on_cpu", + [ + True, + pytest.param(False, marks=pytest.mark.gpu), + ], +) +def test_on_policy_distillation_trainer(model_name, run_on_cpu): + """Test full on-policy distillation training loop.""" + from tests.utils import run_dummy_training + + run_dummy_training( + { + "type": "on_policy_distillation", + "model": { + "type": "random-weight-hf", + "name_or_path": model_name, + "dtype": "float32", + }, + "teacher_model": { + "type": "random-weight-hf", + "name_or_path": model_name, + "dtype": "float32", + }, + "data": { + "max_length": 512, + "max_prompt_length": 128, + "sources": ["HuggingFaceH4/ultrachat_200k:train[:10]"], + }, + "max_new_tokens": 32, + "num_rollouts_per_prompt": 1, + "beta": 1.0, + }, + run_on_cpu=run_on_cpu, + ) + + +@pytest.mark.parametrize( + "run_on_cpu", + [ + True, + pytest.param(False, marks=pytest.mark.gpu), + ], +) +def test_on_policy_distillation_with_multiple_rollouts(model_name, run_on_cpu): + """Test on-policy distillation with multiple rollouts per prompt.""" + from tests.utils import run_dummy_training + + run_dummy_training( + { + "type": "on_policy_distillation", + "model": { + "type": "random-weight-hf", + "name_or_path": model_name, + "dtype": "float32", + }, + "teacher_model": { + "type": "random-weight-hf", + "name_or_path": model_name, + "dtype": "float32", + }, + "data": { + "max_length": 512, + "max_prompt_length": 128, + "sources": ["HuggingFaceH4/ultrachat_200k:train[:10]"], + }, + "max_new_tokens": 32, + "num_rollouts_per_prompt": 2, # Multiple rollouts + "micro_batch_size": 2, + "beta": 1.0, + }, + run_on_cpu=run_on_cpu, + )