Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ffde2d0
Use `VLLMGeneration` in `GOLDTrainer`
cmpatino Mar 18, 2026
1797fc1
Update with precommit
cmpatino Mar 19, 2026
f723677
Initial `DistillationTrainer` implementation
cmpatino Mar 23, 2026
b629987
Fix how we handle padding and special tokens
cmpatino Mar 23, 2026
6746237
Initial implementation of distillation trainer
cmpatino Mar 23, 2026
cdc3196
Address concern about vllm weight sync
cmpatino Mar 23, 2026
f4c193e
Run precommit
cmpatino Mar 23, 2026
2b41f84
Fix max len behavior for generation
cmpatino Mar 23, 2026
91715cb
Format docstring
cmpatino Mar 23, 2026
0ded0db
Merge branch 'kd-vllm-generation' into kd-distillation-trainer
cmpatino Mar 24, 2026
b8754db
Fix data collation issue
cmpatino Mar 24, 2026
b94fc1f
Remove decode -> re-tokenization roundtrip
cmpatino Mar 24, 2026
dcfce59
Run precommit
cmpatino Mar 24, 2026
ff81a89
Add check for tokenizers and prompt length
cmpatino Mar 24, 2026
d075a1e
Merge branch 'kd-vllm-generation' into kd-distillation-trainer
cmpatino Mar 24, 2026
6eeaa8f
Merge branch 'kd-distillation-trainer' of github.com:cmpatino/trl int…
cmpatino Mar 24, 2026
f5fc947
Implement efficient logprob generation
cmpatino Mar 27, 2026
42f3d4d
Fix top-k implementation
cmpatino Mar 30, 2026
4a00c68
Merge branch 'main' into kd-distillation-trainer
cmpatino Mar 30, 2026
d13bc4a
Migrate trainer to experimental
cmpatino Mar 30, 2026
64fdafd
Fix reverse KL calculation for top-1
cmpatino Mar 30, 2026
2730c58
Merge branch 'main' into kd-distillation-trainer
cmpatino Mar 30, 2026
0d4bd05
Run precommit
cmpatino Mar 30, 2026
59aa007
Address cursor comments
cmpatino Mar 30, 2026
91bffa4
Fix reverse KL computation
cmpatino Mar 30, 2026
6d0fe13
Add `DistillationTrainer` to table of contents
cmpatino Mar 30, 2026
b55996e
Remove `DistillationTrainer` from toc
cmpatino Mar 30, 2026
9727ec7
Tighten logic for different top-k scenarios
cmpatino Mar 30, 2026
939b53b
Add tail bucket to reverse KL + server case
cmpatino Mar 30, 2026
47eacd5
Merge branch 'main' into kd-distillation-trainer
cmpatino Mar 31, 2026
bcf21d2
Fix dead code when using full vocab for external teacher
cmpatino Mar 31, 2026
aaf47ad
Remove unused function
cmpatino Mar 31, 2026
89a7b88
Add guard in config for liger + external teacher
cmpatino Mar 31, 2026
16a5380
Tighten alignment logic
cmpatino Mar 31, 2026
1c7b9de
Run precommit
cmpatino Mar 31, 2026
3cb4bc3
Correct completion logic
cmpatino Mar 31, 2026
d1b97a8
Remove wandb config params
cmpatino Mar 31, 2026
691f46c
Address Albert's comments
cmpatino Apr 1, 2026
9ef9192
Run precommit
cmpatino Apr 1, 2026
bb2eee6
Address PR comments
cmpatino Apr 2, 2026
088db42
Match behavior between local and external teacher when top-1 and bet…
cmpatino Apr 7, 2026
90c6e81
Run precommit
cmpatino Apr 7, 2026
864ba11
Address latest comments
cmpatino Apr 9, 2026
0f65a7b
Merge branch 'main' into kd-distillation-trainer
cmpatino Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 268 additions & 51 deletions tests/experimental/test_gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from trl.experimental.gold import gold_trainer as gold_trainer_module
from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts
from trl.experimental.utils import DataCollatorForChatML

Expand Down Expand Up @@ -289,58 +290,11 @@ def pad_labels(labels, target_length):
return labels + [-100] * (target_length - len(labels))


def test_process_completions_to_buffer_left_pads_prompt_retokenization():
class DummyBatch:
def __init__(self, input_ids):
self.input_ids = input_ids

def to(self, device):
self.input_ids = self.input_ids.to(device)
return self

def test_process_completions_to_buffer_left_pads_prompt_ids():
class RecordingTokenizer:
pad_token_id = 0
pad_token = "<pad>"

def __init__(self):
self.padding_side = "right"
self.calls = []
self._prompt_ids = {
"short": [11],
"longer": [21, 22],
}

def __call__(
self,
texts,
return_tensors,
padding,
truncation,
max_length,
add_special_tokens,
padding_side=None,
):
assert return_tensors == "pt"
assert padding == "longest"
assert not truncation
assert max_length is None
assert not add_special_tokens
self.calls.append(padding_side)

