Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ffde2d0
Use `VLLMGeneration` in `GOLDTrainer`
cmpatino Mar 18, 2026
1797fc1
Update with precommit
cmpatino Mar 19, 2026
f723677
Initial `DistillationTrainer` implementation
cmpatino Mar 23, 2026
b629987
Fix how we handle padding and special tokens
cmpatino Mar 23, 2026
6746237
Initial implementation of distillation trainer
cmpatino Mar 23, 2026
cdc3196
Address concern about vllm weight sync
cmpatino Mar 23, 2026
f4c193e
Run precommit
cmpatino Mar 23, 2026
2b41f84
Fix max len behavior for generation
cmpatino Mar 23, 2026
91715cb
Format docstring
cmpatino Mar 23, 2026
0ded0db
Merge branch 'kd-vllm-generation' into kd-distillation-trainer
cmpatino Mar 24, 2026
b8754db
Fix data collation issue
cmpatino Mar 24, 2026
b94fc1f
Remove decode -> re-tokenization roundtrip
cmpatino Mar 24, 2026
dcfce59
Run precommit
cmpatino Mar 24, 2026
ff81a89
Add check for tokenizers and prompt length
cmpatino Mar 24, 2026
d075a1e
Merge branch 'kd-vllm-generation' into kd-distillation-trainer
cmpatino Mar 24, 2026
6eeaa8f
Merge branch 'kd-distillation-trainer' of github.com:cmpatino/trl int…
cmpatino Mar 24, 2026
f5fc947
Implement efficient logprob generation
cmpatino Mar 27, 2026
42f3d4d
Fix top-k implementation
cmpatino Mar 30, 2026
4a00c68
Merge branch 'main' into kd-distillation-trainer
cmpatino Mar 30, 2026
d13bc4a
Migrate trainer to experimental
cmpatino Mar 30, 2026
64fdafd
Fix reverse KL calculation for top-1
cmpatino Mar 30, 2026
2730c58
Merge branch 'main' into kd-distillation-trainer
cmpatino Mar 30, 2026
0d4bd05
Run precommit
cmpatino Mar 30, 2026
59aa007
Address cursor comments
cmpatino Mar 30, 2026
91bffa4
Fix reverse KL computation
cmpatino Mar 30, 2026
6d0fe13
Add `DistillationTrainer` to table of contents
cmpatino Mar 30, 2026
b55996e
Remove `DistillationTrainer` from toc
cmpatino Mar 30, 2026
9727ec7
Tighten logic for different top-k scenarios
cmpatino Mar 30, 2026
939b53b
Add tail bucket to reverse KL + server case
cmpatino Mar 30, 2026
47eacd5
Merge branch 'main' into kd-distillation-trainer
cmpatino Mar 31, 2026
bcf21d2
Fix dead code when using full vocab for external teacher
cmpatino Mar 31, 2026
aaf47ad
Remove unused function
cmpatino Mar 31, 2026
89a7b88
Add guard in config for liger + external teacher
cmpatino Mar 31, 2026
16a5380
Tighten alignment logic
cmpatino Mar 31, 2026
1c7b9de
Run precommit
cmpatino Mar 31, 2026
3cb4bc3
Correct completion logic
cmpatino Mar 31, 2026
d1b97a8
Remove wandb config params
cmpatino Mar 31, 2026
691f46c
Address Albert's comments
cmpatino Apr 1, 2026
9ef9192
Run precommit
cmpatino Apr 1, 2026
bb2eee6
Address PR comments
cmpatino Apr 2, 2026
088db42
Match behavior between local and external teacher when top-1 and bet…
cmpatino Apr 7, 2026
90c6e81
Run precommit
cmpatino Apr 7, 2026
864ba11
Address latest comments
cmpatino Apr 9, 2026
0f65a7b
Merge branch 'main' into kd-distillation-trainer
cmpatino Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1563,21 +1563,21 @@ Papers relating to training a student model with the help of a teacher model.

**📜 Paper**: https://huggingface.co/papers/2306.13649

Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration:
Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.distillation.DistillationTrainer`] and [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration:

```python
from trl.experimental.gkd import GKDConfig
from trl.experimental.distillation import DistillationConfig

