Add DistillationTrainer for efficient on-policy distillation#5407
Add DistillationTrainer for efficient on-policy distillation#5407cmpatino merged 44 commits intohuggingface:mainfrom
DistillationTrainer for efficient on-policy distillation#5407Conversation
…o kd-distillation-trainer
DistillationTrainer for efficient on-policy distillation
qgallouedec
left a comment
There was a problem hiding this comment.
IMO it looks good and can be merged once we've moved the trainer to experimental. After that, we should
- add documentation
- add tests
- constraint the use of vLLM 0.18 (to avoid trl-vllm-client/trl-vllm-server)
- align internals with other trainers
Either we merge this before v1.0, or we can merge it after, and release v1.1 shortly after so that in both cases users can just pip install trl to use it
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks for the awesome work! 🚀 🤗
I made a first pass: some comments and questions below.
| decoded = VLLMClient._decode_binary_logprobs(resp) | ||
| all_logprobs.extend(decoded["logprobs"]) | ||
| all_token_ids.extend(decoded["logprob_token_ids"]) | ||
| if "actual_logprobs" in decoded: |
There was a problem hiding this comment.
Shouldn't we force here that either ALL decoded resp in responses should contain "actual_logprobs" key, or NONE should contain it? With current implementation, a potential misalignment (where actual_logprobs is shorter than logprobs) may occur if any bug in the server.
There was a problem hiding this comment.
I see you have addressed this, but I think the implementation is complicated. I'm commenting inline.
albertvillanova
left a comment
There was a problem hiding this comment.
Just a suggestion so simplify the implementation of _merge_binary_responses.
|
|
||
| labels_mask = inputs["labels"] != -100 | ||
| masked_input_ids = torch.where(labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)) | ||
| true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) |
There was a problem hiding this comment.
Liger loss path ignores prompt_length slicing, wasting compute
Low Severity
_compute_liger_loss passes the full sequence's hidden states (minus only the last token) and constructs true_labels from the entire sequence. Unlike the non-Liger path which slices from prompt_length, this forces the Liger kernel to process and materialize logits for all prompt positions (which are subsequently masked by -100). For long prompts with short completions, this significantly wastes GPU memory and compute.
| # Process completions into the buffer | ||
| self._store_completions_in_buffer( | ||
| slices, on_policy_indices, local_slice_indices, local_prompts, completion_ids | ||
| ) |
There was a problem hiding this comment.
vLLM deduplication breaks with misaligned repeat groups
Medium Severity
When num_generations > 1, use_vllm=True, and lmbda < 1.0, _generate_student_completions collects prompts only from on-policy slices and passes them to VLLMGeneration.generate. That method deduplicates via all_prompts[::num_generations], which assumes prompts are repeated num_generations times consecutively. When per_device_train_batch_size is not divisible by num_generations, split_tensor_dict can split across repeat group boundaries, breaking the consecutive-repeat invariant. This causes the deduplication to skip real unique prompts, generating completions for wrong prompts or producing a count mismatch.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 088db42. Configure here.
|
|
||
| return Response(content=orjson.dumps(payload), media_type="application/json") | ||
| except ImportError: | ||
| return payload |
There was a problem hiding this comment.
Starlette import inside thread-pool executor is unnecessary overhead
Low Severity
_format_logprob_response is invoked via loop.run_in_executor (a thread pool). The binary path unconditionally does from starlette.responses import Response even when orjson is not installed, in which case the Response class is imported but never used. Additionally, the function has two completely separate formatting paths (binary vs JSON) that could benefit from being split for clarity, since the binary path alone is ~70 lines.
Reviewed by Cursor Bugbot for commit 088db42. Configure here.
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks. Great work!
Maybe @qgallouedec could give a final review before merging. What do you think?
trl/scripts/vllm_serve.py
Outdated
| future.set_result((outputs_slice, prompt_lengths, top_logprobs, response_format)) | ||
| except Exception as e: | ||
| # Signal error to all waiting requests in this execution-parameter group | ||
| for _, _, _, future in items: |
There was a problem hiding this comment.
Align with the code below:
| for _, _, _, future in items: | |
| for *_, future in items: |
| """ | ||
| if len(request.sequences) != len(request.prompt_lengths): | ||
| raise ValueError("sequences and prompt_lengths must have the same length.") | ||
|
|
There was a problem hiding this comment.
Shouldn't we validate here as well that 0 < prompt_lengths[i] < len(sequences[i])?
An invalid value produces a zero or negative comp_lengths[i], which is silently skipped (if comp_lengths[i] == 0: continue), returning an all-padding row with no error raised.
There was a problem hiding this comment.
I addressed this by validating that the prompt lengths are non-negative and no longer than the sequences.
qgallouedec
left a comment
There was a problem hiding this comment.
Looks good, I recommend adding a page in the doc in a next pr
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
There are 5 total unresolved issues (including 3 from previous reviews).
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 864ba11. Configure here.
| connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) | ||
| all_outputs = [connection.recv() for connection in connections] | ||
| all_outputs = [output for output, chunk in zip(all_outputs, chunked_prompts, strict=True) if chunk] | ||
| return list(chain.from_iterable(all_outputs)) |
There was a problem hiding this comment.
Race condition on shared connections between endpoints
Low Severity
_run_prompt_logprobs accesses the shared connections list and runs in a thread pool via loop.run_in_executor. The existing generate endpoint also accesses connections directly in the event loop thread. If both endpoints are called on the same server, the generate handler can execute synchronous pipe I/O on connections while the executor thread concurrently runs _run_prompt_logprobs, causing data corruption on the multiprocessing.Pipe objects which are not thread-safe.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 864ba11. Configure here.
| return { | ||
| "logprobs": json_response["logprobs"], | ||
| "logprob_token_ids": json_response["logprob_token_ids"], | ||
| } |
There was a problem hiding this comment.
Non-binary server format missing actual_logprobs causes KeyError
Low Severity
The JSON (non-binary) server response format only returns logprobs and logprob_token_ids, omitting actual_logprobs and actual_token_ids. The client's non-binary path mirrors this. However, _get_teacher_token_logprobs_from_server unconditionally accesses result["actual_logprobs"], which raises a KeyError when use_binary=False. The binary format is the default, but the two code paths are inconsistent.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 864ba11. Configure here.


What does this PR do?
Introduces a new
DistillationTrainerwith the following capabilities:Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
High Risk
High risk due to introducing a new, complex training path (
DistillationTrainer) with custom batching/generation and new vLLM server/client APIs for token-level logprobs, which could affect training correctness and server stability/performance.Overview
Introduces a new experimental
DistillationTrainer/DistillationConfigthat implements generalized JSD distillation (forward/reverse KL viabeta), mixes on- and off-policy data via buffered generation across gradient accumulation, and optionally generates student rollouts with vLLM.Adds optional external teacher support: a new vLLM server endpoint
/get_sequence_logprobs/(with request batching and a binary base64 format) and a matchingVLLMClient.get_sequence_logprobs()API, enabling distillation using teacher token-level logprobs without loading the teacher locally.Updates the distillation documentation examples to reference
experimental.distillationand renames the config field frommax_new_tokenstomax_completion_lengthin the shown setup.Reviewed by Cursor Bugbot for commit 0f65a7b. Bugbot is set up for automated code reviews on this repo. Configure here.