From c761d6b8107ff455277bf6c31ccf458b3125974a Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 29 Apr 2026 13:21:44 -0700 Subject: [PATCH 01/14] Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch Text-only export of Gemma 4 31B-IT to ExecuTorch with the CUDA backend and INT4/INT8 weight quantization via a new packing-agnostic quant/ framework. The quant/ package separates quantization into four concerns: - recipe.py: declarative QuantRecipe with regex FQN matching - quantize.py: produces CanonicalQuantizedWeight (min_max, HQQ) - serialize.py: save/load to safetensors with versioned headers - pack.py + pack_cuda.py: per-module packer dispatch for CUDA Two production recipes: "default" (INT4 min_max + INT8 embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ elsewhere). Sliding window attention uses a ring-buffer KV cache (2x window size) for the 50 sliding layers, saving memory for long sequences. The 10 full-attention layers use a standard flat KV cache. Includes C++ runner (main.cpp), eager inference script, and 60+ unit and integration tests across quant/ and pipeline test files. --- .github/workflows/cuda.yml | 3 + Makefile | 12 +- .../models/gemma4/text_decoder/__init__.py | 9 + .../models/gemma4/text_decoder/gemma4_norm.py | 39 +- examples/models/gemma4_31b/CMakeLists.txt | 67 ++ examples/models/gemma4_31b/CMakePresets.json | 52 ++ examples/models/gemma4_31b/README.md | 93 +++ examples/models/gemma4_31b/__init__.py | 5 + examples/models/gemma4_31b/export.py | 297 ++++++++ examples/models/gemma4_31b/inference.py | 192 +++++ examples/models/gemma4_31b/main.cpp | 340 +++++++++ examples/models/gemma4_31b/model.md | 197 +++++ examples/models/gemma4_31b/model.py | 700 ++++++++++++++++++ examples/models/gemma4_31b/quant/README.md | 88 +++ examples/models/gemma4_31b/quant/__init__.py | 24 + examples/models/gemma4_31b/quant/pack.py | 85 +++ examples/models/gemma4_31b/quant/pack_cuda.py | 210 ++++++ examples/models/gemma4_31b/quant/quantize.py | 225 ++++++ examples/models/gemma4_31b/quant/recipe.py | 58 ++ examples/models/gemma4_31b/quant/serialize.py | 215 ++++++ .../models/gemma4_31b/quant/test_pack_cuda.py | 360 +++++++++ .../models/gemma4_31b/quant/test_quantize.py | 194 +++++ .../models/gemma4_31b/quant/test_recipe.py | 163 ++++ .../models/gemma4_31b/quant/test_serialize.py | 207 ++++++ .../models/gemma4_31b/quantize_and_save.py | 135 ++++ examples/models/gemma4_31b/sampler.py | 41 + .../models/gemma4_31b/test_cuda_pipeline.py | 121 +++ examples/models/gemma4_31b/test_pipeline.py | 210 ++++++ 28 files changed, 4340 insertions(+), 2 deletions(-) create mode 100644 examples/models/gemma4_31b/CMakeLists.txt create mode 100644 examples/models/gemma4_31b/CMakePresets.json create mode 100644 examples/models/gemma4_31b/README.md create mode 100644 examples/models/gemma4_31b/__init__.py create mode 100644 examples/models/gemma4_31b/export.py create mode 100644 examples/models/gemma4_31b/inference.py create mode 100644 examples/models/gemma4_31b/main.cpp create mode 100644 examples/models/gemma4_31b/model.md create mode 100644 examples/models/gemma4_31b/model.py create mode 100644 examples/models/gemma4_31b/quant/README.md create mode 100644 examples/models/gemma4_31b/quant/__init__.py create mode 100644 examples/models/gemma4_31b/quant/pack.py create mode 100644 examples/models/gemma4_31b/quant/pack_cuda.py create mode 100644 examples/models/gemma4_31b/quant/quantize.py create mode 100644 examples/models/gemma4_31b/quant/recipe.py create mode 100644 examples/models/gemma4_31b/quant/serialize.py create mode 100644 examples/models/gemma4_31b/quant/test_pack_cuda.py create mode 100644 examples/models/gemma4_31b/quant/test_quantize.py create mode 100644 examples/models/gemma4_31b/quant/test_recipe.py create mode 100644 examples/models/gemma4_31b/quant/test_serialize.py create mode 100644 examples/models/gemma4_31b/quantize_and_save.py create mode 100644 examples/models/gemma4_31b/sampler.py create mode 100644 examples/models/gemma4_31b/test_cuda_pipeline.py create mode 100644 examples/models/gemma4_31b/test_pipeline.py diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index c3b7c058ee6..d1b954820ef 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -148,6 +148,9 @@ jobs: # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" + # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) + python -m pytest examples/models/gemma4_31b/quant/ examples/models/gemma4_31b/test_pipeline.py examples/models/gemma4_31b/test_cuda_pipeline.py -v -o "addopts=" + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/Makefile b/Makefile index 3c0eac14bce..ba61dddce44 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -126,6 +126,7 @@ help: @echo " llava-cpu - Build Llava runner with CPU backend" @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -425,6 +426,15 @@ qwen3_5_moe-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" +gemma4_31b-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Gemma 4 31B runner with CUDA..." + cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/examples/models/gemma4/text_decoder/__init__.py b/examples/models/gemma4/text_decoder/__init__.py index 25d7c5c7a16..51c96f0717f 100644 --- a/examples/models/gemma4/text_decoder/__init__.py +++ b/examples/models/gemma4/text_decoder/__init__.py @@ -6,5 +6,14 @@ # LICENSE file in the root directory of this source tree. from .convert_weights import convert_hf_to_custom # noqa: F401 +from .gemma4_attention import ( # noqa: F401 + apply_rotary_emb, + apply_rotary_emb_single, + Gemma4KVCache, + precompute_freqs_cis, + rotate_half, +) from .gemma4_config import Gemma4Config # noqa: F401 +from .gemma4_decoder_layer import Gemma4MLP # noqa: F401 from .gemma4_model import create_gemma4_model, Gemma4Model # noqa: F401 +from .gemma4_norm import RMSNorm, RMSNormNoWeight # noqa: F401 diff --git a/examples/models/gemma4/text_decoder/gemma4_norm.py b/examples/models/gemma4/text_decoder/gemma4_norm.py index 17e42a43ca1..2c8fec67525 100644 --- a/examples/models/gemma4/text_decoder/gemma4_norm.py +++ b/examples/models/gemma4/text_decoder/gemma4_norm.py @@ -5,9 +5,46 @@ # pyre-unsafe # LICENSE file in the root directory of this source tree. +"""Gemma 4 RMSNorm — self-contained re-implementation. + +Numerically identical to ``transformers.models.gemma4.modeling_gemma4.Gemma4RMSNorm`` +(same float32 upcast and ``pow(mean_squared, -0.5)`` normalization), but +without the transformers import so this module is exportable and dep-light. +""" + from functools import partial -from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm as RMSNorm +import torch +from torch import nn + + +class RMSNorm(nn.Module): + """Gemma4 RMSNorm: ``y = (x / rms(x)) * weight``, computed in float32. + + Unlike Gemma 2/3 (``(1 + weight)``) Gemma 4 multiplies by ``weight`` directly. + Pass ``with_scale=False`` for the v-norm and the (unused-here) router norm, + which omit the learnable weight entirely. + """ + + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + # Match transformers' use of pow(mean_squared, -0.5) over rsqrt; + # the comment there cites Torch/JAX compiler differences. + mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps + return x * torch.pow(mean_squared, -0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + normed = self._norm(x.float()) + if self.with_scale: + normed = normed * self.weight.float() + return normed.type_as(x) + # V-norm in attention uses RMSNorm without learnable weight. RMSNormNoWeight = partial(RMSNorm, with_scale=False) diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt new file mode 100644 index 00000000000..8d536a47fc5 --- /dev/null +++ b/examples/models/gemma4_31b/CMakeLists.txt @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(gemma4_31b) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# executorch +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# CPU ops (for the host-side helpers that aren't delegated to CUDA) +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# Extensions +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# CUDA backend (the only supported backend for this example for now) +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + add_compile_definitions(EXECUTORCH_BUILD_CUDA) +else() + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON") +endif() + +# Tokenizer (HuggingFace tokenizer.json) +list(APPEND link_libraries tokenizers::tokenizers) + +add_executable(gemma4_31b_runner main.cpp) +target_include_directories( + gemma4_31b_runner PUBLIC ${_common_include_directories} +) +target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries}) + +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(gemma4_31b_runner) + target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") +endif() diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json new file mode 100644 index 00000000000..97ba7f4c57a --- /dev/null +++ b/examples/models/gemma4_31b/CMakePresets.json @@ -0,0 +1,52 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "gemma4-31b-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/gemma4_31b", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "gemma4-31b-cuda", + "displayName": "Gemma 4 31B runner (CUDA)", + "inherits": ["gemma4-31b-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + } + ], + "buildPresets": [ + { + "name": "gemma4-31b-cuda", + "displayName": "Build Gemma 4 31B runner (CUDA)", + "configurePreset": "gemma4-31b-cuda", + "targets": ["gemma4_31b_runner"] + } + ], + "workflowPresets": [ + { + "name": "gemma4-31b-cuda", + "displayName": "Configure and build Gemma 4 31B runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "gemma4-31b-cuda" + }, + { + "type": "build", + "name": "gemma4-31b-cuda" + } + ] + } + ] +} diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md new file mode 100644 index 00000000000..85fa67844be --- /dev/null +++ b/examples/models/gemma4_31b/README.md @@ -0,0 +1,93 @@ +# Gemma 4 31B-IT + +Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8 +weight quantization. Currently supports the CUDA backend. + +For architecture and design notes see [model.md](model.md). + +## When to use which script + +The full bf16 weights for 31B (~62 GB) often don't fit in available RAM. The +recommended flow is to quantize once and reuse the quantized checkpoint for +both export and eager inference: + +| Script | Purpose | Peak memory | +|---|---|---| +| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU | +| `export.py --prequantized ` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing | +| `inference.py --prequantized ` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU | +| `export.py --model-dir ` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing | + +The quantized checkpoint is a safetensors file with int values + per-group +scales and a JSON header describing each weight's `QuantConfig`. No tensor +subclass or backend-specific packing — packing for the target backend happens +at load time via `quant.pack_model()`. + +## Quantization recipes + +Two built-in recipes (see `quantize_and_save.py`): + +| Recipe | Description | +|---|---| +| `default` | INT4 min_max linears, INT8 per-axis embedding | +| `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding | + +## Quantize once + +```bash +python examples/models/gemma4_31b/quantize_and_save.py \ + --model-dir ~/local/scripts/models/gemma-4-31B-it \ + --output ./gemma4_31b_int4 \ + --quant-recipe default +``` + +Writes `model.safetensors`, `config.json`, and +`tokenizer.json` into `--output`. + +## Export to ExecuTorch + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports \ + --max-seq-len 4096 \ + --backend cuda +``` + +Writes `model.pte` and `model.ptd` into `--output-dir`. + +## Eager inference + +```bash +python examples/models/gemma4_31b/inference.py \ + --prequantized ./gemma4_31b_int4 \ + --prompt "Write a short joke about saving RAM." \ + --max-new-tokens 128 \ + --temperature 0.8 +``` + +Useful before spending the export+lowering time to confirm the quantized +model produces sensible text. + +## Build the runner + +```bash +make gemma4_31b-cuda +``` + +The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. + +## Run the .pte + +```bash +./gemma4_31b_runner \ + --model_path ./gemma4_31b_exports/model.pte \ + --data_path ./gemma4_31b_exports/aoti_cuda_blob.ptd \ + --tokenizer_path ./gemma4_31b_int4/tokenizer.json \ + --prompt "Write a short joke about saving RAM." \ + --max_new_tokens 128 \ + --temperature 0.8 +``` + +For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA +graph (decode is fully static — `T=1`). diff --git a/examples/models/gemma4_31b/__init__.py b/examples/models/gemma4_31b/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/models/gemma4_31b/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py new file mode 100644 index 00000000000..53e66fcd646 --- /dev/null +++ b/examples/models/gemma4_31b/export.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd). + +Two methods are exported and lowered together so they share KV-cache buffers: + - "decode": T=1, static shape, returns the next sampled token. + - "prefill": T>=2, dynamic shape, returns the next sampled token. + +Two input paths: + --prequantized Load a quantized checkpoint (from quantize_and_save.py) + and pack for the target backend. No re-quantization. + --model-dir Load bf16 checkpoint, quantize, pack, and export + in one shot. + +Backends: + --backend cuda (default) CUDA via tinygemm INT4 + CudaPartitioner. +""" + +import argparse +import os + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + Gemma4_31BConfig, + materialize_runtime_buffers, +) + + +# --------------------------------------------------------------------------- +# Load paths + + +def load_prequantized_model( + prequantized_dir: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple[Gemma4_31B, Gemma4_31BConfig]: + """Load a quantized checkpoint and pack for the target backend.""" + config = Gemma4_31BConfig.from_hf_config( + os.path.join(prequantized_dir, "config.json") + ) + config.max_seq_len = max_seq_len + + print("Building model on meta device...") + with torch.device("meta"): + model = Gemma4_31B(config) + + safetensors_path = os.path.join(prequantized_dir, "model.safetensors") + print(f"Loading quantized checkpoint from {safetensors_path}...") + _pack_for_backend(model, safetensors_path, backend) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config + + +def load_and_quantize( + model_dir: str, + recipe_name: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple[Gemma4_31B, Gemma4_31BConfig]: + """Load bf16 checkpoint, quantize, pack — one shot.""" + from executorch.examples.models.gemma4_31b.quant import pack_model, quantize_model + from executorch.examples.models.gemma4_31b.quantize_and_save import _RECIPES + + recipe = _RECIPES[recipe_name] + + print("Loading checkpoint (lazy, shard-by-shard)...") + model, config = Gemma4_31B.from_hf_checkpoint(model_dir, max_seq_len=max_seq_len) + + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + print("Untying embed_tokens / lm_head...") + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + print(f"Quantizing with recipe '{recipe_name}'...") + quantized, unquantized = quantize_model(model, recipe) + + print(f"Packing for {backend}...") + with torch.device("meta"): + model = Gemma4_31B(config) + pack_model(model, quantized, unquantized, packers=_get_packers(backend)) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config + + +# --------------------------------------------------------------------------- +# Backend dispatch helpers + + +def _get_packers(backend: str) -> dict: + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS + + return DEFAULT_CUDA_PACKERS + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda + + load_and_pack_for_cuda(path, model) + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +# --------------------------------------------------------------------------- +# Export + lower + + +def export_and_lower( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + backend: str = "cuda", +) -> None: + """Export and lower the model to ExecuTorch for the given backend.""" + if backend == "cuda": + _export_cuda(model, config, output_dir) + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: + import torch._inductor.config as inductor_config + + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + inductor_config.coordinate_descent_tuning = False + inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + print("Exporting decode (T=1)...") + with torch.no_grad(): + decode_ep = export( + model, + ( + torch.tensor([[0]], dtype=torch.long), + torch.tensor([0], dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + strict=True, + ) + + max_prefill = config.max_seq_len - 1 + seq_dim = Dim("seq_len", min=2, max=max_prefill) + print(f"Exporting prefill (T in [2, {max_prefill}])...") + with torch.no_grad(): + prefill_ep = export( + model, + ( + torch.zeros((1, max_prefill), dtype=torch.long), + torch.arange(max_prefill, dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}, None), + strict=True, + ) + + print("Lowering to ExecuTorch with CUDA backend...") + et_prog = to_edge_transform_and_lower( + {"decode": decode_ep, "prefill": prefill_ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("decode")] + ) + ], + "prefill": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("prefill")] + ) + ], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + }, + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + share_mutable_buffers=True, + ), + emit_mutable_buffer_names=True, + ), + ) + + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + +# --------------------------------------------------------------------------- +# CLI + + +def main() -> None: + from executorch.examples.models.gemma4_31b.quantize_and_save import _RECIPES + + parser = argparse.ArgumentParser(description="Export Gemma 4 31B-IT to ExecuTorch.") + src = parser.add_mutually_exclusive_group(required=True) + src.add_argument( + "--model-dir", + default=None, + help="HuggingFace model dir. Triggers load + quantize + export.", + ) + src.add_argument( + "--prequantized", + default=None, + help="Path to a quantized checkpoint directory. Skips quantization.", + ) + parser.add_argument( + "--output-dir", + default="./gemma4_31b_exports", + help="Output directory for model.pte / model.ptd.", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache size.", + ) + parser.add_argument( + "--quant-recipe", + default="default", + choices=list(_RECIPES), + help="Quantization recipe (only with --model-dir).", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend for export.", + ) + args = parser.parse_args() + + if args.backend == "cuda" and not torch.cuda.is_available(): + parser.error("CUDA is required for the cuda backend.") + + if args.prequantized: + model, config = load_prequantized_model( + args.prequantized, + max_seq_len=args.max_seq_len, + backend=args.backend, + ) + else: + model, config = load_and_quantize( + args.model_dir, + args.quant_recipe, + max_seq_len=args.max_seq_len, + backend=args.backend, + ) + + export_and_lower(model, config, args.output_dir, backend=args.backend) + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py new file mode 100644 index 00000000000..59418f3b746 --- /dev/null +++ b/examples/models/gemma4_31b/inference.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Eager inference on a prequantized Gemma 4 31B-IT model (CUDA + torch.compile). + +Loads a quantized checkpoint (from ``quantize_and_save.py``), packs for CUDA, +materializes runtime buffers, optionally compiles with ``torch.compile``, and +generates text autoregressively. The model performs Gumbel-max sampling +on-device, so each forward returns the next token ID as a float tensor of +shape ``[B, 1]``. + +Usage: + python inference.py \\ + --prequantized ./gemma4_31b_int4 \\ + --prompt "Write a short joke about saving RAM." \\ + --max-new-tokens 128 \\ + --temperature 0.8 +""" + +import argparse +import os +import time + +import torch + +from executorch.examples.models.gemma4_31b.export import load_prequantized_model +from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers + + +def _move_to_cuda(model, config) -> None: + """Move the prequantized model to CUDA and materialize runtime buffers there. + + Parameters are moved individually (not via ``model.cuda()``) to preserve + ``Int4TilePackedTo4dTensor`` subclass identity. Non-meta buffers (e.g. + ``layer_scalar``) are moved to CUDA. Meta-device buffers (KV cache, RoPE, + constants) are materialized directly on CUDA via + ``materialize_runtime_buffers``. + """ + for name, p in model.named_parameters(): + parts = name.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + setattr( + parent, + parts[-1], + torch.nn.Parameter(p.data.to("cuda"), requires_grad=False), + ) + + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + parent.register_buffer(parts[-1], buf.to("cuda"), persistent=False) + + materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda") + + +def generate( + model, + tokenizer, + prompt: str, + max_new_tokens: int = 128, + temperature: float = 0.0, + eos_token_ids=None, + bos_token_id: int = 2, +) -> str: + """Autoregressive generation. Prefill is one-token-at-a-time so a single + compiled graph handles every step; the exported PTE uses a separate + multi-token prefill method, but for eager+compile a uniform decode-shape + forward is simpler and benefits from CUDA-graph friendly shapes. + + ``tokenizers.Tokenizer.from_file`` does not auto-prepend BOS — and Gemma 4 + is unusable without it (the model's logits collapse to a single + high-frequency vocab token if the very first input isn't BOS). We prepend + explicitly here; pass ``bos_token_id=None`` to disable. + """ + if eos_token_ids is None: + eos_token_ids = set() + + input_ids = tokenizer.encode(prompt).ids + if bos_token_id is not None and (not input_ids or input_ids[0] != bos_token_id): + input_ids = [bos_token_id] + input_ids + + temp_val = max(temperature, 1e-6) # avoid div-by-zero in the on-device sampler + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + + sampled = None + with torch.no_grad(): + # Prefill, one token at a time. + for i, tok_id in enumerate(input_ids): + tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + + # First generated token from the last prefill step. + next_id = int(sampled.item()) + generated = [next_id] + + # Decode loop. + seq_len = len(input_ids) + for i in range(max_new_tokens - 1): + tok = torch.tensor([[next_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + next_id = int(sampled.item()) + generated.append(next_id) + if next_id in eos_token_ids: + break + + return tokenizer.decode(generated) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Eager inference on prequantized Gemma 4 31B-IT (CUDA)." + ) + parser.add_argument( + "--prequantized", + required=True, + help="Path to a quantized checkpoint directory.", + ) + parser.add_argument("--prompt", default="Hello", help="Input prompt.") + parser.add_argument( + "--max-new-tokens", + type=int, + default=128, + help="Maximum tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature (0 = near-greedy).", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache length to allocate for this run.", + ) + parser.add_argument( + "--no-compile", + action="store_true", + help="Skip torch.compile (slower, but easier to debug).", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + parser.error("CUDA is required for inference.") + + print(f"Loading prequantized model from {args.prequantized}...") + model, config = load_prequantized_model( + args.prequantized, max_seq_len=args.max_seq_len + ) + _move_to_cuda(model, config) + model.eval() + + if not args.no_compile: + print("Compiling model with torch.compile...") + model = torch.compile(model, mode="default") + + tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(tokenizer_path) + + # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). + eos_token_ids = {1, 50, 106} + + print(f"\nPrompt: {args.prompt}") + print("-" * 40) + + t0 = time.perf_counter() + output = generate( + model, + tokenizer, + args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + eos_token_ids=eos_token_ids, + ) + elapsed = time.perf_counter() - t0 + + print(output) + print("-" * 40) + print(f"Generated in {elapsed:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp new file mode 100644 index 00000000000..23526119e7b --- /dev/null +++ b/examples/models/gemma4_31b/main.cpp @@ -0,0 +1,340 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Gemma 4 31B-IT runner for the CUDA ExecuTorch backend. +// +// Drives the prefill + decode methods produced by export.py. +// The exported model performs Gumbel-max sampling on-device and returns a +// single float token ID per call, so this runner only has to feed tokens +// in and decode them via the HuggingFace tokenizer. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef EXECUTORCH_BUILD_CUDA +#include +#endif + +DEFINE_string(model_path, "", "Path to model.pte."); +DEFINE_string(data_path, "", "Path to model.ptd (CUDA tensor data)."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); +DEFINE_string(prompt, "Hello", "Prompt text."); +DEFINE_string( + prompt_file, + "", + "Optional path to a file with the prompt text (overrides --prompt)."); +DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); +DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); +DEFINE_bool( + cuda_graph, + false, + "Enable CUDA graph capture for the decode method."); + +namespace llm = ::executorch::extension::llm; +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +using SizesType = executorch::aten::SizesType; + +// The model performs sampling on-device and returns a [B, 1] float tensor +// holding a token ID. Copy it to host and convert to uint64. +static uint64_t read_token(const executorch::aten::Tensor& output) { + const void* ptr = output.const_data_ptr(); + float val = 0.0f; + +#ifdef EXECUTORCH_BUILD_CUDA + cudaPointerAttributes attrs{}; + bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + if (on_device) { + cudaError_t err = + cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + ET_LOG( + Error, + "read_token: cudaMemcpy D2H failed: %s", + cudaGetErrorString(err)); + return 0; + } + } else { + memcpy(&val, ptr, sizeof(float)); + } +#else + memcpy(&val, ptr, sizeof(float)); +#endif + + return static_cast(val); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty()) { + ET_LOG(Error, "Must specify --model_path"); + return 1; + } + if (FLAGS_tokenizer_path.empty()) { + ET_LOG(Error, "Must specify --tokenizer_path"); + return 1; + } + + llm::Stats stats; + +#ifdef EXECUTORCH_BUILD_CUDA + size_t gpu_free_bytes = 0, gpu_total_bytes = 0; + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_total_bytes = gpu_total_bytes; + stats.gpu_free_before_load_bytes = gpu_free_bytes; +#endif + + stats.model_load_start_ms = llm::time_in_ms(); + + // Tokenizer + auto tokenizer = std::make_unique(); + if (tokenizer->load(FLAGS_tokenizer_path) != tokenizers::Error::Ok) { + ET_LOG( + Error, + "Failed to load tokenizer from %s", + FLAGS_tokenizer_path.c_str()); + return 1; + } + + // Module: share_memory_arenas=true so prefill and decode see the same + // KV-cache memory (we exported with share_mutable_buffers=True). + std::vector data_files; + if (!FLAGS_data_path.empty()) { + data_files.push_back(FLAGS_data_path); + } + auto module = std::make_unique( + FLAGS_model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to read model metadata"); + return 1; + } + + if (FLAGS_cuda_graph) { + executorch::runtime::BackendOptions<2> cuda_opts; + cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); + executorch::runtime::set_option("CudaBackend", cuda_opts.view()); + printf("CUDA graph enabled for decode method\n"); + } + + // Cross-method per-FQN weight sharing: prefill + decode share the same + // weight tensors and (more importantly) the same KV-cache buffers, so + // without this flag we would allocate them twice. MUST be set before + // load_method. + { + executorch::runtime::BackendOptions<1> backend_options; + if (backend_options.set_option("weight_sharing_across_methods", true) != + Error::Ok || + executorch::runtime::set_option( + "CudaBackend", backend_options.view()) != Error::Ok) { + ET_LOG(Error, "Failed to enable weight_sharing_across_methods"); + return 1; + } + } + + printf("Loading methods...\n"); + if (module->load_method("prefill") != Error::Ok) { + ET_LOG(Error, "Failed to load prefill method"); + return 1; + } + if (module->load_method("decode") != Error::Ok) { + ET_LOG(Error, "Failed to load decode method"); + return 1; + } + stats.model_load_end_ms = llm::time_in_ms(); + +#ifdef EXECUTORCH_BUILD_CUDA + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_free_after_load_bytes = gpu_free_bytes; +#endif + + auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); + + std::string prompt_text = FLAGS_prompt; + if (!FLAGS_prompt_file.empty()) { + std::ifstream f(FLAGS_prompt_file); + if (!f.is_open()) { + ET_LOG( + Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str()); + return 1; + } + prompt_text.assign( + (std::istreambuf_iterator(f)), std::istreambuf_iterator()); + } + + auto encode_result = tokenizer->encode(prompt_text); + if (!encode_result.ok()) { + ET_LOG(Error, "Failed to encode prompt"); + return 1; + } + auto prompt_tokens = std::move(*encode_result); + int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + stats.num_prompt_tokens = num_prompt_tokens; + + stats.inference_start_ms = llm::time_in_ms(); + + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + + // Temperature: clamp 0 to a tiny epsilon so the divide in the exported + // sampler stays well-defined. Gumbel noise then becomes negligible + // relative to logit gaps and we get effectively-greedy sampling. + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); + + // --------------------------------------------------------------- + // Prefill + // --------------------------------------------------------------- + std::string run_method = "prefill"; + if (num_prompt_tokens == 1) { + // prefill was exported with min seq_len=2; decode handles T==1. + run_method = "decode"; + } + + std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); + std::vector pos_data(num_prompt_tokens); + for (int64_t i = 0; i < num_prompt_tokens; i++) { + pos_data[i] = i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, S(num_prompt_tokens)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {S(num_prompt_tokens)}, + executorch::aten::ScalarType::Long); + + std::vector prefill_inputs = { + EValue(tokens_tensor), + EValue(pos_tensor), + EValue(temp_tensor), + }; + + auto prefill_result = module->execute(run_method, prefill_inputs); + if (prefill_result.error() != Error::Ok) { + ET_LOG(Error, "%s failed", run_method.c_str()); + return 1; + } + uint64_t cur_token = read_token(prefill_result.get()[0].toTensor()); + + stats.prompt_eval_end_ms = llm::time_in_ms(); + double prefill_ms = + static_cast(stats.prompt_eval_end_ms - stats.inference_start_ms); + printf( + "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_prompt_tokens, + prefill_ms, + num_prompt_tokens * 1000.0 / prefill_ms); + +#ifdef EXECUTORCH_BUILD_CUDA + // Make prefill's writes to the shared KV cache visible before decode + // potentially runs on a different stream. + cudaDeviceSynchronize(); +#endif + + // --------------------------------------------------------------- + // Decode loop + // --------------------------------------------------------------- + int64_t pos = num_prompt_tokens; + std::vector decode_token_data = {static_cast(cur_token)}; + std::vector decode_pos_data = {pos}; + auto decode_tokens = from_blob( + decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); + auto decode_pos = from_blob( + decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); + + uint64_t prev_token = cur_token; + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + decode_token_data[0] = static_cast(cur_token); + decode_pos_data[0] = pos; + + std::vector decode_inputs = { + EValue(decode_tokens), + EValue(decode_pos), + EValue(temp_tensor), + }; + + auto decode_result = module->execute("decode", decode_inputs); + if (decode_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + + prev_token = cur_token; + cur_token = read_token(decode_result.get()[0].toTensor()); + + if (step == 0) { + stats.first_token_ms = llm::time_in_ms(); + } + pos++; + + auto decode_str = tokenizer->decode(prev_token, cur_token); + if (decode_str.ok()) { + printf("%s", decode_str->c_str()); + fflush(stdout); + } + + if (eos_ids.find(cur_token) != eos_ids.end()) { + printf("\n"); + break; + } + } + + stats.inference_end_ms = llm::time_in_ms(); + printf("\n"); + + int64_t num_generated = pos - num_prompt_tokens; + stats.num_generated_tokens = num_generated; + double decode_ms = + static_cast(stats.inference_end_ms - stats.prompt_eval_end_ms); + printf( + "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_generated, + decode_ms, + num_generated * 1000.0 / decode_ms); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + +#ifdef EXECUTORCH_BUILD_CUDA + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_free_after_generate_bytes = gpu_free_bytes; + stats.gpu_peak_usage_mb = + (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; +#endif + + llm::print_report(stats); + return 0; +} diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md new file mode 100644 index 00000000000..c6a20d5c306 --- /dev/null +++ b/examples/models/gemma4_31b/model.md @@ -0,0 +1,197 @@ +# Gemma 4 31B-IT — Architecture & Design Notes + +Developer reference for `model.py` and the `quant/` package. For +export/build/run instructions see [README.md](README.md). + +The model mirrors the `Gemma4ForConditionalGeneration` text stack from +HuggingFace transformers / vLLM, with the ExecuTorch customizations needed +for `torch.export(strict=True)`. + +## Architecture + +``` +Input tokens (B, T) + | + v +Embedding (vocab=262144, dim=5376) -> *= sqrt(hidden_size) (normalizer) + | + v ++--- Decoder Layer x60 -----------------------------------------+ +| | +| residual = x | +| RMSNorm -> Attention (sliding | full) -> RMSNorm -> +residual | +| residual = x | +| RMSNorm -> MLP (gate_proj, up_proj, down_proj, GELU-tanh) | +| -> RMSNorm -> +residual | +| x *= layer_scalar (per-layer buffer) | +| | ++----------------------------------------------------------------+ + | + v +RMSNorm -> LM Head (tied with embed) -> tanh(logits/30) * 30 + | + v +Gumbel-max sample(temperature) -> next token (B, 1) +``` + +Layer pattern (`5 sliding + 1 full`, repeated 10x — the last layer is full): + +``` +S S S S S F S S S S S F ... S S S S S F (S = sliding, F = full) +``` + +## Attention details + +Two attention flavors, selected by `config.layer_types[layer_idx]`: + +| Property | Sliding (50 layers) | Full (10 layers, idx 5,11,...,59) | +|---------------------|--------------------|-----------------------------------| +| `head_dim` | 256 | 512 | +| `num_kv_heads` | 16 | 4 | +| `num_heads` | 32 | 32 | +| RoPE θ | 10 000 | 1 000 000 | +| RoPE flavor | full neox | proportional, partial=0.25 | +| K = V | no | yes (no `v_proj`) | +| Causal mask | causal | causal | +| Window restriction | 1024 tokens | none | +| Q-norm / K-norm | RMSNorm w/ weight | RMSNorm w/ weight | +| V-norm | RMSNorm no weight | RMSNorm no weight | +| `scaling` | 1.0 | 1.0 | + +Notes: + +- **Proportional partial RoPE**: the inv_freq vector for full-attention layers + has the first `head_dim * partial_rotary_factor / 2 = 64` frequencies real + (computed with denominator `head_dim`, not `rotary_dim` — that's the + proportional part) and the remaining `head_dim/2 - 64 = 192` zero so cos=1 + and sin=0 (identity rotation) for the non-rotated dims. +- **K = V**: on full-attention layers `v_proj` is absent in the checkpoint + and `V` is taken from the pre-norm `K` projection. After `k_norm` / + RoPE on K and `v_norm` (weightless) on V the two diverge, so the cache + still stores them separately. +- **Mask construction**: a single boolean `(1, 1, T_q, T_kv)` mask is built + once per forward at the model level — one for sliding (causal AND + pos_q - pos_k < 1024), one for full (just causal). Layers pick whichever + matches their type and pass it to `F.scaled_dot_product_attention(..., + enable_gqa=True)`. +- **Gemma `scaling=1.0`**: unlike Gemma 2/3, Gemma 4 does not scale Q by + `query_pre_attn_scalar`; QK-norm handles attention magnitude. + +## Model parameters (text stack) + +| Parameter | Value | +|---------------------------------|------------| +| `vocab_size` | 262 144 | +| `hidden_size` | 5 376 | +| `intermediate_size` | 21 504 | +| `num_hidden_layers` | 60 | +| `num_attention_heads` | 32 | +| `num_key_value_heads` (sliding) | 16 | +| `head_dim` (sliding) | 256 | +| `num_global_key_value_heads` | 4 | +| `global_head_dim` | 512 | +| `sliding_window` | 1024 | +| `rms_norm_eps` | 1e-6 | +| `final_logit_softcapping` | 30.0 | +| `tie_word_embeddings` | true | +| `max_position_embeddings` | 262 144 | + +Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, +`pre_feedforward_layernorm`, `post_feedforward_layernorm` — all +`RMSNorm` (multiplies by `weight` directly, not `(1 + weight)`). + +## Methods exported (`export.py`) + +| Method | Input | Output (sampled) | +|-----------|------------------------------------------------------------|------------------| +| `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | +| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, max_seq_len-1] | `(1, 1)` float | + +Both methods share the same KV-cache buffers via +`MemoryPlanningPass(share_mutable_buffers=True)` and +`emit_mutable_buffer_names=True`. The exported program performs Gumbel-max +sampling on-device and returns a single token ID per call so the C++ runner +only has to feed tokens. + +## Quantization + +Three modules in `quant/`: + +- **Recipe** (`recipe.py`): `QuantConfig` (bits, group_size, symmetric, + method) + `QuantRule` (regex pattern, config, optional layer filter) + + `QuantRecipe` (ordered rules, first match wins). Declares what to + quantize and how — says nothing about packing or backends. +- **Serialize** (`serialize.py`): `CanonicalQuantizedWeight` (int8 qdata + + bf16 scale + optional zero). `save()` / `load()` persist to safetensors + with a JSON header per weight. Packing-agnostic — any backend can read + the file. +- **Packer** (`pack_cuda.py`): converts `CanonicalQuantizedWeight` to + backend runtime format at load time via `pack_model()`. Dispatches per + parent module type (`nn.Linear` → `Int4TilePackedTo4dTensor` for + tinygemm). Extensible via a packers dict. + +The quantize-once flow: + +``` +quantize_and_save.py export.py / inference.py + | | + bf16 weights quantized checkpoint (safetensors) + | | + quantize_weight() load() + | | + CanonicalQuantizedWeight CanonicalQuantizedWeight + | | + save() pack_model() + | | + model.safetensors Int4TilePackedTo4dTensor (runtime) +``` + +`embed_tokens` and `lm_head` start tied; they are untied before +quantization so `lm_head` (a 5376→262 144 matmul, very expensive at decode) +gets quantized. The embedding gets INT8 per-axis quantization (nearly +lossless for index lookup). + +## Runtime buffer materialization + +After weight loading (via `pack_model()` or `from_hf_checkpoint()`), the +model's KV caches, RoPE tables, and scalar constants are still on the meta +device. `materialize_runtime_buffers(model, dtype, device)` in `model.py` +replaces them with real tensors: + +- KV caches → zeros in `dtype` (bf16 for inference, bf16 for export) +- RoPE tables → computed per-layer (sliding vs full, different θ and head_dim) +- `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants + +Called by `export.py` (device="cpu" for tracing) and `inference.py` +(device="cuda" for eager execution). Having one function avoids duplicating +the RoPE computation and constant setup across scripts. + +## Customizations vs. vLLM / transformers reference + +These exist solely to make the model exportable / efficient under ExecuTorch: + +- **Boolean attention mask** built once per forward and shared across layers + of the same type, instead of HF's per-layer `_create_causal_mask`. +- **Ring-buffer KV cache** for sliding layers (`RingKVCache`, sized to + `2 × sliding_window`) saves memory for long sequences — positions wrap + via modulo and the attention mask reconstructs which slots are valid. + Full-attention layers use a flat `Gemma4KVCache` sized to `max_seq_len`. + Both use `index_copy_(dim=2, ...)` for trace-friendly updates. +- **Per-layer RoPE tables** registered as `persistent=False` buffers (sliding + uses full RoPE, full uses proportional partial RoPE — head_dim and θ + differ, so the table is not shared). +- **On-device Gumbel-max sampling** so the exported program emits a token + rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a + single float per step. +- **Final-logit softcap baked into the graph**, applied before sampling. +- **Meta-device construction + assign-load** keeps peak memory small enough + to load the 31B-parameter checkpoint on one machine. + +## Shared primitives + +The numerically-sensitive math primitives are imported from +`examples.models.gemma4.text_decoder` and shared with the Gemma 4 E2B/E4B +example: `RMSNorm`, `RMSNormNoWeight`, `Gemma4MLP`, `Gemma4KVCache`, +`precompute_freqs_cis`, `apply_rotary_emb`. The 31B-specific pieces +(attention with K=V branch, decoder layer, top-level model with softcap + +sampling, checkpoint loader) live in `model.py`. diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py new file mode 100644 index 00000000000..f9d4d4c9060 --- /dev/null +++ b/examples/models/gemma4_31b/model.py @@ -0,0 +1,700 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Gemma 4 31B-IT — export-friendly reference implementation for ExecuTorch. + +Model definition designed for torch.export(strict=True) with the CUDA backend. +All stateful buffers (KV cache, RoPE inv_freq) are registered buffers so they +are captured by share_mutable_buffers across prefill/decode. The numerically +sensitive primitives — RMSNorm, GELU-tanh MLP, proportional/full RoPE, and +the BHSD KV cache — are imported from ``examples.models.gemma4.text_decoder`` +so the 31B and E2B/E4B paths share them. + +Reference: + - HF transformers: src/transformers/models/gemma4/modeling_gemma4.py + - vLLM: vllm/model_executor/models/gemma4.py + +Architecture highlights for the 31B dense variant: + - 60 decoder layers with hybrid attention: every 6th layer is "full" attention + (idx 5, 11, ..., 59 — 10 layers); the remaining 50 use sliding-window + attention with window=1024. + - Sliding layers: head_dim=256, num_kv_heads=16, full RoPE, theta=10000. + - Full layers: head_dim=512, num_kv_heads=4, K=V (no v_proj), and + "proportional" partial RoPE (factor=0.25, theta=1_000_000). + - Q-norm and K-norm with learnable scale; V-norm without scale. + - Per-layer scalar (loaded buffer) multiplied at the end of each layer. + - Final logits are soft-capped: tanh(logits / 30) * 30. + - Embedding is scaled by sqrt(hidden_size) before layer 0. + - Embedding and lm_head are tied (a single weight, untied for quantization + in the export step so lm_head can be 4-bit). +""" + +import json +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch.nn as nn + +# Shared primitives lifted out of the gemma4 (E2B/E4B) example. These are the +# bits whose semantics are identical for both variants — RMSNorm, the GELU-tanh +# MLP, the proportional/full RoPE table builder, and the BHSD KV cache. +from executorch.examples.models.gemma4.text_decoder import ( + apply_rotary_emb, + Gemma4KVCache, + Gemma4MLP, + precompute_freqs_cis, + RMSNorm, + RMSNormNoWeight, +) +from executorch.examples.models.gemma4_31b.sampler import sample +from torch.nn import functional as F + + +# --------------------------------------------------------------------------- +# Ring-buffer KV cache for sliding window attention + + +class RingKVCache(nn.Module): + """Ring-buffer KV cache for sliding window attention. + + Sized to ``window_size * 2`` (not ``max_seq_len``), saving memory for + long sequences. Positions wrap via modulo; old entries outside the + window are masked out by ``_build_masks``. + """ + + def __init__( + self, + max_batch_size: int, + window_size: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + self.window_size = window_size + self.buf_size = window_size * 2 + cache_shape = (max_batch_size, num_kv_heads, self.buf_size, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False) + self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + wrapped = input_pos % self.buf_size + self.k_cache.index_copy_(2, wrapped, k_val) + self.v_cache.index_copy_(2, wrapped, v_val) + return self.k_cache, self.v_cache + + +# --------------------------------------------------------------------------- +# Config + + +@dataclass +class Gemma4_31BConfig: + # Embedding / shape + vocab_size: int = 262144 + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + + # Attention shape (sliding layers — also the "default" path) + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + head_dim: int = 256 + + # Attention shape (full-attention layers) + num_global_key_value_heads: int = 4 + global_head_dim: int = 512 + attention_k_eq_v: bool = ( + True # full layers: V is derived from the same projection as K + ) + + # RoPE — split per layer type + sliding_rope_theta: float = 10_000.0 + full_rope_theta: float = 1_000_000.0 + full_partial_rotary_factor: float = 0.25 # proportional RoPE for full attention + + # Norm / activation + rms_norm_eps: float = 1e-6 + hidden_activation: str = "gelu_pytorch_tanh" + + # Sampling / output + final_logit_softcapping: float = 30.0 + tie_word_embeddings: bool = True + + # Sliding window + sliding_window: int = 1024 + + # Hybrid attention pattern + layer_types: list = field(default_factory=list) + + # Runtime + max_seq_len: int = 4096 + + def __post_init__(self): + if not self.layer_types: + # Default hybrid pattern: 5 sliding then 1 full, repeated. + self.layer_types = [ + "full_attention" if (i + 1) % 6 == 0 else "sliding_attention" + for i in range(self.num_hidden_layers) + ] + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError( + f"layer_types length {len(self.layer_types)} != " + f"num_hidden_layers {self.num_hidden_layers}" + ) + + @staticmethod + def from_hf_config(config_path: str) -> "Gemma4_31BConfig": + with open(config_path, "r") as f: + cfg = json.load(f) + if "text_config" in cfg: + cfg = cfg["text_config"] + + rope_params = cfg.get("rope_parameters", {}) + sliding_rope = rope_params.get("sliding_attention", {}) + full_rope = rope_params.get("full_attention", {}) + + return Gemma4_31BConfig( + vocab_size=cfg.get("vocab_size", 262144), + hidden_size=cfg.get("hidden_size", 5376), + intermediate_size=cfg.get("intermediate_size", 21504), + num_hidden_layers=cfg.get("num_hidden_layers", 60), + num_attention_heads=cfg.get("num_attention_heads", 32), + num_key_value_heads=cfg.get("num_key_value_heads", 16), + head_dim=cfg.get("head_dim", 256), + num_global_key_value_heads=cfg.get("num_global_key_value_heads", 4), + global_head_dim=cfg.get("global_head_dim", 512), + attention_k_eq_v=cfg.get("attention_k_eq_v", True), + sliding_rope_theta=sliding_rope.get("rope_theta", 10_000.0), + full_rope_theta=full_rope.get("rope_theta", 1_000_000.0), + full_partial_rotary_factor=full_rope.get("partial_rotary_factor", 0.25), + rms_norm_eps=cfg.get("rms_norm_eps", 1e-6), + hidden_activation=cfg.get("hidden_activation", "gelu_pytorch_tanh"), + final_logit_softcapping=cfg.get("final_logit_softcapping", 30.0), + tie_word_embeddings=cfg.get("tie_word_embeddings", True), + sliding_window=cfg.get("sliding_window", 1024), + layer_types=cfg.get("layer_types", []), + ) + + +# --------------------------------------------------------------------------- +# Attention — single class, branches on layer type via config +# +# RMSNorm, Gemma4MLP, the RoPE helpers, and Gemma4KVCache are imported from +# examples.models.gemma4.text_decoder so the two Gemma 4 variants share their +# numerically-sensitive primitives. + + +class Gemma4Attention(nn.Module): + """Gemma 4 attention with QK-norm, per-layer head_dim, RoPE, KV cache, and SDPA. + + The same class handles both sliding and full attention; the per-layer + config picks head_dim, num_kv_heads, RoPE flavor, and the K=V optimization. + """ + + def __init__(self, config: Gemma4_31BConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + layer_type = config.layer_types[layer_idx] + self.is_sliding = layer_type == "sliding_attention" + + if self.is_sliding: + self.head_dim = config.head_dim + self.n_kv_heads = config.num_key_value_heads + self.rope_theta = config.sliding_rope_theta + self.partial_rotary = 1.0 + self.k_eq_v = False + else: + self.head_dim = config.global_head_dim + self.n_kv_heads = config.num_global_key_value_heads + self.rope_theta = config.full_rope_theta + self.partial_rotary = config.full_partial_rotary_factor + self.k_eq_v = config.attention_k_eq_v + + self.n_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.scaling = 1.0 # Gemma 4 uses scale=1; QK-norm handles normalization. + + # Linear projections. v_proj is omitted on K=V layers to match the checkpoint. + self.q_proj = nn.Linear( + self.hidden_size, self.n_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.n_kv_heads * self.head_dim, bias=False + ) + if not self.k_eq_v: + self.v_proj = nn.Linear( + self.hidden_size, self.n_kv_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.n_heads * self.head_dim, self.hidden_size, bias=False + ) + + # Q/K norm have learnable weight; V norm is weightless. + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = RMSNormNoWeight(self.head_dim, eps=config.rms_norm_eps) + + # Precomputed RoPE table for this layer (per-layer because head_dim + # and theta differ between sliding and full attention). For full + # attention layers we pass freq_base_dim=head_dim so the zero-padded + # inv_freq matches HF's "proportional" partial RoPE. + if self.is_sliding: + rotary_dim = self.head_dim + freq_base_dim = None + else: + rotary_dim = int(self.head_dim * self.partial_rotary) + freq_base_dim = self.head_dim + freqs_cos, freqs_sin = precompute_freqs_cis( + rotary_dim, + config.max_seq_len, + theta=self.rope_theta, + freq_base_dim=freq_base_dim, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + # KV cache. Sliding layers use a ring buffer (2x window) to save + # memory; full layers use a flat buffer (max_seq_len). + if self.is_sliding: + self.kv_cache = RingKVCache( + max_batch_size=1, + window_size=config.sliding_window, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + ) + else: + self.kv_cache = Gemma4KVCache( + max_batch_size=1, + max_seq_len=config.max_seq_len, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + use_index_copy=True, + ) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + # raw_kv is the linear output before any norm — needed for K=V layers + # so V can be derived from the same tensor as K (post-norm differently). + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + # Norms applied per-head (HF unflatten -> norm -> flatten pattern). + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # Move to BHSD for SDPA / KV cache. + q = q.transpose(1, 2) # (B, H, T, D) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE on Q and K only (V is not rotated). cos/sin are gathered for + # the current positions to avoid baking the full table into the graph. + cos = self.freqs_cos[input_pos] + sin = self.freqs_sin[input_pos] + q, k = apply_rotary_emb(q, k, cos, sin) + + # Update cache and read back full K/V. + k, v = self.kv_cache.update(input_pos, k, v) + + # SDPA with explicit additive mask (already includes causal + + # sliding-window masking; built once per forward at the model level). + # `scale=1.0` matches HF Gemma 4 — Q-norm/K-norm have absorbed the + # 1/sqrt(d) factor into their trained weights, so the standard SDPA + # default of 1/sqrt(head_dim) would over-divide. enable_gqa lets the + # kernel handle the head ratio without us materializing expanded K/V. + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + is_causal=False, + enable_gqa=True, + scale=self.scaling, + ) + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + +# --------------------------------------------------------------------------- +# Decoder block — Gemma's "norm sandwich" pattern. + + +class Gemma4DecoderLayer(nn.Module): + def __init__(self, config: Gemma4_31BConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + + self.self_attn = Gemma4Attention(config, layer_idx) + self.mlp = Gemma4MLP(config.hidden_size, config.intermediate_size) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Per-layer scalar (loaded from checkpoint) — multiplied at the end of + # each layer. Kept as a buffer (not nn.Parameter) so it isn't quantized. + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + sliding_mask: torch.Tensor, + full_mask: torch.Tensor, + ) -> torch.Tensor: + attn_mask = sliding_mask if self.is_sliding else full_mask + + residual = x + h = self.input_layernorm(x) + h = self.self_attn(h, input_pos, attn_mask) + h = self.post_attention_layernorm(h) + x = residual + h + + residual = x + h = self.pre_feedforward_layernorm(x) + h = self.mlp(h) + h = self.post_feedforward_layernorm(h) + x = residual + h + + return x * self.layer_scalar + + +# --------------------------------------------------------------------------- +# Top-level model + + +class Gemma4_31B(nn.Module): + def __init__(self, config: Gemma4_31BConfig): + super().__init__() + self.config = config + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [Gemma4DecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Held separately so it can be untied + quantized at export time. + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Constants (registered as buffers so they move with .to(device)). + self.register_buffer( + "embed_normalizer", + torch.tensor(config.hidden_size**0.5), + persistent=False, + ) + self.register_buffer( + "logit_softcap", + torch.tensor(config.final_logit_softcapping), + persistent=False, + ) + # cache_positions[i] = i — used to build attention masks without + # introducing dynamic-shape tensors at runtime. + self.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long), + persistent=False, + ) + + def _build_masks( + self, input_pos: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build boolean (B=1, H=1, T_q, T_kv) masks for full and sliding attention. + + True = attend. Built once per forward, shared across layers of the + same type. Full mask is (T_q, max_seq_len); sliding mask is + (T_q, buf_size) where buf_size = 2 * sliding_window. + """ + # Full attention mask: (T_q, max_seq_len) + cache_pos = self.cache_positions # (max_seq_len,) + q_pos = input_pos.unsqueeze(1) # (T_q, 1) + causal = q_pos >= cache_pos.unsqueeze(0) + full_mask = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, max_seq_len) + + # Sliding attention mask over ring buffer: (T_q, buf_size) + buf_size = self.config.sliding_window * 2 + seq_len = input_pos.shape[0] + total_written = input_pos[0] + seq_len + j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device) + ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size + delta = q_pos - ring_pos.unsqueeze(0) + sliding = (ring_pos >= 0) & (delta >= 0) & (delta < self.config.sliding_window) + sliding_mask = sliding.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, buf_size) + + return sliding_mask, full_mask + + def forward( + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run the model. + + Args: + tokens: (B, T) token IDs. + input_pos: (T,) absolute positions for RoPE / KV cache. + temperature: optional 1-D float tensor controlling on-device sampling. + When provided, returns sampled tokens (B, 1) via Gumbel-max; + when None (e.g. eager eval), returns full logits (B, T, V) with + soft-capping applied so callers see post-cap values. + + Returns: + (B, 1) token IDs when sampling, else (B, T, V) float32 logits. + """ + x = self.embed_tokens(tokens) * self.embed_normalizer + + sliding_mask, full_mask = self._build_masks(input_pos) + for layer in self.layers: + x = layer(x, input_pos, sliding_mask, full_mask) + + x = self.norm(x) + + if temperature is None: + logits = self.lm_head(x).float() + cap = self.logit_softcap.float() + return torch.tanh(logits / cap) * cap + + # Decode-time fast path: only materialize logits for the last token. + last = self.lm_head(x[:, -1, :]).float() + cap = self.logit_softcap.float() + last = torch.tanh(last / cap) * cap + return sample(last, temperature) + + # ---------------- checkpoint loading ---------------- + + @staticmethod + def from_hf_checkpoint( + model_dir: str, max_seq_len: int = 4096 + ) -> tuple["Gemma4_31B", Gemma4_31BConfig]: + """Build the model on `meta` and load weights from the HF safetensors checkpoint. + + Uses lazy shard-by-shard loading + assign=True so peak memory stays at + roughly one shard's worth of weights. + """ + config = Gemma4_31BConfig.from_hf_config(os.path.join(model_dir, "config.json")) + config.max_seq_len = max_seq_len + + print( + f"Building Gemma4_31B on meta (layers={config.num_hidden_layers}, " + f"hidden={config.hidden_size}, max_seq_len={max_seq_len})..." + ) + with torch.device("meta"): + model = Gemma4_31B(config) + + print(f"Loading weights from {model_dir}...") + state_dict = _load_and_remap_checkpoint(model_dir, config) + + # Tied embeddings: copy embedding weight into lm_head when missing. + if "lm_head.weight" not in state_dict and "embed_tokens.weight" in state_dict: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"] + + missing, unexpected = model.load_state_dict( + state_dict, strict=False, assign=True + ) + + # Runtime buffers (KV caches, RoPE tables, masks) are zero-initialized + # and not in the checkpoint — those are the "expected" missing keys. + runtime_prefixes = ( + ".kv_cache.", + ".freqs_cos", + ".freqs_sin", + "embed_normalizer", + "logit_softcap", + "cache_positions", + ) + actual_missing = set(missing) + expected = {k for k in actual_missing if any(p in k for p in runtime_prefixes)} + extra = actual_missing - expected + if extra: + print(f" WARNING: missing weight keys: {sorted(extra)[:10]}") + if unexpected: + print(f" WARNING: unexpected keys: {sorted(unexpected)[:10]}") + print( + f" Loaded {len(state_dict)} tensors " + f"({len(expected)} runtime buffers OK)" + ) + return model, config + + +# --------------------------------------------------------------------------- +# Weight loading utilities + + +# HuggingFace key -> our model key. Patterns use `{}` for the layer index. +_HF_KEY_MAP = { + "model.embed_tokens.weight": "embed_tokens.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "lm_head.weight", + # Per-layer norms + "model.layers.{}.input_layernorm.weight": "layers.{}.input_layernorm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_layernorm.weight", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.pre_feedforward_layernorm.weight", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_feedforward_layernorm.weight", + "model.layers.{}.layer_scalar": "layers.{}.layer_scalar", + # Attention projections + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.self_attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.self_attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.self_attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.self_attn.o_proj.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.self_attn.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.self_attn.k_norm.weight", + # MLP + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.gate_proj.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.up_proj.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.down_proj.weight", +} + +# Multimodal keys we deliberately ignore for the text-only export. +_IGNORED_PREFIXES = ( + "model.vision_tower.", + "model.embed_vision.", +) + + +def _hf_to_model_key(hf_key: str) -> Optional[str]: + # Gemma4ForConditionalGeneration stores the LM under model.language_model.* + norm = hf_key + if norm.startswith("model.language_model."): + norm = norm.replace("model.language_model.", "model.", 1) + + if norm.startswith(_IGNORED_PREFIXES): + return None + + for hf_pat, model_pat in _HF_KEY_MAP.items(): + if "{}" not in hf_pat: + if norm == hf_pat: + return model_pat + continue + regex = re.escape(hf_pat).replace(r"\{\}", r"(\d+)") + m = re.fullmatch(regex, norm) + if m: + return model_pat.replace("{}", m.group(1), 1) + return None + + +def _load_and_remap_checkpoint(model_dir: str, config: Gemma4_31BConfig) -> dict: + """Stream-load safetensors shards and remap keys to model state_dict keys.""" + from safetensors import safe_open + + index_path = os.path.join(model_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + shard_files = sorted(set(index["weight_map"].values())) + elif os.path.exists(os.path.join(model_dir, "model.safetensors")): + shard_files = ["model.safetensors"] + else: + raise FileNotFoundError(f"No safetensors checkpoint in {model_dir}") + + state_dict: dict[str, torch.Tensor] = {} + skipped = 0 + for shard_file in shard_files: + shard_path = os.path.join(model_dir, shard_file) + with safe_open(shard_path, framework="pt", device="cpu") as f: + for ckpt_key in f.keys(): + model_key = _hf_to_model_key(ckpt_key) + if model_key is None: + skipped += 1 + continue + tensor = f.get_tensor(ckpt_key) + # layer_scalar in checkpoint is shape (1,) bf16 — keep as-is. + state_dict[model_key] = tensor + if skipped > 0: + print(f" Skipped {skipped} non-text keys (vision tower, etc.)") + return state_dict + + +# --------------------------------------------------------------------------- +# Runtime buffer materialization + + +def materialize_runtime_buffers( + model: Gemma4_31B, + dtype: torch.dtype, + device: str = "cpu", +) -> None: + """Replace meta-device buffers with real tensors and set runtime constants. + + Called after weight loading to fill in KV caches (zeros), RoPE tables + (computed), and scalar constants. Only touches buffers still on the meta + device — loaded (non-meta) buffers are left in place. + """ + config = model.config + + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + continue + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + is_kv = ".kv_cache." in fqn + target_dtype = dtype if is_kv else torch.float32 + if buf.dtype == torch.bool: + target_dtype = torch.bool + parent.register_buffer( + parts[-1], + torch.zeros(buf.shape, dtype=target_dtype, device=device), + persistent=False, + ) + + for layer in model.layers: + attn = layer.self_attn + if attn.is_sliding: + rotary_dim, freq_base_dim = attn.head_dim, None + else: + rotary_dim = int(attn.head_dim * attn.partial_rotary) + freq_base_dim = attn.head_dim + cos, sin = precompute_freqs_cis( + rotary_dim, + config.max_seq_len, + theta=attn.rope_theta, + freq_base_dim=freq_base_dim, + ) + attn.register_buffer("freqs_cos", cos.to(device), persistent=False) + attn.register_buffer("freqs_sin", sin.to(device), persistent=False) + + model.register_buffer( + "embed_normalizer", + torch.tensor(config.hidden_size**0.5, device=device), + persistent=False, + ) + model.register_buffer( + "logit_softcap", + torch.tensor(config.final_logit_softcapping, device=device), + persistent=False, + ) + model.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long, device=device), + persistent=False, + ) diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md new file mode 100644 index 00000000000..01a74434487 --- /dev/null +++ b/examples/models/gemma4_31b/quant/README.md @@ -0,0 +1,88 @@ +# quant/ + +Packing-agnostic quantization framework: **recipe → quantize → serialize → pack**. + +## Files + +| File | Concern | Depends on | +|---|---|---| +| `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | +| `quantize.py` | **Computation** — produces canonical weights from fp weights | recipe, torchao | +| `serialize.py` | **Data format** — saves/loads canonical weights to safetensors | recipe | +| `pack.py` | **Packing dispatch** — walks model, dispatches to per-module packers | serialize | +| `pack_cuda.py` | **CUDA packing** — converts canonical to tinygemm/intx runtime format | pack, serialize | + +## Data flow + +``` +QuantRecipe → quantize_model() → CanonicalQuantizedWeight → save() → file → load() → CanonicalQuantizedWeight → pack_model() → runtime model +``` + +`CanonicalQuantizedWeight` is the interchange point — int8 qdata + bf16 +scale + optional zero + config. Everything left of it is backend-agnostic. +Everything right is backend-specific. + +## Adding a new backend + +Write a `pack_.py` with per-module packers and a default registry: + +```python +def pack_linear_for_metal(module, weights): ... +DEFAULT_METAL_PACKERS = {nn.Linear: pack_linear_for_metal} +``` + +Call `pack_model(model, quantized, unquantized, packers=DEFAULT_METAL_PACKERS)`. +No changes to recipe, quantize, or serialize. + +Things to consider: + +- **Recipes may need to be backend-aware.** Each backend's kernels have + different constraints (e.g., Metal's `fpa4w` is INT4-only — no INT8 linear + kernel, so the sensitive recipe's 8-bit edge layers would need to be INT4 + or dequantized to bf16). Define per-backend recipes or validate recipe + compatibility at pack time. +- **Source transforms before packing.** Some backends replace model modules + (e.g., MLX swaps `FusedMoEExperts` → `SwitchMLP`, Metal swaps to + `MetalMoEExperts`). These transforms change the module types that + packers dispatch on, so they must run before `pack_model()`. For dense + models (no MoE) this is not needed. +- **Embedding quantization.** Not all backends have a quantized embedding + gather kernel. The packer can dequantize to bf16 at load time — the + disk savings from the canonical format still apply. + +## Adding a new model + +1. Define a `QuantRecipe` with rules for the model's FQN patterns. +2. If the model has custom module types (e.g., `FusedMoEExperts`), write a + per-module packer and extend the packers dict: + ```python + packers = {**DEFAULT_CUDA_PACKERS, FusedMoEExperts: pack_moe_experts} + ``` +3. No changes to the quant package itself. + +## On-disk format + +Safetensors with a `format_version` in the header. Per quantized weight: +`{fqn}.qdata` (int8, nibble-packed for 4-bit), `{fqn}.scale` (bf16), +optionally `{fqn}.zero` (bf16). Header JSON records bits, group_size, +symmetric, and method per weight. Unquantized weights stored as-is. + +## TODO + +- `pack_metal.py` — Metal backend packer. Convert canonical INT4 to + `UIntxWeightOnlyConfig` subclass (torchao experimental) for the + `torchao::_linear_fp_act_4bit_weight` kernel. For MoE models, pack + expert weights into Metal's `gather_qmv` format (asymmetric, unsigned + INT4 with scale + bias buffers). + +- `pack_mlx.py` — MLX backend packer. Convert canonical INT4 to + `IntxWeightOnlyConfig` subclass for the `mlx::gather_qmm` kernel. + For MoE models, stack per-expert weights into `SwitchLinear` format. + +- `gguf.py` — read a GGUF file and convert to `CanonicalQuantizedWeight` + dicts, enabling `load() → pack_model()` from community-quantized GGUF + checkpoints without re-quantizing from bf16. Maps GGUF quant types + (Q4_K, Q6_K, Q8_0, etc.) to `QuantConfig` and unpacks super-blocks + into the canonical qdata + scale + zero layout. For CUDA packing, + Q6_K would be widened to 8-bit (`pack_int8_for_cuda`) since there is + no 6-bit CUDA kernel — lossless, ~33% more memory than true 6-bit. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py new file mode 100644 index 00000000000..23d321f0c0b --- /dev/null +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .pack import ModulePackerFn, pack_model # noqa: F401 +from .pack_cuda import ( # noqa: F401 + DEFAULT_CUDA_PACKERS, + load_and_pack_for_cuda, + pack_embedding_for_cuda, + pack_int4_for_cuda, + pack_int8_for_cuda, + pack_linear_for_cuda, +) +from .quantize import quantize_model, quantize_weight # noqa: F401 +from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 +from .serialize import ( # noqa: F401 + CanonicalQuantizedWeight, + deserialize, + load, + save, + serialize, +) diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py new file mode 100644 index 00000000000..5a1e792b56d --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Backend-agnostic model packing: canonical weights → runtime model. + +``pack_model`` walks a model's quantized weights, groups them by parent +module, and dispatches to per-module packer functions. Each backend +(``pack_cuda.py``, future ``pack_metal.py``) provides its own packers dict +mapping module types to packer functions. + +Pure logic — no file I/O, no backend imports. +""" + +from collections import defaultdict +from typing import Callable + +import torch +import torch.nn as nn + +from .serialize import CanonicalQuantizedWeight + +# Packer signature: receives the module + a dict of its quantized weights +# (keyed by attribute name, e.g., {"weight": CQW}), modifies module in-place. +ModulePackerFn = Callable[[nn.Module, dict[str, CanonicalQuantizedWeight]], None] + + +def _assign_unquantized(model: nn.Module, unquantized: dict[str, torch.Tensor]) -> None: + """Assign plain (unquantized) tensors to model parameters and buffers.""" + model_sd_keys = set(model.state_dict().keys()) + for fqn, tensor in unquantized.items(): + if fqn not in model_sd_keys: + continue + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + attr_name = parts[-1] + if isinstance(getattr(parent, attr_name, None), nn.Parameter): + setattr(parent, attr_name, nn.Parameter(tensor, requires_grad=False)) + else: + parent.register_buffer(attr_name, tensor) + + +def pack_model( + model: nn.Module, + quantized: dict[str, CanonicalQuantizedWeight], + unquantized: dict[str, torch.Tensor], + packers: dict[type, ModulePackerFn], +) -> None: + """Pack canonical weights into ``model`` using the given packers. + + Groups quantized weights by their parent module, then dispatches to the + appropriate per-module packer based on the module's type. Models with + custom module types (e.g., ``FusedMoEExperts``) extend ``packers``. + + Pure logic — no file I/O, no backend dependency. + """ + + _assign_unquantized(model, unquantized) + + module_weights: dict[str, dict[str, CanonicalQuantizedWeight]] = defaultdict(dict) + for fqn, cw in quantized.items(): + parent_fqn, attr = fqn.rsplit(".", 1) + module_weights[parent_fqn][attr] = cw + + for parent_fqn, weights in module_weights.items(): + module = model.get_submodule(parent_fqn) + packer = packers.get(type(module)) + if packer is None: + raise ValueError( + f"No packer registered for {type(module).__name__} at '{parent_fqn}'. " + f"Registered types: {[t.__name__ for t in packers]}." + ) + packer(module, weights) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) + + for p in model.parameters(): + p.requires_grad_(False) diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py new file mode 100644 index 00000000000..039f2cbf7ba --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA packer: canonical → CUDA runtime format. + +Provides per-module packers for the CUDA backend (INT4 via tinygemm, +INT8 via ``IntxUnpackedToInt8Tensor``) and ``load_and_pack_for_cuda`` +as a convenience I/O wrapper. + +The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. +""" + +import torch +import torch.nn as nn + +from .pack import ModulePackerFn, pack_model # noqa: F401 +from .serialize import CanonicalQuantizedWeight, load + + +# --------------------------------------------------------------------------- +# Low-level: canonical → Int4TilePackedTo4dTensor (one weight at a time) + + +def pack_int4_for_cuda( + cw: CanonicalQuantizedWeight, + device: str = "cuda", +) -> nn.Parameter: + """Convert a canonical 4-bit weight to ``Int4TilePackedTo4dTensor``. + + Pads K to a multiple of 1024 and N to a multiple of 8 (tinygemm + requirements), nibble-packs, then tile-packs via the CUDA kernel. + Returns an ``nn.Parameter`` wrapping the subclass tensor **on CUDA**. + """ + from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( + Int4TilePackedTo4dTensor, + ) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + from torchao.utils import find_multiple + + assert cw.config.bits == 4, f"Expected 4-bit, got {cw.config.bits}" + assert cw.qdata.ndim == 2, ( + f"pack_int4_for_cuda requires 2D weight (nn.Linear), got {cw.qdata.ndim}D " + f"shape {tuple(cw.qdata.shape)}." + ) + + original_shape = cw.qdata.shape + N, K = original_shape + gs = cw.config.group_size + inner_k_tiles = 8 + + K_padded = find_multiple(K, 1024) + N_padded = find_multiple(N, 8) + + int_data = cw.qdata.to(torch.int32) + if K_padded != K or N_padded != N: + int_data = torch.nn.functional.pad(int_data, (0, K_padded - K, 0, N_padded - N)) + + scale = cw.scale + n_groups_orig = K // gs + n_groups_padded = K_padded // gs + if n_groups_padded != n_groups_orig or N_padded != N: + scale = torch.nn.functional.pad( + scale, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) + ) + + if cw.zero is not None: + zero = cw.zero + if n_groups_padded != n_groups_orig or N_padded != N: + zero = torch.nn.functional.pad( + zero, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) + ) + else: + # Symmetric: qdata is unsigned [0, 15] (shifted +8 from signed [-8, 7]). + # Standard convention: weight = (q - zp_std) * scale, so zp_std = 8. + zero = torch.full_like(scale, 8.0) + + int_data = int_data.to(device) + scale = scale.to(device) + zero = zero.to(device) + + # Convert zero from standard convention (weight = (q - zp_std) * scale) + # to tinygemm convention (weight = (q - 8) * scale + zp_tg). + # Derivation: (q - zp_std) * scale = (q - 8) * scale + zp_tg + # → zp_tg = (8 - zp_std) * scale + tinygemm_zero = (8 - zero.to(torch.float32)) * scale.to(torch.float32) + + # Tinygemm nibble convention: even index in HIGH nibble, odd in LOW. + # (This differs from serialize.py's _nibble_pack which uses the opposite + # convention for on-disk storage — both are valid, they serve different + # consumers.) + int_data_u8 = (int_data[:, ::2] << 4 | int_data[:, 1::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data_u8.contiguous(), inner_k_tiles + ) + + scale_and_zero = pack_tinygemm_scales_and_zeros( + scale.to(torch.bfloat16), tinygemm_zero.to(torch.bfloat16), torch.bfloat16 + ) + + subclass = Int4TilePackedTo4dTensor( + qdata=packed_weight, + scale_and_zero=scale_and_zero, + block_size=[1, gs], + shape=torch.Size(original_shape), + ) + return nn.Parameter(subclass, requires_grad=False) + + +# --------------------------------------------------------------------------- +# Per-module packers + + +def pack_int8_for_cuda( + cw: CanonicalQuantizedWeight, +) -> nn.Parameter: + """Convert a canonical 8-bit weight to ``IntxUnpackedToInt8Tensor``. + + Unlike INT4 (which needs tinygemm tile packing), INT8 weights are stored + unpacked. The subclass carries int8 qdata + scales and dequantizes during + matmul — AOTI fuses the ``dequantize → mm`` pattern in the compiled graph. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + assert cw.config.bits == 8, f"Expected 8-bit, got {cw.config.bits}" + assert cw.qdata.ndim == 2, f"Expected 2D weight, got {cw.qdata.ndim}D" + + N, K = cw.qdata.shape + n_groups = K // cw.config.group_size + scale = cw.scale.to(torch.bfloat16).reshape(N, n_groups) + zero_point = ( + cw.zero.to(torch.int8).reshape(N, n_groups) + if cw.zero is not None + else torch.zeros(N, n_groups, dtype=torch.int8) + ) + + subclass = IntxUnpackedToInt8Tensor( + qdata=cw.qdata, + scale=scale, + zero_point=zero_point, + target_dtype=torch.int8, + block_size=(1, cw.config.group_size), + dtype=torch.bfloat16, + activation_quantization=None, + ) + return nn.Parameter(subclass, requires_grad=False) + + +def pack_linear_for_cuda( + module: nn.Module, weights: dict[str, CanonicalQuantizedWeight] +) -> None: + """Pack a quantized ``nn.Linear`` for CUDA. + + 4-bit weights use ``Int4TilePackedTo4dTensor`` (tinygemm kernel, requires + CUDA for packing). 8-bit weights use ``IntxUnpackedToInt8Tensor`` (AOTI + fuses the dequantize-matmul pattern). Both stay as tensor subclasses so + the export graph captures quantized ops. + """ + cw = weights["weight"] + if cw.config.bits == 4: + packed = pack_int4_for_cuda(cw, device="cuda") + module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False) + torch.cuda.empty_cache() + elif cw.config.bits == 8: + module.weight = pack_int8_for_cuda(cw) + else: + raise ValueError(f"Unsupported bit width: {cw.config.bits}") + + +def pack_embedding_for_cuda( + module: nn.Module, weights: dict[str, CanonicalQuantizedWeight] +) -> None: + """Pack a quantized ``nn.Embedding`` for CUDA. + + Uses ``IntxUnpackedToInt8Tensor`` which supports embedding gather. + Only INT8 is supported — ``Int4TilePackedTo4dTensor`` does not + implement the embedding op. + """ + cw = weights["weight"] + if cw.config.bits != 8: + raise ValueError( + f"Only 8-bit embedding quantization is supported on CUDA, " + f"got {cw.config.bits}-bit." + ) + module.weight = pack_int8_for_cuda(cw) + + +DEFAULT_CUDA_PACKERS: dict[type, ModulePackerFn] = { + nn.Linear: pack_linear_for_cuda, + nn.Embedding: pack_embedding_for_cuda, +} + + +# --------------------------------------------------------------------------- +# Load + pack (I/O wrapper) + + +def load_and_pack_for_cuda( + path: str, + model: nn.Module, + packers: dict[type, ModulePackerFn] | None = None, +) -> None: + """Read a quantized safetensors file and pack into ``model`` for CUDA. + + Thin wrapper: ``load`` + ``pack_model``. + """ + quantized, unquantized = load(path) + pack_model(model, quantized, unquantized, packers or DEFAULT_CUDA_PACKERS) diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py new file mode 100644 index 00000000000..0ebfd032681 --- /dev/null +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantize weights to canonical form. + +``quantize_weight`` quantizes a single tensor given a ``QuantConfig``, +dispatching to the appropriate algorithm based on ``config.method``: + + - ``"min_max"``: standard symmetric/asymmetric quantization via torchao's + ``choose_qparams_affine`` + ``quantize_affine``. Runs on CPU or CUDA. + - ``"hqq"``: Half-Quadratic Quantization — iteratively refines scales via + a proximal solver for better accuracy. ``symmetric=False`` optimizes both + scale and zero (requires CUDA). ``symmetric=True`` optimizes scale only + (CPU or CUDA). + +``quantize_model`` walks a model's parameters, applies a ``QuantRecipe``, +and returns two dicts: quantized weights as ``CanonicalQuantizedWeight`` +and unquantized weights as plain tensors. + +Both are model-agnostic — they work for any ``nn.Module`` and any weight +shape (2D linears, 3D fused-expert stacks, etc.). +""" + +import torch +import torch.nn as nn + +from .recipe import QuantConfig, QuantRecipe + +from .serialize import CanonicalQuantizedWeight + + +# --------------------------------------------------------------------------- +# Per-weight quantization + + +def _quantize_min_max( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Standard min/max quantization. Returns (int_data, scale, zero_point).""" + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + if config.bits == 4: + qmin, qmax = (-8, 7) if config.symmetric else (0, 15) + elif config.bits == 8: + qmin, qmax = -128, 127 + else: + raise ValueError(f"Unsupported bits={config.bits}") + + mapping = MappingType.SYMMETRIC if config.symmetric else MappingType.ASYMMETRIC + block_size = tuple([1] * (weight.ndim - 1) + [config.group_size]) + + scale, zero_point = choose_qparams_affine( + weight.float(), + mapping, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + scale_dtype=torch.bfloat16, + zero_point_dtype=torch.bfloat16, + ) + int_data = quantize_affine( + weight.float(), + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + return int_data, scale, zero_point + + +def _quantize_hqq_asymmetric( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Full HQQ (asymmetric, optimizes scale + zero). Requires CUDA. + + Returns (int_data, scale, zero_point) in canonical layout. + """ + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_affine_hqq, + ) + + device = weight.device + if device.type != "cuda": + device = torch.device("cuda") + + W_q, scale, zero, _shape = _choose_qparams_and_quantize_affine_hqq( + weight, + nbits=config.bits, + group_size=config.group_size, + axis=1, + compute_dtype=torch.bfloat16, + device=str(device), + raw_output=True, + ) + + int_data = W_q.to(torch.int8) + scale = scale.to(torch.bfloat16).reshape(*weight.shape[:-1], -1) + zero = zero.to(torch.bfloat16).reshape(*weight.shape[:-1], -1) + + return int_data, scale, zero + + +def _quantize_hqq_symmetric( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Scale-only HQQ (symmetric, optimizes scale only). Runs on CPU or CUDA. + + Returns (int_data, scale, zero_point) where zero_point is all zeros. + """ + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + + if config.bits == 4: + qmin, qmax = -8, 7 + elif config.bits == 8: + qmin, qmax = -128, 127 + else: + raise ValueError(f"Unsupported bits={config.bits}") + + # scale_only_hqq requires 2D. For 3D+, flatten → quantize → reshape. + orig_shape = weight.shape + weight_2d = weight.reshape(-1, weight.shape[-1]) if weight.ndim > 2 else weight + + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + weight_2d, + [1, config.group_size], + qmin, + qmax, + ) + + int_data = qdata.to(torch.int8).reshape(orig_shape) + scale = scale.to(torch.bfloat16).reshape(*orig_shape[:-1], -1) + zero_point = torch.zeros_like(scale) + + return int_data, scale, zero_point + + +def quantize_weight( + weight: torch.Tensor, + config: QuantConfig, +) -> CanonicalQuantizedWeight: + """Quantize ``weight`` to canonical form. + + Dispatches to the algorithm specified by ``config.method``. The input is + processed in float32 internally for numerical stability. Does NOT pad or + pack for any backend. + """ + if config.method == "min_max": + int_data, scale, zero_point = _quantize_min_max(weight, config) + elif config.method == "hqq": + if config.symmetric: + int_data, scale, zero_point = _quantize_hqq_symmetric(weight, config) + else: + int_data, scale, zero_point = _quantize_hqq_asymmetric(weight, config) + else: + raise ValueError( + f"Unknown quantization method: {config.method!r}. " + f"Supported: 'min_max', 'hqq'." + ) + + # Normalize 4-bit to unsigned [0, 15] for uniform storage and nibble + # packing. Symmetric min_max produces [-8, 7]; shift to [0, 15]. + # HQQ already produces [0, 15] (asymmetric internally). + if config.bits == 4 and config.symmetric: + int_data = int_data + 8 + + return CanonicalQuantizedWeight( + qdata=int_data.to(torch.int8), + scale=scale.to(torch.bfloat16), + zero=zero_point.to(torch.bfloat16) if not config.symmetric else None, + config=config, + ) + + +# --------------------------------------------------------------------------- +# Per-model quantization + + +def quantize_model( + model: nn.Module, + recipe: QuantRecipe, + dtype: torch.dtype = torch.bfloat16, +) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: + """Walk model parameters + persistent buffers, apply recipe. + + For each parameter matched by a recipe rule: quantize to canonical. + Parameters that match ``None`` (skip) rules and persistent buffers go + into the unquantized dict (cast to ``dtype``). Non-persistent buffers + (KV cache, RoPE tables, etc.) are excluded. + + Returns ``(quantized, unquantized)`` dicts keyed by FQN. + """ + quantized: dict[str, CanonicalQuantizedWeight] = {} + unquantized: dict[str, torch.Tensor] = {} + persistent_keys = set(model.state_dict().keys()) + + n_params = sum(1 for _ in model.named_parameters()) + for i, (fqn, param) in enumerate(model.named_parameters()): + config = recipe.get_config(fqn) + if config is None: + unquantized[fqn] = param.data.to(dtype) + else: + quantized[fqn] = quantize_weight(param.data, config) + print(f" Quantized {i + 1}/{n_params}: {fqn}", end="\r") + print() + + for fqn, buf in model.named_buffers(): + if fqn in persistent_keys and fqn not in quantized: + unquantized[fqn] = buf.data + + return quantized, unquantized diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py new file mode 100644 index 00000000000..0edb0491640 --- /dev/null +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantization recipe: declares what to quantize and how. + +A ``QuantRecipe`` is an ordered list of ``QuantRule`` objects matched against +weight FQNs. First match wins. The recipe says nothing about packing format, +tensor subclass, or target backend. +""" + +import re +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass(frozen=True) +class QuantConfig: + """Per-weight quantization parameters.""" + + bits: int # 4, 6, 8 + group_size: int # 32, 64, 128 + symmetric: bool # True = no zero point + method: str # "min_max" | "hqq" + + +@dataclass +class QuantRule: + """A single recipe rule: regex pattern + config + optional layer filter.""" + + pattern: str # regex matched against weight FQN + config: Optional[QuantConfig] # None = skip (leave unquantized) + layers: Optional[set[int]] = field(default=None, repr=False) # None = all layers + + +@dataclass +class QuantRecipe: + """Ordered list of rules. First match wins.""" + + rules: list[QuantRule] + + def get_config(self, fqn: str) -> Optional[QuantConfig]: + """Return the ``QuantConfig`` for a weight FQN, or ``None`` to skip.""" + layer_idx = self._extract_layer_idx(fqn) + for rule in self.rules: + if rule.layers is not None: + if layer_idx is None or layer_idx not in rule.layers: + continue + if re.fullmatch(rule.pattern, fqn): + return rule.config + return None + + @staticmethod + def _extract_layer_idx(fqn: str) -> Optional[int]: + m = re.search(r"layers\.(\d+)\.", fqn) + return int(m.group(1)) if m else None diff --git a/examples/models/gemma4_31b/quant/serialize.py b/examples/models/gemma4_31b/quant/serialize.py new file mode 100644 index 00000000000..5996599ad90 --- /dev/null +++ b/examples/models/gemma4_31b/quant/serialize.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Serialize and persist quantized weights. + +Two layers: + + - **serialize / deserialize** — convert between ``CanonicalQuantizedWeight`` + objects and plain tensors + JSON metadata. Pure logic, no I/O. The output + is a ``(tensors_dict, metadata_dict)`` pair that any file writer can + consume. + - **save / load** — write/read the serialized form to/from safetensors on + disk. Thin I/O wrappers around ``safetensors.save_file`` / + ``safetensors.safe_open``. + +For 4-bit weights, qdata is nibble-packed (two values per byte) during +serialization to keep file size at ~0.5 bytes/param. +""" + +import json +from dataclasses import dataclass +from typing import Optional + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from .recipe import QuantConfig + +# Bump when the on-disk layout changes in a backward-incompatible way +# (e.g., different nibble-pack convention, renamed keys, new required fields). +# The loader rejects files with an unsupported version rather than silently +# producing corrupt data. +FORMAT_VERSION = "1" +_SUPPORTED_VERSIONS = {FORMAT_VERSION} + + +@dataclass +class CanonicalQuantizedWeight: + """Packing-free quantized weight representation. + + ``qdata`` int8 values: [0, 15] for 4-bit (both symmetric and asymmetric + are stored as unsigned after shifting), [-128, 127] for 8-bit. + ``scale`` bf16 per-group scales, shape ``[*weight_shape[:-1], K // group_size]``. + ``zero`` bf16 per-group zero points (``None`` when symmetric). + ``config`` the ``QuantConfig`` that produced this. + """ + + qdata: torch.Tensor + scale: torch.Tensor + zero: Optional[torch.Tensor] + config: QuantConfig + + +# --------------------------------------------------------------------------- +# Nibble packing for 4-bit on-disk storage. +# +# Two 4-bit values are packed into one byte to halve file size. The +# convention is: even-indexed values go into the LOW nibble (bits 0-3), +# odd-indexed values go into the HIGH nibble (bits 4-7). +# +# values: [v0, v1, v2, v3, ...] (each in [0, 15]) +# packed: [v0 | (v1 << 4), v2 | (v3 << 4), ...] +# byte 0: bits 0-3 = v0, bits 4-7 = v1 +# +# To unpack: low = byte & 0x0F, high = (byte >> 4) & 0x0F. +# +# This matches the Triton fused_moe kernel's unpack convention +# ((byte >> (k%2)*4) & 0xF) and Qwen's _quantize_experts_int4 packing. +# Note: tinygemm uses the OPPOSITE convention (even=HIGH, odd=LOW) — the +# CUDA packer in pack_cuda.py handles that conversion separately. + + +def _nibble_pack(qdata: torch.Tensor) -> torch.Tensor: + """Pack int8 values (each in [0, 15]) into half the last dim. + + Even-indexed values → low nibble, odd-indexed → high nibble. + """ + assert qdata.shape[-1] % 2 == 0, f"Last dim must be even, got {qdata.shape}" + low = qdata[..., ::2].to(torch.uint8) + high = qdata[..., 1::2].to(torch.uint8) + return (low | (high << 4)).to(torch.int8).contiguous() + + +def _nibble_unpack(packed: torch.Tensor, orig_last_dim: int) -> torch.Tensor: + """Unpack nibble-packed int8 → original last dim. + + Low nibble (bits 0-3) → even indices, high nibble (bits 4-7) → odd indices. + """ + p = packed.to(torch.uint8) + low = (p & 0x0F).to(torch.int8) + high = ((p >> 4) & 0x0F).to(torch.int8) + return torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], orig_last_dim) + + +# --------------------------------------------------------------------------- +# Serialize / deserialize (pure logic, no I/O) + + +def serialize( + quantized: dict[str, CanonicalQuantizedWeight], + unquantized: dict[str, torch.Tensor], +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Convert quantized + unquantized weights to plain tensors + metadata. + + Returns ``(tensors, header)`` ready for any file writer. Quantized + weights become ``{fqn}.qdata``, ``{fqn}.scale``, optionally + ``{fqn}.zero``. For 4-bit, qdata is nibble-packed. + """ + tensors: dict[str, torch.Tensor] = {} + quant_meta: dict[str, dict] = {} + + for fqn, cw in quantized.items(): + qdata = cw.qdata + if cw.config.bits == 4: + qdata = _nibble_pack(qdata) + tensors[f"{fqn}.qdata"] = qdata.contiguous() + tensors[f"{fqn}.scale"] = cw.scale.contiguous() + if cw.zero is not None: + tensors[f"{fqn}.zero"] = cw.zero.contiguous() + quant_meta[fqn] = { + "bits": cw.config.bits, + "group_size": cw.config.group_size, + "symmetric": cw.config.symmetric, + "method": cw.config.method, + "shape": list(cw.qdata.shape), + } + + for fqn, tensor in unquantized.items(): + tensors[fqn] = tensor.contiguous() + + header = {"format_version": FORMAT_VERSION} + if quant_meta: + header["quant"] = json.dumps(quant_meta) + + return tensors, header + + +def deserialize( + tensors: dict[str, torch.Tensor], + header: dict[str, str], +) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: + """Reconstruct quantized + unquantized weights from plain tensors + metadata. + + Inverse of ``serialize``. Returns ``(quantized, unquantized)`` dicts. + """ + version = header.get("format_version", "1") + if version not in _SUPPORTED_VERSIONS: + raise ValueError( + f"Unsupported format version {version!r}. " + f"This code supports {sorted(_SUPPORTED_VERSIONS)}. " + f"Update the quant package or re-quantize the model." + ) + + quant_meta = json.loads(header.get("quant", "{}")) + + quantized: dict[str, CanonicalQuantizedWeight] = {} + consumed_keys: set[str] = set() + + for fqn, meta in quant_meta.items(): + config = QuantConfig( + bits=meta["bits"], + group_size=meta["group_size"], + symmetric=meta["symmetric"], + method=meta["method"], + ) + qdata = tensors[f"{fqn}.qdata"] + consumed_keys.add(f"{fqn}.qdata") + + original_shape = meta["shape"] + if config.bits == 4: + qdata = _nibble_unpack(qdata, original_shape[-1]) + + scale = tensors[f"{fqn}.scale"] + consumed_keys.add(f"{fqn}.scale") + + zero = tensors.get(f"{fqn}.zero") + if zero is not None: + consumed_keys.add(f"{fqn}.zero") + + quantized[fqn] = CanonicalQuantizedWeight( + qdata=qdata, scale=scale, zero=zero, config=config + ) + + unquantized = {k: v for k, v in tensors.items() if k not in consumed_keys} + + return quantized, unquantized + + +# --------------------------------------------------------------------------- +# Save / load (I/O wrappers) + + +def save( + quantized: dict[str, CanonicalQuantizedWeight], + unquantized: dict[str, torch.Tensor], + path: str, +) -> int: + """Serialize and write to safetensors. Returns the number of tensors written.""" + tensors, header = serialize(quantized, unquantized) + save_file(tensors, path, metadata=header) + return len(tensors) + + +def load( + path: str, +) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: + """Read safetensors and deserialize. Returns ``(quantized, unquantized)``.""" + with safe_open(path, framework="pt", device="cpu") as f: + header = f.metadata() + tensors = {k: f.get_tensor(k) for k in f.keys()} + return deserialize(tensors, header) diff --git a/examples/models/gemma4_31b/quant/test_pack_cuda.py b/examples/models/gemma4_31b/quant/test_pack_cuda.py new file mode 100644 index 00000000000..5a20d02998b --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_pack_cuda.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/pack_cuda.py. Requires CUDA.""" + +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from .pack_cuda import ( + DEFAULT_CUDA_PACKERS, + load_and_pack_for_cuda, + pack_embedding_for_cuda, + pack_int4_for_cuda, + pack_int8_for_cuda, + pack_linear_for_cuda, + pack_model, +) +from .quantize import quantize_weight +from .recipe import QuantConfig + +from .serialize import save + + +class TestPackInt4ForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_symmetric_works(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) + self.assertEqual(pack_int4_for_cuda(cw).shape, torch.Size([128, 256])) + + def test_rejects_1d(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(1, 128, dtype=torch.bfloat16), config) + cw.qdata = cw.qdata.squeeze(0) + with self.assertRaises(AssertionError): + pack_int4_for_cuda(cw) + + def test_rejects_8bit(self): + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + with self.assertRaises(AssertionError): + pack_int4_for_cuda(cw) + + def test_different_group_sizes(self): + for gs in (32, 64, 128): + with self.subTest(group_size=gs): + config = QuantConfig( + bits=4, group_size=gs, symmetric=False, method="min_max" + ) + cw = quantize_weight( + torch.randn(128, 256, dtype=torch.bfloat16), config + ) + self.assertEqual(pack_int4_for_cuda(cw).shape, torch.Size([128, 256])) + + def test_matmul_approximates_original(self): + """Packed weight produces matmul output close to the original.""" + torch.manual_seed(0) + # Use dimensions already aligned to tinygemm requirements + # (K multiple of 1024, N multiple of 8) to avoid padding effects. + weight = torch.randn(256, 1024, dtype=torch.bfloat16) + x = torch.randn(1, 1024, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int4_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_symmetric_matmul_approximates_original(self): + """Symmetric 4-bit (e.g. HQQ) packs correctly for tinygemm.""" + torch.manual_seed(0) + weight = torch.randn(256, 1024, dtype=torch.bfloat16) + x = torch.randn(1, 1024, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int4_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + +class TestPackInt8ForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_rejects_4bit(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + with self.assertRaises(AssertionError): + pack_int8_for_cuda(cw) + + def test_matmul_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(256, 128, dtype=torch.bfloat16) + x = torch.randn(1, 128, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int8_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_per_axis_gather_approximates_original(self): + """Per-axis INT8 (group_size == K) works for embedding gather.""" + torch.manual_seed(0) + weight = torch.randn(1000, 64, dtype=torch.bfloat16) + ids = torch.tensor([0, 1, 42, 500, 999]) + + original = weight[ids] + + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int8_for_cuda(cw) + + emb = nn.Embedding(1000, 64) + emb.weight = nn.Parameter(packed, requires_grad=False) + emb.to("cuda") + packed_out = emb(ids.cuda()) + + rel_error = ( + packed_out.cpu().float() - original.float() + ).abs().mean() / original.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + +class TestPackLinearForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_4bit_modifies_module_in_place(self): + module = nn.Linear(128, 256, bias=False) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(256, 128, dtype=torch.bfloat16), config) + pack_linear_for_cuda(module, {"weight": cw}) + self.assertEqual(module.weight.device.type, "cpu") + self.assertEqual(module.weight.shape, torch.Size([256, 128])) + + def test_8bit_modifies_module_in_place(self): + module = nn.Linear(128, 64, bias=False) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + pack_linear_for_cuda(module, {"weight": cw}) + self.assertEqual(module.weight.shape, torch.Size([64, 128])) + + +class TestPackEmbeddingForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_gather_approximates_original(self): + """INT8 quantized embedding gather matches bf16 gather.""" + torch.manual_seed(0) + weight = torch.randn(1000, 64, dtype=torch.bfloat16) + ids = torch.tensor([0, 1, 42, 500, 999]) + + original = weight[ids] + + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + cw = quantize_weight(weight, config) + + module = nn.Embedding(1000, 64) + pack_embedding_for_cuda(module, {"weight": cw}) + module.to("cuda") + packed_out = module(ids.cuda()) + + rel_error = ( + packed_out.cpu().float() - original.float() + ).abs().mean() / original.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_rejects_4bit(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + module = nn.Embedding(100, 64) + with self.assertRaises(ValueError): + pack_embedding_for_cuda(module, {"weight": cw}) + + +class TestLoadAndPackForCuda(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_pack_model_in_memory(self): + """pack_model works with in-memory dicts (no file I/O).""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + unq = {"norm.weight": torch.randn(64, dtype=torch.bfloat16)} + + with torch.device("meta"): + model = nn.ModuleDict( + { + "proj": nn.Linear(128, 64, bias=False), + "norm": nn.LayerNorm(64, bias=False), + } + ) + pack_model(model, {"proj.weight": cw}, unq, DEFAULT_CUDA_PACKERS) + + self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.norm.weight.shape, torch.Size([64])) + + def test_pack_model_mixed_precision(self): + """pack_model handles 4-bit and 8-bit weights in the same model.""" + q4_config = QuantConfig( + bits=4, group_size=32, symmetric=False, method="min_max" + ) + q8_config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4_config) + cw8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8_config) + + with torch.device("meta"): + model = nn.ModuleDict( + { + "q_proj": nn.Linear(128, 64, bias=False), + "v_proj": nn.Linear(128, 64, bias=False), + } + ) + pack_model( + model, + {"q_proj.weight": cw4, "v_proj.weight": cw8}, + {}, + DEFAULT_CUDA_PACKERS, + ) + + self.assertEqual(model.q_proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.v_proj.weight.shape, torch.Size([64, 128])) + # Verify different subclass types + self.assertNotEqual( + type(model.q_proj.weight.data).__name__, + type(model.v_proj.weight.data).__name__, + ) + + def test_dispatches_by_module_type(self): + """load_and_pack_for_cuda reads from disk and dispatches.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"proj.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict({"proj": nn.Linear(128, 64, bias=False)}) + load_and_pack_for_cuda(path, model2) + + self.assertEqual(model2.proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model2.proj.weight.device.type, "cpu") + + def test_unknown_module_type_raises(self): + """Unregistered module types get a clear error.""" + + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 64)) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"custom.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict({"custom": CustomModule()}) + with self.assertRaises(ValueError) as ctx: + load_and_pack_for_cuda(path, model2) + self.assertIn("CustomModule", str(ctx.exception)) + + def test_missing_weight_raises(self): + """A meta-device parameter after loading means the checkpoint was incomplete.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + # Only save weight for 'a', not 'b' + save({"a.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict( + { + "a": nn.Linear(64, 32, bias=False), + "b": nn.Linear(64, 32, bias=False), + } + ) + with self.assertRaises(RuntimeError) as ctx: + load_and_pack_for_cuda(path, model2) + self.assertIn("b.weight", str(ctx.exception)) + + def test_custom_packer_via_dict(self): + """Models can extend the packer dict with custom module types.""" + call_log = [] + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 64)) + + def my_packer(module, weights): + call_log.append(("my_packer", list(weights.keys()))) + cw = weights["weight"] + module.weight = nn.Parameter( + cw.qdata.to(torch.bfloat16), requires_grad=False + ) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + custom_packers = {**DEFAULT_CUDA_PACKERS, MyModule: my_packer} + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"m.weight": cw}, {}, path) + + with torch.device("meta"): + model2 = nn.ModuleDict({"m": MyModule()}) + load_and_pack_for_cuda(path, model2, packers=custom_packers) + + self.assertEqual(len(call_log), 1) + self.assertEqual(call_log[0], ("my_packer", ["weight"])) + self.assertEqual(model2.m.weight.device.type, "cpu") + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/test_quantize.py b/examples/models/gemma4_31b/quant/test_quantize.py new file mode 100644 index 00000000000..214a22f718b --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_quantize.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/quantize.py. + +Tests the public API: ``quantize_weight`` and ``quantize_model``. Organized +by resource requirement (CPU vs CUDA), not by internal codepath. +""" + +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized + +from .quantize import quantize_model, quantize_weight +from .recipe import QuantConfig, QuantRecipe, QuantRule + + +# --------------------------------------------------------------------------- +# quantize_weight — CPU (uses min_max; tests the output contract) + + +class TestQuantizeWeight(unittest.TestCase): + @parameterized.expand( + [ + ("4bit_asym", 4, 32, False, (64, 128), 0, 15), + ("4bit_sym", 4, 32, True, (64, 128), 0, 15), + ("4bit_gs64", 4, 64, False, (32, 128), 0, 15), + ("8bit_sym", 8, 32, True, (32, 64), -128, 127), + ("3d_expert", 4, 32, False, (8, 64, 128), 0, 15), + ] + ) + def test_output_structure(self, _name, bits, gs, sym, shape, qmin, qmax): + config = QuantConfig(bits=bits, group_size=gs, symmetric=sym, method="min_max") + cw = quantize_weight(torch.randn(*shape, dtype=torch.bfloat16), config) + + self.assertEqual(cw.qdata.shape, shape) + self.assertEqual(cw.qdata.dtype, torch.int8) + self.assertEqual(cw.scale.shape, (*shape[:-1], shape[-1] // gs)) + self.assertGreaterEqual(cw.qdata.min().item(), qmin) + self.assertLessEqual(cw.qdata.max().item(), qmax) + + if sym: + self.assertIsNone(cw.zero) + else: + self.assertIsNotNone(cw.zero) + self.assertEqual(cw.zero.shape, cw.scale.shape) + + self.assertEqual(cw.config, config) + + def test_fp32_input(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.float32), config) + self.assertEqual(cw.qdata.shape, (32, 64)) + + def test_dequant_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.float32) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(weight, config) + scale = cw.scale.float().repeat_interleave(config.group_size, dim=-1) + zero = cw.zero.float().repeat_interleave(config.group_size, dim=-1) + dequant = (cw.qdata.float() - zero) * scale + rel_error = (dequant - weight).abs().mean() / weight.abs().mean() + self.assertLess(rel_error.item(), 0.15) + + @parameterized.expand( + [ + ("unknown_method", QuantConfig(4, 32, False, "bogus"), "bogus"), + ("unsupported_bits", QuantConfig(3, 32, False, "min_max"), None), + ] + ) + def test_invalid_config_raises(self, _name, config, expected_substr): + with self.assertRaises(ValueError) as ctx: + quantize_weight(torch.randn(32, 64), config) + if expected_substr: + self.assertIn(expected_substr, str(ctx.exception)) + + +# --------------------------------------------------------------------------- +# quantize_weight — CUDA (HQQ-specific behavior only) + + +class TestQuantizeWeightHQQ(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required for HQQ") + + def test_dequant_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.float32, device="cuda") + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") + cw = quantize_weight(weight, config) + scale = cw.scale.cpu().float().repeat_interleave(config.group_size, dim=-1) + zero = cw.zero.cpu().float().repeat_interleave(config.group_size, dim=-1) + dequant = (cw.qdata.cpu().float() - zero) * scale + rel_error = (dequant - weight.cpu()).abs().mean() / weight.cpu().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_symmetric_scale_only(self): + """symmetric=True dispatches to scale-only HQQ (no zero).""" + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") + cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + self.assertIsNone(cw.zero) + self.assertGreaterEqual(cw.qdata.min().item(), 0) + self.assertLessEqual(cw.qdata.max().item(), 15) + + def test_cpu_input_accepted(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") + cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + self.assertEqual(cw.qdata.shape, (32, 64)) + + +# --------------------------------------------------------------------------- +# quantize_model + + +class TestQuantizeModel(unittest.TestCase): + def test_applies_recipe(self): + model = nn.ModuleDict( + { + "embed": nn.Embedding(32, 16), + "proj": nn.Linear(16, 32, bias=False), + "norm": nn.LayerNorm(32), + } + ) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + p.data.normal_(0, 0.02) + + recipe = QuantRecipe( + rules=[ + QuantRule(r"embed\.weight", None), + QuantRule(r"norm\.weight", None), + QuantRule(r".*\.weight", QuantConfig(4, 16, False, "min_max")), + ] + ) + + quantized, unquantized = quantize_model(model, recipe) + + self.assertIn("proj.weight", quantized) + self.assertEqual(quantized["proj.weight"].qdata.shape, (32, 16)) + self.assertIn("embed.weight", unquantized) + self.assertIn("norm.weight", unquantized) + self.assertNotIn("embed.weight", quantized) + self.assertNotIn("norm.weight", quantized) + + def test_persistent_buffers_included(self): + model = nn.Module() + model.weight = nn.Parameter(torch.randn(16, 32, dtype=torch.bfloat16)) + model.register_buffer("scalar", torch.ones(1)) + model.register_buffer("temp", torch.zeros(4), persistent=False) + + recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) + _, unquantized = quantize_model(model, recipe) + + self.assertIn("scalar", unquantized) + self.assertNotIn("temp", unquantized) + + def test_unquantized_cast_to_dtype(self): + model = nn.ModuleDict({"proj": nn.Linear(16, 8, bias=False)}) + model.proj.weight.data = torch.randn(8, 16, dtype=torch.float32) + + recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) + _, unquantized = quantize_model(model, recipe, dtype=torch.float16) + + self.assertEqual(unquantized["proj.weight"].dtype, torch.float16) + + def test_empty_model(self): + quantized, unquantized = quantize_model(nn.Module(), QuantRecipe(rules=[])) + self.assertEqual(len(quantized), 0) + self.assertEqual(len(unquantized), 0) + + def test_all_quantized(self): + model = nn.ModuleDict({"a": nn.Linear(32, 16, bias=False)}) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + p.data.normal_(0, 0.02) + + config = QuantConfig(bits=4, group_size=16, symmetric=False, method="min_max") + quantized, unquantized = quantize_model( + model, QuantRecipe(rules=[QuantRule(r".*", config)]) + ) + self.assertEqual(len(quantized), 1) + self.assertIn("a.weight", quantized) + self.assertEqual(len(unquantized), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/test_recipe.py b/examples/models/gemma4_31b/quant/test_recipe.py new file mode 100644 index 00000000000..6bd04a936a3 --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_recipe.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/recipe.py. CPU only — no CUDA, no model, no torchao.""" + +import unittest + +from parameterized import parameterized + +from .recipe import QuantConfig, QuantRecipe, QuantRule + +_Q4 = QuantConfig(4, 32, True, "min_max") +_Q6 = QuantConfig(6, 32, False, "min_max") + + +class TestQuantRecipeGetConfig(unittest.TestCase): + """Tests for ``QuantRecipe.get_config`` — the core matching logic.""" + + @parameterized.expand( + [ + ( + "first_match_wins", + [QuantRule(r".*v_proj\.weight", _Q6), QuantRule(r".*\.weight", _Q4)], + "layers.0.self_attn.v_proj.weight", + 6, + ), + ( + "fallthrough_to_catchall", + [QuantRule(r".*v_proj\.weight", _Q6), QuantRule(r".*\.weight", _Q4)], + "layers.0.self_attn.q_proj.weight", + 4, + ), + ( + "none_rule_skips", + [ + QuantRule(r"embed_tokens\.weight", None), + QuantRule(r".*\.weight", _Q4), + ], + "embed_tokens.weight", + None, + ), + ( + "unmatched_returns_none", + [QuantRule(r"foo", _Q4)], + "bar.weight", + None, + ), + ( + "empty_recipe", + [], + "anything", + None, + ), + ( + "fullmatch_not_partial", + [QuantRule(r"foo", _Q4)], + "foo.bar", + None, + ), + ( + "fullmatch_exact", + [QuantRule(r"foo", _Q4)], + "foo", + 4, + ), + ] + ) + def test_get_config(self, _name, rules, fqn, expected_bits): + recipe = QuantRecipe(rules=rules) + config = recipe.get_config(fqn) + if expected_bits is None: + self.assertIsNone(config) + else: + self.assertEqual(config.bits, expected_bits) + + +class TestQuantRecipeLayerFilter(unittest.TestCase): + """Tests for the ``layers`` field on ``QuantRule``.""" + + def test_layer_filter(self): + edge = set(range(5)) | set(range(55, 60)) + recipe = QuantRecipe( + rules=[ + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q6, layers=edge), + QuantRule(r".*\.weight", _Q4), + ] + ) + # Edge v_proj → 6-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 6) + self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 6) + # Middle v_proj → falls through → 4-bit + self.assertEqual(recipe.get_config("layers.30.self_attn.v_proj.weight").bits, 4) + # q_proj always 4-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.q_proj.weight").bits, 4) + # Non-layer FQN skips layer-filtered rule, hits catch-all + self.assertEqual(recipe.get_config("lm_head.weight").bits, 4) + + def test_layer_filter_with_none_config(self): + """Skip rule scoped to specific layers.""" + recipe = QuantRecipe( + rules=[ + QuantRule(r".*\.weight", None, layers={0}), + QuantRule(r".*\.weight", _Q4), + ] + ) + self.assertIsNone(recipe.get_config("layers.0.mlp.gate_proj.weight")) + self.assertEqual(recipe.get_config("layers.1.mlp.gate_proj.weight").bits, 4) + + +class TestProductionRecipes(unittest.TestCase): + """Regression tests for the production recipes in quantize_and_save.py.""" + + def test_default_recipe(self): + from executorch.examples.models.gemma4_31b.quantize_and_save import ( + GEMMA4_31B_DEFAULT_RECIPE, + ) + + r = GEMMA4_31B_DEFAULT_RECIPE + self.assertIsNone(r.get_config("layers.0.input_layernorm.weight")) + self.assertIsNone(r.get_config("layers.5.self_attn.q_norm.weight")) + self.assertIsNone(r.get_config("norm.weight")) + embed_cfg = r.get_config("embed_tokens.weight") + self.assertEqual(embed_cfg.bits, 8) + self.assertEqual(embed_cfg.group_size, 5376) + for fqn in ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.down_proj.weight", + "lm_head.weight", + ): + cfg = r.get_config(fqn) + self.assertEqual(cfg.bits, 4, fqn) + self.assertEqual(cfg.method, "min_max", fqn) + + def test_sensitive_recipe(self): + from executorch.examples.models.gemma4_31b.quantize_and_save import ( + GEMMA4_31B_SENSITIVE_RECIPE, + ) + + r = GEMMA4_31B_SENSITIVE_RECIPE + self.assertIsNone(r.get_config("layers.0.input_layernorm.weight")) + embed_cfg = r.get_config("embed_tokens.weight") + self.assertEqual(embed_cfg.bits, 8) + self.assertEqual(embed_cfg.group_size, 5376) + # Edge v_proj/down_proj → int8 + self.assertEqual(r.get_config("layers.0.self_attn.v_proj.weight").bits, 8) + self.assertEqual(r.get_config("layers.0.mlp.down_proj.weight").bits, 8) + self.assertEqual(r.get_config("layers.58.self_attn.v_proj.weight").bits, 8) + # Middle v_proj/down_proj → int4 + self.assertEqual(r.get_config("layers.30.self_attn.v_proj.weight").bits, 4) + self.assertEqual(r.get_config("layers.30.mlp.down_proj.weight").bits, 4) + # q_proj always int4 + self.assertEqual(r.get_config("layers.0.self_attn.q_proj.weight").bits, 4) + self.assertEqual(r.get_config("layers.30.self_attn.q_proj.weight").bits, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/test_serialize.py b/examples/models/gemma4_31b/quant/test_serialize.py new file mode 100644 index 00000000000..302c38647ed --- /dev/null +++ b/examples/models/gemma4_31b/quant/test_serialize.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/serialize.py — data format and I/O only. + +Tests nibble pack/unpack and save/load. Does NOT test +quantize_weight (that lives in test_quantize.py). Bundle tests use +hand-built CanonicalQuantizedWeight fixtures to avoid coupling to the +quantizer. +""" + +import json +import os +import tempfile +import unittest + +import torch +from safetensors import safe_open + +from .recipe import QuantConfig + +from .serialize import ( + _nibble_pack, + _nibble_unpack, + CanonicalQuantizedWeight, + deserialize, + load, + save, + serialize, +) + + +def _make_cqw( + shape: tuple[int, ...], + config: QuantConfig, +) -> CanonicalQuantizedWeight: + """Build a CanonicalQuantizedWeight with random data (no actual quantization).""" + K = shape[-1] + n_groups = K // config.group_size + scale_shape = (*shape[:-1], n_groups) + + if config.bits == 4: + qdata = torch.randint(0, 16, shape, dtype=torch.int8) + else: + qdata = torch.randint(-128, 128, shape, dtype=torch.int8) + + return CanonicalQuantizedWeight( + qdata=qdata, + scale=torch.randn(scale_shape, dtype=torch.bfloat16), + zero=( + torch.randn(scale_shape, dtype=torch.bfloat16) + if not config.symmetric + else None + ), + config=config, + ) + + +# --------------------------------------------------------------------------- +# Nibble pack / unpack + + +class TestNibblePack(unittest.TestCase): + def test_roundtrip(self): + qdata = torch.randint(0, 16, (8, 64), dtype=torch.int8) + packed = _nibble_pack(qdata) + self.assertEqual(packed.shape, (8, 32)) + self.assertTrue(torch.equal(qdata, _nibble_unpack(packed, 64))) + + def test_rejects_odd_last_dim(self): + with self.assertRaises(AssertionError): + _nibble_pack(torch.zeros(4, 33, dtype=torch.int8)) + + def test_3d(self): + """Nibble pack works for 3D tensors (MoE expert weights).""" + qdata = torch.randint(0, 16, (4, 32, 64), dtype=torch.int8) + packed = _nibble_pack(qdata) + self.assertEqual(packed.shape, (4, 32, 32)) + self.assertTrue(torch.equal(qdata, _nibble_unpack(packed, 64))) + + +# --------------------------------------------------------------------------- +# save / load + + +class TestSerializeDeserialize(unittest.TestCase): + """Pure logic layer — no disk I/O.""" + + def test_roundtrip(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + unq = {"embed": torch.randn(8, 8, dtype=torch.bfloat16)} + + tensors, header = serialize({"w": cw}, unq) + q, u = deserialize(tensors, header) + + self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) + self.assertTrue(torch.equal(cw.scale, q["w"].scale)) + self.assertTrue(torch.equal(cw.zero, q["w"].zero)) + self.assertEqual(cw.config, q["w"].config) + self.assertTrue(torch.equal(unq["embed"], u["embed"])) + + def test_rejects_unsupported_version(self): + tensors, header = serialize({}, {"w": torch.randn(4, 4)}) + header["format_version"] = "99" + with self.assertRaises(ValueError) as ctx: + deserialize(tensors, header) + self.assertIn("99", str(ctx.exception)) + + +class TestSaveLoad(unittest.TestCase): + """I/O layer — roundtrip through safetensors on disk.""" + + def test_roundtrip_asymmetric(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + unq = {"embed.weight": torch.randn(32, 64, dtype=torch.bfloat16)} + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, unq, path) + q, u = load(path) + + self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) + self.assertTrue(torch.equal(cw.scale, q["w"].scale)) + self.assertTrue(torch.equal(cw.zero, q["w"].zero)) + self.assertEqual(cw.config, q["w"].config) + self.assertTrue(torch.equal(unq["embed.weight"], u["embed.weight"])) + + def test_roundtrip_symmetric(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + cw = _make_cqw((32, 64), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + q, _ = load(path) + + self.assertIsNone(q["w"].zero) + self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) + + def test_roundtrip_3d(self): + """3D quantized weights (MoE experts) roundtrip correctly.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((8, 64, 128), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"experts.w1": cw}, {}, path) + q, _ = load(path) + + self.assertTrue(torch.equal(cw.qdata, q["experts.w1"].qdata)) + self.assertEqual(q["experts.w1"].scale.shape, (8, 64, 4)) + + def test_4bit_nibble_packed_on_disk(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + with safe_open(path, framework="pt", device="cpu") as f: + on_disk = f.get_tensor("w.qdata") + self.assertEqual(on_disk.shape, (64, 64)) + + def test_8bit_not_nibble_packed(self): + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw = _make_cqw((32, 64), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + with safe_open(path, framework="pt", device="cpu") as f: + on_disk = f.get_tensor("w.qdata") + self.assertEqual(on_disk.shape, (32, 64)) # no packing for 8-bit + + def test_header_metadata(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((32, 64), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"foo.weight": cw}, {}, path) + with safe_open(path, framework="pt", device="cpu") as f: + meta = json.loads(f.metadata()["quant"]) + + self.assertIn("foo.weight", meta) + self.assertEqual(meta["foo.weight"]["bits"], 4) + self.assertEqual(meta["foo.weight"]["group_size"], 32) + self.assertFalse(meta["foo.weight"]["symmetric"]) + self.assertEqual(meta["foo.weight"]["method"], "min_max") + + def test_empty_quantized(self): + unq = {"w": torch.randn(8, 8, dtype=torch.bfloat16)} + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({}, unq, path) + q, u = load(path) + self.assertEqual(len(q), 0) + self.assertTrue(torch.equal(unq["w"], u["w"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py new file mode 100644 index 00000000000..7a9eb9900f2 --- /dev/null +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantize Gemma 4 31B-IT and save as a quantized checkpoint. + +Produces a packing-agnostic safetensors file (int values + per-group scales + +JSON header) that can later be loaded and packed for any backend via +``quant.load()`` and ``quant.pack_model()``. + +No CUDA is needed — quantization runs on CPU. CUDA is only required at +load-and-pack time. + +Usage: + python quantize_and_save.py \\ + --model-dir ~/local/scripts/models/gemma-4-31B-it \\ + --output ./gemma4_31b_int4 \\ + --quant-recipe default +""" + +import argparse +import os +import shutil + +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + QuantConfig, + quantize_model, + QuantRecipe, + QuantRule, + save, +) + +# --------------------------------------------------------------------------- +# Production recipes for Gemma 4 31B. +# +# Layer sensitivity: +# - v_proj and down_proj are the most sensitive to quantization error +# (first/last quarter of layers especially so). +# - q_proj, k_proj, o_proj, gate_proj, up_proj tolerate 4-bit well. +# - embed_tokens is an index lookup — INT8 per-axis is nearly lossless. +# - Norms and layer_scalar are tiny and must stay unquantized. + +_INT4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") +_INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") +_INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") +_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max") +_EDGE_LAYERS = set(range(15)) | set(range(45, 60)) + +GEMMA4_31B_DEFAULT_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.weight", _INT4), + ] +) + +GEMMA4_31B_SENSITIVE_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _INT8, layers=_EDGE_LAYERS), + QuantRule(r".*\.weight", _INT4_HQQ), + ] +) + +_RECIPES = { + "default": GEMMA4_31B_DEFAULT_RECIPE, + "sensitive": GEMMA4_31B_SENSITIVE_RECIPE, +} + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Quantize Gemma 4 31B-IT and save as a quantized checkpoint." + ) + parser.add_argument( + "--model-dir", + required=True, + help="HuggingFace Gemma 4 31B-IT model dir.", + ) + parser.add_argument( + "--output", + default="./gemma4_31b_int4", + help="Output directory.", + ) + parser.add_argument( + "--quant-recipe", + default="default", + choices=list(_RECIPES), + help="'default': int4 min_max linears + int8 per-axis embedding. " + "'sensitive': int8 for edge-layer v_proj/down_proj, int4 hqq elsewhere.", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend (the quantized checkpoint is backend-agnostic, " + "but this may influence default recipe selection in the future).", + ) + args = parser.parse_args() + + recipe = _RECIPES[args.quant_recipe] + + print("Loading checkpoint (lazy, shard-by-shard)...") + model, _ = Gemma4_31B.from_hf_checkpoint(args.model_dir) + + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + print("Untying embed_tokens / lm_head...") + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + print(f"Quantizing with recipe '{args.quant_recipe}'...") + quantized, unquantized = quantize_model(model, recipe) + + os.makedirs(args.output, exist_ok=True) + safetensors_path = os.path.join(args.output, "model.safetensors") + print("Saving quantized checkpoint...") + n_tensors = save(quantized, unquantized, safetensors_path) + + for filename in ("config.json", "tokenizer.json", "tokenizer_config.json"): + src = os.path.join(args.model_dir, filename) + if os.path.exists(src): + shutil.copy2(src, os.path.join(args.output, filename)) + + size_mb = os.path.getsize(safetensors_path) / (1024 * 1024) + print(f"Saved {n_tensors} tensors ({size_mb:.1f} MB) to {args.output}/") + print(f"Done. Use with: python export.py --prequantized {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/sampler.py b/examples/models/gemma4_31b/sampler.py new file mode 100644 index 00000000000..45e4e17887a --- /dev/null +++ b/examples/models/gemma4_31b/sampler.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""GPU-side Gumbel-max sampler. + +Mirrors ``examples/models/qwen3_5_moe/sampler.py``: a single-output sampler +that lets one exported program be re-driven with different temperatures +without re-export. ``temperature=None`` is a no-op (returns logits). +""" + +from typing import Optional + +import torch + + +def sample( + logits: torch.Tensor, + temperature: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Draw a single token per batch row using the Gumbel-max trick. + + Args: + logits: ``[B, V]`` float32 logits (already soft-capped if applicable). + temperature: 0-D or 1-D float tensor; clamped to >= 1e-6 so a 0 + temperature still works ("near-greedy"). When ``None`` the call + short-circuits and returns ``logits`` unchanged. + + Returns: + ``[B, 1]`` float32 token IDs (``argmax(logits/T + gumbel_noise)``), + or the unmodified logits when ``temperature`` is ``None``. + """ + if temperature is None: + return logits + + logits = logits / temperature.clamp(min=1e-6) + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits + gumbel).argmax(dim=-1, keepdim=True).float() diff --git a/examples/models/gemma4_31b/test_cuda_pipeline.py b/examples/models/gemma4_31b/test_cuda_pipeline.py new file mode 100644 index 00000000000..a7df5d9818c --- /dev/null +++ b/examples/models/gemma4_31b/test_cuda_pipeline.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA-specific integration tests for the Gemma 4 31B-IT pipeline. + +Tests pack → inference → export on a tiny model using the CUDA backend. +Backend-agnostic tests (quantize, save, load) live in ``test_pipeline.py``. + +Requires CUDA. + +Usage: + python -m pytest examples/models/gemma4_31b/test_cuda_pipeline.py -v +""" + +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.export import ( + export_and_lower, + load_prequantized_model, +) +from executorch.examples.models.gemma4_31b.inference import _move_to_cuda, generate +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + DEFAULT_CUDA_PACKERS, + pack_model, + quantize_model, +) +from executorch.examples.models.gemma4_31b.test_pipeline import ( + build_hf_checkpoint, + DEFAULT_RECIPE, + MockTokenizer, + save_checkpoint, + TINY_CONFIG, +) + + +def _require_cuda(testcase: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + testcase.skipTest("CUDA required") + + +class TestCudaInference(unittest.TestCase): + def setUp(self): + _require_cuda(self) + + def test_generate(self): + """save → load → pack → generate (sampling + greedy).""" + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + model, config = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + _move_to_cuda(model, config) + model.eval() + tokenizer = MockTokenizer(TINY_CONFIG.vocab_size) + + torch.manual_seed(0) + out = generate(model, tokenizer, prompt="hi", max_new_tokens=5, temperature=1.0) + self.assertIsInstance(out, str) + ids_part = out[len("" + + +def config_dict() -> dict: + cfg = TINY_CONFIG + return { + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_hidden_layers": cfg.num_hidden_layers, + "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, + "head_dim": cfg.head_dim, + "num_global_key_value_heads": cfg.num_global_key_value_heads, + "global_head_dim": cfg.global_head_dim, + "attention_k_eq_v": cfg.attention_k_eq_v, + "rope_parameters": { + "sliding_attention": {"rope_theta": cfg.sliding_rope_theta}, + "full_attention": { + "rope_theta": cfg.full_rope_theta, + "partial_rotary_factor": cfg.full_partial_rotary_factor, + }, + }, + "rms_norm_eps": cfg.rms_norm_eps, + "hidden_activation": cfg.hidden_activation, + "final_logit_softcapping": cfg.final_logit_softcapping, + "tie_word_embeddings": cfg.tie_word_embeddings, + "sliding_window": cfg.sliding_window, + "layer_types": cfg.layer_types, + } + + +def build_random_tiny_model() -> Gemma4_31B: + torch.manual_seed(42) + model = Gemma4_31B(TINY_CONFIG) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + model.eval() + return model + + +def save_checkpoint(output_dir: str): + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + os.makedirs(output_dir, exist_ok=True) + save(quantized, unquantized, os.path.join(output_dir, "model.safetensors")) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + +def build_hf_checkpoint(output_dir: str) -> None: + model = build_random_tiny_model() + sd = model.state_dict() + sd.pop("lm_head.weight", None) + hf_sd = {f"model.language_model.{k}": v.contiguous() for k, v in sd.items()} + save_file(hf_sd, os.path.join(output_dir, "model.safetensors")) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + +# --------------------------------------------------------------------------- +# Tests (CPU only, no backend dependency) + + +class TestQuantizeSaveLoadRoundtrip(unittest.TestCase): + def test_roundtrip_preserves_weights(self): + """quantize → save → load recovers all weights and configs.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.safetensors") + save(quantized, unquantized, path) + q_loaded, u_loaded = load(path) + + self.assertEqual(set(quantized.keys()), set(q_loaded.keys())) + for fqn in quantized: + self.assertEqual(quantized[fqn].config, q_loaded[fqn].config) + self.assertTrue(torch.equal(quantized[fqn].qdata, q_loaded[fqn].qdata)) + self.assertTrue(torch.equal(quantized[fqn].scale, q_loaded[fqn].scale)) + + self.assertEqual(set(unquantized.keys()), set(u_loaded.keys())) + for fqn in unquantized: + self.assertTrue(torch.equal(unquantized[fqn], u_loaded[fqn])) + + def test_embedding_quantized_as_int8(self): + """embed_tokens is quantized to INT8 per-axis, not skipped.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + + self.assertIn("embed_tokens.weight", quantized) + self.assertNotIn("embed_tokens.weight", unquantized) + self.assertEqual(quantized["embed_tokens.weight"].config.bits, 8) + + def test_corrupted_checkpoint_detected(self): + """Renaming a key in the safetensors file causes a load-time error.""" + from safetensors import safe_open + + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + path = os.path.join(tmpdir, "model.safetensors") + + with safe_open(path, framework="pt", device="cpu") as f: + header = f.metadata() + tensors = {k: f.get_tensor(k) for k in f.keys()} + tensors["norm.BOGUS"] = tensors.pop("norm.weight") + save_file(tensors, path, metadata=header) + + q, u = load(path) + # norm.weight is now missing from unquantized, norm.BOGUS is unexpected. + # pack_model would fail, but we can verify at the load level: + self.assertNotIn("norm.weight", u) + + +if __name__ == "__main__": + unittest.main() From cecb42d89bc84df502735fd2894a9827d9feb0e6 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 29 Apr 2026 14:48:34 -0700 Subject: [PATCH 02/14] Ring-buffer KV cache, chunked prefill, INT8 embedding, and cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Sliding window layers use RingKVCache (2×window) instead of flat max_seq_len buffer, reducing KV cache memory for long sequences. - Prefill is capped to ring buffer size; the C++ runner chunks longer prompts automatically via get_max_prefill_chunk metadata. - Both recipes now quantize embed_tokens to INT8 per-axis (~1.4 GB savings vs bf16). Embedding packer uses IntxUnpackedToInt8Tensor which supports gather. - pack_model handles top-level FQNs (no parent module). - C++ runner aligned with Qwen patterns: #ifdef guards for non-CUDA builds, better weight_sharing error handling, cudaDeviceSynchronize between prefill and decode. - Test suite split into test_pipeline.py (CPU) and test_cuda_pipeline.py (CUDA) with shared fixtures. New chunked prefill correctness test. - Prequantized checkpoint available at huggingface.co/SocialLocalMobile/gemma-4-31B-it-HQQ-INT4. - Added Gemma 4 31B tests to cuda.yml CI workflow. - Cleaned up stale terminology, docstrings, and comments throughout. --- examples/models/gemma4_31b/README.md | 29 +++- examples/models/gemma4_31b/export.py | 5 +- examples/models/gemma4_31b/main.cpp | 141 +++++++++++------- examples/models/gemma4_31b/model.md | 8 +- examples/models/gemma4_31b/model.py | 3 + examples/models/gemma4_31b/quant/pack.py | 6 +- examples/models/gemma4_31b/quant/recipe.py | 2 +- .../models/gemma4_31b/quant/test_recipe.py | 16 +- .../models/gemma4_31b/quant/test_serialize.py | 2 +- .../models/gemma4_31b/test_cuda_pipeline.py | 51 +++++++ examples/models/gemma4_31b/test_pipeline.py | 7 +- 11 files changed, 195 insertions(+), 75 deletions(-) diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 85fa67844be..3dcb958d8cc 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -32,17 +32,36 @@ Two built-in recipes (see `quantize_and_save.py`): | `default` | INT4 min_max linears, INT8 per-axis embedding | | `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding | -## Quantize once +## Prequantized checkpoint + +A prequantized checkpoint (sensitive recipe) is available on HuggingFace: + +```bash +huggingface-cli download SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 --local-dir gemma-4-31B-it-HQQ-INT4 +``` + +> **Note**: This checkpoint is intended for development and testing of the +> ExecuTorch CUDA export pipeline. Output quality has not been formally +> evaluated against the base model. + +Use it directly with `--prequantized` in the export and inference scripts +below — no need to run `quantize_and_save.py`. + +## Quantize from scratch (optional) + +To quantize from the original bf16 checkpoint instead, pass +`--quant-recipe` to select a recipe (`default` or `sensitive`): ```bash python examples/models/gemma4_31b/quantize_and_save.py \ - --model-dir ~/local/scripts/models/gemma-4-31B-it \ + --model-dir /path/to/gemma-4-31B-it \ --output ./gemma4_31b_int4 \ - --quant-recipe default + --quant-recipe sensitive ``` -Writes `model.safetensors`, `config.json`, and -`tokenizer.json` into `--output`. +See [Quantization recipes](#quantization-recipes) above for details on each +recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into +`--output`. ## Export to ExecuTorch diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 53e66fcd646..f2bf054015e 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -161,7 +161,9 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - strict=True, ) - max_prefill = config.max_seq_len - 1 + # Cap prefill length to the ring-buffer KV cache size (2×sliding_window). + # Longer prompts are chunked by the runner. + max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) seq_dim = Dim("seq_len", min=2, max=max_prefill) print(f"Exporting prefill (T in [2, {max_prefill}])...") with torch.no_grad(): @@ -199,6 +201,7 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - "get_max_seq_len": config.max_seq_len, "get_vocab_size": config.vocab_size, "get_n_layers": config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 23526119e7b..b12e1b87db9 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -34,20 +34,20 @@ #include #endif -DEFINE_string(model_path, "", "Path to model.pte."); -DEFINE_string(data_path, "", "Path to model.ptd (CUDA tensor data)."); +DEFINE_string(model_path, "", "Model .pte file path."); +DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend."); DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); DEFINE_string(prompt, "Hello", "Prompt text."); DEFINE_string( prompt_file, "", - "Optional path to a file with the prompt text (overrides --prompt)."); + "Path to file containing prompt text (overrides --prompt)."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_bool( cuda_graph, false, - "Enable CUDA graph capture for the decode method."); + "Enable CUDA graph capture for the decode method. CUDA only."); namespace llm = ::executorch::extension::llm; using ::executorch::extension::from_blob; @@ -57,8 +57,6 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; -// The model performs sampling on-device and returns a [B, 1] float tensor -// holding a token ID. Copy it to host and convert to uint64. static uint64_t read_token(const executorch::aten::Tensor& output) { const void* ptr = output.const_data_ptr(); float val = 0.0f; @@ -135,12 +133,14 @@ int main(int argc, char** argv) { /*temp_allocator=*/nullptr, /*share_memory_arenas=*/true); + // Get metadata auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); if (metadata_result.error() != Error::Ok) { ET_LOG(Error, "Failed to read model metadata"); return 1; } +#ifdef EXECUTORCH_BUILD_CUDA if (FLAGS_cuda_graph) { executorch::runtime::BackendOptions<2> cuda_opts; cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); @@ -154,14 +154,30 @@ int main(int argc, char** argv) { // load_method. { executorch::runtime::BackendOptions<1> backend_options; - if (backend_options.set_option("weight_sharing_across_methods", true) != - Error::Ok || - executorch::runtime::set_option( - "CudaBackend", backend_options.view()) != Error::Ok) { - ET_LOG(Error, "Failed to enable weight_sharing_across_methods"); + auto set_err = + backend_options.set_option("weight_sharing_across_methods", true); + if (set_err != Error::Ok) { + ET_LOG( + Error, + "Failed to construct weight_sharing_across_methods option: %d", + static_cast(set_err)); + return 1; + } + auto opt_err = + executorch::runtime::set_option("CudaBackend", backend_options.view()); + if (opt_err != Error::Ok) { + ET_LOG( + Error, + "Failed to enable weight_sharing_across_methods: %d", + static_cast(opt_err)); return 1; } } +#else + if (FLAGS_cuda_graph) { + ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); + } +#endif printf("Loading methods...\n"); if (module->load_method("prefill") != Error::Ok) { @@ -181,6 +197,7 @@ int main(int argc, char** argv) { auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); + // Read prompt from file or flag std::string prompt_text = FLAGS_prompt; if (!FLAGS_prompt_file.empty()) { std::ifstream f(FLAGS_prompt_file); @@ -189,10 +206,11 @@ int main(int argc, char** argv) { Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str()); return 1; } - prompt_text.assign( + prompt_text = std::string( (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } + // Encode prompt auto encode_result = tokenizer->encode(prompt_text); if (!encode_result.ok()) { ET_LOG(Error, "Failed to encode prompt"); @@ -207,49 +225,66 @@ int main(int argc, char** argv) { auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - // Temperature: clamp 0 to a tiny epsilon so the divide in the exported - // sampler stays well-defined. Gumbel noise then becomes negligible - // relative to logit gaps and we get effectively-greedy sampling. +#ifdef EXECUTORCH_BUILD_CUDA + // CUDA build: model fuses the sampler. Pass temperature as a third input. float temp_val = FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); auto temp_tensor = from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); +#endif // --------------------------------------------------------------- - // Prefill + // Prefill (chunked to respect ring-buffer KV cache limit) // --------------------------------------------------------------- - std::string run_method = "prefill"; - if (num_prompt_tokens == 1) { - // prefill was exported with min seq_len=2; decode handles T==1. - run_method = "decode"; + // Sliding layers use a ring buffer sized to 2×sliding_window. A single + // prefill call must not exceed this size, otherwise index_copy_ with + // wrapped indices produces non-deterministic results on CUDA. + int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; + { + auto get_result = module->get("get_max_prefill_chunk"); + if (get_result.ok()) { + max_prefill_chunk = get_result->toScalar().to(); + } } - std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); - std::vector pos_data(num_prompt_tokens); - for (int64_t i = 0; i < num_prompt_tokens; i++) { - pos_data[i] = i; - } - auto tokens_tensor = from_blob( - token_data.data(), - {1, S(num_prompt_tokens)}, - executorch::aten::ScalarType::Long); - auto pos_tensor = from_blob( - pos_data.data(), - {S(num_prompt_tokens)}, - executorch::aten::ScalarType::Long); - - std::vector prefill_inputs = { - EValue(tokens_tensor), - EValue(pos_tensor), - EValue(temp_tensor), - }; - - auto prefill_result = module->execute(run_method, prefill_inputs); - if (prefill_result.error() != Error::Ok) { - ET_LOG(Error, "%s failed", run_method.c_str()); - return 1; + uint64_t cur_token = 0; + int64_t prefill_pos = 0; + while (prefill_pos < num_prompt_tokens) { + int64_t chunk_len = + std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk); + + std::string run_method = (chunk_len == 1) ? "decode" : "prefill"; + + std::vector token_data( + prompt_tokens.begin() + prefill_pos, + prompt_tokens.begin() + prefill_pos + chunk_len); + std::vector pos_data(chunk_len); + for (int64_t i = 0; i < chunk_len; i++) { + pos_data[i] = prefill_pos + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, S(chunk_len)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); + + std::vector prefill_inputs; + prefill_inputs.push_back(EValue(tokens_tensor)); + prefill_inputs.push_back(EValue(pos_tensor)); +#ifdef EXECUTORCH_BUILD_CUDA + prefill_inputs.push_back(EValue(temp_tensor)); +#endif + + auto prefill_result = module->execute(run_method, prefill_inputs); + if (prefill_result.error() != Error::Ok) { + ET_LOG( + Error, "%s failed at pos %" PRId64, run_method.c_str(), prefill_pos); + return 1; + } + cur_token = read_token(prefill_result.get()[0].toTensor()); + prefill_pos += chunk_len; } - uint64_t cur_token = read_token(prefill_result.get()[0].toTensor()); stats.prompt_eval_end_ms = llm::time_in_ms(); double prefill_ms = @@ -261,8 +296,9 @@ int main(int argc, char** argv) { num_prompt_tokens * 1000.0 / prefill_ms); #ifdef EXECUTORCH_BUILD_CUDA - // Make prefill's writes to the shared KV cache visible before decode - // potentially runs on a different stream. + // Synchronize CUDA device to ensure prefill's writes to shared mutable + // buffers (KV cache) are visible to the decode method, which may run on + // a different CUDA stream. cudaDeviceSynchronize(); #endif @@ -282,11 +318,12 @@ int main(int argc, char** argv) { decode_token_data[0] = static_cast(cur_token); decode_pos_data[0] = pos; - std::vector decode_inputs = { - EValue(decode_tokens), - EValue(decode_pos), - EValue(temp_tensor), - }; + std::vector decode_inputs; + decode_inputs.push_back(EValue(decode_tokens)); + decode_inputs.push_back(EValue(decode_pos)); +#ifdef EXECUTORCH_BUILD_CUDA + decode_inputs.push_back(EValue(temp_tensor)); +#endif auto decode_result = module->execute("decode", decode_inputs); if (decode_result.error() != Error::Ok) { diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index c6a20d5c306..1f8ccb16f6f 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -105,7 +105,7 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, | Method | Input | Output (sampled) | |-----------|------------------------------------------------------------|------------------| | `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | -| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, max_seq_len-1] | `(1, 1)` float | +| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float | Both methods share the same KV-cache buffers via `MemoryPlanningPass(share_mutable_buffers=True)` and @@ -113,6 +113,12 @@ Both methods share the same KV-cache buffers via sampling on-device and returns a single token ID per call so the C++ runner only has to feed tokens. +Prefill length is capped to the ring-buffer KV cache size +(`2 × sliding_window`) to avoid duplicate wrapped indices in +`index_copy_`. The C++ runner chunks longer prompts automatically using +the `get_max_prefill_chunk` constant method. Chunked prefill produces +identical logits to sequential one-token-at-a-time prefill. + ## Quantization Three modules in `quant/`: diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index f9d4d4c9060..7366a57bf46 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -89,6 +89,9 @@ def update( k_val: torch.Tensor, v_val: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + # seq_len must not exceed buf_size, otherwise wrapped indices contain + # duplicates and index_copy_ is non-deterministic on CUDA. The C++ + # runner must chunk prefill to respect this limit. wrapped = input_pos % self.buf_size self.k_cache.index_copy_(2, wrapped, k_val) self.v_cache.index_copy_(2, wrapped, v_val) diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py index 5a1e792b56d..544e96287e9 100644 --- a/examples/models/gemma4_31b/quant/pack.py +++ b/examples/models/gemma4_31b/quant/pack.py @@ -61,11 +61,13 @@ def pack_model( module_weights: dict[str, dict[str, CanonicalQuantizedWeight]] = defaultdict(dict) for fqn, cw in quantized.items(): - parent_fqn, attr = fqn.rsplit(".", 1) + parts = fqn.rsplit(".", 1) + parent_fqn = parts[0] if len(parts) > 1 else "" + attr = parts[-1] module_weights[parent_fqn][attr] = cw for parent_fqn, weights in module_weights.items(): - module = model.get_submodule(parent_fqn) + module = model.get_submodule(parent_fqn) if parent_fqn else model packer = packers.get(type(module)) if packer is None: raise ValueError( diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py index 0edb0491640..49294c9b579 100644 --- a/examples/models/gemma4_31b/quant/recipe.py +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -20,7 +20,7 @@ class QuantConfig: """Per-weight quantization parameters.""" - bits: int # 4, 6, 8 + bits: int # 4, 8 group_size: int # 32, 64, 128 symmetric: bool # True = no zero point method: str # "min_max" | "hqq" diff --git a/examples/models/gemma4_31b/quant/test_recipe.py b/examples/models/gemma4_31b/quant/test_recipe.py index 6bd04a936a3..5b7afd992e0 100644 --- a/examples/models/gemma4_31b/quant/test_recipe.py +++ b/examples/models/gemma4_31b/quant/test_recipe.py @@ -13,7 +13,7 @@ from .recipe import QuantConfig, QuantRecipe, QuantRule _Q4 = QuantConfig(4, 32, True, "min_max") -_Q6 = QuantConfig(6, 32, False, "min_max") +_Q8 = QuantConfig(8, 32, True, "min_max") class TestQuantRecipeGetConfig(unittest.TestCase): @@ -23,13 +23,13 @@ class TestQuantRecipeGetConfig(unittest.TestCase): [ ( "first_match_wins", - [QuantRule(r".*v_proj\.weight", _Q6), QuantRule(r".*\.weight", _Q4)], + [QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)], "layers.0.self_attn.v_proj.weight", - 6, + 8, ), ( "fallthrough_to_catchall", - [QuantRule(r".*v_proj\.weight", _Q6), QuantRule(r".*\.weight", _Q4)], + [QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)], "layers.0.self_attn.q_proj.weight", 4, ), @@ -85,13 +85,13 @@ def test_layer_filter(self): recipe = QuantRecipe( rules=[ QuantRule(r".*norm\.weight", None), - QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q6, layers=edge), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q8, layers=edge), QuantRule(r".*\.weight", _Q4), ] ) - # Edge v_proj → 6-bit - self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 6) - self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 6) + # Edge v_proj → 8-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 8) + self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 8) # Middle v_proj → falls through → 4-bit self.assertEqual(recipe.get_config("layers.30.self_attn.v_proj.weight").bits, 4) # q_proj always 4-bit diff --git a/examples/models/gemma4_31b/quant/test_serialize.py b/examples/models/gemma4_31b/quant/test_serialize.py index 302c38647ed..d84e53d0a0b 100644 --- a/examples/models/gemma4_31b/quant/test_serialize.py +++ b/examples/models/gemma4_31b/quant/test_serialize.py @@ -7,7 +7,7 @@ """Unit tests for quant/serialize.py — data format and I/O only. Tests nibble pack/unpack and save/load. Does NOT test -quantize_weight (that lives in test_quantize.py). Bundle tests use +quantize_weight (that lives in test_quantize.py). Save/load tests use hand-built CanonicalQuantizedWeight fixtures to avoid coupling to the quantizer. """ diff --git a/examples/models/gemma4_31b/test_cuda_pipeline.py b/examples/models/gemma4_31b/test_cuda_pipeline.py index a7df5d9818c..faae59f160f 100644 --- a/examples/models/gemma4_31b/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/test_cuda_pipeline.py @@ -79,6 +79,57 @@ def test_generate(self): self.assertGreater(len(out_greedy), 0) +class TestChunkedPrefill(unittest.TestCase): + """Verify that chunked prefill matches one-token-at-a-time prefill.""" + + def setUp(self): + _require_cuda(self) + + def test_chunked_prefill_matches_sequential(self): + """Long prompt chunked across ring buffer gives same logits as sequential.""" + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + model_seq, config = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + model_chunk, _ = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + + _move_to_cuda(model_seq, config) + _move_to_cuda(model_chunk, config) + model_seq.eval() + model_chunk.eval() + + buf_size = config.sliding_window * 2 + prompt_len = buf_size + 8 # exceeds buf_size + torch.manual_seed(0) + prompt = torch.randint(0, config.vocab_size, (1, prompt_len), device="cuda") + + # Sequential: one token at a time (temperature=None returns logits) + with torch.no_grad(): + for i in range(prompt_len): + tok = prompt[:, i : i + 1] + pos = torch.tensor([i], dtype=torch.long, device="cuda") + logits_seq = model_seq(tok, pos, None) + + # Chunked: two chunks respecting buf_size + with torch.no_grad(): + chunk1 = prompt[:, :buf_size] + pos1 = torch.arange(buf_size, dtype=torch.long, device="cuda") + model_chunk(chunk1, pos1, None) + + chunk2 = prompt[:, buf_size:] + pos2 = torch.arange(buf_size, prompt_len, dtype=torch.long, device="cuda") + logits_chunk = model_chunk(chunk2, pos2, None) + + # Compare last-token logits (skip sampling to avoid RNG differences) + self.assertTrue( + torch.equal(logits_seq[0, -1], logits_chunk[0, -1]), + f"Max diff: {(logits_seq[0, -1] - logits_chunk[0, -1]).abs().max()}", + ) + + class TestCudaExport(unittest.TestCase): def setUp(self): _require_cuda(self) diff --git a/examples/models/gemma4_31b/test_pipeline.py b/examples/models/gemma4_31b/test_pipeline.py index 317c1b3a0ff..85ec3a671f0 100644 --- a/examples/models/gemma4_31b/test_pipeline.py +++ b/examples/models/gemma4_31b/test_pipeline.py @@ -186,8 +186,8 @@ def test_embedding_quantized_as_int8(self): self.assertNotIn("embed_tokens.weight", unquantized) self.assertEqual(quantized["embed_tokens.weight"].config.bits, 8) - def test_corrupted_checkpoint_detected(self): - """Renaming a key in the safetensors file causes a load-time error.""" + def test_corrupted_checkpoint_missing_key(self): + """Renaming a key in the safetensors file makes it absent after load.""" from safetensors import safe_open with tempfile.TemporaryDirectory() as tmpdir: @@ -201,9 +201,8 @@ def test_corrupted_checkpoint_detected(self): save_file(tensors, path, metadata=header) q, u = load(path) - # norm.weight is now missing from unquantized, norm.BOGUS is unexpected. - # pack_model would fail, but we can verify at the load level: self.assertNotIn("norm.weight", u) + self.assertIn("norm.BOGUS", u) if __name__ == "__main__": From 8768019ae00ff1bfc2dceedf9e53022873c1c7e2 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 30 Apr 2026 06:59:54 -0700 Subject: [PATCH 03/14] Add GGUF import, dequantize_weight, pack_one, and test reorganization - quant/gguf.py: unpack Q4_K/Q6_K GGUF blocks to CanonicalQuantizedWeight, with iter_gguf_tensors for streaming (low peak memory). Validated against original bf16 weights (Q4_K: 7.9%, Q6_K: 1.9% error). - gguf_loader.py: Gemma 4 31B GGUF key mapping + load_gguf_model. Handles tied embed/lm_head: embedding dequantized to bf16 (gather), lm_head keeps Q4_K (tinygemm matmul). - export.py and inference.py: --gguf flag for direct GGUF file loading. - quant/quantize.py: dequantize_weight (inverse of quantize_weight). - quant/pack.py: pack_one for single-weight streaming; pack_model delegates to pack_one for unquantized, groups quantized by parent for multi-weight modules (MoE-compatible). - quant/serialize.py: CanonicalQuantizedWeight.__post_init__ validation (dtype, shape, symmetric/zero consistency). - Tests moved to tests/ folders (quant/tests/ and tests/). --- .github/workflows/cuda.yml | 2 +- examples/models/gemma4_31b/README.md | 10 + examples/models/gemma4_31b/export.py | 14 +- examples/models/gemma4_31b/gguf_loader.py | 165 ++++++++++++ examples/models/gemma4_31b/inference.py | 69 +++-- examples/models/gemma4_31b/model.md | 23 +- examples/models/gemma4_31b/quant/README.md | 17 +- examples/models/gemma4_31b/quant/__init__.py | 4 +- examples/models/gemma4_31b/quant/gguf.py | 219 ++++++++++++++++ examples/models/gemma4_31b/quant/pack.py | 56 ++-- examples/models/gemma4_31b/quant/quantize.py | 17 ++ examples/models/gemma4_31b/quant/serialize.py | 20 ++ .../gemma4_31b/quant/tests/test_gguf.py | 239 ++++++++++++++++++ .../quant/{ => tests}/test_pack_cuda.py | 79 +++++- .../quant/{ => tests}/test_quantize.py | 52 ++-- .../quant/{ => tests}/test_recipe.py | 8 +- .../quant/{ => tests}/test_serialize.py | 6 +- .../{ => tests}/test_cuda_pipeline.py | 2 +- .../gemma4_31b/{ => tests}/test_pipeline.py | 0 19 files changed, 917 insertions(+), 85 deletions(-) create mode 100644 examples/models/gemma4_31b/gguf_loader.py create mode 100644 examples/models/gemma4_31b/quant/gguf.py create mode 100644 examples/models/gemma4_31b/quant/tests/test_gguf.py rename examples/models/gemma4_31b/quant/{ => tests}/test_pack_cuda.py (82%) rename examples/models/gemma4_31b/quant/{ => tests}/test_quantize.py (80%) rename examples/models/gemma4_31b/quant/{ => tests}/test_recipe.py (98%) rename examples/models/gemma4_31b/quant/{ => tests}/test_serialize.py (98%) rename examples/models/gemma4_31b/{ => tests}/test_cuda_pipeline.py (98%) rename examples/models/gemma4_31b/{ => tests}/test_pipeline.py (100%) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index d1b954820ef..e9d4d3db89e 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -149,7 +149,7 @@ jobs: python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) - python -m pytest examples/models/gemma4_31b/quant/ examples/models/gemma4_31b/test_pipeline.py examples/models/gemma4_31b/test_cuda_pipeline.py -v -o "addopts=" + python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" export-model-cuda-artifact: name: export-model-cuda-artifact diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 3dcb958d8cc..2b1c92e31d8 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -16,6 +16,7 @@ both export and eager inference: | `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU | | `export.py --prequantized ` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing | | `inference.py --prequantized ` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU | +| `inference.py --gguf ` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU | | `export.py --model-dir ` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing | The quantized checkpoint is a safetensors file with int values + per-group @@ -85,6 +86,15 @@ python examples/models/gemma4_31b/inference.py \ --temperature 0.8 ``` +GGUF files from the community (e.g., Q4_K_M) can also be used directly: + +```bash +python examples/models/gemma4_31b/inference.py \ + --gguf ./gemma-4-31B-it-Q4_K_M.gguf \ + --tokenizer-path /path/to/tokenizer.json \ + --prompt "Hello" +``` + Useful before spending the export+lowering time to confirm the quantized model produces sensible text. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index f2bf054015e..69eabd6b70c 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -10,9 +10,10 @@ - "decode": T=1, static shape, returns the next sampled token. - "prefill": T>=2, dynamic shape, returns the next sampled token. -Two input paths: +Three input paths: --prequantized Load a quantized checkpoint (from quantize_and_save.py) and pack for the target backend. No re-quantization. + --gguf Load a GGUF file (e.g., Q4_K_M from the community). --model-dir Load bf16 checkpoint, quantize, pack, and export in one shot. @@ -251,6 +252,11 @@ def main() -> None: default=None, help="Path to a quantized checkpoint directory. Skips quantization.", ) + src.add_argument( + "--gguf", + default=None, + help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).", + ) parser.add_argument( "--output-dir", default="./gemma4_31b_exports", @@ -285,6 +291,12 @@ def main() -> None: max_seq_len=args.max_seq_len, backend=args.backend, ) + elif args.gguf: + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + model, config = load_gguf_model( + args.gguf, max_seq_len=args.max_seq_len, backend=args.backend + ) else: model, config = load_and_quantize( args.model_dir, diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py new file mode 100644 index 00000000000..c55e15be888 --- /dev/null +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Load a GGUF file into a Gemma 4 31B model. + +Streams tensors one at a time via ``iter_gguf_tensors`` for low peak +memory, remaps GGUF names to model FQNs, handles tied embed/lm_head, +and packs for the target backend. + +Usage: + model, config = load_gguf_model("model.gguf", backend="cuda") +""" + +from typing import Optional + +import torch + +# GGUF pattern → model FQN pattern. ``{}`` is the layer index. +_KEY_MAP = { + "token_embd.weight": "embed_tokens.weight", + "output_norm.weight": "norm.weight", + # Per-layer attention + "blk.{}.attn_q.weight": "layers.{}.self_attn.q_proj.weight", + "blk.{}.attn_k.weight": "layers.{}.self_attn.k_proj.weight", + "blk.{}.attn_v.weight": "layers.{}.self_attn.v_proj.weight", + "blk.{}.attn_output.weight": "layers.{}.self_attn.o_proj.weight", + "blk.{}.attn_q_norm.weight": "layers.{}.self_attn.q_norm.weight", + "blk.{}.attn_k_norm.weight": "layers.{}.self_attn.k_norm.weight", + # Per-layer norms + "blk.{}.attn_norm.weight": "layers.{}.input_layernorm.weight", + "blk.{}.post_attention_norm.weight": "layers.{}.post_attention_layernorm.weight", + "blk.{}.ffn_norm.weight": "layers.{}.pre_feedforward_layernorm.weight", + "blk.{}.post_ffw_norm.weight": "layers.{}.post_feedforward_layernorm.weight", + # Per-layer MLP + "blk.{}.ffn_gate.weight": "layers.{}.mlp.gate_proj.weight", + "blk.{}.ffn_up.weight": "layers.{}.mlp.up_proj.weight", + "blk.{}.ffn_down.weight": "layers.{}.mlp.down_proj.weight", + # Per-layer scalar + "blk.{}.layer_output_scale.weight": "layers.{}.layer_scalar", +} + +_IGNORED_KEYS = {"rope_freqs.weight"} + + +def gguf_to_model_key(gguf_key: str) -> Optional[str]: + """Map a GGUF tensor name to a model FQN, or ``None`` to skip.""" + if gguf_key in _IGNORED_KEYS: + return None + + for gguf_pat, model_pat in _KEY_MAP.items(): + if "{}" not in gguf_pat: + if gguf_key == gguf_pat: + return model_pat + continue + prefix, suffix = gguf_pat.split("{}") + if gguf_key.startswith(prefix) and gguf_key.endswith(suffix): + layer_str = gguf_key[len(prefix) : len(gguf_key) - len(suffix)] + if layer_str.isdigit(): + return model_pat.replace("{}", layer_str) + + return None + + +def _resolve_tied_lm_head(model, embed_cw, packers): + """Handle tied embed/lm_head after streaming all tensors.""" + from executorch.examples.models.gemma4_31b.quant import pack_one + + lm_head = getattr(model.lm_head, "weight", None) + if lm_head is None or lm_head.device.type != "meta": + return + if embed_cw is not None: + pack_one(model, "lm_head.weight", embed_cw, packers) + else: + pack_one( + model, + "lm_head.weight", + model.embed_tokens.weight.data.clone(), + packers, + ) + + +def _validate_no_meta(model): + """Ensure all parameters have been loaded.""" + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in GGUF file " + f"(model/checkpoint version mismatch?)" + ) + for p in model.parameters(): + p.requires_grad_(False) + + +def load_gguf_model( + gguf_path: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple: + """Load a GGUF file, remap keys, and pack for the target backend. + + Streams tensors one at a time for low peak memory. + + GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor. + We untie them: the embedding is dequantized to bf16 (``nn.Embedding`` + needs gather, which ``Int4TilePackedTo4dTensor`` does not support), + while ``lm_head`` keeps the original Q4_K quantization (``nn.Linear`` + matmul via tinygemm). + + Returns ``(model, config)``. + """ + from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig + from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one + from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors + from executorch.examples.models.gemma4_31b.quant.serialize import ( + CanonicalQuantizedWeight, + ) + + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS + + packers = DEFAULT_CUDA_PACKERS + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + config = Gemma4_31BConfig(max_seq_len=max_seq_len) + + print("Building model on meta device...") + with torch.device("meta"): + model = Gemma4_31B(config) + + embed_cw = None + n_processed = 0 + + print(f"Streaming GGUF from {gguf_path}...") + for gguf_name, result in iter_gguf_tensors(gguf_path): + model_key = gguf_to_model_key(gguf_name) + if model_key is None: + continue + + if isinstance(result, torch.Tensor) and result.dtype == torch.float32: + result = result.to(torch.bfloat16) + + if model_key == "embed_tokens.weight" and isinstance( + result, CanonicalQuantizedWeight + ): + embed_cw = result + result = dequantize_weight(result, torch.bfloat16) + + pack_one(model, model_key, result, packers) + + n_processed += 1 + if n_processed % 100 == 0: + print(f" Processed {n_processed} tensors...") + + _resolve_tied_lm_head(model, embed_cw, packers) + del embed_cw + + _validate_no_meta(model) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index 59418f3b746..a5c51ac66a9 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Eager inference on a prequantized Gemma 4 31B-IT model (CUDA + torch.compile). +"""Eager inference on Gemma 4 31B-IT (CUDA + torch.compile). -Loads a quantized checkpoint (from ``quantize_and_save.py``), packs for CUDA, -materializes runtime buffers, optionally compiles with ``torch.compile``, and -generates text autoregressively. The model performs Gumbel-max sampling -on-device, so each forward returns the next token ID as a float tensor of -shape ``[B, 1]``. +Two input paths: + --prequantized Load a quantized checkpoint (from quantize_and_save.py). + --gguf Load a GGUF file (e.g., Q4_K_M from the community). + +Packs for the target backend (--backend cuda), materializes runtime buffers, +optionally compiles with ``torch.compile``, and generates text autoregressively. Usage: python inference.py \\ @@ -18,6 +19,11 @@ --prompt "Write a short joke about saving RAM." \\ --max-new-tokens 128 \\ --temperature 0.8 + + python inference.py \\ + --gguf ./gemma-4-31B-it-Q4_K_M.gguf \\ + --tokenizer-path ./tokenizer.json \\ + --prompt "Hello" """ import argparse @@ -113,14 +119,23 @@ def generate( def main() -> None: - parser = argparse.ArgumentParser( - description="Eager inference on prequantized Gemma 4 31B-IT (CUDA)." - ) - parser.add_argument( + parser = argparse.ArgumentParser(description="Eager inference on Gemma 4 31B-IT.") + src = parser.add_mutually_exclusive_group(required=True) + src.add_argument( "--prequantized", - required=True, + default=None, help="Path to a quantized checkpoint directory.", ) + src.add_argument( + "--gguf", + default=None, + help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).", + ) + parser.add_argument( + "--tokenizer-path", + default=None, + help="Path to tokenizer.json (required with --gguf, optional with --prequantized).", + ) parser.add_argument("--prompt", default="Hello", help="Input prompt.") parser.add_argument( "--max-new-tokens", @@ -145,15 +160,28 @@ def main() -> None: action="store_true", help="Skip torch.compile (slower, but easier to debug).", ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend.", + ) args = parser.parse_args() - if not torch.cuda.is_available(): - parser.error("CUDA is required for inference.") + if args.backend == "cuda" and not torch.cuda.is_available(): + parser.error("CUDA is required for the cuda backend.") - print(f"Loading prequantized model from {args.prequantized}...") - model, config = load_prequantized_model( - args.prequantized, max_seq_len=args.max_seq_len - ) + if args.gguf: + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + model, config = load_gguf_model( + args.gguf, args.max_seq_len, backend=args.backend + ) + else: + print(f"Loading prequantized model from {args.prequantized}...") + model, config = load_prequantized_model( + args.prequantized, max_seq_len=args.max_seq_len, backend=args.backend + ) _move_to_cuda(model, config) model.eval() @@ -161,7 +189,12 @@ def main() -> None: print("Compiling model with torch.compile...") model = torch.compile(model, mode="default") - tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + if args.tokenizer_path: + tokenizer_path = args.tokenizer_path + elif args.prequantized: + tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + else: + parser.error("--tokenizer-path is required with --gguf.") from tokenizers import Tokenizer tokenizer = Tokenizer.from_file(tokenizer_path) diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 1f8ccb16f6f..622ff657b42 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -121,20 +121,19 @@ identical logits to sequential one-token-at-a-time prefill. ## Quantization -Three modules in `quant/`: +Modules in `quant/`: -- **Recipe** (`recipe.py`): `QuantConfig` (bits, group_size, symmetric, - method) + `QuantRule` (regex pattern, config, optional layer filter) + - `QuantRecipe` (ordered rules, first match wins). Declares what to - quantize and how — says nothing about packing or backends. +- **Recipe** (`recipe.py`): `QuantConfig` + `QuantRule` + `QuantRecipe`. + Declares what to quantize — says nothing about packing or backends. +- **Quantize** (`quantize.py`): `quantize_weight` / `dequantize_weight` / + `quantize_model`. Produces `CanonicalQuantizedWeight` from fp weights. - **Serialize** (`serialize.py`): `CanonicalQuantizedWeight` (int8 qdata + - bf16 scale + optional zero). `save()` / `load()` persist to safetensors - with a JSON header per weight. Packing-agnostic — any backend can read - the file. -- **Packer** (`pack_cuda.py`): converts `CanonicalQuantizedWeight` to - backend runtime format at load time via `pack_model()`. Dispatches per - parent module type (`nn.Linear` → `Int4TilePackedTo4dTensor` for - tinygemm). Extensible via a packers dict. + bf16 scale + optional zero). `save()` / `load()` persist to safetensors. +- **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by + parent module, `pack_one` handles single weights. Per-module packers + dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE). +- **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for + loading community-quantized GGUF files (Q4_K, Q6_K). The quantize-once flow: diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index 01a74434487..f14d02af2cf 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -7,10 +7,11 @@ Packing-agnostic quantization framework: **recipe → quantize → serialize → | File | Concern | Depends on | |---|---|---| | `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | -| `quantize.py` | **Computation** — produces canonical weights from fp weights | recipe, torchao | +| `quantize.py` | **Computation** — produces/dequantizes canonical weights | recipe, torchao | | `serialize.py` | **Data format** — saves/loads canonical weights to safetensors | recipe | -| `pack.py` | **Packing dispatch** — walks model, dispatches to per-module packers | serialize | +| `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | serialize | | `pack_cuda.py` | **CUDA packing** — converts canonical to tinygemm/intx runtime format | pack, serialize | +| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to canonical form | recipe, serialize | ## Data flow @@ -79,10 +80,8 @@ symmetric, and method per weight. Unquantized weights stored as-is. `IntxWeightOnlyConfig` subclass for the `mlx::gather_qmm` kernel. For MoE models, stack per-expert weights into `SwitchLinear` format. -- `gguf.py` — read a GGUF file and convert to `CanonicalQuantizedWeight` - dicts, enabling `load() → pack_model()` from community-quantized GGUF - checkpoints without re-quantizing from bf16. Maps GGUF quant types - (Q4_K, Q6_K, Q8_0, etc.) to `QuantConfig` and unpacks super-blocks - into the canonical qdata + scale + zero layout. For CUDA packing, - Q6_K would be widened to 8-bit (`pack_int8_for_cuda`) since there is - no 6-bit CUDA kernel — lossless, ~33% more memory than true 6-bit. +- `gguf.py` — extend with Q5_K, Q8_0, and other GGUF quant types. + Currently supports Q4_K and Q6_K. Some Q4_K_M files also contain + Q5_K or Q8_0 tensors (for sensitive layers on certain architectures) + which will raise — add support as needed. Q6_K is widened to 8-bit + for CUDA packing since there is no 6-bit CUDA kernel. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py index 23d321f0c0b..96a227eb966 100644 --- a/examples/models/gemma4_31b/quant/__init__.py +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .pack import ModulePackerFn, pack_model # noqa: F401 +from .pack import ModulePackerFn, pack_model, pack_one # noqa: F401 from .pack_cuda import ( # noqa: F401 DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda, @@ -13,7 +13,7 @@ pack_int8_for_cuda, pack_linear_for_cuda, ) -from .quantize import quantize_model, quantize_weight # noqa: F401 +from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401 from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 from .serialize import ( # noqa: F401 CanonicalQuantizedWeight, diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py new file mode 100644 index 00000000000..27f244ed2c1 --- /dev/null +++ b/examples/models/gemma4_31b/quant/gguf.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unpack GGUF quantized tensors to CanonicalQuantizedWeight. + +Supports Q4_K, Q6_K, F32, and F16 tensor types. Two public APIs: + + - ``unpack_gguf_tensor`` — convert a single tensor + - ``iter_gguf_tensors`` — stream all tensors from a file (low peak memory) + +Model-agnostic. For Gemma 4 31B key mapping and model loading, see +``gguf_loader.py``. +""" + +from collections.abc import Iterator + +import torch + +from .recipe import QuantConfig +from .serialize import CanonicalQuantizedWeight + +QK_K = 256 # super-block size for k-quants +Q4_K_GROUPS = 8 # sub-blocks per Q4_K super-block +Q4_K_GROUP_SIZE = QK_K // Q4_K_GROUPS # 32 +Q6_K_GROUPS = 16 # sub-blocks per Q6_K super-block +Q6_K_GROUP_SIZE = QK_K // Q6_K_GROUPS # 16 + + +def _raw_tensor(data): + """Wrap a numpy mmap view as a uint8 torch tensor (zero-copy).""" + return torch.frombuffer(memoryview(data), dtype=torch.uint8) + + +def _read_f16(raw, col_start, col_end): + """Read fp16 field from block bytes, return float32.""" + return raw[:, col_start:col_end].contiguous().view(torch.float16).float() + + +def _unpack_q4_k(data, shape: list[int]) -> CanonicalQuantizedWeight: + """Unpack Q4_K super-blocks into canonical form. + + Q4_K block layout (144 bytes per 256 values): + - d (2B, fp16): super-block scale + - dmin (2B, fp16): super-block min + - scales (12B): 8 sub-block scales + 8 sub-block mins, 6-bit packed + - qs (128B): 256 4-bit values, two per byte + + Dequant: weight = d * sub_scale * q - dmin * sub_min + """ + N, K = shape + assert K % QK_K == 0, f"Q4_K requires K divisible by {QK_K}, got {K}" + n_blocks = N * (K // QK_K) + block_bytes = 2 + 2 + 12 + QK_K // 2 # 144 + raw = _raw_tensor(data).reshape(n_blocks, block_bytes) + + d = _read_f16(raw, 0, 2) # (n_blocks, 1) + dmin = _read_f16(raw, 2, 4) # (n_blocks, 1) + s = raw[:, 4:16] # (n_blocks, 12) + qs = raw[:, 16:144] # (n_blocks, 128) + + # Unpack 6-bit scales/mins and compute effective scale/zero directly + sc = torch.empty(n_blocks, 8, dtype=torch.float32) + mn = torch.empty(n_blocks, 8, dtype=torch.float32) + sc[:, :4] = (s[:, :4] & 0x3F).float() + mn[:, :4] = (s[:, 4:8] & 0x3F).float() + sc[:, 4:] = ((s[:, 8:12] & 0xF) | ((s[:, :4] >> 6) << 4)).float() + mn[:, 4:] = ((s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)).float() + del s + + eff_scale = (d * sc).reshape(N, -1) + eff_min = (dmin * mn).reshape(N, -1) + del d, dmin, sc, mn + + zero_std = torch.where( + eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) + ) + del eff_min + + # GGUF Q4_K nibble order: for each 32-byte group, 32 low nibbles come + # first (positions 0..31), then 32 high nibbles (positions 32..63). + low = (qs & 0x0F).to(torch.int8) # (n_blocks, 128) + high = ((qs >> 4) & 0x0F).to(torch.int8) + qdata = torch.cat( + [ + low[:, :32], + high[:, :32], + low[:, 32:64], + high[:, 32:64], + low[:, 64:96], + high[:, 64:96], + low[:, 96:128], + high[:, 96:128], + ], + dim=-1, + ) # (n_blocks, 256) + del qs, low, high + + return CanonicalQuantizedWeight( + qdata=qdata.reshape(N, K), + scale=eff_scale.to(torch.bfloat16), + zero=zero_std.to(torch.bfloat16), + config=QuantConfig( + bits=4, group_size=Q4_K_GROUP_SIZE, symmetric=False, method="gguf_q4_k" + ), + ) + + +def _unpack_q6_k(data, shape: list[int]) -> CanonicalQuantizedWeight: + """Unpack Q6_K super-blocks into canonical form as INT8. + + Q6_K block layout (210 bytes per 256 values): + - ql (128B): lower 4 bits of 256 6-bit values + - qh (64B): upper 2 bits of 256 6-bit values + - scales (16B): 16 int8 sub-block scales (groups of 16) + - d (2B, fp16): super-block scale + + Dequant: weight = d * scale_j * (q - 32) + Values are 6-bit [-32, 31], widened to INT8 for canonical storage. + """ + N, K = shape + assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" + n_blocks = N * (K // QK_K) + block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 + raw = _raw_tensor(data).reshape(n_blocks, block_bytes) + + ql = raw[:, 0:128] + qh = raw[:, 128:192] + sc = raw[:, 192:208] + d = _read_f16(raw, 208, 210) + + # Combine 4-bit low + 2-bit high into 6-bit, center to [-32, 31]. + # ggml processes 128 values at a time: ql[0..63] + qh[0..31] for the + # first half, ql[64..127] + qh[32..63] for the second half. + qh0 = qh[:, :32] # first 32 qh bytes → first 128 values + qh1 = qh[:, 32:64] # second 32 qh bytes → next 128 values + qdata = torch.empty(n_blocks, QK_K, dtype=torch.int16) + # First 128 values + qdata[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) + qdata[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) + qdata[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) + qdata[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) + # Second 128 values + qdata[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) + qdata[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) + qdata[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) + qdata[:, 224:256] = ((ql[:, 96:128] >> 4) & 0x0F) | (((qh1 >> 6) & 0x03) << 4) + qdata -= 32 + del ql, qh, qh0, qh1 + + eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) + del d, sc + + return CanonicalQuantizedWeight( + qdata=qdata.reshape(N, K).to(torch.int8), + scale=eff_scale.to(torch.bfloat16), + zero=None, + config=QuantConfig( + bits=8, group_size=Q6_K_GROUP_SIZE, symmetric=True, method="gguf_q6_k" + ), + ) + + +def unpack_gguf_tensor( + tensor_data, + tensor_type, + shape: list[int], +) -> CanonicalQuantizedWeight | torch.Tensor: + """Unpack a single GGUF tensor into canonical form or a plain tensor. + + Args: + tensor_data: raw numpy/mmap data from GGUFReader + tensor_type: GGMLQuantizationType enum value + shape: tensor shape in PyTorch convention [out_features, in_features] + + Returns: + ``CanonicalQuantizedWeight`` for quantized types (Q4_K, Q6_K), + ``torch.Tensor`` for unquantized types (F32, F16). + """ + from gguf import GGMLQuantizationType + + if tensor_type == GGMLQuantizationType.Q4_K: + return _unpack_q4_k(tensor_data, shape) + elif tensor_type == GGMLQuantizationType.Q6_K: + return _unpack_q6_k(tensor_data, shape) + elif tensor_type == GGMLQuantizationType.F32: + return _raw_tensor(tensor_data).view(torch.float32).reshape(shape).clone() + elif tensor_type == GGMLQuantizationType.F16: + return ( + _raw_tensor(tensor_data) + .view(torch.float16) + .reshape(shape) + .to(torch.bfloat16) + ) + else: + raise ValueError(f"Unsupported GGUF quant type: {tensor_type}") + + +def iter_gguf_tensors( + path: str, +) -> Iterator[tuple[str, CanonicalQuantizedWeight | torch.Tensor]]: + """Yield ``(name, result)`` for each tensor in a GGUF file. + + Processes one tensor at a time for low peak memory. ``result`` is a + ``CanonicalQuantizedWeight`` for quantized types or a ``torch.Tensor`` + for F32/F16. Tensor names are GGUF names (e.g., ``blk.0.attn_q.weight``); + the caller handles key remapping. + + GGUF shapes are reversed to PyTorch convention automatically. + """ + from gguf import GGUFReader + + reader = GGUFReader(path) + for tensor in reader.tensors: + shape = list(reversed(tensor.shape.tolist())) + result = unpack_gguf_tensor(tensor.data, tensor.tensor_type, shape) + yield tensor.name, result diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py index 544e96287e9..6d218387f4f 100644 --- a/examples/models/gemma4_31b/quant/pack.py +++ b/examples/models/gemma4_31b/quant/pack.py @@ -14,7 +14,6 @@ Pure logic — no file I/O, no backend imports. """ -from collections import defaultdict from typing import Callable import torch @@ -27,21 +26,6 @@ ModulePackerFn = Callable[[nn.Module, dict[str, CanonicalQuantizedWeight]], None] -def _assign_unquantized(model: nn.Module, unquantized: dict[str, torch.Tensor]) -> None: - """Assign plain (unquantized) tensors to model parameters and buffers.""" - model_sd_keys = set(model.state_dict().keys()) - for fqn, tensor in unquantized.items(): - if fqn not in model_sd_keys: - continue - parts = fqn.rsplit(".", 1) - parent = model.get_submodule(parts[0]) if len(parts) > 1 else model - attr_name = parts[-1] - if isinstance(getattr(parent, attr_name, None), nn.Parameter): - setattr(parent, attr_name, nn.Parameter(tensor, requires_grad=False)) - else: - parent.register_buffer(attr_name, tensor) - - def pack_model( model: nn.Module, quantized: dict[str, CanonicalQuantizedWeight], @@ -57,7 +41,13 @@ def pack_model( Pure logic — no file I/O, no backend dependency. """ - _assign_unquantized(model, unquantized) + for fqn, tensor in unquantized.items(): + pack_one(model, fqn, tensor, packers) + + # Group quantized weights by parent module so packers that need + # multiple weights at once (e.g., FusedMoEExperts with w1 + w2) + # receive them in a single call. + from collections import defaultdict module_weights: dict[str, dict[str, CanonicalQuantizedWeight]] = defaultdict(dict) for fqn, cw in quantized.items(): @@ -85,3 +75,35 @@ def pack_model( for p in model.parameters(): p.requires_grad_(False) + + +def pack_one( + model: nn.Module, + fqn: str, + value: CanonicalQuantizedWeight | torch.Tensor, + packers: dict[type, ModulePackerFn], +) -> None: + """Pack a single weight into ``model``. + + If ``value`` is a ``CanonicalQuantizedWeight``, dispatches to the + packer for the parent module's type. If it's a plain tensor, assigns + directly as a parameter or buffer. + """ + parts = fqn.rsplit(".", 1) + parent_fqn = parts[0] if len(parts) > 1 else "" + attr = parts[-1] + parent = model.get_submodule(parent_fqn) if parent_fqn else model + + if isinstance(value, CanonicalQuantizedWeight): + packer = packers.get(type(parent)) + if packer is None: + raise ValueError( + f"No packer registered for {type(parent).__name__} at '{parent_fqn}'. " + f"Registered types: {[t.__name__ for t in packers]}." + ) + packer(parent, {attr: value}) + else: + if isinstance(getattr(parent, attr, None), nn.Parameter): + setattr(parent, attr, nn.Parameter(value, requires_grad=False)) + else: + parent.register_buffer(attr, value) diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index 0ebfd032681..365c1ab0176 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -186,6 +186,23 @@ def quantize_weight( ) +def dequantize_weight( + cw: CanonicalQuantizedWeight, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Dequantize a ``CanonicalQuantizedWeight`` back to a dense tensor. + + Inverse of ``quantize_weight``. Useful for embedding lookups (which + need dense weights) or for inspecting quantized values. + """ + gs = cw.config.group_size + scale = cw.scale.float().repeat_interleave(gs, dim=-1) + if cw.zero is not None: + zero = cw.zero.float().repeat_interleave(gs, dim=-1) + return ((cw.qdata.float() - zero) * scale).to(dtype) + return (cw.qdata.float() * scale).to(dtype) + + # --------------------------------------------------------------------------- # Per-model quantization diff --git a/examples/models/gemma4_31b/quant/serialize.py b/examples/models/gemma4_31b/quant/serialize.py index 5996599ad90..ee3cb594cce 100644 --- a/examples/models/gemma4_31b/quant/serialize.py +++ b/examples/models/gemma4_31b/quant/serialize.py @@ -54,6 +54,26 @@ class CanonicalQuantizedWeight: zero: Optional[torch.Tensor] config: QuantConfig + def __post_init__(self): + if self.qdata.dtype != torch.int8: + raise ValueError(f"qdata must be int8, got {self.qdata.dtype}") + K = self.qdata.shape[-1] + if K % self.config.group_size != 0: + raise ValueError( + f"Last dim ({K}) must be divisible by group_size ({self.config.group_size})" + ) + n_groups = K // self.config.group_size + expected_numel = self.qdata[..., 0:1].numel() * n_groups + if self.scale.numel() != expected_numel: + raise ValueError( + f"scale has {self.scale.numel()} elements, expected {expected_numel} " + f"(from qdata {tuple(self.qdata.shape)}, group_size={self.config.group_size})" + ) + if self.config.symmetric and self.zero is not None: + raise ValueError("symmetric config must have zero=None") + if not self.config.symmetric and self.zero is None: + raise ValueError("asymmetric config must have zero (not None)") + # --------------------------------------------------------------------------- # Nibble packing for 4-bit on-disk storage. diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py new file mode 100644 index 00000000000..7d6134b4de3 --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/gguf.py — Q4_K and Q6_K unpacking. + +Tests verify the API contract: dequantized canonical weights match the +original GGUF dequantization formula. Uses synthetic blocks — no GGUF +file required. +""" + +import struct +import unittest + +import numpy as np +import torch + +from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor +from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight +from executorch.examples.models.gemma4_31b.quant.serialize import deserialize, serialize + +try: + from gguf import GGMLQuantizationType + + _Q4_K = GGMLQuantizationType.Q4_K + _Q6_K = GGMLQuantizationType.Q6_K +except ImportError: + _Q4_K = None + _Q6_K = None + + +def _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals): + """Build one Q4_K block (144 bytes) from components.""" + buf = bytearray(144) + struct.pack_into("> 4) << 6 + scales_bytes[j] |= (sub_mins[j] >> 4) << 6 + buf[4:16] = scales_bytes + # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair + for g in range(4): + for i in range(32): + lo_val = qvals[g * 64 + i] + hi_val = qvals[g * 64 + 32 + i] + buf[16 + g * 32 + i] = (lo_val & 0xF) | ((hi_val & 0xF) << 4) + return buf + + +def _make_q6_k_block(d, scales_16, qvals_256): + """Build one Q6_K block (210 bytes) from components. + + ggml processes 128 values at a time. For each 128-value half: + ql: 64 bytes (two groups of 32, low/high nibbles) + qh: 32 bytes (2 bits each for 4 sub-positions) + The qvals_256 array is in output order (position 0..255). + """ + buf = bytearray(210) + # First half (positions 0..127): ql bytes 0..63, qh bytes 0..31 + for i in range(32): + buf[i] = (qvals_256[i] & 0x0F) | ((qvals_256[i + 64] & 0x0F) << 4) + for i in range(32): + buf[32 + i] = (qvals_256[i + 32] & 0x0F) | ((qvals_256[i + 96] & 0x0F) << 4) + for i in range(32): + h0 = (qvals_256[i] >> 4) & 0x03 + h1 = (qvals_256[i + 32] >> 4) & 0x03 + h2 = (qvals_256[i + 64] >> 4) & 0x03 + h3 = (qvals_256[i + 96] >> 4) & 0x03 + buf[128 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) + # Second half (positions 128..255): ql bytes 64..127, qh bytes 32..63 + for i in range(32): + buf[64 + i] = (qvals_256[i + 128] & 0x0F) | ((qvals_256[i + 192] & 0x0F) << 4) + for i in range(32): + buf[96 + i] = (qvals_256[i + 160] & 0x0F) | ((qvals_256[i + 224] & 0x0F) << 4) + for i in range(32): + h0 = (qvals_256[i + 128] >> 4) & 0x03 + h1 = (qvals_256[i + 160] >> 4) & 0x03 + h2 = (qvals_256[i + 192] >> 4) & 0x03 + h3 = (qvals_256[i + 224] >> 4) & 0x03 + buf[160 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) + # Scales and d + for i in range(16): + buf[192 + i] = scales_16[i] & 0xFF + struct.pack_into(" Date: Thu, 30 Apr 2026 07:19:59 -0700 Subject: [PATCH 04/14] Fix symmetric INT4 dequantization and harden tests - dequantize_weight now subtracts 8 from symmetric 4-bit qdata (stored as unsigned [0,15]) before scaling, matching the quantize_weight shift - Guard test_gguf.py with skipUnless so CI doesn't break without gguf - Install gguf in cuda.yml for GGUF test coverage - Use torch.allclose instead of torch.equal for chunked prefill logit comparison to avoid CUDA FP flakiness - Fix Usage docblock paths in test_pipeline.py and test_cuda_pipeline.py --- .github/workflows/cuda.yml | 1 + examples/models/gemma4_31b/quant/quantize.py | 10 ++++-- .../gemma4_31b/quant/tests/test_gguf.py | 34 +++++++++---------- .../gemma4_31b/quant/tests/test_quantize.py | 8 ++++- .../gemma4_31b/tests/test_cuda_pipeline.py | 16 ++++++--- .../models/gemma4_31b/tests/test_pipeline.py | 2 +- 6 files changed, 46 insertions(+), 25 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index e9d4d3db89e..087917c1116 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -149,6 +149,7 @@ jobs: python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) + pip install gguf python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" export-model-cuda-artifact: diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index 365c1ab0176..247ccc3bc24 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -197,10 +197,16 @@ def dequantize_weight( """ gs = cw.config.group_size scale = cw.scale.float().repeat_interleave(gs, dim=-1) + qdata = cw.qdata.float() + # Symmetric 4-bit qdata is stored as unsigned [0, 15] (shifted +8 in + # quantize_weight). Undo the shift to recover signed [-8, 7] before + # scaling. (Q4_K is asymmetric and uses a zero field instead.) + if cw.config.bits == 4 and cw.zero is None: + qdata = qdata - 8 if cw.zero is not None: zero = cw.zero.float().repeat_interleave(gs, dim=-1) - return ((cw.qdata.float() - zero) * scale).to(dtype) - return (cw.qdata.float() * scale).to(dtype) + return ((qdata - zero) * scale).to(dtype) + return (qdata * scale).to(dtype) # --------------------------------------------------------------------------- diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index 7d6134b4de3..503ab2099f0 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -17,18 +17,18 @@ import numpy as np import torch -from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor -from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight -from executorch.examples.models.gemma4_31b.quant.serialize import deserialize, serialize - try: from gguf import GGMLQuantizationType - _Q4_K = GGMLQuantizationType.Q4_K - _Q6_K = GGMLQuantizationType.Q6_K + _HAS_GGUF = True except ImportError: - _Q4_K = None - _Q6_K = None + _HAS_GGUF = False + +if _HAS_GGUF: + from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor + +from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight +from executorch.examples.models.gemma4_31b.quant.serialize import deserialize, serialize def _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals): @@ -113,6 +113,7 @@ def _q6_k_reference_dequant(d, scales_16, qvals_256): return result +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestQ4KDequant(unittest.TestCase): def test_dequant_matches_reference(self): """Canonical dequant reproduces the GGUF Q4_K formula across all sub-blocks.""" @@ -123,7 +124,7 @@ def test_dequant_matches_reference(self): block = _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 144) - cw = unpack_gguf_tensor(data, _Q4_K, [1, 256]) + cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q4_K, [1, 256]) actual = dequantize_weight(cw)[0] expected = torch.tensor( @@ -139,7 +140,7 @@ def test_zero_scale_produces_zero(self): """Scale=0 produces zero dequantized values (not dmin*min).""" block = _make_q4_k_block(0.0, 1.0, [0] * 8, [1] * 8, [7] * 256) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 144) - cw = unpack_gguf_tensor(data, _Q4_K, [1, 256]) + cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q4_K, [1, 256]) dequant = dequantize_weight(cw) self.assertFalse(torch.isnan(dequant).any()) self.assertFalse(torch.isinf(dequant).any()) @@ -149,6 +150,7 @@ def test_zero_scale_produces_zero(self): self.assertTrue((dequant == 0).all()) +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestQ6KDequant(unittest.TestCase): def test_dequant_matches_reference(self): """Canonical dequant reproduces the GGUF Q6_K formula.""" @@ -158,7 +160,7 @@ def test_dequant_matches_reference(self): block = _make_q6_k_block(d, scales_16, qvals) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) - cw = unpack_gguf_tensor(data, _Q6_K, [1, 256]) + cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q6_K, [1, 256]) actual = dequantize_weight(cw)[0] expected = torch.tensor(_q6_k_reference_dequant(d, scales_16, qvals)) @@ -169,6 +171,7 @@ def test_dequant_matches_reference(self): ) +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestGgufSerializeRoundtrip(unittest.TestCase): def test_q4_k_survives_serialize_roundtrip(self): """unpack → serialize → deserialize → dequant matches original.""" @@ -179,7 +182,7 @@ def test_q4_k_survives_serialize_roundtrip(self): block = _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 144) - cw = unpack_gguf_tensor(data, _Q4_K, [1, 256]) + cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q4_K, [1, 256]) dequant_before = dequantize_weight(cw) @@ -200,7 +203,7 @@ def test_q6_k_survives_serialize_roundtrip(self): block = _make_q6_k_block(d, scales_16, qvals) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) - cw = unpack_gguf_tensor(data, _Q6_K, [1, 256]) + cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q6_K, [1, 256]) dequant_before = dequantize_weight(cw) @@ -214,12 +217,11 @@ def test_q6_k_survives_serialize_roundtrip(self): ) +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestUnpackGgufTensor(unittest.TestCase): """Tests for the public ``unpack_gguf_tensor`` API.""" def test_f32_returns_tensor(self): - from gguf import GGMLQuantizationType - data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) result = unpack_gguf_tensor(data, GGMLQuantizationType.F32, [4]) self.assertIsInstance(result, torch.Tensor) @@ -227,8 +229,6 @@ def test_f32_returns_tensor(self): self.assertEqual(result.tolist(), [1.0, 2.0, 3.0, 4.0]) def test_unsupported_type_raises(self): - from gguf import GGMLQuantizationType - with self.assertRaises(ValueError): unpack_gguf_tensor( np.zeros(10, dtype=np.uint8), GGMLQuantizationType.Q5_K, [1, 10] diff --git a/examples/models/gemma4_31b/quant/tests/test_quantize.py b/examples/models/gemma4_31b/quant/tests/test_quantize.py index 11b5132cc2b..0dbbbff3167 100644 --- a/examples/models/gemma4_31b/quant/tests/test_quantize.py +++ b/examples/models/gemma4_31b/quant/tests/test_quantize.py @@ -84,11 +84,17 @@ def test_dequantize_output_dtype(self): self.assertEqual(dequantize_weight(cw, torch.float16).dtype, torch.float16) def test_dequantize_symmetric(self): + torch.manual_seed(1) + weight = torch.randn(32, 64, dtype=torch.bfloat16) config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + cw = quantize_weight(weight, config) self.assertIsNone(cw.zero) dequant = dequantize_weight(cw) self.assertEqual(dequant.shape, (32, 64)) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) @parameterized.expand( [ diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 81501eea522..228b34d9261 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -12,7 +12,7 @@ Requires CUDA. Usage: - python -m pytest examples/models/gemma4_31b/test_cuda_pipeline.py -v + python -m pytest examples/models/gemma4_31b/tests/test_cuda_pipeline.py -v """ import os @@ -123,10 +123,18 @@ def test_chunked_prefill_matches_sequential(self): pos2 = torch.arange(buf_size, prompt_len, dtype=torch.long, device="cuda") logits_chunk = model_chunk(chunk2, pos2, None) - # Compare last-token logits (skip sampling to avoid RNG differences) + # Compare last-token logits (skip sampling to avoid RNG differences). + # Use allclose rather than equal — CUDA kernels can produce small FP + # differences across execution shapes. + max_diff = (logits_seq[0, -1].float() - logits_chunk[0, -1].float()).abs().max() self.assertTrue( - torch.equal(logits_seq[0, -1], logits_chunk[0, -1]), - f"Max diff: {(logits_seq[0, -1] - logits_chunk[0, -1]).abs().max()}", + torch.allclose( + logits_seq[0, -1].float(), + logits_chunk[0, -1].float(), + atol=1e-2, + rtol=1e-3, + ), + f"Chunked prefill diverged: max_diff={max_diff:.4g}", ) diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index 85ec3a671f0..d718a2811cb 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -11,7 +11,7 @@ ``test_cuda_pipeline.py``. Usage: - python -m pytest examples/models/gemma4_31b/test_pipeline.py -v + python -m pytest examples/models/gemma4_31b/tests/test_pipeline.py -v """ import json From 313799bc7dddb370c72caf686c4f50bd248ff193 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 30 Apr 2026 07:49:51 -0700 Subject: [PATCH 05/14] Harden edge cases and expand test coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix float→uint64 truncation in main.cpp read_token (use llrintf) - Add assert in RingKVCache.update to catch seq_len > buf_size misuse - Add RingKVCache unit tests (sequential, wraparound, multi-token, assert) - Add CanonicalQuantizedWeight __post_init__ validation error path tests - Add GGUF Q4_K through tinygemm pack pipeline test (asymmetric) - Add 8-bit asymmetric matmul test - Add F16 GGUF tensor type test - Document QuantConfig.bits as storage width and _INT8_PER_AXIS coupling --- examples/models/gemma4_31b/main.cpp | 2 +- examples/models/gemma4_31b/model.py | 3 + examples/models/gemma4_31b/quant/recipe.py | 4 +- .../gemma4_31b/quant/tests/test_gguf.py | 7 +++ .../gemma4_31b/quant/tests/test_pack_cuda.py | 51 +++++++++++++++ .../gemma4_31b/quant/tests/test_serialize.py | 61 ++++++++++++++++++ .../models/gemma4_31b/quantize_and_save.py | 4 +- .../models/gemma4_31b/tests/test_pipeline.py | 63 ++++++++++++++++++- 8 files changed, 190 insertions(+), 5 deletions(-) diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index b12e1b87db9..f86538a27c2 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -82,7 +82,7 @@ static uint64_t read_token(const executorch::aten::Tensor& output) { memcpy(&val, ptr, sizeof(float)); #endif - return static_cast(val); + return static_cast(llrintf(val)); } int main(int argc, char** argv) { diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index 7366a57bf46..4c5d2f5b97e 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -92,6 +92,9 @@ def update( # seq_len must not exceed buf_size, otherwise wrapped indices contain # duplicates and index_copy_ is non-deterministic on CUDA. The C++ # runner must chunk prefill to respect this limit. + assert ( + input_pos.shape[0] <= self.buf_size + ), f"seq_len {input_pos.shape[0]} > buf_size {self.buf_size}" wrapped = input_pos % self.buf_size self.k_cache.index_copy_(2, wrapped, k_val) self.v_cache.index_copy_(2, wrapped, v_val) diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py index 49294c9b579..e207e268c38 100644 --- a/examples/models/gemma4_31b/quant/recipe.py +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -20,10 +20,10 @@ class QuantConfig: """Per-weight quantization parameters.""" - bits: int # 4, 8 + bits: int # storage width: 4 or 8 (6-bit formats like Q6_K are widened to 8) group_size: int # 32, 64, 128 symmetric: bool # True = no zero point - method: str # "min_max" | "hqq" + method: str # "min_max" | "hqq" | "gguf_q4_k" | "gguf_q6_k" @dataclass diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index 503ab2099f0..80e6af5001a 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -228,6 +228,13 @@ def test_f32_returns_tensor(self): self.assertEqual(result.dtype, torch.float32) self.assertEqual(result.tolist(), [1.0, 2.0, 3.0, 4.0]) + def test_f16_returns_bf16_tensor(self): + data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + result = unpack_gguf_tensor(data, GGMLQuantizationType.F16, [2, 2]) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.dtype, torch.bfloat16) + self.assertEqual(result.shape, (2, 2)) + def test_unsupported_type_raises(self): with self.assertRaises(ValueError): unpack_gguf_tensor( diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index c1eb855adb2..bc429ee6372 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -103,6 +103,38 @@ def test_symmetric_matmul_approximates_original(self): ).abs().mean() / original_out.float().abs().mean() self.assertLess(rel_error.item(), 0.15) + def test_asymmetric_gguf_q4_k_matmul(self): + """Asymmetric 4-bit (GGUF Q4_K style) packs and produces correct matmul.""" + torch.manual_seed(0) + weight = torch.randn(256, 1024, dtype=torch.bfloat16) + x = torch.randn(1, 1024, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(weight, config) + # Mimic GGUF Q4_K: asymmetric with a non-standard method name + from executorch.examples.models.gemma4_31b.quant.serialize import ( + CanonicalQuantizedWeight, + ) + + cw_gguf = CanonicalQuantizedWeight( + qdata=cw.qdata, + scale=cw.scale, + zero=cw.zero, + config=QuantConfig( + bits=4, group_size=32, symmetric=False, method="gguf_q4_k" + ), + ) + packed = pack_int4_for_cuda(cw_gguf) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + class TestPackInt8ForCuda(unittest.TestCase): def setUp(self): @@ -133,6 +165,25 @@ def test_matmul_approximates_original(self): ).abs().mean() / original_out.float().abs().mean() self.assertLess(rel_error.item(), 0.02) + def test_asymmetric_matmul_approximates_original(self): + """8-bit asymmetric quantization packs and produces correct matmul.""" + torch.manual_seed(0) + weight = torch.randn(256, 128, dtype=torch.bfloat16) + x = torch.randn(1, 128, dtype=torch.bfloat16) + + original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=8, group_size=32, symmetric=False, method="min_max") + cw = quantize_weight(weight, config) + packed = pack_int8_for_cuda(cw) + + packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + def test_per_axis_gather_approximates_original(self): """Per-axis INT8 (group_size == K) works for embedding gather.""" torch.manual_seed(0) diff --git a/examples/models/gemma4_31b/quant/tests/test_serialize.py b/examples/models/gemma4_31b/quant/tests/test_serialize.py index e75876fa4cd..776e4a4d7a0 100644 --- a/examples/models/gemma4_31b/quant/tests/test_serialize.py +++ b/examples/models/gemma4_31b/quant/tests/test_serialize.py @@ -59,6 +59,67 @@ def _make_cqw( ) +# --------------------------------------------------------------------------- +# CanonicalQuantizedWeight validation + + +class TestCanonicalQuantizedWeightValidation(unittest.TestCase): + def test_rejects_non_int8_qdata(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + with self.assertRaises(ValueError) as ctx: + CanonicalQuantizedWeight( + qdata=torch.randint(0, 16, (8, 64), dtype=torch.int32), + scale=torch.randn(8, 2, dtype=torch.bfloat16), + zero=torch.randn(8, 2, dtype=torch.bfloat16), + config=config, + ) + self.assertIn("int8", str(ctx.exception)) + + def test_rejects_indivisible_group_size(self): + config = QuantConfig(bits=4, group_size=33, symmetric=False, method="min_max") + with self.assertRaises(ValueError) as ctx: + CanonicalQuantizedWeight( + qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), + scale=torch.randn(8, 2, dtype=torch.bfloat16), + zero=torch.randn(8, 2, dtype=torch.bfloat16), + config=config, + ) + self.assertIn("divisible", str(ctx.exception)) + + def test_rejects_wrong_scale_numel(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + with self.assertRaises(ValueError) as ctx: + CanonicalQuantizedWeight( + qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), + scale=torch.randn(8, 4, dtype=torch.bfloat16), # should be (8, 2) + zero=torch.randn(8, 4, dtype=torch.bfloat16), + config=config, + ) + self.assertIn("scale", str(ctx.exception)) + + def test_rejects_symmetric_with_zero(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + with self.assertRaises(ValueError) as ctx: + CanonicalQuantizedWeight( + qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), + scale=torch.randn(8, 2, dtype=torch.bfloat16), + zero=torch.randn(8, 2, dtype=torch.bfloat16), + config=config, + ) + self.assertIn("symmetric", str(ctx.exception)) + + def test_rejects_asymmetric_without_zero(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + with self.assertRaises(ValueError) as ctx: + CanonicalQuantizedWeight( + qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), + scale=torch.randn(8, 2, dtype=torch.bfloat16), + zero=None, + config=config, + ) + self.assertIn("asymmetric", str(ctx.exception)) + + # --------------------------------------------------------------------------- # Nibble pack / unpack diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py index 7a9eb9900f2..6d048d1e912 100644 --- a/examples/models/gemma4_31b/quantize_and_save.py +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -48,7 +48,9 @@ _INT4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") _INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") _INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") -_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max") +_INT8_PER_AXIS = QuantConfig( # group_size = hidden_size (5376) for Gemma 4 31B + bits=8, group_size=5376, symmetric=True, method="min_max" +) _EDGE_LAYERS = set(range(15)) | set(range(45, 60)) GEMMA4_31B_DEFAULT_RECIPE = QuantRecipe( diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index d718a2811cb..8bc1b36bfb7 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -22,7 +22,11 @@ import torch import torch.nn as nn -from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig +from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + Gemma4_31BConfig, + RingKVCache, +) from executorch.examples.models.gemma4_31b.quant import ( load, QuantConfig, @@ -205,5 +209,62 @@ def test_corrupted_checkpoint_missing_key(self): self.assertIn("norm.BOGUS", u) +class TestRingKVCache(unittest.TestCase): + """Unit tests for the ring-buffer KV cache (CPU, no model needed).""" + + def _make_cache(self, window=4, heads=2, head_dim=8): + return RingKVCache( + max_batch_size=1, window_size=window, num_kv_heads=heads, head_dim=head_dim + ) + + def test_sequential_write_read(self): + """Writing positions 0..buf_size-1 fills every slot exactly once.""" + cache = self._make_cache(window=4) + buf_size = cache.buf_size # 8 + for i in range(buf_size): + pos = torch.tensor([i], dtype=torch.long) + k = torch.full((1, 2, 1, 8), float(i)) + v = torch.full((1, 2, 1, 8), float(i + 100)) + k_out, v_out = cache.update(pos, k, v) + for i in range(buf_size): + slot = i % buf_size + self.assertEqual(k_out[0, 0, slot, 0].item(), float(i)) + self.assertEqual(v_out[0, 0, slot, 0].item(), float(i + 100)) + + def test_wraparound_overwrites_oldest(self): + """Position buf_size overwrites slot 0 (the oldest entry).""" + cache = self._make_cache(window=4) + buf_size = cache.buf_size # 8 + for i in range(buf_size + 1): + pos = torch.tensor([i], dtype=torch.long) + k = torch.full((1, 2, 1, 8), float(i)) + v = torch.full((1, 2, 1, 8), float(i)) + k_out, _ = cache.update(pos, k, v) + # Slot 0 should now contain position buf_size (not 0) + self.assertEqual(k_out[0, 0, 0, 0].item(), float(buf_size)) + # Slot 1 should still contain position 1 + self.assertEqual(k_out[0, 0, 1, 0].item(), 1.0) + + def test_multi_token_prefill(self): + """Writing multiple positions in one call places them correctly.""" + cache = self._make_cache(window=4) + pos = torch.arange(4, dtype=torch.long) + k = torch.arange(4).float().view(1, 1, 4, 1).expand(1, 2, 4, 8) + v = torch.zeros(1, 2, 4, 8) + k_out, _ = cache.update(pos, k, v) + for i in range(4): + self.assertEqual(k_out[0, 0, i, 0].item(), float(i)) + + def test_assert_on_oversized_prefill(self): + """seq_len > buf_size raises AssertionError.""" + cache = self._make_cache(window=4) + buf_size = cache.buf_size + pos = torch.arange(buf_size + 1, dtype=torch.long) + k = torch.zeros(1, 2, buf_size + 1, 8) + v = torch.zeros(1, 2, buf_size + 1, 8) + with self.assertRaises(AssertionError): + cache.update(pos, k, v) + + if __name__ == "__main__": unittest.main() From 89cd615a854062a80f0ce1d8601b3939f7cbdd3b Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 30 Apr 2026 14:53:09 -0400 Subject: [PATCH 06/14] Add streaming iter_load and tighten quant public API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - serialize.py: add iter_load() generator that streams weights one at a time from safetensors, keeping peak memory proportional to the largest single weight instead of loading all weights into memory at once. - pack_cuda.py: rewrite load_and_pack_for_cuda to use iter_load for streaming — avoids ~40 GB peak memory when loading the 31B checkpoint. - __init__.py: remove low-level CUDA packer internals (pack_int4_for_cuda, pack_int8_for_cuda, pack_linear_for_cuda, pack_embedding_for_cuda) from the public API. Tests import these directly from pack_cuda.py. --- examples/models/gemma4_31b/quant/__init__.py | 9 +-- examples/models/gemma4_31b/quant/pack_cuda.py | 24 +++++-- examples/models/gemma4_31b/quant/serialize.py | 48 +++++++++++++- .../gemma4_31b/quant/tests/test_serialize.py | 65 +++++++++++++++++++ 4 files changed, 132 insertions(+), 14 deletions(-) diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py index 96a227eb966..2f4ad98c864 100644 --- a/examples/models/gemma4_31b/quant/__init__.py +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -5,14 +5,7 @@ # LICENSE file in the root directory of this source tree. from .pack import ModulePackerFn, pack_model, pack_one # noqa: F401 -from .pack_cuda import ( # noqa: F401 - DEFAULT_CUDA_PACKERS, - load_and_pack_for_cuda, - pack_embedding_for_cuda, - pack_int4_for_cuda, - pack_int8_for_cuda, - pack_linear_for_cuda, -) +from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401 from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401 from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 from .serialize import ( # noqa: F401 diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index 039f2cbf7ba..a3a6257ae0c 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -17,7 +17,7 @@ import torch.nn as nn from .pack import ModulePackerFn, pack_model # noqa: F401 -from .serialize import CanonicalQuantizedWeight, load +from .serialize import CanonicalQuantizedWeight # --------------------------------------------------------------------------- @@ -202,9 +202,23 @@ def load_and_pack_for_cuda( model: nn.Module, packers: dict[type, ModulePackerFn] | None = None, ) -> None: - """Read a quantized safetensors file and pack into ``model`` for CUDA. + """Stream weights from a quantized safetensors file and pack for CUDA. - Thin wrapper: ``load`` + ``pack_model``. + Uses ``iter_load`` to process one weight at a time, keeping peak + memory proportional to the largest single weight instead of loading + all weights into memory at once. """ - quantized, unquantized = load(path) - pack_model(model, quantized, unquantized, packers or DEFAULT_CUDA_PACKERS) + from .pack import pack_one + from .serialize import iter_load + + _packers = packers or DEFAULT_CUDA_PACKERS + + for fqn, value in iter_load(path): + pack_one(model, fqn, value, _packers) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) diff --git a/examples/models/gemma4_31b/quant/serialize.py b/examples/models/gemma4_31b/quant/serialize.py index ee3cb594cce..35eae615fba 100644 --- a/examples/models/gemma4_31b/quant/serialize.py +++ b/examples/models/gemma4_31b/quant/serialize.py @@ -22,7 +22,7 @@ import json from dataclasses import dataclass -from typing import Optional +from typing import Iterator, Optional import torch from safetensors import safe_open @@ -233,3 +233,49 @@ def load( header = f.metadata() tensors = {k: f.get_tensor(k) for k in f.keys()} return deserialize(tensors, header) + + +def iter_load( + path: str, +) -> Iterator[tuple[str, CanonicalQuantizedWeight | torch.Tensor]]: + """Stream weights from a safetensors file one at a time. + + Yields ``(fqn, value)`` where *value* is a ``CanonicalQuantizedWeight`` + for quantized weights or a plain ``torch.Tensor`` for unquantized ones. + Only one weight's tensors are resident in memory at a time, keeping peak + memory proportional to the largest single weight. + """ + with safe_open(path, framework="pt", device="cpu") as f: + header = f.metadata() + quant_meta = json.loads(header.get("quant", "{}")) + all_keys = set(f.keys()) + consumed: set[str] = set() + + for fqn, meta in quant_meta.items(): + config = QuantConfig( + bits=meta["bits"], + group_size=meta["group_size"], + symmetric=meta["symmetric"], + method=meta["method"], + ) + qdata = f.get_tensor(f"{fqn}.qdata") + consumed.add(f"{fqn}.qdata") + if config.bits == 4: + qdata = _nibble_unpack(qdata, meta["shape"][-1]) + + scale = f.get_tensor(f"{fqn}.scale") + consumed.add(f"{fqn}.scale") + + zero_key = f"{fqn}.zero" + zero = None + if zero_key in all_keys: + zero = f.get_tensor(zero_key) + consumed.add(zero_key) + + yield fqn, CanonicalQuantizedWeight( + qdata=qdata, scale=scale, zero=zero, config=config + ) + + for key in all_keys: + if key not in consumed: + yield key, f.get_tensor(key) diff --git a/examples/models/gemma4_31b/quant/tests/test_serialize.py b/examples/models/gemma4_31b/quant/tests/test_serialize.py index 776e4a4d7a0..ddbcb257faa 100644 --- a/examples/models/gemma4_31b/quant/tests/test_serialize.py +++ b/examples/models/gemma4_31b/quant/tests/test_serialize.py @@ -26,6 +26,7 @@ _nibble_unpack, CanonicalQuantizedWeight, deserialize, + iter_load, load, save, serialize, @@ -264,5 +265,69 @@ def test_empty_quantized(self): self.assertTrue(torch.equal(unq["w"], u["w"])) +class TestIterLoad(unittest.TestCase): + """Streaming load — one weight at a time from disk.""" + + def test_yields_all_weights(self): + """iter_load yields every quantized and unquantized weight.""" + q4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + cw4 = _make_cqw((64, 128), q4) + cw8 = _make_cqw((32, 64), q8) + unq = {"norm.weight": torch.randn(64, dtype=torch.bfloat16)} + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"proj.weight": cw4, "embed.weight": cw8}, unq, path) + items = list(iter_load(path)) + + fqns = {fqn for fqn, _ in items} + self.assertIn("proj.weight", fqns) + self.assertIn("embed.weight", fqns) + self.assertIn("norm.weight", fqns) + self.assertEqual(len(items), 3) + + def test_quantized_matches_load(self): + """Streaming yields identical CQW to batch load.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + cw = _make_cqw((64, 128), config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"w": cw}, {}, path) + + q_batch, _ = load(path) + items = dict(iter_load(path)) + + batch_cw = q_batch["w"] + stream_cw = items["w"] + self.assertIsInstance(stream_cw, CanonicalQuantizedWeight) + self.assertTrue(torch.equal(batch_cw.qdata, stream_cw.qdata)) + self.assertTrue(torch.equal(batch_cw.scale, stream_cw.scale)) + self.assertTrue(torch.equal(batch_cw.zero, stream_cw.zero)) + self.assertEqual(batch_cw.config, stream_cw.config) + + def test_unquantized_matches_load(self): + """Streaming yields identical plain tensors to batch load.""" + unq = {"a": torch.randn(8, 16, dtype=torch.bfloat16)} + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({}, unq, path) + + _, u_batch = load(path) + items = dict(iter_load(path)) + + self.assertTrue(torch.equal(u_batch["a"], items["a"])) + + def test_empty_file(self): + """Streaming an empty checkpoint yields nothing.""" + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({}, {}, path) + items = list(iter_load(path)) + self.assertEqual(len(items), 0) + + if __name__ == "__main__": unittest.main() From 94c0d4cad47ab596f5fd7347bcd97d5fbb9cdc69 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 30 Apr 2026 15:04:59 -0400 Subject: [PATCH 07/14] Add BOS/EOS token handling to C++ runner Gemma's HuggingFace tokenizer does not auto-prepend BOS. Without it the model's logits collapse. Add --bos_id (default 2) to prepend and --eos_id (default 1) as a fallback stop token. --- examples/models/gemma4_31b/main.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index f86538a27c2..351adc03a33 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -30,6 +30,22 @@ #include #include +#include +#include +extern "C" void et_pal_emit_log_message( + ET_UNUSED et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + ET_UNUSED const char* function, + size_t line, + const char* message, + ET_UNUSED size_t length) { + if (level < 'W') { + return; + } + fprintf(stderr, "%c [%s:%zu] %s\n", (char)level, filename, line, message); +} + #ifdef EXECUTORCH_BUILD_CUDA #include #endif @@ -44,6 +60,8 @@ DEFINE_string( "Path to file containing prompt text (overrides --prompt)."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); +DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); +DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); DEFINE_bool( cuda_graph, false, @@ -196,6 +214,7 @@ int main(int argc, char** argv) { #endif auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); + eos_ids.insert(static_cast(FLAGS_eos_id)); // Read prompt from file or flag std::string prompt_text = FLAGS_prompt; @@ -217,6 +236,9 @@ int main(int argc, char** argv) { return 1; } auto prompt_tokens = std::move(*encode_result); + // Gemma models require BOS at the start of the sequence. + prompt_tokens.insert( + prompt_tokens.begin(), static_cast(FLAGS_bos_id)); int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); stats.num_prompt_tokens = num_prompt_tokens; From ad57fc9e6c73a7ec1436a0c3e51524c1a9457f62 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 1 May 2026 09:46:51 -0700 Subject: [PATCH 08/14] Replace CanonicalQuantizedWeight with torchao tensor subclasses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete the custom CanonicalQuantizedWeight dataclass and serialize.py format. Quantized weights are now stored as torchao's native Int4Tensor (4-bit) and IntxUnpackedToInt8Tensor (8-bit) subclasses, serialized via torchao's safetensors integration. Key changes: - quantize_weight returns Int4Tensor or IntxUnpackedToInt8Tensor - quantize_model returns a single state_dict (not two dicts) - 8-bit quantization done in float32 to avoid bf16 precision loss (manual quantize + direct IntxUnpackedToInt8Tensor construction) - Sensitive recipe uses HQQ asymmetric INT4 (scale + zero optimization) - pack_model takes a single state_dict, dispatches by isinstance - pack.py uses TorchAOBaseTensor for quantized weight detection - GGUF unpacker produces Int4Tensor/IntxUnpackedToInt8Tensor directly - serialize.py dissolved — callers inline torchao safetensors directly Breaking change: existing prequantized checkpoints (old format) must be regenerated with quantize_and_save.py. --- examples/models/gemma4_31b/README.md | 9 +- examples/models/gemma4_31b/export.py | 4 +- examples/models/gemma4_31b/gguf_loader.py | 24 +- examples/models/gemma4_31b/model.md | 13 +- examples/models/gemma4_31b/quant/README.md | 81 ++-- examples/models/gemma4_31b/quant/__init__.py | 7 - examples/models/gemma4_31b/quant/gguf.py | 106 +++-- examples/models/gemma4_31b/quant/pack.py | 73 ++-- examples/models/gemma4_31b/quant/pack_cuda.py | 193 ++++----- examples/models/gemma4_31b/quant/quantize.py | 236 +++++++---- examples/models/gemma4_31b/quant/recipe.py | 10 +- examples/models/gemma4_31b/quant/serialize.py | 281 ------------- .../gemma4_31b/quant/tests/test_gguf.py | 72 +++- .../gemma4_31b/quant/tests/test_pack_cuda.py | 378 +++++------------- .../gemma4_31b/quant/tests/test_quantize.py | 194 ++++----- .../quant/tests/test_safetensors_roundtrip.py | 143 +++++++ .../gemma4_31b/quant/tests/test_serialize.py | 333 --------------- .../models/gemma4_31b/quantize_and_save.py | 24 +- .../gemma4_31b/tests/test_cuda_pipeline.py | 4 +- .../models/gemma4_31b/tests/test_pipeline.py | 122 ++++-- 20 files changed, 857 insertions(+), 1450 deletions(-) delete mode 100644 examples/models/gemma4_31b/quant/serialize.py create mode 100644 examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py delete mode 100644 examples/models/gemma4_31b/quant/tests/test_serialize.py diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 2b1c92e31d8..6f567d739b7 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -19,10 +19,11 @@ both export and eager inference: | `inference.py --gguf ` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU | | `export.py --model-dir ` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing | -The quantized checkpoint is a safetensors file with int values + per-group -scales and a JSON header describing each weight's `QuantConfig`. No tensor -subclass or backend-specific packing — packing for the target backend happens -at load time via `quant.pack_model()`. +The quantized checkpoint is a safetensors file containing torchao tensor +subclasses (`Int4Tensor`, `IntxUnpackedToInt8Tensor`) and plain tensors. +Metadata records each subclass's type and attributes. No backend-specific +packing — packing for the target backend happens at load time via +`quant.pack_model()`. ## Quantization recipes diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 69eabd6b70c..78668c55118 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -82,12 +82,12 @@ def load_and_quantize( model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) print(f"Quantizing with recipe '{recipe_name}'...") - quantized, unquantized = quantize_model(model, recipe) + state_dict = quantize_model(model, recipe) print(f"Packing for {backend}...") with torch.device("meta"): model = Gemma4_31B(config) - pack_model(model, quantized, unquantized, packers=_get_packers(backend)) + pack_model(model, state_dict, packers=_get_packers(backend)) model.eval() print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index c55e15be888..3e50991e553 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -64,15 +64,15 @@ def gguf_to_model_key(gguf_key: str) -> Optional[str]: return None -def _resolve_tied_lm_head(model, embed_cw, packers): +def _resolve_tied_lm_head(model, embed_quant, packers): """Handle tied embed/lm_head after streaming all tensors.""" from executorch.examples.models.gemma4_31b.quant import pack_one lm_head = getattr(model.lm_head, "weight", None) if lm_head is None or lm_head.device.type != "meta": return - if embed_cw is not None: - pack_one(model, "lm_head.weight", embed_cw, packers) + if embed_quant is not None: + pack_one(model, "lm_head.weight", embed_quant, packers) else: pack_one( model, @@ -114,9 +114,7 @@ def load_gguf_model( from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors - from executorch.examples.models.gemma4_31b.quant.serialize import ( - CanonicalQuantizedWeight, - ) + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS @@ -131,7 +129,7 @@ def load_gguf_model( with torch.device("meta"): model = Gemma4_31B(config) - embed_cw = None + embed_quant = None n_processed = 0 print(f"Streaming GGUF from {gguf_path}...") @@ -140,13 +138,11 @@ def load_gguf_model( if model_key is None: continue - if isinstance(result, torch.Tensor) and result.dtype == torch.float32: + if type(result) is torch.Tensor and result.dtype == torch.float32: result = result.to(torch.bfloat16) - if model_key == "embed_tokens.weight" and isinstance( - result, CanonicalQuantizedWeight - ): - embed_cw = result + if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor): + embed_quant = result result = dequantize_weight(result, torch.bfloat16) pack_one(model, model_key, result, packers) @@ -155,8 +151,8 @@ def load_gguf_model( if n_processed % 100 == 0: print(f" Processed {n_processed} tensors...") - _resolve_tied_lm_head(model, embed_cw, packers) - del embed_cw + _resolve_tied_lm_head(model, embed_quant, packers) + del embed_quant _validate_no_meta(model) model.eval() diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 622ff657b42..9a8c6d84e5f 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -126,9 +126,10 @@ Modules in `quant/`: - **Recipe** (`recipe.py`): `QuantConfig` + `QuantRule` + `QuantRecipe`. Declares what to quantize — says nothing about packing or backends. - **Quantize** (`quantize.py`): `quantize_weight` / `dequantize_weight` / - `quantize_model`. Produces `CanonicalQuantizedWeight` from fp weights. -- **Serialize** (`serialize.py`): `CanonicalQuantizedWeight` (int8 qdata + - bf16 scale + optional zero). `save()` / `load()` persist to safetensors. + `quantize_model`. Produces torchao tensor subclasses (`Int4Tensor`, + `IntxUnpackedToInt8Tensor`) from fp weights. +- **Serialization**: callers use torchao's safetensors integration + (`torchao.prototype.safetensors`) directly — no wrapper module needed. - **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by parent module, `pack_one` handles single weights. Per-module packers dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE). @@ -142,11 +143,11 @@ quantize_and_save.py export.py / inference.py | | bf16 weights quantized checkpoint (safetensors) | | - quantize_weight() load() + quantize_weight() load (torchao safetensors) | | - CanonicalQuantizedWeight CanonicalQuantizedWeight + Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked | | - save() pack_model() + save (torchao safetensors) pack_model() | | model.safetensors Int4TilePackedTo4dTensor (runtime) ``` diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index f14d02af2cf..31b1c43d574 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -1,27 +1,29 @@ # quant/ -Packing-agnostic quantization framework: **recipe → quantize → serialize → pack**. +Quantization framework: **recipe → quantize → pack**. ## Files | File | Concern | Depends on | |---|---|---| | `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | -| `quantize.py` | **Computation** — produces/dequantizes canonical weights | recipe, torchao | -| `serialize.py` | **Data format** — saves/loads canonical weights to safetensors | recipe | -| `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | serialize | -| `pack_cuda.py` | **CUDA packing** — converts canonical to tinygemm/intx runtime format | pack, serialize | -| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to canonical form | recipe, serialize | +| `quantize.py` | **Computation** — produces torchao subclass tensors | recipe, torchao | +| `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | — | +| `pack_cuda.py` | **CUDA packing** — converts Int4Tensor to tinygemm format | pack | +| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao | ## Data flow ``` -QuantRecipe → quantize_model() → CanonicalQuantizedWeight → save() → file → load() → CanonicalQuantizedWeight → pack_model() → runtime model +QuantRecipe → quantize_model() → state_dict{Int4Tensor, IntxUnpackedToInt8Tensor, Tensor} → safetensors → state_dict → pack_model() → runtime model ``` -`CanonicalQuantizedWeight` is the interchange point — int8 qdata + bf16 -scale + optional zero + config. Everything left of it is backend-agnostic. -Everything right is backend-specific. +Quantized weights are stored as torchao tensor subclasses: +- **Int4Tensor** — 4-bit weights (nibble-packed qdata + transposed scale/zero_point) +- **IntxUnpackedToInt8Tensor** — 8-bit weights (int8 qdata + scale + zero_point) + +These are the canonical interchange formats from torchao. Everything left +of `save()` is backend-agnostic. Everything right is backend-specific. ## Adding a new backend @@ -32,56 +34,21 @@ def pack_linear_for_metal(module, weights): ... DEFAULT_METAL_PACKERS = {nn.Linear: pack_linear_for_metal} ``` -Call `pack_model(model, quantized, unquantized, packers=DEFAULT_METAL_PACKERS)`. -No changes to recipe, quantize, or serialize. - -Things to consider: - -- **Recipes may need to be backend-aware.** Each backend's kernels have - different constraints (e.g., Metal's `fpa4w` is INT4-only — no INT8 linear - kernel, so the sensitive recipe's 8-bit edge layers would need to be INT4 - or dequantized to bf16). Define per-backend recipes or validate recipe - compatibility at pack time. -- **Source transforms before packing.** Some backends replace model modules - (e.g., MLX swaps `FusedMoEExperts` → `SwitchMLP`, Metal swaps to - `MetalMoEExperts`). These transforms change the module types that - packers dispatch on, so they must run before `pack_model()`. For dense - models (no MoE) this is not needed. -- **Embedding quantization.** Not all backends have a quantized embedding - gather kernel. The packer can dequantize to bf16 at load time — the - disk savings from the canonical format still apply. - -## Adding a new model - -1. Define a `QuantRecipe` with rules for the model's FQN patterns. -2. If the model has custom module types (e.g., `FusedMoEExperts`), write a - per-module packer and extend the packers dict: - ```python - packers = {**DEFAULT_CUDA_PACKERS, FusedMoEExperts: pack_moe_experts} - ``` -3. No changes to the quant package itself. +Call `pack_model(model, state_dict, packers=DEFAULT_METAL_PACKERS)`. +No changes to recipe or quantize. ## On-disk format -Safetensors with a `format_version` in the header. Per quantized weight: -`{fqn}.qdata` (int8, nibble-packed for 4-bit), `{fqn}.scale` (bf16), -optionally `{fqn}.zero` (bf16). Header JSON records bits, group_size, -symmetric, and method per weight. Unquantized weights stored as-is. +Uses torchao's safetensors integration (`torchao.prototype.safetensors`). +Each tensor subclass is decomposed into its inner tensors +(e.g., `layer._weight_qdata`, `layer._weight_scale`) plus JSON metadata +recording the subclass type and attributes. Plain tensors are stored as-is. +The format is compatible with torchao's `save_pretrained` / `load_pretrained`. ## TODO -- `pack_metal.py` — Metal backend packer. Convert canonical INT4 to - `UIntxWeightOnlyConfig` subclass (torchao experimental) for the - `torchao::_linear_fp_act_4bit_weight` kernel. For MoE models, pack - expert weights into Metal's `gather_qmv` format (asymmetric, unsigned - INT4 with scale + bias buffers). - -- `pack_mlx.py` — MLX backend packer. Convert canonical INT4 to - `IntxWeightOnlyConfig` subclass for the `mlx::gather_qmm` kernel. - For MoE models, stack per-expert weights into `SwitchLinear` format. - -- `gguf.py` — extend with Q5_K, Q8_0, and other GGUF quant types. - Currently supports Q4_K and Q6_K. Some Q4_K_M files also contain - Q5_K or Q8_0 tensors (for sensitive layers on certain architectures) - which will raise — add support as needed. Q6_K is widened to 8-bit - for CUDA packing since there is no 6-bit CUDA kernel. +- `pack_metal.py` — Metal backend packer. +- `pack_mlx.py` — MLX backend packer. +- `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. +- Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao + to replace the manual conversion in `pack_int4_for_cuda`. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py index 2f4ad98c864..93efb69865f 100644 --- a/examples/models/gemma4_31b/quant/__init__.py +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -8,10 +8,3 @@ from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401 from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401 from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 -from .serialize import ( # noqa: F401 - CanonicalQuantizedWeight, - deserialize, - load, - save, - serialize, -) diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py index 27f244ed2c1..78c3aa3d8f9 100644 --- a/examples/models/gemma4_31b/quant/gguf.py +++ b/examples/models/gemma4_31b/quant/gguf.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Unpack GGUF quantized tensors to CanonicalQuantizedWeight. +"""Unpack GGUF quantized tensors to torchao tensor subclasses. Supports Q4_K, Q6_K, F32, and F16 tensor types. Two public APIs: @@ -19,9 +19,6 @@ import torch -from .recipe import QuantConfig -from .serialize import CanonicalQuantizedWeight - QK_K = 256 # super-block size for k-quants Q4_K_GROUPS = 8 # sub-blocks per Q4_K super-block Q4_K_GROUP_SIZE = QK_K // Q4_K_GROUPS # 32 @@ -29,18 +26,18 @@ Q6_K_GROUP_SIZE = QK_K // Q6_K_GROUPS # 16 -def _raw_tensor(data): +def _raw_tensor(data: bytes) -> torch.Tensor: """Wrap a numpy mmap view as a uint8 torch tensor (zero-copy).""" return torch.frombuffer(memoryview(data), dtype=torch.uint8) -def _read_f16(raw, col_start, col_end): +def _read_f16(raw: torch.Tensor, col_start: int, col_end: int) -> torch.Tensor: """Read fp16 field from block bytes, return float32.""" return raw[:, col_start:col_end].contiguous().view(torch.float16).float() -def _unpack_q4_k(data, shape: list[int]) -> CanonicalQuantizedWeight: - """Unpack Q4_K super-blocks into canonical form. +def _unpack_q4_k(data, shape: list[int]) -> torch.Tensor: + """Unpack Q4_K super-blocks into an ``Int4Tensor``. Q4_K block layout (144 bytes per 256 values): - d (2B, fp16): super-block scale @@ -50,18 +47,19 @@ def _unpack_q4_k(data, shape: list[int]) -> CanonicalQuantizedWeight: Dequant: weight = d * sub_scale * q - dmin * sub_min """ + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + N, K = shape assert K % QK_K == 0, f"Q4_K requires K divisible by {QK_K}, got {K}" n_blocks = N * (K // QK_K) block_bytes = 2 + 2 + 12 + QK_K // 2 # 144 raw = _raw_tensor(data).reshape(n_blocks, block_bytes) - d = _read_f16(raw, 0, 2) # (n_blocks, 1) - dmin = _read_f16(raw, 2, 4) # (n_blocks, 1) - s = raw[:, 4:16] # (n_blocks, 12) - qs = raw[:, 16:144] # (n_blocks, 128) + d = _read_f16(raw, 0, 2) + dmin = _read_f16(raw, 2, 4) + s = raw[:, 4:16] + qs = raw[:, 16:144] - # Unpack 6-bit scales/mins and compute effective scale/zero directly sc = torch.empty(n_blocks, 8, dtype=torch.float32) mn = torch.empty(n_blocks, 8, dtype=torch.float32) sc[:, :4] = (s[:, :4] & 0x3F).float() @@ -79,11 +77,10 @@ def _unpack_q4_k(data, shape: list[int]) -> CanonicalQuantizedWeight: ) del eff_min - # GGUF Q4_K nibble order: for each 32-byte group, 32 low nibbles come - # first (positions 0..31), then 32 high nibbles (positions 32..63). - low = (qs & 0x0F).to(torch.int8) # (n_blocks, 128) - high = ((qs >> 4) & 0x0F).to(torch.int8) - qdata = torch.cat( + # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair + low = (qs & 0x0F).to(torch.uint8) + high = ((qs >> 4) & 0x0F).to(torch.uint8) + qdata_unpacked = torch.cat( [ low[:, :32], high[:, :32], @@ -95,21 +92,24 @@ def _unpack_q4_k(data, shape: list[int]) -> CanonicalQuantizedWeight: high[:, 96:128], ], dim=-1, - ) # (n_blocks, 256) + ).reshape(N, K) del qs, low, high - return CanonicalQuantizedWeight( - qdata=qdata.reshape(N, K), - scale=eff_scale.to(torch.bfloat16), - zero=zero_std.to(torch.bfloat16), - config=QuantConfig( - bits=4, group_size=Q4_K_GROUP_SIZE, symmetric=False, method="gguf_q4_k" - ), + # Nibble-pack for Int4Tensor: even=LOW, odd=HIGH + packed = qdata_unpacked[:, ::2] | (qdata_unpacked[:, 1::2] << 4) + + # Int4Tensor scale/zero layout: (K//gs, N) — transposed + return Int4Tensor( + qdata=packed, + scale=eff_scale.to(torch.bfloat16).t().contiguous(), + zero_point=zero_std.to(torch.bfloat16).t().contiguous(), + block_size=[1, Q4_K_GROUP_SIZE], + shape=torch.Size([N, K]), ) -def _unpack_q6_k(data, shape: list[int]) -> CanonicalQuantizedWeight: - """Unpack Q6_K super-blocks into canonical form as INT8. +def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: + """Unpack Q6_K super-blocks into an ``IntxUnpackedToInt8Tensor``. Q6_K block layout (210 bytes per 256 values): - ql (128B): lower 4 bits of 256 6-bit values @@ -118,8 +118,10 @@ def _unpack_q6_k(data, shape: list[int]) -> CanonicalQuantizedWeight: - d (2B, fp16): super-block scale Dequant: weight = d * scale_j * (q - 32) - Values are 6-bit [-32, 31], widened to INT8 for canonical storage. + Values are 6-bit [-32, 31], widened to INT8. """ + from torchao.quantization import IntxUnpackedToInt8Tensor + N, K = shape assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" n_blocks = N * (K // QK_K) @@ -131,18 +133,13 @@ def _unpack_q6_k(data, shape: list[int]) -> CanonicalQuantizedWeight: sc = raw[:, 192:208] d = _read_f16(raw, 208, 210) - # Combine 4-bit low + 2-bit high into 6-bit, center to [-32, 31]. - # ggml processes 128 values at a time: ql[0..63] + qh[0..31] for the - # first half, ql[64..127] + qh[32..63] for the second half. - qh0 = qh[:, :32] # first 32 qh bytes → first 128 values - qh1 = qh[:, 32:64] # second 32 qh bytes → next 128 values + qh0 = qh[:, :32] + qh1 = qh[:, 32:64] qdata = torch.empty(n_blocks, QK_K, dtype=torch.int16) - # First 128 values qdata[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) qdata[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) qdata[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) qdata[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) - # Second 128 values qdata[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) qdata[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) qdata[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) @@ -150,16 +147,18 @@ def _unpack_q6_k(data, shape: list[int]) -> CanonicalQuantizedWeight: qdata -= 32 del ql, qh, qh0, qh1 + # sc bytes are signed int8 scales; reinterpret from uint8 eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) del d, sc - return CanonicalQuantizedWeight( + return IntxUnpackedToInt8Tensor( qdata=qdata.reshape(N, K).to(torch.int8), scale=eff_scale.to(torch.bfloat16), - zero=None, - config=QuantConfig( - bits=8, group_size=Q6_K_GROUP_SIZE, symmetric=True, method="gguf_q6_k" - ), + zero_point=torch.zeros_like(eff_scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, Q6_K_GROUP_SIZE), + dtype=torch.bfloat16, + activation_quantization=None, ) @@ -167,17 +166,11 @@ def unpack_gguf_tensor( tensor_data, tensor_type, shape: list[int], -) -> CanonicalQuantizedWeight | torch.Tensor: - """Unpack a single GGUF tensor into canonical form or a plain tensor. - - Args: - tensor_data: raw numpy/mmap data from GGUFReader - tensor_type: GGMLQuantizationType enum value - shape: tensor shape in PyTorch convention [out_features, in_features] +) -> torch.Tensor: + """Unpack a single GGUF tensor. - Returns: - ``CanonicalQuantizedWeight`` for quantized types (Q4_K, Q6_K), - ``torch.Tensor`` for unquantized types (F32, F16). + Returns an ``Int4Tensor`` for Q4_K, ``IntxUnpackedToInt8Tensor`` for Q6_K, + or a plain ``torch.Tensor`` for F32/F16. """ from gguf import GGMLQuantizationType @@ -200,15 +193,12 @@ def unpack_gguf_tensor( def iter_gguf_tensors( path: str, -) -> Iterator[tuple[str, CanonicalQuantizedWeight | torch.Tensor]]: +) -> Iterator[tuple[str, torch.Tensor]]: """Yield ``(name, result)`` for each tensor in a GGUF file. - Processes one tensor at a time for low peak memory. ``result`` is a - ``CanonicalQuantizedWeight`` for quantized types or a ``torch.Tensor`` - for F32/F16. Tensor names are GGUF names (e.g., ``blk.0.attn_q.weight``); - the caller handles key remapping. - - GGUF shapes are reversed to PyTorch convention automatically. + Processes one tensor at a time for low peak memory. Tensor names are + GGUF names (e.g., ``blk.0.attn_q.weight``); the caller handles key + remapping. GGUF shapes are reversed to PyTorch convention automatically. """ from gguf import GGUFReader diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py index 6d218387f4f..95abc43546a 100644 --- a/examples/models/gemma4_31b/quant/pack.py +++ b/examples/models/gemma4_31b/quant/pack.py @@ -4,57 +4,55 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Backend-agnostic model packing: canonical weights → runtime model. +"""Backend-agnostic model packing: quantized state dict → runtime model. -``pack_model`` walks a model's quantized weights, groups them by parent +``pack_model`` walks a state dict, groups quantized weights by parent module, and dispatches to per-module packer functions. Each backend -(``pack_cuda.py``, future ``pack_metal.py``) provides its own packers dict -mapping module types to packer functions. - -Pure logic — no file I/O, no backend imports. +(``pack_cuda.py``, future ``pack_metal.py``) provides its own packers dict. """ +from collections import defaultdict from typing import Callable import torch import torch.nn as nn -from .serialize import CanonicalQuantizedWeight - # Packer signature: receives the module + a dict of its quantized weights -# (keyed by attribute name, e.g., {"weight": CQW}), modifies module in-place. -ModulePackerFn = Callable[[nn.Module, dict[str, CanonicalQuantizedWeight]], None] +# (keyed by attribute name), modifies module in-place. +ModulePackerFn = Callable[[nn.Module, dict[str, torch.Tensor]], None] + + +def _is_quantized(value: torch.Tensor) -> bool: + """Check if a tensor is a torchao quantized subclass.""" + from torchao.utils import TorchAOBaseTensor + + return isinstance(value, TorchAOBaseTensor) def pack_model( model: nn.Module, - quantized: dict[str, CanonicalQuantizedWeight], - unquantized: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], packers: dict[type, ModulePackerFn], ) -> None: - """Pack canonical weights into ``model`` using the given packers. + """Pack a state dict into ``model`` using the given packers. - Groups quantized weights by their parent module, then dispatches to the - appropriate per-module packer based on the module's type. Models with - custom module types (e.g., ``FusedMoEExperts``) extend ``packers``. - - Pure logic — no file I/O, no backend dependency. + Quantized weights (torchao tensor subclasses) are grouped by parent + module and dispatched to per-module packers. Plain tensors are assigned + directly as parameters or buffers. """ - - for fqn, tensor in unquantized.items(): - pack_one(model, fqn, tensor, packers) - - # Group quantized weights by parent module so packers that need - # multiple weights at once (e.g., FusedMoEExperts with w1 + w2) - # receive them in a single call. - from collections import defaultdict - - module_weights: dict[str, dict[str, CanonicalQuantizedWeight]] = defaultdict(dict) - for fqn, cw in quantized.items(): - parts = fqn.rsplit(".", 1) - parent_fqn = parts[0] if len(parts) > 1 else "" - attr = parts[-1] - module_weights[parent_fqn][attr] = cw + # Separate quantized and unquantized + for fqn, value in state_dict.items(): + if not _is_quantized(value): + pack_one(model, fqn, value, packers) + + # Group quantized weights by parent module + module_weights: dict[str, dict[str, torch.Tensor]] = defaultdict(dict) + for fqn, value in state_dict.items(): + if _is_quantized(value): + parts = fqn.rsplit(".", 1) + parent_fqn = parts[0] if len(parts) > 1 else "" + attr = parts[-1] + module_weights[parent_fqn][attr] = value for parent_fqn, weights in module_weights.items(): module = model.get_submodule(parent_fqn) if parent_fqn else model @@ -80,21 +78,20 @@ def pack_model( def pack_one( model: nn.Module, fqn: str, - value: CanonicalQuantizedWeight | torch.Tensor, + value: torch.Tensor, packers: dict[type, ModulePackerFn], ) -> None: """Pack a single weight into ``model``. - If ``value`` is a ``CanonicalQuantizedWeight``, dispatches to the - packer for the parent module's type. If it's a plain tensor, assigns - directly as a parameter or buffer. + Quantized subclass tensors are dispatched to the packer for the parent + module's type. Plain tensors are assigned directly. """ parts = fqn.rsplit(".", 1) parent_fqn = parts[0] if len(parts) > 1 else "" attr = parts[-1] parent = model.get_submodule(parent_fqn) if parent_fqn else model - if isinstance(value, CanonicalQuantizedWeight): + if _is_quantized(value): packer = packers.get(type(parent)) if packer is None: raise ValueError( diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index a3a6257ae0c..21949024fcc 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -4,35 +4,38 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""CUDA packer: canonical → CUDA runtime format. +"""CUDA packer: torchao quantized tensors → CUDA runtime format. -Provides per-module packers for the CUDA backend (INT4 via tinygemm, -INT8 via ``IntxUnpackedToInt8Tensor``) and ``load_and_pack_for_cuda`` -as a convenience I/O wrapper. +Converts ``Int4Tensor`` to ``Int4TilePackedTo4dTensor`` (tinygemm) and +passes ``IntxUnpackedToInt8Tensor`` through unchanged (AOTI fuses +the dequantize-matmul pattern). The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. """ +import json + import torch import torch.nn as nn from .pack import ModulePackerFn, pack_model # noqa: F401 -from .serialize import CanonicalQuantizedWeight # --------------------------------------------------------------------------- -# Low-level: canonical → Int4TilePackedTo4dTensor (one weight at a time) +# Low-level converters def pack_int4_for_cuda( - cw: CanonicalQuantizedWeight, + weight: torch.Tensor, device: str = "cuda", ) -> nn.Parameter: - """Convert a canonical 4-bit weight to ``Int4TilePackedTo4dTensor``. + """Convert an ``Int4Tensor`` to ``Int4TilePackedTo4dTensor`` for tinygemm. + + Unpacks nibbles, pads to tinygemm alignment, tile-packs via CUDA kernel, + and builds the combined scale_and_zero tensor. - Pads K to a multiple of 1024 and N to a multiple of 8 (tinygemm - requirements), nibble-packs, then tile-packs via the CUDA kernel. - Returns an ``nn.Parameter`` wrapping the subclass tensor **on CUDA**. + TODO: replace with ``Int4TilePackedTo4dTensor.from_int4_tensor()`` once + that's upstreamed to torchao. """ from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( Int4TilePackedTo4dTensor, @@ -40,57 +43,43 @@ def pack_int4_for_cuda( from torchao.quantization.utils import pack_tinygemm_scales_and_zeros from torchao.utils import find_multiple - assert cw.config.bits == 4, f"Expected 4-bit, got {cw.config.bits}" - assert cw.qdata.ndim == 2, ( - f"pack_int4_for_cuda requires 2D weight (nn.Linear), got {cw.qdata.ndim}D " - f"shape {tuple(cw.qdata.shape)}." - ) - - original_shape = cw.qdata.shape + original_shape = weight.shape N, K = original_shape - gs = cw.config.group_size + gs = weight.block_size[-1] inner_k_tiles = 8 + # Unpack Int4Tensor nibbles to int32 + p = weight.qdata.to(torch.uint8) + low = (p & 0x0F).to(torch.int32) + high = ((p >> 4) & 0x0F).to(torch.int32) + int_data = torch.stack([low, high], dim=-1).reshape(N, K) + + # Scale/zero: Int4Tensor stores (K//gs, N), transpose to (N, K//gs) + scale = weight.scale.t().contiguous() + zero = weight.zero_point.t().contiguous() + + # Pad to tinygemm alignment K_padded = find_multiple(K, 1024) N_padded = find_multiple(N, 8) - - int_data = cw.qdata.to(torch.int32) if K_padded != K or N_padded != N: int_data = torch.nn.functional.pad(int_data, (0, K_padded - K, 0, N_padded - N)) - - scale = cw.scale - n_groups_orig = K // gs - n_groups_padded = K_padded // gs - if n_groups_padded != n_groups_orig or N_padded != N: + n_groups_padded = K_padded // gs + n_groups_orig = K // gs scale = torch.nn.functional.pad( scale, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) ) - - if cw.zero is not None: - zero = cw.zero - if n_groups_padded != n_groups_orig or N_padded != N: - zero = torch.nn.functional.pad( - zero, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) - ) - else: - # Symmetric: qdata is unsigned [0, 15] (shifted +8 from signed [-8, 7]). - # Standard convention: weight = (q - zp_std) * scale, so zp_std = 8. - zero = torch.full_like(scale, 8.0) + zero = torch.nn.functional.pad( + zero, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) + ) int_data = int_data.to(device) scale = scale.to(device) zero = zero.to(device) - # Convert zero from standard convention (weight = (q - zp_std) * scale) - # to tinygemm convention (weight = (q - 8) * scale + zp_tg). - # Derivation: (q - zp_std) * scale = (q - 8) * scale + zp_tg - # → zp_tg = (8 - zp_std) * scale - tinygemm_zero = (8 - zero.to(torch.float32)) * scale.to(torch.float32) + # Convert zero-point convention: tinygemm uses zp_tg = (8 - zp_std) * scale + tinygemm_zero = (8 - zero.float()) * scale.float() - # Tinygemm nibble convention: even index in HIGH nibble, odd in LOW. - # (This differs from serialize.py's _nibble_pack which uses the opposite - # convention for on-disk storage — both are valid, they serve different - # consumers.) + # Tinygemm nibble convention: even=HIGH, odd=LOW int_data_u8 = (int_data[:, ::2] << 4 | int_data[:, 1::2]).to(torch.uint8) packed_weight = torch.ops.aten._convert_weight_to_int4pack( int_data_u8.contiguous(), inner_k_tiles @@ -113,78 +102,37 @@ def pack_int4_for_cuda( # Per-module packers -def pack_int8_for_cuda( - cw: CanonicalQuantizedWeight, -) -> nn.Parameter: - """Convert a canonical 8-bit weight to ``IntxUnpackedToInt8Tensor``. - - Unlike INT4 (which needs tinygemm tile packing), INT8 weights are stored - unpacked. The subclass carries int8 qdata + scales and dequantizes during - matmul — AOTI fuses the ``dequantize → mm`` pattern in the compiled graph. - """ +def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: + """Pack a quantized ``nn.Linear`` for CUDA.""" from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor - assert cw.config.bits == 8, f"Expected 8-bit, got {cw.config.bits}" - assert cw.qdata.ndim == 2, f"Expected 2D weight, got {cw.qdata.ndim}D" - - N, K = cw.qdata.shape - n_groups = K // cw.config.group_size - scale = cw.scale.to(torch.bfloat16).reshape(N, n_groups) - zero_point = ( - cw.zero.to(torch.int8).reshape(N, n_groups) - if cw.zero is not None - else torch.zeros(N, n_groups, dtype=torch.int8) - ) - - subclass = IntxUnpackedToInt8Tensor( - qdata=cw.qdata, - scale=scale, - zero_point=zero_point, - target_dtype=torch.int8, - block_size=(1, cw.config.group_size), - dtype=torch.bfloat16, - activation_quantization=None, - ) - return nn.Parameter(subclass, requires_grad=False) - - -def pack_linear_for_cuda( - module: nn.Module, weights: dict[str, CanonicalQuantizedWeight] -) -> None: - """Pack a quantized ``nn.Linear`` for CUDA. - - 4-bit weights use ``Int4TilePackedTo4dTensor`` (tinygemm kernel, requires - CUDA for packing). 8-bit weights use ``IntxUnpackedToInt8Tensor`` (AOTI - fuses the dequantize-matmul pattern). Both stay as tensor subclasses so - the export graph captures quantized ops. - """ - cw = weights["weight"] - if cw.config.bits == 4: - packed = pack_int4_for_cuda(cw, device="cuda") + w = weights["weight"] + if isinstance(w, Int4Tensor): + # Pack on CUDA (required by _convert_weight_to_int4pack), move back + # to CPU for assembly. The model moves to CUDA later at runtime. + packed = pack_int4_for_cuda(w, device="cuda") module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False) torch.cuda.empty_cache() - elif cw.config.bits == 8: - module.weight = pack_int8_for_cuda(cw) + elif isinstance(w, IntxUnpackedToInt8Tensor): + module.weight = nn.Parameter(w, requires_grad=False) else: - raise ValueError(f"Unsupported bit width: {cw.config.bits}") + raise ValueError(f"Unsupported weight type: {type(w).__name__}") def pack_embedding_for_cuda( - module: nn.Module, weights: dict[str, CanonicalQuantizedWeight] + module: nn.Module, weights: dict[str, torch.Tensor] ) -> None: - """Pack a quantized ``nn.Embedding`` for CUDA. + """Pack a quantized ``nn.Embedding`` for CUDA (INT8 only).""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor - Uses ``IntxUnpackedToInt8Tensor`` which supports embedding gather. - Only INT8 is supported — ``Int4TilePackedTo4dTensor`` does not - implement the embedding op. - """ - cw = weights["weight"] - if cw.config.bits != 8: + w = weights["weight"] + if isinstance(w, Int4Tensor): raise ValueError( - f"Only 8-bit embedding quantization is supported on CUDA, " - f"got {cw.config.bits}-bit." + "Only 8-bit embedding quantization is supported on CUDA. " + "Int4TilePackedTo4dTensor does not implement the embedding op." ) - module.weight = pack_int8_for_cuda(cw) + module.weight = nn.Parameter(w, requires_grad=False) DEFAULT_CUDA_PACKERS: dict[type, ModulePackerFn] = { @@ -202,19 +150,34 @@ def load_and_pack_for_cuda( model: nn.Module, packers: dict[type, ModulePackerFn] | None = None, ) -> None: - """Stream weights from a quantized safetensors file and pack for CUDA. + """Load a quantized safetensors file and pack for CUDA.""" + from safetensors import safe_open + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) - Uses ``iter_load`` to process one weight at a time, keeping peak - memory proportional to the largest single weight instead of loading - all weights into memory at once. - """ from .pack import pack_one - from .serialize import iter_load _packers = packers or DEFAULT_CUDA_PACKERS - - for fqn, value in iter_load(path): - pack_one(model, fqn, value, _packers) + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + all_keys = list(f.keys()) + tensor_names = json.loads(metadata.get("tensor_names", "[]")) + + # Stream one logical weight at a time: load its inner tensors, + # reconstruct the subclass, pack, then release before the next. + loaded_keys: set[str] = set() + for name in tensor_names: + module_fqn, weight_name = name.rsplit(".", 1) + prefix = f"{module_fqn}._{weight_name}_" + partial = {} + for key in all_keys: + if key.startswith(prefix) or key == name: + partial[key] = f.get_tensor(key) + loaded_keys.add(key) + result, _ = unflatten_tensor_state_dict(partial, metadata) + for fqn, value in result.items(): + pack_one(model, fqn, value, _packers) for fqn, p in model.named_parameters(): if p.device.type == "meta": diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index 247ccc3bc24..4e4b993d496 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -4,24 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Quantize weights to canonical form. +"""Quantize weights to torchao tensor subclasses. ``quantize_weight`` quantizes a single tensor given a ``QuantConfig``, -dispatching to the appropriate algorithm based on ``config.method``: - - - ``"min_max"``: standard symmetric/asymmetric quantization via torchao's - ``choose_qparams_affine`` + ``quantize_affine``. Runs on CPU or CUDA. - - ``"hqq"``: Half-Quadratic Quantization — iteratively refines scales via - a proximal solver for better accuracy. ``symmetric=False`` optimizes both - scale and zero (requires CUDA). ``symmetric=True`` optimizes scale only - (CPU or CUDA). +returning an ``Int4Tensor`` (4-bit) or ``IntxUnpackedToInt8Tensor`` (8-bit). ``quantize_model`` walks a model's parameters, applies a ``QuantRecipe``, -and returns two dicts: quantized weights as ``CanonicalQuantizedWeight`` -and unquantized weights as plain tensors. - -Both are model-agnostic — they work for any ``nn.Module`` and any weight -shape (2D linears, 3D fused-expert stacks, etc.). +and returns a single state dict containing both quantized subclass tensors +and unquantized plain tensors. """ import torch @@ -29,8 +19,6 @@ from .recipe import QuantConfig, QuantRecipe -from .serialize import CanonicalQuantizedWeight - # --------------------------------------------------------------------------- # Per-weight quantization @@ -40,19 +28,14 @@ def _quantize_min_max( weight: torch.Tensor, config: QuantConfig, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Standard min/max quantization. Returns (int_data, scale, zero_point).""" + """Standard min/max 4-bit quantization. Returns (int_data, scale, zero_point).""" from torchao.quantization.quant_primitives import ( choose_qparams_affine, MappingType, quantize_affine, ) - if config.bits == 4: - qmin, qmax = (-8, 7) if config.symmetric else (0, 15) - elif config.bits == 8: - qmin, qmax = -128, 127 - else: - raise ValueError(f"Unsupported bits={config.bits}") + qmin, qmax = (-8, 7) if config.symmetric else (0, 15) mapping = MappingType.SYMMETRIC if config.symmetric else MappingType.ASYMMETRIC block_size = tuple([1] * (weight.ndim - 1) + [config.group_size]) @@ -83,10 +66,7 @@ def _quantize_hqq_asymmetric( weight: torch.Tensor, config: QuantConfig, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Full HQQ (asymmetric, optimizes scale + zero). Requires CUDA. - - Returns (int_data, scale, zero_point) in canonical layout. - """ + """Full HQQ (asymmetric, optimizes scale + zero). Requires CUDA.""" from torchao.quantization.quant_primitives import ( _choose_qparams_and_quantize_affine_hqq, ) @@ -116,22 +96,13 @@ def _quantize_hqq_symmetric( weight: torch.Tensor, config: QuantConfig, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Scale-only HQQ (symmetric, optimizes scale only). Runs on CPU or CUDA. - - Returns (int_data, scale, zero_point) where zero_point is all zeros. - """ + """Scale-only HQQ (symmetric 4-bit, optimizes scale only). Runs on CPU or CUDA.""" from torchao.quantization.quant_primitives import ( _choose_qparams_and_quantize_scale_only_hqq, ) - if config.bits == 4: - qmin, qmax = -8, 7 - elif config.bits == 8: - qmin, qmax = -128, 127 - else: - raise ValueError(f"Unsupported bits={config.bits}") + qmin, qmax = -8, 7 - # scale_only_hqq requires 2D. For 3D+, flatten → quantize → reshape. orig_shape = weight.shape weight_2d = weight.reshape(-1, weight.shape[-1]) if weight.ndim > 2 else weight @@ -149,16 +120,116 @@ def _quantize_hqq_symmetric( return int_data, scale, zero_point -def quantize_weight( +def _to_int4_tensor( + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + config: QuantConfig, +) -> torch.Tensor: + """Wrap quantized 4-bit data into an Int4Tensor.""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + # Normalize 4-bit signed [-8, 7] to unsigned [0, 15] for storage. + if config.symmetric: + int_data = int_data + 8 + zero_point = torch.full_like(scale, 8.0) + + # Int4Tensor stores qdata as nibble-packed uint8 (N, K//2) + q = int_data.to(torch.uint8) + packed = q[..., ::2] | (q[..., 1::2] << 4) + + # Int4Tensor stores scale/zero as (K//gs, N) — transposed from our (N, K//gs) + return Int4Tensor( + qdata=packed, + scale=scale.t().contiguous(), + zero_point=zero_point.t().contiguous(), + block_size=[1, config.group_size], + shape=torch.Size(int_data.shape), + ) + + +def _to_intx_tensor( weight: torch.Tensor, config: QuantConfig, -) -> CanonicalQuantizedWeight: - """Quantize ``weight`` to canonical form. +) -> torch.Tensor: + """Quantize 8-bit and wrap in IntxUnpackedToInt8Tensor. - Dispatches to the algorithm specified by ``config.method``. The input is - processed in float32 internally for numerical stability. Does NOT pad or - pack for any backend. + Quantizes in float32 for numerical precision, then constructs the + subclass directly. We avoid ``from_hp`` because it quantizes in the + input dtype (bf16), which loses precision for small-magnitude weights. """ + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + if config.method == "hqq": + if not config.symmetric: + raise ValueError( + "8-bit HQQ only supports symmetric quantization " + "(HQQ_SCALE_ONLY). Use method='min_max' for asymmetric 8-bit." + ) + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + + w2d = weight.float().reshape(-1, weight.shape[-1]) + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + w2d, [1, config.group_size], -128, 127 + ) + qdata = qdata.to(torch.int8).reshape(weight.shape) + scale = scale.to(torch.bfloat16).reshape(weight.shape[0], -1) + zero_point = torch.zeros_like(scale, dtype=torch.int8) + else: + mapping = MappingType.SYMMETRIC if config.symmetric else MappingType.ASYMMETRIC + block_size = (1, config.group_size) + scale, zero_point = choose_qparams_affine( + weight.float(), + mapping, + block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=torch.bfloat16, + zero_point_dtype=torch.int8, + ) + qdata = quantize_affine( + weight.float(), + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=-128, + quant_max=127, + ) + N, n_groups = weight.shape[0], weight.shape[-1] // config.group_size + scale = scale.reshape(N, n_groups) + zero_point = zero_point.reshape(N, n_groups) + + return IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=torch.int8, + block_size=(1, config.group_size), + dtype=torch.bfloat16, + activation_quantization=None, + ) + + +def quantize_weight(weight: torch.Tensor, config: QuantConfig) -> torch.Tensor: + """Quantize ``weight`` to a torchao tensor subclass. + + Returns ``Int4Tensor`` for 4-bit or ``IntxUnpackedToInt8Tensor`` for 8-bit. + """ + if config.bits == 8: + return _to_intx_tensor(weight, config) + + if config.bits != 4: + raise ValueError(f"Unsupported bits={config.bits}") + if config.method == "min_max": int_data, scale, zero_point = _quantize_min_max(weight, config) elif config.method == "hqq": @@ -172,41 +243,36 @@ def quantize_weight( f"Supported: 'min_max', 'hqq'." ) - # Normalize 4-bit to unsigned [0, 15] for uniform storage and nibble - # packing. Symmetric min_max produces [-8, 7]; shift to [0, 15]. - # HQQ already produces [0, 15] (asymmetric internally). - if config.bits == 4 and config.symmetric: - int_data = int_data + 8 - - return CanonicalQuantizedWeight( - qdata=int_data.to(torch.int8), - scale=scale.to(torch.bfloat16), - zero=zero_point.to(torch.bfloat16) if not config.symmetric else None, - config=config, - ) + return _to_int4_tensor(int_data, scale, zero_point, config) def dequantize_weight( - cw: CanonicalQuantizedWeight, + weight: torch.Tensor, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - """Dequantize a ``CanonicalQuantizedWeight`` back to a dense tensor. - - Inverse of ``quantize_weight``. Useful for embedding lookups (which - need dense weights) or for inspecting quantized values. - """ - gs = cw.config.group_size - scale = cw.scale.float().repeat_interleave(gs, dim=-1) - qdata = cw.qdata.float() - # Symmetric 4-bit qdata is stored as unsigned [0, 15] (shifted +8 in - # quantize_weight). Undo the shift to recover signed [-8, 7] before - # scaling. (Q4_K is asymmetric and uses a zero field instead.) - if cw.config.bits == 4 and cw.zero is None: - qdata = qdata - 8 - if cw.zero is not None: - zero = cw.zero.float().repeat_interleave(gs, dim=-1) + """Dequantize a torchao quantized tensor back to a dense tensor.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + if isinstance(weight, Int4Tensor): + # Unpack nibbles + p = weight.qdata.to(torch.uint8) + low = (p & 0x0F).float() + high = ((p >> 4) & 0x0F).float() + qdata = torch.stack([low, high], dim=-1).reshape(weight.shape) + # Scale is (K//gs, N), transpose to (N, K//gs) for broadcast + gs = weight.block_size[-1] + scale = weight.scale.t().float().repeat_interleave(gs, dim=-1) + zero = weight.zero_point.t().float().repeat_interleave(gs, dim=-1) return ((qdata - zero) * scale).to(dtype) - return (qdata * scale).to(dtype) + + if isinstance(weight, IntxUnpackedToInt8Tensor): + gs = weight.block_size[-1] + scale = weight.scale.float().repeat_interleave(gs, dim=-1) + zero = weight.zero_point.float().repeat_interleave(gs, dim=-1) + return ((weight.qdata.float() - zero) * scale).to(dtype) + + raise TypeError(f"Cannot dequantize {type(weight).__name__}") # --------------------------------------------------------------------------- @@ -217,32 +283,28 @@ def quantize_model( model: nn.Module, recipe: QuantRecipe, dtype: torch.dtype = torch.bfloat16, -) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: +) -> dict[str, torch.Tensor]: """Walk model parameters + persistent buffers, apply recipe. - For each parameter matched by a recipe rule: quantize to canonical. - Parameters that match ``None`` (skip) rules and persistent buffers go - into the unquantized dict (cast to ``dtype``). Non-persistent buffers - (KV cache, RoPE tables, etc.) are excluded. - - Returns ``(quantized, unquantized)`` dicts keyed by FQN. + Returns a single state dict containing quantized tensor subclasses + (``Int4Tensor``, ``IntxUnpackedToInt8Tensor``) and unquantized plain + tensors. Non-persistent buffers (KV cache, RoPE tables) are excluded. """ - quantized: dict[str, CanonicalQuantizedWeight] = {} - unquantized: dict[str, torch.Tensor] = {} + state: dict[str, torch.Tensor] = {} persistent_keys = set(model.state_dict().keys()) n_params = sum(1 for _ in model.named_parameters()) for i, (fqn, param) in enumerate(model.named_parameters()): config = recipe.get_config(fqn) if config is None: - unquantized[fqn] = param.data.to(dtype) + state[fqn] = param.data.to(dtype) else: - quantized[fqn] = quantize_weight(param.data, config) + state[fqn] = quantize_weight(param.data, config) print(f" Quantized {i + 1}/{n_params}: {fqn}", end="\r") print() for fqn, buf in model.named_buffers(): - if fqn in persistent_keys and fqn not in quantized: - unquantized[fqn] = buf.data + if fqn in persistent_keys and fqn not in state: + state[fqn] = buf.data - return quantized, unquantized + return state diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py index e207e268c38..9ffafeafc5f 100644 --- a/examples/models/gemma4_31b/quant/recipe.py +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -18,12 +18,16 @@ @dataclass(frozen=True) class QuantConfig: - """Per-weight quantization parameters.""" + """Per-weight quantization parameters (quantization-time only). - bits: int # storage width: 4 or 8 (6-bit formats like Q6_K are widened to 8) + Not stored in the serialized checkpoint — torchao tensor subclasses + carry their own metadata. This is purely for driving ``quantize_weight``. + """ + + bits: int # 4 or 8 group_size: int # 32, 64, 128 symmetric: bool # True = no zero point - method: str # "min_max" | "hqq" | "gguf_q4_k" | "gguf_q6_k" + method: str # "min_max" | "hqq" @dataclass diff --git a/examples/models/gemma4_31b/quant/serialize.py b/examples/models/gemma4_31b/quant/serialize.py deleted file mode 100644 index 35eae615fba..00000000000 --- a/examples/models/gemma4_31b/quant/serialize.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Serialize and persist quantized weights. - -Two layers: - - - **serialize / deserialize** — convert between ``CanonicalQuantizedWeight`` - objects and plain tensors + JSON metadata. Pure logic, no I/O. The output - is a ``(tensors_dict, metadata_dict)`` pair that any file writer can - consume. - - **save / load** — write/read the serialized form to/from safetensors on - disk. Thin I/O wrappers around ``safetensors.save_file`` / - ``safetensors.safe_open``. - -For 4-bit weights, qdata is nibble-packed (two values per byte) during -serialization to keep file size at ~0.5 bytes/param. -""" - -import json -from dataclasses import dataclass -from typing import Iterator, Optional - -import torch -from safetensors import safe_open -from safetensors.torch import save_file - -from .recipe import QuantConfig - -# Bump when the on-disk layout changes in a backward-incompatible way -# (e.g., different nibble-pack convention, renamed keys, new required fields). -# The loader rejects files with an unsupported version rather than silently -# producing corrupt data. -FORMAT_VERSION = "1" -_SUPPORTED_VERSIONS = {FORMAT_VERSION} - - -@dataclass -class CanonicalQuantizedWeight: - """Packing-free quantized weight representation. - - ``qdata`` int8 values: [0, 15] for 4-bit (both symmetric and asymmetric - are stored as unsigned after shifting), [-128, 127] for 8-bit. - ``scale`` bf16 per-group scales, shape ``[*weight_shape[:-1], K // group_size]``. - ``zero`` bf16 per-group zero points (``None`` when symmetric). - ``config`` the ``QuantConfig`` that produced this. - """ - - qdata: torch.Tensor - scale: torch.Tensor - zero: Optional[torch.Tensor] - config: QuantConfig - - def __post_init__(self): - if self.qdata.dtype != torch.int8: - raise ValueError(f"qdata must be int8, got {self.qdata.dtype}") - K = self.qdata.shape[-1] - if K % self.config.group_size != 0: - raise ValueError( - f"Last dim ({K}) must be divisible by group_size ({self.config.group_size})" - ) - n_groups = K // self.config.group_size - expected_numel = self.qdata[..., 0:1].numel() * n_groups - if self.scale.numel() != expected_numel: - raise ValueError( - f"scale has {self.scale.numel()} elements, expected {expected_numel} " - f"(from qdata {tuple(self.qdata.shape)}, group_size={self.config.group_size})" - ) - if self.config.symmetric and self.zero is not None: - raise ValueError("symmetric config must have zero=None") - if not self.config.symmetric and self.zero is None: - raise ValueError("asymmetric config must have zero (not None)") - - -# --------------------------------------------------------------------------- -# Nibble packing for 4-bit on-disk storage. -# -# Two 4-bit values are packed into one byte to halve file size. The -# convention is: even-indexed values go into the LOW nibble (bits 0-3), -# odd-indexed values go into the HIGH nibble (bits 4-7). -# -# values: [v0, v1, v2, v3, ...] (each in [0, 15]) -# packed: [v0 | (v1 << 4), v2 | (v3 << 4), ...] -# byte 0: bits 0-3 = v0, bits 4-7 = v1 -# -# To unpack: low = byte & 0x0F, high = (byte >> 4) & 0x0F. -# -# This matches the Triton fused_moe kernel's unpack convention -# ((byte >> (k%2)*4) & 0xF) and Qwen's _quantize_experts_int4 packing. -# Note: tinygemm uses the OPPOSITE convention (even=HIGH, odd=LOW) — the -# CUDA packer in pack_cuda.py handles that conversion separately. - - -def _nibble_pack(qdata: torch.Tensor) -> torch.Tensor: - """Pack int8 values (each in [0, 15]) into half the last dim. - - Even-indexed values → low nibble, odd-indexed → high nibble. - """ - assert qdata.shape[-1] % 2 == 0, f"Last dim must be even, got {qdata.shape}" - low = qdata[..., ::2].to(torch.uint8) - high = qdata[..., 1::2].to(torch.uint8) - return (low | (high << 4)).to(torch.int8).contiguous() - - -def _nibble_unpack(packed: torch.Tensor, orig_last_dim: int) -> torch.Tensor: - """Unpack nibble-packed int8 → original last dim. - - Low nibble (bits 0-3) → even indices, high nibble (bits 4-7) → odd indices. - """ - p = packed.to(torch.uint8) - low = (p & 0x0F).to(torch.int8) - high = ((p >> 4) & 0x0F).to(torch.int8) - return torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], orig_last_dim) - - -# --------------------------------------------------------------------------- -# Serialize / deserialize (pure logic, no I/O) - - -def serialize( - quantized: dict[str, CanonicalQuantizedWeight], - unquantized: dict[str, torch.Tensor], -) -> tuple[dict[str, torch.Tensor], dict[str, str]]: - """Convert quantized + unquantized weights to plain tensors + metadata. - - Returns ``(tensors, header)`` ready for any file writer. Quantized - weights become ``{fqn}.qdata``, ``{fqn}.scale``, optionally - ``{fqn}.zero``. For 4-bit, qdata is nibble-packed. - """ - tensors: dict[str, torch.Tensor] = {} - quant_meta: dict[str, dict] = {} - - for fqn, cw in quantized.items(): - qdata = cw.qdata - if cw.config.bits == 4: - qdata = _nibble_pack(qdata) - tensors[f"{fqn}.qdata"] = qdata.contiguous() - tensors[f"{fqn}.scale"] = cw.scale.contiguous() - if cw.zero is not None: - tensors[f"{fqn}.zero"] = cw.zero.contiguous() - quant_meta[fqn] = { - "bits": cw.config.bits, - "group_size": cw.config.group_size, - "symmetric": cw.config.symmetric, - "method": cw.config.method, - "shape": list(cw.qdata.shape), - } - - for fqn, tensor in unquantized.items(): - tensors[fqn] = tensor.contiguous() - - header = {"format_version": FORMAT_VERSION} - if quant_meta: - header["quant"] = json.dumps(quant_meta) - - return tensors, header - - -def deserialize( - tensors: dict[str, torch.Tensor], - header: dict[str, str], -) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: - """Reconstruct quantized + unquantized weights from plain tensors + metadata. - - Inverse of ``serialize``. Returns ``(quantized, unquantized)`` dicts. - """ - version = header.get("format_version", "1") - if version not in _SUPPORTED_VERSIONS: - raise ValueError( - f"Unsupported format version {version!r}. " - f"This code supports {sorted(_SUPPORTED_VERSIONS)}. " - f"Update the quant package or re-quantize the model." - ) - - quant_meta = json.loads(header.get("quant", "{}")) - - quantized: dict[str, CanonicalQuantizedWeight] = {} - consumed_keys: set[str] = set() - - for fqn, meta in quant_meta.items(): - config = QuantConfig( - bits=meta["bits"], - group_size=meta["group_size"], - symmetric=meta["symmetric"], - method=meta["method"], - ) - qdata = tensors[f"{fqn}.qdata"] - consumed_keys.add(f"{fqn}.qdata") - - original_shape = meta["shape"] - if config.bits == 4: - qdata = _nibble_unpack(qdata, original_shape[-1]) - - scale = tensors[f"{fqn}.scale"] - consumed_keys.add(f"{fqn}.scale") - - zero = tensors.get(f"{fqn}.zero") - if zero is not None: - consumed_keys.add(f"{fqn}.zero") - - quantized[fqn] = CanonicalQuantizedWeight( - qdata=qdata, scale=scale, zero=zero, config=config - ) - - unquantized = {k: v for k, v in tensors.items() if k not in consumed_keys} - - return quantized, unquantized - - -# --------------------------------------------------------------------------- -# Save / load (I/O wrappers) - - -def save( - quantized: dict[str, CanonicalQuantizedWeight], - unquantized: dict[str, torch.Tensor], - path: str, -) -> int: - """Serialize and write to safetensors. Returns the number of tensors written.""" - tensors, header = serialize(quantized, unquantized) - save_file(tensors, path, metadata=header) - return len(tensors) - - -def load( - path: str, -) -> tuple[dict[str, CanonicalQuantizedWeight], dict[str, torch.Tensor]]: - """Read safetensors and deserialize. Returns ``(quantized, unquantized)``.""" - with safe_open(path, framework="pt", device="cpu") as f: - header = f.metadata() - tensors = {k: f.get_tensor(k) for k in f.keys()} - return deserialize(tensors, header) - - -def iter_load( - path: str, -) -> Iterator[tuple[str, CanonicalQuantizedWeight | torch.Tensor]]: - """Stream weights from a safetensors file one at a time. - - Yields ``(fqn, value)`` where *value* is a ``CanonicalQuantizedWeight`` - for quantized weights or a plain ``torch.Tensor`` for unquantized ones. - Only one weight's tensors are resident in memory at a time, keeping peak - memory proportional to the largest single weight. - """ - with safe_open(path, framework="pt", device="cpu") as f: - header = f.metadata() - quant_meta = json.loads(header.get("quant", "{}")) - all_keys = set(f.keys()) - consumed: set[str] = set() - - for fqn, meta in quant_meta.items(): - config = QuantConfig( - bits=meta["bits"], - group_size=meta["group_size"], - symmetric=meta["symmetric"], - method=meta["method"], - ) - qdata = f.get_tensor(f"{fqn}.qdata") - consumed.add(f"{fqn}.qdata") - if config.bits == 4: - qdata = _nibble_unpack(qdata, meta["shape"][-1]) - - scale = f.get_tensor(f"{fqn}.scale") - consumed.add(f"{fqn}.scale") - - zero_key = f"{fqn}.zero" - zero = None - if zero_key in all_keys: - zero = f.get_tensor(zero_key) - consumed.add(zero_key) - - yield fqn, CanonicalQuantizedWeight( - qdata=qdata, scale=scale, zero=zero, config=config - ) - - for key in all_keys: - if key not in consumed: - yield key, f.get_tensor(key) diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index 80e6af5001a..89a7099d6f0 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -6,12 +6,13 @@ """Unit tests for quant/gguf.py — Q4_K and Q6_K unpacking. -Tests verify the API contract: dequantized canonical weights match the -original GGUF dequantization formula. Uses synthetic blocks — no GGUF -file required. +Tests verify the API contract: dequantized weights match the original +GGUF dequantization formula. Uses synthetic blocks — no GGUF file required. """ +import os import struct +import tempfile import unittest import numpy as np @@ -28,7 +29,12 @@ from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight -from executorch.examples.models.gemma4_31b.quant.serialize import deserialize, serialize +from safetensors import safe_open +from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, +) def _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals): @@ -173,8 +179,8 @@ def test_dequant_matches_reference(self): @unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestGgufSerializeRoundtrip(unittest.TestCase): - def test_q4_k_survives_serialize_roundtrip(self): - """unpack → serialize → deserialize → dequant matches original.""" + def test_q4_k_survives_save_load_roundtrip(self): + """unpack → save → load → dequant matches original.""" d, dmin = 0.5, 0.25 sub_scales = [3, 7, 1, 15, 20, 10, 31, 5] sub_mins = [1, 2, 0, 4, 8, 3, 12, 6] @@ -182,34 +188,46 @@ def test_q4_k_survives_serialize_roundtrip(self): block = _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 144) - cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q4_K, [1, 256]) + q = unpack_gguf_tensor(data, GGMLQuantizationType.Q4_K, [1, 256]) - dequant_before = dequantize_weight(cw) + dequant_before = dequantize_weight(q) - tensors, header = serialize({"w": cw}, {}) - q_loaded, _ = deserialize(tensors, header) - dequant_after = dequantize_weight(q_loaded["w"]) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "m.safetensors") + td, md = flatten_tensor_state_dict({"layer.weight": q}) + save_file(td, path, metadata=md) + with safe_open(path, framework="pt", device="cpu") as sf: + loaded_meta = sf.metadata() + loaded_tensors = {k: sf.get_tensor(k) for k in sf.keys()} + loaded, _ = unflatten_tensor_state_dict(loaded_tensors, loaded_meta) + dequant_after = dequantize_weight(loaded["layer.weight"]) self.assertTrue( torch.allclose(dequant_before, dequant_after, atol=0.01), f"Max diff: {(dequant_before - dequant_after).abs().max():.6f}", ) - def test_q6_k_survives_serialize_roundtrip(self): - """unpack → serialize → deserialize → dequant matches original.""" + def test_q6_k_survives_save_load_roundtrip(self): + """unpack → save → load → dequant matches original.""" d = 0.5 scales_16 = [i + 1 for i in range(16)] qvals = [(i * 3 + 5) % 64 for i in range(256)] block = _make_q6_k_block(d, scales_16, qvals) data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) - cw = unpack_gguf_tensor(data, GGMLQuantizationType.Q6_K, [1, 256]) + q = unpack_gguf_tensor(data, GGMLQuantizationType.Q6_K, [1, 256]) - dequant_before = dequantize_weight(cw) + dequant_before = dequantize_weight(q) - tensors, header = serialize({"w": cw}, {}) - q_loaded, _ = deserialize(tensors, header) - dequant_after = dequantize_weight(q_loaded["w"]) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "m.safetensors") + td, md = flatten_tensor_state_dict({"layer.weight": q}) + save_file(td, path, metadata=md) + with safe_open(path, framework="pt", device="cpu") as sf: + loaded_meta = sf.metadata() + loaded_tensors = {k: sf.get_tensor(k) for k in sf.keys()} + loaded, _ = unflatten_tensor_state_dict(loaded_tensors, loaded_meta) + dequant_after = dequantize_weight(loaded["layer.weight"]) self.assertTrue( torch.allclose(dequant_before, dequant_after, atol=0.01), @@ -221,6 +239,24 @@ def test_q6_k_survives_serialize_roundtrip(self): class TestUnpackGgufTensor(unittest.TestCase): """Tests for the public ``unpack_gguf_tensor`` API.""" + def test_q4_k_returns_int4_tensor(self): + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + block = _make_q4_k_block(0.5, 0.25, [1] * 8, [1] * 8, [7] * 256) + data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 144) + result = unpack_gguf_tensor(data, GGMLQuantizationType.Q4_K, [1, 256]) + self.assertIsInstance(result, Int4Tensor) + self.assertEqual(result.shape, torch.Size([1, 256])) + + def test_q6_k_returns_intx_tensor(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + block = _make_q6_k_block(0.5, list(range(1, 17)), [32] * 256) + data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) + result = unpack_gguf_tensor(data, GGMLQuantizationType.Q6_K, [1, 256]) + self.assertIsInstance(result, IntxUnpackedToInt8Tensor) + self.assertEqual(result.shape, torch.Size([1, 256])) + def test_f32_returns_tensor(self): data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) result = unpack_gguf_tensor(data, GGMLQuantizationType.F32, [4]) diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index bc429ee6372..89ebbcbab56 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -19,14 +19,13 @@ load_and_pack_for_cuda, pack_embedding_for_cuda, pack_int4_for_cuda, - pack_int8_for_cuda, pack_linear_for_cuda, pack_model, ) from executorch.examples.models.gemma4_31b.quant.quantize import quantize_weight from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig - -from executorch.examples.models.gemma4_31b.quant.serialize import save +from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import flatten_tensor_state_dict class TestPackInt4ForCuda(unittest.TestCase): @@ -34,23 +33,10 @@ def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA required") - def test_symmetric_works(self): - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) - self.assertEqual(pack_int4_for_cuda(cw).shape, torch.Size([128, 256])) - - def test_rejects_1d(self): + def test_basic(self): config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(torch.randn(1, 128, dtype=torch.bfloat16), config) - cw.qdata = cw.qdata.squeeze(0) - with self.assertRaises(AssertionError): - pack_int4_for_cuda(cw) - - def test_rejects_8bit(self): - config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - with self.assertRaises(AssertionError): - pack_int4_for_cuda(cw) + q = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) + self.assertEqual(pack_int4_for_cuda(q).shape, torch.Size([128, 256])) def test_different_group_sizes(self): for gs in (32, 64, 128): @@ -58,25 +44,18 @@ def test_different_group_sizes(self): config = QuantConfig( bits=4, group_size=gs, symmetric=False, method="min_max" ) - cw = quantize_weight( - torch.randn(128, 256, dtype=torch.bfloat16), config - ) - self.assertEqual(pack_int4_for_cuda(cw).shape, torch.Size([128, 256])) + q = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) + self.assertEqual(pack_int4_for_cuda(q).shape, torch.Size([128, 256])) def test_matmul_approximates_original(self): - """Packed weight produces matmul output close to the original.""" torch.manual_seed(0) - # Use dimensions already aligned to tinygemm requirements - # (K multiple of 1024, N multiple of 8) to avoid padding effects. weight = torch.randn(256, 1024, dtype=torch.bfloat16) x = torch.randn(1, 1024, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(weight, config) - packed = pack_int4_for_cuda(cw) - + q = quantize_weight(weight, config) + packed = pack_int4_for_cuda(q) packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) rel_error = ( @@ -84,50 +63,15 @@ def test_matmul_approximates_original(self): ).abs().mean() / original_out.float().abs().mean() self.assertLess(rel_error.item(), 0.15) - def test_symmetric_matmul_approximates_original(self): - """Symmetric 4-bit (e.g. HQQ) packs correctly for tinygemm.""" + def test_symmetric_matmul(self): torch.manual_seed(0) weight = torch.randn(256, 1024, dtype=torch.bfloat16) x = torch.randn(1, 1024, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(weight, config) - packed = pack_int4_for_cuda(cw) - - packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) - - rel_error = ( - packed_out.float() - original_out.float() - ).abs().mean() / original_out.float().abs().mean() - self.assertLess(rel_error.item(), 0.15) - - def test_asymmetric_gguf_q4_k_matmul(self): - """Asymmetric 4-bit (GGUF Q4_K style) packs and produces correct matmul.""" - torch.manual_seed(0) - weight = torch.randn(256, 1024, dtype=torch.bfloat16) - x = torch.randn(1, 1024, dtype=torch.bfloat16) - - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) - - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(weight, config) - # Mimic GGUF Q4_K: asymmetric with a non-standard method name - from executorch.examples.models.gemma4_31b.quant.serialize import ( - CanonicalQuantizedWeight, - ) - - cw_gguf = CanonicalQuantizedWeight( - qdata=cw.qdata, - scale=cw.scale, - zero=cw.zero, - config=QuantConfig( - bits=4, group_size=32, symmetric=False, method="gguf_q4_k" - ), - ) - packed = pack_int4_for_cuda(cw_gguf) - + q = quantize_weight(weight, config) + packed = pack_int4_for_cuda(q) packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) rel_error = ( @@ -136,68 +80,40 @@ def test_asymmetric_gguf_q4_k_matmul(self): self.assertLess(rel_error.item(), 0.15) -class TestPackInt8ForCuda(unittest.TestCase): +class TestPackInt8OnCuda(unittest.TestCase): def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA required") - def test_rejects_4bit(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - with self.assertRaises(AssertionError): - pack_int8_for_cuda(cw) - def test_matmul_approximates_original(self): torch.manual_seed(0) weight = torch.randn(256, 128, dtype=torch.bfloat16) x = torch.randn(1, 128, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(weight, config) - packed = pack_int8_for_cuda(cw) - - packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) - - rel_error = ( - packed_out.float() - original_out.float() - ).abs().mean() / original_out.float().abs().mean() - self.assertLess(rel_error.item(), 0.02) - - def test_asymmetric_matmul_approximates_original(self): - """8-bit asymmetric quantization packs and produces correct matmul.""" - torch.manual_seed(0) - weight = torch.randn(256, 128, dtype=torch.bfloat16) - x = torch.randn(1, 128, dtype=torch.bfloat16) - - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) - - config = QuantConfig(bits=8, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(weight, config) - packed = pack_int8_for_cuda(cw) - - packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + q = quantize_weight(weight, config) + # IntxUnpackedToInt8Tensor is already the CUDA format + emb = nn.Linear(128, 256, bias=False) + emb.weight = nn.Parameter(q, requires_grad=False) + emb.to("cuda") + packed_out = emb(x.cuda()) rel_error = ( packed_out.float() - original_out.float() ).abs().mean() / original_out.float().abs().mean() self.assertLess(rel_error.item(), 0.02) - def test_per_axis_gather_approximates_original(self): - """Per-axis INT8 (group_size == K) works for embedding gather.""" + def test_per_axis_embedding_gather(self): torch.manual_seed(0) weight = torch.randn(1000, 64, dtype=torch.bfloat16) ids = torch.tensor([0, 1, 42, 500, 999]) - original = weight[ids] config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") - cw = quantize_weight(weight, config) - packed = pack_int8_for_cuda(cw) - + q = quantize_weight(weight, config) emb = nn.Embedding(1000, 64) - emb.weight = nn.Parameter(packed, requires_grad=False) + emb.weight = nn.Parameter(q, requires_grad=False) emb.to("cuda") packed_out = emb(ids.cuda()) @@ -212,19 +128,18 @@ def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA required") - def test_4bit_modifies_module_in_place(self): - module = nn.Linear(128, 256, bias=False) + def test_4bit(self): config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(256, 128, dtype=torch.bfloat16), config) - pack_linear_for_cuda(module, {"weight": cw}) - self.assertEqual(module.weight.device.type, "cpu") + q = quantize_weight(torch.randn(256, 128, dtype=torch.bfloat16), config) + module = nn.Linear(128, 256, bias=False) + pack_linear_for_cuda(module, {"weight": q}) self.assertEqual(module.weight.shape, torch.Size([256, 128])) - def test_8bit_modifies_module_in_place(self): - module = nn.Linear(128, 64, bias=False) + def test_8bit(self): config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - pack_linear_for_cuda(module, {"weight": cw}) + q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + module = nn.Linear(128, 64, bias=False) + pack_linear_for_cuda(module, {"weight": q}) self.assertEqual(module.weight.shape, torch.Size([64, 128])) @@ -233,19 +148,16 @@ def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA required") - def test_gather_approximates_original(self): - """INT8 quantized embedding gather matches bf16 gather.""" + def test_int8_gather(self): torch.manual_seed(0) weight = torch.randn(1000, 64, dtype=torch.bfloat16) ids = torch.tensor([0, 1, 42, 500, 999]) - original = weight[ids] config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") - cw = quantize_weight(weight, config) - + q = quantize_weight(weight, config) module = nn.Embedding(1000, 64) - pack_embedding_for_cuda(module, {"weight": cw}) + pack_embedding_for_cuda(module, {"weight": q}) module.to("cuda") packed_out = module(ids.cuda()) @@ -256,43 +168,25 @@ def test_gather_approximates_original(self): def test_rejects_4bit(self): config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + q = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) module = nn.Embedding(100, 64) with self.assertRaises(ValueError): - pack_embedding_for_cuda(module, {"weight": cw}) + pack_embedding_for_cuda(module, {"weight": q}) -class TestLoadAndPackForCuda(unittest.TestCase): +class TestPackModel(unittest.TestCase): def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA required") - def test_pack_model_in_memory(self): - """pack_model works with in-memory dicts (no file I/O).""" - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - unq = {"norm.weight": torch.randn(64, dtype=torch.bfloat16)} - - with torch.device("meta"): - model = nn.ModuleDict( - { - "proj": nn.Linear(128, 64, bias=False), - "norm": nn.LayerNorm(64, bias=False), - } - ) - pack_model(model, {"proj.weight": cw}, unq, DEFAULT_CUDA_PACKERS) - - self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) - self.assertEqual(model.norm.weight.shape, torch.Size([64])) - - def test_pack_model_mixed_precision(self): + def test_mixed_precision(self): """pack_model handles 4-bit and 8-bit weights in the same model.""" q4_config = QuantConfig( bits=4, group_size=32, symmetric=False, method="min_max" ) q8_config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - cw4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4_config) - cw8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8_config) + q4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4_config) + q8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8_config) with torch.device("meta"): model = nn.ModuleDict( @@ -302,148 +196,35 @@ def test_pack_model_mixed_precision(self): } ) pack_model( - model, - {"q_proj.weight": cw4, "v_proj.weight": cw8}, - {}, - DEFAULT_CUDA_PACKERS, + model, {"q_proj.weight": q4, "v_proj.weight": q8}, DEFAULT_CUDA_PACKERS ) - self.assertEqual(model.q_proj.weight.shape, torch.Size([64, 128])) self.assertEqual(model.v_proj.weight.shape, torch.Size([64, 128])) - # Verify different subclass types - self.assertNotEqual( - type(model.q_proj.weight.data).__name__, - type(model.v_proj.weight.data).__name__, - ) - - def test_dispatches_by_module_type(self): - """load_and_pack_for_cuda reads from disk and dispatches.""" - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"proj.weight": cw}, {}, path) - - with torch.device("meta"): - model2 = nn.ModuleDict({"proj": nn.Linear(128, 64, bias=False)}) - load_and_pack_for_cuda(path, model2) - - self.assertEqual(model2.proj.weight.shape, torch.Size([64, 128])) - self.assertEqual(model2.proj.weight.device.type, "cpu") - - def test_unknown_module_type_raises(self): - """Unregistered module types get a clear error.""" - - class CustomModule(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(32, 64)) + def test_load_and_pack(self): + """load_and_pack_for_cuda reads from disk and packs.""" config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) with tempfile.TemporaryDirectory() as d: path = os.path.join(d, "m.safetensors") - save({"custom.weight": cw}, {}, path) + state = { + "proj.weight": q, + "norm.weight": torch.randn(64, dtype=torch.bfloat16), + } + td, md = flatten_tensor_state_dict(state) + save_file(td, path, metadata=md) with torch.device("meta"): - model2 = nn.ModuleDict({"custom": CustomModule()}) - with self.assertRaises(ValueError) as ctx: - load_and_pack_for_cuda(path, model2) - self.assertIn("CustomModule", str(ctx.exception)) - - def test_missing_weight_raises(self): - """A meta-device parameter after loading means the checkpoint was incomplete.""" - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - # Only save weight for 'a', not 'b' - save({"a.weight": cw}, {}, path) - - with torch.device("meta"): - model2 = nn.ModuleDict( + model = nn.ModuleDict( { - "a": nn.Linear(64, 32, bias=False), - "b": nn.Linear(64, 32, bias=False), + "proj": nn.Linear(128, 64, bias=False), + "norm": nn.LayerNorm(64, bias=False), } ) - with self.assertRaises(RuntimeError) as ctx: - load_and_pack_for_cuda(path, model2) - self.assertIn("b.weight", str(ctx.exception)) - - def test_custom_packer_via_dict(self): - """Models can extend the packer dict with custom module types.""" - call_log = [] - - class MyModule(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(32, 64)) - - def my_packer(module, weights): - call_log.append(("my_packer", list(weights.keys()))) - cw = weights["weight"] - module.weight = nn.Parameter( - cw.qdata.to(torch.bfloat16), requires_grad=False - ) - - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) - - custom_packers = {**DEFAULT_CUDA_PACKERS, MyModule: my_packer} - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"m.weight": cw}, {}, path) - - with torch.device("meta"): - model2 = nn.ModuleDict({"m": MyModule()}) - load_and_pack_for_cuda(path, model2, packers=custom_packers) - - self.assertEqual(len(call_log), 1) - self.assertEqual(call_log[0], ("my_packer", ["weight"])) - self.assertEqual(model2.m.weight.device.type, "cpu") - - def test_multi_weight_module_grouped(self): - """pack_model groups multiple weights per module (MoE-style).""" - call_log = [] - - class FusedExperts(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.randn(32, 64)) - self.w2 = nn.Parameter(torch.randn(32, 64)) - - def moe_packer(module, weights): - call_log.append(sorted(weights.keys())) - for attr, cw in weights.items(): - setattr( - module, - attr, - nn.Parameter(cw.qdata.to(torch.bfloat16), requires_grad=False), - ) - - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw1 = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) - cw2 = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) - - with torch.device("meta"): - model = nn.ModuleDict({"experts": FusedExperts()}) - - packers = {**DEFAULT_CUDA_PACKERS, FusedExperts: moe_packer} - pack_model( - model, - {"experts.w1": cw1, "experts.w2": cw2}, - {}, - packers, - ) - - # Packer should be called ONCE with both weights - self.assertEqual(len(call_log), 1) - self.assertEqual(call_log[0], ["w1", "w2"]) + load_and_pack_for_cuda(path, model) + self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.norm.weight.shape, torch.Size([64])) class TestPackOne(unittest.TestCase): @@ -452,19 +233,15 @@ def setUp(self): self.skipTest("CUDA required") def test_quantized_weight(self): - """pack_one dispatches CQW to the module packer.""" config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) with torch.device("meta"): model = nn.ModuleDict({"proj": nn.Linear(128, 64, bias=False)}) - pack_one(model, "proj.weight", cw, DEFAULT_CUDA_PACKERS) - + pack_one(model, "proj.weight", q, DEFAULT_CUDA_PACKERS) self.assertNotEqual(model.proj.weight.device.type, "meta") - self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) def test_plain_tensor(self): - """pack_one assigns a plain tensor as a parameter or buffer.""" with torch.device("meta"): model = nn.ModuleDict({"norm": nn.LayerNorm(64, bias=False)}) pack_one( @@ -473,10 +250,47 @@ def test_plain_tensor(self): torch.randn(64, dtype=torch.bfloat16), DEFAULT_CUDA_PACKERS, ) - - self.assertEqual(model.norm.weight.shape, torch.Size([64])) self.assertEqual(model.norm.weight.dtype, torch.bfloat16) +class TestPackErrorPaths(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_unregistered_module_type(self): + """pack_model raises for module types not in packers dict.""" + + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 64)) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with torch.device("meta"): + model = nn.ModuleDict({"custom": CustomModule()}) + with self.assertRaises(ValueError) as ctx: + pack_model(model, {"custom.weight": q}, DEFAULT_CUDA_PACKERS) + self.assertIn("CustomModule", str(ctx.exception)) + + def test_missing_weight_detected(self): + """pack_model raises when a parameter stays on meta after packing.""" + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with torch.device("meta"): + model = nn.ModuleDict( + { + "a": nn.Linear(64, 32, bias=False), + "b": nn.Linear(64, 32, bias=False), + } + ) + with self.assertRaises(RuntimeError) as ctx: + pack_model(model, {"a.weight": q}, DEFAULT_CUDA_PACKERS) + self.assertIn("b.weight", str(ctx.exception)) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/quant/tests/test_quantize.py b/examples/models/gemma4_31b/quant/tests/test_quantize.py index 0dbbbff3167..43970c2bef2 100644 --- a/examples/models/gemma4_31b/quant/tests/test_quantize.py +++ b/examples/models/gemma4_31b/quant/tests/test_quantize.py @@ -4,11 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Unit tests for quant/quantize.py. - -Tests the public API: ``quantize_weight`` and ``quantize_model``. Organized -by resource requirement (CPU vs CUDA), not by internal codepath. -""" +"""Unit tests for quant/quantize.py.""" import unittest @@ -26,51 +22,35 @@ QuantRule, ) from parameterized import parameterized - - -# --------------------------------------------------------------------------- -# quantize_weight — CPU (uses min_max; tests the output contract) +from torchao.quantization import IntxUnpackedToInt8Tensor +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor class TestQuantizeWeight(unittest.TestCase): @parameterized.expand( [ - ("4bit_asym", 4, 32, False, (64, 128), 0, 15), - ("4bit_sym", 4, 32, True, (64, 128), 0, 15), - ("4bit_gs64", 4, 64, False, (32, 128), 0, 15), - ("8bit_sym", 8, 32, True, (32, 64), -128, 127), - ("3d_expert", 4, 32, False, (8, 64, 128), 0, 15), + ("4bit_asym", 4, 32, False), + ("4bit_sym", 4, 32, True), + ("4bit_gs64", 4, 64, False), + ("8bit_sym", 8, 32, True), ] ) - def test_output_structure(self, _name, bits, gs, sym, shape, qmin, qmax): + def test_output_type(self, _name, bits, gs, sym): config = QuantConfig(bits=bits, group_size=gs, symmetric=sym, method="min_max") - cw = quantize_weight(torch.randn(*shape, dtype=torch.bfloat16), config) - - self.assertEqual(cw.qdata.shape, shape) - self.assertEqual(cw.qdata.dtype, torch.int8) - self.assertEqual(cw.scale.shape, (*shape[:-1], shape[-1] // gs)) - self.assertGreaterEqual(cw.qdata.min().item(), qmin) - self.assertLessEqual(cw.qdata.max().item(), qmax) - - if sym: - self.assertIsNone(cw.zero) + result = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + if bits == 4: + self.assertIsInstance(result, Int4Tensor) + self.assertEqual(result.shape, torch.Size([64, 128])) else: - self.assertIsNotNone(cw.zero) - self.assertEqual(cw.zero.shape, cw.scale.shape) - - self.assertEqual(cw.config, config) - - def test_fp32_input(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.float32), config) - self.assertEqual(cw.qdata.shape, (32, 64)) + self.assertIsInstance(result, IntxUnpackedToInt8Tensor) + self.assertEqual(result.shape, torch.Size([64, 128])) def test_quantize_dequantize_roundtrip(self): torch.manual_seed(0) weight = torch.randn(64, 128, dtype=torch.bfloat16) config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(weight, config) - dequant = dequantize_weight(cw, dtype=torch.bfloat16) + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) rel_error = ( dequant.float() - weight.float() ).abs().mean() / weight.float().abs().mean() @@ -78,28 +58,79 @@ def test_quantize_dequantize_roundtrip(self): def test_dequantize_output_dtype(self): config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) - self.assertEqual(dequantize_weight(cw, torch.float32).dtype, torch.float32) - self.assertEqual(dequantize_weight(cw, torch.bfloat16).dtype, torch.bfloat16) - self.assertEqual(dequantize_weight(cw, torch.float16).dtype, torch.float16) + q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + self.assertEqual(dequantize_weight(q, torch.float32).dtype, torch.float32) + self.assertEqual(dequantize_weight(q, torch.bfloat16).dtype, torch.bfloat16) - def test_dequantize_symmetric(self): + def test_dequantize_symmetric_4bit(self): torch.manual_seed(1) weight = torch.randn(32, 64, dtype=torch.bfloat16) config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = quantize_weight(weight, config) - self.assertIsNone(cw.zero) - dequant = dequantize_weight(cw) + q = quantize_weight(weight, config) + dequant = dequantize_weight(q) self.assertEqual(dequant.shape, (32, 64)) rel_error = ( dequant.float() - weight.float() ).abs().mean() / weight.float().abs().mean() self.assertLess(rel_error.item(), 0.15) + def test_dequantize_int8(self): + torch.manual_seed(2) + weight = torch.randn(32, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_int8_small_weights_bf16_precision(self): + """INT8 quantization of small bf16 weights must use full int8 range. + + Regression: IntxUnpackedToInt8Tensor.from_hp quantizes in bf16, + which collapses per-group scales to a single value for weights + with abs_mean ~0.01 (e.g., Gemma 4 v_proj). Our _to_intx_tensor + casts to float32 first to avoid this. + """ + torch.manual_seed(42) + weight = torch.randn(64, 128, dtype=torch.bfloat16) * 0.01 + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_dequantize_int8_asymmetric(self): + torch.manual_seed(3) + weight = torch.randn(32, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_int8_per_axis(self): + """Per-axis (group_size == K) used for embeddings.""" + weight = torch.randn(256, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + self.assertIsInstance(q, IntxUnpackedToInt8Tensor) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.01) + @parameterized.expand( [ ("unknown_method", QuantConfig(4, 32, False, "bogus"), "bogus"), ("unsupported_bits", QuantConfig(3, 32, False, "min_max"), None), + ("hqq_8bit_asym", QuantConfig(8, 32, False, "hqq"), "symmetric"), ] ) def test_invalid_config_raises(self, _name, config, expected_substr): @@ -109,10 +140,6 @@ def test_invalid_config_raises(self, _name, config, expected_substr): self.assertIn(expected_substr, str(ctx.exception)) -# --------------------------------------------------------------------------- -# quantize_weight — CUDA (HQQ-specific behavior only) - - class TestQuantizeWeightHQQ(unittest.TestCase): def setUp(self): if not torch.cuda.is_available(): @@ -122,29 +149,29 @@ def test_quantize_dequantize_roundtrip(self): torch.manual_seed(0) weight = torch.randn(64, 128, dtype=torch.bfloat16, device="cuda") config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") - cw = quantize_weight(weight, config) - dequant = dequantize_weight(cw, dtype=torch.bfloat16).cpu() + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16).cpu() rel_error = ( dequant.float() - weight.cpu().float() ).abs().mean() / weight.cpu().float().abs().mean() self.assertLess(rel_error.item(), 0.15) def test_symmetric_scale_only(self): - """symmetric=True dispatches to scale-only HQQ (no zero).""" config = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") - cw = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - self.assertIsNone(cw.zero) - self.assertGreaterEqual(cw.qdata.min().item(), 0) - self.assertLessEqual(cw.qdata.max().item(), 15) + q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + self.assertIsInstance(q, Int4Tensor) - def test_cpu_input_accepted(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") - cw = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) - self.assertEqual(cw.qdata.shape, (32, 64)) - - -# --------------------------------------------------------------------------- -# quantize_model + def test_int8_hqq_roundtrip(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="hqq") + q = quantize_weight(weight, config) + self.assertIsInstance(q, IntxUnpackedToInt8Tensor) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) class TestQuantizeModel(unittest.TestCase): @@ -167,15 +194,11 @@ def test_applies_recipe(self): QuantRule(r".*\.weight", QuantConfig(4, 16, False, "min_max")), ] ) + state = quantize_model(model, recipe) - quantized, unquantized = quantize_model(model, recipe) - - self.assertIn("proj.weight", quantized) - self.assertEqual(quantized["proj.weight"].qdata.shape, (32, 16)) - self.assertIn("embed.weight", unquantized) - self.assertIn("norm.weight", unquantized) - self.assertNotIn("embed.weight", quantized) - self.assertNotIn("norm.weight", quantized) + self.assertIsInstance(state["proj.weight"], Int4Tensor) + self.assertIs(type(state["embed.weight"]), torch.Tensor) + self.assertIs(type(state["norm.weight"]), torch.Tensor) def test_persistent_buffers_included(self): model = nn.Module() @@ -184,38 +207,23 @@ def test_persistent_buffers_included(self): model.register_buffer("temp", torch.zeros(4), persistent=False) recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) - _, unquantized = quantize_model(model, recipe) + state = quantize_model(model, recipe) - self.assertIn("scalar", unquantized) - self.assertNotIn("temp", unquantized) + self.assertIn("scalar", state) + self.assertNotIn("temp", state) def test_unquantized_cast_to_dtype(self): model = nn.ModuleDict({"proj": nn.Linear(16, 8, bias=False)}) model.proj.weight.data = torch.randn(8, 16, dtype=torch.float32) recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) - _, unquantized = quantize_model(model, recipe, dtype=torch.float16) + state = quantize_model(model, recipe, dtype=torch.float16) - self.assertEqual(unquantized["proj.weight"].dtype, torch.float16) + self.assertEqual(state["proj.weight"].dtype, torch.float16) def test_empty_model(self): - quantized, unquantized = quantize_model(nn.Module(), QuantRecipe(rules=[])) - self.assertEqual(len(quantized), 0) - self.assertEqual(len(unquantized), 0) - - def test_all_quantized(self): - model = nn.ModuleDict({"a": nn.Linear(32, 16, bias=False)}) - model.to(dtype=torch.bfloat16) - for p in model.parameters(): - p.data.normal_(0, 0.02) - - config = QuantConfig(bits=4, group_size=16, symmetric=False, method="min_max") - quantized, unquantized = quantize_model( - model, QuantRecipe(rules=[QuantRule(r".*", config)]) - ) - self.assertEqual(len(quantized), 1) - self.assertIn("a.weight", quantized) - self.assertEqual(len(unquantized), 0) + state = quantize_model(nn.Module(), QuantRecipe(rules=[])) + self.assertEqual(len(state), 0) if __name__ == "__main__": diff --git a/examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py b/examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py new file mode 100644 index 00000000000..7c4fa8decfa --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Smoke tests: torchao subclasses survive safetensors roundtrip.""" + +import os +import tempfile +import unittest + +import torch + +from safetensors import safe_open +from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, +) + + +def save(state_dict, path): + tensors_data, metadata = flatten_tensor_state_dict(state_dict) + save_file(tensors_data, path, metadata=metadata) + + +def load(path): + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + tensors = {k: f.get_tensor(k) for k in f.keys()} + result, _ = unflatten_tensor_state_dict(tensors, metadata) + return result + + +from torchao.quantization import IntxUnpackedToInt8Tensor +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + +def _make_int4(shape, group_size=32): + """Build a random Int4Tensor.""" + N, K = shape + packed = torch.randint(0, 255, (N, K // 2), dtype=torch.uint8) + scale = torch.randn(K // group_size, N, dtype=torch.bfloat16) + zp = torch.zeros(K // group_size, N, dtype=torch.bfloat16) + return Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zp, + block_size=[1, group_size], + shape=torch.Size([N, K]), + ) + + +def _make_int8(shape, group_size=32): + """Build a random IntxUnpackedToInt8Tensor.""" + N, K = shape + return IntxUnpackedToInt8Tensor( + qdata=torch.randint(-128, 127, (N, K), dtype=torch.int8), + scale=torch.randn(N, K // group_size, dtype=torch.bfloat16), + zero_point=torch.zeros(N, K // group_size, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, group_size), + dtype=torch.bfloat16, + activation_quantization=None, + ) + + +class TestSaveLoad(unittest.TestCase): + def test_int4_roundtrip(self): + """Int4Tensor survives save/load.""" + t = _make_int4((64, 128)) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"layer.weight": t}, path) + loaded = load(path) + + self.assertIn("layer.weight", loaded) + self.assertIsInstance(loaded["layer.weight"], Int4Tensor) + self.assertTrue(torch.equal(t.qdata, loaded["layer.weight"].qdata)) + self.assertTrue(torch.equal(t.scale, loaded["layer.weight"].scale)) + + def test_int8_roundtrip(self): + """IntxUnpackedToInt8Tensor survives save/load.""" + t = _make_int8((64, 128)) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"layer.weight": t}, path) + loaded = load(path) + + self.assertIn("layer.weight", loaded) + self.assertIsInstance(loaded["layer.weight"], IntxUnpackedToInt8Tensor) + self.assertTrue(torch.equal(t.qdata, loaded["layer.weight"].qdata)) + + def test_mixed_state_dict(self): + """Mixed Int4 + Int8 + plain tensor roundtrip.""" + state = { + "linear.weight": _make_int4((64, 128)), + "embed.weight": _make_int8((100, 64)), + "norm.weight": torch.randn(64, dtype=torch.bfloat16), + } + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save(state, path) + loaded = load(path) + + self.assertEqual(set(state.keys()), set(loaded.keys())) + self.assertIsInstance(loaded["linear.weight"], Int4Tensor) + self.assertIsInstance(loaded["embed.weight"], IntxUnpackedToInt8Tensor) + self.assertIsInstance(loaded["norm.weight"], torch.Tensor) + self.assertTrue(torch.equal(state["norm.weight"], loaded["norm.weight"])) + + def test_plain_tensor_only(self): + """State dict with only plain tensors roundtrips.""" + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"model.norm.weight": torch.randn(64, dtype=torch.bfloat16)}, path) + loaded = load(path) + self.assertIn("model.norm.weight", loaded) + + def test_3d_int4(self): + """3D Int4Tensor (MoE expert weights) roundtrips.""" + # 3D: (num_experts, N, K//2) packed + N, K, gs = 32, 64, 32 + packed = torch.randint(0, 255, (4, N, K // 2), dtype=torch.uint8) + scale = torch.randn(4, K // gs, N, dtype=torch.bfloat16) + zp = torch.zeros(4, K // gs, N, dtype=torch.bfloat16) + t = Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zp, + block_size=[1, 1, gs], + shape=torch.Size([4, N, K]), + ) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"experts.w1": t}, path) + loaded = load(path) + self.assertTrue(torch.equal(t.qdata, loaded["experts.w1"].qdata)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/tests/test_serialize.py b/examples/models/gemma4_31b/quant/tests/test_serialize.py deleted file mode 100644 index ddbcb257faa..00000000000 --- a/examples/models/gemma4_31b/quant/tests/test_serialize.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unit tests for quant/serialize.py — data format and I/O only. - -Tests nibble pack/unpack and save/load. Does NOT test -quantize_weight (that lives in test_quantize.py). Save/load tests use -hand-built CanonicalQuantizedWeight fixtures to avoid coupling to the -quantizer. -""" - -import json -import os -import tempfile -import unittest - -import torch - -from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig - -from executorch.examples.models.gemma4_31b.quant.serialize import ( - _nibble_pack, - _nibble_unpack, - CanonicalQuantizedWeight, - deserialize, - iter_load, - load, - save, - serialize, -) -from safetensors import safe_open - - -def _make_cqw( - shape: tuple[int, ...], - config: QuantConfig, -) -> CanonicalQuantizedWeight: - """Build a CanonicalQuantizedWeight with random data (no actual quantization).""" - K = shape[-1] - n_groups = K // config.group_size - scale_shape = (*shape[:-1], n_groups) - - if config.bits == 4: - qdata = torch.randint(0, 16, shape, dtype=torch.int8) - else: - qdata = torch.randint(-128, 128, shape, dtype=torch.int8) - - return CanonicalQuantizedWeight( - qdata=qdata, - scale=torch.randn(scale_shape, dtype=torch.bfloat16), - zero=( - torch.randn(scale_shape, dtype=torch.bfloat16) - if not config.symmetric - else None - ), - config=config, - ) - - -# --------------------------------------------------------------------------- -# CanonicalQuantizedWeight validation - - -class TestCanonicalQuantizedWeightValidation(unittest.TestCase): - def test_rejects_non_int8_qdata(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - with self.assertRaises(ValueError) as ctx: - CanonicalQuantizedWeight( - qdata=torch.randint(0, 16, (8, 64), dtype=torch.int32), - scale=torch.randn(8, 2, dtype=torch.bfloat16), - zero=torch.randn(8, 2, dtype=torch.bfloat16), - config=config, - ) - self.assertIn("int8", str(ctx.exception)) - - def test_rejects_indivisible_group_size(self): - config = QuantConfig(bits=4, group_size=33, symmetric=False, method="min_max") - with self.assertRaises(ValueError) as ctx: - CanonicalQuantizedWeight( - qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), - scale=torch.randn(8, 2, dtype=torch.bfloat16), - zero=torch.randn(8, 2, dtype=torch.bfloat16), - config=config, - ) - self.assertIn("divisible", str(ctx.exception)) - - def test_rejects_wrong_scale_numel(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - with self.assertRaises(ValueError) as ctx: - CanonicalQuantizedWeight( - qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), - scale=torch.randn(8, 4, dtype=torch.bfloat16), # should be (8, 2) - zero=torch.randn(8, 4, dtype=torch.bfloat16), - config=config, - ) - self.assertIn("scale", str(ctx.exception)) - - def test_rejects_symmetric_with_zero(self): - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - with self.assertRaises(ValueError) as ctx: - CanonicalQuantizedWeight( - qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), - scale=torch.randn(8, 2, dtype=torch.bfloat16), - zero=torch.randn(8, 2, dtype=torch.bfloat16), - config=config, - ) - self.assertIn("symmetric", str(ctx.exception)) - - def test_rejects_asymmetric_without_zero(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - with self.assertRaises(ValueError) as ctx: - CanonicalQuantizedWeight( - qdata=torch.randint(0, 16, (8, 64), dtype=torch.int8), - scale=torch.randn(8, 2, dtype=torch.bfloat16), - zero=None, - config=config, - ) - self.assertIn("asymmetric", str(ctx.exception)) - - -# --------------------------------------------------------------------------- -# Nibble pack / unpack - - -class TestNibblePack(unittest.TestCase): - def test_roundtrip(self): - qdata = torch.randint(0, 16, (8, 64), dtype=torch.int8) - packed = _nibble_pack(qdata) - self.assertEqual(packed.shape, (8, 32)) - self.assertTrue(torch.equal(qdata, _nibble_unpack(packed, 64))) - - def test_rejects_odd_last_dim(self): - with self.assertRaises(AssertionError): - _nibble_pack(torch.zeros(4, 33, dtype=torch.int8)) - - def test_3d(self): - """Nibble pack works for 3D tensors (MoE expert weights).""" - qdata = torch.randint(0, 16, (4, 32, 64), dtype=torch.int8) - packed = _nibble_pack(qdata) - self.assertEqual(packed.shape, (4, 32, 32)) - self.assertTrue(torch.equal(qdata, _nibble_unpack(packed, 64))) - - -# --------------------------------------------------------------------------- -# save / load - - -class TestSerializeDeserialize(unittest.TestCase): - """Pure logic layer — no disk I/O.""" - - def test_roundtrip(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = _make_cqw((64, 128), config) - unq = {"embed": torch.randn(8, 8, dtype=torch.bfloat16)} - - tensors, header = serialize({"w": cw}, unq) - q, u = deserialize(tensors, header) - - self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) - self.assertTrue(torch.equal(cw.scale, q["w"].scale)) - self.assertTrue(torch.equal(cw.zero, q["w"].zero)) - self.assertEqual(cw.config, q["w"].config) - self.assertTrue(torch.equal(unq["embed"], u["embed"])) - - def test_rejects_unsupported_version(self): - tensors, header = serialize({}, {"w": torch.randn(4, 4)}) - header["format_version"] = "99" - with self.assertRaises(ValueError) as ctx: - deserialize(tensors, header) - self.assertIn("99", str(ctx.exception)) - - -class TestSaveLoad(unittest.TestCase): - """I/O layer — roundtrip through safetensors on disk.""" - - def test_roundtrip_asymmetric(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = _make_cqw((64, 128), config) - unq = {"embed.weight": torch.randn(32, 64, dtype=torch.bfloat16)} - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"w": cw}, unq, path) - q, u = load(path) - - self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) - self.assertTrue(torch.equal(cw.scale, q["w"].scale)) - self.assertTrue(torch.equal(cw.zero, q["w"].zero)) - self.assertEqual(cw.config, q["w"].config) - self.assertTrue(torch.equal(unq["embed.weight"], u["embed.weight"])) - - def test_roundtrip_symmetric(self): - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - cw = _make_cqw((32, 64), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"w": cw}, {}, path) - q, _ = load(path) - - self.assertIsNone(q["w"].zero) - self.assertTrue(torch.equal(cw.qdata, q["w"].qdata)) - - def test_roundtrip_3d(self): - """3D quantized weights (MoE experts) roundtrip correctly.""" - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = _make_cqw((8, 64, 128), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"experts.w1": cw}, {}, path) - q, _ = load(path) - - self.assertTrue(torch.equal(cw.qdata, q["experts.w1"].qdata)) - self.assertEqual(q["experts.w1"].scale.shape, (8, 64, 4)) - - def test_4bit_nibble_packed_on_disk(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = _make_cqw((64, 128), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"w": cw}, {}, path) - with safe_open(path, framework="pt", device="cpu") as f: - on_disk = f.get_tensor("w.qdata") - self.assertEqual(on_disk.shape, (64, 64)) - - def test_8bit_not_nibble_packed(self): - config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - cw = _make_cqw((32, 64), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"w": cw}, {}, path) - with safe_open(path, framework="pt", device="cpu") as f: - on_disk = f.get_tensor("w.qdata") - self.assertEqual(on_disk.shape, (32, 64)) # no packing for 8-bit - - def test_header_metadata(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = _make_cqw((32, 64), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"foo.weight": cw}, {}, path) - with safe_open(path, framework="pt", device="cpu") as f: - meta = json.loads(f.metadata()["quant"]) - - self.assertIn("foo.weight", meta) - self.assertEqual(meta["foo.weight"]["bits"], 4) - self.assertEqual(meta["foo.weight"]["group_size"], 32) - self.assertFalse(meta["foo.weight"]["symmetric"]) - self.assertEqual(meta["foo.weight"]["method"], "min_max") - - def test_empty_quantized(self): - unq = {"w": torch.randn(8, 8, dtype=torch.bfloat16)} - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({}, unq, path) - q, u = load(path) - self.assertEqual(len(q), 0) - self.assertTrue(torch.equal(unq["w"], u["w"])) - - -class TestIterLoad(unittest.TestCase): - """Streaming load — one weight at a time from disk.""" - - def test_yields_all_weights(self): - """iter_load yields every quantized and unquantized weight.""" - q4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - q8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - cw4 = _make_cqw((64, 128), q4) - cw8 = _make_cqw((32, 64), q8) - unq = {"norm.weight": torch.randn(64, dtype=torch.bfloat16)} - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"proj.weight": cw4, "embed.weight": cw8}, unq, path) - items = list(iter_load(path)) - - fqns = {fqn for fqn, _ in items} - self.assertIn("proj.weight", fqns) - self.assertIn("embed.weight", fqns) - self.assertIn("norm.weight", fqns) - self.assertEqual(len(items), 3) - - def test_quantized_matches_load(self): - """Streaming yields identical CQW to batch load.""" - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - cw = _make_cqw((64, 128), config) - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({"w": cw}, {}, path) - - q_batch, _ = load(path) - items = dict(iter_load(path)) - - batch_cw = q_batch["w"] - stream_cw = items["w"] - self.assertIsInstance(stream_cw, CanonicalQuantizedWeight) - self.assertTrue(torch.equal(batch_cw.qdata, stream_cw.qdata)) - self.assertTrue(torch.equal(batch_cw.scale, stream_cw.scale)) - self.assertTrue(torch.equal(batch_cw.zero, stream_cw.zero)) - self.assertEqual(batch_cw.config, stream_cw.config) - - def test_unquantized_matches_load(self): - """Streaming yields identical plain tensors to batch load.""" - unq = {"a": torch.randn(8, 16, dtype=torch.bfloat16)} - - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({}, unq, path) - - _, u_batch = load(path) - items = dict(iter_load(path)) - - self.assertTrue(torch.equal(u_batch["a"], items["a"])) - - def test_empty_file(self): - """Streaming an empty checkpoint yields nothing.""" - with tempfile.TemporaryDirectory() as d: - path = os.path.join(d, "m.safetensors") - save({}, {}, path) - items = list(iter_load(path)) - self.assertEqual(len(items), 0) - - -if __name__ == "__main__": - unittest.main() diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py index 6d048d1e912..959540f2c45 100644 --- a/examples/models/gemma4_31b/quantize_and_save.py +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -6,12 +6,12 @@ """Quantize Gemma 4 31B-IT and save as a quantized checkpoint. -Produces a packing-agnostic safetensors file (int values + per-group scales + -JSON header) that can later be loaded and packed for any backend via -``quant.load()`` and ``quant.pack_model()``. +Produces a safetensors file containing torchao tensor subclasses +(``Int4Tensor``, ``IntxUnpackedToInt8Tensor``) that can be loaded and +packed for any backend via ``load_and_pack_for_cuda`` or ``pack_model``. -No CUDA is needed — quantization runs on CPU. CUDA is only required at -load-and-pack time. +The default recipe runs on CPU. The sensitive recipe requires CUDA for +HQQ asymmetric quantization. CUDA is also required at load-and-pack time. Usage: python quantize_and_save.py \\ @@ -32,7 +32,6 @@ quantize_model, QuantRecipe, QuantRule, - save, ) # --------------------------------------------------------------------------- @@ -46,7 +45,7 @@ # - Norms and layer_scalar are tiny and must stay unquantized. _INT4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") -_INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") +_INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") _INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") _INT8_PER_AXIS = QuantConfig( # group_size = hidden_size (5376) for Gemma 4 31B bits=8, group_size=5376, symmetric=True, method="min_max" @@ -116,12 +115,19 @@ def main() -> None: model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) print(f"Quantizing with recipe '{args.quant_recipe}'...") - quantized, unquantized = quantize_model(model, recipe) + state_dict = quantize_model(model, recipe) os.makedirs(args.output, exist_ok=True) safetensors_path = os.path.join(args.output, "model.safetensors") print("Saving quantized checkpoint...") - n_tensors = save(quantized, unquantized, safetensors_path) + from safetensors.torch import save_file + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + ) + + tensors_data, metadata = flatten_tensor_state_dict(state_dict) + save_file(tensors_data, safetensors_path, metadata=metadata) + n_tensors = len(state_dict) for filename in ("config.json", "tokenizer.json", "tokenizer_config.json"): src = os.path.join(args.model_dir, filename) diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 228b34d9261..d83e76fd630 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -162,11 +162,11 @@ def test_export_from_hf_checkpoint(self): ckpt_dir, max_seq_len=TINY_CONFIG.max_seq_len ) model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) - quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + state_dict = quantize_model(model, DEFAULT_RECIPE) with torch.device("meta"): model = Gemma4_31B(config) - pack_model(model, quantized, unquantized, DEFAULT_CUDA_PACKERS) + pack_model(model, state_dict, DEFAULT_CUDA_PACKERS) model.eval() params = dict(model.named_parameters()) diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index 8bc1b36bfb7..a8d9d9cbe34 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -28,14 +28,17 @@ RingKVCache, ) from executorch.examples.models.gemma4_31b.quant import ( - load, QuantConfig, quantize_model, QuantRecipe, QuantRule, - save, ) +from safetensors import safe_open from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, +) # --------------------------------------------------------------------------- @@ -137,9 +140,10 @@ def build_random_tiny_model() -> Gemma4_31B: def save_checkpoint(output_dir: str): model = build_random_tiny_model() model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) - quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + state_dict = quantize_model(model, DEFAULT_RECIPE) os.makedirs(output_dir, exist_ok=True) - save(quantized, unquantized, os.path.join(output_dir, "model.safetensors")) + td, md = flatten_tensor_state_dict(state_dict) + save_file(td, os.path.join(output_dir, "model.safetensors"), metadata=md) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config_dict(), f) @@ -160,53 +164,49 @@ def build_hf_checkpoint(output_dir: str) -> None: class TestQuantizeSaveLoadRoundtrip(unittest.TestCase): def test_roundtrip_preserves_weights(self): - """quantize → save → load recovers all weights and configs.""" + """quantize → save → load recovers all weights.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + model = build_random_tiny_model() model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) - quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) + state_dict = quantize_model(model, DEFAULT_RECIPE) with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "model.safetensors") - save(quantized, unquantized, path) - q_loaded, u_loaded = load(path) - - self.assertEqual(set(quantized.keys()), set(q_loaded.keys())) - for fqn in quantized: - self.assertEqual(quantized[fqn].config, q_loaded[fqn].config) - self.assertTrue(torch.equal(quantized[fqn].qdata, q_loaded[fqn].qdata)) - self.assertTrue(torch.equal(quantized[fqn].scale, q_loaded[fqn].scale)) - - self.assertEqual(set(unquantized.keys()), set(u_loaded.keys())) - for fqn in unquantized: - self.assertTrue(torch.equal(unquantized[fqn], u_loaded[fqn])) + td, md = flatten_tensor_state_dict(state_dict) + save_file(td, path, metadata=md) + with safe_open(path, framework="pt", device="cpu") as f: + loaded_meta = f.metadata() + loaded_tensors = {k: f.get_tensor(k) for k in f.keys()} + loaded, _ = unflatten_tensor_state_dict(loaded_tensors, loaded_meta) + + self.assertEqual(set(state_dict.keys()), set(loaded.keys())) + for fqn in state_dict: + orig = state_dict[fqn] + got = loaded[fqn] + self.assertEqual(type(orig).__name__, type(got).__name__) + if isinstance(orig, Int4Tensor): + self.assertTrue(torch.equal(orig.qdata, got.qdata)) + self.assertTrue(torch.equal(orig.scale, got.scale)) + elif isinstance(orig, IntxUnpackedToInt8Tensor): + self.assertTrue(torch.equal(orig.qdata, got.qdata)) + self.assertTrue(torch.equal(orig.scale, got.scale)) + elif isinstance(orig, torch.Tensor): + self.assertTrue(torch.equal(orig, got)) def test_embedding_quantized_as_int8(self): - """embed_tokens is quantized to INT8 per-axis, not skipped.""" + """embed_tokens is quantized to INT8 (IntxUnpackedToInt8Tensor).""" + from torchao.quantization import IntxUnpackedToInt8Tensor + model = build_random_tiny_model() model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) - quantized, unquantized = quantize_model(model, DEFAULT_RECIPE) - - self.assertIn("embed_tokens.weight", quantized) - self.assertNotIn("embed_tokens.weight", unquantized) - self.assertEqual(quantized["embed_tokens.weight"].config.bits, 8) - - def test_corrupted_checkpoint_missing_key(self): - """Renaming a key in the safetensors file makes it absent after load.""" - from safetensors import safe_open + state_dict = quantize_model(model, DEFAULT_RECIPE) - with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(tmpdir) - path = os.path.join(tmpdir, "model.safetensors") - - with safe_open(path, framework="pt", device="cpu") as f: - header = f.metadata() - tensors = {k: f.get_tensor(k) for k in f.keys()} - tensors["norm.BOGUS"] = tensors.pop("norm.weight") - save_file(tensors, path, metadata=header) - - q, u = load(path) - self.assertNotIn("norm.weight", u) - self.assertIn("norm.BOGUS", u) + self.assertIn("embed_tokens.weight", state_dict) + self.assertIsInstance( + state_dict["embed_tokens.weight"], IntxUnpackedToInt8Tensor + ) class TestRingKVCache(unittest.TestCase): @@ -266,5 +266,45 @@ def test_assert_on_oversized_prefill(self): cache.update(pos, k, v) +class TestGgufKeyMapping(unittest.TestCase): + """Unit tests for gguf_loader.gguf_to_model_key (CPU, no GGUF file needed).""" + + def test_attention_keys(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertEqual( + gguf_to_model_key("blk.0.attn_q.weight"), + "layers.0.self_attn.q_proj.weight", + ) + self.assertEqual( + gguf_to_model_key("blk.59.attn_output.weight"), + "layers.59.self_attn.o_proj.weight", + ) + + def test_mlp_keys(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertEqual( + gguf_to_model_key("blk.5.ffn_gate.weight"), + "layers.5.mlp.gate_proj.weight", + ) + + def test_global_keys(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertEqual(gguf_to_model_key("token_embd.weight"), "embed_tokens.weight") + self.assertEqual(gguf_to_model_key("output_norm.weight"), "norm.weight") + + def test_unknown_key_returns_none(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertIsNone(gguf_to_model_key("blk.0.some_unknown.weight")) + + def test_ignored_key_returns_none(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertIsNone(gguf_to_model_key("rope_freqs.weight")) + + if __name__ == "__main__": unittest.main() From 48d83c7bb9d917963a7077c2313cb16c161d5ee2 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 1 May 2026 14:24:23 -0700 Subject: [PATCH 09/14] Fix minor issues in pack_cuda.py - Use .detach() instead of .data when moving packed INT4 weight to CPU to preserve tensor subclass identity safely - Remove unused loaded_keys set in load_and_pack_for_cuda - Handle top-level tensor keys (no dot) in load_and_pack_for_cuda --- examples/models/gemma4_31b/quant/pack_cuda.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index 21949024fcc..bfaa5d955ae 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -112,7 +112,7 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> # Pack on CUDA (required by _convert_weight_to_int4pack), move back # to CPU for assembly. The model moves to CUDA later at runtime. packed = pack_int4_for_cuda(w, device="cuda") - module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False) + module.weight = nn.Parameter(packed.detach().to("cpu"), requires_grad=False) torch.cuda.empty_cache() elif isinstance(w, IntxUnpackedToInt8Tensor): module.weight = nn.Parameter(w, requires_grad=False) @@ -166,15 +166,17 @@ def load_and_pack_for_cuda( # Stream one logical weight at a time: load its inner tensors, # reconstruct the subclass, pack, then release before the next. - loaded_keys: set[str] = set() for name in tensor_names: - module_fqn, weight_name = name.rsplit(".", 1) - prefix = f"{module_fqn}._{weight_name}_" + parts = name.rsplit(".", 1) + module_fqn = parts[0] if len(parts) > 1 else "" + weight_name = parts[-1] + prefix = ( + f"{module_fqn}._{weight_name}_" if module_fqn else f"_{weight_name}_" + ) partial = {} for key in all_keys: if key.startswith(prefix) or key == name: partial[key] = f.get_tensor(key) - loaded_keys.add(key) result, _ = unflatten_tensor_state_dict(partial, metadata) for fqn, value in result.items(): pack_one(model, fqn, value, _packers) From 869a2ac881a01f72f74803c231979ea55e7d5891 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Mon, 4 May 2026 08:47:31 -0700 Subject: [PATCH 10/14] Add split-K flash-decoding for decode SDPA in CUDA backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend ReplaceEdgeOpWithTritonOpPass to select triton::sdpa_decode_splitk for SDPA nodes where L_q=1 (decode) and L_kv exceeds 2048 (large KV cache). This dramatically improves GPU utilization for full-attention layers at long context lengths — standard SDPA launches only a handful of CTAs (proportional to H_kv), while split-K partitions the KV sequence across up to 128 CTAs. Benchmarked on A100 with Gemma4 31B shapes at 128K context: Full-attention decode (H_kv=4, D=512, L_kv=131072): standard SDPA: 15.7ms/layer → split-K: 0.7ms/layer (22x) Sliding-attention decode (H_kv=16, D=256, L_kv=2048): unchanged (standard SDPA is faster for small L_kv) The threshold of 2048 is chosen to match the sliding-window ring buffer size — anything above is a full-attention cache where split-K wins. No changes to model code — the pass inspects Q/K shapes in the exported graph and selects the kernel automatically. --- .../tests/test_sdpa_splitk_replacement.py | 163 ++++++++++++++++++ backends/cuda/triton/replacement_pass.py | 33 ++++ 2 files changed, 196 insertions(+) create mode 100644 backends/cuda/tests/test_sdpa_splitk_replacement.py diff --git a/backends/cuda/tests/test_sdpa_splitk_replacement.py b/backends/cuda/tests/test_sdpa_splitk_replacement.py new file mode 100644 index 00000000000..414a1308777 --- /dev/null +++ b/backends/cuda/tests/test_sdpa_splitk_replacement.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Test ReplaceEdgeOpWithTritonOpPass split-K SDPA kernel selection. + +Exports a minimal model containing F.scaled_dot_product_attention through +the CUDA backend and verifies that the pass routes to split-K for decode +(L_q=1, large L_kv) and standard SDPA otherwise. +""" + +import logging +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +class SDPAModule(nn.Module): + """Single-layer model with SDPA and a static KV cache buffer.""" + + def __init__(self, n_heads, n_kv_heads, head_dim, kv_len): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + hidden = n_heads * head_dim + self.q_proj = nn.Linear(hidden, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(hidden, n_kv_heads * head_dim, bias=False) + self.register_buffer( + "k_cache", torch.zeros(1, n_kv_heads, kv_len, head_dim), persistent=False + ) + self.register_buffer( + "v_cache", torch.zeros(1, n_kv_heads, kv_len, head_dim), persistent=False + ) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + B, T, _ = x.shape + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + self.k_cache.index_copy_(2, input_pos, k) + self.v_cache.index_copy_(2, input_pos, v) + y = F.scaled_dot_product_attention( + q, + self.k_cache, + self.v_cache, + enable_gqa=True, + ) + return y.transpose(1, 2).contiguous().view(B, T, -1) + + +def _export_through_cuda_backend(model, example_args): + """Export and lower through the CUDA backend (stops before to_executorch).""" + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from torch.export import export + + with torch.no_grad(): + ep = export(model, example_args, strict=True) + + return to_edge_transform_and_lower( + {"decode": ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("decode")] + ) + ], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + + +def _capture_pass_logs(fn): + """Run fn and return replacement pass log messages.""" + pass_logger = logging.getLogger("executorch.backends.cuda.triton.replacement_pass") + prev_level = pass_logger.level + pass_logger.setLevel(logging.INFO) + messages = [] + handler = logging.Handler() + handler.emit = lambda record: messages.append(record.getMessage()) + pass_logger.addHandler(handler) + try: + return fn(), messages + finally: + pass_logger.removeHandler(handler) + pass_logger.setLevel(prev_level) + + +class TestSplitKReplacement(unittest.TestCase): + + def setUp(self): + _require_cuda(self) + + def test_large_kv_cache_uses_splitk(self): + """L_kv=4096 > threshold → split-K selected for decode.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + torch.bfloat16 + ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=4096", splitk[0]) + + def test_small_kv_cache_uses_standard(self): + """L_kv=512 <= threshold → standard SDPA, no split-K.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=512).to( + torch.bfloat16 + ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") + + replaced = [m for m in msgs if "Replaced" in m] + self.assertTrue( + any("1 nodes" in m for m in replaced), + f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + ) + + def test_non_pow2_head_dim_uses_standard(self): + """Non-power-of-2 head_dim → standard SDPA even with large L_kv.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=96, kv_len=8192).to( + torch.bfloat16 + ) + args = ( + torch.zeros(1, 1, 384, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 0, f"Expected no split-K for D=96. Got: {splitk}") + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index b14628d29cf..628222e46f7 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,6 +27,8 @@ exir_ops.edge.aten.topk.default: triton.topk, } +_SPLITK_LKV_THRESHOLD = 2048 + class ReplaceEdgeOpWithTritonOpPass(PassBase): """ @@ -83,6 +85,34 @@ def call(self, graph_module: GraphModule) -> PassResult: # for rows larger than this threshold. _TOPK_MAX_N = 4096 + @staticmethod + def _pick_sdpa_kernel(node: Node): + """Choose between standard SDPA and split-K flash-decoding. + + Split-K partitions the KV sequence across many CTAs for better GPU + utilization at decode time (L_q=1). It wins when L_kv is large + (full-attention KV caches) but loses to the standard kernel for + small L_kv (sliding-window ring buffers) due to the overhead of + allocating partial buffers and running the reduction kernel. + """ + q_shape = node.args[0].meta["val"].shape + k_shape = node.args[1].meta["val"].shape + L_q, D = q_shape[2], q_shape[3] + L_kv = k_shape[2] + + if ( + isinstance(L_q, int) + and L_q == 1 + and isinstance(L_kv, int) + and L_kv > _SPLITK_LKV_THRESHOLD + and D > 0 + and (D & (D - 1)) == 0 # power of 2 + ): + logger.info(f"Using split-K decode SDPA (L_kv={L_kv}, D={D})") + return triton.sdpa_decode_splitk + + return triton.sdpa + def _should_replace_node(self, node: Node) -> bool: """ Check if a node should be replaced with a Triton kernel. @@ -128,6 +158,9 @@ def _replace_node_with_triton(self, graph_module: GraphModule, node: Node) -> No triton_kernel_fn = EDGE_TO_TRITON_KERNELS[target] + if target == exir_ops.edge.aten.scaled_dot_product_attention.default: + triton_kernel_fn = self._pick_sdpa_kernel(node) + # Create a new node with the Triton kernel with graph_module.graph.inserting_before(node): # The triton_kernel_fn is already registered as a custom op via @triton_op From cd50af263ec47819fab440067a9a6d28649fbcee Mon Sep 17 00:00:00 2001 From: mnachin Date: Fri, 8 May 2026 11:35:26 -0700 Subject: [PATCH 11/14] INT4 plain matmul: dp4a decode kernel + dequant dispatch Adds executorch_cuda::int4_plain_mm custom op that reads Int4Tensor's plain [N, K//2] nibble-packed format directly. C shim (.pte runtime): W4A8 dp4a matvec with dynamic INT8 activation quantization, 16-byte vectorized loads, warp-cooperative quantization. No cuBLAS dependency. Eager dispatch: M<=4 routes through the custom op (dp4a in .pte, dequant + F.linear in eager). M>4 uses inline dequant + F.linear, which AOTI compiles into the .so using inductor's own cuBLAS codegen. --- backends/cuda/CMakeLists.txt | 3 +- backends/cuda/cuda_backend.py | 16 + backends/cuda/int4_dispatch.py | 101 +++++ backends/cuda/runtime/shims/int4_plain_mm.cu | 81 ++++ backends/cuda/runtime/shims/int4_plain_mm.cuh | 266 ++++++++++++ backends/cuda/runtime/shims/int4_plain_mm.h | 53 +++ .../cuda/runtime/shims/tests/CMakeLists.txt | 23 + .../test_aoti_torch_cuda_int4_plain_mm.cpp | 397 ++++++++++++++++++ backends/cuda/tests/test_int4_dispatch.py | 204 +++++++++ examples/models/gemma4_31b/export.py | 57 ++- examples/models/gemma4_31b/inference.py | 2 + examples/models/gemma4_31b/main.cpp | 5 +- examples/models/gemma4_31b/quant/pack_cuda.py | 106 +---- examples/models/gemma4_31b/quant/quantize.py | 7 +- examples/models/gemma4_31b/quant/recipe.py | 2 +- .../gemma4_31b/quant/tests/test_pack_cuda.py | 258 ++++++------ .../models/gemma4_31b/quantize_and_save.py | 4 +- .../gemma4_31b/tests/test_cuda_pipeline.py | 74 +++- 18 files changed, 1402 insertions(+), 257 deletions(-) create mode 100644 backends/cuda/int4_dispatch.py create mode 100644 backends/cuda/runtime/shims/int4_plain_mm.cu create mode 100644 backends/cuda/runtime/shims/int4_plain_mm.cuh create mode 100644 backends/cuda/runtime/shims/int4_plain_mm.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp create mode 100644 backends/cuda/tests/test_int4_dispatch.py diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 157cc05a54f..217c893efe5 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -110,7 +110,8 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp # Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu - runtime/shims/sort.cu runtime/shims/rand.cu + runtime/shims/int4_plain_mm.cu runtime/shims/sort.cu + runtime/shims/rand.cu ) endif() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index a3169680b6d..d732a12a8fe 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -226,6 +226,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "at::_ops::_weight_int4pack_mm::call": None, "at::_ops::sort_stable::call": None, "aoti_torch_cuda_randint_low_out": None, + "executorch_cuda::int4_plain_mm": None, + "aoti_torch_cuda_int4_plain_mm": None, } @classmethod @@ -298,6 +300,20 @@ def get_aoti_compile_options( "aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel, } + try: + import torch + + options["aot_inductor.custom_ops_to_c_shims"] = { + torch.ops.executorch_cuda.int4_plain_mm.default: [ + "AOTITorchError aoti_torch_cuda_int4_plain_mm(" + "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " + "AtenTensorHandle, int64_t, AtenTensorHandle*)" + ], + } + except AttributeError: + # int4_dispatch.py not imported — op not registered, skip C shim mapping + pass + # Parse compile_specs to check for platform platform = "linux" diff --git a/backends/cuda/int4_dispatch.py b/backends/cuda/int4_dispatch.py new file mode 100644 index 00000000000..b4fc01c9105 --- /dev/null +++ b/backends/cuda/int4_dispatch.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Int4Tensor F.linear dispatch for CUDA. + +Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` — in eager this + dequants + calls F.linear; in .pte runtime the C shim runs a + W4A8 dp4a matvec kernel. +Prefill (M>4): Inline dequant + F.linear — AOTI compiles this into the .so + using inductor's own cuBLAS codegen, so no explicit cuBLAS + dependency in our shim library. + +Import this module before using nn.Linear with Int4Tensor weights:: + + import executorch.backends.cuda.int4_dispatch # noqa: F401 +""" + +import torch +import torch.nn.functional as F +from torch.library import impl, Library +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + +# --------------------------------------------------------------------------- +# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager +# --------------------------------------------------------------------------- + +_lib = Library("executorch_cuda", "DEF") +_lib.define( + "int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor" +) + + +@impl(_lib, "int4_plain_mm", "Meta") +def _meta(self, qdata, scale, zero, group_size): + return torch.empty( + self.shape[0], qdata.shape[0], dtype=self.dtype, device=self.device + ) + + +@impl(_lib, "int4_plain_mm", "CUDA") +def _cuda(self, qdata, scale, zero, group_size): + return _dequant_matmul(self, qdata, scale, zero, group_size) + + +def _dequant_matmul(x, qdata, scale, zero, group_size): + """Dequant INT4 weights to input dtype and call F.linear.""" + N, K_half = qdata.shape + K = K_half * 2 + n_groups = K // group_size + gs_half = group_size // 2 + dtype = x.dtype + + p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) + low = (p & 0x0F).to(dtype) + high = ((p >> 4) & 0x0F).to(dtype) + data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) + + s = scale.to(dtype).t().unsqueeze(-1) + z = zero.to(dtype).t().unsqueeze(-1) + w_deq = ((data - z) * s).reshape(N, K) + + return F.linear(x, w_deq) + + +# --------------------------------------------------------------------------- +# Int4Tensor F.linear dispatch +# --------------------------------------------------------------------------- + +aten = torch.ops.aten +_implements = Int4Tensor.implements +_implements_torch_function = Int4Tensor.implements_torch_function + + +@_implements([aten.linear.default]) +@_implements_torch_function([F.linear]) +def _(func, types, args, kwargs): + input_tensor = args[0] + weight_tensor = args[1] + bias = args[2] if len(args) > 2 else None + + orig_shape = input_tensor.shape + x_2d = input_tensor.reshape(-1, orig_shape[-1]) + + qdata = weight_tensor.qdata + scale = weight_tensor.scale + zero = weight_tensor.zero_point + gs = weight_tensor.block_size[-1] + + M = x_2d.shape[0] + if M <= 4: + out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs) + else: + out = _dequant_matmul(x_2d, qdata, scale, zero, gs) + + out = out.reshape(*orig_shape[:-1], -1) + if bias is not None: + out = out + bias + return out diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cu b/backends/cuda/runtime/shims/int4_plain_mm.cu new file mode 100644 index 00000000000..fd8fe3b0c3b --- /dev/null +++ b/backends/cuda/runtime/shims/int4_plain_mm.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda_int4_plain_mm( + Tensor* self, + Tensor* qdata, + Tensor* scale, + Tensor* zero, + int64_t group_size, + Tensor** ret0) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + qdata != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: qdata is null"); + + ET_CHECK_OR_RETURN_ERROR( + scale != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale is null"); + + ET_CHECK_OR_RETURN_ERROR( + zero != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: zero is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: ret0 is null"); + + int32_t M = self->size(0); + int32_t N = qdata->size(0); + Tensor* C = nullptr; + std::array c_shape = {M, N}; + std::array c_stride = {N, 1}; + aoti_torch_empty_strided( + 2, + c_shape.data(), + c_stride.data(), + static_cast( + executorch::backends::aoti::slim::c10::ScalarType::BFloat16), + static_cast( + executorch::backends::aoti::slim::c10::DeviceType::CUDA), + 0, + &C); + + _int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + + *ret0 = C; + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh new file mode 100644 index 00000000000..64fccb7c093 --- /dev/null +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -0,0 +1,266 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// W4A8 dp4a matvec for INT4 decode (M <= 4). +// +// Reads plain nibble-packed [N, K//2] weights (Int4Tensor format). +// Scale/zero layout: [K//gs, N] (Int4Tensor's native layout). +// +// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks), +// then uses dp4a for fused int4×int8 dot products with 16-byte vectorized +// loads and warp-cooperative quantization. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +namespace c10 = executorch::backends::aoti::slim::c10; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +constexpr int32_t MV_NWARPS = 8; +constexpr int32_t MV_WARP_SIZE = 32; +constexpr int32_t MV_THREADS = MV_NWARPS * MV_WARP_SIZE; +constexpr int32_t Q8_BLOCK_SIZE = 32; + +__host__ __forceinline__ int32_t log2_pow2(int32_t v) { + int32_t r = 0; + while (v > 1) { + v >>= 1; + r++; + } + return r; +} + +// --------------------------------------------------------------------------- +// Activation quantization: bf16 → int8 (warp-cooperative, per-32-element blocks) +// --------------------------------------------------------------------------- + +struct Q8Block { + int8_t qs_even[Q8_BLOCK_SIZE / 2]; + int8_t qs_odd[Q8_BLOCK_SIZE / 2]; + float d; // scale +}; + +__global__ void quantize_activations_q8_kernel( + const __nv_bfloat16* __restrict__ A, + Q8Block* __restrict__ q8, + int32_t K) { + const int32_t m = blockIdx.y; + const int32_t block_id = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t n_blocks = K / Q8_BLOCK_SIZE; + if (block_id >= n_blocks) + return; + + const int32_t lane = threadIdx.x; + const __nv_bfloat16* src = + A + static_cast(m) * K + block_id * Q8_BLOCK_SIZE; + Q8Block* dst = q8 + static_cast(m) * n_blocks + block_id; + + float val = __bfloat162float(src[lane]); + + float amax = fabsf(val); + for (int offset = 16; offset > 0; offset >>= 1) + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, offset)); + + float d = amax / 127.0f; + float id = (d > 0.0f) ? 1.0f / d : 0.0f; + int32_t q = __float2int_rn(val * id); + q = max(-128, min(127, q)); + + if (lane % 2 == 0) + dst->qs_even[lane / 2] = static_cast(q); + else + dst->qs_odd[lane / 2] = static_cast(q); + + if (lane == 0) { + dst->d = d; + } +} + +// --------------------------------------------------------------------------- +// W4A8 dp4a matvec kernel +// --------------------------------------------------------------------------- + +__global__ void __launch_bounds__(MV_THREADS) + int4_w4a8_matvec_kernel( + const uint8_t* __restrict__ qdata, + const __nv_bfloat16* __restrict__ w_scale, + const __nv_bfloat16* __restrict__ w_zero, + const Q8Block* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift) { + const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y; + const int32_t m = blockIdx.y; + if (n >= N) + return; + + const int32_t K_half = K / 2; + const int32_t lane_id = threadIdx.x; + const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; + + const uint8_t* qrow = qdata + static_cast(n) * K_half; + const __nv_bfloat16* scale_base = w_scale + n; + const __nv_bfloat16* zero_base = w_zero + n; + const int32_t scale_stride = N; + const Q8Block* q8_row = q8 + static_cast(m) * n_q8_blocks; + + const uint4* qrow16 = reinterpret_cast(qrow); + const int32_t K_half_16 = K_half / 16; + + float sum = 0.0f; + + for (int32_t i = lane_id; i < K_half_16; i += MV_WARP_SIZE) { + uint4 packed16 = __ldg(&qrow16[i]); + int32_t k_base = i * 32; + uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + +#pragma unroll + for (int32_t w = 0; w < 4; w++) { + uint32_t packed = words[w]; + int32_t k_word = k_base + w * 8; + int32_t g = k_word >> gs_shift; + + int32_t vi_lo = packed & 0x0F0F0F0F; + int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; + + int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE; + int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2; + const Q8Block* qb = &q8_row[q8_block_idx]; + + int32_t a_even = *reinterpret_cast( + qb->qs_even + q8_half_offset); + int32_t a_odd = *reinterpret_cast( + qb->qs_odd + q8_half_offset); + + int32_t dp = __dp4a(vi_lo, a_even, 0); + dp = __dp4a(vi_hi, a_odd, dp); + + float ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); + float wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); + float a_scale = qb->d; + + int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); + a_sum8 = __dp4a(0x01010101, a_odd, a_sum8); + + sum += ws * a_scale * + (static_cast(dp) - wz * static_cast(a_sum8)); + } + } + + for (int offset = MV_WARP_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + if (lane_id == 0) + out[static_cast(m) * N + n] = __float2bfloat16(sum); +} + +// --------------------------------------------------------------------------- +// Persistent Q8 buffer (lazy init, not thread-safe — single-stream only) +// --------------------------------------------------------------------------- + +static Q8Block* g_q8_buf = nullptr; +static size_t g_q8_buf_size = 0; + +static Q8Block* get_q8_buffer(size_t needed) { + if (g_q8_buf_size < needed) { + if (g_q8_buf) + cudaFree(g_q8_buf); + cudaError_t err = cudaMalloc(&g_q8_buf, needed); + ET_CHECK_MSG( + err == cudaSuccess, + "cudaMalloc failed for Q8 buffer: %s", + cudaGetErrorString(err)); + g_q8_buf_size = needed; + } + return g_q8_buf; +} + +// --------------------------------------------------------------------------- +// Main entry point +// --------------------------------------------------------------------------- + +void _int4_plain_mm_cuda( + const Tensor& A, // [M, K] bf16 + const Tensor& qdata, // [N, K//2] uint8 + const Tensor& scale, // [K//gs, N] bf16 + const Tensor& zero, // [K//gs, N] bf16 + int64_t group_size, + Tensor* output) { // [M, N] bf16, pre-allocated + int32_t M = A.size(0); + int32_t K = A.size(1); + int32_t N = qdata.size(0); + + ET_CHECK(A.dtype() == c10::ScalarType::BFloat16); + ET_CHECK(A.dim() == 2); + ET_CHECK(qdata.dim() == 2); + ET_CHECK(qdata.size(1) == K / 2); + ET_CHECK(scale.dim() == 2); + ET_CHECK(scale.size(1) == N); + ET_CHECK(zero.dim() == 2); + ET_CHECK(zero.size(1) == N); + + int32_t gs = static_cast(group_size); + ET_CHECK_MSG( + gs > 0 && (gs & (gs - 1)) == 0, + "group_size=%d must be a power of 2", + gs); + ET_CHECK_MSG( + K >= Q8_BLOCK_SIZE && K % Q8_BLOCK_SIZE == 0, + "K=%d must be a positive multiple of %d for dp4a kernel", + K, + Q8_BLOCK_SIZE); + + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_MSG(stream_result.ok(), "Failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + int32_t gs_shift = log2_pow2(gs); + + // Quantize activations to INT8 + int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; + size_t q8_bytes = static_cast(M) * n_q8_blocks * sizeof(Q8Block); + Q8Block* q8_buf = get_q8_buffer(q8_bytes); + + constexpr int32_t Q8_WARPS = 8; + int32_t blocks_per_m = (n_q8_blocks + Q8_WARPS - 1) / Q8_WARPS; + dim3 q8_grid(blocks_per_m, M); + dim3 q8_block(MV_WARP_SIZE, Q8_WARPS); + quantize_activations_q8_kernel<<>>( + reinterpret_cast(A.data_ptr()), + q8_buf, + K); + + // dp4a matvec + dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M); + dim3 block(MV_WARP_SIZE, MV_NWARPS); + int4_w4a8_matvec_kernel<<>>( + reinterpret_cast(qdata.data_ptr()), + reinterpret_cast(scale.data_ptr()), + reinterpret_cast(zero.data_ptr()), + q8_buf, + reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), + N, K, gs_shift); +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4_plain_mm.h b/backends/cuda/runtime/shims/int4_plain_mm.h new file mode 100644 index 00000000000..0935937cd7a --- /dev/null +++ b/backends/cuda/runtime/shims/int4_plain_mm.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * INT4 matrix multiplication reading plain nibble-packed weights. + * + * Weight format: [N, K//2] uint8, two INT4 values per byte + * (low nibble = even k, high nibble = odd k). + * Scale: [K//group_size, N] bf16 per-group scales (Int4Tensor layout). + * Zero: [K//group_size, N] bf16 per-group zero points. + * W4A8 dp4a matvec: dynamically quantizes activations to INT8, + * then uses dp4a for fused int4×int8 dot products. + * + * @param self Input activation [M, K] bf16 + * @param qdata Packed weights [N, K//2] uint8 + * @param scale Per-group scales [K//group_size, N] bf16 + * @param zero Per-group zero points [K//group_size, N] bf16 + * @param group_size Quantization group size (32, 64, 128) + * @param ret0 Output [M, N] bf16 + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_int4_plain_mm( + Tensor* self, + Tensor* qdata, + Tensor* scale, + Tensor* zero, + int64_t group_size, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index aec5219d680..62e9180d603 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -48,6 +48,11 @@ set(CUDA_SHIM_TESTS test_aoti_torch_assign_tensors_out ) +# CUDA-specific tests requiring GPU kernels +set(CUDA_KERNEL_TESTS test_aoti_torch_cuda__weight_int4pack_mm + test_aoti_torch_cuda_int4_plain_mm +) + enable_testing() foreach(test_name ${CUDA_SHIM_TESTS}) @@ -67,3 +72,21 @@ foreach(test_name ${CUDA_SHIM_TESTS}) add_test(NAME ${test_name} COMMAND ${test_name}) endforeach() + +foreach(test_name ${CUDA_KERNEL_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + target_include_directories( + ${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ) + + target_compile_definitions(${test_name} PRIVATE CUDA_AVAILABLE=1) + + target_link_libraries( + ${test_name} PRIVATE GTest::gtest GTest::gtest_main aoti_cuda_shims + executorch_core CUDA::cudart + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) +endforeach() diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp new file mode 100644 index 00000000000..ab18e33c713 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp @@ -0,0 +1,397 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::backends::cuda::aoti_torch_cuda_int4_plain_mm; +using executorch::backends::cuda::aoti_torch_empty_strided; +using executorch::backends::cuda::AOTITorchError; +using executorch::runtime::Error; +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +using Tensor = executorch::backends::aoti::slim::SlimTensor; + +class AOTITorchInt4PlainMMTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available"; + } + } + + Tensor* create_tensor( + const std::vector& sizes, + slim_c10::ScalarType dtype) { + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + static_cast(dtype), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } + + Tensor* create_bf16(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::BFloat16); + } + + Tensor* create_uint8(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::Byte); + } + + // Upload raw bytes to a CUDA tensor. + void upload(Tensor* t, const void* host_data, size_t bytes) { + cudaMemcpy(t->data_ptr(), host_data, bytes, cudaMemcpyHostToDevice); + } + + // Download CUDA tensor to host buffer. + void download(const Tensor* t, void* host_data, size_t bytes) { + cudaMemcpy(host_data, t->data_ptr(), bytes, cudaMemcpyDeviceToHost); + } + + // Run the shim and return the output tensor (asserts success). + Tensor* run( + Tensor* A, + Tensor* qdata, + Tensor* scale, + Tensor* zero, + int64_t group_size) { + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_int4_plain_mm( + A, qdata, scale, zero, group_size, &output); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(output, nullptr); + return output; + } + + // Check output bf16 values against expected, with absolute tolerance. + void check_bf16_output( + Tensor* output, + const uint16_t* expected_data, + int64_t count, + float atol = 0.1f) { + std::vector actual(count); + download(output, actual.data(), count * sizeof(uint16_t)); + cudaDeviceSynchronize(); + + for (int64_t i = 0; i < count; i++) { + // Convert bf16 raw bits to float: bf16 is the upper 16 bits of float32. + uint32_t actual_bits = static_cast(actual[i]) << 16; + uint32_t expected_bits = static_cast(expected_data[i]) << 16; + float actual_f, expected_f; + memcpy(&actual_f, &actual_bits, sizeof(float)); + memcpy(&expected_f, &expected_bits, sizeof(float)); + + EXPECT_NEAR(actual_f, expected_f, atol) + << "Mismatch at index " << i << ": actual=" << actual_f + << " expected=" << expected_f; + } + } +}; + +// MultiGroupRandom: M=1, N=4, K=32, gs=16 +// scale/zero layout: [K//gs=2, N=4] +TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { + int64_t M = 1, K = 32, N = 4, gs = 16; + + // clang-format off + uint8_t qdata_host[] = { + 0x36, 0xEC, 0x7A, 0x4C, 0x96, 0x62, 0xAA, 0x47, + 0x73, 0x27, 0x45, 0x71, 0xDB, 0x15, 0xBF, 0x04, + 0x9B, 0xC5, 0x8B, 0xA0, 0xEA, 0xF9, 0xBB, 0xEF, + 0xDD, 0xDE, 0xB2, 0x36, 0x8F, 0x42, 0x62, 0x84, + 0x16, 0x83, 0xDB, 0x91, 0x98, 0x14, 0xB3, 0xBE, + 0xB6, 0x7C, 0x2E, 0x0D, 0x13, 0x37, 0xD1, 0x55, + 0x39, 0xC5, 0x1E, 0xB9, 0x91, 0x3D, 0xED, 0xEF, + 0xD7, 0xB6, 0xD8, 0x47, 0xCF, 0xE1, 0x74, 0x89 + }; + uint16_t scale_host[] = {0xBD86, 0x3DB0, 0xBE26, 0xBE0F, 0xBCC3, 0xBD4E, 0xBE7D, 0xBDBE}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0x3FAD, 0xBF3D, 0x3F9E, 0x4002, 0x3F34, 0x3F9B, 0x3F49, 0x3F8F, 0x3FAD, 0x3DD8, 0x3DFA, 0xBFA5, 0xBF02, 0xBE45, 0x3F97, 0x3F5F, 0xBF85, 0x3DFD, 0x3EDE, 0x3E42, 0xBF86, 0xBE84, 0xBF06, 0x3F9E, 0xBF22, 0x3FDE, 0xBF2E, 0x3E6B, 0x3F72, 0xBEE5, 0x3EBB, 0xC00F}; + uint16_t expected[] = {0xBFCC, 0x3FB5, 0x4046, 0xC01E}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// SingleGroup: M=1, N=8, K=32, gs=32 +// scale/zero layout: [K//gs=1, N=8] +TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { + int64_t M = 1, K = 32, N = 8, gs = 32; + + // clang-format off + uint8_t qdata_host[] = { + 0x31, 0x89, 0x89, 0x42, 0x45, 0x71, 0x3E, 0x17, + 0x01, 0xBD, 0xB6, 0x74, 0x02, 0x8C, 0x48, 0xB9, + 0xF7, 0xFA, 0xEB, 0xE5, 0xC4, 0xE9, 0x91, 0x50, + 0x9F, 0x33, 0xA6, 0xB2, 0xC5, 0xC0, 0xB5, 0xC1, + 0x2B, 0xDB, 0x1F, 0xB9, 0xC1, 0xCD, 0x83, 0x98, + 0x92, 0xB8, 0x70, 0xBD, 0x23, 0x60, 0x0D, 0xB2, + 0x3A, 0xC2, 0xB8, 0x3A, 0x5D, 0x5D, 0xC9, 0x14, + 0xDD, 0xEF, 0xBF, 0xBE, 0x4C, 0x79, 0xE6, 0xBB, + 0x75, 0xBA, 0x05, 0x73, 0xDC, 0x9B, 0xD5, 0x77, + 0x88, 0xE0, 0x32, 0x04, 0xB8, 0xE0, 0xA9, 0x80, + 0xB4, 0xD1, 0x70, 0x29, 0xFA, 0x7A, 0xA6, 0x1C, + 0x24, 0x86, 0xD2, 0xDB, 0x2E, 0x27, 0xF3, 0xEF, + 0xAD, 0xA2, 0x16, 0xEB, 0x6E, 0xFF, 0x3F, 0xAB, + 0x6C, 0x47, 0x94, 0x29, 0xB7, 0x59, 0xE5, 0x51, + 0x20, 0xC7, 0x60, 0x27, 0x68, 0x4B, 0x52, 0xFD, + 0x10, 0x07, 0xB5, 0x53, 0x89, 0x3E, 0xA1, 0xDE + }; + uint16_t scale_host[] = {0xBD36, 0x3B22, 0xBDD0, 0x3C6E, 0x3D9A, 0xBE63, 0xBE50, 0x3D28}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0xBFDD, 0x3F43, 0x3EBF, 0xBF3D, 0xBF8E, 0xBE61, 0xBFB3, 0xBF32, 0xBF06, 0xBF8E, 0xBD93, 0x3E29, 0x3F96, 0x3E1D, 0xBFEC, 0xBEA5, 0xBF44, 0xC01C, 0xBF14, 0x3E92, 0xBF08, 0x3EA5, 0xBF08, 0x3E05, 0xBDC4, 0xBD97, 0xBFA1, 0xBE62, 0xBEDF, 0xBFFC, 0xBD87, 0xBFA5}; + uint16_t expected[] = {0xC031, 0x3BF8, 0x3E81, 0xBF19, 0x3FCB, 0xBF56, 0x4076, 0x3F20}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// PrefillBatch: M=8, N=4, K=64, gs=32 +// scale/zero layout: [K//gs=2, N=4] +TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { + int64_t M = 8, K = 64, N = 4, gs = 32; + + // clang-format off + uint8_t qdata_host[] = { + 0xAD, 0x87, 0xDD, 0x5D, 0x57, 0x1B, 0x06, 0xE3, + 0xDE, 0xED, 0x8C, 0x7F, 0x1F, 0x75, 0x38, 0xDA, + 0xD7, 0x0B, 0xE7, 0xDB, 0x2B, 0x81, 0xE0, 0xA8, + 0xBC, 0xCB, 0xC9, 0x48, 0xCD, 0xD5, 0x4E, 0xA9, + 0x1D, 0x8D, 0x02, 0x7D, 0xEB, 0xE2, 0xD8, 0x0A, + 0x5D, 0xAC, 0x36, 0xA8, 0x27, 0x31, 0xCD, 0xE5, + 0xA3, 0x29, 0x08, 0x3D, 0x2B, 0x1F, 0x2A, 0xB0, + 0x45, 0x73, 0xD4, 0x02, 0x38, 0xEA, 0x0D, 0xA0, + 0xFA, 0x9A, 0xA4, 0x6E, 0x69, 0x35, 0x15, 0x7D, + 0xB5, 0x39, 0x26, 0x62, 0x0D, 0x8D, 0x1E, 0x27, + 0x9E, 0x01, 0x19, 0xAB, 0x17, 0xD2, 0xB3, 0x24, + 0x87, 0x34, 0x2E, 0xDD, 0x4E, 0x64, 0x6B, 0x20, + 0xA3, 0xAA, 0xED, 0x24, 0x80, 0xD0, 0x47, 0x90, + 0x6A, 0x45, 0x1E, 0x1C, 0xBD, 0x7D, 0xA4, 0x04, + 0x48, 0x1A, 0xD9, 0xCF, 0x29, 0xBC, 0x01, 0x07, + 0xB9, 0x00, 0x39, 0xB6, 0xC8, 0x2A, 0xE8, 0x17 + }; + uint16_t scale_host[] = {0xBE06, 0x3E0B, 0xBD82, 0x3DFA, 0x3D5E, 0xBE25, 0x3CBE, 0xBDD2}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0xBF37, 0x3FB7, 0xBF20, 0xBCC0, 0x3F88, 0xBD3F, 0xC02C, 0x3F73, 0xBF9F, 0x3FCA, 0x3E04, 0xBE88, 0x3F5F, 0x4002, 0xBF52, 0x3F1A, 0x3F2B, 0x3F35, 0xBF20, 0xBFF0, 0xBEB0, 0x3F90, 0x3F67, 0xBF85, 0x3F8F, 0x3FEB, 0x3F3A, 0xBEF4, 0xBF31, 0x3CE2, 0xBF74, 0x3EBF, 0xBF4D, 0x400C, 0xBF9E, 0xBD45, 0x3E8E, 0x3FE6, 0x3F7C, 0xBEEB, 0x4027, 0xBF0F, 0x3F5E, 0x3E15, 0x3E69, 0x3F82, 0x3FB3, 0x3E10, 0xBF17, 0x3F88, 0xBFDB, 0x3FA5, 0x3F1B, 0xBE50, 0x3E64, 0xBF5A, 0x3E78, 0x3F1A, 0x3F06, 0xBF51, 0x0000, 0xBF25, 0x3F80, 0x3E34, 0x3EA8, 0xBE9F, 0xBF67, 0x3DF1, 0xBF5C, 0xC020, 0xBEA6, 0x3E7D, 0xBF51, 0x3F70, 0x3F2C, 0xBE25, 0xBEB4, 0xBDEB, 0x3EE4, 0x3E29, 0xBFE5, 0x3E1F, 0x3F03, 0xBF6C, 0xBE8D, 0xBEB9, 0x3FB0, 0xBD7F, 0xBFBB, 0xBF18, 0x3F28, 0xBF0F, 0xBEF5, 0x3F97, 0x3FA8, 0x3FAC, 0x3E51, 0x3F84, 0xBF81, 0xBF5B, 0xBF2E, 0x3FBF, 0xBFE6, 0xBFC7, 0x3F53, 0x3F30, 0xC00C, 0x3F24, 0xBE79, 0xBFB3, 0x3F73, 0x3F62, 0xBF41, 0x3D93, 0xBF8C, 0x3FF3, 0x3F17, 0x3F10, 0x3F1C, 0x3F1E, 0xBF88, 0x3F33, 0xBEAA, 0xBFE3, 0x3EB4, 0xBEAA, 0x3E3E, 0x3F37, 0xC013, 0x3F27, 0xBEF8, 0xBDAD, 0xBF02, 0x3F3E, 0x3EA5, 0xBE6C, 0xBF3D, 0xBF3C, 0x3F82, 0xBFC1, 0x3FC4, 0xBF32, 0xBFD2, 0xBE9B, 0x3EAD, 0x3FA5, 0x3F67, 0x3F10, 0x3F2C, 0xBFCD, 0x3BED, 0xBF91, 0xBF92, 0x3F25, 0x3EDB, 0x3EAB, 0x3F14, 0x3FB9, 0xBF92, 0xBE6E, 0x3F9E, 0x3EC5, 0xC01F, 0x3F90, 0x400E, 0xBFF4, 0xBEC4, 0x3D2E, 0x0000, 0xBF07, 0xBF0D, 0x3FD8, 0x3EC5, 0x3F78, 0xBF45, 0xBED8, 0xBE3D, 0xBF84, 0x3F44, 0xBF70, 0x3E40, 0x3F34, 0x3FCA, 0xBF7C, 0x3E8F, 0x3E87, 0x3F7B, 0x3FBC, 0xBF92, 0xBF77, 0x3F80, 0xBFCB, 0xC006, 0xBF23, 0x3FA6, 0x3F5A, 0x3E86, 0x3F65, 0xBF7E, 0x3D96, 0xBFCE, 0xBF2C, 0xBF44, 0x3DD7, 0x3F96, 0x3F08, 0xBEEC, 0x3EA8, 0x3F4C, 0xBF5F, 0x3EFA, 0xBF97, 0x3E89, 0x3FFE, 0x3FA8, 0xBF89, 0xBEC0, 0xBE90, 0x3EEF, 0x3F88, 0x3F60, 0x3F52, 0xBFD8, 0x3F1B, 0xBF44, 0x3F13, 0xBF09, 0x3FAE, 0xBF38, 0xBEBF, 0x3EE0, 0xBEF9, 0xBE7D, 0xBFDE, 0x3F11, 0xBFFE, 0x3E49, 0xBF78, 0x3F08, 0x3F30, 0x3D99, 0xBF8B, 0xBFB9, 0xBEE6, 0x3E43, 0x3E46, 0x4003, 0x3FBF, 0xBF3E, 0xBEDA, 0xBE98, 0x3F8C, 0xBE0D, 0xBD4B, 0xBF3C, 0x3E98, 0xBF34, 0xBFFC, 0xBF1F, 0xBF54, 0x3BC2, 0xBF90, 0xBE9F, 0xBE83, 0x3F88, 0xBF00, 0xBFBD, 0x3F88, 0x3E00, 0x3DDC, 0x3F1F, 0xBEAD, 0x3FB8, 0x3E57, 0x3F7C, 0xBE8D, 0x3F03, 0xC002, 0xBF1F, 0xBFE1, 0xBFAC, 0xBF6A, 0xBFE4, 0xBF28, 0x3E58, 0xBF73, 0xBFAD, 0xBFDE, 0xBFE1, 0xBEC3, 0xBEB9, 0xBF40, 0x3E80, 0x3F7B, 0x3E99, 0xBF49, 0x3F12, 0x3DC7, 0xBFFE, 0x3DC4, 0xBD03, 0xBE00, 0xBFE9, 0xBEFC, 0x3F2F, 0xBE76, 0x3F9C, 0x3F0C, 0x3F3E, 0x3FAE, 0xBF91, 0x3EC5, 0x3EE9, 0x3F49, 0x3F39, 0xBF35, 0x3F66, 0xBF31, 0x3F83, 0x3F6F, 0xBEDC, 0x3F24, 0x3F82, 0x3F09, 0xBEF2, 0xBFB6, 0xBF00, 0xBED8, 0xBFAE, 0x3F76, 0xBFCC, 0xBE58, 0x3CB9, 0x3E38, 0x3FD2, 0x3FDC, 0xBFA8, 0xBE3E, 0xBFB0, 0xBD7D, 0x3F2A, 0xBFD0, 0xBF30, 0xBFE0, 0xBFA7, 0xBF82, 0xBF9A, 0xBED2, 0xBF2A, 0x3FBC, 0xBF3F, 0xBF48, 0xBEB6, 0xBF0D, 0xBDE5, 0xBF18, 0xBF57, 0x3F18, 0x3F54, 0x3F1A, 0x3FA3, 0xBF9A, 0xBF1D, 0xBF64, 0x3EB1, 0xBF89, 0x3F54, 0xBFC0, 0x3F56, 0x3F09, 0x3FE2, 0xBD9D, 0x3F17, 0x3FAD, 0xBF0B, 0xBF43, 0xBE24, 0xBF1A, 0xBF32, 0x3FD4, 0x3E8F, 0x3F1A, 0xBF80, 0x3E08, 0xBF88, 0xBF1B, 0xBE8B, 0x3F43, 0xBFBD, 0x3F9D, 0xBEB3, 0xBFA5, 0xBDCB, 0x3FB0, 0xBE72, 0x3F9A, 0x3F40, 0x3EAD, 0x3F27, 0x3F3B, 0xBFD0, 0xBF56, 0x3E7F, 0x3E99, 0x4004, 0x3F4B, 0xBEFF, 0xBE7F, 0x3FE3, 0x3E8B, 0x3F41, 0xBFC3, 0x3EEC, 0x3ECC, 0xBFB6, 0x3FA2, 0xBF9F, 0xBF8A, 0xBF97, 0x3FE9, 0xBFF9, 0xBFCC, 0x3CFA, 0x3EFD, 0xBEC1, 0x3E1F, 0xBFF2, 0xBF07, 0xBEAC, 0xBED0, 0xBEE2, 0x3EDA, 0x3FBA, 0xBF2B, 0xBE80, 0x3CB6, 0x3E99, 0x3F32, 0xBDEC, 0x3E82, 0xBD46, 0xBF47, 0x3F82, 0xBEA4, 0x3F9A, 0xBFA1, 0x3FD2, 0x3FA4, 0x3F95, 0x3FB5, 0xBF18, 0x3EDC, 0xBFC6, 0x3FD5, 0x3F83, 0xBF75, 0x3F80, 0xBDBC, 0xBD63, 0xBD61, 0xBE5F, 0xBE88, 0x3EAC, 0x3E96, 0xBF20, 0xBEA5, 0x3FC3, 0x3F2B, 0x3E58, 0xBF10, 0x3E82, 0x3F3B, 0x3E95, 0x39A4, 0xBEAF, 0x3EE3, 0x3E29, 0xBEC9, 0x3EFF, 0x3D13, 0x3F60, 0xBF34, 0xBF9C, 0x3E4E, 0x3F33, 0x3F29, 0x3EE2, 0xBF95, 0xBF87, 0x3F2E, 0xBE9E, 0xBF9D, 0xBDB7, 0x3F85, 0xBF07, 0x3F5E, 0x3F15, 0x3EA3, 0xBF94, 0x3F7C, 0xBF99, 0xBE3F, 0x3F7C, 0xBF89, 0x3F00, 0xBDD4, 0x3F5D, 0x3F07, 0x3F1C, 0xBFC8, 0xBFF6, 0xBF3E}; + uint16_t expected[] = {0x40BD, 0xC0E3, 0x4037, 0x40A9, 0x406F, 0x4116, 0x3F8D, 0xC01F, 0xC039, 0xC043, 0x3F86, 0x410A, 0x3F07, 0xC100, 0x4019, 0x40D7, 0x40A9, 0x40F1, 0xBF89, 0x406F, 0x40FE, 0xBFB8, 0xBF88, 0x406A, 0x4004, 0x3EDE, 0x3E17, 0x4102, 0xC081, 0xC0BA, 0xBFFB, 0x3F25}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// GroupSize128: M=1, N=2, K=256, gs=128 +// scale/zero layout: [K//gs=2, N=2] +TEST_F(AOTITorchInt4PlainMMTest, GroupSize128) { + int64_t M = 1, K = 256, N = 2, gs = 128; + + // clang-format off + uint8_t qdata_host[] = { + 0xDE, 0x2E, 0x2C, 0x16, 0xA3, 0x9B, 0x16, 0x10, + 0xFE, 0x09, 0x0E, 0x9F, 0xE3, 0x4D, 0x00, 0x14, + 0x37, 0x42, 0x27, 0xF4, 0xD8, 0x70, 0x39, 0xCC, + 0x64, 0x51, 0xE6, 0x2B, 0xC1, 0x38, 0x5A, 0xB0, + 0xA2, 0x6A, 0x2D, 0xF4, 0xBB, 0xCD, 0x4E, 0xD6, + 0xA3, 0x60, 0xFE, 0x74, 0x6B, 0x17, 0xAB, 0x75, + 0x29, 0x84, 0xC1, 0x12, 0x31, 0x5C, 0x09, 0xB8, + 0x61, 0x33, 0x5B, 0x79, 0x29, 0xB3, 0x33, 0xE8, + 0x96, 0xE7, 0x36, 0x69, 0x6C, 0x6B, 0xD1, 0xAE, + 0x43, 0x13, 0xDF, 0x50, 0xD8, 0xE6, 0xBF, 0x98, + 0x1D, 0x30, 0x1D, 0x43, 0x7E, 0x6D, 0x1C, 0xE4, + 0x3C, 0x3C, 0x67, 0x68, 0xCD, 0xFC, 0x44, 0x07, + 0x90, 0x88, 0xA4, 0xAF, 0xDD, 0xE8, 0x16, 0x6E, + 0x78, 0xCA, 0x9C, 0xBA, 0x71, 0xCD, 0x1B, 0x97, + 0x8E, 0xF7, 0x31, 0x81, 0x7E, 0x15, 0x52, 0x22, + 0xDE, 0x39, 0xB2, 0x6E, 0x97, 0xE1, 0xB3, 0xCA, + 0xB8, 0x3A, 0xAD, 0xBA, 0x97, 0x9B, 0xBE, 0x33, + 0x5E, 0x6B, 0x80, 0x77, 0x44, 0x05, 0xC8, 0x29, + 0x15, 0xC5, 0xF9, 0xCB, 0xA2, 0x34, 0x30, 0xB7, + 0x27, 0x15, 0x57, 0x19, 0x2A, 0xAD, 0x58, 0x90, + 0x33, 0x13, 0x67, 0x13, 0x27, 0x6C, 0x95, 0x98, + 0xA4, 0x87, 0x95, 0x42, 0xCC, 0x33, 0x71, 0xCF, + 0x8D, 0x75, 0xE7, 0x7E, 0xCE, 0x05, 0xE0, 0xE8, + 0x1F, 0xF0, 0xEE, 0xB4, 0xAF, 0x45, 0x05, 0x17, + 0xA2, 0x72, 0x7A, 0xA3, 0x16, 0x48, 0xD1, 0xE6, + 0x95, 0xFA, 0x30, 0x31, 0x7E, 0x77, 0x35, 0xE6, + 0x3D, 0x15, 0x95, 0x31, 0x9D, 0x51, 0x6D, 0xDA, + 0x51, 0xE0, 0x07, 0xCE, 0x3A, 0xC0, 0x26, 0xA7, + 0xE5, 0x01, 0x20, 0x56, 0xEF, 0xED, 0xCD, 0x19, + 0xE5, 0xA3, 0x46, 0x7A, 0x1D, 0x6E, 0x30, 0x31, + 0x80, 0xEE, 0xED, 0x15, 0x34, 0x22, 0x0D, 0x2E, + 0xAB, 0xEE, 0x20, 0x97, 0xE0, 0xF3, 0xB9, 0xF7 + }; + uint16_t scale_host[] = {0xBB98, 0xBD63, 0xBCBE, 0xBD87}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0x3F02, 0x3EB3, 0x3F22, 0xBD3F, 0x3F91, 0x3EFF, 0xBFD2, 0xC026, 0xBF3D, 0xBEBD, 0x3EFD, 0x4002, 0x3EF0, 0xBF2F, 0xBD4B, 0xBEE6, 0xBEA5, 0x3F78, 0x3FC3, 0xBE08, 0xBFC5, 0xBFFE, 0xBE4F, 0x3FA8, 0x3EE9, 0x3F60, 0xC03E, 0x3F88, 0x3F1C, 0xBF35, 0xBF8E, 0x0000, 0x3F03, 0x3ED9, 0xBE3D, 0x3ED0, 0xBF90, 0x3FF8, 0xBEDF, 0x3E62, 0x3F45, 0x3E68, 0xBF3E, 0xBDA0, 0x3F98, 0xC003, 0x3E51, 0xBDF8, 0xBED1, 0x3E78, 0x3FA4, 0xBEAD, 0x3F6C, 0x3E1F, 0x4000, 0xBED1, 0x3ECF, 0x3EC4, 0xBF50, 0x3F8E, 0x3FC5, 0xBF97, 0x3E18, 0x3EA1, 0xBFBD, 0xBFA5, 0x3EB0, 0xBF02, 0x3FD7, 0x3F6A, 0xBFEF, 0x3F9E, 0xBF3F, 0xBF90, 0xBFC0, 0xBFCE, 0x3F80, 0x3FFA, 0xBDB0, 0xBECD, 0xBF06, 0x3F75, 0xBFEC, 0x3E5E, 0xC00E, 0xBE63, 0xBF9A, 0x3FAB, 0xBEC8, 0xBF1B, 0x4017, 0xBE03, 0x3F4C, 0x3FA3, 0x3F43, 0xBF13, 0xBF4C, 0x3D7D, 0x3F28, 0x3EC5, 0x3F5A, 0x3F39, 0xBEED, 0x4011, 0x3DD0, 0x3F5E, 0x3F6E, 0x3FA1, 0xC008, 0x3F83, 0x3CB5, 0x3EE7, 0xBED1, 0x3F2D, 0x3A68, 0x3D21, 0x3DE7, 0xBE6B, 0x3DEE, 0x3EF5, 0xBFA6, 0x4042, 0x3FEA, 0xBDF3, 0xBF30, 0x3FC5, 0x3DCD, 0x3EA3, 0xBF0A, 0xBF1A, 0xBF41, 0x3F27, 0xBE1F, 0xBEFE, 0x3F25, 0xBE14, 0x3E33, 0xBFDC, 0x3EAE, 0xBF96, 0xBEFC, 0xBFC9, 0xC035, 0xBF2B, 0x3DE1, 0xBF3D, 0x0000, 0xC002, 0x3E77, 0xBEAB, 0x3EC7, 0xBEBB, 0x3F89, 0x3EAB, 0x0000, 0x3E84, 0xBEDF, 0xBE67, 0x3E47, 0x3DE5, 0x3FA6, 0xBF42, 0x3E58, 0x3E8C, 0x4007, 0x3F0A, 0xC00A, 0xBE0D, 0xBEC1, 0x3F62, 0x3D58, 0xBFD5, 0xBED0, 0xBEE3, 0xBF62, 0x3F4B, 0x3FC0, 0xBF34, 0x3F18, 0x3F73, 0x3F18, 0x3FE6, 0x3E6F, 0x3CD9, 0x3DE4, 0x3FA8, 0x3FC6, 0x3F7E, 0xBD1E, 0xBFA6, 0x3E84, 0x3F8E, 0xBE94, 0x3F63, 0xBE0F, 0x3F49, 0x3E16, 0x3EC0, 0xBF90, 0x401A, 0xBDE8, 0xBF06, 0xBEE2, 0x3FC6, 0xBFBA, 0x3EEA, 0x3F4A, 0xBFE0, 0x4009, 0xBFAA, 0xBF04, 0x3F9D, 0xBF9A, 0x3F06, 0x3FD0, 0x3FAB, 0x3EDB, 0xBF6C, 0x3FD7, 0xBEB6, 0xBF09, 0x0000, 0x3F78, 0x3FAB, 0x3F95, 0xBD5C, 0xBF66, 0xBDF9, 0xBD42, 0xBFDE, 0xBF11, 0xBE46, 0xBF76, 0xBF75, 0x3F31, 0xBEC5, 0xBFF3, 0xBF0F, 0x3EEE, 0xBED8, 0x3E2C, 0xBF3E, 0xBD82, 0x3F33, 0x3F24, 0xBFFF, 0xBF23, 0xBF8A, 0x3E0D, 0xBEC0, 0x3FAF, 0xBF76, 0xBF94, 0x3FAC, 0xBF21, 0x3FA0}; + uint16_t expected[] = {0xC013, 0xBF05}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +TEST_F(AOTITorchInt4PlainMMTest, NullInputHandling) { + int64_t M = 2, K = 128, N = 64, gs = 32; + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + Tensor* output = nullptr; + + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(nullptr, qdata, scale, zero, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, nullptr, scale, zero, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, nullptr, zero, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, scale, nullptr, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, scale, zero, gs, nullptr), + Error::InvalidArgument); +} + +TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { + int64_t M = 1, K = 64, N = 8, gs = 32; + int64_t n_groups = K / gs; // 2 + + // clang-format off + // Data from quantize_weight(randn(8,64), QuantConfig(bits=4, gs=32, symmetric=False)) + uint8_t qdata_host[] = { + 0x04, 0x9A, 0x97, 0x63, 0x9B, 0x74, 0x4D, 0x9F, 0x4C, 0x2C, 0x88, 0x58, + 0x56, 0x8D, 0x51, 0x58, 0x87, 0xF5, 0x6A, 0xC7, 0x6C, 0x65, 0x30, 0x84, + 0xB6, 0xA1, 0x37, 0x48, 0x5B, 0x36, 0x68, 0xE7, 0x8E, 0x6A, 0x88, 0x82, + 0xAA, 0x9D, 0xAB, 0x0D, 0xB5, 0x81, 0xBE, 0xA3, 0xE4, 0x9F, 0x99, 0xC8, + 0x86, 0xC8, 0x5D, 0xAA, 0x86, 0x46, 0xBA, 0x9D, 0xDA, 0x06, 0xCA, 0xB7, + 0x53, 0xCD, 0xBF, 0x37, 0x25, 0xD4, 0x04, 0x36, 0xAF, 0x79, 0x57, 0x54, + 0x2A, 0xC9, 0x98, 0x98, 0x5A, 0x05, 0x43, 0x89, 0x84, 0x9A, 0x74, 0xC6, + 0xE6, 0x96, 0x6B, 0x09, 0xAF, 0xFB, 0x3C, 0xB3, 0x88, 0x63, 0x68, 0xAC, + 0x48, 0xB9, 0xC9, 0x34, 0xDC, 0x77, 0x8A, 0x8C, 0xFC, 0x75, 0xC7, 0x95, + 0xAD, 0xF5, 0x70, 0x9C, 0x4A, 0x79, 0x7C, 0x67, 0xAA, 0xAA, 0x0B, 0x8C, + 0xF0, 0x28, 0x91, 0xCD, 0xDA, 0x95, 0x3A, 0x84, 0xD9, 0x45, 0x89, 0x33, + 0x5B, 0x63, 0xB4, 0x39, 0xE9, 0xBF, 0x54, 0x40, 0xAB, 0xC8, 0x88, 0xCB, + 0x48, 0xBA, 0x7A, 0x03, 0xCB, 0x35, 0x74, 0x85, 0x67, 0x58, 0x12, 0xDC, + 0x5B, 0x02, 0x58, 0xF7, 0x8C, 0xC8, 0xA5, 0xFA, 0xAA, 0x8E, 0x4C, 0x1F, + 0xBB, 0x27, 0xC7, 0xEC, 0xB8, 0x69, 0x6F, 0x9F, 0x69, 0x69, 0x55, 0x79, + 0x34, 0x64, 0x56, 0x85, 0x67, 0x3F, 0xA8, 0x80, 0x7A, 0x77, 0x79, 0x05, + 0xA9, 0x10, 0xA7, 0x55, 0x4A, 0x48, 0xF8, 0x59, 0xB6, 0x5A, 0xBD, 0x55, + 0x8C, 0x96, 0x48, 0x6B, 0x9A, 0xC7, 0x97, 0x4B, 0x46, 0x65, 0xF7, 0x7B, + 0x78, 0x5C, 0x8A, 0xC5, 0x98, 0x0C, 0x45, 0x3B, 0x75, 0x9C, 0xC7, 0x58, + 0x63, 0x9A, 0x95, 0x78, 0x95, 0x69, 0xF8, 0x58, 0x65, 0x0A, 0x6B, 0x47, + 0x9C, 0x5C, 0x6A, 0x35, 0xA2, 0x8A, 0x74, 0x93, 0x28, 0x6D, 0xF0, 0xAB, + 0x23, 0xA6, 0xA6, 0x3A}; + // scale/zero are [K//gs, N] = [2, 8] — Int4Tensor's native layout + uint16_t scale_host[] = { + 0x3E46, 0x3E94, 0x3E8F, 0x3E94, 0x3E94, 0x3E8D, 0x3EA5, 0x3EA5, + 0x3E9F, 0x3EAD, 0x3E91, 0x3EA0, 0x3E88, 0x3EB7, 0x3E89, 0x3E92}; + uint16_t zero_host[] = { + 0x4100, 0x4110, 0x40A0, 0x4100, 0x4100, 0x4130, 0x4100, 0x40C0, + 0x40C0, 0x4100, 0x4100, 0x4100, 0x40E0, 0x40E0, 0x4110, 0x40C0}; + uint16_t A_host[] = { + 0x3E47, 0x400A, 0xBE30, 0x3F59, 0xBFF6, 0x3F27, 0xBF26, 0xBF51, + 0x3F07, 0xBFA3, 0xBFD5, 0xBE9B, 0xBDBE, 0x3E4C, 0xBF8F, 0x3FEE, + 0xBF37, 0x3F30, 0x3F4C, 0xBD09, 0x3FBF, 0xBF04, 0xBE82, 0x3FBD, + 0xBEA7, 0xBF94, 0x4017, 0xBF31, 0x3E3C, 0xBF97, 0xBFE7, 0xBFCA, + 0x3F57, 0x3FB6, 0x3F26, 0x3EDA, 0xBFCB, 0x3F1F, 0x3FD8, 0xBF2A, + 0x3F71, 0x3DA0, 0x3DAD, 0xBE10, 0x3EAA, 0xBF17, 0xBF89, 0x3DC3, + 0xBEAB, 0xBF07, 0xBF61, 0x3ECA, 0x3E28, 0xBE4A, 0x3F81, 0xBFAD, + 0xBEB3, 0xBF25, 0x3EE5, 0xBF0A, 0x3F9F, 0xBF51, 0x3E80, 0xBEDB}; + // Reference from Python: bf16 dequant + F.linear + uint16_t expected[] = { + 0x40B7, 0xC100, 0xC0E2, 0xC158, 0xBF29, 0xC11F, 0x4079, 0x407D}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + // Note: scale/zero shape is [n_groups, N], NOT [N, n_groups] + Tensor* scale = create_bf16({n_groups, N}); + Tensor* zero = create_bf16({n_groups, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + // W4A8 adds quantization noise vs bf16 reference — use wider tolerance + check_bf16_output(output, expected, M * N, 0.5f); +} diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py new file mode 100644 index 00000000000..bf9ef01518b --- /dev/null +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for Int4Tensor F.linear dispatch via int4_dispatch. + +The API contract: after importing int4_dispatch, F.linear and nn.Linear +with Int4Tensor weights produce numerically correct results. Tests verify +this across decode (M=1), prefill (M>1), batched (3D), bias, group sizes, +and symmetric/asymmetric quantization. + +Usage: + python -m pytest backends/cuda/tests/test_int4_dispatch.py -v +""" + +import unittest + +import executorch.backends.cuda.int4_dispatch # noqa: F401 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.examples.models.gemma4_31b.quant.quantize import quantize_weight +from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig + + +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +def _make_int4_linear(N, K, group_size=128, symmetric=False, bias=False): + """Build an nn.Linear with Int4Tensor weight and return (module, bf16_ref_weight). + + The bf16 reference is the original unquantized weight, so tests can + measure quantization error against the true value. + """ + w_bf16 = torch.randn(N, K, dtype=torch.bfloat16) + config = QuantConfig( + bits=4, group_size=group_size, symmetric=symmetric, method="min_max" + ) + int4_w = quantize_weight(w_bf16, config) + + module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda") + module.weight = nn.Parameter(int4_w.cuda(), requires_grad=False) + return module, w_bf16.cuda() + + +class TestFLinearDispatch(unittest.TestCase): + """F.linear with Int4Tensor weight produces correct results.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_decode_m1(self): + module, w_ref = _make_int4_linear(256, 512) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_prefill_m64(self): + module, w_ref = _make_int4_linear(256, 512) + x = torch.randn(64, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_3d_batched_input(self): + module, w_ref = _make_int4_linear(256, 512) + x = torch.randn(2, 32, 512, dtype=torch.bfloat16, device="cuda") + out = module(x) + self.assertEqual(out.shape, (2, 32, 256)) + self._check(out, F.linear(x, w_ref)) + + def test_with_bias(self): + module, w_ref = _make_int4_linear(256, 512, bias=True) + x = torch.randn(4, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref, module.bias)) + + def test_group_size_32(self): + module, w_ref = _make_int4_linear(128, 256, group_size=32) + x = torch.randn(1, 256, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_symmetric(self): + module, w_ref = _make_int4_linear(256, 512, symmetric=True) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + +class TestMultiLayer(unittest.TestCase): + """Dispatch works across multiple Int4 linear modules in a model.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_two_layer_mlp(self): + up, w_up = _make_int4_linear(512, 256) + down, w_down = _make_int4_linear(256, 512) + x = torch.randn(4, 256, dtype=torch.bfloat16, device="cuda") + out = down(F.silu(up(x))) + ref = F.linear(F.silu(F.linear(x, w_up)), w_down) + self._check(out, ref) + + def test_sequential_decode_steps(self): + module, w_ref = _make_int4_linear(256, 512) + for _ in range(4): + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + +class TestCompile(unittest.TestCase): + """Dispatch works under torch.compile.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_compile_decode(self): + module, w_ref = _make_int4_linear(256, 512) + compiled = torch.compile(module, fullgraph=True) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(compiled(x), F.linear(x, w_ref)) + + def test_compile_prefill(self): + module, w_ref = _make_int4_linear(256, 512) + compiled = torch.compile(module, fullgraph=True) + x = torch.randn(64, 512, dtype=torch.bfloat16, device="cuda") + self._check(compiled(x), F.linear(x, w_ref)) + + def test_compile_matches_eager(self): + module, _ = _make_int4_linear(256, 512) + compiled = torch.compile(module, fullgraph=True) + x = torch.randn(4, 512, dtype=torch.bfloat16, device="cuda") + out_eager = module(x) + out_compiled = compiled(x) + self.assertTrue(torch.allclose(out_eager, out_compiled, atol=0.5)) + + +class TestDeviceMovement(unittest.TestCase): + """Int4Tensor weight survives device movement and still dispatches.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_to_cuda(self): + w_bf16 = torch.randn(256, 512, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=128, symmetric=False, method="min_max") + int4_w = quantize_weight(w_bf16, config) + module = nn.Linear(512, 256, bias=False) + module.weight = nn.Parameter(int4_w, requires_grad=False) + module = module.to("cuda") + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_bf16.cuda())) + + +class TestLargeShapes(unittest.TestCase): + """Correctness at large production-scale layer shapes.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_4096x5376_decode(self): + module, w_ref = _make_int4_linear(4096, 5376) + x = torch.randn(1, 5376, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_21504x5376_decode(self): + module, w_ref = _make_int4_linear(21504, 5376) + x = torch.randn(1, 5376, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_21504x5376_prefill(self): + module, w_ref = _make_int4_linear(21504, 5376) + x = torch.randn(128, 5376, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 78668c55118..362265f1c37 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -82,7 +82,7 @@ def load_and_quantize( model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) print(f"Quantizing with recipe '{recipe_name}'...") - state_dict = quantize_model(model, recipe) + state_dict = quantize_model(model, recipe, verbose=True) print(f"Packing for {backend}...") with torch.device("meta"): @@ -133,6 +133,8 @@ def export_and_lower( def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: + import gc + import torch._inductor.config as inductor_config from executorch.backends.cuda.cuda_backend import CudaBackend @@ -142,28 +144,23 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - ExecutorchBackendConfig, to_edge_transform_and_lower, ) + from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.passes import MemoryPlanningPass from torch.export import Dim, export inductor_config.coordinate_descent_tuning = False inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + # Register Int4Tensor dispatch → executorch_cuda::int4_plain_mm shim + import executorch.backends.cuda.int4_dispatch # noqa: F401 + materialize_runtime_buffers(model, dtype=torch.bfloat16) - print("Exporting decode (T=1)...") - with torch.no_grad(): - decode_ep = export( - model, - ( - torch.tensor([[0]], dtype=torch.long), - torch.tensor([0], dtype=torch.long), - torch.tensor([1.0], dtype=torch.float32), - ), - strict=True, - ) + # Int4Tensor weights are used directly — no format conversion. + # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). + # Both decode and prefill share the same nibble-packed weights. - # Cap prefill length to the ring-buffer KV cache size (2×sliding_window). - # Longer prompts are chunked by the runner. + # Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M). max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) seq_dim = Dim("seq_len", min=2, max=max_prefill) print(f"Exporting prefill (T in [2, {max_prefill}])...") @@ -179,18 +176,40 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - strict=True, ) + # Decode (T=1): same Int4Tensor weights, same format. No transform needed. + print("Exporting decode (T=1)...") + with torch.no_grad(): + decode_ep = export( + model, + ( + torch.tensor([[0]], dtype=torch.long), + torch.tensor([0], dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + strict=True, + ) + + del model + gc.collect() + print("Lowering to ExecuTorch with CUDA backend...") et_prog = to_edge_transform_and_lower( {"decode": decode_ep, "prefill": prefill_ep}, partitioner={ "decode": [ CudaPartitioner( - [CudaBackend.generate_method_name_compile_spec("decode")] + [ + CudaBackend.generate_method_name_compile_spec("decode"), + CompileSpec("low_memory_mode", b"ON"), + ] ) ], "prefill": [ CudaPartitioner( - [CudaBackend.generate_method_name_compile_spec("prefill")] + [ + CudaBackend.generate_method_name_compile_spec("prefill"), + CompileSpec("low_memory_mode", b"ON"), + ] ) ], }, @@ -208,6 +227,9 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - "enable_dynamic_shape": True, }, ) + del decode_ep, prefill_ep + gc.collect() + et_program = et_prog.to_executorch( config=ExecutorchBackendConfig( extract_delegate_segments=True, @@ -220,6 +242,9 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - ), ) + del et_prog + gc.collect() + os.makedirs(output_dir, exist_ok=True) pte_path = os.path.join(output_dir, "model.pte") print(f"Saving to {pte_path}...") diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index a5c51ac66a9..12785450d8c 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -185,6 +185,8 @@ def main() -> None: _move_to_cuda(model, config) model.eval() + import executorch.backends.cuda.int4_dispatch # noqa: F401 + if not args.no_compile: print("Compiling model with torch.compile...") model = torch.compile(model, mode="default") diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 351adc03a33..0be2fef517c 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -26,6 +26,9 @@ #include #include +#include +#include +#include #include #include #include @@ -40,7 +43,7 @@ extern "C" void et_pal_emit_log_message( size_t line, const char* message, ET_UNUSED size_t length) { - if (level < 'W') { + if (level == 'D' || level == 'I') { return; } fprintf(stderr, "%c [%s:%zu] %s\n", (char)level, filename, line, message); diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index bfaa5d955ae..7c834505d36 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""CUDA packer: torchao quantized tensors → CUDA runtime format. +"""CUDA packer: assign quantized weights to model modules. -Converts ``Int4Tensor`` to ``Int4TilePackedTo4dTensor`` (tinygemm) and -passes ``IntxUnpackedToInt8Tensor`` through unchanged (AOTI fuses -the dequantize-matmul pattern). +Passes ``Int4Tensor`` and ``IntxUnpackedToInt8Tensor`` through as +``nn.Parameter`` without conversion. The Int4Tensor dispatch override +(``int4_dispatch.py``) handles F.linear at runtime. -The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. +No CUDA is required for packing. The backend-agnostic ``pack_model`` +dispatcher lives in ``pack.py``. """ import json @@ -21,100 +22,17 @@ from .pack import ModulePackerFn, pack_model # noqa: F401 -# --------------------------------------------------------------------------- -# Low-level converters - - -def pack_int4_for_cuda( - weight: torch.Tensor, - device: str = "cuda", -) -> nn.Parameter: - """Convert an ``Int4Tensor`` to ``Int4TilePackedTo4dTensor`` for tinygemm. - - Unpacks nibbles, pads to tinygemm alignment, tile-packs via CUDA kernel, - and builds the combined scale_and_zero tensor. - - TODO: replace with ``Int4TilePackedTo4dTensor.from_int4_tensor()`` once - that's upstreamed to torchao. - """ - from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( - Int4TilePackedTo4dTensor, - ) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros - from torchao.utils import find_multiple - - original_shape = weight.shape - N, K = original_shape - gs = weight.block_size[-1] - inner_k_tiles = 8 - - # Unpack Int4Tensor nibbles to int32 - p = weight.qdata.to(torch.uint8) - low = (p & 0x0F).to(torch.int32) - high = ((p >> 4) & 0x0F).to(torch.int32) - int_data = torch.stack([low, high], dim=-1).reshape(N, K) - - # Scale/zero: Int4Tensor stores (K//gs, N), transpose to (N, K//gs) - scale = weight.scale.t().contiguous() - zero = weight.zero_point.t().contiguous() - - # Pad to tinygemm alignment - K_padded = find_multiple(K, 1024) - N_padded = find_multiple(N, 8) - if K_padded != K or N_padded != N: - int_data = torch.nn.functional.pad(int_data, (0, K_padded - K, 0, N_padded - N)) - n_groups_padded = K_padded // gs - n_groups_orig = K // gs - scale = torch.nn.functional.pad( - scale, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) - ) - zero = torch.nn.functional.pad( - zero, (0, n_groups_padded - n_groups_orig, 0, N_padded - N) - ) - - int_data = int_data.to(device) - scale = scale.to(device) - zero = zero.to(device) - - # Convert zero-point convention: tinygemm uses zp_tg = (8 - zp_std) * scale - tinygemm_zero = (8 - zero.float()) * scale.float() - - # Tinygemm nibble convention: even=HIGH, odd=LOW - int_data_u8 = (int_data[:, ::2] << 4 | int_data[:, 1::2]).to(torch.uint8) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data_u8.contiguous(), inner_k_tiles - ) - - scale_and_zero = pack_tinygemm_scales_and_zeros( - scale.to(torch.bfloat16), tinygemm_zero.to(torch.bfloat16), torch.bfloat16 - ) - - subclass = Int4TilePackedTo4dTensor( - qdata=packed_weight, - scale_and_zero=scale_and_zero, - block_size=[1, gs], - shape=torch.Size(original_shape), - ) - return nn.Parameter(subclass, requires_grad=False) - - # --------------------------------------------------------------------------- # Per-module packers def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: - """Pack a quantized ``nn.Linear`` for CUDA.""" + """Assign a quantized weight to an ``nn.Linear`` module.""" from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor w = weights["weight"] - if isinstance(w, Int4Tensor): - # Pack on CUDA (required by _convert_weight_to_int4pack), move back - # to CPU for assembly. The model moves to CUDA later at runtime. - packed = pack_int4_for_cuda(w, device="cuda") - module.weight = nn.Parameter(packed.detach().to("cpu"), requires_grad=False) - torch.cuda.empty_cache() - elif isinstance(w, IntxUnpackedToInt8Tensor): + if isinstance(w, (Int4Tensor, IntxUnpackedToInt8Tensor)): module.weight = nn.Parameter(w, requires_grad=False) else: raise ValueError(f"Unsupported weight type: {type(w).__name__}") @@ -123,14 +41,14 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> def pack_embedding_for_cuda( module: nn.Module, weights: dict[str, torch.Tensor] ) -> None: - """Pack a quantized ``nn.Embedding`` for CUDA (INT8 only).""" + """Assign a quantized weight to an ``nn.Embedding`` (INT8 only).""" from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor w = weights["weight"] if isinstance(w, Int4Tensor): raise ValueError( "Only 8-bit embedding quantization is supported on CUDA. " - "Int4TilePackedTo4dTensor does not implement the embedding op." + "INT4 does not implement the embedding op." ) module.weight = nn.Parameter(w, requires_grad=False) @@ -150,7 +68,7 @@ def load_and_pack_for_cuda( model: nn.Module, packers: dict[type, ModulePackerFn] | None = None, ) -> None: - """Load a quantized safetensors file and pack for CUDA.""" + """Load a quantized safetensors file and assign weights to the model.""" from safetensors import safe_open from torchao.prototype.safetensors.safetensors_support import ( unflatten_tensor_state_dict, @@ -164,8 +82,6 @@ def load_and_pack_for_cuda( all_keys = list(f.keys()) tensor_names = json.loads(metadata.get("tensor_names", "[]")) - # Stream one logical weight at a time: load its inner tensors, - # reconstruct the subclass, pack, then release before the next. for name in tensor_names: parts = name.rsplit(".", 1) module_fqn = parts[0] if len(parts) > 1 else "" diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index 4e4b993d496..ade85efd788 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -283,6 +283,7 @@ def quantize_model( model: nn.Module, recipe: QuantRecipe, dtype: torch.dtype = torch.bfloat16, + verbose: bool = False, ) -> dict[str, torch.Tensor]: """Walk model parameters + persistent buffers, apply recipe. @@ -300,8 +301,10 @@ def quantize_model( state[fqn] = param.data.to(dtype) else: state[fqn] = quantize_weight(param.data, config) - print(f" Quantized {i + 1}/{n_params}: {fqn}", end="\r") - print() + if verbose: + print(f" Quantized {i + 1}/{n_params}: {fqn}", end="\r") + if verbose: + print() for fqn, buf in model.named_buffers(): if fqn in persistent_keys and fqn not in state: diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py index 9ffafeafc5f..6e29a93ba3e 100644 --- a/examples/models/gemma4_31b/quant/recipe.py +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -25,7 +25,7 @@ class QuantConfig: """ bits: int # 4 or 8 - group_size: int # 32, 64, 128 + group_size: int symmetric: bool # True = no zero point method: str # "min_max" | "hqq" diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index 89ebbcbab56..0e525e65158 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -4,21 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Unit tests for quant/pack_cuda.py. Requires CUDA.""" +"""Unit tests for quant/pack_cuda.py. + +Tests the public API contract: after packing, modules produce correct +output via F.linear / nn.Embedding at various batch sizes and configs. +""" import os import tempfile import unittest +# Register Int4Tensor F.linear dispatch before any test uses it +import executorch.backends.cuda.int4_dispatch # noqa: F401 + import torch import torch.nn as nn - from executorch.examples.models.gemma4_31b.quant.pack import pack_one from executorch.examples.models.gemma4_31b.quant.pack_cuda import ( DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda, pack_embedding_for_cuda, - pack_int4_for_cuda, pack_linear_for_cuda, pack_model, ) @@ -28,142 +33,120 @@ from torchao.prototype.safetensors.safetensors_support import flatten_tensor_state_dict -class TestPackInt4ForCuda(unittest.TestCase): - def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") - def test_basic(self): - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - q = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) - self.assertEqual(pack_int4_for_cuda(q).shape, torch.Size([128, 256])) - def test_different_group_sizes(self): - for gs in (32, 64, 128): - with self.subTest(group_size=gs): - config = QuantConfig( - bits=4, group_size=gs, symmetric=False, method="min_max" - ) - q = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) - self.assertEqual(pack_int4_for_cuda(q).shape, torch.Size([128, 256])) +class TestPackLinearInt4(unittest.TestCase): + """pack_linear_for_cuda with INT4 weights produces correct F.linear output.""" - def test_matmul_approximates_original(self): + def setUp(self): + _require_cuda(self) torch.manual_seed(0) - weight = torch.randn(256, 1024, dtype=torch.bfloat16) - x = torch.randn(1, 1024, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + self.weight = torch.randn(256, 1024, dtype=torch.bfloat16) - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - q = quantize_weight(weight, config) - packed = pack_int4_for_cuda(q) - packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + def _pack(self, symmetric=False, group_size=32): + config = QuantConfig( + bits=4, group_size=group_size, symmetric=symmetric, method="min_max" + ) + q = quantize_weight(self.weight, config) + module = nn.Linear(1024, 256, bias=False) + pack_linear_for_cuda(module, {"weight": q}) + module.cuda() + return module + + def test_shape_preserved(self): + module = self._pack() + self.assertEqual(module.weight.shape, torch.Size([256, 1024])) + + def test_asymmetric_decode(self): + module = self._pack(symmetric=False) + x = torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) - rel_error = ( - packed_out.float() - original_out.float() - ).abs().mean() / original_out.float().abs().mean() + def test_symmetric_decode(self): + module = self._pack(symmetric=True) + x = torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() self.assertLess(rel_error.item(), 0.15) - def test_symmetric_matmul(self): - torch.manual_seed(0) - weight = torch.randn(256, 1024, dtype=torch.bfloat16) - x = torch.randn(1, 1024, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + def test_prefill_batch(self): + module = self._pack(symmetric=False) + x = torch.randn(64, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - q = quantize_weight(weight, config) - packed = pack_int4_for_cuda(q) - packed_out = torch.nn.functional.linear(x.cuda(), packed.cuda()) + def test_different_group_sizes(self): + for gs in (32, 64, 128): + with self.subTest(group_size=gs): + module = self._pack(group_size=gs) + x = torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = ( + out.float() - ref.float() + ).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) - rel_error = ( - packed_out.float() - original_out.float() - ).abs().mean() / original_out.float().abs().mean() - self.assertLess(rel_error.item(), 0.15) +class TestPackLinearInt8(unittest.TestCase): + """pack_linear_for_cuda with INT8 weights produces correct F.linear output.""" -class TestPackInt8OnCuda(unittest.TestCase): def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") + _require_cuda(self) - def test_matmul_approximates_original(self): + def test_matmul_correct(self): torch.manual_seed(0) weight = torch.randn(256, 128, dtype=torch.bfloat16) x = torch.randn(1, 128, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x.cuda(), weight.cuda()) + ref = torch.nn.functional.linear(x.cuda(), weight.cuda()) config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") q = quantize_weight(weight, config) - # IntxUnpackedToInt8Tensor is already the CUDA format - emb = nn.Linear(128, 256, bias=False) - emb.weight = nn.Parameter(q, requires_grad=False) - emb.to("cuda") - packed_out = emb(x.cuda()) - - rel_error = ( - packed_out.float() - original_out.float() - ).abs().mean() / original_out.float().abs().mean() - self.assertLess(rel_error.item(), 0.02) - - def test_per_axis_embedding_gather(self): - torch.manual_seed(0) - weight = torch.randn(1000, 64, dtype=torch.bfloat16) - ids = torch.tensor([0, 1, 42, 500, 999]) - original = weight[ids] - - config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") - q = quantize_weight(weight, config) - emb = nn.Embedding(1000, 64) - emb.weight = nn.Parameter(q, requires_grad=False) - emb.to("cuda") - packed_out = emb(ids.cuda()) + module = nn.Linear(128, 256, bias=False) + pack_linear_for_cuda(module, {"weight": q}) + module.cuda() + out = module(x.cuda()) - rel_error = ( - packed_out.cpu().float() - original.float() - ).abs().mean() / original.float().abs().mean() + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() self.assertLess(rel_error.item(), 0.02) + def test_unsupported_type_raises(self): + module = nn.Linear(64, 32, bias=False) + with self.assertRaises(ValueError): + pack_linear_for_cuda(module, {"weight": torch.randn(32, 64)}) -class TestPackLinearForCuda(unittest.TestCase): - def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") - - def test_4bit(self): - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - q = quantize_weight(torch.randn(256, 128, dtype=torch.bfloat16), config) - module = nn.Linear(128, 256, bias=False) - pack_linear_for_cuda(module, {"weight": q}) - self.assertEqual(module.weight.shape, torch.Size([256, 128])) - - def test_8bit(self): - config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - module = nn.Linear(128, 64, bias=False) - pack_linear_for_cuda(module, {"weight": q}) - self.assertEqual(module.weight.shape, torch.Size([64, 128])) +class TestPackEmbedding(unittest.TestCase): + """pack_embedding_for_cuda with INT8 per-axis weights.""" -class TestPackEmbeddingForCuda(unittest.TestCase): def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") + _require_cuda(self) - def test_int8_gather(self): + def test_gather_correct(self): torch.manual_seed(0) weight = torch.randn(1000, 64, dtype=torch.bfloat16) ids = torch.tensor([0, 1, 42, 500, 999]) - original = weight[ids] + ref = weight[ids] config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") q = quantize_weight(weight, config) module = nn.Embedding(1000, 64) pack_embedding_for_cuda(module, {"weight": q}) - module.to("cuda") - packed_out = module(ids.cuda()) + module.cuda() + out = module(ids.cuda()) rel_error = ( - packed_out.cpu().float() - original.float() - ).abs().mean() / original.float().abs().mean() + out.cpu().float() - ref.float() + ).abs().mean() / ref.float().abs().mean() self.assertLess(rel_error.item(), 0.02) def test_rejects_4bit(self): @@ -175,19 +158,23 @@ def test_rejects_4bit(self): class TestPackModel(unittest.TestCase): + """pack_model handles mixed-precision models and disk loading.""" + def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") + _require_cuda(self) def test_mixed_precision(self): - """pack_model handles 4-bit and 8-bit weights in the same model.""" - q4_config = QuantConfig( - bits=4, group_size=32, symmetric=False, method="min_max" + torch.manual_seed(0) + w4 = torch.randn(64, 128, dtype=torch.bfloat16) + w8 = torch.randn(64, 128, dtype=torch.bfloat16) + q4 = quantize_weight( + w4, + QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max"), + ) + q8 = quantize_weight( + w8, + QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max"), ) - q8_config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") - q4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4_config) - q8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8_config) - with torch.device("meta"): model = nn.ModuleDict( { @@ -198,13 +185,30 @@ def test_mixed_precision(self): pack_model( model, {"q_proj.weight": q4, "v_proj.weight": q8}, DEFAULT_CUDA_PACKERS ) - self.assertEqual(model.q_proj.weight.shape, torch.Size([64, 128])) - self.assertEqual(model.v_proj.weight.shape, torch.Size([64, 128])) + model.cuda() + x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + + ref4 = torch.nn.functional.linear(x, w4.cuda()) + out4 = model.q_proj(x) + self.assertLess( + (out4.float() - ref4.float()).abs().mean().item() + / ref4.float().abs().mean().item(), + 0.15, + ) - def test_load_and_pack(self): - """load_and_pack_for_cuda reads from disk and packs.""" + ref8 = torch.nn.functional.linear(x, w8.cuda()) + out8 = model.v_proj(x) + self.assertLess( + (out8.float() - ref8.float()).abs().mean().item() + / ref8.float().abs().mean().item(), + 0.02, + ) + + def test_load_and_pack_from_disk(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + q = quantize_weight(weight, config) with tempfile.TemporaryDirectory() as d: path = os.path.join(d, "m.safetensors") @@ -223,25 +227,26 @@ def test_load_and_pack(self): } ) load_and_pack_for_cuda(path, model) + self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) self.assertEqual(model.norm.weight.shape, torch.Size([64])) + model.proj.cuda() + x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, weight.cuda()) + out = model.proj(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) -class TestPackOne(unittest.TestCase): - def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") - - def test_quantized_weight(self): + def test_pack_one_quantized(self): config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) - with torch.device("meta"): model = nn.ModuleDict({"proj": nn.Linear(128, 64, bias=False)}) pack_one(model, "proj.weight", q, DEFAULT_CUDA_PACKERS) self.assertNotEqual(model.proj.weight.device.type, "meta") - def test_plain_tensor(self): + def test_pack_one_plain_tensor(self): with torch.device("meta"): model = nn.ModuleDict({"norm": nn.LayerNorm(64, bias=False)}) pack_one( @@ -254,13 +259,11 @@ def test_plain_tensor(self): class TestPackErrorPaths(unittest.TestCase): + def setUp(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA required") + _require_cuda(self) def test_unregistered_module_type(self): - """pack_model raises for module types not in packers dict.""" - class CustomModule(nn.Module): def __init__(self): super().__init__() @@ -276,7 +279,6 @@ def __init__(self): self.assertIn("CustomModule", str(ctx.exception)) def test_missing_weight_detected(self): - """pack_model raises when a parameter stays on meta after packing.""" config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py index 959540f2c45..e654e12f637 100644 --- a/examples/models/gemma4_31b/quantize_and_save.py +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -11,7 +11,7 @@ packed for any backend via ``load_and_pack_for_cuda`` or ``pack_model``. The default recipe runs on CPU. The sensitive recipe requires CUDA for -HQQ asymmetric quantization. CUDA is also required at load-and-pack time. +HQQ asymmetric quantization. Usage: python quantize_and_save.py \\ @@ -115,7 +115,7 @@ def main() -> None: model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) print(f"Quantizing with recipe '{args.quant_recipe}'...") - state_dict = quantize_model(model, recipe) + state_dict = quantize_model(model, recipe, verbose=True) os.makedirs(args.output, exist_ok=True) safetensors_path = os.path.join(args.output, "model.safetensors") diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index d83e76fd630..0ff28aac415 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -19,9 +19,11 @@ import tempfile import unittest +# Register Int4Tensor dispatch before any model usage +import executorch.backends.cuda.int4_dispatch # noqa: F401 + import torch import torch.nn as nn - from executorch.examples.models.gemma4_31b.export import ( export_and_lower, load_prequantized_model, @@ -52,7 +54,7 @@ def setUp(self): _require_cuda(self) def test_generate(self): - """save → load → pack → generate (sampling + greedy).""" + """save → load → pack → generate.""" with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint(tmpdir) model, config = load_prequantized_model( @@ -102,18 +104,16 @@ def test_chunked_prefill_matches_sequential(self): model_chunk.eval() buf_size = config.sliding_window * 2 - prompt_len = buf_size + 8 # exceeds buf_size + prompt_len = buf_size + 8 torch.manual_seed(0) prompt = torch.randint(0, config.vocab_size, (1, prompt_len), device="cuda") - # Sequential: one token at a time (temperature=None returns logits) with torch.no_grad(): for i in range(prompt_len): tok = prompt[:, i : i + 1] pos = torch.tensor([i], dtype=torch.long, device="cuda") logits_seq = model_seq(tok, pos, None) - # Chunked: two chunks respecting buf_size with torch.no_grad(): chunk1 = prompt[:, :buf_size] pos1 = torch.arange(buf_size, dtype=torch.long, device="cuda") @@ -123,9 +123,6 @@ def test_chunked_prefill_matches_sequential(self): pos2 = torch.arange(buf_size, prompt_len, dtype=torch.long, device="cuda") logits_chunk = model_chunk(chunk2, pos2, None) - # Compare last-token logits (skip sampling to avoid RNG differences). - # Use allclose rather than equal — CUDA kernels can produce small FP - # differences across execution shapes. max_diff = (logits_seq[0, -1].float() - logits_chunk[0, -1].float()).abs().max() self.assertTrue( torch.allclose( @@ -169,12 +166,67 @@ def test_export_from_hf_checkpoint(self): pack_model(model, state_dict, DEFAULT_CUDA_PACKERS) model.eval() - params = dict(model.named_parameters()) - self.assertIn("lm_head.weight", params) - self.assertNotIn("layers.5.self_attn.v_proj.weight", params) export_and_lower(model, config, out_dir) self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) +class TestInt4Inference(unittest.TestCase): + """Test Int4Tensor passthrough with dispatch override.""" + + def setUp(self): + _require_cuda(self) + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + self.model, self.config = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + _move_to_cuda(self.model, self.config) + self.model.eval() + + def _forward(self): + with torch.no_grad(): + tok = torch.tensor([[1]], dtype=torch.long, device="cuda") + pos = torch.tensor([0], dtype=torch.long, device="cuda") + temp = torch.tensor([1.0], dtype=torch.float32, device="cuda") + return self.model(tok, pos, temp) + + def test_int4_weights_preserved(self): + """Packing passes Int4Tensor through without conversion.""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + w = self.model.layers[0].mlp.gate_proj.weight.data + self.assertIsInstance(w, Int4Tensor) + + def test_inference_produces_valid_output(self): + out = self._forward() + self.assertEqual(out.shape, torch.Size([1, 1])) + self.assertFalse(out.isnan().any()) + + def test_deterministic(self): + """Same seed produces same output.""" + torch.manual_seed(99) + out1 = self._forward() + # Reset KV cache by reloading + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + model2, config2 = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + _move_to_cuda(model2, config2) + model2.eval() + with torch.no_grad(): + tok = torch.tensor([[1]], dtype=torch.long, device="cuda") + pos = torch.tensor([0], dtype=torch.long, device="cuda") + temp = torch.tensor([1.0], dtype=torch.float32, device="cuda") + torch.manual_seed(99) + out2 = model2(tok, pos, temp) + self.assertEqual(int(out1.item()), int(out2.item())) + + def test_embedding_works(self): + tok = torch.tensor([[1]], dtype=torch.long, device="cuda") + emb = self.model.embed_tokens(tok) + self.assertFalse(emb.isnan().any()) + + if __name__ == "__main__": unittest.main() From 644ec6eed5e2fffb27733bd4291668b778468daf Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 8 May 2026 11:49:41 -0700 Subject: [PATCH 12/14] Minor fix --- examples/models/gemma4_31b/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 362265f1c37..a96dba0d512 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -162,7 +162,7 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - # Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M). max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) - seq_dim = Dim("seq_len", min=2, max=max_prefill) + seq_dim = Dim("seq_len", min=5, max=max_prefill) print(f"Exporting prefill (T in [2, {max_prefill}])...") with torch.no_grad(): prefill_ep = export( From 6273bb234e9ad03c27384f31d517ff3b722e0b4f Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 12 May 2026 08:18:06 -0700 Subject: [PATCH 13/14] On-the-fly RoPE and fix imports after gemma4 upstream rebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adapt gemma4_31b to upstream gemma4 changes (33419c04) that removed precompute_freqs_cis in favor of on-the-fly RoPE computation: - Store inv_freq buffer instead of precomputed [max_seq_len, head_dim] cos/sin tables — saves memory, matches qwen3_5_moe and gemma4 E2B - Compute cos/sin per forward via torch.outer(positions, inv_freq) - Fix gemma4/text_decoder/__init__.py to remove stale precompute_freqs_cis re-export - Update model.md to reflect current architecture --- .../models/gemma4/text_decoder/__init__.py | 1 - examples/models/gemma4_31b/model.md | 34 ++++++------- examples/models/gemma4_31b/model.py | 48 +++++++------------ 3 files changed, 35 insertions(+), 48 deletions(-) diff --git a/examples/models/gemma4/text_decoder/__init__.py b/examples/models/gemma4/text_decoder/__init__.py index 51c96f0717f..5f21130e27d 100644 --- a/examples/models/gemma4/text_decoder/__init__.py +++ b/examples/models/gemma4/text_decoder/__init__.py @@ -10,7 +10,6 @@ apply_rotary_emb, apply_rotary_emb_single, Gemma4KVCache, - precompute_freqs_cis, rotate_half, ) from .gemma4_config import Gemma4Config # noqa: F401 diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 9a8c6d84e5f..8233b6d430e 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -105,7 +105,7 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, | Method | Input | Output (sampled) | |-----------|------------------------------------------------------------|------------------| | `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | -| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[2, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float | +| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[5, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float | Both methods share the same KV-cache buffers via `MemoryPlanningPass(share_mutable_buffers=True)` and @@ -145,11 +145,11 @@ quantize_and_save.py export.py / inference.py | | quantize_weight() load (torchao safetensors) | | - Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked + Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked (used directly) | | - save (torchao safetensors) pack_model() + save (torchao safetensors) int4_dispatch routes to int4_plain_mm | | - model.safetensors Int4TilePackedTo4dTensor (runtime) + model.safetensors dp4a decode / dequant+cuBLAS prefill ``` `embed_tokens` and `lm_head` start tied; they are untied before @@ -159,18 +159,17 @@ lossless for index lookup). ## Runtime buffer materialization -After weight loading (via `pack_model()` or `from_hf_checkpoint()`), the -model's KV caches, RoPE tables, and scalar constants are still on the meta -device. `materialize_runtime_buffers(model, dtype, device)` in `model.py` -replaces them with real tensors: +After weight loading (via `from_hf_checkpoint()`), the model's KV caches, +RoPE inv_freq buffers, and scalar constants are still on the meta device. +`materialize_runtime_buffers(model, dtype, device)` in `model.py` replaces +them with real tensors: - KV caches → zeros in `dtype` (bf16 for inference, bf16 for export) -- RoPE tables → computed per-layer (sliding vs full, different θ and head_dim) +- `inv_freq` → moved to target device (cos/sin computed on the fly per forward) - `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants Called by `export.py` (device="cpu" for tracing) and `inference.py` -(device="cuda" for eager execution). Having one function avoids duplicating -the RoPE computation and constant setup across scripts. +(device="cuda" for eager execution). ## Customizations vs. vLLM / transformers reference @@ -183,9 +182,10 @@ These exist solely to make the model exportable / efficient under ExecuTorch: via modulo and the attention mask reconstructs which slots are valid. Full-attention layers use a flat `Gemma4KVCache` sized to `max_seq_len`. Both use `index_copy_(dim=2, ...)` for trace-friendly updates. -- **Per-layer RoPE tables** registered as `persistent=False` buffers (sliding - uses full RoPE, full uses proportional partial RoPE — head_dim and θ - differ, so the table is not shared). +- **On-the-fly RoPE**: stores only `inv_freq` per layer, computes cos/sin + via `torch.outer(positions, inv_freq)` each forward. Saves memory vs + precomputed `[max_seq_len, head_dim]` tables (sliding uses full RoPE, + full uses proportional partial RoPE — head_dim and θ differ). - **On-device Gumbel-max sampling** so the exported program emits a token rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a single float per step. @@ -198,6 +198,6 @@ These exist solely to make the model exportable / efficient under ExecuTorch: The numerically-sensitive math primitives are imported from `examples.models.gemma4.text_decoder` and shared with the Gemma 4 E2B/E4B example: `RMSNorm`, `RMSNormNoWeight`, `Gemma4MLP`, `Gemma4KVCache`, -`precompute_freqs_cis`, `apply_rotary_emb`. The 31B-specific pieces -(attention with K=V branch, decoder layer, top-level model with softcap + -sampling, checkpoint loader) live in `model.py`. +`apply_rotary_emb`. The 31B-specific pieces (attention with K=V branch, +decoder layer, top-level model with softcap + sampling, checkpoint loader) +live in `model.py`. diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index 4c5d2f5b97e..b0eb4004c52 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -49,7 +49,6 @@ apply_rotary_emb, Gemma4KVCache, Gemma4MLP, - precompute_freqs_cis, RMSNorm, RMSNormNoWeight, ) @@ -255,21 +254,22 @@ def __init__(self, config: Gemma4_31BConfig, layer_idx: int): # Precomputed RoPE table for this layer (per-layer because head_dim # and theta differ between sliding and full attention). For full # attention layers we pass freq_base_dim=head_dim so the zero-padded - # inv_freq matches HF's "proportional" partial RoPE. + # On-the-fly RoPE: store only inv_freq, compute cos/sin per forward. + # Saves memory vs precomputed [max_seq_len, head_dim] tables. if self.is_sliding: rotary_dim = self.head_dim - freq_base_dim = None else: rotary_dim = int(self.head_dim * self.partial_rotary) - freq_base_dim = self.head_dim - freqs_cos, freqs_sin = precompute_freqs_cis( - rotary_dim, - config.max_seq_len, - theta=self.rope_theta, - freq_base_dim=freq_base_dim, + rope_angles = rotary_dim // 2 + inv_freq_rotated = 1.0 / ( + self.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / self.head_dim) ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) + nope_angles = self.head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)]) + else: + inv_freq = inv_freq_rotated + self.register_buffer("inv_freq", inv_freq, persistent=False) # KV cache. Sliding layers use a ring buffer (2x window) to save # memory; full layers use a flat buffer (max_seq_len). @@ -316,10 +316,11 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) - # RoPE on Q and K only (V is not rotated). cos/sin are gathered for - # the current positions to avoid baking the full table into the graph. - cos = self.freqs_cos[input_pos] - sin = self.freqs_sin[input_pos] + # RoPE on Q and K only (V is not rotated). cos/sin computed on the fly. + freqs = torch.outer(input_pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = torch.cos(emb) + sin = torch.sin(emb) q, k = apply_rotary_emb(q, k, cos, sin) # Update cache and read back full K/V. @@ -533,8 +534,7 @@ def from_hf_checkpoint( # and not in the checkpoint — those are the "expected" missing keys. runtime_prefixes = ( ".kv_cache.", - ".freqs_cos", - ".freqs_sin", + ".inv_freq", "embed_normalizer", "logit_softcap", "cache_positions", @@ -675,19 +675,7 @@ def materialize_runtime_buffers( for layer in model.layers: attn = layer.self_attn - if attn.is_sliding: - rotary_dim, freq_base_dim = attn.head_dim, None - else: - rotary_dim = int(attn.head_dim * attn.partial_rotary) - freq_base_dim = attn.head_dim - cos, sin = precompute_freqs_cis( - rotary_dim, - config.max_seq_len, - theta=attn.rope_theta, - freq_base_dim=freq_base_dim, - ) - attn.register_buffer("freqs_cos", cos.to(device), persistent=False) - attn.register_buffer("freqs_sin", sin.to(device), persistent=False) + attn.inv_freq = attn.inv_freq.to(device) model.register_buffer( "embed_normalizer", From e7375a1f325cbfd94ff7853762ee4872a82820d7 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 12 May 2026 11:46:07 -0700 Subject: [PATCH 14/14] Harden int4_plain_mm: dtype checks, scale hoist, docstrings - Add dtype checks for qdata (uint8/int8), scale (bf16), zero (bf16) in C shim - Hoist weight scale/zero loads outside inner loop (reload only on group change) - Clarify int4_dispatch.py docblock: runs at eager/trace time, not .pte runtime - Clarify test docblock: tests eager dispatch, not C shim runtime --- backends/cuda/int4_dispatch.py | 24 ++++++++++++------- backends/cuda/runtime/shims/int4_plain_mm.cuh | 16 +++++++++++-- backends/cuda/tests/test_int4_dispatch.py | 12 ++++++++-- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/backends/cuda/int4_dispatch.py b/backends/cuda/int4_dispatch.py index b4fc01c9105..d8bcb1acbd0 100644 --- a/backends/cuda/int4_dispatch.py +++ b/backends/cuda/int4_dispatch.py @@ -4,14 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Int4Tensor F.linear dispatch for CUDA. - -Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` — in eager this - dequants + calls F.linear; in .pte runtime the C shim runs a - W4A8 dp4a matvec kernel. -Prefill (M>4): Inline dequant + F.linear — AOTI compiles this into the .so - using inductor's own cuBLAS codegen, so no explicit cuBLAS - dependency in our shim library. +"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. + +This module overrides Int4Tensor's F.linear dispatch so that torch.export +traces through our custom op and dequant logic instead of torchao's default +(mslk/tinygemm). The code here executes during eager inference and during +AOTI export tracing — it does NOT run at .pte runtime. + +At .pte runtime, the captured graph is executed by the AOTI-generated .so: + - The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that + runs the W4A8 dp4a matvec kernel (backends/cuda/runtime/shims/). + - The inline dequant + F.linear is compiled by inductor into fused Triton + dequant + cuBLAS matmul kernels. + +Dispatch strategy (determines what gets captured in the export graph): + Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` + Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) Import this module before using nn.Linear with Int4Tensor weights:: diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index 64fccb7c093..ea236e8d069 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -130,6 +130,9 @@ __global__ void __launch_bounds__(MV_THREADS) float sum = 0.0f; + int32_t prev_g = -1; + float ws = 0.0f, wz = 0.0f; + for (int32_t i = lane_id; i < K_half_16; i += MV_WARP_SIZE) { uint4 packed16 = __ldg(&qrow16[i]); int32_t k_base = i * 32; @@ -141,6 +144,12 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t k_word = k_base + w * 8; int32_t g = k_word >> gs_shift; + if (g != prev_g) { + ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); + wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); + prev_g = g; + } + int32_t vi_lo = packed & 0x0F0F0F0F; int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; @@ -156,8 +165,6 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t dp = __dp4a(vi_lo, a_even, 0); dp = __dp4a(vi_hi, a_odd, dp); - float ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); - float wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); float a_scale = qb->d; int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); @@ -212,6 +219,11 @@ void _int4_plain_mm_cuda( int32_t N = qdata.size(0); ET_CHECK(A.dtype() == c10::ScalarType::BFloat16); + ET_CHECK( + qdata.dtype() == c10::ScalarType::Byte || + qdata.dtype() == c10::ScalarType::Char); + ET_CHECK(scale.dtype() == c10::ScalarType::BFloat16); + ET_CHECK(zero.dtype() == c10::ScalarType::BFloat16); ET_CHECK(A.dim() == 2); ET_CHECK(qdata.dim() == 2); ET_CHECK(qdata.size(1) == K / 2); diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py index bf9ef01518b..c793544ad48 100644 --- a/backends/cuda/tests/test_int4_dispatch.py +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -7,10 +7,18 @@ """Tests for Int4Tensor F.linear dispatch via int4_dispatch. +These tests validate the eager / trace-time dispatch path — the same code +that torch.export traces through when building the AOTI graph. They do NOT +test the .pte runtime C shim (dp4a kernel); that is covered by +test_aoti_torch_cuda_int4_plain_mm.cpp (C++ unit tests) and +test_cuda_pipeline.py::TestCudaExport (end-to-end export + lower). + The API contract: after importing int4_dispatch, F.linear and nn.Linear with Int4Tensor weights produce numerically correct results. Tests verify -this across decode (M=1), prefill (M>1), batched (3D), bias, group sizes, -and symmetric/asymmetric quantization. +this across decode (M<=4), prefill (M>4), batched (3D), bias, group sizes, +and symmetric/asymmetric quantization. Correctness is measured as mean +relative error against the unquantized bf16 reference (not per-element +atol/rtol, which is too strict for INT4 quantization noise). Usage: python -m pytest backends/cuda/tests/test_int4_dispatch.py -v