# XSum summarization task (Table A.1 of the paper)
training_args = GKDConfig(
training_args = DistillationConfig(
lmbda=0.5, # λ student data fraction (Section 3 of the paper)
beta=0.5, # β Generalized JSD interpolation, 0=KL, 1=reverse KL (Section 3 of the paper)
temperature=1.0, # student training temperature (Appendix A of the paper)
max_steps=40000, # training steps (Table A.1 of the paper)
learning_rate=3e-4, # learning rate (Table A.1 of the paper)
per_device_train_batch_size=32, # batch size (Table A.1 of the paper)
warmup_steps=2000, # warm-up steps (Table A.1 of the paper)
max_new_tokens=64, # max output tokens (Table A.1 of the paper)
max_completion_length=64, # max output tokens (Table A.1 of the paper)
)
```

Expand All @@ -1597,20 +1597,31 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to

Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.

To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.distillation.DistillationTrainer`] and [`experimental.distillation.DistillationConfig`]:

```python
from trl.experimental.distillation import DistillationConfig

training_args = DistillationConfig(
lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
teacher_model_name_or_path="teacher-model", # specify the teacher model
)
```

Alternatively, you can use the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:

```python
from trl.experimental.gkd import GKDConfig

training_args = GKDConfig(
lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
teacher_model_name_or_path="teacher-model", # specify the teacher model

lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
teacher_model_name_or_path="teacher-model", # specify the teacher model
)
```

Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:
You can also use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:

```python
from trl.experimental import GOLDConfig
Expand Down
19 changes: 19 additions & 0 deletions trl/experimental/distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .distillation_config import DistillationConfig
from .distillation_trainer import DistillationTrainer


__all__ = ["DistillationConfig", "DistillationTrainer"]
179 changes: 179 additions & 0 deletions trl/experimental/distillation/distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# /// script
# dependencies = [
# "trl",
# "peft",
# "trackio",
# "kernels",
# ]
# ///

# docstyle-ignore
"""
# Full training (off-policy only, lmbda=0):
```
python trl/experimental/distillation/distillation.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--lmbda 0.0 \
--output_dir distilled-model \
--num_train_epochs 1
```

# Mixed on/off-policy (lmbda=0.5):
```
python trl/experimental/distillation/distillation.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--lmbda 0.5 \
--beta 0.5 \
--output_dir distilled-model \
--num_train_epochs 1
```

# LoRA:
```
python trl/experimental/distillation/distillation.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-4 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--lmbda 0.0 \
--output_dir distilled-model \
--num_train_epochs 1 \
--use_peft \
--lora_r 64 \
--lora_alpha 16
```
"""

import argparse
import os


# Enable logging in a Hugging Face Space
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")


def main(script_args, training_args, model_args):
from datasets import load_dataset
from transformers import GenerationConfig

from trl import (
LogCompletionsCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.experimental.distillation import DistillationTrainer

################
# Model init kwargs
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs

teacher_model_kwargs = dict(
revision=training_args.teacher_model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
if training_args.teacher_model_init_kwargs is not None:
teacher_model_kwargs.update(training_args.teacher_model_init_kwargs)
training_args.teacher_model_init_kwargs = teacher_model_kwargs

################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

################
# Training
################
eval_dataset = None
if training_args.eval_strategy != "no":
if script_args.dataset_test_split in dataset:
eval_dataset = dataset[script_args.dataset_test_split]
elif "validation" in dataset:
eval_dataset = dataset["validation"]
elif "dev" in dataset:
eval_dataset = dataset["dev"]

trainer = DistillationTrainer(
model=model_args.model_name_or_path,
teacher_model=training_args.teacher_model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_args),
)

if training_args.eval_strategy != "no":
generation_config = GenerationConfig(
max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)

trainer.train()

# Save and push to Hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)


def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None):
from trl import ModelConfig, ScriptArguments, TrlParser
from trl.experimental.distillation import DistillationConfig

dataclass_types = (ScriptArguments, DistillationConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser(
"distillation", help="Run the distillation training script", dataclass_types=dataclass_types
)
else:
parser = TrlParser(dataclass_types, prog=prog)
return parser


if __name__ == "__main__":
parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config(fail_with_unknown_args=False)
main(script_args, training_args, model_args)
Loading
Loading