diff --git a/apps/llm/app/llm/index.tsx b/apps/llm/app/llm/index.tsx index 901b74de43..e8090f7660 100644 --- a/apps/llm/app/llm/index.tsx +++ b/apps/llm/app/llm/index.tsx @@ -11,7 +11,7 @@ import { View, } from 'react-native'; import SendIcon from '../../assets/icons/send_icon.svg'; -import { useLLM, LLAMA3_2_1B_SPINQUANT } from 'react-native-executorch'; +import { useLLM, QWEN3_0_6B_QUANTIZED } from 'react-native-executorch'; import { ModelPicker } from '../../components/ModelPicker'; import { LLM_MODELS, LLMModelSources } from '../../components/llmModels'; import PauseIcon from '../../assets/icons/pause_icon.svg'; @@ -42,9 +42,8 @@ function LLMScreen() { const { bottom } = useSafeAreaInsets(); const [isTextInputFocused, setIsTextInputFocused] = useState(false); const [userInput, setUserInput] = useState(''); - const [selectedModel, setSelectedModel] = useState( - LLAMA3_2_1B_SPINQUANT - ); + const [selectedModel, setSelectedModel] = + useState(QWEN3_0_6B_QUANTIZED); const textInputRef = useRef(null); const { setGlobalGenerating } = useContext(GeneratingContext); @@ -76,6 +75,7 @@ function LLMScreen() { } }; + console.log(llm.messageHistory) return !llm.isReady && !llm.error ? ( (null); const { setGlobalGenerating } = useContext(GeneratingContext); - // Added error state + const [audioBuffer, setAudioBuffer] = useState(null); + const [audioLabel, setAudioLabel] = useState(null); + const [audioUrl, setAudioUrl] = useState(''); + const [isFetchingAudio, setIsFetchingAudio] = useState(false); + const [isRecording, setIsRecording] = useState(false); + const [hasMicPermission, setHasMicPermission] = useState(false); + const recorder = useRef(new AudioRecorder()); + const recordChunks = useRef([]); + const [error, setError] = useState(null); const vlm = useLLM({ @@ -68,6 +82,87 @@ function MultimodalLLMScreen() { if (vlm.error) setError(String(vlm.error)); }, [vlm.error]); + useEffect(() => { + AudioManager.setAudioSessionOptions({ + iosCategory: 'playAndRecord', + iosMode: 'spokenAudio', + iosOptions: ['allowBluetoothHFP', 'defaultToSpeaker'], + }); + (async () => { + const status = await AudioManager.requestRecordingPermissions(); + setHasMicPermission(status === 'Granted'); + })(); + }, []); + + const loadAudioFromUrl = async () => { + const url = audioUrl.trim(); + if (!url) return; + setIsFetchingAudio(true); + try { + const ctx = new AudioContext({ sampleRate: 16000 }); + const decoded = await ctx.decodeAudioData(url); + const pcm = decoded.getChannelData(0); + const name = url.split('/').pop() || 'audio'; + setAudioBuffer(pcm); + setAudioLabel(`${name} ยท ${(pcm.length / 16000).toFixed(1)}s`); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } finally { + setIsFetchingAudio(false); + } + }; + + const startRecording = async () => { + if (!hasMicPermission) { + setError('Microphone permission denied. Please enable it in Settings.'); + return; + } + recordChunks.current = []; + const sampleRate = 16000; + recorder.current.onAudioReady( + { sampleRate, bufferLength: 0.1 * sampleRate, channelCount: 1 }, + ({ buffer }) => { + recordChunks.current.push(new Float32Array(buffer.getChannelData(0))); + } + ); + try { + const ok = await AudioManager.setAudioSessionActivity(true); + if (!ok) { + setError('Cannot start audio session'); + return; + } + const result = recorder.current.start(); + if (result.status === 'error') { + setError(`Recording problems: ${result.message}`); + return; + } + setIsRecording(true); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } + }; + + const stopRecording = () => { + recorder.current.stop(); + setIsRecording(false); + const total = recordChunks.current.reduce((n, c) => n + c.length, 0); + if (total === 0) return; + const pcm = new Float32Array(total); + let off = 0; + for (const c of recordChunks.current) { + pcm.set(c, off); + off += c.length; + } + recordChunks.current = []; + setAudioBuffer(pcm); + setAudioLabel(`Recording ยท ${(pcm.length / 16000).toFixed(1)}s`); + }; + + const clearAudio = () => { + setAudioBuffer(null); + setAudioLabel(null); + }; + const pickImage = async () => { try { const result = await launchImageLibrary({ mediaType: 'photo' }); @@ -81,19 +176,27 @@ function MultimodalLLMScreen() { }; const sendMessage = async () => { - if (!userInput.trim() || vlm.isGenerating) return; + if (!(imageUri || audioBuffer || userInput.trim()) || vlm.isGenerating) + return; onMessageSend(); const text = userInput.trim(); setUserInput(''); textInputRef.current?.clear(); Keyboard.dismiss(); const currentImageUri = imageUri; + const currentAudio = audioBuffer; setImageUri(null); + setAudioBuffer(null); + setAudioLabel(null); try { - await vlm.sendMessage( - text, - currentImageUri ? { imagePath: currentImageUri } : undefined - ); + const media = + currentImageUri || currentAudio + ? { + ...(currentImageUri ? { imagePath: currentImageUri } : {}), + ...(currentAudio ? { audioBuffer: currentAudio } : {}), + } + : undefined; + await vlm.sendMessage(text, media); } catch (e) { // Updated to set UI error instead of just console.error setError(e instanceof Error ? e.message : String(e)); @@ -159,6 +262,42 @@ function MultimodalLLMScreen() { )} + {/* Audio URL input */} + + + + + {isFetchingAudio ? 'โ€ฆ' : 'Load'} + + + + + {/* Audio attachment strip */} + {audioLabel && ( + + ๐ŸŽต {audioLabel} + + โœ• + + + )} + ๐Ÿ“ท + {/* Mic record / stop button */} + + + {isRecording ? 'โน๏ธ' : '๐ŸŽค'} + + + - {userInput.trim() && !vlm.isGenerating && ( - - - - )} + {(imageUri || audioBuffer || userInput.trim()) && + !vlm.isGenerating && ( + + + + )} {vlm.isGenerating && ( [] = [ { label: 'Qwen3 0.6B', value: QWEN3_0_6B }, { label: 'Qwen3 0.6B Quantized', value: QWEN3_0_6B_QUANTIZED }, { label: 'Qwen3 1.7B', value: QWEN3_1_7B }, + { label: 'Gemma4 e2b Quantized', value: GEMMA4_E2B_QUANTIZED }, { label: 'Qwen3 1.7B Quantized', value: QWEN3_1_7B_QUANTIZED }, { label: 'Qwen3 4B', value: QWEN3_4B }, { label: 'Qwen3 4B Quantized', value: QWEN3_4B_QUANTIZED }, diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 077d426c8f..42c9eca8ab 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -223,6 +223,22 @@ inline std::vector getValue>(const jsi::Value &val, return getArrayAsVector(val, runtime); } +template <> +inline std::vector> +getValue>>(const jsi::Value &val, + jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + const size_t length = array.size(runtime); + std::vector> result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + auto span = + getTypedArrayAsSpan(array.getValueAtIndex(runtime, i), runtime); + result.emplace_back(span.begin(), span.end()); + } + return result; +} + template <> inline std::vector getValue>(const jsi::Value &val, jsi::Runtime &runtime) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 7e0fa4b26e..331624270a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -4,8 +4,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -21,7 +21,6 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, std::vector capabilities, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker, Module::LoadMode::Mmap) { - if (capabilities.empty()) { runner_ = std::make_unique(std::move(module_), tokenizerSource); @@ -31,6 +30,9 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, if (cap == "vision") { encoders[llm::MultimodalType::Image] = std::make_unique(*module_); + } else if (cap == "audio") { + encoders[llm::MultimodalType::Audio] = + std::make_unique(*module_); } } runner_ = std::make_unique( @@ -74,63 +76,68 @@ std::string LLM::generate(std::string input, return output; } -std::string LLM::generateMultimodal(std::string prompt, - std::vector imagePaths, - std::string imageToken, - std::shared_ptr callback) { +std::string LLM::generateMultimodal( + std::string prompt, std::shared_ptr callback, + std::vector imagePaths, std::string imageToken, + std::vector> audioWaveforms, std::string audioToken) { if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Runner is not loaded"); } if (!runner_->is_multimodal()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "This model does not support multimodal input. Use generate(prompt, " - "callback) for text-only generation."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "This model does not support multimodal input."); } - if (imageToken.empty()) { + if (imageToken.empty() && audioToken.empty()) { throw RnExecutorchError( RnExecutorchErrorCode::InvalidUserInput, - "imageToken must not be empty. Pass the model's image token (e.g. " - "from tokenizer_config.json)."); + "At least one of imageToken/audioToken must be non-empty"); } - const size_t kImageTokenLen = imageToken.size(); - + // Scan the prompt once, splitting at the earliest placeholder at each step + // so that image/audio placeholders can be freely interleaved in the prompt. std::vector inputs; - size_t imageIdx = 0; - size_t searchPos = 0; - - while (true) { - size_t found = prompt.find(imageToken, searchPos); - if (found == std::string::npos) { - if (searchPos < prompt.size()) { - inputs.push_back(llm::make_text_input(prompt.substr(searchPos))); - } + size_t imageIdx = 0, audioIdx = 0, pos = 0; + constexpr int32_t kAudioSampleRate = 16000; + while (pos < prompt.size()) { + size_t imgAt = + imageToken.empty() ? std::string::npos : prompt.find(imageToken, pos); + size_t audAt = + audioToken.empty() ? std::string::npos : prompt.find(audioToken, pos); + if (imgAt == std::string::npos && audAt == std::string::npos) { + inputs.push_back(llm::make_text_input(prompt.substr(pos))); break; } - // Text segment before this placeholder - if (found > searchPos) { - inputs.push_back( - llm::make_text_input(prompt.substr(searchPos, found - searchPos))); + const bool imageFirst = imgAt != std::string::npos && + (audAt == std::string::npos || imgAt < audAt); + size_t at = imageFirst ? imgAt : audAt; + if (at > pos) { + inputs.push_back(llm::make_text_input(prompt.substr(pos, at - pos))); } - // Image at this position - if (imageIdx >= imagePaths.size()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "More '" + imageToken + - "' placeholders in prompt than image paths provided"); + if (imageFirst) { + if (imageIdx >= imagePaths.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More '" + imageToken + + "' placeholders than image paths"); + } + inputs.push_back(llm::make_image_input(imagePaths[imageIdx++])); + pos = at + imageToken.size(); + } else { + if (audioIdx >= audioWaveforms.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More '" + audioToken + + "' placeholders than audio waveforms"); + } + inputs.push_back(llm::make_audio_input( + std::move(audioWaveforms[audioIdx++]), kAudioSampleRate)); + pos = at + audioToken.size(); } - inputs.push_back(llm::make_image_input(imagePaths[imageIdx++])); - searchPos = found + kImageTokenLen; } - - if (imageIdx < imagePaths.size()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "More image paths provided than '" + imageToken + - "' placeholders in prompt"); + if (imageIdx < imagePaths.size() || audioIdx < audioWaveforms.size()) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "More image/audio paths provided than placeholders in prompt"); } - if (inputs.empty()) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "No inputs to generate from"); @@ -150,7 +157,6 @@ std::string LLM::generateMultimodal(std::string prompt, if (error != Error::Ok) { throw RnExecutorchError(error, "Failed to generate multimodal response"); } - return output; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 222b5bc62f..bf1c44313d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -22,10 +22,16 @@ class LLM : public BaseModel { std::string generate(std::string prompt, std::shared_ptr callback); - std::string generateMultimodal(std::string prompt, - std::vector imagePaths, - std::string imageToken, - std::shared_ptr callback); + // Audio variant: `audioWaveforms` is a parallel vector of fp32 mono 16 kHz + // PCM buffers (decoded upstream, same contract as SpeechToText::transcribe). + // The prompt is scanned for `imageToken` and/or `audioToken` placeholders; + // each placeholder consumes the next entry from its respective vector in + // order. Either set of paths/waveforms/token may be empty. + std::string generateMultimodal( + std::string prompt, std::shared_ptr callback, + std::vector imagePaths = {}, std::string imageToken = "", + std::vector> audioWaveforms = {}, + std::string audioToken = ""); void interrupt(); void reset(); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 06a30a13f7..f97527cb53 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -292,6 +292,7 @@ add_rn_test(LLMTests integration/LLMTest.cpp ${COMMON_DIR}/runner/sampler.cpp ${COMMON_DIR}/runner/arange_util.cpp ${COMMON_DIR}/runner/encoders/vision_encoder.cpp + ${COMMON_DIR}/runner/encoders/audio_encoder.cpp ${IMAGE_UTILS_SOURCES} LIBS tokenizers_deps opencv_deps ) diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.cpp b/packages/react-native-executorch/common/runner/base_llm_runner.cpp index a021040807..375341ca61 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.cpp +++ b/packages/react-native-executorch/common/runner/base_llm_runner.cpp @@ -69,6 +69,7 @@ Error BaseLLMRunner::load() { eos_ids_->emplace(static_cast(eos_id.toScalar().to())); } } + if (eos_ids_->empty()) { throw rnexecutorch::RnExecutorchError( rnexecutorch::RnExecutorchErrorCode::InvalidModelOutput, @@ -149,6 +150,8 @@ void BaseLLMRunner::set_repetition_penalty(float repetition_penalty) noexcept { config_.repetition_penalty = repetition_penalty; } +void BaseLLMRunner::set_topk(int32_t topk) noexcept { config_.topk = topk; } + void BaseLLMRunner::set_count_interval(size_t count_interval) { config_.output_token_batch_size = count_interval; } diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.h b/packages/react-native-executorch/common/runner/base_llm_runner.h index 9710f5ae70..82de49bea3 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.h +++ b/packages/react-native-executorch/common/runner/base_llm_runner.h @@ -55,6 +55,7 @@ class BaseLLMRunner { void set_topp(float topp) noexcept; void set_min_p(float min_p) noexcept; void set_repetition_penalty(float repetition_penalty) noexcept; + void set_topk(int32_t topk) noexcept; void set_count_interval(size_t count_interval); void set_time_interval(size_t time_interval); diff --git a/packages/react-native-executorch/common/runner/constants.h b/packages/react-native-executorch/common/runner/constants.h index f1fee23471..ab4a125962 100644 --- a/packages/react-native-executorch/common/runner/constants.h +++ b/packages/react-native-executorch/common/runner/constants.h @@ -17,13 +17,22 @@ inline constexpr auto kMaxSeqLen = "get_max_seq_len"; inline constexpr auto kMaxContextLen = "get_max_context_len"; inline constexpr auto kVocabSize = "get_vocab_size"; inline constexpr auto kUseKVCache = "use_kv_cache"; +// PLE models only: token id that marks image placeholder slots in input_ids. +// token_embedding run on this id produces the per-layer PLE signal for image +// positions; the inputs_embeds output for those positions is discarded (the +// vision encoder output replaces it). +inline constexpr auto kImagePlaceholderId = "image_placeholder_id"; // Multimodal method name conventions inline constexpr auto kVisionEncoderMethod = "vision_encoder"; inline constexpr auto kAudioEncoderMethod = "audio_encoder"; inline constexpr auto kTokenEmbeddingMethod = "token_embedding"; inline constexpr auto kTextModelMethod = "text_decoder"; - +// Absolute ceiling on prefill length (in tokens) and the fallback value used +// when a PTE doesn't bake `get_max_seq_len`. 2048 matches Gemma4 iter201's +// PREFILL_LEN / get_max_context_len; legacy PTEs (e.g. LFM2-VL) typically +// bake their own get_max_seq_len so this ceiling does not affect them. +inline constexpr auto kMaxPrefillLen = 2048; inline constexpr auto numOfAddedBoSTokens = 0; inline constexpr auto numOfAddedEoSTokens = 0; diff --git a/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp new file mode 100644 index 0000000000..4e4f2c879c --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp @@ -0,0 +1,130 @@ +// common/runner/encoders/audio_encoder.cpp +#include "audio_encoder.h" + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::aten::SizesType; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; + +namespace { +// Matches AUDIO_SAMPLES_PER_BLOCK in gemma_export/experiments_vulkan/ +// op_bisect/iter201_mm_4method_dynaudio_prefill2048_export.py. +// The PTE's audio_samples dim was exported as `7680 * audio_blocks`. +constexpr int32_t kSamplesPerBlock = 7680; +// k โˆˆ [kAudioBlockKMin, kAudioBlockKMax] from MODEL_INTERFACE.md ยง6. +// k=62 == 29.76 s @ 16 kHz is the SDPA mask + rel-shift bake point. +constexpr int64_t kAudioBlockKMin = 1; +constexpr int64_t kAudioBlockKMax = 62; +} // namespace + +AudioEncoder::AudioEncoder(::executorch::extension::Module &module) + : module_(&module) {} + +Error AudioEncoder::load() { + if (is_loaded()) { + return Error::Ok; + } + auto method_names_result = module_->method_names(); + if (!method_names_result.ok()) { + return method_names_result.error(); + } + if (method_names_result->count(kAudioEncoderMethod) == 0) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "Model does not support audio: 'audio_encoder' method not found. " + "Check that the .pte file matches the declared capabilities."); + } + return module_->load_method(kAudioEncoderMethod); +} + +bool AudioEncoder::is_loaded() const noexcept { + return module_->is_method_loaded(kAudioEncoderMethod); +} + +int32_t AudioEncoder::encoderTokenCount() const { return last_token_count_; } + +Result AudioEncoder::encode(const MultimodalInput &input) { + if (!is_loaded()) { + return Error::InvalidState; + } + if (!input.is_audio()) { + return Error::InvalidArgument; + } + + const auto &wav = input.get_audio(); + ET_CHECK_OR_RETURN_ERROR(!wav.samples.empty(), InvalidArgument, + "AudioEncoder: empty waveform"); + ET_CHECK_OR_RETURN_ERROR( + wav.sample_rate == 16000, InvalidArgument, + "AudioEncoder: expected 16000 Hz waveform, got %d Hz", wav.sample_rate); + + const int64_t n_valid = static_cast(wav.samples.size()); + const int64_t k_blocks = (n_valid + kSamplesPerBlock - 1) / kSamplesPerBlock; + ET_CHECK_OR_RETURN_ERROR( + k_blocks >= kAudioBlockKMin && k_blocks <= kAudioBlockKMax, + InvalidArgument, + "AudioEncoder: waveform of %lld samples needs k_blocks=%lld; " + "audio_encoder accepts k in [%lld, %lld] (block=%d samples; max %.2f s " + "@ 16 kHz)", + static_cast(n_valid), static_cast(k_blocks), + static_cast(kAudioBlockKMin), + static_cast(kAudioBlockKMax), + static_cast(kSamplesPerBlock), + static_cast(kSamplesPerBlock) * + static_cast(kAudioBlockKMax) / 16000.0); + const int64_t n_padded = k_blocks * kSamplesPerBlock; + + // Own the padded waveform and the attention_mask buffers for the lifetime + // of this call; from_blob below borrows without copying. Mask is bool + // (1 byte per element): true at the first n_valid samples (real PCM), + // false at the zero-padded tail. Matches the iter191+ export at + // iter201_mm_4method_dynaudio_prefill2048_export.py:484-486 โ€” `forward( + // self, waveform[1,N] fp32, attention_mask[1,N] bool)`. + padded_wav_.assign(static_cast(n_padded), 0.0f); + std::memcpy(padded_wav_.data(), wav.samples.data(), + static_cast(n_valid) * sizeof(float)); + + padded_mask_.assign(static_cast(n_padded), uint8_t{0}); + if (n_valid > 0) { + std::memset(padded_mask_.data(), 1, static_cast(n_valid)); + } + + auto wav_tensor = ::executorch::extension::from_blob( + padded_wav_.data(), {1, static_cast(n_padded)}, + ::executorch::aten::ScalarType::Float); + + auto mask_tensor = ::executorch::extension::from_blob( + padded_mask_.data(), {1, static_cast(n_padded)}, + ::executorch::aten::ScalarType::Bool); + + std::vector args = {EValue(*wav_tensor), EValue(*mask_tensor)}; + auto exec_result = ET_UNWRAP(module_->execute(kAudioEncoderMethod, args)); + ET_CHECK_OR_RETURN_ERROR(!exec_result.empty(), InvalidState, + "audio_encoder returned no outputs"); + auto audio_tensor = exec_result[0].toTensor(); + ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState, + "audio_encoder output rank=%zd, expected 3", + audio_tensor.dim()); + last_token_count_ = static_cast(audio_tensor.size(1)); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "AudioEncoder: valid_samples=", n_valid, + " padded_samples=", n_padded, " k_blocks=", k_blocks, + " audio_tokens=", last_token_count_); + return exec_result[0]; +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/audio_encoder.h b/packages/react-native-executorch/common/runner/encoders/audio_encoder.h new file mode 100644 index 0000000000..3b3b9cfb55 --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/audio_encoder.h @@ -0,0 +1,40 @@ +// common/runner/encoders/audio_encoder.h +#pragma once + +#include "iencoder.h" +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +// Runs the Gemma4 `audio_encoder` PTE method. +// +// Contract mirrors SpeechToText (Whisper): JS hands in fp32 mono 16 kHz PCM +// via `MultimodalInput::get_audio()`; the PTE owns the log-mel frontend so +// this class just wraps the samples in a `[1, N_samples]` Float tensor and +// executes. Resampling and WAV/MP3 decoding are the caller's responsibility +// (e.g. react-native-audio-api). +class AudioEncoder : public IEncoder { +public: + explicit AudioEncoder(::executorch::extension::Module &module); + + ::executorch::runtime::Error load() override; + bool is_loaded() const noexcept override; + ::executorch::runtime::Result<::executorch::runtime::EValue> + encode(const MultimodalInput &input) override; + // Number of audio embedding tokens produced per encode() call. 0 until first + // encode, since Gemma4's audio_encoder has a dynamic T dim. + int32_t encoderTokenCount() const override; + +private: + ::executorch::extension::Module *module_; + int32_t last_token_count_ = 0; + std::vector padded_wav_; + std::vector padded_mask_; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp index de3e196c1f..59fee53e11 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -2,7 +2,6 @@ #include "vision_encoder.h" #include -#include #include #include diff --git a/packages/react-native-executorch/common/runner/irunner.h b/packages/react-native-executorch/common/runner/irunner.h index 54b14c354f..4e5b14444a 100644 --- a/packages/react-native-executorch/common/runner/irunner.h +++ b/packages/react-native-executorch/common/runner/irunner.h @@ -73,6 +73,11 @@ struct GenerationConfig { size_t output_token_batch_size = 10; size_t batch_time_interval_ms = 120; + // Top-k sampling โ€“ keep only the k highest-logit tokens before softmax. + // 0 (default) disables top-k filtering. Stacks with topp: temperature -> + // top-k -> top-p -> softmax -> multinomial. + int32_t topk = 0; + // Enable dynamic input shapes (if implemented) or not // Impacts the prefill phase and causes TextPrefiller to pass all the tokens // at once if set to true. diff --git a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h index 071b193539..69e36641d8 100644 --- a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h @@ -14,21 +14,47 @@ #include "text_decoder_runner.h" namespace executorch::extension::llm { +// Supports two PTE contracts, selected automatically at load time from +// `token_embedding`'s output arity: +// +// * Legacy (default): +// token_embedding(ids) -> inputs_embeds +// text_decoder(inputs_embeds, input_pos) +// +// * Gemma-style PLE (when token_embedding emits 2 outputs): +// token_embedding(ids) -> (inputs_embeds, ple_tok) +// text_decoder(inputs_embeds, ple_tok, input_pos) +// ple_tok carries Gemma4's per-layer PLE signal keyed on input_ids. It's +// computed once in token_embedding and threaded through every decoder call +// so PLE fires at every position (including multimodal placeholder slots). class MultimodalDecoderRunner : public TextDecoderRunner { public: explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager, const GenerationConfig &config) : TextDecoderRunner(module, io_manager, config) {} + // True iff the loaded PTE uses the Gemma-style PLE contract above. + // Meaningful only after load() has been called. + bool uses_ple() const { return uses_ple_; } + inline ::executorch::runtime::Result<::executorch::aten::Tensor> step(TensorPtr &tokens, int64_t start_pos) override { auto embed_result = module_->execute(kTokenEmbeddingMethod, tokens); if (!embed_result.ok()) { return embed_result.error(); } - return decode((*embed_result)[0], start_pos); + auto &embed_outputs = *embed_result; + if (uses_ple_) { + ET_CHECK_MSG(embed_outputs.size() == 2, + "Expected 2 outputs (inputs_embeds, ple_tok) from " + "token_embedding, got %zu", + embed_outputs.size()); + return decode(embed_outputs[0], embed_outputs[1], start_pos); + } + return decode(embed_outputs[0], start_pos); } + // Legacy 2-input text_decoder(inputs_embeds, input_pos). inline ::executorch::runtime::Result<::executorch::aten::Tensor> decode(const ::executorch::runtime::EValue &embeddings, int64_t start_pos) { auto start_pos_tensor = ::executorch::extension::from_blob( @@ -46,12 +72,35 @@ class MultimodalDecoderRunner : public TextDecoderRunner { return outputs[0].toTensor(); } + // PLE 3-input text_decoder(inputs_embeds, ple_tok, input_pos). + inline ::executorch::runtime::Result<::executorch::aten::Tensor> + decode(const ::executorch::runtime::EValue &embeddings, + const ::executorch::runtime::EValue &ple_tok, int64_t start_pos) { + auto start_pos_tensor = ::executorch::extension::from_blob( + &start_pos, {1}, ::executorch::aten::ScalarType::Long); + auto outputs_result = module_->execute( + kTextModelMethod, {embeddings, ple_tok, start_pos_tensor}); + if (!outputs_result.ok()) { + return outputs_result.error(); + } + auto &outputs = *outputs_result; + ET_CHECK_MSG(outputs.size() == 1, + "Expected 1 output from text_decoder, got %zu", + outputs.size()); + ET_CHECK_MSG(outputs[0].isTensor(), "text_decoder output is not a tensor"); + return outputs[0].toTensor(); + } + inline ::executorch::runtime::Error load() override { if (is_method_loaded()) { return ::executorch::runtime::Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); + + auto meta = module_->method_meta(kTokenEmbeddingMethod); + ET_CHECK_OK_OR_RETURN_ERROR(meta.error()); + uses_ple_ = (meta->num_outputs() == 2); return ::executorch::runtime::Error::Ok; } @@ -59,6 +108,9 @@ class MultimodalDecoderRunner : public TextDecoderRunner { return module_->is_method_loaded(kTokenEmbeddingMethod) && module_->is_method_loaded(kTextModelMethod); } + +private: + bool uses_ple_ = false; }; } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_input.h b/packages/react-native-executorch/common/runner/multimodal_input.h index 6b7de35014..b515c866d4 100644 --- a/packages/react-native-executorch/common/runner/multimodal_input.h +++ b/packages/react-native-executorch/common/runner/multimodal_input.h @@ -19,6 +19,15 @@ namespace executorch::extension::llm { struct ImagePath { std::string path; }; +// In-memory raw audio (fp32, mono). Pattern mirrors SpeechToText: the JS +// layer decodes WAV/MP3 via react-native-audio-api and passes Float32Array +// samples; the PTE has the log-mel frontend baked in, so the runner only +// needs the waveform itself. sample_rate is expected to match the PTE's +// mel-extractor (Gemma4: 16000 Hz). +struct AudioWaveform { + std::vector samples; + int32_t sample_rate; +}; class MultimodalInput { public: @@ -27,6 +36,7 @@ class MultimodalInput { : data_(std::move(tokens)) {} explicit MultimodalInput(ImagePath image_path) : data_(std::move(image_path)) {} + explicit MultimodalInput(AudioWaveform audio) : data_(std::move(audio)) {} MultimodalInput(const MultimodalInput &) = default; MultimodalInput &operator=(const MultimodalInput &) = default; @@ -42,6 +52,9 @@ class MultimodalInput { bool is_image() const noexcept { return std::holds_alternative(data_); } + bool is_audio() const noexcept { + return std::holds_alternative(data_); + } const std::string &get_text() const & { return std::get(data_); } const std::vector &get_tokens() const & { @@ -50,9 +63,13 @@ class MultimodalInput { const std::string &get_image_path() const & { return std::get(data_).path; } + const AudioWaveform &get_audio() const & { + return std::get(data_); + } private: - std::variant, ImagePath> data_; + std::variant, ImagePath, AudioWaveform> + data_; }; inline MultimodalInput make_text_input(const std::string &text) noexcept { @@ -64,5 +81,9 @@ inline MultimodalInput make_text_input(std::string &&text) noexcept { inline MultimodalInput make_image_input(std::string path) noexcept { return MultimodalInput(ImagePath{std::move(path)}); } +inline MultimodalInput make_audio_input(std::vector samples, + int32_t sample_rate = 16000) noexcept { + return MultimodalInput(AudioWaveform{std::move(samples), sample_rate}); +} } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp index 83a1a7f79c..698002a5cf 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp @@ -8,11 +8,23 @@ // Ported from executorch/extension/llm/runner/multimodal_prefiller.cpp // with our token-embedding padding fix and LFM2-VL adaptations. +// +// Supports two PTE shapes, selected from MultimodalDecoderRunner::uses_ple() +// (auto-detected at load time): +// * Legacy : token_embedding -> inputs_embeds; +// text_decoder(inputs_embeds, cache_positions). +// * PLE : token_embedding -> (inputs_embeds, ple_tok); +// text_decoder(inputs_embeds, ple_tok, cache_positions). #include "multimodal_prefiller.h" #include "constants.h" #include "util.h" #include +#include +#include +#include +#include +#include namespace executorch::extension::llm { @@ -23,91 +35,633 @@ using ::executorch::runtime::Result; MultimodalPrefiller::MultimodalPrefiller( Module &module, MultimodalDecoderRunner &decoder_runner, - tokenizers::HFTokenizer &tokenizer, IEncoder *image_encoder) + tokenizers::HFTokenizer &tokenizer, IEncoder *image_encoder, + IEncoder *audio_encoder) : module_(&module), decoder_runner_(&decoder_runner), - tokenizer_(&tokenizer), image_encoder_(image_encoder) {} - -Result MultimodalPrefiller::prefill(const MultimodalInput &input, - int64_t &start_pos) { - EValue encoder_output; - std::vector padded_tokens_storage; - TensorPtr sliced_embed_storage; - - if (input.is_image()) { - ET_CHECK_OR_RETURN_ERROR(image_encoder_ != nullptr, InvalidState, - "No image encoder registered"); - auto encode_result = image_encoder_->encode(input); + tokenizer_(&tokenizer), image_encoder_(image_encoder), + audio_encoder_(audio_encoder) {} + +Result +MultimodalPrefiller::prefill(const std::vector &inputs, + int64_t &start_pos) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa multimodal prefill"); + + const bool uses_ple = decoder_runner_->uses_ple(); + const long t_prefill_begin = time_in_ms(); + + ET_CHECK_OR_RETURN_ERROR(!inputs.empty(), InvalidArgument, + "prefill: empty input list"); + + // ------------------------------------------------------------ + // Capacity & shape policy from PTE metadata. + // + // Three knobs drive prefill: + // * get_max_seq_len โ€” text_decoder S cap. In dynamic-shape PTEs this + // is the per-call chunk size (Gemma4 iter201 = 128); + // in static-shape PTEs (LFM2-VL) it is also the + // single-shot prefill cap. + // * get_max_context_len โ€” total KV budget (Gemma4 iter201 = 2048). Only + // materially used by the dynamic-shape path. + // * enable_dynamic_shape โ€” selects between chunked (true) and single-shot + // padded (false) prefill. + // ------------------------------------------------------------ + int64_t max_seq_len = -1; + { + auto r = module_->get(kMaxSeqLen); + if (r.error() == Error::Ok) { + max_seq_len = r->toScalar().to(); + } + } + if (max_seq_len <= 0) { + max_seq_len = kMaxPrefillLen; + } + + int64_t max_context_len = max_seq_len; + { + auto r = module_->get(kMaxContextLen); + if (r.error() == Error::Ok) { + max_context_len = r->toScalar().to(); + } + } + + bool enable_dynamic_shape = false; + { + auto r = module_->get(kEnableDynamicShape); + if (r.error() == Error::Ok) { + enable_dynamic_shape = r->toScalar().to(); + } + } + + const int64_t prefill_total_cap = + enable_dynamic_shape ? max_context_len : max_seq_len; + const int64_t decoder_chunk_size = max_seq_len; + + // ------------------------------------------------------------ + // Pass 1: build a fused input_ids buffer spanning all inputs. + // + // Mirrors gemma_export/experiments/infer_image.py::prefill_single_shot: + // llm_ids = prefix_ids + [0] * num_soft + suffix_ids + // Image positions use pad_token_id=0, matching HF modeling_gemma4.py:2190 + // (placeholder_id is rewritten to 0 before PLE lookup). The decoder embeds + // at those positions are then overwritten with the vision encoder output + // in pass 2. + // ------------------------------------------------------------ + struct ImageSlot { + const MultimodalInput *input; // non-owning, valid for duration of call + int64_t slot_start; + int64_t num_visual; + }; + // Audio tokens are dynamic per clip, so we encode first and remember a + // BYTE SNAPSHOT of the encoder output + count + dtype; pass 2 splices + // from the snapshot. + // + // We can NOT stash the EValue here. EValue holds an aten::Tensor which is + // just a TensorImpl*; `Method::get_output(i)` returns `const EValue&` to + // Method-internal storage and Module::execute copies that EValue into the + // returned vector. The copy shares the underlying TensorImpl, so a later + // execute() on the same method โ€” a second audio input in this prefill, + // a Vulkan backend output-buffer reuse across methods, or a load-time + // warm-up โ€” mutates `sizes()` in place under our feet. The original error + // ("audio encoder returned 96 tokens, expected 60") is exactly this: + // slot.num_audio was captured from the FIRST encode, slot.encoded.size(1) + // reflected the SECOND. Mirrors main_mm.cpp:604-675's copy-on-encode. + struct AudioSlot { + std::vector bytes; + ::executorch::aten::ScalarType dtype; + int64_t slot_start; + int64_t num_audio; + int64_t audio_hidden; + }; + + std::vector ids; + ids.reserve(static_cast(prefill_total_cap)); + std::vector image_slots; + std::vector audio_slots; + long audio_encode_ms = 0; + int audio_calls = 0; + + for (const auto &input : inputs) { + + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa input type", input.is_audio() ? "audio" : "no audio", input.is_image() ? "image" : "no image"); + + if (input.is_image()) { + ET_CHECK_OR_RETURN_ERROR(image_encoder_ != nullptr, InvalidState, + "No image encoder registered"); + const int32_t num_visual = image_encoder_->encoderTokenCount(); + ET_CHECK_OR_RETURN_ERROR(num_visual > 0, InvalidState, + "Image encoder reports 0 visual tokens"); + image_slots.push_back(ImageSlot{&input, static_cast(ids.size()), + static_cast(num_visual)}); + ids.insert(ids.end(), static_cast(num_visual), 0); + } else if (input.is_audio()) { + ET_CHECK_OR_RETURN_ERROR(audio_encoder_ != nullptr, InvalidState, + "No audio encoder registered"); + const long t_aud_begin = time_in_ms(); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa encoding audio"); + auto enc = audio_encoder_->encode(input); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa encoded audio"); + ET_CHECK_OK_OR_RETURN_ERROR(enc.error(), "Audio encoding failed"); + audio_encode_ms += time_in_ms() - t_aud_begin; + audio_calls += 1; + // Snapshot the encoder output NOW โ€” see AudioSlot comment above for + // why the returned EValue's tensor metadata can't survive past the + // next module_->execute(). num_audio and audio_hidden are read from + // the tensor directly rather than from encoderTokenCount() so they + // are guaranteed to reflect THIS encode call. + auto audio_tensor = enc->toTensor(); + ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState, + "audio_encoder output rank=%zd, expected 3", + audio_tensor.dim()); + const int64_t num_audio = static_cast(audio_tensor.size(1)); + const int64_t audio_hidden = static_cast(audio_tensor.size(2)); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa encoded audio", num_audio, audio_hidden); + ET_CHECK_OR_RETURN_ERROR(num_audio > 0, InvalidState, + "Audio encoder produced 0 tokens"); + std::vector bytes(audio_tensor.nbytes()); + std::memcpy(bytes.data(), audio_tensor.const_data_ptr(), + audio_tensor.nbytes()); + audio_slots.push_back(AudioSlot{ + std::move(bytes), audio_tensor.scalar_type(), + static_cast(ids.size()), num_audio, audio_hidden}); + ids.insert(ids.end(), static_cast(num_audio), 0); + + // Diagnostic: dump audio encoder output magnitude + first values so + // the user can see in the JS console whether the encoder is producing + // real embeddings or zeros/noise. slot.bytes can be either fp32 or + // fp16 depending on exporter; handle both. + { + const AudioSlot &slot = audio_slots.back(); + const size_t nrows = static_cast(slot.num_audio); + const size_t hidden = static_cast(slot.audio_hidden); + const size_t nfloats = nrows * hidden; + const size_t bytes_per_elem = + nfloats > 0 ? slot.bytes.size() / nfloats : 0; + float maxabs = 0.0f; + double sumsq = 0.0; + std::string head; + const size_t headN = std::min(16, nfloats); + if (slot.dtype == ::executorch::aten::ScalarType::Float) { + const float *e = reinterpret_cast(slot.bytes.data()); + for (size_t i = 0; i < nfloats; ++i) { + const float a = std::fabs(e[i]); + if (a > maxabs) + maxabs = a; + sumsq += static_cast(e[i]) * static_cast(e[i]); + } + for (size_t i = 0; i < headN; ++i) { + if (i) + head += ", "; + head += std::to_string(e[i]); + } + } else if (slot.dtype == ::executorch::aten::ScalarType::Half) { + const auto *e = reinterpret_cast( + slot.bytes.data()); + for (size_t i = 0; i < nfloats; ++i) { + const float v = static_cast(e[i]); + const float a = std::fabs(v); + if (a > maxabs) + maxabs = a; + sumsq += static_cast(v) * static_cast(v); + } + for (size_t i = 0; i < headN; ++i) { + if (i) + head += ", "; + head += std::to_string(static_cast(e[i])); + } + } else { + head = ""; + } + const double rms = + nfloats > 0 ? std::sqrt(sumsq / static_cast(nfloats)) : 0.0; + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "kappa [AudioEmbed] num_audio=", nrows, " hidden=", hidden, + " slot_start=", slot.slot_start, + " dtype=", static_cast(slot.dtype), + " bytes_per_elem=", bytes_per_elem, " maxabs=", maxabs, + " rms=", rms, " first16=", head); + } + } else if (input.is_text() || input.is_tokens()) { + std::vector tokens; + if (input.is_text()) { + auto encode_result = tokenizer_->encode(input.get_text()); + if (!encode_result.ok()) { + ET_LOG(Error, "Tokenizer encode error %d", + static_cast(encode_result.error())); + return Error::InvalidArgument; + } + tokens = std::move(*encode_result); + } else { + tokens = input.get_tokens(); + } + for (auto t : tokens) { + ids.push_back(static_cast(t)); + } + } else { + ET_LOG(Error, "Unsupported MultimodalInput type"); + return Error::NotSupported; + } + } + + const int64_t total_len = static_cast(ids.size()); + ET_CHECK_OR_RETURN_ERROR(total_len > 0, InvalidArgument, + "prefill produced zero tokens"); + + ET_CHECK_OR_RETURN_ERROR( + total_len <= prefill_total_cap, InvalidArgument, + "Prefill length %lld exceeds %s (%lld)", + static_cast(total_len), + enable_dynamic_shape ? "get_max_context_len" : "get_max_seq_len", + static_cast(prefill_total_cap)); + if (!enable_dynamic_shape) { + // Static-shape token_embedding needs fixed-length input; trailing pad + // zeros are inert because we copy only `total_len` rows out of the + // embedding output below. + ids.resize(static_cast(max_seq_len), 0); + } + + // ------------------------------------------------------------ + // Splice diagnostics: dump the fused id stream + slot map so the JS + // console can confirm placeholder positions (audio/image use id=0) + // line up with where the encoder outputs will be spliced in pass 2/2b. + // ------------------------------------------------------------ + { + std::string ids_dump; + const size_t dump_n = std::min(ids.size(), size_t{512}); + ids_dump.reserve(dump_n * 6); + for (size_t i = 0; i < dump_n; ++i) { + if (i) ids_dump += ","; + ids_dump += std::to_string(ids[i]); + } + if (ids.size() > dump_n) ids_dump += ",โ€ฆ"; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[Splice] total_len=", total_len, + " ids_buf=", ids.size(), + " image_slots=", image_slots.size(), + " audio_slots=", audio_slots.size(), + " ids[0..", dump_n, "]=", ids_dump); + for (size_t i = 0; i < image_slots.size(); ++i) { + const auto &s = image_slots[i]; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[Splice] image_slot[", i, + "] start=", s.slot_start, + " num=", s.num_visual); + } + for (size_t i = 0; i < audio_slots.size(); ++i) { + const auto &s = audio_slots[i]; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[Splice] audio_slot[", i, + "] start=", s.slot_start, + " num=", s.num_audio, + " hidden=", s.audio_hidden, + " dtype=", static_cast(s.dtype), + " bytes=", s.bytes.size()); + } + } + + // ------------------------------------------------------------ + // Single token_embedding call over the fused id buffer. + // ------------------------------------------------------------ + const int64_t tok_buf_len = static_cast(ids.size()); + auto token_tensor = ::executorch::extension::from_blob( + ids.data(), {1, static_cast(tok_buf_len)}, + ::executorch::aten::ScalarType::Long); + + const long t_tokembed_begin = time_in_ms(); + auto embed_result = module_->execute(kTokenEmbeddingMethod, token_tensor); + ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); + auto &embed_outputs = *embed_result; + const long t_tokembed_end = time_in_ms(); + + const size_t expected_outputs = uses_ple ? 2u : 1u; + ET_CHECK_OR_RETURN_ERROR(embed_outputs.size() == expected_outputs, + InvalidState, + "Expected %zu output(s) from token_embedding, " + "got %zu", + expected_outputs, embed_outputs.size()); + + auto full_embed = embed_outputs[0].toTensor(); + const auto hidden = static_cast(full_embed.size(2)); + + // Own the embeds for the live prefix โ€” subsequent vision_encoder.execute + // calls may reuse the token_embedding output buffer in the runtime. + // Dtype is whatever the exporter chose (fp32 baseline, fp16 + // s16k_jitmask_fp16); copy bytes through nbytes/numel so we don't assume the + // scalar type. + const ::executorch::aten::ScalarType embeds_dtype = full_embed.scalar_type(); + const size_t embeds_total_numel = static_cast(full_embed.numel()); + ET_CHECK_OR_RETURN_ERROR(embeds_total_numel > 0, InvalidState, + "token_embedding returned zero elements"); + const size_t embeds_elem_size = full_embed.nbytes() / embeds_total_numel; + const size_t embeds_prefix_bytes = static_cast(total_len) * + static_cast(hidden) * + embeds_elem_size; + std::vector embeds_buf(embeds_prefix_bytes); + std::memcpy(embeds_buf.data(), full_embed.mutable_data_ptr(), + embeds_prefix_bytes); + + // Own the ple_tok prefix similarly. Dtype is whatever the exporter chose + // (commonly bf16/int8); we copy bytes through nbytes/numel without + // assuming the scalar type. `ple_elem_size` is hoisted so the chunked + // text_decoder loop below can use it for byte-offset slicing. + std::vector ple_tok_buf; + SizesType num_layers = 0; + SizesType ple_dim = 0; + size_t ple_elem_size = 0; + ::executorch::aten::ScalarType ple_tok_dtype = + ::executorch::aten::ScalarType::Float; + if (uses_ple) { + auto full_ple_tok = embed_outputs[1].toTensor(); + num_layers = static_cast(full_ple_tok.size(2)); + ple_dim = static_cast(full_ple_tok.size(3)); + ple_tok_dtype = full_ple_tok.scalar_type(); + const size_t total_numel = static_cast(full_ple_tok.numel()); + const size_t total_bytes = full_ple_tok.nbytes(); + ET_CHECK_OR_RETURN_ERROR(total_numel > 0, InvalidState, + "ple_tok has zero elements"); + ple_elem_size = total_bytes / total_numel; + const size_t prefix_bytes = static_cast(total_len) * + static_cast(num_layers) * + static_cast(ple_dim) * ple_elem_size; + ple_tok_buf.resize(prefix_bytes); + std::memcpy(ple_tok_buf.data(), full_ple_tok.mutable_data_ptr(), + prefix_bytes); + } + + // ------------------------------------------------------------ + // Pass 2: encode images and splice their outputs into embeds_buf. + // ------------------------------------------------------------ + long vision_total_ms = 0; + int vision_calls = 0; + for (const auto &slot : image_slots) { + const long t_vis_begin = time_in_ms(); + auto encode_result = image_encoder_->encode(*slot.input); ET_CHECK_OK_OR_RETURN_ERROR(encode_result.error(), "Image encoding failed"); - encoder_output = *encode_result; - - } else if (input.is_text() || input.is_tokens()) { - std::vector tokens; - if (input.is_text()) { - auto encode_result = tokenizer_->encode(input.get_text()); - if (!encode_result.ok()) { - ET_LOG(Error, "Tokenizer encode error %d", - static_cast(encode_result.error())); - return Error::InvalidArgument; + auto vision_tensor = encode_result->toTensor(); + vision_total_ms += time_in_ms() - t_vis_begin; + vision_calls += 1; + ET_CHECK_OR_RETURN_ERROR( + static_cast(vision_tensor.size(1)) == slot.num_visual, + InvalidState, "vision encoder returned %lld tokens, expected %lld", + static_cast(vision_tensor.size(1)), + static_cast(slot.num_visual)); + ET_CHECK_OR_RETURN_ERROR( + static_cast(vision_tensor.size(2)) == + static_cast(hidden), + InvalidState, "vision encoder hidden %lld != text_embed hidden %lld", + static_cast(vision_tensor.size(2)), + static_cast(hidden)); + + const auto vision_dtype = vision_tensor.scalar_type(); + const size_t visual_elems = + static_cast(slot.num_visual) * static_cast(hidden); + uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * + static_cast(hidden) * + embeds_elem_size; + if (vision_dtype == embeds_dtype) { + const uint8_t *src = + static_cast(vision_tensor.const_data_ptr()); + std::memcpy(dst, src, visual_elems * embeds_elem_size); + } else if (vision_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::Half) { + const float *src = vision_tensor.const_data_ptr(); + auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_h[i] = ::executorch::aten::Half(src[i]); + } + } else if (vision_dtype == ::executorch::aten::ScalarType::Half && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = + vision_tensor.const_data_ptr<::executorch::aten::Half>(); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_f[i] = static_cast(src[i]); } - tokens = std::move(*encode_result); } else { - tokens = input.get_tokens(); + ET_CHECK_OR_RETURN_ERROR( + false, InvalidState, + "unsupported vision/text dtype pair: vision=%hhd text=%hhd", + static_cast(vision_dtype), static_cast(embeds_dtype)); } + } - const auto actual_seq_len = static_cast(tokens.size()); + // ------------------------------------------------------------ + // Pass 2b: splice encoded audio tokens into embeds_buf. Reads from the + // byte snapshot taken at encode time so post-encode execute() calls can't + // invalidate slot state. Same dtype-conversion matrix as vision. + // ------------------------------------------------------------ - // The token_embedding PTE has a fixed MAX_SEQ_LEN input buffer. - // Pad with zeros, run embedding, then slice output back to actual length. - int64_t max_seq_len = actual_seq_len; // fallback: no padding needed - auto max_seq_len_result = module_->get(kMaxSeqLen); - if (max_seq_len_result.error() == Error::Ok) { - max_seq_len = max_seq_len_result->toScalar().to(); + // Diagnostic helper: stringify the first N elements of a byte buffer + // interpreted as the given scalar dtype (Float or Half supported). Used + // for the before/src/after snapshots logged per audio slot so the user + // can verify (a) the destination held pad-token embeds prior to splice, + // (b) the source bytes carry real audio encoder output, and (c) the + // destination contains the audio values post-splice. + auto stringify_floats = [](const uint8_t *p, size_t n, + ::executorch::aten::ScalarType dt) -> std::string { + std::string out; + out.reserve(n * 10); + for (size_t i = 0; i < n; ++i) { + if (i) out += ","; + float v = 0.0f; + if (dt == ::executorch::aten::ScalarType::Float) { + v = reinterpret_cast(p)[i]; + } else if (dt == ::executorch::aten::ScalarType::Half) { + v = static_cast( + reinterpret_cast(p)[i]); + } + char buf[24]; + std::snprintf(buf, sizeof(buf), "%.4f", v); + out += buf; } + return out; + }; - padded_tokens_storage.assign(max_seq_len, 0); - std::ranges::copy(tokens, padded_tokens_storage.begin()); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "kappa [Splice] pass 2b begin: audio_slots=", + audio_slots.size(), " embeds_buf_bytes=", + embeds_buf.size(), " hidden=", hidden, + " embeds_dtype=", static_cast(embeds_dtype), + " embeds_elem_size=", embeds_elem_size); - auto text_tensor = ::executorch::extension::from_blob( - padded_tokens_storage.data(), {1, static_cast(max_seq_len)}, - ::executorch::aten::ScalarType::Long); + for (auto &slot : audio_slots) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa splice loop!!!"); + + ET_CHECK_OR_RETURN_ERROR( + slot.audio_hidden == static_cast(hidden), InvalidState, + "audio encoder hidden %lld != text_embed hidden %lld", + static_cast(slot.audio_hidden), + static_cast(hidden)); - auto embed_result = module_->execute(kTokenEmbeddingMethod, text_tensor); - ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); + const auto audio_dtype = slot.dtype; + const size_t audio_elems = + static_cast(slot.num_audio) * static_cast(hidden); + const size_t audio_elem_size = + audio_elems > 0 ? slot.bytes.size() / audio_elems : 0; + ET_CHECK_OR_RETURN_ERROR( + audio_elem_size > 0 && + audio_elem_size * audio_elems == slot.bytes.size(), + InvalidState, + "audio slot bytes %zu inconsistent with num_audio=%lld hidden=%lld", + slot.bytes.size(), static_cast(slot.num_audio), + static_cast(hidden)); + + uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * + static_cast(hidden) * + embeds_elem_size; + + // Snapshot the first row before splice. These bytes should hold the + // pad-token (id=0) embedding because ids.insert pushed zeros into + // these slot positions before the token_embedding call. If `before` + // already equals `src` you are looking at a re-run; if `after` differs + // from `src` the splice/dtype-cast path is wrong. + const size_t dbg_n = std::min(8, static_cast(hidden)); + const std::string before_str = + stringify_floats(dst, dbg_n, embeds_dtype); + const std::string src_str = + stringify_floats(slot.bytes.data(), dbg_n, audio_dtype); + + if (audio_dtype == embeds_dtype) { + std::memcpy(dst, slot.bytes.data(), audio_elems * embeds_elem_size); + } else if (audio_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::Half) { + const float *src = reinterpret_cast(slot.bytes.data()); + auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); + for (size_t i = 0; i < audio_elems; ++i) { + dst_h[i] = ::executorch::aten::Half(src[i]); + } + } else if (audio_dtype == ::executorch::aten::ScalarType::Half && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = reinterpret_cast( + slot.bytes.data()); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < audio_elems; ++i) { + dst_f[i] = static_cast(src[i]); + } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidState, + "unsupported audio/text dtype pair: audio=%hhd text=%hhd", + static_cast(audio_dtype), static_cast(embeds_dtype)); + } - auto full_embed = (*embed_result)[0].toTensor(); - const auto embed_dim = static_cast(full_embed.size(2)); - sliced_embed_storage = ::executorch::extension::from_blob( - full_embed.mutable_data_ptr(), {1, actual_seq_len, embed_dim}, - ::executorch::aten::ScalarType::Float); - encoder_output = EValue(*sliced_embed_storage); + // Post-splice: re-read the same dst row + compute magnitude over the + // entire spliced region so the user can confirm the bytes flipped to + // real audio embeds (post should equal src, post_maxabs > 0 always). + const std::string after_str = + stringify_floats(dst, dbg_n, embeds_dtype); + float maxabs = 0.0f; + double sumsq = 0.0; + for (size_t i = 0; i < audio_elems; ++i) { + float v = 0.0f; + if (embeds_dtype == ::executorch::aten::ScalarType::Float) { + v = reinterpret_cast(dst)[i]; + } else if (embeds_dtype == ::executorch::aten::ScalarType::Half) { + v = static_cast( + reinterpret_cast(dst)[i]); + } + const float av = std::fabs(v); + if (av > maxabs) maxabs = av; + sumsq += static_cast(v) * static_cast(v); + } + const double rms = + audio_elems > 0 + ? std::sqrt(sumsq / static_cast(audio_elems)) + : 0.0; - } else { - ET_LOG(Error, "Unsupported MultimodalInput type"); - return Error::NotSupported; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "kappa [Splice] audio slot start=", slot.slot_start, + " num=", slot.num_audio, + " dtype src=", static_cast(audio_dtype), + " dst=", static_cast(embeds_dtype), + " pre[0..", dbg_n, "]=", before_str, + " src[0..", dbg_n, "]=", src_str, + " post[0..", dbg_n, "]=", after_str, + " post_maxabs=", maxabs, " post_rms=", rms); } + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa post splice", embeds_buf); - // Run text_decoder for prefill. - int64_t seq_len = encoder_output.toTensor().size(1); - if (seq_len == 0) { - ET_LOG(Error, "Encoder returned empty output"); - return Error::InvalidState; + // ------------------------------------------------------------ + // Chunked text_decoder calls. + // + // Some PTEs (Gemma4 iter201) hard-cap text_decoder's S dim at + // get_max_seq_len (128) while the prefill budget extends to + // get_max_context_len (2048). KV cache state persists across calls via the + // absolute input_pos vector, so chunking is functionally transparent to + // the model. For single-shot static-shape PTEs (LFM2-VL) chunk_cap == + // total_len so the loop iterates exactly once โ€” preserving prior behavior. + // ------------------------------------------------------------ + const int64_t chunk_cap = + decoder_chunk_size > 0 ? decoder_chunk_size : total_len; + std::vector cache_positions(static_cast(total_len)); + for (int64_t i = 0; i < total_len; ++i) { + cache_positions[static_cast(i)] = start_pos + i; } - std::vector cache_positions; - auto cache_pos_result = populate_start_pos_or_cache_position( - module_, start_pos, cache_positions, seq_len, kTextModelMethod); - ET_CHECK_OK_OR_RETURN_ERROR(cache_pos_result.error()); + const long t_textdec_begin = time_in_ms(); + std::vector last_outs; + const int64_t num_chunks = (total_len + chunk_cap - 1) / chunk_cap; + for (int64_t ci = 0; ci < num_chunks; ++ci) { + const int64_t cs = ci * chunk_cap; + const int64_t ce = std::min(cs + chunk_cap, total_len); + const int64_t chunk_len = ce - cs; - auto prefill_result = - module_->execute(kTextModelMethod, {encoder_output, *cache_pos_result}); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error()); + uint8_t *embeds_chunk_ptr = + embeds_buf.data() + static_cast(cs) * + static_cast(hidden) * embeds_elem_size; + auto embeds_chunk = ::executorch::extension::from_blob( + embeds_chunk_ptr, {1, static_cast(chunk_len), hidden}, + embeds_dtype); + + TensorPtr ple_chunk; + if (uses_ple) { + uint8_t *ple_chunk_ptr = + ple_tok_buf.data() + + static_cast(cs) * static_cast(num_layers) * + static_cast(ple_dim) * ple_elem_size; + ple_chunk = ::executorch::extension::from_blob( + ple_chunk_ptr, + {1, static_cast(chunk_len), num_layers, ple_dim}, + ple_tok_dtype); + } + + auto pos_chunk = ::executorch::extension::from_blob( + cache_positions.data() + cs, {static_cast(chunk_len)}, + ::executorch::aten::ScalarType::Long); - auto &prefill_outputs = *prefill_result; - ET_CHECK_OR_RETURN_ERROR(!prefill_outputs.empty(), InvalidState, + auto res = + uses_ple + ? module_->execute(kTextModelMethod, + {EValue(*embeds_chunk), EValue(*ple_chunk), + EValue(*pos_chunk)}) + : module_->execute(kTextModelMethod, + {EValue(*embeds_chunk), EValue(*pos_chunk)}); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + last_outs = std::move(*res); + } + const long t_textdec_end = time_in_ms(); + + ET_CHECK_OR_RETURN_ERROR(!last_outs.empty(), InvalidState, "text_decoder returned no outputs during prefill"); - auto logits = prefill_outputs[0].toTensor(); - start_pos += seq_len; + auto logits = last_outs[0].toTensor(); + const long t_logits_end = time_in_ms(); + start_pos += total_len; + + const long prefill_total = t_logits_end - t_prefill_begin; + const long tokembed_ms = t_tokembed_end - t_tokembed_begin; + const long textdec_ms = t_textdec_end - t_textdec_begin; + const long sample_ms = t_logits_end - t_textdec_end; + const long overhead_ms = + prefill_total - tokembed_ms - vision_total_ms - textdec_ms - sample_ms; + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, "prefill splits ms: total=", prefill_total, + " token_embed=", tokembed_ms, " vision(x", vision_calls, + ")=", vision_total_ms, " audio(x", audio_calls, ")=", audio_encode_ms, + " text_decoder=", textdec_ms, " logits->token=", sample_ms, + " overhead=", overhead_ms, " total_len=", total_len, + " chunks=", num_chunks, " chunk_cap=", chunk_cap, + " dynamic=", static_cast(enable_dynamic_shape)); return static_cast(decoder_runner_->logits_to_token(logits)); } @@ -127,6 +681,9 @@ Error MultimodalPrefiller::load() { if (methods.find(kVisionEncoderMethod) != methods.end()) { ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kVisionEncoderMethod)); } + if (methods.find(kAudioEncoderMethod) != methods.end()) { + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod)); + } return Error::Ok; } @@ -140,8 +697,13 @@ bool MultimodalPrefiller::is_method_loaded() { return false; } const auto &methods = *methods_res; - if (methods.find(kVisionEncoderMethod) != methods.end()) { - return module_->is_method_loaded(kVisionEncoderMethod); + if (methods.find(kVisionEncoderMethod) != methods.end() && + !module_->is_method_loaded(kVisionEncoderMethod)) { + return false; + } + if (methods.find(kAudioEncoderMethod) != methods.end() && + !module_->is_method_loaded(kAudioEncoderMethod)) { + return false; } return true; } diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.h b/packages/react-native-executorch/common/runner/multimodal_prefiller.h index d9b5a9bf5c..9676205252 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.h +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.h @@ -23,12 +23,16 @@ class MultimodalPrefiller { explicit MultimodalPrefiller(Module &module, MultimodalDecoderRunner &decoder_runner, tokenizers::HFTokenizer &tokenizer, - IEncoder *image_encoder = nullptr); + IEncoder *image_encoder = nullptr, + IEncoder *audio_encoder = nullptr); - // Prefill one input segment. Updates start_pos in-place. - // Returns the first predicted token after this segment. - ::executorch::runtime::Result prefill(const MultimodalInput &input, - int64_t &start_pos); + // Single-shot prefill: fuses all inputs into one token_embedding call and + // one text_decoder call. Image slots are filled with pad_token_id=0 (HF + // modeling_gemma4.py behavior); vision encoder output overwrites the embeds + // at those slots before the decoder runs. Updates start_pos in-place. + // Returns the first predicted token after the fused prefill. + ::executorch::runtime::Result + prefill(const std::vector &inputs, int64_t &start_pos); ::executorch::runtime::Error load(); bool is_method_loaded(); @@ -38,6 +42,7 @@ class MultimodalPrefiller { MultimodalDecoderRunner *decoder_runner_; tokenizers::HFTokenizer *tokenizer_; IEncoder *image_encoder_; + IEncoder *audio_encoder_; }; } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.cpp b/packages/react-native-executorch/common/runner/multimodal_runner.cpp index 767fef9f38..4c587e1b7e 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_runner.cpp @@ -19,6 +19,9 @@ MultimodalRunner::MultimodalRunner( int32_t MultimodalRunner::get_visual_token_count() const { auto it = encoders_.find(MultimodalType::Image); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "MultimodalRunner::get_visual_token_count"); + if (it == encoders_.end()) { return 0; } @@ -54,8 +57,13 @@ Error MultimodalRunner::load_subcomponents() { if (enc_it != encoders_.end()) { image_encoder = enc_it->second.get(); } + IEncoder *audio_encoder = nullptr; + auto aud_it = encoders_.find(MultimodalType::Audio); + if (aud_it != encoders_.end()) { + audio_encoder = aud_it->second.get(); + } mm_prefiller_ = std::make_unique( - *module_, *mm_decoder_runner_, *tokenizer_, image_encoder); + *module_, *mm_decoder_runner_, *tokenizer_, image_encoder, audio_encoder); mm_token_generator_ = std::make_unique( tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true, std::move(eos_ids_), stats_ptr, config_); @@ -78,14 +86,15 @@ Error MultimodalRunner::generate_internal( } stats_.inference_start_ms = time_in_ms(); - - uint64_t prefill_next_token = 0; - for (const auto &input : inputs) { - auto prefill_result = mm_prefiller_->prefill(input, pos_); - if (!prefill_result.ok()) - return prefill_result.error(); - prefill_next_token = prefill_result.get(); - } + const long t_gen_begin = stats_.inference_start_ms; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "inputs count", inputs.size()); + auto prefill_result = mm_prefiller_->prefill(inputs, pos_); + if (!prefill_result.ok()) + return prefill_result.error(); + uint64_t prefill_next_token = prefill_result.get(); + const long t_prefill_done = time_in_ms(); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "prefill result", + prefill_next_token); stats_.first_token_ms = time_in_ms(); stats_.prompt_eval_end_ms = time_in_ms(); @@ -96,13 +105,25 @@ Error MultimodalRunner::generate_internal( config_.max_context_length, config_.max_new_tokens); std::vector seed_tokens = {prefill_next_token}; + bool first_cb_fired = false; + long t_first_cb = 0; auto wrapped_callback = [&](const std::string &piece) { + if (!first_cb_fired) { + t_first_cb = time_in_ms(); + first_cb_fired = true; + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "TTFT splits ms: gen_entry->prefill_done=", + t_prefill_done - t_gen_begin, + " prefill_done->first_token_cb=", t_first_cb - t_prefill_done, + " total=", t_first_cb - t_gen_begin); + } safe_printf(piece.c_str()); fflush(stdout); if (token_callback) token_callback(piece); }; - + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "seed_tokens", seed_tokens); auto generate_result = mm_token_generator_->generate( seed_tokens, pos_, static_cast(std::max(0, resolved_max_new - 1)), diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.h b/packages/react-native-executorch/common/runner/multimodal_runner.h index d24e0b40c2..c6180c54f0 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_runner.h @@ -10,7 +10,7 @@ namespace executorch::extension::llm { -enum class MultimodalType { Image }; +enum class MultimodalType { Image, Audio }; class MultimodalRunner : public BaseLLMRunner { public: diff --git a/packages/react-native-executorch/common/runner/sampler.cpp b/packages/react-native-executorch/common/runner/sampler.cpp index 26c75d4dd5..e29b5a5c63 100644 --- a/packages/react-native-executorch/common/runner/sampler.cpp +++ b/packages/react-native-executorch/common/runner/sampler.cpp @@ -35,6 +35,7 @@ #include "sampler.h" #include #include +#include #include namespace executorch { @@ -91,33 +92,26 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) { n0++; } } +} - auto compare = [](const ProbIndex &a, const ProbIndex &b) { - return a.prob > b.prob; - }; - std::sort(probindex.get(), probindex.get() + n0, compare); - - // truncate the list where cumulative probability exceeds topp - T cumulative_prob = 0; - int last_idx = n0 - 1; // in case of rounding errors consider all elements - for (int i = 0; i < n0; i++) { - cumulative_prob += probindex[i].prob; - if (cumulative_prob > topp_) { - last_idx = i; - break; // we've exceeded topp by including last_idx - } +// Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary +// are kept (matches HuggingFace TopKLogitsWarper). +template void Sampler::mask_topk(T *logits) { + if (topk_ <= 0 || topk_ >= vocab_size_) { + return; } - - // sample from the truncated list - const T &r = coin * cumulative_prob; - T cdf = 0; - for (int i = 0; i <= last_idx; i++) { - cdf += probindex[i].prob; - if (r < cdf) { - return probindex[i].index; + // Partial-select the (topk_-th largest) threshold using nth_element on a + // copy of logits; O(n) average. + std::vector scratch(logits, logits + vocab_size_); + std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1), + scratch.end(), std::greater()); + const T threshold = scratch[topk_ - 1]; + const T neg_inf = std::numeric_limits::lowest(); + for (int i = 0; i < vocab_size_; i++) { + if (logits[i] < threshold) { + logits[i] = neg_inf; } } - return probindex[last_idx].index; // in case of rounding errors } Sampler::Sampler(int32_t vocab_size, float temperature, float topp, @@ -126,10 +120,80 @@ Sampler::Sampler(int32_t vocab_size, float temperature, float topp, : vocab_size_(vocab_size), inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty), + topk_(0), rng_state_(rng_seed) {} + +// Mask logits whose softmax-prob falls outside the top-p nucleus to -inf. +// Keeps the token that crosses the threshold (HuggingFace convention). +template void Sampler::mask_topp(T *logits) { + if (topp_ <= 0.0f || topp_ >= 1.0f) { + return; + } + // Softmax into a scratch probs[] (do not mutate logits yet). + T max_val = logits[0]; + for (int i = 1; i < vocab_size_; i++) { + if (logits[i] > max_val) { + max_val = logits[i]; + } + } + std::unique_ptr[]> probindex = + std::make_unique[]>(vocab_size_); + T sum = 0; + for (int i = 0; i < vocab_size_; i++) { + T e = static_cast(expf(static_cast(logits[i] - max_val))); + probindex[i].prob = e; + probindex[i].index = i; + sum += e; + } + if (sum <= T(0)) { + return; + } + for (int i = 0; i < vocab_size_; i++) { + probindex[i].prob = probindex[i].prob / sum; + } + std::sort(probindex.get(), probindex.get() + vocab_size_, + [](const ProbIndex &a, const ProbIndex &b) { + return a.prob > b.prob; + }); + + // Find the smallest prefix whose cumulative probability >= topp_. + T cumulative = 0; + int last_idx = vocab_size_ - 1; + for (int i = 0; i < vocab_size_; i++) { + cumulative += probindex[i].prob; + if (static_cast(cumulative) >= topp_) { + last_idx = i; + break; + } + } + // Mark kept indices, then -inf the rest. + std::vector keep(vocab_size_, false); + for (int i = 0; i <= last_idx; i++) { + keep[probindex[i].index] = true; + } + const T neg_inf = std::numeric_limits::lowest(); + for (int i = 0; i < vocab_size_; i++) { + if (!keep[i]) { + logits[i] = neg_inf; + } + } +} + +Sampler::Sampler(int vocab_size, float temperature, float topp, int32_t topk, + unsigned long long rng_seed) + : vocab_size_(vocab_size), + inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), + topp_(topp), min_p_(0.0f), repetition_penalty_(1.0f), topk_(topk), rng_state_(rng_seed) {} +Sampler::Sampler(int vocab_size, float temperature, float topp, int32_t topk) + : Sampler(vocab_size, temperature, topp, topk, std::time(nullptr)) {} + +Sampler::Sampler(int vocab_size, float temperature, float topp, + unsigned long long rng_seed) + : Sampler(vocab_size, temperature, topp, 0, rng_seed) {} + Sampler::Sampler(int vocab_size, float temperature, float topp) - : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {} + : Sampler(vocab_size, temperature, topp, 0, std::time(nullptr)) {} template static void softmax(T *x, int size) { // find max value (for numerical stability) @@ -175,9 +239,11 @@ int32_t Sampler::sample(T *logits, const std::vector &recent_tokens) { apply_repetition_penalty(logits, vocab_size_, recent_tokens); // 2. apply the temperature to the logits apply_temperature(logits, vocab_size_); - // 3. apply softmax to the logits to get the probabilities for next token + // 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass) + mask_topk(logits); + // 4. apply softmax to the logits to get the probabilities for next token softmax(logits, vocab_size_); - // 4. apply min_p truncation + // 5. apply min_p truncation apply_min_p(logits, vocab_size_); // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&rng_state_); diff --git a/packages/react-native-executorch/common/runner/sampler.h b/packages/react-native-executorch/common/runner/sampler.h index 16811297ef..cd57e0524e 100644 --- a/packages/react-native-executorch/common/runner/sampler.h +++ b/packages/react-native-executorch/common/runner/sampler.h @@ -41,7 +41,18 @@ class Sampler { Sampler(int32_t vocab_size, float temperature, float topp, unsigned long long rng_seed, float min_p = 0.0f, float repetition_penalty = 1.0f); + // topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p. + // Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask + // -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses + // to greedy; pass topk = 0 to keep full-vocab temperature sampling. + Sampler(int32_t vocab_size, float temperature, float topp, int32_t topk, + unsigned long long rng_seed); + Sampler(int32_t vocab_size, float temperature, float topp, int32_t topk); + + // Back-compat overloads (topk = 0 => disabled). + Sampler(int32_t vocab_size, float temperature, float topp, + unsigned long long rng_seed); Sampler(int32_t vocab_size, float temperature, float topp); template int32_t sample(T *logits); @@ -53,6 +64,9 @@ class Sampler { template int32_t sample_topp(T *probabilities, float coin); template int32_t sample_mult(T *probabilities, float coin); template int32_t sample_argmax(T *probabilities); + // In-place logit warpers: set excluded indices to -inf. + template void mask_topk(T *logits); + template void mask_topp(T *logits); template inline void apply_temperature(T *logits, int32_t vocab_size) { @@ -110,6 +124,7 @@ class Sampler { float topp_; float min_p_; float repetition_penalty_; + int32_t topk_; unsigned long long rng_state_; }; diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp index e67d3e41fb..47cbb39b2d 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp @@ -31,8 +31,18 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager, // outer loop (call site) is responsible for managing state. ::executorch::runtime::Result TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { - // ET_LOG(Info, "Input token %" PRIu64, input_token); - auto method_meta_result = module_->method_meta("forward"); + // Pick method by phase: parallel-prefill chunks (>1 token) go through the + // dynamic "forward"; single-token decode steps go through the static + // "forward_decode" method when present (avoids dynamic-shape re-encode + // overhead โ€” measured 8ร— faster on gemma4 Vulkan). + const char* method_name = "forward"; + // if (tokens->numel() == 1) { + // auto decode_meta = module_->method_meta("forward_decode"); + // if (decode_meta.ok()) { + // method_name = "forward_decode"; + // } + // } + auto method_meta_result = module_->method_meta(method_name); if (!method_meta_result.ok()) { return method_meta_result.error(); } @@ -44,7 +54,7 @@ TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { if (use_kv_cache) { auto start_pos_tensor_result = populate_start_pos_or_cache_position( - module_, start_pos, cache_positions, tokens->numel(), "forward"); + module_, start_pos, cache_positions, tokens->numel(), method_name); if (!start_pos_tensor_result.ok()) { return start_pos_tensor_result.error(); } @@ -54,7 +64,7 @@ TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor); ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error()); inputs = inputs_res.get(); - auto outputs_res = module_->forward(inputs); + auto outputs_res = module_->execute(method_name, inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); auto update_err = io_manager_->update_decode(outputs_res.get()); diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.h b/packages/react-native-executorch/common/runner/text_decoder_runner.h index bffc254bd6..b4a3186161 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.h @@ -10,6 +10,7 @@ #pragma once +#include "constants.h" #include "io_manager.h" #include "sampler.h" @@ -40,11 +41,22 @@ class TextDecoderRunner { step(TensorPtr &input, int64_t start_pos); /** - * Load the Module for text decode purpose. - * @return The error code. + * Load the Module for text decode purpose. Always loads `forward` (used for + * prefill, may be static-shape e.g. [1,256]). Also loads `forward_decode` + * if the PTE exposes it (single-token decode path). */ virtual ::executorch::runtime::Error load() { - return module_->load_method("forward"); + auto err = module_->load_method("forward"); + if (err != ::executorch::runtime::Error::Ok) { + return err; + } + if (module_->method_meta("forward_decode").ok()) { + auto derr = module_->load_method("forward_decode"); + if (derr != ::executorch::runtime::Error::Ok) { + return derr; + } + } + return ::executorch::runtime::Error::Ok; } /** @@ -52,7 +64,65 @@ class TextDecoderRunner { * @return True if the Module is loaded, false otherwise. */ virtual bool is_method_loaded() { - return module_->is_method_loaded("forward"); + if (!module_->is_method_loaded("forward")) { + return false; + } + if (module_->method_meta("forward_decode").ok() && + !module_->is_method_loaded("forward_decode")) { + return false; + } + return true; + } + + /** + * If `forward` declares a static prompt length (input 0 size [1, N]), return + * N. Returns 0 when the prefill method is dynamic-shape or the size cannot + * be determined. Used by TextPrefiller to pad prompt chunks to the exact + * static length the PTE expects. + * + * Detection order: + * 1) If the PTE exposes the `enable_dynamic_shape` constant_method and it + * reads as truthy (bool true OR int!=0), the PTE is dynamic โ€” return 0 + * so the prefiller sends the actual prompt length, no padding. This + * covers iter170+ PTEs whose `forward` input TensorSpec stores the + * dynamic upper bound (e.g. 128) that the meta-based heuristic below + * would mis-detect as static. + * 2) Otherwise (method missing or read failed), fall back to the legacy + * TensorSpec-based heuristic: assume the last dim of input 0 is the + * fixed prompt length when > 1. Preserves behavior for older PTEs. + */ + int64_t prefill_static_len() { + // Step 1: query constant_methods["enable_dynamic_shape"] when present. + // Use module_->get() โ€” this is the canonical pattern in base_llm_runner.cpp + // for reading metadata methods. The value is serialized as a Scalar (bool + // or int); toScalar().to() yields 1 for True, 0 for False. + auto dyn_res = module_->get(kEnableDynamicShape); + if (dyn_res.ok()) { + const auto& evalue = dyn_res.get(); + if (evalue.isScalar()) { + if (evalue.toScalar().to() != 0) { + return 0; // PTE declares itself dynamic โ€” no padding needed. + } + // Explicitly False โ€” fall through to meta heuristic (static PTE). + } + } + // Step 2: legacy fallback โ€” derive static length from input 0's TensorSpec. + auto meta_res = module_->method_meta("forward"); + if (!meta_res.ok()) { + return 0; + } + auto meta = std::move(*meta_res); + auto in0 = meta.input_tensor_meta(0); + if (!in0.ok()) { + return 0; + } + auto sizes = in0->sizes(); + // Expect tokens tensor of rank >=2 with batch dim = 1 and a fixed seq dim. + if (sizes.size() < 2) { + return 0; + } + int64_t seq = sizes[sizes.size() - 1]; + return seq > 1 ? seq : 0; } inline void stop() { should_stop_ = true; } diff --git a/packages/react-native-executorch/common/runner/text_prefiller.cpp b/packages/react-native-executorch/common/runner/text_prefiller.cpp index dc961158b7..0f99821819 100644 --- a/packages/react-native-executorch/common/runner/text_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/text_prefiller.cpp @@ -10,6 +10,7 @@ // LLM. #include "text_prefiller.h" +#include "rnexecutorch/Log.h" #include namespace executorch { @@ -21,7 +22,16 @@ TextPrefiller::TextPrefiller(TextDecoderRunner *text_decoder_runner, int64_t max_seq_len) : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache), enable_parallel_prefill_(enable_parallel_prefill), - max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) {} + max_seq_len_(max_seq_len > 0 ? max_seq_len : 2048) { + // Auto-detect static-shape prefill: when `forward` declares input 0 as + // [1, N] with N>1, we must pad every prefill call to exactly N tokens. + prefill_static_len_ = text_decoder_runner_->prefill_static_len(); + if (prefill_static_len_ > 0) { + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "TextPrefiller: static prefill len detected =", prefill_static_len_); + } +} ::executorch::runtime::Result TextPrefiller::prefill(std::vector &prompt_tokens, @@ -34,14 +44,20 @@ TextPrefiller::prefill(std::vector &prompt_tokens, // Check if we need to chunk the prompt tokens int32_t num_prompt_tokens = prompt_tokens.size(); - // If prompt tokens exceed max_seq_len_, we need to chunk them - if (num_prompt_tokens > max_seq_len_) { + // When the PTE's `forward` is static-shape (e.g. [1, 256]), the chunk size + // is fixed at prefill_static_len_; otherwise fall back to max_seq_len_. + const int32_t chunk_size = prefill_static_len_ > 0 + ? static_cast(prefill_static_len_) + : static_cast(max_seq_len_); + + // If prompt tokens exceed chunk_size, we need to chunk them + if (num_prompt_tokens > chunk_size) { uint64_t cur_token = 0; int num_tokens_to_process = 0; while (num_tokens_to_process < num_prompt_tokens) { - auto num_tokens_to_prefill_with = std::min( - num_prompt_tokens - num_tokens_to_process, max_seq_len_); + auto num_tokens_to_prefill_with = + std::min(num_prompt_tokens - num_tokens_to_process, chunk_size); std::vector prompt_tokens_to_process( num_tokens_to_prefill_with); @@ -75,17 +91,34 @@ TextPrefiller::prefill_chunk(std::vector &prompt_tokens, // store the token uint64_t cur_token; if (enable_parallel_prefill_ || !use_kv_cache_) { - // initialize tensor wrappers - auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, + // Static-shape `forward` (e.g. [1, 256]): pad the prompt chunk to exactly + // prefill_static_len_ with 0, but only count `num_prompt_tokens` real + // tokens for sampling/start_pos. Padded slots' KV writes are overwritten + // by the next prefill chunk or decode step before being attended to. + std::vector padded; + uint64_t *tokens_ptr = prompt_tokens.data(); + int32_t tensor_len = num_prompt_tokens; + if (prefill_static_len_ > 0 && num_prompt_tokens < prefill_static_len_) { + padded.assign(prefill_static_len_, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), padded.begin()); + tokens_ptr = padded.data(); + tensor_len = static_cast(prefill_static_len_); + } + + auto tokens = from_blob(tokens_ptr, {1, tensor_len}, executorch::aten::ScalarType::Long); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "prefill effective_len", + num_prompt_tokens, "tensor_len", tensor_len, "start_pos", + start_pos); auto outputs_res = text_decoder_runner_->step(tokens, start_pos); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG(Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); - start_pos += num_prompt_tokens; + start_pos += num_prompt_tokens; // advance only by REAL tokens + // Sample from the row corresponding to the last real prompt token. cur_token = text_decoder_runner_->logits_to_token(outputs_res.get()); } else { // sequential prefill int64_t pos = 0; // position in the sequence diff --git a/packages/react-native-executorch/common/runner/text_prefiller.h b/packages/react-native-executorch/common/runner/text_prefiller.h index 7929fe9c7f..f6b7bac408 100644 --- a/packages/react-native-executorch/common/runner/text_prefiller.h +++ b/packages/react-native-executorch/common/runner/text_prefiller.h @@ -70,6 +70,11 @@ class TextPrefiller { bool use_kv_cache_; bool enable_parallel_prefill_; int64_t max_seq_len_; + // If >0, the underlying `forward` method is static-shape: every parallel + // prefill call must pass exactly this many tokens (pad with 0 if shorter). + // Sampled logit row is then (effective_len - 1), and start_pos advances by + // effective_len rather than the padded length. + int64_t prefill_static_len_{0}; }; } // namespace llm diff --git a/packages/react-native-executorch/common/runner/text_runner.cpp b/packages/react-native-executorch/common/runner/text_runner.cpp index 5a75e00b4a..b324df6eb6 100644 --- a/packages/react-native-executorch/common/runner/text_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_runner.cpp @@ -1,6 +1,7 @@ // common/runner/text_runner.cpp #include "text_runner.h" #include "constants.h" +#include "rnexecutorch/Log.h" #include "util.h" #include #include @@ -16,9 +17,14 @@ TextRunner::TextRunner(std::unique_ptr module, : BaseLLMRunner(std::move(module), tokenizer_path, config) {} bool TextRunner::is_loaded() const { +#ifdef RNEX_BYPASS_TOKENIZER + return module_ && module_->is_loaded() && text_decoder_runner_ && + text_prefiller_ && text_token_generator_; +#else return module_ && module_->is_loaded() && tokenizer_ && tokenizer_->is_loaded() && text_decoder_runner_ && text_prefiller_ && text_token_generator_; +#endif } Error TextRunner::load_subcomponents() { @@ -26,8 +32,8 @@ Error TextRunner::load_subcomponents() { Stats *stats_ptr = &stats_; - text_decoder_runner_ = std::make_unique( - *module_, io_manager_.get(), config_); + text_decoder_runner_ = + std::make_unique(*module_, io_manager_.get(), config_); text_prefiller_ = std::make_unique( text_decoder_runner_.get(), config_.enable_kv_cache, config_.enable_dynamic_shape, config_.max_seq_len); @@ -65,9 +71,25 @@ Error TextRunner::generate_internal( stats_.inference_start_ms = time_in_ms(); + // Multi-turn: JS re-renders the full chat history each call, so reset KV + // position to 0 and re-prefill from scratch. + pos_ = 0; + int64_t context_len_left = static_cast(config_.max_context_length) - pos_; +#ifdef RNEX_BYPASS_TOKENIZER + // Llama 3.2 Instruct chat template wrapping "Hello" โ€” decode offline: + // <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n + // Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n + (void)prompt; + std::vector prompt_tokens = { + 128000, 128006, 882, 128007, 271, 128000, 9906, 11, 3371, 757, 264, 3446, 128009, 128006, 78191, 128007, 271}; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "RNEX_BYPASS_TOKENIZER: hardcoded Llama-3.2 prompt tokens, " + "count = " + + std::to_string(prompt_tokens.size())); +#else auto encodeResult = tokenizer_->encode(prompt, numOfAddedBoSTokens, numOfAddedEoSTokens); if (!encodeResult.ok()) { @@ -77,6 +99,7 @@ Error TextRunner::generate_internal( std::to_string(static_cast(encodeResult.error()))); } std::vector prompt_tokens = encodeResult.get(); +#endif int num_prompt_tokens = prompt_tokens.size(); ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument, @@ -90,6 +113,8 @@ Error TextRunner::generate_internal( num_prompt_tokens, config_.max_seq_len, static_cast(context_len_left), config_.max_new_tokens); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa1", max_new_tokens, num_prompt_tokens, config_.max_seq_len, config_.max_new_tokens); + ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument, "Max new tokens %d is <= 0", max_new_tokens); @@ -102,6 +127,11 @@ Error TextRunner::generate_internal( ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); +#ifdef RNEX_BYPASS_TOKENIZER + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "RNEX_BYPASS_TOKENIZER: prefill first token id = " + + std::to_string(cur_token)); +#else auto decodeResult = tokenizer_->decode({cur_token}); if (!decodeResult.ok()) { throw rnexecutorch::RnExecutorchError( @@ -109,6 +139,7 @@ Error TextRunner::generate_internal( "Unexpected issue occurred while decoding: " + std::to_string(static_cast(decodeResult.error()))); } +#endif prompt_tokens.push_back(cur_token); int64_t num_generated = ET_UNWRAP(text_token_generator_->generate( diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index 7ecf6177a9..32b17ee409 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -10,6 +10,7 @@ #pragma once #include "irunner.h" +#include "rnexecutorch/Log.h" #include "stats.h" #include "text_decoder_runner.h" #include "util.h" @@ -90,6 +91,7 @@ class TextTokenGenerator { timestamp_ = std::chrono::high_resolution_clock::now(); // Generate our tokens + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "kappa generator", pos, start_pos, max_new_tokens); while (pos < start_pos + max_new_tokens) { // Run the model auto logits_res = text_decoder_runner_->step(tokens_managed, pos); @@ -100,13 +102,18 @@ class TextTokenGenerator { prev_token = cur_token; stats_->on_sampling_begin(); - cur_token = - text_decoder_runner_->logits_to_token(logits_tensor, generated_tokens); + cur_token = text_decoder_runner_->logits_to_token(logits_tensor, + generated_tokens); stats_->on_sampling_end(); + // rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "Generated token id:", + // static_cast(cur_token)); + pos++; generated_tokens.push_back(cur_token); + const bool eos_reached_now = eos_ids_->find(cur_token) != eos_ids_->end(); + if (use_kv_cache_) { // update the token tensor. token_data will not be empty. // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) @@ -118,8 +125,36 @@ class TextTokenGenerator { tokens_managed, {1, static_cast(token_data.size())})); } + // Don't include the terminal EOS/EOT token in the streamed text โ€” it + // would otherwise be appended to the assistant message stored in chat + // history and corrupt the next turn's chat-template rendering + // (e.g. duplicated ). + if (eos_reached_now) { + printf("\n"); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "Reached end of generation"); +#ifndef RNEX_BYPASS_TOKENIZER + if (!token_cache.empty()) { + auto flush = tokenizer_->decode(token_cache, false); + if (flush.ok() && !flush.get().empty() && + !flush.get().ends_with("๏ฟฝ") && token_callback) { + token_callback(flush.get()); + } + token_cache.clear(); + } +#else + token_cache.clear(); +#endif + break; + } + token_cache.push_back(static_cast(cur_token)); +#ifdef RNEX_BYPASS_TOKENIZER + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "gen_token=" + std::to_string(cur_token)); + std::string cache_decoded = std::to_string(cur_token) + " "; +#else // print the token as string, decode it with the Tokenizer object // We pass false, as we want don't want to skip special tokens e.g. // @@ -133,6 +168,7 @@ class TextTokenGenerator { std::to_string(static_cast(decodeResult.error()))); } std::string cache_decoded = decodeResult.get(); +#endif const auto timeIntervalElapsed = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - timestamp_) @@ -142,8 +178,7 @@ class TextTokenGenerator { const auto eos_reached = eos_ids_->contains(cur_token); if (!cache_decoded.ends_with("๏ฟฝ") && - (countIntervalElapsed || timeIntervalElapsed || should_stop_ || - eos_reached)) { + (countIntervalElapsed || timeIntervalElapsed || should_stop_)) { token_callback(cache_decoded); token_cache.clear(); timestamp_ = std::chrono::high_resolution_clock::now(); @@ -152,13 +187,6 @@ class TextTokenGenerator { if (should_stop_) { break; } - - // data-dependent terminating condition: we have n_eos_ number of EOS - if (eos_ids_->find(cur_token) != eos_ids_->end()) { - printf("\n"); - ET_LOG(Info, "\nReached to the end of generation"); - break; - } } return pos - start_pos; } diff --git a/packages/react-native-executorch/common/runner/util.h b/packages/react-native-executorch/common/runner/util.h index 640b96319f..4a37121069 100644 --- a/packages/react-native-executorch/common/runner/util.h +++ b/packages/react-native-executorch/common/runner/util.h @@ -107,25 +107,27 @@ size_t inline get_rss_bytes() { // (when the method_name [`text_decoder` or `forward`] expects a tensor with // size 1 because model will populate the cache position tensor underneath), or // a populated tensor for cache position, for the given start_pos and seq_len. -inline runtime::Result -populate_start_pos_or_cache_position(Module *module, int64_t &start_pos, - std::vector &cache_positions_vec, - int seq_len, - const char *method_name = "forward") { - // Get expected shape of cache position tensor, which should be the second - // argument +// +// `pos_input_index` tells the helper which input slot holds the cache-position +// tensor. Defaults to 1 for legacy `forward(tokens, cache_positions)` and +// `text_decoder(embeddings, cache_positions)` signatures. The multimodal +// `text_decoder(inputs_embeds, ple_tok, input_pos)` layout passes 2. +inline runtime::Result populate_start_pos_or_cache_position( + Module *module, int64_t &start_pos, + std::vector &cache_positions_vec, int seq_len, + const char *method_name = "forward", size_t pos_input_index = 1) { auto method_meta_result = module->method_meta(method_name); if (!method_meta_result.ok()) { return method_meta_result.error(); } auto method_meta = std::move(*method_meta_result); - auto second_input_info_result = method_meta.input_tensor_meta(1); - if (!second_input_info_result.ok()) { - return second_input_info_result.error(); + auto pos_input_info_result = method_meta.input_tensor_meta(pos_input_index); + if (!pos_input_info_result.ok()) { + return pos_input_info_result.error(); } - auto second_input_info = std::move(*second_input_info_result); - auto second_input_sizes = second_input_info.sizes(); - auto numel = second_input_sizes[0]; + auto pos_input_info = std::move(*pos_input_info_result); + auto pos_input_sizes = pos_input_info.sizes(); + auto numel = pos_input_sizes[0]; TensorPtr start_pos_tensor; if (numel > 1) { diff --git a/packages/react-native-executorch/src/constants/llmDefaults.ts b/packages/react-native-executorch/src/constants/llmDefaults.ts index a27a2f7a4f..77a60fe311 100644 --- a/packages/react-native-executorch/src/constants/llmDefaults.ts +++ b/packages/react-native-executorch/src/constants/llmDefaults.ts @@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy'; * @category Utilities - LLM */ export const DEFAULT_SYSTEM_PROMPT = - "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text."; + "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance"; /** * Generates a default structured output prompt based on the provided JSON schema. diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 159396add8..1a022f56bb 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -52,9 +52,11 @@ export const LLAMA3_2_3B_SPINQUANT = { */ export const LLAMA3_2_1B = { modelName: 'llama-3.2-1b', - modelSource: LLAMA3_2_1B_MODEL, - tokenizerSource: LLAMA3_2_TOKENIZER, - tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, + modelSource: + 'http://localhost:9001/Llama3.2-1B-Instruct_vulkan_8da4w_g64_c2048.pte', + tokenizerSource: 'http://localhost:9001/gemma4_e2b/llama_tokenizer.json', + tokenizerConfigSource: + 'http://localhost:9001/gemma4_e2b/llama_tokenizer_config.json', } as const; /** @@ -117,6 +119,20 @@ export const QWEN3_0_6B_QUANTIZED = { generationConfig: QWEN3_GENERATION_CONFIG, } as const; +/** + * @category Models - VLM + */ +export const GEMMA4_E2B_QUANTIZED = { + modelName: 'gemma4-e2b-quantized', + modelSource: + 'http://localhost:9001/ptes/iter206_xnnpack_audio_rsqrt_baked_mask_relshift_idx_prefill2048.pte', + // 'http://localhost:9001/experiments/outputs/exp107f_mm_s4048_noqkvo_wav.pte', + tokenizerSource: 'http://localhost:9001/gemma4_e2b/tokenizer.json', + tokenizerConfigSource: + 'http://localhost:9001/gemma4_e2b/gemma4_tokenizer_config.json', + capabilities: ['vision', 'audio'], +} as const; + /** * @category Models - LLM */ @@ -1303,6 +1319,7 @@ export const MODEL_REGISTRY = { QWEN3_0_6B, QWEN3_0_6B_QUANTIZED, QWEN3_1_7B, + GEMMA4_E2B_QUANTIZED, QWEN3_1_7B_QUANTIZED, QWEN3_4B, QWEN3_4B_QUANTIZED, diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index 5eb8edaed0..303b0cd105 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -16,6 +16,13 @@ import { Logger } from '../common/Logger'; import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; +// Audio soft-token expansion constants for Gemma4's audio_encoder. +// Mirrors AUDIO_SAMPLES_PER_BLOCK (kSamplesPerBlock=7680) and the per-block +// soft-token rate in audio_encoder.cpp; used to size the context budget so +// long audio doesn't silently overflow get_max_seq_len during prefill. +const AUDIO_SAMPLES_PER_BLOCK = 7680; +const AUDIO_TOKENS_PER_BLOCK = 12; + export class LLMController { private nativeModule: any; private chatConfig: ChatConfig = DEFAULT_CHAT_CONFIG; @@ -236,6 +243,17 @@ export class LLMController { return token; } + private getAudioToken(): string { + const token = this.tokenizerConfig.audio_token; + if (!token) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidConfig, + "Tokenizer config is missing 'audio_token'. Audio-capable models require tokenizerConfigSource with an 'audio_token' field." + ); + } + return token; + } + private filterSpecialTokens(text: string): string { let filtered = text; if ( @@ -244,6 +262,12 @@ export class LLMController { ) { filtered = filtered.replaceAll(this.tokenizerConfig.eos_token, ''); } + if ( + SPECIAL_TOKENS.EOT_TOKEN in this.tokenizerConfig && + this.tokenizerConfig.eot_token + ) { + filtered = filtered.replaceAll(this.tokenizerConfig.eot_token, ''); + } if ( SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && this.tokenizerConfig.pad_token @@ -269,25 +293,35 @@ export class LLMController { this.isGeneratingCallback(false); } - public async forward(input: string, imagePaths?: string[]): Promise { + public async forward( + input: string, + imagePaths?: string[], + audioWaveforms?: Float32Array[] + ): Promise { if (!this._isReady) { throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded); } if (this._isGenerating) { throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating); } + const hasImages = !!imagePaths && imagePaths.length > 0; + const hasAudio = !!audioWaveforms && audioWaveforms.length > 0; try { this.isGeneratingCallback(true); this.nativeModule.reset(); - const response = - imagePaths && imagePaths.length > 0 - ? await this.nativeModule.generateMultimodal( - input, - imagePaths.map(normalizeImagePath), - this.getImageToken(), - this.onToken - ) - : await this.nativeModule.generate(input, this.onToken); + let response: string; + if (hasImages || hasAudio) { + response = await this.nativeModule.generateMultimodal( + input, + this.onToken, + hasImages ? imagePaths!.map(normalizeImagePath) : [], + hasImages ? this.getImageToken() : '', + hasAudio ? audioWaveforms! : [], + hasAudio ? this.getAudioToken() : '' + ); + } else { + response = await this.nativeModule.generate(input, this.onToken); + } return this.filterSpecialTokens(response); } catch (e) { throw parseUnknownError(e); @@ -355,6 +389,9 @@ export class LLMController { const imagePaths = messages .filter((m) => m.mediaPath) .map((m) => m.mediaPath!); + const audioWaveforms = messages + .filter((m) => m.audioWaveform) + .map((m) => m.audioWaveform!); const renderedChat: string = this.applyChatTemplate( messages, @@ -366,19 +403,22 @@ export class LLMController { return await this.forward( renderedChat, - imagePaths.length > 0 ? imagePaths : undefined + imagePaths.length > 0 ? imagePaths : undefined, + audioWaveforms.length > 0 ? audioWaveforms : undefined ); } public async sendMessage( message: string, - media?: { imagePath?: string } + media?: { imagePath?: string; audioBuffer?: Float32Array } ): Promise { const mediaPath = media?.imagePath; + const audioBuffer = media?.audioBuffer; const newMessage: Message = { content: message, role: 'user', ...(mediaPath ? { mediaPath } : {}), + ...(audioBuffer ? { audioWaveform: audioBuffer } : {}), }; const updatedHistory = [...this._messageHistory, newMessage]; this.messageHistoryCallback(updatedHistory); @@ -394,7 +434,22 @@ export class LLMController { ); const textTokens = this.nativeModule.countTextTokens(rendered); const imageCount = messages.filter((m) => m.mediaPath).length; - return textTokens + imageCount * (visualTokenCount - 1); + // Audio soft-token expansion: Gemma4's audio_encoder pads samples to + // multiples of AUDIO_SAMPLES_PER_BLOCK (7680 @ 16 kHz) and emits + // AUDIO_TOKENS_PER_BLOCK (~12) soft tokens per padded block. The + // rendered template only contributes 1 token for the audio placeholder, + // so add (expansion - 1) per audio message to match prefill consumption. + const audioTokenExpansion = messages.reduce((acc, m) => { + if (!m.audioWaveform) return acc; + const kBlocks = Math.max( + 1, + Math.ceil(m.audioWaveform.length / AUDIO_SAMPLES_PER_BLOCK) + ); + return acc + (AUDIO_TOKENS_PER_BLOCK * kBlocks - 1); + }, 0); + return ( + textTokens + imageCount * (visualTokenCount - 1) + audioTokenExpansion + ); }; const maxContextLength = this.nativeModule.getMaxContextLength(); const messageHistoryWithPrompt = @@ -497,12 +552,15 @@ function normalizeImagePath(path: string): string { * @returns Messages with image-bearing turns rewritten to structured content. */ function messagesForChatTemplate(messages: Message[]): any[] { - return messages.map((m) => - m.mediaPath && typeof m.content === 'string' - ? { - ...m, - content: [{ type: 'image' }, { type: 'text', text: m.content }], - } - : m - ); + return messages.map((m) => { + if (typeof m.content !== 'string') return m; + const hasImage = !!m.mediaPath; + const hasAudio = !!m.audioWaveform; + if (!hasImage && !hasAudio) return m; + const parts: any[] = []; + if (hasImage) parts.push({ type: 'image' }); + if (hasAudio) parts.push({ type: 'audio' }); + parts.push({ type: 'text', text: m.content }); + return { ...m, content: parts }; + }); } diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 027e237997..a434011c9e 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -106,7 +106,10 @@ export function useLLM({ ); const sendMessage = useCallback( - (message: string, media?: { imagePath?: string }) => { + ( + message: string, + media?: { imagePath?: string; audioBuffer?: Float32Array } + ) => { setResponse(''); return controllerInstance.sendMessage(message, media); }, diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 6254775c15..ee0a0282ff 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -5,20 +5,22 @@ import { ResourceSource } from './common'; * Capabilities a multimodal LLM can have. * @category Types */ -export type LLMCapability = 'vision'; +export type LLMCapability = 'vision' | 'audio'; /** * Derives the media argument shape for `sendMessage` from a capabilities tuple. * @category Types */ export type MediaArg = - 'vision' extends C[number] ? { imagePath?: string } : object; + ('vision' extends C[number] ? { imagePath?: string } : object) & + ('audio' extends C[number] ? { audioBuffer?: Float32Array } : object); /** * Union of all built-in LLM model names. * @category Types */ export type LLMModelName = + | 'gemma4-e2b-quantized' | 'llama-3.2-3b' | 'llama-3.2-3b-qlora' | 'llama-3.2-3b-spinquant' @@ -289,6 +291,12 @@ export interface Message { * controller normalizes the path before passing it to native code. */ mediaPath?: string; + /** + * Optional fp32 mono 16 kHz PCM buffer. Only valid on `user` messages for + * models with the `'audio'` capability. The controller forwards it to the + * native `generateMultimodal` path. + */ + audioWaveform?: Float32Array; } /** @@ -386,6 +394,7 @@ export interface ContextStrategy { export const SPECIAL_TOKENS = { BOS_TOKEN: 'bos_token', EOS_TOKEN: 'eos_token', + EOT_TOKEN: 'eot_token', UNK_TOKEN: 'unk_token', SEP_TOKEN: 'sep_token', PAD_TOKEN: 'pad_token',