Skip to content

Add DistillationTrainer for efficient on-policy distillation#5407

Merged
cmpatino merged 44 commits intohuggingface:mainfrom
cmpatino:kd-distillation-trainer
Apr 9, 2026
Merged

Add DistillationTrainer for efficient on-policy distillation#5407
cmpatino merged 44 commits intohuggingface:mainfrom
cmpatino:kd-distillation-trainer

Conversation

@cmpatino
Copy link
Copy Markdown
Collaborator

@cmpatino cmpatino commented Mar 30, 2026

What does this PR do?

Introduces a new DistillationTrainer with the following capabilities:

  • Generalized JSD loss to use different types of KL divergences.
  • Possibility to mix on- and off-policy data
  • Buffering to speed up the on-policy generation process
  • Support for a teacher served on an external server.

Before submitting

  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the documentation with your changes?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

High Risk
High risk due to introducing a new, complex training path (DistillationTrainer) with custom batching/generation and new vLLM server/client APIs for token-level logprobs, which could affect training correctness and server stability/performance.

Overview
Introduces a new experimental DistillationTrainer/DistillationConfig that implements generalized JSD distillation (forward/reverse KL via beta), mixes on- and off-policy data via buffered generation across gradient accumulation, and optionally generates student rollouts with vLLM.

Adds optional external teacher support: a new vLLM server endpoint /get_sequence_logprobs/ (with request batching and a binary base64 format) and a matching VLLMClient.get_sequence_logprobs() API, enabling distillation using teacher token-level logprobs without loading the teacher locally.

Updates the distillation documentation examples to reference experimental.distillation and renames the config field from max_new_tokens to max_completion_length in the shown setup.

Reviewed by Cursor Bugbot for commit 0f65a7b. Bugbot is set up for automated code reviews on this repo. Configure here.

@cmpatino cmpatino changed the title Kd distillation trainer Add DistillationTrainer for efficient on-policy distillation Mar 30, 2026
Copy link
Copy Markdown
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it looks good and can be merged once we've moved the trainer to experimental. After that, we should

  • add documentation
  • add tests
  • constraint the use of vLLM 0.18 (to avoid trl-vllm-client/trl-vllm-server)
  • align internals with other trainers

Either we merge this before v1.0, or we can merge it after, and release v1.1 shortly after so that in both cases users can just pip install trl to use it

@cmpatino cmpatino marked this pull request as ready for review March 30, 2026 16:52
@cmpatino cmpatino marked this pull request as draft March 30, 2026 16:55
@cmpatino cmpatino marked this pull request as ready for review March 30, 2026 17:04
Copy link
Copy Markdown
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the awesome work! 🚀 🤗

I made a first pass: some comments and questions below.

decoded = VLLMClient._decode_binary_logprobs(resp)
all_logprobs.extend(decoded["logprobs"])
all_token_ids.extend(decoded["logprob_token_ids"])
if "actual_logprobs" in decoded:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we force here that either ALL decoded resp in responses should contain "actual_logprobs" key, or NONE should contain it? With current implementation, a potential misalignment (where actual_logprobs is shorter than logprobs) may occur if any bug in the server.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you have addressed this, but I think the implementation is complicated. I'm commenting inline.

Copy link
Copy Markdown
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion so simplify the implementation of _merge_binary_responses.


labels_mask = inputs["labels"] != -100
masked_input_ids = torch.where(labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100))
true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Liger loss path ignores prompt_length slicing, wasting compute

Low Severity

_compute_liger_loss passes the full sequence's hidden states (minus only the last token) and constructs true_labels from the entire sequence. Unlike the non-Liger path which slices from prompt_length, this forces the Liger kernel to process and materialize logits for all prompt positions (which are subsequently masked by -100). For long prompts with short completions, this significantly wastes GPU memory and compute.

Fix in Cursor Fix in Web

@cmpatino cmpatino requested a review from albertvillanova April 2, 2026 13:57
# Process completions into the buffer
self._store_completions_in_buffer(
slices, on_policy_indices, local_slice_indices, local_prompts, completion_ids
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM deduplication breaks with misaligned repeat groups

Medium Severity

When num_generations > 1, use_vllm=True, and lmbda < 1.0, _generate_student_completions collects prompts only from on-policy slices and passes them to VLLMGeneration.generate. That method deduplicates via all_prompts[::num_generations], which assumes prompts are repeated num_generations times consecutively. When per_device_train_batch_size is not divisible by num_generations, split_tensor_dict can split across repeat group boundaries, breaking the consecutive-repeat invariant. This causes the deduplication to skip real unique prompts, generating completions for wrong prompts or producing a count mismatch.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 088db42. Configure here.


return Response(content=orjson.dumps(payload), media_type="application/json")
except ImportError:
return payload
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starlette import inside thread-pool executor is unnecessary overhead

Low Severity

_format_logprob_response is invoked via loop.run_in_executor (a thread pool). The binary path unconditionally does from starlette.responses import Response even when orjson is not installed, in which case the Response class is imported but never used. Additionally, the function has two completely separate formatting paths (binary vs JSON) that could benefit from being split for clarity, since the binary path alone is ~70 lines.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 088db42. Configure here.

Copy link
Copy Markdown
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Great work!

Maybe @qgallouedec could give a final review before merging. What do you think?

future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format))
except Exception as e:
# Signal error to all waiting requests in this execution-parameter group
for _, _, _, future in items:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Align with the code below:

Suggested change
for _, _, _, future in items:
for *_, future in items:

"""
if len(request.sequences) != len(request.prompt_lengths):
raise ValueError("sequences and prompt_lengths must have the same length.")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we validate here as well that 0 < prompt_lengths[i] < len(sequences[i])?

An invalid value produces a zero or negative comp_lengths[i], which is silently skipped (if comp_lengths[i] == 0: continue), returning an all-padding row with no error raised.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this by validating that the prompt lengths are non-negative and no longer than the sequences.

Copy link
Copy Markdown
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I recommend adding a page in the doc in a next pr

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

There are 5 total unresolved issues (including 3 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 864ba11. Configure here.

connection.send({"type": "call", "method": "generate", "kwargs": kwargs})
all_outputs = [connection.recv() for connection in connections]
all_outputs = [output for output, chunk in zip(all_outputs, chunked_prompts, strict=True) if chunk]
return list(chain.from_iterable(all_outputs))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition on shared connections between endpoints

Low Severity

_run_prompt_logprobs accesses the shared connections list and runs in a thread pool via loop.run_in_executor. The existing generate endpoint also accesses connections directly in the event loop thread. If both endpoints are called on the same server, the generate handler can execute synchronous pipe I/O on connections while the executor thread concurrently runs _run_prompt_logprobs, causing data corruption on the multiprocessing.Pipe objects which are not thread-safe.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 864ba11. Configure here.

return {
"logprobs": json_response["logprobs"],
"logprob_token_ids": json_response["logprob_token_ids"],
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-binary server format missing actual_logprobs causes KeyError

Low Severity

The JSON (non-binary) server response format only returns logprobs and logprob_token_ids, omitting actual_logprobs and actual_token_ids. The client's non-binary path mirrors this. However, _get_teacher_token_logprobs_from_server unconditionally accesses result["actual_logprobs"], which raises a KeyError when use_binary=False. The binary format is the default, but the two code paths are inconsistent.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 864ba11. Configure here.

@cmpatino cmpatino merged commit c475b97 into huggingface:main Apr 9, 2026
14 of 15 checks passed
@cmpatino cmpatino deleted the kd-distillation-trainer branch April 9, 2026 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants