Add LogitProcessor interface for pre-sampling logit transforms (#19517)#19517
Add LogitProcessor interface for pre-sampling logit transforms (#19517)#19517kirklandsign wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19517
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 2 Unrelated Failures, 1 Unclassified FailureAs of commit 6ebfdf6 with merge base 9e36d62 ( NEW FAILURES - The following jobs have failed:
UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104767967. |
This PR needs a
|
|
@claude review |
|
Claude finished @kirklandsign's task in 1m 31s —— View job Code Review: Add LogitProcessor interface for pre-sampling logit transforms
Overall this is a clean, well-scoped addition. The interface is minimal, the integration point is correct, and tests cover the key behaviors. A few observations: Design — looks good
Potential issues
Buck targets
Nits
SummaryClean PR, well-designed interface. The main suggestion is adding an integration test for the 3D tensor offset logic in |
There was a problem hiding this comment.
Pull request overview
This PR introduces a LogitProcessor interface to enable in-place mutation of model logits between the decoder forward pass and sampling, allowing features like constrained decoding and logit biasing without changing the core generation loop.
Changes:
- Adds
LogitProcessorinterface underextension/llm/sampler/. - Extends
TextTokenGeneratorwith a configurable processor chain applied pre-sampling. - Updates Buck targets to export the new header and adds a unit test for the interface.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/sampler/test/test_logit_processor.cpp | Adds unit tests validating basic LogitProcessor behavior and ordering semantics. |
| extension/llm/sampler/test/targets.bzl | Adds a Buck test target for the new logit processor tests. |
| extension/llm/sampler/targets.bzl | Exports logit_processor.h from the sampler library target. |
| extension/llm/sampler/logit_processor.h | Introduces the LogitProcessor pure virtual interface. |
| extension/llm/runner/text_token_generator.h | Adds processor registration APIs and applies processor chain to logits before sampling. |
| extension/llm/runner/targets.bzl | Adds runner dependency on the sampler target (for LogitProcessor). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const auto vocab_size = logits_tensor.size(logits_tensor.dim() - 1); | ||
| if (logits_tensor.dim() == 3) { | ||
| const auto num_tokens = logits_tensor.size(1); |
| auto* logits = logits_tensor.mutable_data_ptr<float>(); | ||
| const auto vocab_size = logits_tensor.size(logits_tensor.dim() - 1); | ||
| if (logits_tensor.dim() == 3) { | ||
| const auto num_tokens = logits_tensor.size(1); | ||
| logits += (num_tokens - 1) * vocab_size; | ||
| } | ||
| for (auto& processor : logit_processors_) { | ||
| processor->process(logits, static_cast<int32_t>(vocab_size)); | ||
| } |
| ET_CHECK_OR_RETURN_ERROR( | ||
| logits_tensor.scalar_type() == ::executorch::aten::ScalarType::Float, | ||
| InvalidArgument, | ||
| "LogitProcessor chain only supports Float logits; got dtype %d", | ||
| static_cast<int>(logits_tensor.scalar_type())); |
| if (!logit_processors_.empty()) { | ||
| ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors_(logits_tensor)); | ||
| } |
| * @param vocab_size Number of logits in the buffer (size of the model's | ||
| * output vocabulary for the current step). | ||
| */ | ||
| virtual void process(float* logits, int32_t vocab_size) = 0; |
Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): pure virtual interface with a single `process(float*, int32_t)` method, placed in `extension/llm/sampler/`. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that validates Float dtype, advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` — concurrent modification during generation is not safe. Differential Revision: D104767967
3b3862f to
6ebfdf6
Compare
Summary:
Introduces a
LogitProcessorabstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop.Changes:
LogitProcessor(new): pure virtual interface with a singleprocess(float*, int32_t)method, placed inextension/llm/sampler/.TextTokenGenerator: gainsadd_logit_processor(),clear_logit_processors(), andnum_logit_processors(). The processor chain runs after the model step and beforelogits_to_token(). When no processors are registered, behavior is identical to before.apply_logit_processors_(): private helper that validates Float dtype, advances to the last-position logits for 3D tensors (mirroringlogits_to_token), and invokes each processor in order.logit_processor.hexported from the sampler target;text_token_generatorgains a direct dep on sampler; test target added.Processors must be configured before calling
generate()— concurrent modification during generation is not safe.Differential Revision: D104767967