Skip to content

Commit c475b97

Browse files
authored
Add DistillationTrainer for efficient on-policy distillation (#5407)
1 parent 89c5ed6 commit c475b97

File tree

7 files changed

+2873
-10
lines changed

7 files changed

+2873
-10
lines changed

docs/source/paper_index.md

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,21 +1563,21 @@ Papers relating to training a student model with the help of a teacher model.
15631563

15641564
**📜 Paper**: https://huggingface.co/papers/2306.13649
15651565

1566-
Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration:
1566+
Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.distillation.DistillationTrainer`] and [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration:
15671567

15681568
```python
1569-
from trl.experimental.gkd import GKDConfig
1569+
from trl.experimental.distillation import DistillationConfig
15701570

15711571
# XSum summarization task (Table A.1 of the paper)
1572-
training_args = GKDConfig(
1572+
training_args = DistillationConfig(
15731573
lmbda=0.5, # λ student data fraction (Section 3 of the paper)
15741574
beta=0.5, # β Generalized JSD interpolation, 0=KL, 1=reverse KL (Section 3 of the paper)
15751575
temperature=1.0, # student training temperature (Appendix A of the paper)
15761576
max_steps=40000, # training steps (Table A.1 of the paper)
15771577
learning_rate=3e-4, # learning rate (Table A.1 of the paper)
15781578
per_device_train_batch_size=32, # batch size (Table A.1 of the paper)
15791579
warmup_steps=2000, # warm-up steps (Table A.1 of the paper)
1580-
max_new_tokens=64, # max output tokens (Table A.1 of the paper)
1580+
max_completion_length=64, # max output tokens (Table A.1 of the paper)
15811581
)
15821582
```
15831583

@@ -1597,20 +1597,31 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to
15971597

15981598
Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.
15991599

1600-
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:
1600+
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.distillation.DistillationTrainer`] and [`experimental.distillation.DistillationConfig`]:
1601+
1602+
```python
1603+
from trl.experimental.distillation import DistillationConfig
1604+
1605+
training_args = DistillationConfig(
1606+
lmbda=1.0, # student produces rollouts for all batches
1607+
beta=1.0, # to ensure reverse-kl as the loss function
1608+
teacher_model_name_or_path="teacher-model", # specify the teacher model
1609+
)
1610+
```
1611+
1612+
Alternatively, you can use the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:
16011613

16021614
```python
16031615
from trl.experimental.gkd import GKDConfig
16041616

16051617
training_args = GKDConfig(
1606-
lmbda=1.0, # student produces rollouts for all batches
1607-
beta=1.0, # to ensure reverse-kl as the loss function
1608-
teacher_model_name_or_path="teacher-model", # specify the teacher model
1609-
1618+
lmbda=1.0, # student produces rollouts for all batches
1619+
beta=1.0, # to ensure reverse-kl as the loss function
1620+
teacher_model_name_or_path="teacher-model", # specify the teacher model
16101621
)
16111622
```
16121623

1613-
Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:
1624+
You can also use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:
16141625

16151626
```python
16161627
from trl.experimental import GOLDConfig
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .distillation_config import DistillationConfig
16+
from .distillation_trainer import DistillationTrainer
17+
18+
19+
__all__ = ["DistillationConfig", "DistillationTrainer"]
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# /// script
16+
# dependencies = [
17+
# "trl",
18+
# "peft",
19+
# "trackio",
20+
# "kernels",
21+
# ]
22+
# ///
23+
24+
# docstyle-ignore
25+
"""
26+
# Full training (off-policy only, lmbda=0):
27+
```
28+
python trl/experimental/distillation/distillation.py \
29+
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
30+
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
31+
--dataset_name trl-lib/chatbot_arena_completions \
32+
--learning_rate 2e-5 \
33+
--per_device_train_batch_size 4 \
34+
--gradient_accumulation_steps 8 \
35+
--lmbda 0.0 \
36+
--output_dir distilled-model \
37+
--num_train_epochs 1
38+
```
39+
40+
# Mixed on/off-policy (lmbda=0.5):
41+
```
42+
python trl/experimental/distillation/distillation.py \
43+
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
44+
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
45+
--dataset_name trl-lib/chatbot_arena_completions \
46+
--learning_rate 2e-5 \
47+
--per_device_train_batch_size 4 \
48+
--gradient_accumulation_steps 8 \
49+
--lmbda 0.5 \
50+
--beta 0.5 \
51+
--output_dir distilled-model \
52+
--num_train_epochs 1
53+
```
54+
55+
# LoRA:
56+
```
57+
python trl/experimental/distillation/distillation.py \
58+
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
59+
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
60+
--dataset_name trl-lib/chatbot_arena_completions \
61+
--learning_rate 2e-4 \
62+
--per_device_train_batch_size 4 \
63+
--gradient_accumulation_steps 8 \
64+
--lmbda 0.0 \
65+
--output_dir distilled-model \
66+
--num_train_epochs 1 \
67+
--use_peft \
68+
--lora_r 64 \
69+
--lora_alpha 16
70+
```
71+
"""
72+
73+
import argparse
74+
import os
75+
76+
77+
# Enable logging in a Hugging Face Space
78+
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
79+
80+
81+
def main(script_args, training_args, model_args):
82+
from datasets import load_dataset
83+
from transformers import GenerationConfig
84+
85+
from trl import (
86+
LogCompletionsCallback,
87+
get_kbit_device_map,
88+
get_peft_config,
89+
get_quantization_config,
90+
)
91+
from trl.experimental.distillation import DistillationTrainer
92+
93+
################
94+
# Model init kwargs
95+
################
96+
quantization_config = get_quantization_config(model_args)
97+
model_kwargs = dict(
98+
revision=model_args.model_revision,
99+
trust_remote_code=model_args.trust_remote_code,
100+
attn_implementation=model_args.attn_implementation,
101+
torch_dtype=model_args.dtype,
102+
use_cache=False if training_args.gradient_checkpointing else True,
103+
device_map=get_kbit_device_map() if quantization_config is not None else None,
104+
quantization_config=quantization_config,
105+
)
106+
training_args.model_init_kwargs = model_kwargs
107+
108+
teacher_model_kwargs = dict(
109+
revision=training_args.teacher_model_revision,
110+
trust_remote_code=model_args.trust_remote_code,
111+
attn_implementation=model_args.attn_implementation,
112+
torch_dtype=model_args.dtype,
113+
use_cache=True,
114+
device_map=get_kbit_device_map() if quantization_config is not None else None,
115+
quantization_config=quantization_config,
116+
)
117+
if training_args.teacher_model_init_kwargs is not None:
118+
teacher_model_kwargs.update(training_args.teacher_model_init_kwargs)
119+
training_args.teacher_model_init_kwargs = teacher_model_kwargs
120+
121+
################
122+
# Dataset
123+
################
124+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
125+
126+
################
127+
# Training
128+
################
129+
eval_dataset = None
130+
if training_args.eval_strategy != "no":
131+
if script_args.dataset_test_split in dataset:
132+
eval_dataset = dataset[script_args.dataset_test_split]
133+
elif "validation" in dataset:
134+
eval_dataset = dataset["validation"]
135+
elif "dev" in dataset:
136+
eval_dataset = dataset["dev"]
137+
138+
trainer = DistillationTrainer(
139+
model=model_args.model_name_or_path,
140+
teacher_model=training_args.teacher_model_name_or_path,
141+
args=training_args,
142+
train_dataset=dataset[script_args.dataset_train_split],
143+
eval_dataset=eval_dataset,
144+
peft_config=get_peft_config(model_args),
145+
)
146+
147+
if training_args.eval_strategy != "no":
148+
generation_config = GenerationConfig(
149+
max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature
150+
)
151+
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
152+
trainer.add_callback(completions_callback)
153+
154+
trainer.train()
155+
156+
# Save and push to Hub
157+
trainer.save_model(training_args.output_dir)
158+
if training_args.push_to_hub:
159+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
160+
161+
162+
def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None):
163+
from trl import ModelConfig, ScriptArguments, TrlParser
164+
from trl.experimental.distillation import DistillationConfig
165+
166+
dataclass_types = (ScriptArguments, DistillationConfig, ModelConfig)
167+
if subparsers is not None:
168+
parser = subparsers.add_parser(
169+
"distillation", help="Run the distillation training script", dataclass_types=dataclass_types
170+
)
171+
else:
172+
parser = TrlParser(dataclass_types, prog=prog)
173+
return parser
174+
175+
176+
if __name__ == "__main__":
177+
parser = make_parser()
178+
script_args, training_args, model_args = parser.parse_args_and_config(fail_with_unknown_args=False)
179+
main(script_args, training_args, model_args)

0 commit comments

Comments
 (0)