side = padding_side or self.padding_side
encoded = [torch.tensor(self._prompt_ids[text], dtype=torch.long) for text in texts]
max_len = max(len(ids) for ids in encoded)

padded = []
for ids in encoded:
pad_width = max_len - len(ids)
if pad_width:
pad = torch.full((pad_width,), self.pad_token_id, dtype=torch.long)
ids = torch.cat([pad, ids]) if side == "left" else torch.cat([ids, pad])
padded.append(ids)

return DummyBatch(torch.stack(padded))

def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False):
del skip_special_tokens, clean_up_tokenization_spaces
return [" ".join(str(token) for token in sequence) for sequence in sequences]
Expand All @@ -358,19 +312,282 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati
on_policy_indices=[0],
local_slice_indices=[0, 0],
completion_ids=[[31], [41]],
prompts_text=["short", "longer"],
prompts_text_with_special=["short", "longer"],
prompt_ids_list=[[11], [21, 22]],
prompts_text=["short", "longer"],
max_completion_length=1,
)

buffered_inputs = trainer._buffered_inputs[0]
assert trainer.processing_class.calls == ["left"]
assert trainer.processing_class.padding_side == "right"
assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[0, 11, 31], [21, 22, 41]], dtype=torch.long))
assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long))
assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, 31], [-100, -100, 41]]))


def test_generate_on_policy_for_slices_uses_prompt_attention_mask_for_vllm_prompts():
class RecordingVLLMGeneration:
def __init__(self):
self.prompts = None
self.sync_calls = 0

def sync_weights(self):
self.sync_calls += 1

def generate(self, prompts, images, num_generations):
self.prompts = prompts
assert images is None
assert num_generations == 1
return None, [[42]], None, None

class RecordingTokenizer:
pad_token_id = 9
pad_token = "<eos>"

def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False):
del clean_up_tokenization_spaces
decoded = []
token_map = {5: "A", 6: "B", 9: "<eos>"}
for sequence in sequences:
tokens = []
for token in sequence:
token = int(token)
if skip_special_tokens and token == 9:
continue
tokens.append(token_map[token])
decoded.append(" ".join(tokens))
return decoded

captured = {}

def capture_process_completions(
slices,
on_policy_indices,
local_slice_indices,
completion_ids,
prompt_ids_list,
prompts_text_with_special,
prompts_text,
max_completion_length,
):
captured["slices"] = slices
captured["on_policy_indices"] = on_policy_indices
captured["local_slice_indices"] = local_slice_indices
captured["completion_ids"] = completion_ids
captured["prompt_ids_list"] = prompt_ids_list
captured["prompts_text"] = prompts_text
captured["prompts_text_with_special"] = prompts_text_with_special
captured["max_completion_length"] = max_completion_length

trainer = GOLDTrainer.__new__(GOLDTrainer)
trainer.accelerator = SimpleNamespace(is_main_process=True)
trainer.args = SimpleNamespace(report_to=[])
trainer.processing_class = RecordingTokenizer()
trainer.use_vllm = True
trainer.vllm_generation = RecordingVLLMGeneration()
trainer.vllm_sync_frequency = 1
trainer._last_vllm_sync_step = -1
trainer.state = SimpleNamespace(global_step=0)
trainer.num_generations = 1
trainer.generation_config = SimpleNamespace(max_new_tokens=1)
trainer._process_completions_to_buffer = capture_process_completions

slices = [
{
"prompts": torch.tensor([[9, 9, 5, 9, 6]], dtype=torch.long),
"prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long),
}
]

GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0])

assert trainer.vllm_generation.prompts == [[5, 9, 6]]
assert trainer.vllm_generation.sync_calls == 1
assert captured["completion_ids"] == [[42]]
assert captured["prompt_ids_list"] == [[5, 9, 6]]
assert captured["prompts_text"] == ["A B"]
assert captured["prompts_text_with_special"] == ["A <eos> B"]


def test_generate_on_policy_for_slices_reconstructs_prompt_with_special_tokens():
class RecordingVLLMGeneration:
def __init__(self):
self.prompts = None
self.sync_calls = 0

def sync_weights(self):
self.sync_calls += 1

def generate(self, prompts, images, num_generations):
self.prompts = prompts
assert images is None
assert num_generations == 1
return None, [[42]], None, None

class RecordingTokenizer:
pad_token_id = 0
pad_token = "<pad>"

def __init__(self):
self.truncation_side = "right"

def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False):
del clean_up_tokenization_spaces
token_map = {0: "<pad>", 5: "A", 6: "B", 13: "<special>", 42: "C"}
decoded = []
for sequence in sequences:
tokens = []
for token in sequence:
token = int(token)
if skip_special_tokens and token == 13:
continue
if token == 0:
continue
tokens.append(token_map[token])
decoded.append(" ".join(tokens))
return decoded

