diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index c3b7c058ee6..087917c1116 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -148,6 +148,10 @@ jobs: # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" + # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) + pip install gguf + python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/Makefile b/Makefile index 3c0eac14bce..ba61dddce44 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -126,6 +126,7 @@ help: @echo " llava-cpu - Build Llava runner with CPU backend" @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -425,6 +426,15 @@ qwen3_5_moe-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" +gemma4_31b-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Gemma 4 31B runner with CUDA..." + cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 157cc05a54f..217c893efe5 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -110,7 +110,8 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp # Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu - runtime/shims/sort.cu runtime/shims/rand.cu + runtime/shims/int4_plain_mm.cu runtime/shims/sort.cu + runtime/shims/rand.cu ) endif() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index a3169680b6d..d732a12a8fe 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -226,6 +226,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "at::_ops::_weight_int4pack_mm::call": None, "at::_ops::sort_stable::call": None, "aoti_torch_cuda_randint_low_out": None, + "executorch_cuda::int4_plain_mm": None, + "aoti_torch_cuda_int4_plain_mm": None, } @classmethod @@ -298,6 +300,20 @@ def get_aoti_compile_options( "aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel, } + try: + import torch + + options["aot_inductor.custom_ops_to_c_shims"] = { + torch.ops.executorch_cuda.int4_plain_mm.default: [ + "AOTITorchError aoti_torch_cuda_int4_plain_mm(" + "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " + "AtenTensorHandle, int64_t, AtenTensorHandle*)" + ], + } + except AttributeError: + # int4_dispatch.py not imported — op not registered, skip C shim mapping + pass + # Parse compile_specs to check for platform platform = "linux" diff --git a/backends/cuda/int4_dispatch.py b/backends/cuda/int4_dispatch.py new file mode 100644 index 00000000000..d8bcb1acbd0 --- /dev/null +++ b/backends/cuda/int4_dispatch.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. + +This module overrides Int4Tensor's F.linear dispatch so that torch.export +traces through our custom op and dequant logic instead of torchao's default +(mslk/tinygemm). The code here executes during eager inference and during +AOTI export tracing — it does NOT run at .pte runtime. + +At .pte runtime, the captured graph is executed by the AOTI-generated .so: + - The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that + runs the W4A8 dp4a matvec kernel (backends/cuda/runtime/shims/). + - The inline dequant + F.linear is compiled by inductor into fused Triton + dequant + cuBLAS matmul kernels. + +Dispatch strategy (determines what gets captured in the export graph): + Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` + Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) + +Import this module before using nn.Linear with Int4Tensor weights:: + + import executorch.backends.cuda.int4_dispatch # noqa: F401 +""" + +import torch +import torch.nn.functional as F +from torch.library import impl, Library +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + +# --------------------------------------------------------------------------- +# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager +# --------------------------------------------------------------------------- + +_lib = Library("executorch_cuda", "DEF") +_lib.define( + "int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor" +) + + +@impl(_lib, "int4_plain_mm", "Meta") +def _meta(self, qdata, scale, zero, group_size): + return torch.empty( + self.shape[0], qdata.shape[0], dtype=self.dtype, device=self.device + ) + + +@impl(_lib, "int4_plain_mm", "CUDA") +def _cuda(self, qdata, scale, zero, group_size): + return _dequant_matmul(self, qdata, scale, zero, group_size) + + +def _dequant_matmul(x, qdata, scale, zero, group_size): + """Dequant INT4 weights to input dtype and call F.linear.""" + N, K_half = qdata.shape + K = K_half * 2 + n_groups = K // group_size + gs_half = group_size // 2 + dtype = x.dtype + + p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) + low = (p & 0x0F).to(dtype) + high = ((p >> 4) & 0x0F).to(dtype) + data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) + + s = scale.to(dtype).t().unsqueeze(-1) + z = zero.to(dtype).t().unsqueeze(-1) + w_deq = ((data - z) * s).reshape(N, K) + + return F.linear(x, w_deq) + + +# --------------------------------------------------------------------------- +# Int4Tensor F.linear dispatch +# --------------------------------------------------------------------------- + +aten = torch.ops.aten +_implements = Int4Tensor.implements +_implements_torch_function = Int4Tensor.implements_torch_function + + +@_implements([aten.linear.default]) +@_implements_torch_function([F.linear]) +def _(func, types, args, kwargs): + input_tensor = args[0] + weight_tensor = args[1] + bias = args[2] if len(args) > 2 else None + + orig_shape = input_tensor.shape + x_2d = input_tensor.reshape(-1, orig_shape[-1]) + + qdata = weight_tensor.qdata + scale = weight_tensor.scale + zero = weight_tensor.zero_point + gs = weight_tensor.block_size[-1] + + M = x_2d.shape[0] + if M <= 4: + out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs) + else: + out = _dequant_matmul(x_2d, qdata, scale, zero, gs) + + out = out.reshape(*orig_shape[:-1], -1) + if bias is not None: + out = out + bias + return out diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cu b/backends/cuda/runtime/shims/int4_plain_mm.cu new file mode 100644 index 00000000000..fd8fe3b0c3b --- /dev/null +++ b/backends/cuda/runtime/shims/int4_plain_mm.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda_int4_plain_mm( + Tensor* self, + Tensor* qdata, + Tensor* scale, + Tensor* zero, + int64_t group_size, + Tensor** ret0) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + qdata != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: qdata is null"); + + ET_CHECK_OR_RETURN_ERROR( + scale != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale is null"); + + ET_CHECK_OR_RETURN_ERROR( + zero != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: zero is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: ret0 is null"); + + int32_t M = self->size(0); + int32_t N = qdata->size(0); + Tensor* C = nullptr; + std::array c_shape = {M, N}; + std::array c_stride = {N, 1}; + aoti_torch_empty_strided( + 2, + c_shape.data(), + c_stride.data(), + static_cast( + executorch::backends::aoti::slim::c10::ScalarType::BFloat16), + static_cast( + executorch::backends::aoti::slim::c10::DeviceType::CUDA), + 0, + &C); + + _int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + + *ret0 = C; + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh new file mode 100644 index 00000000000..ea236e8d069 --- /dev/null +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -0,0 +1,278 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// W4A8 dp4a matvec for INT4 decode (M <= 4). +// +// Reads plain nibble-packed [N, K//2] weights (Int4Tensor format). +// Scale/zero layout: [K//gs, N] (Int4Tensor's native layout). +// +// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks), +// then uses dp4a for fused int4×int8 dot products with 16-byte vectorized +// loads and warp-cooperative quantization. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +namespace c10 = executorch::backends::aoti::slim::c10; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +constexpr int32_t MV_NWARPS = 8; +constexpr int32_t MV_WARP_SIZE = 32; +constexpr int32_t MV_THREADS = MV_NWARPS * MV_WARP_SIZE; +constexpr int32_t Q8_BLOCK_SIZE = 32; + +__host__ __forceinline__ int32_t log2_pow2(int32_t v) { + int32_t r = 0; + while (v > 1) { + v >>= 1; + r++; + } + return r; +} + +// --------------------------------------------------------------------------- +// Activation quantization: bf16 → int8 (warp-cooperative, per-32-element blocks) +// --------------------------------------------------------------------------- + +struct Q8Block { + int8_t qs_even[Q8_BLOCK_SIZE / 2]; + int8_t qs_odd[Q8_BLOCK_SIZE / 2]; + float d; // scale +}; + +__global__ void quantize_activations_q8_kernel( + const __nv_bfloat16* __restrict__ A, + Q8Block* __restrict__ q8, + int32_t K) { + const int32_t m = blockIdx.y; + const int32_t block_id = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t n_blocks = K / Q8_BLOCK_SIZE; + if (block_id >= n_blocks) + return; + + const int32_t lane = threadIdx.x; + const __nv_bfloat16* src = + A + static_cast(m) * K + block_id * Q8_BLOCK_SIZE; + Q8Block* dst = q8 + static_cast(m) * n_blocks + block_id; + + float val = __bfloat162float(src[lane]); + + float amax = fabsf(val); + for (int offset = 16; offset > 0; offset >>= 1) + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, offset)); + + float d = amax / 127.0f; + float id = (d > 0.0f) ? 1.0f / d : 0.0f; + int32_t q = __float2int_rn(val * id); + q = max(-128, min(127, q)); + + if (lane % 2 == 0) + dst->qs_even[lane / 2] = static_cast(q); + else + dst->qs_odd[lane / 2] = static_cast(q); + + if (lane == 0) { + dst->d = d; + } +} + +// --------------------------------------------------------------------------- +// W4A8 dp4a matvec kernel +// --------------------------------------------------------------------------- + +__global__ void __launch_bounds__(MV_THREADS) + int4_w4a8_matvec_kernel( + const uint8_t* __restrict__ qdata, + const __nv_bfloat16* __restrict__ w_scale, + const __nv_bfloat16* __restrict__ w_zero, + const Q8Block* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift) { + const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y; + const int32_t m = blockIdx.y; + if (n >= N) + return; + + const int32_t K_half = K / 2; + const int32_t lane_id = threadIdx.x; + const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; + + const uint8_t* qrow = qdata + static_cast(n) * K_half; + const __nv_bfloat16* scale_base = w_scale + n; + const __nv_bfloat16* zero_base = w_zero + n; + const int32_t scale_stride = N; + const Q8Block* q8_row = q8 + static_cast(m) * n_q8_blocks; + + const uint4* qrow16 = reinterpret_cast(qrow); + const int32_t K_half_16 = K_half / 16; + + float sum = 0.0f; + + int32_t prev_g = -1; + float ws = 0.0f, wz = 0.0f; + + for (int32_t i = lane_id; i < K_half_16; i += MV_WARP_SIZE) { + uint4 packed16 = __ldg(&qrow16[i]); + int32_t k_base = i * 32; + uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + +#pragma unroll + for (int32_t w = 0; w < 4; w++) { + uint32_t packed = words[w]; + int32_t k_word = k_base + w * 8; + int32_t g = k_word >> gs_shift; + + if (g != prev_g) { + ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); + wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); + prev_g = g; + } + + int32_t vi_lo = packed & 0x0F0F0F0F; + int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; + + int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE; + int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2; + const Q8Block* qb = &q8_row[q8_block_idx]; + + int32_t a_even = *reinterpret_cast( + qb->qs_even + q8_half_offset); + int32_t a_odd = *reinterpret_cast( + qb->qs_odd + q8_half_offset); + + int32_t dp = __dp4a(vi_lo, a_even, 0); + dp = __dp4a(vi_hi, a_odd, dp); + + float a_scale = qb->d; + + int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); + a_sum8 = __dp4a(0x01010101, a_odd, a_sum8); + + sum += ws * a_scale * + (static_cast(dp) - wz * static_cast(a_sum8)); + } + } + + for (int offset = MV_WARP_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + if (lane_id == 0) + out[static_cast(m) * N + n] = __float2bfloat16(sum); +} + +// --------------------------------------------------------------------------- +// Persistent Q8 buffer (lazy init, not thread-safe — single-stream only) +// --------------------------------------------------------------------------- + +static Q8Block* g_q8_buf = nullptr; +static size_t g_q8_buf_size = 0; + +static Q8Block* get_q8_buffer(size_t needed) { + if (g_q8_buf_size < needed) { + if (g_q8_buf) + cudaFree(g_q8_buf); + cudaError_t err = cudaMalloc(&g_q8_buf, needed); + ET_CHECK_MSG( + err == cudaSuccess, + "cudaMalloc failed for Q8 buffer: %s", + cudaGetErrorString(err)); + g_q8_buf_size = needed; + } + return g_q8_buf; +} + +// --------------------------------------------------------------------------- +// Main entry point +// --------------------------------------------------------------------------- + +void _int4_plain_mm_cuda( + const Tensor& A, // [M, K] bf16 + const Tensor& qdata, // [N, K//2] uint8 + const Tensor& scale, // [K//gs, N] bf16 + const Tensor& zero, // [K//gs, N] bf16 + int64_t group_size, + Tensor* output) { // [M, N] bf16, pre-allocated + int32_t M = A.size(0); + int32_t K = A.size(1); + int32_t N = qdata.size(0); + + ET_CHECK(A.dtype() == c10::ScalarType::BFloat16); + ET_CHECK( + qdata.dtype() == c10::ScalarType::Byte || + qdata.dtype() == c10::ScalarType::Char); + ET_CHECK(scale.dtype() == c10::ScalarType::BFloat16); + ET_CHECK(zero.dtype() == c10::ScalarType::BFloat16); + ET_CHECK(A.dim() == 2); + ET_CHECK(qdata.dim() == 2); + ET_CHECK(qdata.size(1) == K / 2); + ET_CHECK(scale.dim() == 2); + ET_CHECK(scale.size(1) == N); + ET_CHECK(zero.dim() == 2); + ET_CHECK(zero.size(1) == N); + + int32_t gs = static_cast(group_size); + ET_CHECK_MSG( + gs > 0 && (gs & (gs - 1)) == 0, + "group_size=%d must be a power of 2", + gs); + ET_CHECK_MSG( + K >= Q8_BLOCK_SIZE && K % Q8_BLOCK_SIZE == 0, + "K=%d must be a positive multiple of %d for dp4a kernel", + K, + Q8_BLOCK_SIZE); + + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_MSG(stream_result.ok(), "Failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + int32_t gs_shift = log2_pow2(gs); + + // Quantize activations to INT8 + int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; + size_t q8_bytes = static_cast(M) * n_q8_blocks * sizeof(Q8Block); + Q8Block* q8_buf = get_q8_buffer(q8_bytes); + + constexpr int32_t Q8_WARPS = 8; + int32_t blocks_per_m = (n_q8_blocks + Q8_WARPS - 1) / Q8_WARPS; + dim3 q8_grid(blocks_per_m, M); + dim3 q8_block(MV_WARP_SIZE, Q8_WARPS); + quantize_activations_q8_kernel<<>>( + reinterpret_cast(A.data_ptr()), + q8_buf, + K); + + // dp4a matvec + dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M); + dim3 block(MV_WARP_SIZE, MV_NWARPS); + int4_w4a8_matvec_kernel<<>>( + reinterpret_cast(qdata.data_ptr()), + reinterpret_cast(scale.data_ptr()), + reinterpret_cast(zero.data_ptr()), + q8_buf, + reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), + N, K, gs_shift); +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4_plain_mm.h b/backends/cuda/runtime/shims/int4_plain_mm.h new file mode 100644 index 00000000000..0935937cd7a --- /dev/null +++ b/backends/cuda/runtime/shims/int4_plain_mm.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * INT4 matrix multiplication reading plain nibble-packed weights. + * + * Weight format: [N, K//2] uint8, two INT4 values per byte + * (low nibble = even k, high nibble = odd k). + * Scale: [K//group_size, N] bf16 per-group scales (Int4Tensor layout). + * Zero: [K//group_size, N] bf16 per-group zero points. + * W4A8 dp4a matvec: dynamically quantizes activations to INT8, + * then uses dp4a for fused int4×int8 dot products. + * + * @param self Input activation [M, K] bf16 + * @param qdata Packed weights [N, K//2] uint8 + * @param scale Per-group scales [K//group_size, N] bf16 + * @param zero Per-group zero points [K//group_size, N] bf16 + * @param group_size Quantization group size (32, 64, 128) + * @param ret0 Output [M, N] bf16 + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_int4_plain_mm( + Tensor* self, + Tensor* qdata, + Tensor* scale, + Tensor* zero, + int64_t group_size, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index aec5219d680..62e9180d603 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -48,6 +48,11 @@ set(CUDA_SHIM_TESTS test_aoti_torch_assign_tensors_out ) +# CUDA-specific tests requiring GPU kernels +set(CUDA_KERNEL_TESTS test_aoti_torch_cuda__weight_int4pack_mm + test_aoti_torch_cuda_int4_plain_mm +) + enable_testing() foreach(test_name ${CUDA_SHIM_TESTS}) @@ -67,3 +72,21 @@ foreach(test_name ${CUDA_SHIM_TESTS}) add_test(NAME ${test_name} COMMAND ${test_name}) endforeach() + +foreach(test_name ${CUDA_KERNEL_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + target_include_directories( + ${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ) + + target_compile_definitions(${test_name} PRIVATE CUDA_AVAILABLE=1) + + target_link_libraries( + ${test_name} PRIVATE GTest::gtest GTest::gtest_main aoti_cuda_shims + executorch_core CUDA::cudart + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) +endforeach() diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp new file mode 100644 index 00000000000..ab18e33c713 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp @@ -0,0 +1,397 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::backends::cuda::aoti_torch_cuda_int4_plain_mm; +using executorch::backends::cuda::aoti_torch_empty_strided; +using executorch::backends::cuda::AOTITorchError; +using executorch::runtime::Error; +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +using Tensor = executorch::backends::aoti::slim::SlimTensor; + +class AOTITorchInt4PlainMMTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available"; + } + } + + Tensor* create_tensor( + const std::vector& sizes, + slim_c10::ScalarType dtype) { + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + static_cast(dtype), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } + + Tensor* create_bf16(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::BFloat16); + } + + Tensor* create_uint8(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::Byte); + } + + // Upload raw bytes to a CUDA tensor. + void upload(Tensor* t, const void* host_data, size_t bytes) { + cudaMemcpy(t->data_ptr(), host_data, bytes, cudaMemcpyHostToDevice); + } + + // Download CUDA tensor to host buffer. + void download(const Tensor* t, void* host_data, size_t bytes) { + cudaMemcpy(host_data, t->data_ptr(), bytes, cudaMemcpyDeviceToHost); + } + + // Run the shim and return the output tensor (asserts success). + Tensor* run( + Tensor* A, + Tensor* qdata, + Tensor* scale, + Tensor* zero, + int64_t group_size) { + Tensor* output = nullptr; + AOTITorchError error = aoti_torch_cuda_int4_plain_mm( + A, qdata, scale, zero, group_size, &output); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(output, nullptr); + return output; + } + + // Check output bf16 values against expected, with absolute tolerance. + void check_bf16_output( + Tensor* output, + const uint16_t* expected_data, + int64_t count, + float atol = 0.1f) { + std::vector actual(count); + download(output, actual.data(), count * sizeof(uint16_t)); + cudaDeviceSynchronize(); + + for (int64_t i = 0; i < count; i++) { + // Convert bf16 raw bits to float: bf16 is the upper 16 bits of float32. + uint32_t actual_bits = static_cast(actual[i]) << 16; + uint32_t expected_bits = static_cast(expected_data[i]) << 16; + float actual_f, expected_f; + memcpy(&actual_f, &actual_bits, sizeof(float)); + memcpy(&expected_f, &expected_bits, sizeof(float)); + + EXPECT_NEAR(actual_f, expected_f, atol) + << "Mismatch at index " << i << ": actual=" << actual_f + << " expected=" << expected_f; + } + } +}; + +// MultiGroupRandom: M=1, N=4, K=32, gs=16 +// scale/zero layout: [K//gs=2, N=4] +TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { + int64_t M = 1, K = 32, N = 4, gs = 16; + + // clang-format off + uint8_t qdata_host[] = { + 0x36, 0xEC, 0x7A, 0x4C, 0x96, 0x62, 0xAA, 0x47, + 0x73, 0x27, 0x45, 0x71, 0xDB, 0x15, 0xBF, 0x04, + 0x9B, 0xC5, 0x8B, 0xA0, 0xEA, 0xF9, 0xBB, 0xEF, + 0xDD, 0xDE, 0xB2, 0x36, 0x8F, 0x42, 0x62, 0x84, + 0x16, 0x83, 0xDB, 0x91, 0x98, 0x14, 0xB3, 0xBE, + 0xB6, 0x7C, 0x2E, 0x0D, 0x13, 0x37, 0xD1, 0x55, + 0x39, 0xC5, 0x1E, 0xB9, 0x91, 0x3D, 0xED, 0xEF, + 0xD7, 0xB6, 0xD8, 0x47, 0xCF, 0xE1, 0x74, 0x89 + }; + uint16_t scale_host[] = {0xBD86, 0x3DB0, 0xBE26, 0xBE0F, 0xBCC3, 0xBD4E, 0xBE7D, 0xBDBE}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0x3FAD, 0xBF3D, 0x3F9E, 0x4002, 0x3F34, 0x3F9B, 0x3F49, 0x3F8F, 0x3FAD, 0x3DD8, 0x3DFA, 0xBFA5, 0xBF02, 0xBE45, 0x3F97, 0x3F5F, 0xBF85, 0x3DFD, 0x3EDE, 0x3E42, 0xBF86, 0xBE84, 0xBF06, 0x3F9E, 0xBF22, 0x3FDE, 0xBF2E, 0x3E6B, 0x3F72, 0xBEE5, 0x3EBB, 0xC00F}; + uint16_t expected[] = {0xBFCC, 0x3FB5, 0x4046, 0xC01E}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// SingleGroup: M=1, N=8, K=32, gs=32 +// scale/zero layout: [K//gs=1, N=8] +TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { + int64_t M = 1, K = 32, N = 8, gs = 32; + + // clang-format off + uint8_t qdata_host[] = { + 0x31, 0x89, 0x89, 0x42, 0x45, 0x71, 0x3E, 0x17, + 0x01, 0xBD, 0xB6, 0x74, 0x02, 0x8C, 0x48, 0xB9, + 0xF7, 0xFA, 0xEB, 0xE5, 0xC4, 0xE9, 0x91, 0x50, + 0x9F, 0x33, 0xA6, 0xB2, 0xC5, 0xC0, 0xB5, 0xC1, + 0x2B, 0xDB, 0x1F, 0xB9, 0xC1, 0xCD, 0x83, 0x98, + 0x92, 0xB8, 0x70, 0xBD, 0x23, 0x60, 0x0D, 0xB2, + 0x3A, 0xC2, 0xB8, 0x3A, 0x5D, 0x5D, 0xC9, 0x14, + 0xDD, 0xEF, 0xBF, 0xBE, 0x4C, 0x79, 0xE6, 0xBB, + 0x75, 0xBA, 0x05, 0x73, 0xDC, 0x9B, 0xD5, 0x77, + 0x88, 0xE0, 0x32, 0x04, 0xB8, 0xE0, 0xA9, 0x80, + 0xB4, 0xD1, 0x70, 0x29, 0xFA, 0x7A, 0xA6, 0x1C, + 0x24, 0x86, 0xD2, 0xDB, 0x2E, 0x27, 0xF3, 0xEF, + 0xAD, 0xA2, 0x16, 0xEB, 0x6E, 0xFF, 0x3F, 0xAB, + 0x6C, 0x47, 0x94, 0x29, 0xB7, 0x59, 0xE5, 0x51, + 0x20, 0xC7, 0x60, 0x27, 0x68, 0x4B, 0x52, 0xFD, + 0x10, 0x07, 0xB5, 0x53, 0x89, 0x3E, 0xA1, 0xDE + }; + uint16_t scale_host[] = {0xBD36, 0x3B22, 0xBDD0, 0x3C6E, 0x3D9A, 0xBE63, 0xBE50, 0x3D28}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0xBFDD, 0x3F43, 0x3EBF, 0xBF3D, 0xBF8E, 0xBE61, 0xBFB3, 0xBF32, 0xBF06, 0xBF8E, 0xBD93, 0x3E29, 0x3F96, 0x3E1D, 0xBFEC, 0xBEA5, 0xBF44, 0xC01C, 0xBF14, 0x3E92, 0xBF08, 0x3EA5, 0xBF08, 0x3E05, 0xBDC4, 0xBD97, 0xBFA1, 0xBE62, 0xBEDF, 0xBFFC, 0xBD87, 0xBFA5}; + uint16_t expected[] = {0xC031, 0x3BF8, 0x3E81, 0xBF19, 0x3FCB, 0xBF56, 0x4076, 0x3F20}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// PrefillBatch: M=8, N=4, K=64, gs=32 +// scale/zero layout: [K//gs=2, N=4] +TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { + int64_t M = 8, K = 64, N = 4, gs = 32; + + // clang-format off + uint8_t qdata_host[] = { + 0xAD, 0x87, 0xDD, 0x5D, 0x57, 0x1B, 0x06, 0xE3, + 0xDE, 0xED, 0x8C, 0x7F, 0x1F, 0x75, 0x38, 0xDA, + 0xD7, 0x0B, 0xE7, 0xDB, 0x2B, 0x81, 0xE0, 0xA8, + 0xBC, 0xCB, 0xC9, 0x48, 0xCD, 0xD5, 0x4E, 0xA9, + 0x1D, 0x8D, 0x02, 0x7D, 0xEB, 0xE2, 0xD8, 0x0A, + 0x5D, 0xAC, 0x36, 0xA8, 0x27, 0x31, 0xCD, 0xE5, + 0xA3, 0x29, 0x08, 0x3D, 0x2B, 0x1F, 0x2A, 0xB0, + 0x45, 0x73, 0xD4, 0x02, 0x38, 0xEA, 0x0D, 0xA0, + 0xFA, 0x9A, 0xA4, 0x6E, 0x69, 0x35, 0x15, 0x7D, + 0xB5, 0x39, 0x26, 0x62, 0x0D, 0x8D, 0x1E, 0x27, + 0x9E, 0x01, 0x19, 0xAB, 0x17, 0xD2, 0xB3, 0x24, + 0x87, 0x34, 0x2E, 0xDD, 0x4E, 0x64, 0x6B, 0x20, + 0xA3, 0xAA, 0xED, 0x24, 0x80, 0xD0, 0x47, 0x90, + 0x6A, 0x45, 0x1E, 0x1C, 0xBD, 0x7D, 0xA4, 0x04, + 0x48, 0x1A, 0xD9, 0xCF, 0x29, 0xBC, 0x01, 0x07, + 0xB9, 0x00, 0x39, 0xB6, 0xC8, 0x2A, 0xE8, 0x17 + }; + uint16_t scale_host[] = {0xBE06, 0x3E0B, 0xBD82, 0x3DFA, 0x3D5E, 0xBE25, 0x3CBE, 0xBDD2}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0xBF37, 0x3FB7, 0xBF20, 0xBCC0, 0x3F88, 0xBD3F, 0xC02C, 0x3F73, 0xBF9F, 0x3FCA, 0x3E04, 0xBE88, 0x3F5F, 0x4002, 0xBF52, 0x3F1A, 0x3F2B, 0x3F35, 0xBF20, 0xBFF0, 0xBEB0, 0x3F90, 0x3F67, 0xBF85, 0x3F8F, 0x3FEB, 0x3F3A, 0xBEF4, 0xBF31, 0x3CE2, 0xBF74, 0x3EBF, 0xBF4D, 0x400C, 0xBF9E, 0xBD45, 0x3E8E, 0x3FE6, 0x3F7C, 0xBEEB, 0x4027, 0xBF0F, 0x3F5E, 0x3E15, 0x3E69, 0x3F82, 0x3FB3, 0x3E10, 0xBF17, 0x3F88, 0xBFDB, 0x3FA5, 0x3F1B, 0xBE50, 0x3E64, 0xBF5A, 0x3E78, 0x3F1A, 0x3F06, 0xBF51, 0x0000, 0xBF25, 0x3F80, 0x3E34, 0x3EA8, 0xBE9F, 0xBF67, 0x3DF1, 0xBF5C, 0xC020, 0xBEA6, 0x3E7D, 0xBF51, 0x3F70, 0x3F2C, 0xBE25, 0xBEB4, 0xBDEB, 0x3EE4, 0x3E29, 0xBFE5, 0x3E1F, 0x3F03, 0xBF6C, 0xBE8D, 0xBEB9, 0x3FB0, 0xBD7F, 0xBFBB, 0xBF18, 0x3F28, 0xBF0F, 0xBEF5, 0x3F97, 0x3FA8, 0x3FAC, 0x3E51, 0x3F84, 0xBF81, 0xBF5B, 0xBF2E, 0x3FBF, 0xBFE6, 0xBFC7, 0x3F53, 0x3F30, 0xC00C, 0x3F24, 0xBE79, 0xBFB3, 0x3F73, 0x3F62, 0xBF41, 0x3D93, 0xBF8C, 0x3FF3, 0x3F17, 0x3F10, 0x3F1C, 0x3F1E, 0xBF88, 0x3F33, 0xBEAA, 0xBFE3, 0x3EB4, 0xBEAA, 0x3E3E, 0x3F37, 0xC013, 0x3F27, 0xBEF8, 0xBDAD, 0xBF02, 0x3F3E, 0x3EA5, 0xBE6C, 0xBF3D, 0xBF3C, 0x3F82, 0xBFC1, 0x3FC4, 0xBF32, 0xBFD2, 0xBE9B, 0x3EAD, 0x3FA5, 0x3F67, 0x3F10, 0x3F2C, 0xBFCD, 0x3BED, 0xBF91, 0xBF92, 0x3F25, 0x3EDB, 0x3EAB, 0x3F14, 0x3FB9, 0xBF92, 0xBE6E, 0x3F9E, 0x3EC5, 0xC01F, 0x3F90, 0x400E, 0xBFF4, 0xBEC4, 0x3D2E, 0x0000, 0xBF07, 0xBF0D, 0x3FD8, 0x3EC5, 0x3F78, 0xBF45, 0xBED8, 0xBE3D, 0xBF84, 0x3F44, 0xBF70, 0x3E40, 0x3F34, 0x3FCA, 0xBF7C, 0x3E8F, 0x3E87, 0x3F7B, 0x3FBC, 0xBF92, 0xBF77, 0x3F80, 0xBFCB, 0xC006, 0xBF23, 0x3FA6, 0x3F5A, 0x3E86, 0x3F65, 0xBF7E, 0x3D96, 0xBFCE, 0xBF2C, 0xBF44, 0x3DD7, 0x3F96, 0x3F08, 0xBEEC, 0x3EA8, 0x3F4C, 0xBF5F, 0x3EFA, 0xBF97, 0x3E89, 0x3FFE, 0x3FA8, 0xBF89, 0xBEC0, 0xBE90, 0x3EEF, 0x3F88, 0x3F60, 0x3F52, 0xBFD8, 0x3F1B, 0xBF44, 0x3F13, 0xBF09, 0x3FAE, 0xBF38, 0xBEBF, 0x3EE0, 0xBEF9, 0xBE7D, 0xBFDE, 0x3F11, 0xBFFE, 0x3E49, 0xBF78, 0x3F08, 0x3F30, 0x3D99, 0xBF8B, 0xBFB9, 0xBEE6, 0x3E43, 0x3E46, 0x4003, 0x3FBF, 0xBF3E, 0xBEDA, 0xBE98, 0x3F8C, 0xBE0D, 0xBD4B, 0xBF3C, 0x3E98, 0xBF34, 0xBFFC, 0xBF1F, 0xBF54, 0x3BC2, 0xBF90, 0xBE9F, 0xBE83, 0x3F88, 0xBF00, 0xBFBD, 0x3F88, 0x3E00, 0x3DDC, 0x3F1F, 0xBEAD, 0x3FB8, 0x3E57, 0x3F7C, 0xBE8D, 0x3F03, 0xC002, 0xBF1F, 0xBFE1, 0xBFAC, 0xBF6A, 0xBFE4, 0xBF28, 0x3E58, 0xBF73, 0xBFAD, 0xBFDE, 0xBFE1, 0xBEC3, 0xBEB9, 0xBF40, 0x3E80, 0x3F7B, 0x3E99, 0xBF49, 0x3F12, 0x3DC7, 0xBFFE, 0x3DC4, 0xBD03, 0xBE00, 0xBFE9, 0xBEFC, 0x3F2F, 0xBE76, 0x3F9C, 0x3F0C, 0x3F3E, 0x3FAE, 0xBF91, 0x3EC5, 0x3EE9, 0x3F49, 0x3F39, 0xBF35, 0x3F66, 0xBF31, 0x3F83, 0x3F6F, 0xBEDC, 0x3F24, 0x3F82, 0x3F09, 0xBEF2, 0xBFB6, 0xBF00, 0xBED8, 0xBFAE, 0x3F76, 0xBFCC, 0xBE58, 0x3CB9, 0x3E38, 0x3FD2, 0x3FDC, 0xBFA8, 0xBE3E, 0xBFB0, 0xBD7D, 0x3F2A, 0xBFD0, 0xBF30, 0xBFE0, 0xBFA7, 0xBF82, 0xBF9A, 0xBED2, 0xBF2A, 0x3FBC, 0xBF3F, 0xBF48, 0xBEB6, 0xBF0D, 0xBDE5, 0xBF18, 0xBF57, 0x3F18, 0x3F54, 0x3F1A, 0x3FA3, 0xBF9A, 0xBF1D, 0xBF64, 0x3EB1, 0xBF89, 0x3F54, 0xBFC0, 0x3F56, 0x3F09, 0x3FE2, 0xBD9D, 0x3F17, 0x3FAD, 0xBF0B, 0xBF43, 0xBE24, 0xBF1A, 0xBF32, 0x3FD4, 0x3E8F, 0x3F1A, 0xBF80, 0x3E08, 0xBF88, 0xBF1B, 0xBE8B, 0x3F43, 0xBFBD, 0x3F9D, 0xBEB3, 0xBFA5, 0xBDCB, 0x3FB0, 0xBE72, 0x3F9A, 0x3F40, 0x3EAD, 0x3F27, 0x3F3B, 0xBFD0, 0xBF56, 0x3E7F, 0x3E99, 0x4004, 0x3F4B, 0xBEFF, 0xBE7F, 0x3FE3, 0x3E8B, 0x3F41, 0xBFC3, 0x3EEC, 0x3ECC, 0xBFB6, 0x3FA2, 0xBF9F, 0xBF8A, 0xBF97, 0x3FE9, 0xBFF9, 0xBFCC, 0x3CFA, 0x3EFD, 0xBEC1, 0x3E1F, 0xBFF2, 0xBF07, 0xBEAC, 0xBED0, 0xBEE2, 0x3EDA, 0x3FBA, 0xBF2B, 0xBE80, 0x3CB6, 0x3E99, 0x3F32, 0xBDEC, 0x3E82, 0xBD46, 0xBF47, 0x3F82, 0xBEA4, 0x3F9A, 0xBFA1, 0x3FD2, 0x3FA4, 0x3F95, 0x3FB5, 0xBF18, 0x3EDC, 0xBFC6, 0x3FD5, 0x3F83, 0xBF75, 0x3F80, 0xBDBC, 0xBD63, 0xBD61, 0xBE5F, 0xBE88, 0x3EAC, 0x3E96, 0xBF20, 0xBEA5, 0x3FC3, 0x3F2B, 0x3E58, 0xBF10, 0x3E82, 0x3F3B, 0x3E95, 0x39A4, 0xBEAF, 0x3EE3, 0x3E29, 0xBEC9, 0x3EFF, 0x3D13, 0x3F60, 0xBF34, 0xBF9C, 0x3E4E, 0x3F33, 0x3F29, 0x3EE2, 0xBF95, 0xBF87, 0x3F2E, 0xBE9E, 0xBF9D, 0xBDB7, 0x3F85, 0xBF07, 0x3F5E, 0x3F15, 0x3EA3, 0xBF94, 0x3F7C, 0xBF99, 0xBE3F, 0x3F7C, 0xBF89, 0x3F00, 0xBDD4, 0x3F5D, 0x3F07, 0x3F1C, 0xBFC8, 0xBFF6, 0xBF3E}; + uint16_t expected[] = {0x40BD, 0xC0E3, 0x4037, 0x40A9, 0x406F, 0x4116, 0x3F8D, 0xC01F, 0xC039, 0xC043, 0x3F86, 0x410A, 0x3F07, 0xC100, 0x4019, 0x40D7, 0x40A9, 0x40F1, 0xBF89, 0x406F, 0x40FE, 0xBFB8, 0xBF88, 0x406A, 0x4004, 0x3EDE, 0x3E17, 0x4102, 0xC081, 0xC0BA, 0xBFFB, 0x3F25}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// GroupSize128: M=1, N=2, K=256, gs=128 +// scale/zero layout: [K//gs=2, N=2] +TEST_F(AOTITorchInt4PlainMMTest, GroupSize128) { + int64_t M = 1, K = 256, N = 2, gs = 128; + + // clang-format off + uint8_t qdata_host[] = { + 0xDE, 0x2E, 0x2C, 0x16, 0xA3, 0x9B, 0x16, 0x10, + 0xFE, 0x09, 0x0E, 0x9F, 0xE3, 0x4D, 0x00, 0x14, + 0x37, 0x42, 0x27, 0xF4, 0xD8, 0x70, 0x39, 0xCC, + 0x64, 0x51, 0xE6, 0x2B, 0xC1, 0x38, 0x5A, 0xB0, + 0xA2, 0x6A, 0x2D, 0xF4, 0xBB, 0xCD, 0x4E, 0xD6, + 0xA3, 0x60, 0xFE, 0x74, 0x6B, 0x17, 0xAB, 0x75, + 0x29, 0x84, 0xC1, 0x12, 0x31, 0x5C, 0x09, 0xB8, + 0x61, 0x33, 0x5B, 0x79, 0x29, 0xB3, 0x33, 0xE8, + 0x96, 0xE7, 0x36, 0x69, 0x6C, 0x6B, 0xD1, 0xAE, + 0x43, 0x13, 0xDF, 0x50, 0xD8, 0xE6, 0xBF, 0x98, + 0x1D, 0x30, 0x1D, 0x43, 0x7E, 0x6D, 0x1C, 0xE4, + 0x3C, 0x3C, 0x67, 0x68, 0xCD, 0xFC, 0x44, 0x07, + 0x90, 0x88, 0xA4, 0xAF, 0xDD, 0xE8, 0x16, 0x6E, + 0x78, 0xCA, 0x9C, 0xBA, 0x71, 0xCD, 0x1B, 0x97, + 0x8E, 0xF7, 0x31, 0x81, 0x7E, 0x15, 0x52, 0x22, + 0xDE, 0x39, 0xB2, 0x6E, 0x97, 0xE1, 0xB3, 0xCA, + 0xB8, 0x3A, 0xAD, 0xBA, 0x97, 0x9B, 0xBE, 0x33, + 0x5E, 0x6B, 0x80, 0x77, 0x44, 0x05, 0xC8, 0x29, + 0x15, 0xC5, 0xF9, 0xCB, 0xA2, 0x34, 0x30, 0xB7, + 0x27, 0x15, 0x57, 0x19, 0x2A, 0xAD, 0x58, 0x90, + 0x33, 0x13, 0x67, 0x13, 0x27, 0x6C, 0x95, 0x98, + 0xA4, 0x87, 0x95, 0x42, 0xCC, 0x33, 0x71, 0xCF, + 0x8D, 0x75, 0xE7, 0x7E, 0xCE, 0x05, 0xE0, 0xE8, + 0x1F, 0xF0, 0xEE, 0xB4, 0xAF, 0x45, 0x05, 0x17, + 0xA2, 0x72, 0x7A, 0xA3, 0x16, 0x48, 0xD1, 0xE6, + 0x95, 0xFA, 0x30, 0x31, 0x7E, 0x77, 0x35, 0xE6, + 0x3D, 0x15, 0x95, 0x31, 0x9D, 0x51, 0x6D, 0xDA, + 0x51, 0xE0, 0x07, 0xCE, 0x3A, 0xC0, 0x26, 0xA7, + 0xE5, 0x01, 0x20, 0x56, 0xEF, 0xED, 0xCD, 0x19, + 0xE5, 0xA3, 0x46, 0x7A, 0x1D, 0x6E, 0x30, 0x31, + 0x80, 0xEE, 0xED, 0x15, 0x34, 0x22, 0x0D, 0x2E, + 0xAB, 0xEE, 0x20, 0x97, 0xE0, 0xF3, 0xB9, 0xF7 + }; + uint16_t scale_host[] = {0xBB98, 0xBD63, 0xBCBE, 0xBD87}; + uint16_t zero_host[] = {0x4100, 0x4100, 0x4100, 0x4100}; + uint16_t A_host[] = {0x3F02, 0x3EB3, 0x3F22, 0xBD3F, 0x3F91, 0x3EFF, 0xBFD2, 0xC026, 0xBF3D, 0xBEBD, 0x3EFD, 0x4002, 0x3EF0, 0xBF2F, 0xBD4B, 0xBEE6, 0xBEA5, 0x3F78, 0x3FC3, 0xBE08, 0xBFC5, 0xBFFE, 0xBE4F, 0x3FA8, 0x3EE9, 0x3F60, 0xC03E, 0x3F88, 0x3F1C, 0xBF35, 0xBF8E, 0x0000, 0x3F03, 0x3ED9, 0xBE3D, 0x3ED0, 0xBF90, 0x3FF8, 0xBEDF, 0x3E62, 0x3F45, 0x3E68, 0xBF3E, 0xBDA0, 0x3F98, 0xC003, 0x3E51, 0xBDF8, 0xBED1, 0x3E78, 0x3FA4, 0xBEAD, 0x3F6C, 0x3E1F, 0x4000, 0xBED1, 0x3ECF, 0x3EC4, 0xBF50, 0x3F8E, 0x3FC5, 0xBF97, 0x3E18, 0x3EA1, 0xBFBD, 0xBFA5, 0x3EB0, 0xBF02, 0x3FD7, 0x3F6A, 0xBFEF, 0x3F9E, 0xBF3F, 0xBF90, 0xBFC0, 0xBFCE, 0x3F80, 0x3FFA, 0xBDB0, 0xBECD, 0xBF06, 0x3F75, 0xBFEC, 0x3E5E, 0xC00E, 0xBE63, 0xBF9A, 0x3FAB, 0xBEC8, 0xBF1B, 0x4017, 0xBE03, 0x3F4C, 0x3FA3, 0x3F43, 0xBF13, 0xBF4C, 0x3D7D, 0x3F28, 0x3EC5, 0x3F5A, 0x3F39, 0xBEED, 0x4011, 0x3DD0, 0x3F5E, 0x3F6E, 0x3FA1, 0xC008, 0x3F83, 0x3CB5, 0x3EE7, 0xBED1, 0x3F2D, 0x3A68, 0x3D21, 0x3DE7, 0xBE6B, 0x3DEE, 0x3EF5, 0xBFA6, 0x4042, 0x3FEA, 0xBDF3, 0xBF30, 0x3FC5, 0x3DCD, 0x3EA3, 0xBF0A, 0xBF1A, 0xBF41, 0x3F27, 0xBE1F, 0xBEFE, 0x3F25, 0xBE14, 0x3E33, 0xBFDC, 0x3EAE, 0xBF96, 0xBEFC, 0xBFC9, 0xC035, 0xBF2B, 0x3DE1, 0xBF3D, 0x0000, 0xC002, 0x3E77, 0xBEAB, 0x3EC7, 0xBEBB, 0x3F89, 0x3EAB, 0x0000, 0x3E84, 0xBEDF, 0xBE67, 0x3E47, 0x3DE5, 0x3FA6, 0xBF42, 0x3E58, 0x3E8C, 0x4007, 0x3F0A, 0xC00A, 0xBE0D, 0xBEC1, 0x3F62, 0x3D58, 0xBFD5, 0xBED0, 0xBEE3, 0xBF62, 0x3F4B, 0x3FC0, 0xBF34, 0x3F18, 0x3F73, 0x3F18, 0x3FE6, 0x3E6F, 0x3CD9, 0x3DE4, 0x3FA8, 0x3FC6, 0x3F7E, 0xBD1E, 0xBFA6, 0x3E84, 0x3F8E, 0xBE94, 0x3F63, 0xBE0F, 0x3F49, 0x3E16, 0x3EC0, 0xBF90, 0x401A, 0xBDE8, 0xBF06, 0xBEE2, 0x3FC6, 0xBFBA, 0x3EEA, 0x3F4A, 0xBFE0, 0x4009, 0xBFAA, 0xBF04, 0x3F9D, 0xBF9A, 0x3F06, 0x3FD0, 0x3FAB, 0x3EDB, 0xBF6C, 0x3FD7, 0xBEB6, 0xBF09, 0x0000, 0x3F78, 0x3FAB, 0x3F95, 0xBD5C, 0xBF66, 0xBDF9, 0xBD42, 0xBFDE, 0xBF11, 0xBE46, 0xBF76, 0xBF75, 0x3F31, 0xBEC5, 0xBFF3, 0xBF0F, 0x3EEE, 0xBED8, 0x3E2C, 0xBF3E, 0xBD82, 0x3F33, 0x3F24, 0xBFFF, 0xBF23, 0xBF8A, 0x3E0D, 0xBEC0, 0x3FAF, 0xBF76, 0xBF94, 0x3FAC, 0xBF21, 0x3FA0}; + uint16_t expected[] = {0xC013, 0xBF05}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +TEST_F(AOTITorchInt4PlainMMTest, NullInputHandling) { + int64_t M = 2, K = 128, N = 64, gs = 32; + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + Tensor* scale = create_bf16({K / gs, N}); + Tensor* zero = create_bf16({K / gs, N}); + Tensor* output = nullptr; + + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(nullptr, qdata, scale, zero, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, nullptr, scale, zero, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, nullptr, zero, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, scale, nullptr, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, scale, zero, gs, nullptr), + Error::InvalidArgument); +} + +TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { + int64_t M = 1, K = 64, N = 8, gs = 32; + int64_t n_groups = K / gs; // 2 + + // clang-format off + // Data from quantize_weight(randn(8,64), QuantConfig(bits=4, gs=32, symmetric=False)) + uint8_t qdata_host[] = { + 0x04, 0x9A, 0x97, 0x63, 0x9B, 0x74, 0x4D, 0x9F, 0x4C, 0x2C, 0x88, 0x58, + 0x56, 0x8D, 0x51, 0x58, 0x87, 0xF5, 0x6A, 0xC7, 0x6C, 0x65, 0x30, 0x84, + 0xB6, 0xA1, 0x37, 0x48, 0x5B, 0x36, 0x68, 0xE7, 0x8E, 0x6A, 0x88, 0x82, + 0xAA, 0x9D, 0xAB, 0x0D, 0xB5, 0x81, 0xBE, 0xA3, 0xE4, 0x9F, 0x99, 0xC8, + 0x86, 0xC8, 0x5D, 0xAA, 0x86, 0x46, 0xBA, 0x9D, 0xDA, 0x06, 0xCA, 0xB7, + 0x53, 0xCD, 0xBF, 0x37, 0x25, 0xD4, 0x04, 0x36, 0xAF, 0x79, 0x57, 0x54, + 0x2A, 0xC9, 0x98, 0x98, 0x5A, 0x05, 0x43, 0x89, 0x84, 0x9A, 0x74, 0xC6, + 0xE6, 0x96, 0x6B, 0x09, 0xAF, 0xFB, 0x3C, 0xB3, 0x88, 0x63, 0x68, 0xAC, + 0x48, 0xB9, 0xC9, 0x34, 0xDC, 0x77, 0x8A, 0x8C, 0xFC, 0x75, 0xC7, 0x95, + 0xAD, 0xF5, 0x70, 0x9C, 0x4A, 0x79, 0x7C, 0x67, 0xAA, 0xAA, 0x0B, 0x8C, + 0xF0, 0x28, 0x91, 0xCD, 0xDA, 0x95, 0x3A, 0x84, 0xD9, 0x45, 0x89, 0x33, + 0x5B, 0x63, 0xB4, 0x39, 0xE9, 0xBF, 0x54, 0x40, 0xAB, 0xC8, 0x88, 0xCB, + 0x48, 0xBA, 0x7A, 0x03, 0xCB, 0x35, 0x74, 0x85, 0x67, 0x58, 0x12, 0xDC, + 0x5B, 0x02, 0x58, 0xF7, 0x8C, 0xC8, 0xA5, 0xFA, 0xAA, 0x8E, 0x4C, 0x1F, + 0xBB, 0x27, 0xC7, 0xEC, 0xB8, 0x69, 0x6F, 0x9F, 0x69, 0x69, 0x55, 0x79, + 0x34, 0x64, 0x56, 0x85, 0x67, 0x3F, 0xA8, 0x80, 0x7A, 0x77, 0x79, 0x05, + 0xA9, 0x10, 0xA7, 0x55, 0x4A, 0x48, 0xF8, 0x59, 0xB6, 0x5A, 0xBD, 0x55, + 0x8C, 0x96, 0x48, 0x6B, 0x9A, 0xC7, 0x97, 0x4B, 0x46, 0x65, 0xF7, 0x7B, + 0x78, 0x5C, 0x8A, 0xC5, 0x98, 0x0C, 0x45, 0x3B, 0x75, 0x9C, 0xC7, 0x58, + 0x63, 0x9A, 0x95, 0x78, 0x95, 0x69, 0xF8, 0x58, 0x65, 0x0A, 0x6B, 0x47, + 0x9C, 0x5C, 0x6A, 0x35, 0xA2, 0x8A, 0x74, 0x93, 0x28, 0x6D, 0xF0, 0xAB, + 0x23, 0xA6, 0xA6, 0x3A}; + // scale/zero are [K//gs, N] = [2, 8] — Int4Tensor's native layout + uint16_t scale_host[] = { + 0x3E46, 0x3E94, 0x3E8F, 0x3E94, 0x3E94, 0x3E8D, 0x3EA5, 0x3EA5, + 0x3E9F, 0x3EAD, 0x3E91, 0x3EA0, 0x3E88, 0x3EB7, 0x3E89, 0x3E92}; + uint16_t zero_host[] = { + 0x4100, 0x4110, 0x40A0, 0x4100, 0x4100, 0x4130, 0x4100, 0x40C0, + 0x40C0, 0x4100, 0x4100, 0x4100, 0x40E0, 0x40E0, 0x4110, 0x40C0}; + uint16_t A_host[] = { + 0x3E47, 0x400A, 0xBE30, 0x3F59, 0xBFF6, 0x3F27, 0xBF26, 0xBF51, + 0x3F07, 0xBFA3, 0xBFD5, 0xBE9B, 0xBDBE, 0x3E4C, 0xBF8F, 0x3FEE, + 0xBF37, 0x3F30, 0x3F4C, 0xBD09, 0x3FBF, 0xBF04, 0xBE82, 0x3FBD, + 0xBEA7, 0xBF94, 0x4017, 0xBF31, 0x3E3C, 0xBF97, 0xBFE7, 0xBFCA, + 0x3F57, 0x3FB6, 0x3F26, 0x3EDA, 0xBFCB, 0x3F1F, 0x3FD8, 0xBF2A, + 0x3F71, 0x3DA0, 0x3DAD, 0xBE10, 0x3EAA, 0xBF17, 0xBF89, 0x3DC3, + 0xBEAB, 0xBF07, 0xBF61, 0x3ECA, 0x3E28, 0xBE4A, 0x3F81, 0xBFAD, + 0xBEB3, 0xBF25, 0x3EE5, 0xBF0A, 0x3F9F, 0xBF51, 0x3E80, 0xBEDB}; + // Reference from Python: bf16 dequant + F.linear + uint16_t expected[] = { + 0x40B7, 0xC100, 0xC0E2, 0xC158, 0xBF29, 0xC11F, 0x4079, 0x407D}; + // clang-format on + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + // Note: scale/zero shape is [n_groups, N], NOT [N, n_groups] + Tensor* scale = create_bf16({n_groups, N}); + Tensor* zero = create_bf16({n_groups, N}); + upload(A, A_host, sizeof(A_host)); + upload(qdata, qdata_host, sizeof(qdata_host)); + upload(scale, scale_host, sizeof(scale_host)); + upload(zero, zero_host, sizeof(zero_host)); + + Tensor* output = run(A, qdata, scale, zero, gs); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + // W4A8 adds quantization noise vs bf16 reference — use wider tolerance + check_bf16_output(output, expected, M * N, 0.5f); +} diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py new file mode 100644 index 00000000000..c793544ad48 --- /dev/null +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for Int4Tensor F.linear dispatch via int4_dispatch. + +These tests validate the eager / trace-time dispatch path — the same code +that torch.export traces through when building the AOTI graph. They do NOT +test the .pte runtime C shim (dp4a kernel); that is covered by +test_aoti_torch_cuda_int4_plain_mm.cpp (C++ unit tests) and +test_cuda_pipeline.py::TestCudaExport (end-to-end export + lower). + +The API contract: after importing int4_dispatch, F.linear and nn.Linear +with Int4Tensor weights produce numerically correct results. Tests verify +this across decode (M<=4), prefill (M>4), batched (3D), bias, group sizes, +and symmetric/asymmetric quantization. Correctness is measured as mean +relative error against the unquantized bf16 reference (not per-element +atol/rtol, which is too strict for INT4 quantization noise). + +Usage: + python -m pytest backends/cuda/tests/test_int4_dispatch.py -v +""" + +import unittest + +import executorch.backends.cuda.int4_dispatch # noqa: F401 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.examples.models.gemma4_31b.quant.quantize import quantize_weight +from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig + + +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +def _make_int4_linear(N, K, group_size=128, symmetric=False, bias=False): + """Build an nn.Linear with Int4Tensor weight and return (module, bf16_ref_weight). + + The bf16 reference is the original unquantized weight, so tests can + measure quantization error against the true value. + """ + w_bf16 = torch.randn(N, K, dtype=torch.bfloat16) + config = QuantConfig( + bits=4, group_size=group_size, symmetric=symmetric, method="min_max" + ) + int4_w = quantize_weight(w_bf16, config) + + module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda") + module.weight = nn.Parameter(int4_w.cuda(), requires_grad=False) + return module, w_bf16.cuda() + + +class TestFLinearDispatch(unittest.TestCase): + """F.linear with Int4Tensor weight produces correct results.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_decode_m1(self): + module, w_ref = _make_int4_linear(256, 512) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_prefill_m64(self): + module, w_ref = _make_int4_linear(256, 512) + x = torch.randn(64, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_3d_batched_input(self): + module, w_ref = _make_int4_linear(256, 512) + x = torch.randn(2, 32, 512, dtype=torch.bfloat16, device="cuda") + out = module(x) + self.assertEqual(out.shape, (2, 32, 256)) + self._check(out, F.linear(x, w_ref)) + + def test_with_bias(self): + module, w_ref = _make_int4_linear(256, 512, bias=True) + x = torch.randn(4, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref, module.bias)) + + def test_group_size_32(self): + module, w_ref = _make_int4_linear(128, 256, group_size=32) + x = torch.randn(1, 256, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_symmetric(self): + module, w_ref = _make_int4_linear(256, 512, symmetric=True) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + +class TestMultiLayer(unittest.TestCase): + """Dispatch works across multiple Int4 linear modules in a model.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_two_layer_mlp(self): + up, w_up = _make_int4_linear(512, 256) + down, w_down = _make_int4_linear(256, 512) + x = torch.randn(4, 256, dtype=torch.bfloat16, device="cuda") + out = down(F.silu(up(x))) + ref = F.linear(F.silu(F.linear(x, w_up)), w_down) + self._check(out, ref) + + def test_sequential_decode_steps(self): + module, w_ref = _make_int4_linear(256, 512) + for _ in range(4): + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + +class TestCompile(unittest.TestCase): + """Dispatch works under torch.compile.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_compile_decode(self): + module, w_ref = _make_int4_linear(256, 512) + compiled = torch.compile(module, fullgraph=True) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(compiled(x), F.linear(x, w_ref)) + + def test_compile_prefill(self): + module, w_ref = _make_int4_linear(256, 512) + compiled = torch.compile(module, fullgraph=True) + x = torch.randn(64, 512, dtype=torch.bfloat16, device="cuda") + self._check(compiled(x), F.linear(x, w_ref)) + + def test_compile_matches_eager(self): + module, _ = _make_int4_linear(256, 512) + compiled = torch.compile(module, fullgraph=True) + x = torch.randn(4, 512, dtype=torch.bfloat16, device="cuda") + out_eager = module(x) + out_compiled = compiled(x) + self.assertTrue(torch.allclose(out_eager, out_compiled, atol=0.5)) + + +class TestDeviceMovement(unittest.TestCase): + """Int4Tensor weight survives device movement and still dispatches.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_to_cuda(self): + w_bf16 = torch.randn(256, 512, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=128, symmetric=False, method="min_max") + int4_w = quantize_weight(w_bf16, config) + module = nn.Linear(512, 256, bias=False) + module.weight = nn.Parameter(int4_w, requires_grad=False) + module = module.to("cuda") + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_bf16.cuda())) + + +class TestLargeShapes(unittest.TestCase): + """Correctness at large production-scale layer shapes.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.15): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def test_4096x5376_decode(self): + module, w_ref = _make_int4_linear(4096, 5376) + x = torch.randn(1, 5376, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_21504x5376_decode(self): + module, w_ref = _make_int4_linear(21504, 5376) + x = torch.randn(1, 5376, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_21504x5376_prefill(self): + module, w_ref = _make_int4_linear(21504, 5376) + x = torch.randn(128, 5376, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/tests/test_sdpa_splitk_replacement.py b/backends/cuda/tests/test_sdpa_splitk_replacement.py new file mode 100644 index 00000000000..414a1308777 --- /dev/null +++ b/backends/cuda/tests/test_sdpa_splitk_replacement.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Test ReplaceEdgeOpWithTritonOpPass split-K SDPA kernel selection. + +Exports a minimal model containing F.scaled_dot_product_attention through +the CUDA backend and verifies that the pass routes to split-K for decode +(L_q=1, large L_kv) and standard SDPA otherwise. +""" + +import logging +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +class SDPAModule(nn.Module): + """Single-layer model with SDPA and a static KV cache buffer.""" + + def __init__(self, n_heads, n_kv_heads, head_dim, kv_len): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + hidden = n_heads * head_dim + self.q_proj = nn.Linear(hidden, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(hidden, n_kv_heads * head_dim, bias=False) + self.register_buffer( + "k_cache", torch.zeros(1, n_kv_heads, kv_len, head_dim), persistent=False + ) + self.register_buffer( + "v_cache", torch.zeros(1, n_kv_heads, kv_len, head_dim), persistent=False + ) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + B, T, _ = x.shape + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + self.k_cache.index_copy_(2, input_pos, k) + self.v_cache.index_copy_(2, input_pos, v) + y = F.scaled_dot_product_attention( + q, + self.k_cache, + self.v_cache, + enable_gqa=True, + ) + return y.transpose(1, 2).contiguous().view(B, T, -1) + + +def _export_through_cuda_backend(model, example_args): + """Export and lower through the CUDA backend (stops before to_executorch).""" + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from torch.export import export + + with torch.no_grad(): + ep = export(model, example_args, strict=True) + + return to_edge_transform_and_lower( + {"decode": ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("decode")] + ) + ], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + + +def _capture_pass_logs(fn): + """Run fn and return replacement pass log messages.""" + pass_logger = logging.getLogger("executorch.backends.cuda.triton.replacement_pass") + prev_level = pass_logger.level + pass_logger.setLevel(logging.INFO) + messages = [] + handler = logging.Handler() + handler.emit = lambda record: messages.append(record.getMessage()) + pass_logger.addHandler(handler) + try: + return fn(), messages + finally: + pass_logger.removeHandler(handler) + pass_logger.setLevel(prev_level) + + +class TestSplitKReplacement(unittest.TestCase): + + def setUp(self): + _require_cuda(self) + + def test_large_kv_cache_uses_splitk(self): + """L_kv=4096 > threshold → split-K selected for decode.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + torch.bfloat16 + ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=4096", splitk[0]) + + def test_small_kv_cache_uses_standard(self): + """L_kv=512 <= threshold → standard SDPA, no split-K.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=512).to( + torch.bfloat16 + ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") + + replaced = [m for m in msgs if "Replaced" in m] + self.assertTrue( + any("1 nodes" in m for m in replaced), + f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + ) + + def test_non_pow2_head_dim_uses_standard(self): + """Non-power-of-2 head_dim → standard SDPA even with large L_kv.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=96, kv_len=8192).to( + torch.bfloat16 + ) + args = ( + torch.zeros(1, 1, 384, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 0, f"Expected no split-K for D=96. Got: {splitk}") + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index b14628d29cf..628222e46f7 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,6 +27,8 @@ exir_ops.edge.aten.topk.default: triton.topk, } +_SPLITK_LKV_THRESHOLD = 2048 + class ReplaceEdgeOpWithTritonOpPass(PassBase): """ @@ -83,6 +85,34 @@ def call(self, graph_module: GraphModule) -> PassResult: # for rows larger than this threshold. _TOPK_MAX_N = 4096 + @staticmethod + def _pick_sdpa_kernel(node: Node): + """Choose between standard SDPA and split-K flash-decoding. + + Split-K partitions the KV sequence across many CTAs for better GPU + utilization at decode time (L_q=1). It wins when L_kv is large + (full-attention KV caches) but loses to the standard kernel for + small L_kv (sliding-window ring buffers) due to the overhead of + allocating partial buffers and running the reduction kernel. + """ + q_shape = node.args[0].meta["val"].shape + k_shape = node.args[1].meta["val"].shape + L_q, D = q_shape[2], q_shape[3] + L_kv = k_shape[2] + + if ( + isinstance(L_q, int) + and L_q == 1 + and isinstance(L_kv, int) + and L_kv > _SPLITK_LKV_THRESHOLD + and D > 0 + and (D & (D - 1)) == 0 # power of 2 + ): + logger.info(f"Using split-K decode SDPA (L_kv={L_kv}, D={D})") + return triton.sdpa_decode_splitk + + return triton.sdpa + def _should_replace_node(self, node: Node) -> bool: """ Check if a node should be replaced with a Triton kernel. @@ -128,6 +158,9 @@ def _replace_node_with_triton(self, graph_module: GraphModule, node: Node) -> No triton_kernel_fn = EDGE_TO_TRITON_KERNELS[target] + if target == exir_ops.edge.aten.scaled_dot_product_attention.default: + triton_kernel_fn = self._pick_sdpa_kernel(node) + # Create a new node with the Triton kernel with graph_module.graph.inserting_before(node): # The triton_kernel_fn is already registered as a custom op via @triton_op diff --git a/examples/models/gemma4/text_decoder/__init__.py b/examples/models/gemma4/text_decoder/__init__.py index 25d7c5c7a16..5f21130e27d 100644 --- a/examples/models/gemma4/text_decoder/__init__.py +++ b/examples/models/gemma4/text_decoder/__init__.py @@ -6,5 +6,13 @@ # LICENSE file in the root directory of this source tree. from .convert_weights import convert_hf_to_custom # noqa: F401 +from .gemma4_attention import ( # noqa: F401 + apply_rotary_emb, + apply_rotary_emb_single, + Gemma4KVCache, + rotate_half, +) from .gemma4_config import Gemma4Config # noqa: F401 +from .gemma4_decoder_layer import Gemma4MLP # noqa: F401 from .gemma4_model import create_gemma4_model, Gemma4Model # noqa: F401 +from .gemma4_norm import RMSNorm, RMSNormNoWeight # noqa: F401 diff --git a/examples/models/gemma4/text_decoder/gemma4_norm.py b/examples/models/gemma4/text_decoder/gemma4_norm.py index 17e42a43ca1..2c8fec67525 100644 --- a/examples/models/gemma4/text_decoder/gemma4_norm.py +++ b/examples/models/gemma4/text_decoder/gemma4_norm.py @@ -5,9 +5,46 @@ # pyre-unsafe # LICENSE file in the root directory of this source tree. +"""Gemma 4 RMSNorm — self-contained re-implementation. + +Numerically identical to ``transformers.models.gemma4.modeling_gemma4.Gemma4RMSNorm`` +(same float32 upcast and ``pow(mean_squared, -0.5)`` normalization), but +without the transformers import so this module is exportable and dep-light. +""" + from functools import partial -from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm as RMSNorm +import torch +from torch import nn + + +class RMSNorm(nn.Module): + """Gemma4 RMSNorm: ``y = (x / rms(x)) * weight``, computed in float32. + + Unlike Gemma 2/3 (``(1 + weight)``) Gemma 4 multiplies by ``weight`` directly. + Pass ``with_scale=False`` for the v-norm and the (unused-here) router norm, + which omit the learnable weight entirely. + """ + + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + # Match transformers' use of pow(mean_squared, -0.5) over rsqrt; + # the comment there cites Torch/JAX compiler differences. + mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps + return x * torch.pow(mean_squared, -0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + normed = self._norm(x.float()) + if self.with_scale: + normed = normed * self.weight.float() + return normed.type_as(x) + # V-norm in attention uses RMSNorm without learnable weight. RMSNormNoWeight = partial(RMSNorm, with_scale=False) diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt new file mode 100644 index 00000000000..8d536a47fc5 --- /dev/null +++ b/examples/models/gemma4_31b/CMakeLists.txt @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(gemma4_31b) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# executorch +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# CPU ops (for the host-side helpers that aren't delegated to CUDA) +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# Extensions +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# CUDA backend (the only supported backend for this example for now) +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + add_compile_definitions(EXECUTORCH_BUILD_CUDA) +else() + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON") +endif() + +# Tokenizer (HuggingFace tokenizer.json) +list(APPEND link_libraries tokenizers::tokenizers) + +add_executable(gemma4_31b_runner main.cpp) +target_include_directories( + gemma4_31b_runner PUBLIC ${_common_include_directories} +) +target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries}) + +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(gemma4_31b_runner) + target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") +endif() diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json new file mode 100644 index 00000000000..97ba7f4c57a --- /dev/null +++ b/examples/models/gemma4_31b/CMakePresets.json @@ -0,0 +1,52 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "gemma4-31b-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/gemma4_31b", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "gemma4-31b-cuda", + "displayName": "Gemma 4 31B runner (CUDA)", + "inherits": ["gemma4-31b-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + } + ], + "buildPresets": [ + { + "name": "gemma4-31b-cuda", + "displayName": "Build Gemma 4 31B runner (CUDA)", + "configurePreset": "gemma4-31b-cuda", + "targets": ["gemma4_31b_runner"] + } + ], + "workflowPresets": [ + { + "name": "gemma4-31b-cuda", + "displayName": "Configure and build Gemma 4 31B runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "gemma4-31b-cuda" + }, + { + "type": "build", + "name": "gemma4-31b-cuda" + } + ] + } + ] +} diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md new file mode 100644 index 00000000000..6f567d739b7 --- /dev/null +++ b/examples/models/gemma4_31b/README.md @@ -0,0 +1,123 @@ +# Gemma 4 31B-IT + +Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8 +weight quantization. Currently supports the CUDA backend. + +For architecture and design notes see [model.md](model.md). + +## When to use which script + +The full bf16 weights for 31B (~62 GB) often don't fit in available RAM. The +recommended flow is to quantize once and reuse the quantized checkpoint for +both export and eager inference: + +| Script | Purpose | Peak memory | +|---|---|---| +| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU | +| `export.py --prequantized ` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing | +| `inference.py --prequantized ` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU | +| `inference.py --gguf ` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU | +| `export.py --model-dir ` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing | + +The quantized checkpoint is a safetensors file containing torchao tensor +subclasses (`Int4Tensor`, `IntxUnpackedToInt8Tensor`) and plain tensors. +Metadata records each subclass's type and attributes. No backend-specific +packing — packing for the target backend happens at load time via +`quant.pack_model()`. + +## Quantization recipes + +Two built-in recipes (see `quantize_and_save.py`): + +| Recipe | Description | +|---|---| +| `default` | INT4 min_max linears, INT8 per-axis embedding | +| `sensitive` | INT8 for edge-layer v_proj/down_proj, INT4 hqq elsewhere, INT8 per-axis embedding | + +## Prequantized checkpoint + +A prequantized checkpoint (sensitive recipe) is available on HuggingFace: + +```bash +huggingface-cli download SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 --local-dir gemma-4-31B-it-HQQ-INT4 +``` + +> **Note**: This checkpoint is intended for development and testing of the +> ExecuTorch CUDA export pipeline. Output quality has not been formally +> evaluated against the base model. + +Use it directly with `--prequantized` in the export and inference scripts +below — no need to run `quantize_and_save.py`. + +## Quantize from scratch (optional) + +To quantize from the original bf16 checkpoint instead, pass +`--quant-recipe` to select a recipe (`default` or `sensitive`): + +```bash +python examples/models/gemma4_31b/quantize_and_save.py \ + --model-dir /path/to/gemma-4-31B-it \ + --output ./gemma4_31b_int4 \ + --quant-recipe sensitive +``` + +See [Quantization recipes](#quantization-recipes) above for details on each +recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into +`--output`. + +## Export to ExecuTorch + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports \ + --max-seq-len 4096 \ + --backend cuda +``` + +Writes `model.pte` and `model.ptd` into `--output-dir`. + +## Eager inference + +```bash +python examples/models/gemma4_31b/inference.py \ + --prequantized ./gemma4_31b_int4 \ + --prompt "Write a short joke about saving RAM." \ + --max-new-tokens 128 \ + --temperature 0.8 +``` + +GGUF files from the community (e.g., Q4_K_M) can also be used directly: + +```bash +python examples/models/gemma4_31b/inference.py \ + --gguf ./gemma-4-31B-it-Q4_K_M.gguf \ + --tokenizer-path /path/to/tokenizer.json \ + --prompt "Hello" +``` + +Useful before spending the export+lowering time to confirm the quantized +model produces sensible text. + +## Build the runner + +```bash +make gemma4_31b-cuda +``` + +The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. + +## Run the .pte + +```bash +./gemma4_31b_runner \ + --model_path ./gemma4_31b_exports/model.pte \ + --data_path ./gemma4_31b_exports/aoti_cuda_blob.ptd \ + --tokenizer_path ./gemma4_31b_int4/tokenizer.json \ + --prompt "Write a short joke about saving RAM." \ + --max_new_tokens 128 \ + --temperature 0.8 +``` + +For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA +graph (decode is fully static — `T=1`). diff --git a/examples/models/gemma4_31b/__init__.py b/examples/models/gemma4_31b/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/models/gemma4_31b/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py new file mode 100644 index 00000000000..a96dba0d512 --- /dev/null +++ b/examples/models/gemma4_31b/export.py @@ -0,0 +1,337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export Gemma 4 31B-IT to ExecuTorch (.pte + .ptd). + +Two methods are exported and lowered together so they share KV-cache buffers: + - "decode": T=1, static shape, returns the next sampled token. + - "prefill": T>=2, dynamic shape, returns the next sampled token. + +Three input paths: + --prequantized Load a quantized checkpoint (from quantize_and_save.py) + and pack for the target backend. No re-quantization. + --gguf Load a GGUF file (e.g., Q4_K_M from the community). + --model-dir Load bf16 checkpoint, quantize, pack, and export + in one shot. + +Backends: + --backend cuda (default) CUDA via tinygemm INT4 + CudaPartitioner. +""" + +import argparse +import os + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + Gemma4_31BConfig, + materialize_runtime_buffers, +) + + +# --------------------------------------------------------------------------- +# Load paths + + +def load_prequantized_model( + prequantized_dir: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple[Gemma4_31B, Gemma4_31BConfig]: + """Load a quantized checkpoint and pack for the target backend.""" + config = Gemma4_31BConfig.from_hf_config( + os.path.join(prequantized_dir, "config.json") + ) + config.max_seq_len = max_seq_len + + print("Building model on meta device...") + with torch.device("meta"): + model = Gemma4_31B(config) + + safetensors_path = os.path.join(prequantized_dir, "model.safetensors") + print(f"Loading quantized checkpoint from {safetensors_path}...") + _pack_for_backend(model, safetensors_path, backend) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config + + +def load_and_quantize( + model_dir: str, + recipe_name: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple[Gemma4_31B, Gemma4_31BConfig]: + """Load bf16 checkpoint, quantize, pack — one shot.""" + from executorch.examples.models.gemma4_31b.quant import pack_model, quantize_model + from executorch.examples.models.gemma4_31b.quantize_and_save import _RECIPES + + recipe = _RECIPES[recipe_name] + + print("Loading checkpoint (lazy, shard-by-shard)...") + model, config = Gemma4_31B.from_hf_checkpoint(model_dir, max_seq_len=max_seq_len) + + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + print("Untying embed_tokens / lm_head...") + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + print(f"Quantizing with recipe '{recipe_name}'...") + state_dict = quantize_model(model, recipe, verbose=True) + + print(f"Packing for {backend}...") + with torch.device("meta"): + model = Gemma4_31B(config) + pack_model(model, state_dict, packers=_get_packers(backend)) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config + + +# --------------------------------------------------------------------------- +# Backend dispatch helpers + + +def _get_packers(backend: str) -> dict: + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS + + return DEFAULT_CUDA_PACKERS + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda + + load_and_pack_for_cuda(path, model) + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +# --------------------------------------------------------------------------- +# Export + lower + + +def export_and_lower( + model: Gemma4_31B, + config: Gemma4_31BConfig, + output_dir: str, + backend: str = "cuda", +) -> None: + """Export and lower the model to ExecuTorch for the given backend.""" + if backend == "cuda": + _export_cuda(model, config, output_dir) + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + +def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: + import gc + + import torch._inductor.config as inductor_config + + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.backend.compile_spec_schema import CompileSpec + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + inductor_config.coordinate_descent_tuning = False + inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + + # Register Int4Tensor dispatch → executorch_cuda::int4_plain_mm shim + import executorch.backends.cuda.int4_dispatch # noqa: F401 + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + # Int4Tensor weights are used directly — no format conversion. + # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). + # Both decode and prefill share the same nibble-packed weights. + + # Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M). + max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + seq_dim = Dim("seq_len", min=5, max=max_prefill) + print(f"Exporting prefill (T in [2, {max_prefill}])...") + with torch.no_grad(): + prefill_ep = export( + model, + ( + torch.zeros((1, max_prefill), dtype=torch.long), + torch.arange(max_prefill, dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}, None), + strict=True, + ) + + # Decode (T=1): same Int4Tensor weights, same format. No transform needed. + print("Exporting decode (T=1)...") + with torch.no_grad(): + decode_ep = export( + model, + ( + torch.tensor([[0]], dtype=torch.long), + torch.tensor([0], dtype=torch.long), + torch.tensor([1.0], dtype=torch.float32), + ), + strict=True, + ) + + del model + gc.collect() + + print("Lowering to ExecuTorch with CUDA backend...") + et_prog = to_edge_transform_and_lower( + {"decode": decode_ep, "prefill": prefill_ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [ + CudaBackend.generate_method_name_compile_spec("decode"), + CompileSpec("low_memory_mode", b"ON"), + ] + ) + ], + "prefill": [ + CudaPartitioner( + [ + CudaBackend.generate_method_name_compile_spec("prefill"), + CompileSpec("low_memory_mode", b"ON"), + ] + ) + ], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + }, + ) + del decode_ep, prefill_ep + gc.collect() + + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + share_mutable_buffers=True, + ), + emit_mutable_buffer_names=True, + ), + ) + + del et_prog + gc.collect() + + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + +# --------------------------------------------------------------------------- +# CLI + + +def main() -> None: + from executorch.examples.models.gemma4_31b.quantize_and_save import _RECIPES + + parser = argparse.ArgumentParser(description="Export Gemma 4 31B-IT to ExecuTorch.") + src = parser.add_mutually_exclusive_group(required=True) + src.add_argument( + "--model-dir", + default=None, + help="HuggingFace model dir. Triggers load + quantize + export.", + ) + src.add_argument( + "--prequantized", + default=None, + help="Path to a quantized checkpoint directory. Skips quantization.", + ) + src.add_argument( + "--gguf", + default=None, + help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).", + ) + parser.add_argument( + "--output-dir", + default="./gemma4_31b_exports", + help="Output directory for model.pte / model.ptd.", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache size.", + ) + parser.add_argument( + "--quant-recipe", + default="default", + choices=list(_RECIPES), + help="Quantization recipe (only with --model-dir).", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend for export.", + ) + args = parser.parse_args() + + if args.backend == "cuda" and not torch.cuda.is_available(): + parser.error("CUDA is required for the cuda backend.") + + if args.prequantized: + model, config = load_prequantized_model( + args.prequantized, + max_seq_len=args.max_seq_len, + backend=args.backend, + ) + elif args.gguf: + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + model, config = load_gguf_model( + args.gguf, max_seq_len=args.max_seq_len, backend=args.backend + ) + else: + model, config = load_and_quantize( + args.model_dir, + args.quant_recipe, + max_seq_len=args.max_seq_len, + backend=args.backend, + ) + + export_and_lower(model, config, args.output_dir, backend=args.backend) + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py new file mode 100644 index 00000000000..3e50991e553 --- /dev/null +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Load a GGUF file into a Gemma 4 31B model. + +Streams tensors one at a time via ``iter_gguf_tensors`` for low peak +memory, remaps GGUF names to model FQNs, handles tied embed/lm_head, +and packs for the target backend. + +Usage: + model, config = load_gguf_model("model.gguf", backend="cuda") +""" + +from typing import Optional + +import torch + +# GGUF pattern → model FQN pattern. ``{}`` is the layer index. +_KEY_MAP = { + "token_embd.weight": "embed_tokens.weight", + "output_norm.weight": "norm.weight", + # Per-layer attention + "blk.{}.attn_q.weight": "layers.{}.self_attn.q_proj.weight", + "blk.{}.attn_k.weight": "layers.{}.self_attn.k_proj.weight", + "blk.{}.attn_v.weight": "layers.{}.self_attn.v_proj.weight", + "blk.{}.attn_output.weight": "layers.{}.self_attn.o_proj.weight", + "blk.{}.attn_q_norm.weight": "layers.{}.self_attn.q_norm.weight", + "blk.{}.attn_k_norm.weight": "layers.{}.self_attn.k_norm.weight", + # Per-layer norms + "blk.{}.attn_norm.weight": "layers.{}.input_layernorm.weight", + "blk.{}.post_attention_norm.weight": "layers.{}.post_attention_layernorm.weight", + "blk.{}.ffn_norm.weight": "layers.{}.pre_feedforward_layernorm.weight", + "blk.{}.post_ffw_norm.weight": "layers.{}.post_feedforward_layernorm.weight", + # Per-layer MLP + "blk.{}.ffn_gate.weight": "layers.{}.mlp.gate_proj.weight", + "blk.{}.ffn_up.weight": "layers.{}.mlp.up_proj.weight", + "blk.{}.ffn_down.weight": "layers.{}.mlp.down_proj.weight", + # Per-layer scalar + "blk.{}.layer_output_scale.weight": "layers.{}.layer_scalar", +} + +_IGNORED_KEYS = {"rope_freqs.weight"} + + +def gguf_to_model_key(gguf_key: str) -> Optional[str]: + """Map a GGUF tensor name to a model FQN, or ``None`` to skip.""" + if gguf_key in _IGNORED_KEYS: + return None + + for gguf_pat, model_pat in _KEY_MAP.items(): + if "{}" not in gguf_pat: + if gguf_key == gguf_pat: + return model_pat + continue + prefix, suffix = gguf_pat.split("{}") + if gguf_key.startswith(prefix) and gguf_key.endswith(suffix): + layer_str = gguf_key[len(prefix) : len(gguf_key) - len(suffix)] + if layer_str.isdigit(): + return model_pat.replace("{}", layer_str) + + return None + + +def _resolve_tied_lm_head(model, embed_quant, packers): + """Handle tied embed/lm_head after streaming all tensors.""" + from executorch.examples.models.gemma4_31b.quant import pack_one + + lm_head = getattr(model.lm_head, "weight", None) + if lm_head is None or lm_head.device.type != "meta": + return + if embed_quant is not None: + pack_one(model, "lm_head.weight", embed_quant, packers) + else: + pack_one( + model, + "lm_head.weight", + model.embed_tokens.weight.data.clone(), + packers, + ) + + +def _validate_no_meta(model): + """Ensure all parameters have been loaded.""" + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in GGUF file " + f"(model/checkpoint version mismatch?)" + ) + for p in model.parameters(): + p.requires_grad_(False) + + +def load_gguf_model( + gguf_path: str, + max_seq_len: int = 4096, + backend: str = "cuda", +) -> tuple: + """Load a GGUF file, remap keys, and pack for the target backend. + + Streams tensors one at a time for low peak memory. + + GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor. + We untie them: the embedding is dequantized to bf16 (``nn.Embedding`` + needs gather, which ``Int4TilePackedTo4dTensor`` does not support), + while ``lm_head`` keeps the original Q4_K quantization (``nn.Linear`` + matmul via tinygemm). + + Returns ``(model, config)``. + """ + from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig + from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one + from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + if backend == "cuda": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS + + packers = DEFAULT_CUDA_PACKERS + else: + raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + + config = Gemma4_31BConfig(max_seq_len=max_seq_len) + + print("Building model on meta device...") + with torch.device("meta"): + model = Gemma4_31B(config) + + embed_quant = None + n_processed = 0 + + print(f"Streaming GGUF from {gguf_path}...") + for gguf_name, result in iter_gguf_tensors(gguf_path): + model_key = gguf_to_model_key(gguf_name) + if model_key is None: + continue + + if type(result) is torch.Tensor and result.dtype == torch.float32: + result = result.to(torch.bfloat16) + + if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor): + embed_quant = result + result = dequantize_weight(result, torch.bfloat16) + + pack_one(model, model_key, result, packers) + + n_processed += 1 + if n_processed % 100 == 0: + print(f" Processed {n_processed} tensors...") + + _resolve_tied_lm_head(model, embed_quant, packers) + del embed_quant + + _validate_no_meta(model) + model.eval() + + print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}") + return model, config diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py new file mode 100644 index 00000000000..12785450d8c --- /dev/null +++ b/examples/models/gemma4_31b/inference.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Eager inference on Gemma 4 31B-IT (CUDA + torch.compile). + +Two input paths: + --prequantized Load a quantized checkpoint (from quantize_and_save.py). + --gguf Load a GGUF file (e.g., Q4_K_M from the community). + +Packs for the target backend (--backend cuda), materializes runtime buffers, +optionally compiles with ``torch.compile``, and generates text autoregressively. + +Usage: + python inference.py \\ + --prequantized ./gemma4_31b_int4 \\ + --prompt "Write a short joke about saving RAM." \\ + --max-new-tokens 128 \\ + --temperature 0.8 + + python inference.py \\ + --gguf ./gemma-4-31B-it-Q4_K_M.gguf \\ + --tokenizer-path ./tokenizer.json \\ + --prompt "Hello" +""" + +import argparse +import os +import time + +import torch + +from executorch.examples.models.gemma4_31b.export import load_prequantized_model +from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers + + +def _move_to_cuda(model, config) -> None: + """Move the prequantized model to CUDA and materialize runtime buffers there. + + Parameters are moved individually (not via ``model.cuda()``) to preserve + ``Int4TilePackedTo4dTensor`` subclass identity. Non-meta buffers (e.g. + ``layer_scalar``) are moved to CUDA. Meta-device buffers (KV cache, RoPE, + constants) are materialized directly on CUDA via + ``materialize_runtime_buffers``. + """ + for name, p in model.named_parameters(): + parts = name.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + setattr( + parent, + parts[-1], + torch.nn.Parameter(p.data.to("cuda"), requires_grad=False), + ) + + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + parent.register_buffer(parts[-1], buf.to("cuda"), persistent=False) + + materialize_runtime_buffers(model, dtype=torch.bfloat16, device="cuda") + + +def generate( + model, + tokenizer, + prompt: str, + max_new_tokens: int = 128, + temperature: float = 0.0, + eos_token_ids=None, + bos_token_id: int = 2, +) -> str: + """Autoregressive generation. Prefill is one-token-at-a-time so a single + compiled graph handles every step; the exported PTE uses a separate + multi-token prefill method, but for eager+compile a uniform decode-shape + forward is simpler and benefits from CUDA-graph friendly shapes. + + ``tokenizers.Tokenizer.from_file`` does not auto-prepend BOS — and Gemma 4 + is unusable without it (the model's logits collapse to a single + high-frequency vocab token if the very first input isn't BOS). We prepend + explicitly here; pass ``bos_token_id=None`` to disable. + """ + if eos_token_ids is None: + eos_token_ids = set() + + input_ids = tokenizer.encode(prompt).ids + if bos_token_id is not None and (not input_ids or input_ids[0] != bos_token_id): + input_ids = [bos_token_id] + input_ids + + temp_val = max(temperature, 1e-6) # avoid div-by-zero in the on-device sampler + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + + sampled = None + with torch.no_grad(): + # Prefill, one token at a time. + for i, tok_id in enumerate(input_ids): + tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + + # First generated token from the last prefill step. + next_id = int(sampled.item()) + generated = [next_id] + + # Decode loop. + seq_len = len(input_ids) + for i in range(max_new_tokens - 1): + tok = torch.tensor([[next_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + next_id = int(sampled.item()) + generated.append(next_id) + if next_id in eos_token_ids: + break + + return tokenizer.decode(generated) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Eager inference on Gemma 4 31B-IT.") + src = parser.add_mutually_exclusive_group(required=True) + src.add_argument( + "--prequantized", + default=None, + help="Path to a quantized checkpoint directory.", + ) + src.add_argument( + "--gguf", + default=None, + help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).", + ) + parser.add_argument( + "--tokenizer-path", + default=None, + help="Path to tokenizer.json (required with --gguf, optional with --prequantized).", + ) + parser.add_argument("--prompt", default="Hello", help="Input prompt.") + parser.add_argument( + "--max-new-tokens", + type=int, + default=128, + help="Maximum tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature (0 = near-greedy).", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=4096, + help="KV cache length to allocate for this run.", + ) + parser.add_argument( + "--no-compile", + action="store_true", + help="Skip torch.compile (slower, but easier to debug).", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend.", + ) + args = parser.parse_args() + + if args.backend == "cuda" and not torch.cuda.is_available(): + parser.error("CUDA is required for the cuda backend.") + + if args.gguf: + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + model, config = load_gguf_model( + args.gguf, args.max_seq_len, backend=args.backend + ) + else: + print(f"Loading prequantized model from {args.prequantized}...") + model, config = load_prequantized_model( + args.prequantized, max_seq_len=args.max_seq_len, backend=args.backend + ) + _move_to_cuda(model, config) + model.eval() + + import executorch.backends.cuda.int4_dispatch # noqa: F401 + + if not args.no_compile: + print("Compiling model with torch.compile...") + model = torch.compile(model, mode="default") + + if args.tokenizer_path: + tokenizer_path = args.tokenizer_path + elif args.prequantized: + tokenizer_path = os.path.join(args.prequantized, "tokenizer.json") + else: + parser.error("--tokenizer-path is required with --gguf.") + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(tokenizer_path) + + # Gemma 4 EOS tokens (from generation_config.json: ids 1, 50, 106). + eos_token_ids = {1, 50, 106} + + print(f"\nPrompt: {args.prompt}") + print("-" * 40) + + t0 = time.perf_counter() + output = generate( + model, + tokenizer, + args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + eos_token_ids=eos_token_ids, + ) + elapsed = time.perf_counter() - t0 + + print(output) + print("-" * 40) + print(f"Generated in {elapsed:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp new file mode 100644 index 00000000000..0be2fef517c --- /dev/null +++ b/examples/models/gemma4_31b/main.cpp @@ -0,0 +1,402 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Gemma 4 31B-IT runner for the CUDA ExecuTorch backend. +// +// Drives the prefill + decode methods produced by export.py. +// The exported model performs Gumbel-max sampling on-device and returns a +// single float token ID per call, so this runner only has to feed tokens +// in and decode them via the HuggingFace tokenizer. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +extern "C" void et_pal_emit_log_message( + ET_UNUSED et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + ET_UNUSED const char* function, + size_t line, + const char* message, + ET_UNUSED size_t length) { + if (level == 'D' || level == 'I') { + return; + } + fprintf(stderr, "%c [%s:%zu] %s\n", (char)level, filename, line, message); +} + +#ifdef EXECUTORCH_BUILD_CUDA +#include +#endif + +DEFINE_string(model_path, "", "Model .pte file path."); +DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend."); +DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); +DEFINE_string(prompt, "Hello", "Prompt text."); +DEFINE_string( + prompt_file, + "", + "Path to file containing prompt text (overrides --prompt)."); +DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); +DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); +DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); +DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); +DEFINE_bool( + cuda_graph, + false, + "Enable CUDA graph capture for the decode method. CUDA only."); + +namespace llm = ::executorch::extension::llm; +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +using SizesType = executorch::aten::SizesType; + +static uint64_t read_token(const executorch::aten::Tensor& output) { + const void* ptr = output.const_data_ptr(); + float val = 0.0f; + +#ifdef EXECUTORCH_BUILD_CUDA + cudaPointerAttributes attrs{}; + bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + if (on_device) { + cudaError_t err = + cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + ET_LOG( + Error, + "read_token: cudaMemcpy D2H failed: %s", + cudaGetErrorString(err)); + return 0; + } + } else { + memcpy(&val, ptr, sizeof(float)); + } +#else + memcpy(&val, ptr, sizeof(float)); +#endif + + return static_cast(llrintf(val)); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty()) { + ET_LOG(Error, "Must specify --model_path"); + return 1; + } + if (FLAGS_tokenizer_path.empty()) { + ET_LOG(Error, "Must specify --tokenizer_path"); + return 1; + } + + llm::Stats stats; + +#ifdef EXECUTORCH_BUILD_CUDA + size_t gpu_free_bytes = 0, gpu_total_bytes = 0; + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_total_bytes = gpu_total_bytes; + stats.gpu_free_before_load_bytes = gpu_free_bytes; +#endif + + stats.model_load_start_ms = llm::time_in_ms(); + + // Tokenizer + auto tokenizer = std::make_unique(); + if (tokenizer->load(FLAGS_tokenizer_path) != tokenizers::Error::Ok) { + ET_LOG( + Error, + "Failed to load tokenizer from %s", + FLAGS_tokenizer_path.c_str()); + return 1; + } + + // Module: share_memory_arenas=true so prefill and decode see the same + // KV-cache memory (we exported with share_mutable_buffers=True). + std::vector data_files; + if (!FLAGS_data_path.empty()) { + data_files.push_back(FLAGS_data_path); + } + auto module = std::make_unique( + FLAGS_model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + // Get metadata + auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to read model metadata"); + return 1; + } + +#ifdef EXECUTORCH_BUILD_CUDA + if (FLAGS_cuda_graph) { + executorch::runtime::BackendOptions<2> cuda_opts; + cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); + executorch::runtime::set_option("CudaBackend", cuda_opts.view()); + printf("CUDA graph enabled for decode method\n"); + } + + // Cross-method per-FQN weight sharing: prefill + decode share the same + // weight tensors and (more importantly) the same KV-cache buffers, so + // without this flag we would allocate them twice. MUST be set before + // load_method. + { + executorch::runtime::BackendOptions<1> backend_options; + auto set_err = + backend_options.set_option("weight_sharing_across_methods", true); + if (set_err != Error::Ok) { + ET_LOG( + Error, + "Failed to construct weight_sharing_across_methods option: %d", + static_cast(set_err)); + return 1; + } + auto opt_err = + executorch::runtime::set_option("CudaBackend", backend_options.view()); + if (opt_err != Error::Ok) { + ET_LOG( + Error, + "Failed to enable weight_sharing_across_methods: %d", + static_cast(opt_err)); + return 1; + } + } +#else + if (FLAGS_cuda_graph) { + ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); + } +#endif + + printf("Loading methods...\n"); + if (module->load_method("prefill") != Error::Ok) { + ET_LOG(Error, "Failed to load prefill method"); + return 1; + } + if (module->load_method("decode") != Error::Ok) { + ET_LOG(Error, "Failed to load decode method"); + return 1; + } + stats.model_load_end_ms = llm::time_in_ms(); + +#ifdef EXECUTORCH_BUILD_CUDA + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_free_after_load_bytes = gpu_free_bytes; +#endif + + auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); + eos_ids.insert(static_cast(FLAGS_eos_id)); + + // Read prompt from file or flag + std::string prompt_text = FLAGS_prompt; + if (!FLAGS_prompt_file.empty()) { + std::ifstream f(FLAGS_prompt_file); + if (!f.is_open()) { + ET_LOG( + Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str()); + return 1; + } + prompt_text = std::string( + (std::istreambuf_iterator(f)), std::istreambuf_iterator()); + } + + // Encode prompt + auto encode_result = tokenizer->encode(prompt_text); + if (!encode_result.ok()) { + ET_LOG(Error, "Failed to encode prompt"); + return 1; + } + auto prompt_tokens = std::move(*encode_result); + // Gemma models require BOS at the start of the sequence. + prompt_tokens.insert( + prompt_tokens.begin(), static_cast(FLAGS_bos_id)); + int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + stats.num_prompt_tokens = num_prompt_tokens; + + stats.inference_start_ms = llm::time_in_ms(); + + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + +#ifdef EXECUTORCH_BUILD_CUDA + // CUDA build: model fuses the sampler. Pass temperature as a third input. + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); +#endif + + // --------------------------------------------------------------- + // Prefill (chunked to respect ring-buffer KV cache limit) + // --------------------------------------------------------------- + // Sliding layers use a ring buffer sized to 2×sliding_window. A single + // prefill call must not exceed this size, otherwise index_copy_ with + // wrapped indices produces non-deterministic results on CUDA. + int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; + { + auto get_result = module->get("get_max_prefill_chunk"); + if (get_result.ok()) { + max_prefill_chunk = get_result->toScalar().to(); + } + } + + uint64_t cur_token = 0; + int64_t prefill_pos = 0; + while (prefill_pos < num_prompt_tokens) { + int64_t chunk_len = + std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk); + + std::string run_method = (chunk_len == 1) ? "decode" : "prefill"; + + std::vector token_data( + prompt_tokens.begin() + prefill_pos, + prompt_tokens.begin() + prefill_pos + chunk_len); + std::vector pos_data(chunk_len); + for (int64_t i = 0; i < chunk_len; i++) { + pos_data[i] = prefill_pos + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, S(chunk_len)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); + + std::vector prefill_inputs; + prefill_inputs.push_back(EValue(tokens_tensor)); + prefill_inputs.push_back(EValue(pos_tensor)); +#ifdef EXECUTORCH_BUILD_CUDA + prefill_inputs.push_back(EValue(temp_tensor)); +#endif + + auto prefill_result = module->execute(run_method, prefill_inputs); + if (prefill_result.error() != Error::Ok) { + ET_LOG( + Error, "%s failed at pos %" PRId64, run_method.c_str(), prefill_pos); + return 1; + } + cur_token = read_token(prefill_result.get()[0].toTensor()); + prefill_pos += chunk_len; + } + + stats.prompt_eval_end_ms = llm::time_in_ms(); + double prefill_ms = + static_cast(stats.prompt_eval_end_ms - stats.inference_start_ms); + printf( + "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_prompt_tokens, + prefill_ms, + num_prompt_tokens * 1000.0 / prefill_ms); + +#ifdef EXECUTORCH_BUILD_CUDA + // Synchronize CUDA device to ensure prefill's writes to shared mutable + // buffers (KV cache) are visible to the decode method, which may run on + // a different CUDA stream. + cudaDeviceSynchronize(); +#endif + + // --------------------------------------------------------------- + // Decode loop + // --------------------------------------------------------------- + int64_t pos = num_prompt_tokens; + std::vector decode_token_data = {static_cast(cur_token)}; + std::vector decode_pos_data = {pos}; + auto decode_tokens = from_blob( + decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); + auto decode_pos = from_blob( + decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); + + uint64_t prev_token = cur_token; + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + decode_token_data[0] = static_cast(cur_token); + decode_pos_data[0] = pos; + + std::vector decode_inputs; + decode_inputs.push_back(EValue(decode_tokens)); + decode_inputs.push_back(EValue(decode_pos)); +#ifdef EXECUTORCH_BUILD_CUDA + decode_inputs.push_back(EValue(temp_tensor)); +#endif + + auto decode_result = module->execute("decode", decode_inputs); + if (decode_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + + prev_token = cur_token; + cur_token = read_token(decode_result.get()[0].toTensor()); + + if (step == 0) { + stats.first_token_ms = llm::time_in_ms(); + } + pos++; + + auto decode_str = tokenizer->decode(prev_token, cur_token); + if (decode_str.ok()) { + printf("%s", decode_str->c_str()); + fflush(stdout); + } + + if (eos_ids.find(cur_token) != eos_ids.end()) { + printf("\n"); + break; + } + } + + stats.inference_end_ms = llm::time_in_ms(); + printf("\n"); + + int64_t num_generated = pos - num_prompt_tokens; + stats.num_generated_tokens = num_generated; + double decode_ms = + static_cast(stats.inference_end_ms - stats.prompt_eval_end_ms); + printf( + "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_generated, + decode_ms, + num_generated * 1000.0 / decode_ms); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + +#ifdef EXECUTORCH_BUILD_CUDA + cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); + stats.gpu_free_after_generate_bytes = gpu_free_bytes; + stats.gpu_peak_usage_mb = + (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; +#endif + + llm::print_report(stats); + return 0; +} diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md new file mode 100644 index 00000000000..8233b6d430e --- /dev/null +++ b/examples/models/gemma4_31b/model.md @@ -0,0 +1,203 @@ +# Gemma 4 31B-IT — Architecture & Design Notes + +Developer reference for `model.py` and the `quant/` package. For +export/build/run instructions see [README.md](README.md). + +The model mirrors the `Gemma4ForConditionalGeneration` text stack from +HuggingFace transformers / vLLM, with the ExecuTorch customizations needed +for `torch.export(strict=True)`. + +## Architecture + +``` +Input tokens (B, T) + | + v +Embedding (vocab=262144, dim=5376) -> *= sqrt(hidden_size) (normalizer) + | + v ++--- Decoder Layer x60 -----------------------------------------+ +| | +| residual = x | +| RMSNorm -> Attention (sliding | full) -> RMSNorm -> +residual | +| residual = x | +| RMSNorm -> MLP (gate_proj, up_proj, down_proj, GELU-tanh) | +| -> RMSNorm -> +residual | +| x *= layer_scalar (per-layer buffer) | +| | ++----------------------------------------------------------------+ + | + v +RMSNorm -> LM Head (tied with embed) -> tanh(logits/30) * 30 + | + v +Gumbel-max sample(temperature) -> next token (B, 1) +``` + +Layer pattern (`5 sliding + 1 full`, repeated 10x — the last layer is full): + +``` +S S S S S F S S S S S F ... S S S S S F (S = sliding, F = full) +``` + +## Attention details + +Two attention flavors, selected by `config.layer_types[layer_idx]`: + +| Property | Sliding (50 layers) | Full (10 layers, idx 5,11,...,59) | +|---------------------|--------------------|-----------------------------------| +| `head_dim` | 256 | 512 | +| `num_kv_heads` | 16 | 4 | +| `num_heads` | 32 | 32 | +| RoPE θ | 10 000 | 1 000 000 | +| RoPE flavor | full neox | proportional, partial=0.25 | +| K = V | no | yes (no `v_proj`) | +| Causal mask | causal | causal | +| Window restriction | 1024 tokens | none | +| Q-norm / K-norm | RMSNorm w/ weight | RMSNorm w/ weight | +| V-norm | RMSNorm no weight | RMSNorm no weight | +| `scaling` | 1.0 | 1.0 | + +Notes: + +- **Proportional partial RoPE**: the inv_freq vector for full-attention layers + has the first `head_dim * partial_rotary_factor / 2 = 64` frequencies real + (computed with denominator `head_dim`, not `rotary_dim` — that's the + proportional part) and the remaining `head_dim/2 - 64 = 192` zero so cos=1 + and sin=0 (identity rotation) for the non-rotated dims. +- **K = V**: on full-attention layers `v_proj` is absent in the checkpoint + and `V` is taken from the pre-norm `K` projection. After `k_norm` / + RoPE on K and `v_norm` (weightless) on V the two diverge, so the cache + still stores them separately. +- **Mask construction**: a single boolean `(1, 1, T_q, T_kv)` mask is built + once per forward at the model level — one for sliding (causal AND + pos_q - pos_k < 1024), one for full (just causal). Layers pick whichever + matches their type and pass it to `F.scaled_dot_product_attention(..., + enable_gqa=True)`. +- **Gemma `scaling=1.0`**: unlike Gemma 2/3, Gemma 4 does not scale Q by + `query_pre_attn_scalar`; QK-norm handles attention magnitude. + +## Model parameters (text stack) + +| Parameter | Value | +|---------------------------------|------------| +| `vocab_size` | 262 144 | +| `hidden_size` | 5 376 | +| `intermediate_size` | 21 504 | +| `num_hidden_layers` | 60 | +| `num_attention_heads` | 32 | +| `num_key_value_heads` (sliding) | 16 | +| `head_dim` (sliding) | 256 | +| `num_global_key_value_heads` | 4 | +| `global_head_dim` | 512 | +| `sliding_window` | 1024 | +| `rms_norm_eps` | 1e-6 | +| `final_logit_softcapping` | 30.0 | +| `tie_word_embeddings` | true | +| `max_position_embeddings` | 262 144 | + +Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, +`pre_feedforward_layernorm`, `post_feedforward_layernorm` — all +`RMSNorm` (multiplies by `weight` directly, not `(1 + weight)`). + +## Methods exported (`export.py`) + +| Method | Input | Output (sampled) | +|-----------|------------------------------------------------------------|------------------| +| `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | +| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[5, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float | + +Both methods share the same KV-cache buffers via +`MemoryPlanningPass(share_mutable_buffers=True)` and +`emit_mutable_buffer_names=True`. The exported program performs Gumbel-max +sampling on-device and returns a single token ID per call so the C++ runner +only has to feed tokens. + +Prefill length is capped to the ring-buffer KV cache size +(`2 × sliding_window`) to avoid duplicate wrapped indices in +`index_copy_`. The C++ runner chunks longer prompts automatically using +the `get_max_prefill_chunk` constant method. Chunked prefill produces +identical logits to sequential one-token-at-a-time prefill. + +## Quantization + +Modules in `quant/`: + +- **Recipe** (`recipe.py`): `QuantConfig` + `QuantRule` + `QuantRecipe`. + Declares what to quantize — says nothing about packing or backends. +- **Quantize** (`quantize.py`): `quantize_weight` / `dequantize_weight` / + `quantize_model`. Produces torchao tensor subclasses (`Int4Tensor`, + `IntxUnpackedToInt8Tensor`) from fp weights. +- **Serialization**: callers use torchao's safetensors integration + (`torchao.prototype.safetensors`) directly — no wrapper module needed. +- **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by + parent module, `pack_one` handles single weights. Per-module packers + dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE). +- **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for + loading community-quantized GGUF files (Q4_K, Q6_K). + +The quantize-once flow: + +``` +quantize_and_save.py export.py / inference.py + | | + bf16 weights quantized checkpoint (safetensors) + | | + quantize_weight() load (torchao safetensors) + | | + Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked (used directly) + | | + save (torchao safetensors) int4_dispatch routes to int4_plain_mm + | | + model.safetensors dp4a decode / dequant+cuBLAS prefill +``` + +`embed_tokens` and `lm_head` start tied; they are untied before +quantization so `lm_head` (a 5376→262 144 matmul, very expensive at decode) +gets quantized. The embedding gets INT8 per-axis quantization (nearly +lossless for index lookup). + +## Runtime buffer materialization + +After weight loading (via `from_hf_checkpoint()`), the model's KV caches, +RoPE inv_freq buffers, and scalar constants are still on the meta device. +`materialize_runtime_buffers(model, dtype, device)` in `model.py` replaces +them with real tensors: + +- KV caches → zeros in `dtype` (bf16 for inference, bf16 for export) +- `inv_freq` → moved to target device (cos/sin computed on the fly per forward) +- `embed_normalizer`, `logit_softcap`, `cache_positions` → scalar constants + +Called by `export.py` (device="cpu" for tracing) and `inference.py` +(device="cuda" for eager execution). + +## Customizations vs. vLLM / transformers reference + +These exist solely to make the model exportable / efficient under ExecuTorch: + +- **Boolean attention mask** built once per forward and shared across layers + of the same type, instead of HF's per-layer `_create_causal_mask`. +- **Ring-buffer KV cache** for sliding layers (`RingKVCache`, sized to + `2 × sliding_window`) saves memory for long sequences — positions wrap + via modulo and the attention mask reconstructs which slots are valid. + Full-attention layers use a flat `Gemma4KVCache` sized to `max_seq_len`. + Both use `index_copy_(dim=2, ...)` for trace-friendly updates. +- **On-the-fly RoPE**: stores only `inv_freq` per layer, computes cos/sin + via `torch.outer(positions, inv_freq)` each forward. Saves memory vs + precomputed `[max_seq_len, head_dim]` tables (sliding uses full RoPE, + full uses proportional partial RoPE — head_dim and θ differ). +- **On-device Gumbel-max sampling** so the exported program emits a token + rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a + single float per step. +- **Final-logit softcap baked into the graph**, applied before sampling. +- **Meta-device construction + assign-load** keeps peak memory small enough + to load the 31B-parameter checkpoint on one machine. + +## Shared primitives + +The numerically-sensitive math primitives are imported from +`examples.models.gemma4.text_decoder` and shared with the Gemma 4 E2B/E4B +example: `RMSNorm`, `RMSNormNoWeight`, `Gemma4MLP`, `Gemma4KVCache`, +`apply_rotary_emb`. The 31B-specific pieces (attention with K=V branch, +decoder layer, top-level model with softcap + sampling, checkpoint loader) +live in `model.py`. diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py new file mode 100644 index 00000000000..b0eb4004c52 --- /dev/null +++ b/examples/models/gemma4_31b/model.py @@ -0,0 +1,694 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Gemma 4 31B-IT — export-friendly reference implementation for ExecuTorch. + +Model definition designed for torch.export(strict=True) with the CUDA backend. +All stateful buffers (KV cache, RoPE inv_freq) are registered buffers so they +are captured by share_mutable_buffers across prefill/decode. The numerically +sensitive primitives — RMSNorm, GELU-tanh MLP, proportional/full RoPE, and +the BHSD KV cache — are imported from ``examples.models.gemma4.text_decoder`` +so the 31B and E2B/E4B paths share them. + +Reference: + - HF transformers: src/transformers/models/gemma4/modeling_gemma4.py + - vLLM: vllm/model_executor/models/gemma4.py + +Architecture highlights for the 31B dense variant: + - 60 decoder layers with hybrid attention: every 6th layer is "full" attention + (idx 5, 11, ..., 59 — 10 layers); the remaining 50 use sliding-window + attention with window=1024. + - Sliding layers: head_dim=256, num_kv_heads=16, full RoPE, theta=10000. + - Full layers: head_dim=512, num_kv_heads=4, K=V (no v_proj), and + "proportional" partial RoPE (factor=0.25, theta=1_000_000). + - Q-norm and K-norm with learnable scale; V-norm without scale. + - Per-layer scalar (loaded buffer) multiplied at the end of each layer. + - Final logits are soft-capped: tanh(logits / 30) * 30. + - Embedding is scaled by sqrt(hidden_size) before layer 0. + - Embedding and lm_head are tied (a single weight, untied for quantization + in the export step so lm_head can be 4-bit). +""" + +import json +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch.nn as nn + +# Shared primitives lifted out of the gemma4 (E2B/E4B) example. These are the +# bits whose semantics are identical for both variants — RMSNorm, the GELU-tanh +# MLP, the proportional/full RoPE table builder, and the BHSD KV cache. +from executorch.examples.models.gemma4.text_decoder import ( + apply_rotary_emb, + Gemma4KVCache, + Gemma4MLP, + RMSNorm, + RMSNormNoWeight, +) +from executorch.examples.models.gemma4_31b.sampler import sample +from torch.nn import functional as F + + +# --------------------------------------------------------------------------- +# Ring-buffer KV cache for sliding window attention + + +class RingKVCache(nn.Module): + """Ring-buffer KV cache for sliding window attention. + + Sized to ``window_size * 2`` (not ``max_seq_len``), saving memory for + long sequences. Positions wrap via modulo; old entries outside the + window are masked out by ``_build_masks``. + """ + + def __init__( + self, + max_batch_size: int, + window_size: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + self.window_size = window_size + self.buf_size = window_size * 2 + cache_shape = (max_batch_size, num_kv_heads, self.buf_size, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False) + self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # seq_len must not exceed buf_size, otherwise wrapped indices contain + # duplicates and index_copy_ is non-deterministic on CUDA. The C++ + # runner must chunk prefill to respect this limit. + assert ( + input_pos.shape[0] <= self.buf_size + ), f"seq_len {input_pos.shape[0]} > buf_size {self.buf_size}" + wrapped = input_pos % self.buf_size + self.k_cache.index_copy_(2, wrapped, k_val) + self.v_cache.index_copy_(2, wrapped, v_val) + return self.k_cache, self.v_cache + + +# --------------------------------------------------------------------------- +# Config + + +@dataclass +class Gemma4_31BConfig: + # Embedding / shape + vocab_size: int = 262144 + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + + # Attention shape (sliding layers — also the "default" path) + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + head_dim: int = 256 + + # Attention shape (full-attention layers) + num_global_key_value_heads: int = 4 + global_head_dim: int = 512 + attention_k_eq_v: bool = ( + True # full layers: V is derived from the same projection as K + ) + + # RoPE — split per layer type + sliding_rope_theta: float = 10_000.0 + full_rope_theta: float = 1_000_000.0 + full_partial_rotary_factor: float = 0.25 # proportional RoPE for full attention + + # Norm / activation + rms_norm_eps: float = 1e-6 + hidden_activation: str = "gelu_pytorch_tanh" + + # Sampling / output + final_logit_softcapping: float = 30.0 + tie_word_embeddings: bool = True + + # Sliding window + sliding_window: int = 1024 + + # Hybrid attention pattern + layer_types: list = field(default_factory=list) + + # Runtime + max_seq_len: int = 4096 + + def __post_init__(self): + if not self.layer_types: + # Default hybrid pattern: 5 sliding then 1 full, repeated. + self.layer_types = [ + "full_attention" if (i + 1) % 6 == 0 else "sliding_attention" + for i in range(self.num_hidden_layers) + ] + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError( + f"layer_types length {len(self.layer_types)} != " + f"num_hidden_layers {self.num_hidden_layers}" + ) + + @staticmethod + def from_hf_config(config_path: str) -> "Gemma4_31BConfig": + with open(config_path, "r") as f: + cfg = json.load(f) + if "text_config" in cfg: + cfg = cfg["text_config"] + + rope_params = cfg.get("rope_parameters", {}) + sliding_rope = rope_params.get("sliding_attention", {}) + full_rope = rope_params.get("full_attention", {}) + + return Gemma4_31BConfig( + vocab_size=cfg.get("vocab_size", 262144), + hidden_size=cfg.get("hidden_size", 5376), + intermediate_size=cfg.get("intermediate_size", 21504), + num_hidden_layers=cfg.get("num_hidden_layers", 60), + num_attention_heads=cfg.get("num_attention_heads", 32), + num_key_value_heads=cfg.get("num_key_value_heads", 16), + head_dim=cfg.get("head_dim", 256), + num_global_key_value_heads=cfg.get("num_global_key_value_heads", 4), + global_head_dim=cfg.get("global_head_dim", 512), + attention_k_eq_v=cfg.get("attention_k_eq_v", True), + sliding_rope_theta=sliding_rope.get("rope_theta", 10_000.0), + full_rope_theta=full_rope.get("rope_theta", 1_000_000.0), + full_partial_rotary_factor=full_rope.get("partial_rotary_factor", 0.25), + rms_norm_eps=cfg.get("rms_norm_eps", 1e-6), + hidden_activation=cfg.get("hidden_activation", "gelu_pytorch_tanh"), + final_logit_softcapping=cfg.get("final_logit_softcapping", 30.0), + tie_word_embeddings=cfg.get("tie_word_embeddings", True), + sliding_window=cfg.get("sliding_window", 1024), + layer_types=cfg.get("layer_types", []), + ) + + +# --------------------------------------------------------------------------- +# Attention — single class, branches on layer type via config +# +# RMSNorm, Gemma4MLP, the RoPE helpers, and Gemma4KVCache are imported from +# examples.models.gemma4.text_decoder so the two Gemma 4 variants share their +# numerically-sensitive primitives. + + +class Gemma4Attention(nn.Module): + """Gemma 4 attention with QK-norm, per-layer head_dim, RoPE, KV cache, and SDPA. + + The same class handles both sliding and full attention; the per-layer + config picks head_dim, num_kv_heads, RoPE flavor, and the K=V optimization. + """ + + def __init__(self, config: Gemma4_31BConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + layer_type = config.layer_types[layer_idx] + self.is_sliding = layer_type == "sliding_attention" + + if self.is_sliding: + self.head_dim = config.head_dim + self.n_kv_heads = config.num_key_value_heads + self.rope_theta = config.sliding_rope_theta + self.partial_rotary = 1.0 + self.k_eq_v = False + else: + self.head_dim = config.global_head_dim + self.n_kv_heads = config.num_global_key_value_heads + self.rope_theta = config.full_rope_theta + self.partial_rotary = config.full_partial_rotary_factor + self.k_eq_v = config.attention_k_eq_v + + self.n_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.scaling = 1.0 # Gemma 4 uses scale=1; QK-norm handles normalization. + + # Linear projections. v_proj is omitted on K=V layers to match the checkpoint. + self.q_proj = nn.Linear( + self.hidden_size, self.n_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.n_kv_heads * self.head_dim, bias=False + ) + if not self.k_eq_v: + self.v_proj = nn.Linear( + self.hidden_size, self.n_kv_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.n_heads * self.head_dim, self.hidden_size, bias=False + ) + + # Q/K norm have learnable weight; V norm is weightless. + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = RMSNormNoWeight(self.head_dim, eps=config.rms_norm_eps) + + # Precomputed RoPE table for this layer (per-layer because head_dim + # and theta differ between sliding and full attention). For full + # attention layers we pass freq_base_dim=head_dim so the zero-padded + # On-the-fly RoPE: store only inv_freq, compute cos/sin per forward. + # Saves memory vs precomputed [max_seq_len, head_dim] tables. + if self.is_sliding: + rotary_dim = self.head_dim + else: + rotary_dim = int(self.head_dim * self.partial_rotary) + rope_angles = rotary_dim // 2 + inv_freq_rotated = 1.0 / ( + self.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / self.head_dim) + ) + nope_angles = self.head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat([inv_freq_rotated, torch.zeros(nope_angles)]) + else: + inv_freq = inv_freq_rotated + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # KV cache. Sliding layers use a ring buffer (2x window) to save + # memory; full layers use a flat buffer (max_seq_len). + if self.is_sliding: + self.kv_cache = RingKVCache( + max_batch_size=1, + window_size=config.sliding_window, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + ) + else: + self.kv_cache = Gemma4KVCache( + max_batch_size=1, + max_seq_len=config.max_seq_len, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + use_index_copy=True, + ) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + # raw_kv is the linear output before any norm — needed for K=V layers + # so V can be derived from the same tensor as K (post-norm differently). + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + # Norms applied per-head (HF unflatten -> norm -> flatten pattern). + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # Move to BHSD for SDPA / KV cache. + q = q.transpose(1, 2) # (B, H, T, D) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE on Q and K only (V is not rotated). cos/sin computed on the fly. + freqs = torch.outer(input_pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = torch.cos(emb) + sin = torch.sin(emb) + q, k = apply_rotary_emb(q, k, cos, sin) + + # Update cache and read back full K/V. + k, v = self.kv_cache.update(input_pos, k, v) + + # SDPA with explicit additive mask (already includes causal + + # sliding-window masking; built once per forward at the model level). + # `scale=1.0` matches HF Gemma 4 — Q-norm/K-norm have absorbed the + # 1/sqrt(d) factor into their trained weights, so the standard SDPA + # default of 1/sqrt(head_dim) would over-divide. enable_gqa lets the + # kernel handle the head ratio without us materializing expanded K/V. + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + is_causal=False, + enable_gqa=True, + scale=self.scaling, + ) + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + +# --------------------------------------------------------------------------- +# Decoder block — Gemma's "norm sandwich" pattern. + + +class Gemma4DecoderLayer(nn.Module): + def __init__(self, config: Gemma4_31BConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + + self.self_attn = Gemma4Attention(config, layer_idx) + self.mlp = Gemma4MLP(config.hidden_size, config.intermediate_size) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Per-layer scalar (loaded from checkpoint) — multiplied at the end of + # each layer. Kept as a buffer (not nn.Parameter) so it isn't quantized. + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + sliding_mask: torch.Tensor, + full_mask: torch.Tensor, + ) -> torch.Tensor: + attn_mask = sliding_mask if self.is_sliding else full_mask + + residual = x + h = self.input_layernorm(x) + h = self.self_attn(h, input_pos, attn_mask) + h = self.post_attention_layernorm(h) + x = residual + h + + residual = x + h = self.pre_feedforward_layernorm(x) + h = self.mlp(h) + h = self.post_feedforward_layernorm(h) + x = residual + h + + return x * self.layer_scalar + + +# --------------------------------------------------------------------------- +# Top-level model + + +class Gemma4_31B(nn.Module): + def __init__(self, config: Gemma4_31BConfig): + super().__init__() + self.config = config + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [Gemma4DecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Held separately so it can be untied + quantized at export time. + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Constants (registered as buffers so they move with .to(device)). + self.register_buffer( + "embed_normalizer", + torch.tensor(config.hidden_size**0.5), + persistent=False, + ) + self.register_buffer( + "logit_softcap", + torch.tensor(config.final_logit_softcapping), + persistent=False, + ) + # cache_positions[i] = i — used to build attention masks without + # introducing dynamic-shape tensors at runtime. + self.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long), + persistent=False, + ) + + def _build_masks( + self, input_pos: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build boolean (B=1, H=1, T_q, T_kv) masks for full and sliding attention. + + True = attend. Built once per forward, shared across layers of the + same type. Full mask is (T_q, max_seq_len); sliding mask is + (T_q, buf_size) where buf_size = 2 * sliding_window. + """ + # Full attention mask: (T_q, max_seq_len) + cache_pos = self.cache_positions # (max_seq_len,) + q_pos = input_pos.unsqueeze(1) # (T_q, 1) + causal = q_pos >= cache_pos.unsqueeze(0) + full_mask = causal.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, max_seq_len) + + # Sliding attention mask over ring buffer: (T_q, buf_size) + buf_size = self.config.sliding_window * 2 + seq_len = input_pos.shape[0] + total_written = input_pos[0] + seq_len + j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device) + ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size + delta = q_pos - ring_pos.unsqueeze(0) + sliding = (ring_pos >= 0) & (delta >= 0) & (delta < self.config.sliding_window) + sliding_mask = sliding.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, buf_size) + + return sliding_mask, full_mask + + def forward( + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run the model. + + Args: + tokens: (B, T) token IDs. + input_pos: (T,) absolute positions for RoPE / KV cache. + temperature: optional 1-D float tensor controlling on-device sampling. + When provided, returns sampled tokens (B, 1) via Gumbel-max; + when None (e.g. eager eval), returns full logits (B, T, V) with + soft-capping applied so callers see post-cap values. + + Returns: + (B, 1) token IDs when sampling, else (B, T, V) float32 logits. + """ + x = self.embed_tokens(tokens) * self.embed_normalizer + + sliding_mask, full_mask = self._build_masks(input_pos) + for layer in self.layers: + x = layer(x, input_pos, sliding_mask, full_mask) + + x = self.norm(x) + + if temperature is None: + logits = self.lm_head(x).float() + cap = self.logit_softcap.float() + return torch.tanh(logits / cap) * cap + + # Decode-time fast path: only materialize logits for the last token. + last = self.lm_head(x[:, -1, :]).float() + cap = self.logit_softcap.float() + last = torch.tanh(last / cap) * cap + return sample(last, temperature) + + # ---------------- checkpoint loading ---------------- + + @staticmethod + def from_hf_checkpoint( + model_dir: str, max_seq_len: int = 4096 + ) -> tuple["Gemma4_31B", Gemma4_31BConfig]: + """Build the model on `meta` and load weights from the HF safetensors checkpoint. + + Uses lazy shard-by-shard loading + assign=True so peak memory stays at + roughly one shard's worth of weights. + """ + config = Gemma4_31BConfig.from_hf_config(os.path.join(model_dir, "config.json")) + config.max_seq_len = max_seq_len + + print( + f"Building Gemma4_31B on meta (layers={config.num_hidden_layers}, " + f"hidden={config.hidden_size}, max_seq_len={max_seq_len})..." + ) + with torch.device("meta"): + model = Gemma4_31B(config) + + print(f"Loading weights from {model_dir}...") + state_dict = _load_and_remap_checkpoint(model_dir, config) + + # Tied embeddings: copy embedding weight into lm_head when missing. + if "lm_head.weight" not in state_dict and "embed_tokens.weight" in state_dict: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"] + + missing, unexpected = model.load_state_dict( + state_dict, strict=False, assign=True + ) + + # Runtime buffers (KV caches, RoPE tables, masks) are zero-initialized + # and not in the checkpoint — those are the "expected" missing keys. + runtime_prefixes = ( + ".kv_cache.", + ".inv_freq", + "embed_normalizer", + "logit_softcap", + "cache_positions", + ) + actual_missing = set(missing) + expected = {k for k in actual_missing if any(p in k for p in runtime_prefixes)} + extra = actual_missing - expected + if extra: + print(f" WARNING: missing weight keys: {sorted(extra)[:10]}") + if unexpected: + print(f" WARNING: unexpected keys: {sorted(unexpected)[:10]}") + print( + f" Loaded {len(state_dict)} tensors " + f"({len(expected)} runtime buffers OK)" + ) + return model, config + + +# --------------------------------------------------------------------------- +# Weight loading utilities + + +# HuggingFace key -> our model key. Patterns use `{}` for the layer index. +_HF_KEY_MAP = { + "model.embed_tokens.weight": "embed_tokens.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "lm_head.weight", + # Per-layer norms + "model.layers.{}.input_layernorm.weight": "layers.{}.input_layernorm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_layernorm.weight", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.pre_feedforward_layernorm.weight", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_feedforward_layernorm.weight", + "model.layers.{}.layer_scalar": "layers.{}.layer_scalar", + # Attention projections + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.self_attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.self_attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.self_attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.self_attn.o_proj.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.self_attn.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.self_attn.k_norm.weight", + # MLP + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.gate_proj.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.up_proj.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.down_proj.weight", +} + +# Multimodal keys we deliberately ignore for the text-only export. +_IGNORED_PREFIXES = ( + "model.vision_tower.", + "model.embed_vision.", +) + + +def _hf_to_model_key(hf_key: str) -> Optional[str]: + # Gemma4ForConditionalGeneration stores the LM under model.language_model.* + norm = hf_key + if norm.startswith("model.language_model."): + norm = norm.replace("model.language_model.", "model.", 1) + + if norm.startswith(_IGNORED_PREFIXES): + return None + + for hf_pat, model_pat in _HF_KEY_MAP.items(): + if "{}" not in hf_pat: + if norm == hf_pat: + return model_pat + continue + regex = re.escape(hf_pat).replace(r"\{\}", r"(\d+)") + m = re.fullmatch(regex, norm) + if m: + return model_pat.replace("{}", m.group(1), 1) + return None + + +def _load_and_remap_checkpoint(model_dir: str, config: Gemma4_31BConfig) -> dict: + """Stream-load safetensors shards and remap keys to model state_dict keys.""" + from safetensors import safe_open + + index_path = os.path.join(model_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + shard_files = sorted(set(index["weight_map"].values())) + elif os.path.exists(os.path.join(model_dir, "model.safetensors")): + shard_files = ["model.safetensors"] + else: + raise FileNotFoundError(f"No safetensors checkpoint in {model_dir}") + + state_dict: dict[str, torch.Tensor] = {} + skipped = 0 + for shard_file in shard_files: + shard_path = os.path.join(model_dir, shard_file) + with safe_open(shard_path, framework="pt", device="cpu") as f: + for ckpt_key in f.keys(): + model_key = _hf_to_model_key(ckpt_key) + if model_key is None: + skipped += 1 + continue + tensor = f.get_tensor(ckpt_key) + # layer_scalar in checkpoint is shape (1,) bf16 — keep as-is. + state_dict[model_key] = tensor + if skipped > 0: + print(f" Skipped {skipped} non-text keys (vision tower, etc.)") + return state_dict + + +# --------------------------------------------------------------------------- +# Runtime buffer materialization + + +def materialize_runtime_buffers( + model: Gemma4_31B, + dtype: torch.dtype, + device: str = "cpu", +) -> None: + """Replace meta-device buffers with real tensors and set runtime constants. + + Called after weight loading to fill in KV caches (zeros), RoPE tables + (computed), and scalar constants. Only touches buffers still on the meta + device — loaded (non-meta) buffers are left in place. + """ + config = model.config + + for fqn, buf in list(model.named_buffers()): + if buf.device.type != "meta": + continue + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + is_kv = ".kv_cache." in fqn + target_dtype = dtype if is_kv else torch.float32 + if buf.dtype == torch.bool: + target_dtype = torch.bool + parent.register_buffer( + parts[-1], + torch.zeros(buf.shape, dtype=target_dtype, device=device), + persistent=False, + ) + + for layer in model.layers: + attn = layer.self_attn + attn.inv_freq = attn.inv_freq.to(device) + + model.register_buffer( + "embed_normalizer", + torch.tensor(config.hidden_size**0.5, device=device), + persistent=False, + ) + model.register_buffer( + "logit_softcap", + torch.tensor(config.final_logit_softcapping, device=device), + persistent=False, + ) + model.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long, device=device), + persistent=False, + ) diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md new file mode 100644 index 00000000000..31b1c43d574 --- /dev/null +++ b/examples/models/gemma4_31b/quant/README.md @@ -0,0 +1,54 @@ +# quant/ + +Quantization framework: **recipe → quantize → pack**. + +## Files + +| File | Concern | Depends on | +|---|---|---| +| `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | +| `quantize.py` | **Computation** — produces torchao subclass tensors | recipe, torchao | +| `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | — | +| `pack_cuda.py` | **CUDA packing** — converts Int4Tensor to tinygemm format | pack | +| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao | + +## Data flow + +``` +QuantRecipe → quantize_model() → state_dict{Int4Tensor, IntxUnpackedToInt8Tensor, Tensor} → safetensors → state_dict → pack_model() → runtime model +``` + +Quantized weights are stored as torchao tensor subclasses: +- **Int4Tensor** — 4-bit weights (nibble-packed qdata + transposed scale/zero_point) +- **IntxUnpackedToInt8Tensor** — 8-bit weights (int8 qdata + scale + zero_point) + +These are the canonical interchange formats from torchao. Everything left +of `save()` is backend-agnostic. Everything right is backend-specific. + +## Adding a new backend + +Write a `pack_.py` with per-module packers and a default registry: + +```python +def pack_linear_for_metal(module, weights): ... +DEFAULT_METAL_PACKERS = {nn.Linear: pack_linear_for_metal} +``` + +Call `pack_model(model, state_dict, packers=DEFAULT_METAL_PACKERS)`. +No changes to recipe or quantize. + +## On-disk format + +Uses torchao's safetensors integration (`torchao.prototype.safetensors`). +Each tensor subclass is decomposed into its inner tensors +(e.g., `layer._weight_qdata`, `layer._weight_scale`) plus JSON metadata +recording the subclass type and attributes. Plain tensors are stored as-is. +The format is compatible with torchao's `save_pretrained` / `load_pretrained`. + +## TODO + +- `pack_metal.py` — Metal backend packer. +- `pack_mlx.py` — MLX backend packer. +- `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. +- Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao + to replace the manual conversion in `pack_int4_for_cuda`. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py new file mode 100644 index 00000000000..93efb69865f --- /dev/null +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .pack import ModulePackerFn, pack_model, pack_one # noqa: F401 +from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401 +from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401 +from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py new file mode 100644 index 00000000000..78c3aa3d8f9 --- /dev/null +++ b/examples/models/gemma4_31b/quant/gguf.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unpack GGUF quantized tensors to torchao tensor subclasses. + +Supports Q4_K, Q6_K, F32, and F16 tensor types. Two public APIs: + + - ``unpack_gguf_tensor`` — convert a single tensor + - ``iter_gguf_tensors`` — stream all tensors from a file (low peak memory) + +Model-agnostic. For Gemma 4 31B key mapping and model loading, see +``gguf_loader.py``. +""" + +from collections.abc import Iterator + +import torch + +QK_K = 256 # super-block size for k-quants +Q4_K_GROUPS = 8 # sub-blocks per Q4_K super-block +Q4_K_GROUP_SIZE = QK_K // Q4_K_GROUPS # 32 +Q6_K_GROUPS = 16 # sub-blocks per Q6_K super-block +Q6_K_GROUP_SIZE = QK_K // Q6_K_GROUPS # 16 + + +def _raw_tensor(data: bytes) -> torch.Tensor: + """Wrap a numpy mmap view as a uint8 torch tensor (zero-copy).""" + return torch.frombuffer(memoryview(data), dtype=torch.uint8) + + +def _read_f16(raw: torch.Tensor, col_start: int, col_end: int) -> torch.Tensor: + """Read fp16 field from block bytes, return float32.""" + return raw[:, col_start:col_end].contiguous().view(torch.float16).float() + + +def _unpack_q4_k(data, shape: list[int]) -> torch.Tensor: + """Unpack Q4_K super-blocks into an ``Int4Tensor``. + + Q4_K block layout (144 bytes per 256 values): + - d (2B, fp16): super-block scale + - dmin (2B, fp16): super-block min + - scales (12B): 8 sub-block scales + 8 sub-block mins, 6-bit packed + - qs (128B): 256 4-bit values, two per byte + + Dequant: weight = d * sub_scale * q - dmin * sub_min + """ + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + N, K = shape + assert K % QK_K == 0, f"Q4_K requires K divisible by {QK_K}, got {K}" + n_blocks = N * (K // QK_K) + block_bytes = 2 + 2 + 12 + QK_K // 2 # 144 + raw = _raw_tensor(data).reshape(n_blocks, block_bytes) + + d = _read_f16(raw, 0, 2) + dmin = _read_f16(raw, 2, 4) + s = raw[:, 4:16] + qs = raw[:, 16:144] + + sc = torch.empty(n_blocks, 8, dtype=torch.float32) + mn = torch.empty(n_blocks, 8, dtype=torch.float32) + sc[:, :4] = (s[:, :4] & 0x3F).float() + mn[:, :4] = (s[:, 4:8] & 0x3F).float() + sc[:, 4:] = ((s[:, 8:12] & 0xF) | ((s[:, :4] >> 6) << 4)).float() + mn[:, 4:] = ((s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)).float() + del s + + eff_scale = (d * sc).reshape(N, -1) + eff_min = (dmin * mn).reshape(N, -1) + del d, dmin, sc, mn + + zero_std = torch.where( + eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) + ) + del eff_min + + # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair + low = (qs & 0x0F).to(torch.uint8) + high = ((qs >> 4) & 0x0F).to(torch.uint8) + qdata_unpacked = torch.cat( + [ + low[:, :32], + high[:, :32], + low[:, 32:64], + high[:, 32:64], + low[:, 64:96], + high[:, 64:96], + low[:, 96:128], + high[:, 96:128], + ], + dim=-1, + ).reshape(N, K) + del qs, low, high + + # Nibble-pack for Int4Tensor: even=LOW, odd=HIGH + packed = qdata_unpacked[:, ::2] | (qdata_unpacked[:, 1::2] << 4) + + # Int4Tensor scale/zero layout: (K//gs, N) — transposed + return Int4Tensor( + qdata=packed, + scale=eff_scale.to(torch.bfloat16).t().contiguous(), + zero_point=zero_std.to(torch.bfloat16).t().contiguous(), + block_size=[1, Q4_K_GROUP_SIZE], + shape=torch.Size([N, K]), + ) + + +def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: + """Unpack Q6_K super-blocks into an ``IntxUnpackedToInt8Tensor``. + + Q6_K block layout (210 bytes per 256 values): + - ql (128B): lower 4 bits of 256 6-bit values + - qh (64B): upper 2 bits of 256 6-bit values + - scales (16B): 16 int8 sub-block scales (groups of 16) + - d (2B, fp16): super-block scale + + Dequant: weight = d * scale_j * (q - 32) + Values are 6-bit [-32, 31], widened to INT8. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K = shape + assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" + n_blocks = N * (K // QK_K) + block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 + raw = _raw_tensor(data).reshape(n_blocks, block_bytes) + + ql = raw[:, 0:128] + qh = raw[:, 128:192] + sc = raw[:, 192:208] + d = _read_f16(raw, 208, 210) + + qh0 = qh[:, :32] + qh1 = qh[:, 32:64] + qdata = torch.empty(n_blocks, QK_K, dtype=torch.int16) + qdata[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) + qdata[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) + qdata[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) + qdata[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) + qdata[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) + qdata[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) + qdata[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) + qdata[:, 224:256] = ((ql[:, 96:128] >> 4) & 0x0F) | (((qh1 >> 6) & 0x03) << 4) + qdata -= 32 + del ql, qh, qh0, qh1 + + # sc bytes are signed int8 scales; reinterpret from uint8 + eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) + del d, sc + + return IntxUnpackedToInt8Tensor( + qdata=qdata.reshape(N, K).to(torch.int8), + scale=eff_scale.to(torch.bfloat16), + zero_point=torch.zeros_like(eff_scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, Q6_K_GROUP_SIZE), + dtype=torch.bfloat16, + activation_quantization=None, + ) + + +def unpack_gguf_tensor( + tensor_data, + tensor_type, + shape: list[int], +) -> torch.Tensor: + """Unpack a single GGUF tensor. + + Returns an ``Int4Tensor`` for Q4_K, ``IntxUnpackedToInt8Tensor`` for Q6_K, + or a plain ``torch.Tensor`` for F32/F16. + """ + from gguf import GGMLQuantizationType + + if tensor_type == GGMLQuantizationType.Q4_K: + return _unpack_q4_k(tensor_data, shape) + elif tensor_type == GGMLQuantizationType.Q6_K: + return _unpack_q6_k(tensor_data, shape) + elif tensor_type == GGMLQuantizationType.F32: + return _raw_tensor(tensor_data).view(torch.float32).reshape(shape).clone() + elif tensor_type == GGMLQuantizationType.F16: + return ( + _raw_tensor(tensor_data) + .view(torch.float16) + .reshape(shape) + .to(torch.bfloat16) + ) + else: + raise ValueError(f"Unsupported GGUF quant type: {tensor_type}") + + +def iter_gguf_tensors( + path: str, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(name, result)`` for each tensor in a GGUF file. + + Processes one tensor at a time for low peak memory. Tensor names are + GGUF names (e.g., ``blk.0.attn_q.weight``); the caller handles key + remapping. GGUF shapes are reversed to PyTorch convention automatically. + """ + from gguf import GGUFReader + + reader = GGUFReader(path) + for tensor in reader.tensors: + shape = list(reversed(tensor.shape.tolist())) + result = unpack_gguf_tensor(tensor.data, tensor.tensor_type, shape) + yield tensor.name, result diff --git a/examples/models/gemma4_31b/quant/pack.py b/examples/models/gemma4_31b/quant/pack.py new file mode 100644 index 00000000000..95abc43546a --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Backend-agnostic model packing: quantized state dict → runtime model. + +``pack_model`` walks a state dict, groups quantized weights by parent +module, and dispatches to per-module packer functions. Each backend +(``pack_cuda.py``, future ``pack_metal.py``) provides its own packers dict. +""" + +from collections import defaultdict +from typing import Callable + +import torch +import torch.nn as nn + +# Packer signature: receives the module + a dict of its quantized weights +# (keyed by attribute name), modifies module in-place. +ModulePackerFn = Callable[[nn.Module, dict[str, torch.Tensor]], None] + + +def _is_quantized(value: torch.Tensor) -> bool: + """Check if a tensor is a torchao quantized subclass.""" + from torchao.utils import TorchAOBaseTensor + + return isinstance(value, TorchAOBaseTensor) + + +def pack_model( + model: nn.Module, + state_dict: dict[str, torch.Tensor], + packers: dict[type, ModulePackerFn], +) -> None: + """Pack a state dict into ``model`` using the given packers. + + Quantized weights (torchao tensor subclasses) are grouped by parent + module and dispatched to per-module packers. Plain tensors are assigned + directly as parameters or buffers. + """ + # Separate quantized and unquantized + for fqn, value in state_dict.items(): + if not _is_quantized(value): + pack_one(model, fqn, value, packers) + + # Group quantized weights by parent module + module_weights: dict[str, dict[str, torch.Tensor]] = defaultdict(dict) + for fqn, value in state_dict.items(): + if _is_quantized(value): + parts = fqn.rsplit(".", 1) + parent_fqn = parts[0] if len(parts) > 1 else "" + attr = parts[-1] + module_weights[parent_fqn][attr] = value + + for parent_fqn, weights in module_weights.items(): + module = model.get_submodule(parent_fqn) if parent_fqn else model + packer = packers.get(type(module)) + if packer is None: + raise ValueError( + f"No packer registered for {type(module).__name__} at '{parent_fqn}'. " + f"Registered types: {[t.__name__ for t in packers]}." + ) + packer(module, weights) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) + + for p in model.parameters(): + p.requires_grad_(False) + + +def pack_one( + model: nn.Module, + fqn: str, + value: torch.Tensor, + packers: dict[type, ModulePackerFn], +) -> None: + """Pack a single weight into ``model``. + + Quantized subclass tensors are dispatched to the packer for the parent + module's type. Plain tensors are assigned directly. + """ + parts = fqn.rsplit(".", 1) + parent_fqn = parts[0] if len(parts) > 1 else "" + attr = parts[-1] + parent = model.get_submodule(parent_fqn) if parent_fqn else model + + if _is_quantized(value): + packer = packers.get(type(parent)) + if packer is None: + raise ValueError( + f"No packer registered for {type(parent).__name__} at '{parent_fqn}'. " + f"Registered types: {[t.__name__ for t in packers]}." + ) + packer(parent, {attr: value}) + else: + if isinstance(getattr(parent, attr, None), nn.Parameter): + setattr(parent, attr, nn.Parameter(value, requires_grad=False)) + else: + parent.register_buffer(attr, value) diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py new file mode 100644 index 00000000000..7c834505d36 --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA packer: assign quantized weights to model modules. + +Passes ``Int4Tensor`` and ``IntxUnpackedToInt8Tensor`` through as +``nn.Parameter`` without conversion. The Int4Tensor dispatch override +(``int4_dispatch.py``) handles F.linear at runtime. + +No CUDA is required for packing. The backend-agnostic ``pack_model`` +dispatcher lives in ``pack.py``. +""" + +import json + +import torch +import torch.nn as nn + +from .pack import ModulePackerFn, pack_model # noqa: F401 + + +# --------------------------------------------------------------------------- +# Per-module packers + + +def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: + """Assign a quantized weight to an ``nn.Linear`` module.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + w = weights["weight"] + if isinstance(w, (Int4Tensor, IntxUnpackedToInt8Tensor)): + module.weight = nn.Parameter(w, requires_grad=False) + else: + raise ValueError(f"Unsupported weight type: {type(w).__name__}") + + +def pack_embedding_for_cuda( + module: nn.Module, weights: dict[str, torch.Tensor] +) -> None: + """Assign a quantized weight to an ``nn.Embedding`` (INT8 only).""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + w = weights["weight"] + if isinstance(w, Int4Tensor): + raise ValueError( + "Only 8-bit embedding quantization is supported on CUDA. " + "INT4 does not implement the embedding op." + ) + module.weight = nn.Parameter(w, requires_grad=False) + + +DEFAULT_CUDA_PACKERS: dict[type, ModulePackerFn] = { + nn.Linear: pack_linear_for_cuda, + nn.Embedding: pack_embedding_for_cuda, +} + + +# --------------------------------------------------------------------------- +# Load + pack (I/O wrapper) + + +def load_and_pack_for_cuda( + path: str, + model: nn.Module, + packers: dict[type, ModulePackerFn] | None = None, +) -> None: + """Load a quantized safetensors file and assign weights to the model.""" + from safetensors import safe_open + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) + + from .pack import pack_one + + _packers = packers or DEFAULT_CUDA_PACKERS + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + all_keys = list(f.keys()) + tensor_names = json.loads(metadata.get("tensor_names", "[]")) + + for name in tensor_names: + parts = name.rsplit(".", 1) + module_fqn = parts[0] if len(parts) > 1 else "" + weight_name = parts[-1] + prefix = ( + f"{module_fqn}._{weight_name}_" if module_fqn else f"_{weight_name}_" + ) + partial = {} + for key in all_keys: + if key.startswith(prefix) or key == name: + partial[key] = f.get_tensor(key) + result, _ = unflatten_tensor_state_dict(partial, metadata) + for fqn, value in result.items(): + pack_one(model, fqn, value, _packers) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py new file mode 100644 index 00000000000..ade85efd788 --- /dev/null +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantize weights to torchao tensor subclasses. + +``quantize_weight`` quantizes a single tensor given a ``QuantConfig``, +returning an ``Int4Tensor`` (4-bit) or ``IntxUnpackedToInt8Tensor`` (8-bit). + +``quantize_model`` walks a model's parameters, applies a ``QuantRecipe``, +and returns a single state dict containing both quantized subclass tensors +and unquantized plain tensors. +""" + +import torch +import torch.nn as nn + +from .recipe import QuantConfig, QuantRecipe + + +# --------------------------------------------------------------------------- +# Per-weight quantization + + +def _quantize_min_max( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Standard min/max 4-bit quantization. Returns (int_data, scale, zero_point).""" + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + qmin, qmax = (-8, 7) if config.symmetric else (0, 15) + + mapping = MappingType.SYMMETRIC if config.symmetric else MappingType.ASYMMETRIC + block_size = tuple([1] * (weight.ndim - 1) + [config.group_size]) + + scale, zero_point = choose_qparams_affine( + weight.float(), + mapping, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + scale_dtype=torch.bfloat16, + zero_point_dtype=torch.bfloat16, + ) + int_data = quantize_affine( + weight.float(), + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + return int_data, scale, zero_point + + +def _quantize_hqq_asymmetric( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Full HQQ (asymmetric, optimizes scale + zero). Requires CUDA.""" + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_affine_hqq, + ) + + device = weight.device + if device.type != "cuda": + device = torch.device("cuda") + + W_q, scale, zero, _shape = _choose_qparams_and_quantize_affine_hqq( + weight, + nbits=config.bits, + group_size=config.group_size, + axis=1, + compute_dtype=torch.bfloat16, + device=str(device), + raw_output=True, + ) + + int_data = W_q.to(torch.int8) + scale = scale.to(torch.bfloat16).reshape(*weight.shape[:-1], -1) + zero = zero.to(torch.bfloat16).reshape(*weight.shape[:-1], -1) + + return int_data, scale, zero + + +def _quantize_hqq_symmetric( + weight: torch.Tensor, + config: QuantConfig, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Scale-only HQQ (symmetric 4-bit, optimizes scale only). Runs on CPU or CUDA.""" + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + + qmin, qmax = -8, 7 + + orig_shape = weight.shape + weight_2d = weight.reshape(-1, weight.shape[-1]) if weight.ndim > 2 else weight + + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + weight_2d, + [1, config.group_size], + qmin, + qmax, + ) + + int_data = qdata.to(torch.int8).reshape(orig_shape) + scale = scale.to(torch.bfloat16).reshape(*orig_shape[:-1], -1) + zero_point = torch.zeros_like(scale) + + return int_data, scale, zero_point + + +def _to_int4_tensor( + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + config: QuantConfig, +) -> torch.Tensor: + """Wrap quantized 4-bit data into an Int4Tensor.""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + # Normalize 4-bit signed [-8, 7] to unsigned [0, 15] for storage. + if config.symmetric: + int_data = int_data + 8 + zero_point = torch.full_like(scale, 8.0) + + # Int4Tensor stores qdata as nibble-packed uint8 (N, K//2) + q = int_data.to(torch.uint8) + packed = q[..., ::2] | (q[..., 1::2] << 4) + + # Int4Tensor stores scale/zero as (K//gs, N) — transposed from our (N, K//gs) + return Int4Tensor( + qdata=packed, + scale=scale.t().contiguous(), + zero_point=zero_point.t().contiguous(), + block_size=[1, config.group_size], + shape=torch.Size(int_data.shape), + ) + + +def _to_intx_tensor( + weight: torch.Tensor, + config: QuantConfig, +) -> torch.Tensor: + """Quantize 8-bit and wrap in IntxUnpackedToInt8Tensor. + + Quantizes in float32 for numerical precision, then constructs the + subclass directly. We avoid ``from_hp`` because it quantizes in the + input dtype (bf16), which loses precision for small-magnitude weights. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + if config.method == "hqq": + if not config.symmetric: + raise ValueError( + "8-bit HQQ only supports symmetric quantization " + "(HQQ_SCALE_ONLY). Use method='min_max' for asymmetric 8-bit." + ) + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + + w2d = weight.float().reshape(-1, weight.shape[-1]) + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + w2d, [1, config.group_size], -128, 127 + ) + qdata = qdata.to(torch.int8).reshape(weight.shape) + scale = scale.to(torch.bfloat16).reshape(weight.shape[0], -1) + zero_point = torch.zeros_like(scale, dtype=torch.int8) + else: + mapping = MappingType.SYMMETRIC if config.symmetric else MappingType.ASYMMETRIC + block_size = (1, config.group_size) + scale, zero_point = choose_qparams_affine( + weight.float(), + mapping, + block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=torch.bfloat16, + zero_point_dtype=torch.int8, + ) + qdata = quantize_affine( + weight.float(), + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=-128, + quant_max=127, + ) + N, n_groups = weight.shape[0], weight.shape[-1] // config.group_size + scale = scale.reshape(N, n_groups) + zero_point = zero_point.reshape(N, n_groups) + + return IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=torch.int8, + block_size=(1, config.group_size), + dtype=torch.bfloat16, + activation_quantization=None, + ) + + +def quantize_weight(weight: torch.Tensor, config: QuantConfig) -> torch.Tensor: + """Quantize ``weight`` to a torchao tensor subclass. + + Returns ``Int4Tensor`` for 4-bit or ``IntxUnpackedToInt8Tensor`` for 8-bit. + """ + if config.bits == 8: + return _to_intx_tensor(weight, config) + + if config.bits != 4: + raise ValueError(f"Unsupported bits={config.bits}") + + if config.method == "min_max": + int_data, scale, zero_point = _quantize_min_max(weight, config) + elif config.method == "hqq": + if config.symmetric: + int_data, scale, zero_point = _quantize_hqq_symmetric(weight, config) + else: + int_data, scale, zero_point = _quantize_hqq_asymmetric(weight, config) + else: + raise ValueError( + f"Unknown quantization method: {config.method!r}. " + f"Supported: 'min_max', 'hqq'." + ) + + return _to_int4_tensor(int_data, scale, zero_point, config) + + +def dequantize_weight( + weight: torch.Tensor, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Dequantize a torchao quantized tensor back to a dense tensor.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + if isinstance(weight, Int4Tensor): + # Unpack nibbles + p = weight.qdata.to(torch.uint8) + low = (p & 0x0F).float() + high = ((p >> 4) & 0x0F).float() + qdata = torch.stack([low, high], dim=-1).reshape(weight.shape) + # Scale is (K//gs, N), transpose to (N, K//gs) for broadcast + gs = weight.block_size[-1] + scale = weight.scale.t().float().repeat_interleave(gs, dim=-1) + zero = weight.zero_point.t().float().repeat_interleave(gs, dim=-1) + return ((qdata - zero) * scale).to(dtype) + + if isinstance(weight, IntxUnpackedToInt8Tensor): + gs = weight.block_size[-1] + scale = weight.scale.float().repeat_interleave(gs, dim=-1) + zero = weight.zero_point.float().repeat_interleave(gs, dim=-1) + return ((weight.qdata.float() - zero) * scale).to(dtype) + + raise TypeError(f"Cannot dequantize {type(weight).__name__}") + + +# --------------------------------------------------------------------------- +# Per-model quantization + + +def quantize_model( + model: nn.Module, + recipe: QuantRecipe, + dtype: torch.dtype = torch.bfloat16, + verbose: bool = False, +) -> dict[str, torch.Tensor]: + """Walk model parameters + persistent buffers, apply recipe. + + Returns a single state dict containing quantized tensor subclasses + (``Int4Tensor``, ``IntxUnpackedToInt8Tensor``) and unquantized plain + tensors. Non-persistent buffers (KV cache, RoPE tables) are excluded. + """ + state: dict[str, torch.Tensor] = {} + persistent_keys = set(model.state_dict().keys()) + + n_params = sum(1 for _ in model.named_parameters()) + for i, (fqn, param) in enumerate(model.named_parameters()): + config = recipe.get_config(fqn) + if config is None: + state[fqn] = param.data.to(dtype) + else: + state[fqn] = quantize_weight(param.data, config) + if verbose: + print(f" Quantized {i + 1}/{n_params}: {fqn}", end="\r") + if verbose: + print() + + for fqn, buf in model.named_buffers(): + if fqn in persistent_keys and fqn not in state: + state[fqn] = buf.data + + return state diff --git a/examples/models/gemma4_31b/quant/recipe.py b/examples/models/gemma4_31b/quant/recipe.py new file mode 100644 index 00000000000..6e29a93ba3e --- /dev/null +++ b/examples/models/gemma4_31b/quant/recipe.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantization recipe: declares what to quantize and how. + +A ``QuantRecipe`` is an ordered list of ``QuantRule`` objects matched against +weight FQNs. First match wins. The recipe says nothing about packing format, +tensor subclass, or target backend. +""" + +import re +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass(frozen=True) +class QuantConfig: + """Per-weight quantization parameters (quantization-time only). + + Not stored in the serialized checkpoint — torchao tensor subclasses + carry their own metadata. This is purely for driving ``quantize_weight``. + """ + + bits: int # 4 or 8 + group_size: int + symmetric: bool # True = no zero point + method: str # "min_max" | "hqq" + + +@dataclass +class QuantRule: + """A single recipe rule: regex pattern + config + optional layer filter.""" + + pattern: str # regex matched against weight FQN + config: Optional[QuantConfig] # None = skip (leave unquantized) + layers: Optional[set[int]] = field(default=None, repr=False) # None = all layers + + +@dataclass +class QuantRecipe: + """Ordered list of rules. First match wins.""" + + rules: list[QuantRule] + + def get_config(self, fqn: str) -> Optional[QuantConfig]: + """Return the ``QuantConfig`` for a weight FQN, or ``None`` to skip.""" + layer_idx = self._extract_layer_idx(fqn) + for rule in self.rules: + if rule.layers is not None: + if layer_idx is None or layer_idx not in rule.layers: + continue + if re.fullmatch(rule.pattern, fqn): + return rule.config + return None + + @staticmethod + def _extract_layer_idx(fqn: str) -> Optional[int]: + m = re.search(r"layers\.(\d+)\.", fqn) + return int(m.group(1)) if m else None diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py new file mode 100644 index 00000000000..89a7099d6f0 --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/gguf.py — Q4_K and Q6_K unpacking. + +Tests verify the API contract: dequantized weights match the original +GGUF dequantization formula. Uses synthetic blocks — no GGUF file required. +""" + +import os +import struct +import tempfile +import unittest + +import numpy as np +import torch + +try: + from gguf import GGMLQuantizationType + + _HAS_GGUF = True +except ImportError: + _HAS_GGUF = False + +if _HAS_GGUF: + from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor + +from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight +from safetensors import safe_open +from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, +) + + +def _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals): + """Build one Q4_K block (144 bytes) from components.""" + buf = bytearray(144) + struct.pack_into("> 4) << 6 + scales_bytes[j] |= (sub_mins[j] >> 4) << 6 + buf[4:16] = scales_bytes + # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair + for g in range(4): + for i in range(32): + lo_val = qvals[g * 64 + i] + hi_val = qvals[g * 64 + 32 + i] + buf[16 + g * 32 + i] = (lo_val & 0xF) | ((hi_val & 0xF) << 4) + return buf + + +def _make_q6_k_block(d, scales_16, qvals_256): + """Build one Q6_K block (210 bytes) from components. + + ggml processes 128 values at a time. For each 128-value half: + ql: 64 bytes (two groups of 32, low/high nibbles) + qh: 32 bytes (2 bits each for 4 sub-positions) + The qvals_256 array is in output order (position 0..255). + """ + buf = bytearray(210) + # First half (positions 0..127): ql bytes 0..63, qh bytes 0..31 + for i in range(32): + buf[i] = (qvals_256[i] & 0x0F) | ((qvals_256[i + 64] & 0x0F) << 4) + for i in range(32): + buf[32 + i] = (qvals_256[i + 32] & 0x0F) | ((qvals_256[i + 96] & 0x0F) << 4) + for i in range(32): + h0 = (qvals_256[i] >> 4) & 0x03 + h1 = (qvals_256[i + 32] >> 4) & 0x03 + h2 = (qvals_256[i + 64] >> 4) & 0x03 + h3 = (qvals_256[i + 96] >> 4) & 0x03 + buf[128 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) + # Second half (positions 128..255): ql bytes 64..127, qh bytes 32..63 + for i in range(32): + buf[64 + i] = (qvals_256[i + 128] & 0x0F) | ((qvals_256[i + 192] & 0x0F) << 4) + for i in range(32): + buf[96 + i] = (qvals_256[i + 160] & 0x0F) | ((qvals_256[i + 224] & 0x0F) << 4) + for i in range(32): + h0 = (qvals_256[i + 128] >> 4) & 0x03 + h1 = (qvals_256[i + 160] >> 4) & 0x03 + h2 = (qvals_256[i + 192] >> 4) & 0x03 + h3 = (qvals_256[i + 224] >> 4) & 0x03 + buf[160 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) + # Scales and d + for i in range(16): + buf[192 + i] = scales_16[i] & 0xFF + struct.pack_into(" None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +class TestPackLinearInt4(unittest.TestCase): + """pack_linear_for_cuda with INT4 weights produces correct F.linear output.""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(0) + self.weight = torch.randn(256, 1024, dtype=torch.bfloat16) + + def _pack(self, symmetric=False, group_size=32): + config = QuantConfig( + bits=4, group_size=group_size, symmetric=symmetric, method="min_max" + ) + q = quantize_weight(self.weight, config) + module = nn.Linear(1024, 256, bias=False) + pack_linear_for_cuda(module, {"weight": q}) + module.cuda() + return module + + def test_shape_preserved(self): + module = self._pack() + self.assertEqual(module.weight.shape, torch.Size([256, 1024])) + + def test_asymmetric_decode(self): + module = self._pack(symmetric=False) + x = torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_symmetric_decode(self): + module = self._pack(symmetric=True) + x = torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_prefill_batch(self): + module = self._pack(symmetric=False) + x = torch.randn(64, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_different_group_sizes(self): + for gs in (32, 64, 128): + with self.subTest(group_size=gs): + module = self._pack(group_size=gs) + x = torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, self.weight.cuda()) + out = module(x) + rel_error = ( + out.float() - ref.float() + ).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + +class TestPackLinearInt8(unittest.TestCase): + """pack_linear_for_cuda with INT8 weights produces correct F.linear output.""" + + def setUp(self): + _require_cuda(self) + + def test_matmul_correct(self): + torch.manual_seed(0) + weight = torch.randn(256, 128, dtype=torch.bfloat16) + x = torch.randn(1, 128, dtype=torch.bfloat16) + ref = torch.nn.functional.linear(x.cuda(), weight.cuda()) + + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + module = nn.Linear(128, 256, bias=False) + pack_linear_for_cuda(module, {"weight": q}) + module.cuda() + out = module(x.cuda()) + + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_unsupported_type_raises(self): + module = nn.Linear(64, 32, bias=False) + with self.assertRaises(ValueError): + pack_linear_for_cuda(module, {"weight": torch.randn(32, 64)}) + + +class TestPackEmbedding(unittest.TestCase): + """pack_embedding_for_cuda with INT8 per-axis weights.""" + + def setUp(self): + _require_cuda(self) + + def test_gather_correct(self): + torch.manual_seed(0) + weight = torch.randn(1000, 64, dtype=torch.bfloat16) + ids = torch.tensor([0, 1, 42, 500, 999]) + ref = weight[ids] + + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + module = nn.Embedding(1000, 64) + pack_embedding_for_cuda(module, {"weight": q}) + module.cuda() + out = module(ids.cuda()) + + rel_error = ( + out.cpu().float() - ref.float() + ).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_rejects_4bit(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + module = nn.Embedding(100, 64) + with self.assertRaises(ValueError): + pack_embedding_for_cuda(module, {"weight": q}) + + +class TestPackModel(unittest.TestCase): + """pack_model handles mixed-precision models and disk loading.""" + + def setUp(self): + _require_cuda(self) + + def test_mixed_precision(self): + torch.manual_seed(0) + w4 = torch.randn(64, 128, dtype=torch.bfloat16) + w8 = torch.randn(64, 128, dtype=torch.bfloat16) + q4 = quantize_weight( + w4, + QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max"), + ) + q8 = quantize_weight( + w8, + QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max"), + ) + with torch.device("meta"): + model = nn.ModuleDict( + { + "q_proj": nn.Linear(128, 64, bias=False), + "v_proj": nn.Linear(128, 64, bias=False), + } + ) + pack_model( + model, {"q_proj.weight": q4, "v_proj.weight": q8}, DEFAULT_CUDA_PACKERS + ) + model.cuda() + x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + + ref4 = torch.nn.functional.linear(x, w4.cuda()) + out4 = model.q_proj(x) + self.assertLess( + (out4.float() - ref4.float()).abs().mean().item() + / ref4.float().abs().mean().item(), + 0.15, + ) + + ref8 = torch.nn.functional.linear(x, w8.cuda()) + out8 = model.v_proj(x) + self.assertLess( + (out8.float() - ref8.float()).abs().mean().item() + / ref8.float().abs().mean().item(), + 0.02, + ) + + def test_load_and_pack_from_disk(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(weight, config) + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + state = { + "proj.weight": q, + "norm.weight": torch.randn(64, dtype=torch.bfloat16), + } + td, md = flatten_tensor_state_dict(state) + save_file(td, path, metadata=md) + + with torch.device("meta"): + model = nn.ModuleDict( + { + "proj": nn.Linear(128, 64, bias=False), + "norm": nn.LayerNorm(64, bias=False), + } + ) + load_and_pack_for_cuda(path, model) + + self.assertEqual(model.proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.norm.weight.shape, torch.Size([64])) + + model.proj.cuda() + x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + ref = torch.nn.functional.linear(x, weight.cuda()) + out = model.proj(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_pack_one_quantized(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + with torch.device("meta"): + model = nn.ModuleDict({"proj": nn.Linear(128, 64, bias=False)}) + pack_one(model, "proj.weight", q, DEFAULT_CUDA_PACKERS) + self.assertNotEqual(model.proj.weight.device.type, "meta") + + def test_pack_one_plain_tensor(self): + with torch.device("meta"): + model = nn.ModuleDict({"norm": nn.LayerNorm(64, bias=False)}) + pack_one( + model, + "norm.weight", + torch.randn(64, dtype=torch.bfloat16), + DEFAULT_CUDA_PACKERS, + ) + self.assertEqual(model.norm.weight.dtype, torch.bfloat16) + + +class TestPackErrorPaths(unittest.TestCase): + + def setUp(self): + _require_cuda(self) + + def test_unregistered_module_type(self): + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 64)) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with torch.device("meta"): + model = nn.ModuleDict({"custom": CustomModule()}) + with self.assertRaises(ValueError) as ctx: + pack_model(model, {"custom.weight": q}, DEFAULT_CUDA_PACKERS) + self.assertIn("CustomModule", str(ctx.exception)) + + def test_missing_weight_detected(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + + with torch.device("meta"): + model = nn.ModuleDict( + { + "a": nn.Linear(64, 32, bias=False), + "b": nn.Linear(64, 32, bias=False), + } + ) + with self.assertRaises(RuntimeError) as ctx: + pack_model(model, {"a.weight": q}, DEFAULT_CUDA_PACKERS) + self.assertIn("b.weight", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/tests/test_quantize.py b/examples/models/gemma4_31b/quant/tests/test_quantize.py new file mode 100644 index 00000000000..43970c2bef2 --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_quantize.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/quantize.py.""" + +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + quantize_model, + quantize_weight, +) +from executorch.examples.models.gemma4_31b.quant.recipe import ( + QuantConfig, + QuantRecipe, + QuantRule, +) +from parameterized import parameterized +from torchao.quantization import IntxUnpackedToInt8Tensor +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + +class TestQuantizeWeight(unittest.TestCase): + @parameterized.expand( + [ + ("4bit_asym", 4, 32, False), + ("4bit_sym", 4, 32, True), + ("4bit_gs64", 4, 64, False), + ("8bit_sym", 8, 32, True), + ] + ) + def test_output_type(self, _name, bits, gs, sym): + config = QuantConfig(bits=bits, group_size=gs, symmetric=sym, method="min_max") + result = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + if bits == 4: + self.assertIsInstance(result, Int4Tensor) + self.assertEqual(result.shape, torch.Size([64, 128])) + else: + self.assertIsInstance(result, IntxUnpackedToInt8Tensor) + self.assertEqual(result.shape, torch.Size([64, 128])) + + def test_quantize_dequantize_roundtrip(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_dequantize_output_dtype(self): + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(torch.randn(32, 64, dtype=torch.bfloat16), config) + self.assertEqual(dequantize_weight(q, torch.float32).dtype, torch.float32) + self.assertEqual(dequantize_weight(q, torch.bfloat16).dtype, torch.bfloat16) + + def test_dequantize_symmetric_4bit(self): + torch.manual_seed(1) + weight = torch.randn(32, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q) + self.assertEqual(dequant.shape, (32, 64)) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_dequantize_int8(self): + torch.manual_seed(2) + weight = torch.randn(32, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_int8_small_weights_bf16_precision(self): + """INT8 quantization of small bf16 weights must use full int8 range. + + Regression: IntxUnpackedToInt8Tensor.from_hp quantizes in bf16, + which collapses per-group scales to a single value for weights + with abs_mean ~0.01 (e.g., Gemma 4 v_proj). Our _to_intx_tensor + casts to float32 first to avoid this. + """ + torch.manual_seed(42) + weight = torch.randn(64, 128, dtype=torch.bfloat16) * 0.01 + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_dequantize_int8_asymmetric(self): + torch.manual_seed(3) + weight = torch.randn(32, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=32, symmetric=False, method="min_max") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + def test_int8_per_axis(self): + """Per-axis (group_size == K) used for embeddings.""" + weight = torch.randn(256, 64, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=64, symmetric=True, method="min_max") + q = quantize_weight(weight, config) + self.assertIsInstance(q, IntxUnpackedToInt8Tensor) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.01) + + @parameterized.expand( + [ + ("unknown_method", QuantConfig(4, 32, False, "bogus"), "bogus"), + ("unsupported_bits", QuantConfig(3, 32, False, "min_max"), None), + ("hqq_8bit_asym", QuantConfig(8, 32, False, "hqq"), "symmetric"), + ] + ) + def test_invalid_config_raises(self, _name, config, expected_substr): + with self.assertRaises(ValueError) as ctx: + quantize_weight(torch.randn(32, 64), config) + if expected_substr: + self.assertIn(expected_substr, str(ctx.exception)) + + +class TestQuantizeWeightHQQ(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required for HQQ") + + def test_quantize_dequantize_roundtrip(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16, device="cuda") + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") + q = quantize_weight(weight, config) + dequant = dequantize_weight(q, dtype=torch.bfloat16).cpu() + rel_error = ( + dequant.float() - weight.cpu().float() + ).abs().mean() / weight.cpu().float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + def test_symmetric_scale_only(self): + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="hqq") + q = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + self.assertIsInstance(q, Int4Tensor) + + def test_int8_hqq_roundtrip(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="hqq") + q = quantize_weight(weight, config) + self.assertIsInstance(q, IntxUnpackedToInt8Tensor) + dequant = dequantize_weight(q, dtype=torch.bfloat16) + rel_error = ( + dequant.float() - weight.float() + ).abs().mean() / weight.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + +class TestQuantizeModel(unittest.TestCase): + def test_applies_recipe(self): + model = nn.ModuleDict( + { + "embed": nn.Embedding(32, 16), + "proj": nn.Linear(16, 32, bias=False), + "norm": nn.LayerNorm(32), + } + ) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + p.data.normal_(0, 0.02) + + recipe = QuantRecipe( + rules=[ + QuantRule(r"embed\.weight", None), + QuantRule(r"norm\.weight", None), + QuantRule(r".*\.weight", QuantConfig(4, 16, False, "min_max")), + ] + ) + state = quantize_model(model, recipe) + + self.assertIsInstance(state["proj.weight"], Int4Tensor) + self.assertIs(type(state["embed.weight"]), torch.Tensor) + self.assertIs(type(state["norm.weight"]), torch.Tensor) + + def test_persistent_buffers_included(self): + model = nn.Module() + model.weight = nn.Parameter(torch.randn(16, 32, dtype=torch.bfloat16)) + model.register_buffer("scalar", torch.ones(1)) + model.register_buffer("temp", torch.zeros(4), persistent=False) + + recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) + state = quantize_model(model, recipe) + + self.assertIn("scalar", state) + self.assertNotIn("temp", state) + + def test_unquantized_cast_to_dtype(self): + model = nn.ModuleDict({"proj": nn.Linear(16, 8, bias=False)}) + model.proj.weight.data = torch.randn(8, 16, dtype=torch.float32) + + recipe = QuantRecipe(rules=[QuantRule(r".*", None)]) + state = quantize_model(model, recipe, dtype=torch.float16) + + self.assertEqual(state["proj.weight"].dtype, torch.float16) + + def test_empty_model(self): + state = quantize_model(nn.Module(), QuantRecipe(rules=[])) + self.assertEqual(len(state), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/tests/test_recipe.py b/examples/models/gemma4_31b/quant/tests/test_recipe.py new file mode 100644 index 00000000000..199b13f3bd5 --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_recipe.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/recipe.py. CPU only — no CUDA, no model, no torchao.""" + +import unittest + +from executorch.examples.models.gemma4_31b.quant.recipe import ( + QuantConfig, + QuantRecipe, + QuantRule, +) + +from parameterized import parameterized + +_Q4 = QuantConfig(4, 32, True, "min_max") +_Q8 = QuantConfig(8, 32, True, "min_max") + + +class TestQuantRecipeGetConfig(unittest.TestCase): + """Tests for ``QuantRecipe.get_config`` — the core matching logic.""" + + @parameterized.expand( + [ + ( + "first_match_wins", + [QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)], + "layers.0.self_attn.v_proj.weight", + 8, + ), + ( + "fallthrough_to_catchall", + [QuantRule(r".*v_proj\.weight", _Q8), QuantRule(r".*\.weight", _Q4)], + "layers.0.self_attn.q_proj.weight", + 4, + ), + ( + "none_rule_skips", + [ + QuantRule(r"embed_tokens\.weight", None), + QuantRule(r".*\.weight", _Q4), + ], + "embed_tokens.weight", + None, + ), + ( + "unmatched_returns_none", + [QuantRule(r"foo", _Q4)], + "bar.weight", + None, + ), + ( + "empty_recipe", + [], + "anything", + None, + ), + ( + "fullmatch_not_partial", + [QuantRule(r"foo", _Q4)], + "foo.bar", + None, + ), + ( + "fullmatch_exact", + [QuantRule(r"foo", _Q4)], + "foo", + 4, + ), + ] + ) + def test_get_config(self, _name, rules, fqn, expected_bits): + recipe = QuantRecipe(rules=rules) + config = recipe.get_config(fqn) + if expected_bits is None: + self.assertIsNone(config) + else: + self.assertEqual(config.bits, expected_bits) + + +class TestQuantRecipeLayerFilter(unittest.TestCase): + """Tests for the ``layers`` field on ``QuantRule``.""" + + def test_layer_filter(self): + edge = set(range(5)) | set(range(55, 60)) + recipe = QuantRecipe( + rules=[ + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _Q8, layers=edge), + QuantRule(r".*\.weight", _Q4), + ] + ) + # Edge v_proj → 8-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.v_proj.weight").bits, 8) + self.assertEqual(recipe.get_config("layers.58.self_attn.v_proj.weight").bits, 8) + # Middle v_proj → falls through → 4-bit + self.assertEqual(recipe.get_config("layers.30.self_attn.v_proj.weight").bits, 4) + # q_proj always 4-bit + self.assertEqual(recipe.get_config("layers.0.self_attn.q_proj.weight").bits, 4) + # Non-layer FQN skips layer-filtered rule, hits catch-all + self.assertEqual(recipe.get_config("lm_head.weight").bits, 4) + + def test_layer_filter_with_none_config(self): + """Skip rule scoped to specific layers.""" + recipe = QuantRecipe( + rules=[ + QuantRule(r".*\.weight", None, layers={0}), + QuantRule(r".*\.weight", _Q4), + ] + ) + self.assertIsNone(recipe.get_config("layers.0.mlp.gate_proj.weight")) + self.assertEqual(recipe.get_config("layers.1.mlp.gate_proj.weight").bits, 4) + + +class TestProductionRecipes(unittest.TestCase): + """Regression tests for the production recipes in quantize_and_save.py.""" + + def test_default_recipe(self): + from executorch.examples.models.gemma4_31b.quantize_and_save import ( + GEMMA4_31B_DEFAULT_RECIPE, + ) + + r = GEMMA4_31B_DEFAULT_RECIPE + self.assertIsNone(r.get_config("layers.0.input_layernorm.weight")) + self.assertIsNone(r.get_config("layers.5.self_attn.q_norm.weight")) + self.assertIsNone(r.get_config("norm.weight")) + embed_cfg = r.get_config("embed_tokens.weight") + self.assertEqual(embed_cfg.bits, 8) + self.assertEqual(embed_cfg.group_size, 5376) + for fqn in ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.down_proj.weight", + "lm_head.weight", + ): + cfg = r.get_config(fqn) + self.assertEqual(cfg.bits, 4, fqn) + self.assertEqual(cfg.method, "min_max", fqn) + + def test_sensitive_recipe(self): + from executorch.examples.models.gemma4_31b.quantize_and_save import ( + GEMMA4_31B_SENSITIVE_RECIPE, + ) + + r = GEMMA4_31B_SENSITIVE_RECIPE + self.assertIsNone(r.get_config("layers.0.input_layernorm.weight")) + embed_cfg = r.get_config("embed_tokens.weight") + self.assertEqual(embed_cfg.bits, 8) + self.assertEqual(embed_cfg.group_size, 5376) + # Edge v_proj/down_proj → int8 + self.assertEqual(r.get_config("layers.0.self_attn.v_proj.weight").bits, 8) + self.assertEqual(r.get_config("layers.0.mlp.down_proj.weight").bits, 8) + self.assertEqual(r.get_config("layers.58.self_attn.v_proj.weight").bits, 8) + # Middle v_proj/down_proj → int4 + self.assertEqual(r.get_config("layers.30.self_attn.v_proj.weight").bits, 4) + self.assertEqual(r.get_config("layers.30.mlp.down_proj.weight").bits, 4) + # q_proj always int4 + self.assertEqual(r.get_config("layers.0.self_attn.q_proj.weight").bits, 4) + self.assertEqual(r.get_config("layers.30.self_attn.q_proj.weight").bits, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py b/examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py new file mode 100644 index 00000000000..7c4fa8decfa --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_safetensors_roundtrip.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Smoke tests: torchao subclasses survive safetensors roundtrip.""" + +import os +import tempfile +import unittest + +import torch + +from safetensors import safe_open +from safetensors.torch import save_file +from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, +) + + +def save(state_dict, path): + tensors_data, metadata = flatten_tensor_state_dict(state_dict) + save_file(tensors_data, path, metadata=metadata) + + +def load(path): + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + tensors = {k: f.get_tensor(k) for k in f.keys()} + result, _ = unflatten_tensor_state_dict(tensors, metadata) + return result + + +from torchao.quantization import IntxUnpackedToInt8Tensor +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + +def _make_int4(shape, group_size=32): + """Build a random Int4Tensor.""" + N, K = shape + packed = torch.randint(0, 255, (N, K // 2), dtype=torch.uint8) + scale = torch.randn(K // group_size, N, dtype=torch.bfloat16) + zp = torch.zeros(K // group_size, N, dtype=torch.bfloat16) + return Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zp, + block_size=[1, group_size], + shape=torch.Size([N, K]), + ) + + +def _make_int8(shape, group_size=32): + """Build a random IntxUnpackedToInt8Tensor.""" + N, K = shape + return IntxUnpackedToInt8Tensor( + qdata=torch.randint(-128, 127, (N, K), dtype=torch.int8), + scale=torch.randn(N, K // group_size, dtype=torch.bfloat16), + zero_point=torch.zeros(N, K // group_size, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, group_size), + dtype=torch.bfloat16, + activation_quantization=None, + ) + + +class TestSaveLoad(unittest.TestCase): + def test_int4_roundtrip(self): + """Int4Tensor survives save/load.""" + t = _make_int4((64, 128)) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"layer.weight": t}, path) + loaded = load(path) + + self.assertIn("layer.weight", loaded) + self.assertIsInstance(loaded["layer.weight"], Int4Tensor) + self.assertTrue(torch.equal(t.qdata, loaded["layer.weight"].qdata)) + self.assertTrue(torch.equal(t.scale, loaded["layer.weight"].scale)) + + def test_int8_roundtrip(self): + """IntxUnpackedToInt8Tensor survives save/load.""" + t = _make_int8((64, 128)) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"layer.weight": t}, path) + loaded = load(path) + + self.assertIn("layer.weight", loaded) + self.assertIsInstance(loaded["layer.weight"], IntxUnpackedToInt8Tensor) + self.assertTrue(torch.equal(t.qdata, loaded["layer.weight"].qdata)) + + def test_mixed_state_dict(self): + """Mixed Int4 + Int8 + plain tensor roundtrip.""" + state = { + "linear.weight": _make_int4((64, 128)), + "embed.weight": _make_int8((100, 64)), + "norm.weight": torch.randn(64, dtype=torch.bfloat16), + } + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save(state, path) + loaded = load(path) + + self.assertEqual(set(state.keys()), set(loaded.keys())) + self.assertIsInstance(loaded["linear.weight"], Int4Tensor) + self.assertIsInstance(loaded["embed.weight"], IntxUnpackedToInt8Tensor) + self.assertIsInstance(loaded["norm.weight"], torch.Tensor) + self.assertTrue(torch.equal(state["norm.weight"], loaded["norm.weight"])) + + def test_plain_tensor_only(self): + """State dict with only plain tensors roundtrips.""" + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"model.norm.weight": torch.randn(64, dtype=torch.bfloat16)}, path) + loaded = load(path) + self.assertIn("model.norm.weight", loaded) + + def test_3d_int4(self): + """3D Int4Tensor (MoE expert weights) roundtrips.""" + # 3D: (num_experts, N, K//2) packed + N, K, gs = 32, 64, 32 + packed = torch.randint(0, 255, (4, N, K // 2), dtype=torch.uint8) + scale = torch.randn(4, K // gs, N, dtype=torch.bfloat16) + zp = torch.zeros(4, K // gs, N, dtype=torch.bfloat16) + t = Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zp, + block_size=[1, 1, gs], + shape=torch.Size([4, N, K]), + ) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.safetensors") + save({"experts.w1": t}, path) + loaded = load(path) + self.assertTrue(torch.equal(t.qdata, loaded["experts.w1"].qdata)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/quantize_and_save.py b/examples/models/gemma4_31b/quantize_and_save.py new file mode 100644 index 00000000000..e654e12f637 --- /dev/null +++ b/examples/models/gemma4_31b/quantize_and_save.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quantize Gemma 4 31B-IT and save as a quantized checkpoint. + +Produces a safetensors file containing torchao tensor subclasses +(``Int4Tensor``, ``IntxUnpackedToInt8Tensor``) that can be loaded and +packed for any backend via ``load_and_pack_for_cuda`` or ``pack_model``. + +The default recipe runs on CPU. The sensitive recipe requires CUDA for +HQQ asymmetric quantization. + +Usage: + python quantize_and_save.py \\ + --model-dir ~/local/scripts/models/gemma-4-31B-it \\ + --output ./gemma4_31b_int4 \\ + --quant-recipe default +""" + +import argparse +import os +import shutil + +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + QuantConfig, + quantize_model, + QuantRecipe, + QuantRule, +) + +# --------------------------------------------------------------------------- +# Production recipes for Gemma 4 31B. +# +# Layer sensitivity: +# - v_proj and down_proj are the most sensitive to quantization error +# (first/last quarter of layers especially so). +# - q_proj, k_proj, o_proj, gate_proj, up_proj tolerate 4-bit well. +# - embed_tokens is an index lookup — INT8 per-axis is nearly lossless. +# - Norms and layer_scalar are tiny and must stay unquantized. + +_INT4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") +_INT4_HQQ = QuantConfig(bits=4, group_size=32, symmetric=False, method="hqq") +_INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") +_INT8_PER_AXIS = QuantConfig( # group_size = hidden_size (5376) for Gemma 4 31B + bits=8, group_size=5376, symmetric=True, method="min_max" +) +_EDGE_LAYERS = set(range(15)) | set(range(45, 60)) + +GEMMA4_31B_DEFAULT_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.weight", _INT4), + ] +) + +GEMMA4_31B_SENSITIVE_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _INT8, layers=_EDGE_LAYERS), + QuantRule(r".*\.weight", _INT4_HQQ), + ] +) + +_RECIPES = { + "default": GEMMA4_31B_DEFAULT_RECIPE, + "sensitive": GEMMA4_31B_SENSITIVE_RECIPE, +} + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Quantize Gemma 4 31B-IT and save as a quantized checkpoint." + ) + parser.add_argument( + "--model-dir", + required=True, + help="HuggingFace Gemma 4 31B-IT model dir.", + ) + parser.add_argument( + "--output", + default="./gemma4_31b_int4", + help="Output directory.", + ) + parser.add_argument( + "--quant-recipe", + default="default", + choices=list(_RECIPES), + help="'default': int4 min_max linears + int8 per-axis embedding. " + "'sensitive': int8 for edge-layer v_proj/down_proj, int4 hqq elsewhere.", + ) + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda"], + help="Target backend (the quantized checkpoint is backend-agnostic, " + "but this may influence default recipe selection in the future).", + ) + args = parser.parse_args() + + recipe = _RECIPES[args.quant_recipe] + + print("Loading checkpoint (lazy, shard-by-shard)...") + model, _ = Gemma4_31B.from_hf_checkpoint(args.model_dir) + + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + print("Untying embed_tokens / lm_head...") + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + print(f"Quantizing with recipe '{args.quant_recipe}'...") + state_dict = quantize_model(model, recipe, verbose=True) + + os.makedirs(args.output, exist_ok=True) + safetensors_path = os.path.join(args.output, "model.safetensors") + print("Saving quantized checkpoint...") + from safetensors.torch import save_file + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + ) + + tensors_data, metadata = flatten_tensor_state_dict(state_dict) + save_file(tensors_data, safetensors_path, metadata=metadata) + n_tensors = len(state_dict) + + for filename in ("config.json", "tokenizer.json", "tokenizer_config.json"): + src = os.path.join(args.model_dir, filename) + if os.path.exists(src): + shutil.copy2(src, os.path.join(args.output, filename)) + + size_mb = os.path.getsize(safetensors_path) / (1024 * 1024) + print(f"Saved {n_tensors} tensors ({size_mb:.1f} MB) to {args.output}/") + print(f"Done. Use with: python export.py --prequantized {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma4_31b/sampler.py b/examples/models/gemma4_31b/sampler.py new file mode 100644 index 00000000000..45e4e17887a --- /dev/null +++ b/examples/models/gemma4_31b/sampler.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""GPU-side Gumbel-max sampler. + +Mirrors ``examples/models/qwen3_5_moe/sampler.py``: a single-output sampler +that lets one exported program be re-driven with different temperatures +without re-export. ``temperature=None`` is a no-op (returns logits). +""" + +from typing import Optional + +import torch + + +def sample( + logits: torch.Tensor, + temperature: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Draw a single token per batch row using the Gumbel-max trick. + + Args: + logits: ``[B, V]`` float32 logits (already soft-capped if applicable). + temperature: 0-D or 1-D float tensor; clamped to >= 1e-6 so a 0 + temperature still works ("near-greedy"). When ``None`` the call + short-circuits and returns ``logits`` unchanged. + + Returns: + ``[B, 1]`` float32 token IDs (``argmax(logits/T + gumbel_noise)``), + or the unmodified logits when ``temperature`` is ``None``. + """ + if temperature is None: + return logits + + logits = logits / temperature.clamp(min=1e-6) + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits + gumbel).argmax(dim=-1, keepdim=True).float() diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py new file mode 100644 index 00000000000..0ff28aac415 --- /dev/null +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CUDA-specific integration tests for the Gemma 4 31B-IT pipeline. + +Tests pack → inference → export on a tiny model using the CUDA backend. +Backend-agnostic tests (quantize, save, load) live in ``test_pipeline.py``. + +Requires CUDA. + +Usage: + python -m pytest examples/models/gemma4_31b/tests/test_cuda_pipeline.py -v +""" + +import os +import tempfile +import unittest + +# Register Int4Tensor dispatch before any model usage +import executorch.backends.cuda.int4_dispatch # noqa: F401 + +import torch +import torch.nn as nn +from executorch.examples.models.gemma4_31b.export import ( + export_and_lower, + load_prequantized_model, +) +from executorch.examples.models.gemma4_31b.inference import _move_to_cuda, generate +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + DEFAULT_CUDA_PACKERS, + pack_model, + quantize_model, +) +from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( + build_hf_checkpoint, + DEFAULT_RECIPE, + MockTokenizer, + save_checkpoint, + TINY_CONFIG, +) + + +def _require_cuda(testcase: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + testcase.skipTest("CUDA required") + + +class TestCudaInference(unittest.TestCase): + def setUp(self): + _require_cuda(self) + + def test_generate(self): + """save → load → pack → generate.""" + with tempfile.TemporaryDirectory() as tmpdir: + save_checkpoint(tmpdir) + model, config = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + _move_to_cuda(model, config) + model.eval() + tokenizer = MockTokenizer(TINY_CONFIG.vocab_size) + + torch.manual_seed(0) + out = generate(model, tokenizer, prompt="hi", max_new_tokens=5, temperature=1.0) + self.assertIsInstance(out, str) + ids_part = out[len("" + + +def config_dict() -> dict: + cfg = TINY_CONFIG + return { + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_hidden_layers": cfg.num_hidden_layers, + "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, + "head_dim": cfg.head_dim, + "num_global_key_value_heads": cfg.num_global_key_value_heads, + "global_head_dim": cfg.global_head_dim, + "attention_k_eq_v": cfg.attention_k_eq_v, + "rope_parameters": { + "sliding_attention": {"rope_theta": cfg.sliding_rope_theta}, + "full_attention": { + "rope_theta": cfg.full_rope_theta, + "partial_rotary_factor": cfg.full_partial_rotary_factor, + }, + }, + "rms_norm_eps": cfg.rms_norm_eps, + "hidden_activation": cfg.hidden_activation, + "final_logit_softcapping": cfg.final_logit_softcapping, + "tie_word_embeddings": cfg.tie_word_embeddings, + "sliding_window": cfg.sliding_window, + "layer_types": cfg.layer_types, + } + + +def build_random_tiny_model() -> Gemma4_31B: + torch.manual_seed(42) + model = Gemma4_31B(TINY_CONFIG) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + model.eval() + return model + + +def save_checkpoint(output_dir: str): + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, DEFAULT_RECIPE) + os.makedirs(output_dir, exist_ok=True) + td, md = flatten_tensor_state_dict(state_dict) + save_file(td, os.path.join(output_dir, "model.safetensors"), metadata=md) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + +def build_hf_checkpoint(output_dir: str) -> None: + model = build_random_tiny_model() + sd = model.state_dict() + sd.pop("lm_head.weight", None) + hf_sd = {f"model.language_model.{k}": v.contiguous() for k, v in sd.items()} + save_file(hf_sd, os.path.join(output_dir, "model.safetensors")) + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + +# --------------------------------------------------------------------------- +# Tests (CPU only, no backend dependency) + + +class TestQuantizeSaveLoadRoundtrip(unittest.TestCase): + def test_roundtrip_preserves_weights(self): + """quantize → save → load recovers all weights.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, DEFAULT_RECIPE) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.safetensors") + td, md = flatten_tensor_state_dict(state_dict) + save_file(td, path, metadata=md) + with safe_open(path, framework="pt", device="cpu") as f: + loaded_meta = f.metadata() + loaded_tensors = {k: f.get_tensor(k) for k in f.keys()} + loaded, _ = unflatten_tensor_state_dict(loaded_tensors, loaded_meta) + + self.assertEqual(set(state_dict.keys()), set(loaded.keys())) + for fqn in state_dict: + orig = state_dict[fqn] + got = loaded[fqn] + self.assertEqual(type(orig).__name__, type(got).__name__) + if isinstance(orig, Int4Tensor): + self.assertTrue(torch.equal(orig.qdata, got.qdata)) + self.assertTrue(torch.equal(orig.scale, got.scale)) + elif isinstance(orig, IntxUnpackedToInt8Tensor): + self.assertTrue(torch.equal(orig.qdata, got.qdata)) + self.assertTrue(torch.equal(orig.scale, got.scale)) + elif isinstance(orig, torch.Tensor): + self.assertTrue(torch.equal(orig, got)) + + def test_embedding_quantized_as_int8(self): + """embed_tokens is quantized to INT8 (IntxUnpackedToInt8Tensor).""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, DEFAULT_RECIPE) + + self.assertIn("embed_tokens.weight", state_dict) + self.assertIsInstance( + state_dict["embed_tokens.weight"], IntxUnpackedToInt8Tensor + ) + + +class TestRingKVCache(unittest.TestCase): + """Unit tests for the ring-buffer KV cache (CPU, no model needed).""" + + def _make_cache(self, window=4, heads=2, head_dim=8): + return RingKVCache( + max_batch_size=1, window_size=window, num_kv_heads=heads, head_dim=head_dim + ) + + def test_sequential_write_read(self): + """Writing positions 0..buf_size-1 fills every slot exactly once.""" + cache = self._make_cache(window=4) + buf_size = cache.buf_size # 8 + for i in range(buf_size): + pos = torch.tensor([i], dtype=torch.long) + k = torch.full((1, 2, 1, 8), float(i)) + v = torch.full((1, 2, 1, 8), float(i + 100)) + k_out, v_out = cache.update(pos, k, v) + for i in range(buf_size): + slot = i % buf_size + self.assertEqual(k_out[0, 0, slot, 0].item(), float(i)) + self.assertEqual(v_out[0, 0, slot, 0].item(), float(i + 100)) + + def test_wraparound_overwrites_oldest(self): + """Position buf_size overwrites slot 0 (the oldest entry).""" + cache = self._make_cache(window=4) + buf_size = cache.buf_size # 8 + for i in range(buf_size + 1): + pos = torch.tensor([i], dtype=torch.long) + k = torch.full((1, 2, 1, 8), float(i)) + v = torch.full((1, 2, 1, 8), float(i)) + k_out, _ = cache.update(pos, k, v) + # Slot 0 should now contain position buf_size (not 0) + self.assertEqual(k_out[0, 0, 0, 0].item(), float(buf_size)) + # Slot 1 should still contain position 1 + self.assertEqual(k_out[0, 0, 1, 0].item(), 1.0) + + def test_multi_token_prefill(self): + """Writing multiple positions in one call places them correctly.""" + cache = self._make_cache(window=4) + pos = torch.arange(4, dtype=torch.long) + k = torch.arange(4).float().view(1, 1, 4, 1).expand(1, 2, 4, 8) + v = torch.zeros(1, 2, 4, 8) + k_out, _ = cache.update(pos, k, v) + for i in range(4): + self.assertEqual(k_out[0, 0, i, 0].item(), float(i)) + + def test_assert_on_oversized_prefill(self): + """seq_len > buf_size raises AssertionError.""" + cache = self._make_cache(window=4) + buf_size = cache.buf_size + pos = torch.arange(buf_size + 1, dtype=torch.long) + k = torch.zeros(1, 2, buf_size + 1, 8) + v = torch.zeros(1, 2, buf_size + 1, 8) + with self.assertRaises(AssertionError): + cache.update(pos, k, v) + + +class TestGgufKeyMapping(unittest.TestCase): + """Unit tests for gguf_loader.gguf_to_model_key (CPU, no GGUF file needed).""" + + def test_attention_keys(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertEqual( + gguf_to_model_key("blk.0.attn_q.weight"), + "layers.0.self_attn.q_proj.weight", + ) + self.assertEqual( + gguf_to_model_key("blk.59.attn_output.weight"), + "layers.59.self_attn.o_proj.weight", + ) + + def test_mlp_keys(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertEqual( + gguf_to_model_key("blk.5.ffn_gate.weight"), + "layers.5.mlp.gate_proj.weight", + ) + + def test_global_keys(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertEqual(gguf_to_model_key("token_embd.weight"), "embed_tokens.weight") + self.assertEqual(gguf_to_model_key("output_norm.weight"), "norm.weight") + + def test_unknown_key_returns_none(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertIsNone(gguf_to_model_key("blk.0.some_unknown.weight")) + + def test_ignored_key_returns_none(self): + from executorch.examples.models.gemma4_31b.gguf_loader import gguf_to_model_key + + self.assertIsNone(gguf_to_model_key("rope_freqs.weight")) + + +if __name__ == "__main__": + unittest.main()