diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 0d4ed99308d..d3e12266adc 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -68,6 +68,7 @@ def define_common_targets(): visibility = ["PUBLIC"], exported_deps = [ ":text_decoder_runner" + aten_suffix, + "//executorch/extension/llm/sampler:sampler" + aten_suffix, "//pytorch/tokenizers:headers", "//executorch/extension/module:module" + aten_suffix, "//executorch/extension/tensor:tensor" + aten_suffix, diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 7e7fbbf1341..dd5ce9e0d89 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -10,9 +10,13 @@ #pragma once #include +#include +#include +#include #include #include +#include #include #include @@ -38,6 +42,21 @@ class ET_EXPERIMENTAL TextTokenGenerator { ignore_eos_ = ignore_eos; } + // Not safe to call while generate() is running on another thread. + void add_logit_processor(std::shared_ptr processor) { + if (processor) { + logit_processors_.push_back(std::move(processor)); + } + } + + void clear_logit_processors() { + logit_processors_.clear(); + } + + size_t num_logit_processors() const { + return logit_processors_.size(); + } + virtual ~TextTokenGenerator() = default; /** @@ -109,6 +128,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { prev_token = cur_token; + if (!logit_processors_.empty()) { + ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors_(logits_tensor)); + } + stats_->on_sampling_begin(); cur_token = text_decoder_runner_->logits_to_token(logits_tensor, temperature); @@ -177,6 +200,40 @@ class ET_EXPERIMENTAL TextTokenGenerator { } private: + inline ::executorch::runtime::Error apply_logit_processors_( + ::executorch::aten::Tensor& logits_tensor) { + ET_CHECK_OR_RETURN_ERROR( + logits_tensor.dim() >= 2, + InvalidArgument, + "LogitProcessor expects logits with dim >= 2, got %d", + static_cast(logits_tensor.dim())); + ET_CHECK_OR_RETURN_ERROR( + logits_tensor.scalar_type() == ::executorch::aten::ScalarType::Float, + InvalidArgument, + "LogitProcessor chain only supports Float logits; got dtype %d", + static_cast(logits_tensor.scalar_type())); + + auto* logits = logits_tensor.mutable_data_ptr(); + const ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); + ET_CHECK_OR_RETURN_ERROR( + vocab_size > 0 && vocab_size <= std::numeric_limits::max(), + InvalidArgument, + "vocab_size %zd out of range for LogitProcessor", + vocab_size); + if (logits_tensor.dim() == 3) { + const ssize_t num_tokens = logits_tensor.size(1); + ET_CHECK_OR_RETURN_ERROR( + num_tokens > 0, + InvalidArgument, + "LogitProcessor expects non-empty sequence dimension"); + logits += (num_tokens - 1) * vocab_size; + } + for (auto& processor : logit_processors_) { + processor->process(logits, static_cast(vocab_size)); + } + return ::executorch::runtime::Error::Ok; + } + /** * Note: TextTokenGenerator does not own the tokenizer_ and * text_decoder_runner_. The lifecycle of these objects should be managed @@ -189,6 +246,8 @@ class ET_EXPERIMENTAL TextTokenGenerator { bool use_kv_cache_; bool ignore_eos_ = false; + std::vector> logit_processors_; + // state machine std::atomic should_stop_{false}; diff --git a/extension/llm/sampler/logit_processor.h b/extension/llm/sampler/logit_processor.h new file mode 100644 index 00000000000..6313bc14b5c --- /dev/null +++ b/extension/llm/sampler/logit_processor.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * Interface for in-place logit transformations applied between the model's + * forward pass and the sampler. Examples include: + * - Grammar / constrained-decoding masks (set disallowed tokens to -inf) + * - Logit bias (additive per-token bias) + * - Custom debug instrumentation + * + * A `TextTokenGenerator` may be configured with a chain of processors. They + * are invoked in order on every decoding step, before the sampler sees the + * logits. Each processor mutates the buffer in place; later processors + * observe earlier processors' modifications. + * + * Implementations must be cheap to call repeatedly — `process()` runs on the + * critical path of every generated token. + */ +class ET_EXPERIMENTAL LogitProcessor { + public: + virtual ~LogitProcessor() = default; + + /** + * Modify logits in place for the current decoding step. + * + * @param logits Mutable pointer to the logits buffer for the current + * step. Must contain at least `vocab_size` elements. + * @param vocab_size Number of logits in the buffer (size of the model's + * output vocabulary for the current step). + */ + virtual void process(float* logits, int32_t vocab_size) = 0; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/sampler/targets.bzl b/extension/llm/sampler/targets.bzl index 42551e248e5..94a62745d6a 100644 --- a/extension/llm/sampler/targets.bzl +++ b/extension/llm/sampler/targets.bzl @@ -7,6 +7,7 @@ def define_common_targets(): runtime.cxx_library( name = "sampler" + aten_suffix, exported_headers = [ + "logit_processor.h", "sampler.h", "util.h", ], diff --git a/extension/llm/sampler/test/targets.bzl b/extension/llm/sampler/test/targets.bzl index 83b3d31e4cb..05138bea0d8 100644 --- a/extension/llm/sampler/test/targets.bzl +++ b/extension/llm/sampler/test/targets.bzl @@ -22,3 +22,13 @@ def define_common_targets(): "//caffe2:torch-cpp", ], ) + + runtime.cxx_test( + name = "test_logit_processor", + srcs = [ + "test_logit_processor.cpp", + ], + deps = [ + "//executorch/extension/llm/sampler:sampler", + ], + ) diff --git a/extension/llm/sampler/test/test_logit_processor.cpp b/extension/llm/sampler/test/test_logit_processor.cpp new file mode 100644 index 00000000000..786c1a76d20 --- /dev/null +++ b/extension/llm/sampler/test/test_logit_processor.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include + +using ::executorch::extension::llm::LogitProcessor; + +namespace { + +// Adds a fixed bias to every logit slot. Records how many times it was +// invoked so tests can verify chain ordering. +class AddBiasProcessor : public LogitProcessor { + public: + explicit AddBiasProcessor(float bias) : bias_(bias) {} + + void process(float* logits, int32_t vocab_size) override { + ++call_count_; + for (int32_t i = 0; i < vocab_size; ++i) { + logits[i] += bias_; + } + } + + int call_count() const { + return call_count_; + } + + private: + float bias_; + int call_count_ = 0; +}; + +class MultiplyProcessor : public LogitProcessor { + public: + explicit MultiplyProcessor(float factor) : factor_(factor) {} + + void process(float* logits, int32_t vocab_size) override { + for (int32_t i = 0; i < vocab_size; ++i) { + logits[i] *= factor_; + } + } + + private: + float factor_; +}; + +class MaskTokenProcessor : public LogitProcessor { + public: + explicit MaskTokenProcessor(int32_t banned_token) + : banned_token_(banned_token) {} + + void process(float* logits, int32_t vocab_size) override { + if (banned_token_ >= 0 && banned_token_ < vocab_size) { + logits[banned_token_] = -std::numeric_limits::infinity(); + } + } + + private: + int32_t banned_token_; +}; + +} // namespace + +// A single processor sees the buffer and may mutate it in place. +TEST(LogitProcessorTest, SingleProcessorMutatesLogits) { + std::vector logits = {1.0f, 2.0f, 3.0f, 4.0f}; + AddBiasProcessor bias{10.0f}; + + bias.process(logits.data(), static_cast(logits.size())); + + const std::vector expected = {11.0f, 12.0f, 13.0f, 14.0f}; + EXPECT_EQ(logits, expected); + EXPECT_EQ(bias.call_count(), 1); +} + +// Multiply(×2) then Add(+1) gives (x*2)+1, which differs from +// Add(+1) then Multiply(×2) = (x+1)*2. Non-commutative operations +// verify that processors run in registration order. +TEST(LogitProcessorTest, ProcessorChainAppliesInOrder) { + std::vector logits = {1.0f, 2.0f, 3.0f, 4.0f}; + + std::vector> chain; + chain.push_back(std::make_shared(2.0f)); + chain.push_back(std::make_shared(1.0f)); + + for (auto& p : chain) { + // NOLINTNEXTLINE(facebook-hte-Deprecated) + p->process(logits.data(), static_cast(logits.size())); + } + + // (x*2)+1, NOT (x+1)*2 + const std::vector expected = {3.0f, 5.0f, 7.0f, 9.0f}; + EXPECT_EQ(logits, expected); +} + +// A masking processor zeroes (well, -inf's) a specific token slot. This is +// the pattern grammar processors will follow. +TEST(LogitProcessorTest, MaskTokenDrivesArgmaxAway) { + std::vector logits = {0.1f, 0.2f, 0.99f, 0.4f}; // argmax = 2 + + MaskTokenProcessor mask{/*banned_token=*/2}; + mask.process(logits.data(), static_cast(logits.size())); + + const std::vector expected = { + 0.1f, 0.2f, -std::numeric_limits::infinity(), 0.4f}; + EXPECT_EQ(logits, expected); +} + +TEST(LogitProcessorTest, MaskTokenOutOfRangeIsNoOp) { + std::vector logits = {1.0f, 2.0f, 3.0f}; + const std::vector snapshot = logits; + + MaskTokenProcessor mask_over{/*banned_token=*/99}; + mask_over.process(logits.data(), static_cast(logits.size())); + EXPECT_EQ(logits, snapshot); + + MaskTokenProcessor mask_neg{/*banned_token=*/-1}; + mask_neg.process(logits.data(), static_cast(logits.size())); + EXPECT_EQ(logits, snapshot); +}