trainer = GOLDTrainer.__new__(GOLDTrainer)
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
trainer.processing_class = RecordingTokenizer()
trainer.args = SimpleNamespace(max_length=None, report_to=[])
trainer.use_vllm = True
trainer.vllm_generation = RecordingVLLMGeneration()
trainer.vllm_sync_frequency = 1
trainer._last_vllm_sync_step = -1
trainer.state = SimpleNamespace(global_step=0)
trainer.num_generations = 1
trainer.generation_config = SimpleNamespace(max_new_tokens=1)
trainer._buffered_inputs = [None]
trainer._buffered_text_logs = [None]

slices = [
{
"slice": "original",
"prompts": torch.tensor([[0, 0, 5, 13, 6]], dtype=torch.long),
"prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long),
}
]

GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0])

buffered_inputs = trainer._buffered_inputs[0]
assert trainer.vllm_generation.prompts == [[5, 13, 6]]
assert trainer.vllm_generation.sync_calls == 1
assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[5, 13, 6, 42]], dtype=torch.long))
assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[1, 1, 1, 1]], dtype=torch.long))
assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, -100, 42]], dtype=torch.long))
assert buffered_inputs["original_prompt_text"] == ["A <special> B"]
assert buffered_inputs["original_completion_text"] == ["C"]
assert trainer._buffered_text_logs[0] == (["A B"], ["C"])


def test_gold_trainer_init_defaults_vllm_max_model_length_to_max_length(monkeypatch):
captured = {}

class DummyStudentModel:
def __init__(self):
self.config = SimpleNamespace(_name_or_path="student", vocab_size=17)
self.generation_config = SimpleNamespace(eos_token_id=2)
self.name_or_path = "student"

class DummyTeacherModel:
def __init__(self):
self.resized_to = None

def resize_token_embeddings(self, vocab_size):
self.resized_to = vocab_size

class DummyProcessingClass:
pad_token_id = 0

def fake_sft_init(
self,
model,
args=None,
data_collator=None,
train_dataset=None,
eval_dataset=None,
processing_class=None,
compute_metrics=None,
callbacks=None,
optimizers=None,
preprocess_logits_for_metrics=None,
peft_config=None,
):
del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers
del preprocess_logits_for_metrics, peft_config
self.model = model
self.args = args
self.processing_class = processing_class
self.accelerator = SimpleNamespace(
device=torch.device("cpu"),
num_processes=1,
prepare_model=lambda module, evaluation_mode=True: module,
)
self.is_deepspeed_enabled = False
self.is_fsdp_enabled = False

class CapturingVLLMGeneration:
def __init__(self, **kwargs):
captured.update(kwargs)

monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init)
monkeypatch.setattr(gold_trainer_module, "is_vllm_available", lambda: True)
monkeypatch.setattr(gold_trainer_module, "VLLMGeneration", CapturingVLLMGeneration)

args = SimpleNamespace(
model_init_kwargs=None,
max_length=128,
use_liger_kernel=False,
teacher_model_init_kwargs=None,
use_uld_loss=False,
teacher_tokenizer_name_or_path=None,
teacher_model_revision=None,
disable_dropout=False,
lmbda=1.0,
beta=0.5,
temperature=1.0,
top_p=1.0,
seq_kd=False,
num_generations=1,
use_transformers_paged=False,
max_completion_length=16,
top_k=0,
log_completions=False,
log_completions_steps=100,
wandb_log_unique_prompts=True,
num_completions_to_print=None,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
use_vllm=True,
vllm_mode="colocate",
vllm_structured_outputs_regex=None,
vllm_server_base_url=None,
vllm_server_host="0.0.0.0",
vllm_server_port=8001,
vllm_group_port=51216,
vllm_server_timeout=240.0,
vllm_tensor_parallel_size=1,
vllm_gpu_memory_utilization=0.2,
vllm_max_model_length=None,
vllm_enable_sleep_mode=False,
vllm_model_impl="vllm",
vllm_sync_frequency=1,
)

teacher_model = DummyTeacherModel()
GOLDTrainer(
model=DummyStudentModel(),
teacher_model=teacher_model,
args=args,
data_collator=object(),
processing_class=DummyProcessingClass(),
)

assert teacher_model.resized_to == 17
assert captured["max_model_length"] == 128


def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer):
config = build_config()
loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer)
Expand Down
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"scripts": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"],
"trainer": [
"BEMACallback",
"DistillationConfig",
"DistillationTrainer",
"DPOConfig",
"DPOTrainer",
"GRPOConfig",
Expand Down Expand Up @@ -90,6 +92,8 @@
from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
from .trainer import (
BEMACallback,
DistillationConfig,
DistillationTrainer,
DPOConfig,
DPOTrainer,
GRPOConfig,
Expand Down
Loading
Loading