Skip to content

Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch#19213

Merged
mergennachin merged 14 commits into
mainfrom
gemma4-31b-quant-framework
May 12, 2026
Merged

Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch#19213
mergennachin merged 14 commits into
mainfrom
gemma4-31b-quant-framework

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Apr 29, 2026

Text-only export of Gemma 4 31B-IT to ExecuTorch with INT4/INT8 weight quantization. Quantized weights use torchao's native tensor subclasses (Int4Tensor, IntxUnpackedToInt8Tensor) for serialization, aligning with the torchao ecosystem.

quant/ package separates quantization into independent modules:

  • recipe.py: declarative QuantRecipe with regex FQN matching and per-layer overrides
  • quantize.py: quantize_weight / dequantize_weight / quantize_model — returns torchao subclasses directly. 8-bit fully delegates to IntxUnpackedToInt8Tensor.from_hp (min_max and HQQ). 4-bit uses torchao primitives + manual Int4Tensor construction (pending mslk availability for from_hp)
  • pack.py: pack_model (bulk, groups by parent for MoE) and pack_one (streaming). Dispatches via isinstance(_, TorchAOBaseTensor)
  • pack_cuda.py: converts Int4Tensor to IntxUnpackedToInt8Tensor (int4 values unpacked to int8) and passes INT8 IntxUnpackedToInt8Tensor through unchanged. No CUDA required for packing — the CUDA-specific tinygemm conversion is a source transform applied at export time
  • gguf.py: unpack Q4_K/Q6_K GGUF blocks directly to Int4Tensor/IntxUnpackedToInt8Tensor, with streaming iterator

Serialization uses torchao's safetensors integration (torchao.prototype.safetensors) — no custom format. Checkpoints are compatible with torchao's save_pretrained/load_pretrained and can be loaded by vLLM.

This framework is designed to be promoted and reused for Qwen 3.5 MoE and other models — adding a new model requires only a QuantRecipe and optionally a custom packer.

Quantization recipes: "default" (INT4 min_max linears + INT8 per-axis embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ asymmetric elsewhere).

Dual-path INT4 linear dispatch: IntxUnpackedToInt8Tensor's F.linear dispatch dequantizes to bf16 and calls cuBLAS, optimal for prefill (12x faster than tinygemm at T=2048). For decode, a model-agnostic source transform (backends/cuda/transforms/int4_linear_dispatch.py) converts to Int4TilePackedTo4dTensor (tinygemm), optimal for M=1. Export flow: prefill first (dequant+cuBLAS), then tinygemm transform, then decode export. inference.py applies the tinygemm transform for fast eager decode.

Split-K flash-decoding: ReplaceEdgeOpWithTritonOpPass in the CUDA backend selects triton::sdpa_decode_splitk for SDPA nodes where L_q=1 and L_kv exceeds 2048. At 128K context, full-attention decode SDPA improves from 15.7ms/layer to 0.7ms/layer (22x). Sliding-window layers (ring buffer <= 2048) use standard triton::sdpa. No model code changes — the pass inspects Q/K shapes in the exported graph automatically.

GGUF support: inference.py --gguf and export.py --gguf load community-quantized GGUF files directly. Tied embed/lm_head is untied — embedding dequantized to bf16 for gather, lm_head keeps INT4 for matmul.

Ring-buffer KV cache: Sliding window layers use RingKVCache (2x window) instead of flat max_seq_len buffers. The C++ runner chunks long prompts automatically via get_max_prefill_chunk metadata. Chunked prefill produces identical logits to sequential (verified by test).

Includes: C++ runner with BOS/EOS handling, chunked prefill, and #ifdef guards for non-CUDA builds; eager inference with torch.compile; unit and integration tests across quant/tests/, tests/, and backends/cuda/tests/.

┌──────────────────┬────────────────────┐                                                                                                                                                                                                                                                    
│      Metric      │       Value        │                                                                                                                                                                                                                                                    
├──────────────────┼────────────────────┤                                                                                                                                                                                                                                                    
│ Prompt tokens    │ 513                │                                                                                                                                                                                                                                                    
├──────────────────┼────────────────────┤
│ Generated tokens │ 128                │                                                                                                                                                                                                                                                    
├──────────────────┼────────────────────┤                                        
│ Prefill          │ 766 tok/s (670ms)  │
├──────────────────┼────────────────────┤                                                                                                                                                                                                                                                    
│ Decode           │ 21.5 tok/s           │
├──────────────────┼────────────────────┤                                                                                                                                                                                                                                                    
│ TTFT             │ 89ms               │                                        
├──────────────────┼────────────────────┤
│ GPU peak         │ 25.1GB             │
├──────────────────┼────────────────────┤
│ Model load       │ 28.8s              │
└──────────────────┴────────────────────┘

Copilot AI review requested due to automatic review settings April 29, 2026 21:06
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 29, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19213

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 154 Pending, 1 Unclassified Failure

As of commit e7375a1 with merge base f1062a7 (image):

UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new Gemma 4 31B-IT example pipeline for ExecuTorch (CUDA backend), including a packing-agnostic quantization format + recipes, CUDA packers, export/inference scripts, a C++ runner, and CI coverage.

Changes:

  • Introduces examples/models/gemma4_31b/quant/ with recipe → quantize → serialize → pack flow plus unit tests.
  • Adds Gemma 4 31B model implementation with hybrid attention and a sliding-window KV cache, plus export + eager inference entrypoints.
  • Adds CUDA runner build targets and runs Gemma 4 31B tests in the CUDA GitHub Actions workflow.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/test_pipeline.py CPU-only integration tests for quantize/save/load roundtrip and tiny checkpoint fixtures.
examples/models/gemma4_31b/test_cuda_pipeline.py CUDA integration tests for pack/infer/export on a tiny model.
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler used by the exported model.
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoints and write packing-agnostic safetensors bundles + production recipes.
examples/models/gemma4_31b/quant/test_serialize.py Unit tests for nibble packing and safetensors serialization format.
examples/models/gemma4_31b/quant/test_recipe.py Unit tests for regex/layer-filter recipe matching + production recipe regression tests.
examples/models/gemma4_31b/quant/test_quantize.py Unit tests for quantize_weight and quantize_model APIs (CPU + CUDA/HQQ paths).
examples/models/gemma4_31b/quant/test_pack_cuda.py CUDA unit tests for int4/int8 packers and load-and-pack dispatcher behavior.
examples/models/gemma4_31b/quant/serialize.py Canonical quantized weight format + safetensors save/load with versioned metadata.
examples/models/gemma4_31b/quant/recipe.py Declarative quantization recipe/rule objects with regex FQN matching and optional layer filters.
examples/models/gemma4_31b/quant/quantize.py Implements min-max and HQQ quantization into canonical (packing-free) representations.
examples/models/gemma4_31b/quant/pack_cuda.py CUDA-specific packers converting canonical weights into torchao runtime tensor subclasses.
examples/models/gemma4_31b/quant/pack.py Backend-agnostic pack dispatcher that assigns weights/buffers and calls module-type packers.
examples/models/gemma4_31b/quant/init.py Public API re-exports for quant/ package.
examples/models/gemma4_31b/quant/README.md Documentation of the quant framework, data flow, and backend extension points.
examples/models/gemma4_31b/model.py Gemma 4 31B model definition, HF checkpoint loader, ring KV cache for sliding layers, runtime buffer materialization.
examples/models/gemma4_31b/model.md Architecture/design notes for model + quant pipeline.
examples/models/gemma4_31b/main.cpp ExecuTorch CUDA runner driving exported prefill/decode and HF tokenizer decoding.
examples/models/gemma4_31b/inference.py Eager CUDA inference script loading prequantized weights, packing, and generating text.
examples/models/gemma4_31b/export.py Export + lowering pipeline (decode + prefill methods) targeting the CUDA backend.
examples/models/gemma4_31b/init.py Package marker for the new model example.
examples/models/gemma4_31b/README.md User-facing instructions for quantize/export/inference/build/run workflows.
examples/models/gemma4_31b/CMakePresets.json CMake preset for building the Gemma 4 31B CUDA runner.
examples/models/gemma4_31b/CMakeLists.txt CMake build for the Gemma 4 31B runner, linking ExecuTorch + CUDA backend + tokenizer.
examples/models/gemma4/text_decoder/gemma4_norm.py Replaces transformers RMSNorm dependency with a self-contained implementation.
examples/models/gemma4/text_decoder/init.py Exposes attention/norm/MLP primitives used by gemma4_31b for shared numerically-sensitive ops.
Makefile Adds gemma4_31b-cuda build target.
.github/workflows/cuda.yml Adds Gemma 4 31B quant + pipeline tests to the CUDA CI job.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/model.py
Comment thread examples/models/gemma4_31b/model.py
Comment thread examples/models/gemma4_31b/quant/pack.py Outdated
Comment thread examples/models/gemma4_31b/test_pipeline.py Outdated
Comment thread examples/models/gemma4_31b/quant/recipe.py Outdated
Copilot AI review requested due to automatic review settings April 30, 2026 14:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a full Gemma 4 31B-IT example to ExecuTorch, including a new packing-agnostic quantization framework, CUDA packing/export/inference tooling, GGUF import support, a C++ CUDA runner, and a comprehensive test suite integrated into CI.

Changes:

  • Introduce examples/models/gemma4_31b/quant/ canonical quantization framework (recipe → quantize → serialize → pack) with CUDA packers and safetensors persistence.
  • Add Gemma 4 31B-IT model implementation with ring-buffer KV cache for sliding-window layers, plus export/eager inference/runner scripts.
  • Add unit + integration tests (CPU and CUDA) and run them in the CUDA CI workflow.

Reviewed changes

Copilot reviewed 31 out of 31 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/tests/test_pipeline.py CPU-only integration tests for quantize→save→load roundtrip and fixtures for CUDA tests.
examples/models/gemma4_31b/tests/test_cuda_pipeline.py CUDA integration tests for packing, generation, chunked prefill, and export.
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler used by the exported model for on-device sampling.
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoints and persist packing-agnostic safetensors checkpoints.
examples/models/gemma4_31b/quant/tests/test_serialize.py Unit tests for canonical serialize/deserialize and nibble pack/unpack.
examples/models/gemma4_31b/quant/tests/test_recipe.py Unit tests for regex/layer-filter recipe matching + production recipe regression tests.
examples/models/gemma4_31b/quant/tests/test_quantize.py Unit tests for canonical quantize/dequantize APIs and model-walking quantization.
examples/models/gemma4_31b/quant/tests/test_pack_cuda.py CUDA unit tests for packing canonical weights into CUDA runtime formats and dispatch.
examples/models/gemma4_31b/quant/tests/test_gguf.py Unit tests validating GGUF Q4_K/Q6_K unpacking against reference formulas.
examples/models/gemma4_31b/quant/serialize.py Canonical CQW representation + safetensors format, nibble packing, save/load.
examples/models/gemma4_31b/quant/recipe.py Declarative quantization recipe/rule/config structures and matching logic.
examples/models/gemma4_31b/quant/quantize.py Canonical quantization implementations (min_max, HQQ) + per-model quantization walk.
examples/models/gemma4_31b/quant/pack_cuda.py CUDA packers from canonical weights to tinygemm/int8 subclass runtime formats.
examples/models/gemma4_31b/quant/pack.py Backend-agnostic pack dispatcher grouping weights per module and applying packers.
examples/models/gemma4_31b/quant/gguf.py GGUF tensor unpacker/streamer to canonical CQW or dense tensors.
examples/models/gemma4_31b/quant/init.py Public API re-exports for quant/ package.
examples/models/gemma4_31b/quant/README.md Documentation of the quant framework layers, data flow, and on-disk format.
examples/models/gemma4_31b/model.py Gemma 4 31B-IT model definition, ring-buffer KV cache, HF load/remap, runtime buffer materialization.
examples/models/gemma4_31b/model.md Architecture/design notes including attention flavors, caching strategy, and export methods.
examples/models/gemma4_31b/main.cpp CUDA ExecuTorch runner driving exported prefill/decode with tokenizer integration.
examples/models/gemma4_31b/inference.py Eager CUDA inference path (load/pack/materialize/compile + generate loop).
examples/models/gemma4_31b/gguf_loader.py Streams GGUF weights, remaps keys, handles tied weights, and packs into runtime model.
examples/models/gemma4_31b/export.py Export+lower pipeline producing .pte/.ptd with decode/prefill methods and CUDA lowering.
examples/models/gemma4_31b/init.py Package marker for the Gemma 4 31B example.
examples/models/gemma4_31b/README.md User-facing docs for quantize/export/inference/build/run flows.
examples/models/gemma4_31b/CMakePresets.json CMake workflow presets to build the CUDA runner.
examples/models/gemma4_31b/CMakeLists.txt CMake build for the Gemma 4 31B CUDA runner binary.
examples/models/gemma4/text_decoder/gemma4_norm.py Replace transformers RMSNorm dependency with a self-contained implementation.
examples/models/gemma4/text_decoder/init.py Export shared primitives (attention helpers, KV cache, MLP, norm) for reuse by gemma4_31b.
Makefile Add gemma4_31b-cuda build target and help entry.
.github/workflows/cuda.yml Run Gemma 4 31B quant and pipeline tests as part of CUDA unit test job.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/quant/quantize.py Outdated
Comment thread examples/models/gemma4_31b/quant/tests/test_gguf.py Outdated
Comment thread .github/workflows/cuda.yml
Comment thread examples/models/gemma4_31b/tests/test_cuda_pipeline.py Outdated
Comment thread examples/models/gemma4_31b/tests/test_pipeline.py
Comment thread examples/models/gemma4_31b/tests/test_cuda_pipeline.py Outdated
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mergennachin's task in 2m 54s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review quantization framework (quant/)
  • Review model implementation (model.py)
  • Review export and inference scripts
  • Review GGUF loader
  • Review C++ runner and build system
  • Review tests
  • Review CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a well-structured, substantial addition (~5300 lines) that introduces a full Gemma 4 31B-IT pipeline. The architecture is clean — the quantization framework separates concerns nicely (recipe → quantize → serialize → pack), and the code is well-organized for reuse with future models. The test coverage is good (78 tests across multiple files), and the CI integration is appropriate.

I've reviewed the full diff and have the following feedback:


Architecture & Design (Positive)

  1. Clean separation of concerns in quant/: recipe declares intent, quantize produces canonical form, serialize handles persistence, pack handles backend-specific conversion. This layering is excellent for extensibility.

  2. Ring-buffer KV cache (model.py:64-98) is a good memory optimization for sliding-window layers. The buf_size = window_size * 2 sizing combined with chunked prefill in the runner is sound.

  3. GGUF support with streaming iteration (iter_gguf_tensors) keeps peak memory low. The tied embed/lm_head handling (dequantize embedding to bf16 for gather, keep lm_head as INT4) is a pragmatic choice.

  4. The pack_one API (pack.py:80-109) enables streaming GGUF loading without buffering all weights before packing — nice design for 31B-scale models.


Issues

1. QuantConfig.bits documentation vs implementation mismatch (quant/recipe.py:21)

The comment says # 4, 8 but the Copilot reviewer noted that tests use bits=6 (via GGUF Q6_K) and the QuantConfig doesn't validate. In practice, _unpack_q6_k produces a CQW with bits=8 and method="gguf_q6_k" (since 6-bit values are widened to int8), so this is mostly a documentation issue. But it would be clearer to document the contract as "storage width: 4 or 8" since that's what the rest of the pipeline actually supports.

2. Sliding-window mask correctness for multi-token prefill (model.py:448-454)

The ring_pos calculation:

ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size

This computes the "logical position" for each ring-buffer slot. When input_pos[0] == 0 and seq_len == buf_size, total_written = buf_size, and ring_pos[j] = j for all j, which is correct (all slots were just written). The mask delta >= 0 ensures causality, and delta < window clips the sliding window.

However, I'd recommend adding a runtime assert or at least a more prominent comment that seq_len <= buf_size is a hard invariant. The C++ runner respects get_max_prefill_chunk, but nothing in the Python model's forward prevents violation. If someone calls model(tokens, input_pos) with seq_len > buf_size, the index_copy_ on line 96 has non-deterministic behavior (as the comment notes) but no error is raised. Consider:

# In RingKVCache.update:
assert input_pos.shape[0] <= self.buf_size, f"seq_len {input_pos.shape[0]} > buf_size {self.buf_size}"

This would catch misuse in eager mode (export would strip it).

3. C++ runner: read_token casts float to uint64 (main.cpp:60-86)

float val = 0.0f;
// ... read float from tensor ...
return static_cast<uint64_t>(val);

The sampler returns argmax(...).float(), so val is a token ID stored as float. The cast float → uint64_t truncates. For token IDs up to 262144 (Gemma 4 vocab), float32 has sufficient precision (exact up to 2^24 = 16M). This is fine in practice but worth noting — if vocab ever exceeds 16M, this would silently produce wrong IDs. A comment like "// safe: vocab_size < 2^24 so float→int is exact" would help.

4. _INT8_PER_AXIS group_size is hardcoded to hidden_size (quantize_and_save.py:51)

_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max")

This hardcodes group_size=5376 (the hidden_size). If someone reuses this recipe with a model whose embedding dim differs, the quantization will produce wrong-shaped scales. Since the recipe is specifically for Gemma 4 31B this is acceptable, but a comment noting the coupling would be helpful.

5. GGUF test import handling (quant/tests/test_gguf.py:28-29)

if _HAS_GGUF:
    from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor

The @unittest.skipUnless(_HAS_GGUF, ...) decorators handle skipping correctly. However, dequantize_weight and serialize/deserialize are imported unconditionally (line 30-31), which is fine since they don't depend on gguf. The CI workflow correctly installs gguf before running tests (line 152 of cuda.yml). This is properly handled.

6. Chunked prefill test uses allclose appropriately (tests/test_cuda_pipeline.py:129-138)

The test uses atol=1e-2, rtol=1e-3 which is reasonable for comparing across different CUDA kernel execution paths. Good fix from the earlier Copilot suggestion about torch.equal being too strict.


Minor Suggestions

  1. model.py:486-495 — The temperature is None branch returns full (B, T, V) logits, while the temperature branch returns (B, 1) token IDs. The return type depends on runtime input, which makes the type signature unclear. This is documented in the docstring, but a @torch.no_grad() decorator on forward would also be appropriate since this model is inference-only.

  2. inference.py:98-101 — Eager prefill is one-token-at-a-time. This is intentional (comment explains it's simpler for torch.compile), but for a 31B model with long prompts, this could be slow. The exported PTE path handles this properly via chunked prefill. Consider adding a note in the README about this trade-off.

  3. pack_cuda.py:164pack_int4_for_cuda is called with device="cuda" then immediately moved to CPU:

    module.weight = nn.Parameter(packed.data.to("cpu"), requires_grad=False)
    torch.cuda.empty_cache()

    This is because tinygemm packing requires CUDA, but the model is built on CPU for export. The torch.cuda.empty_cache() after each linear is a good practice for 31B-scale models to avoid OOM during the packing phase.

  4. gguf_loader.py:128config = Gemma4_31BConfig(max_seq_len=max_seq_len) uses all defaults. This is correct for the community GGUF files (which are the standard 31B architecture), but if GGUF metadata contains different hyperparameters, they'd be ignored. Fine for now since GGUF only stores the standard config.

  5. CMakeLists.txt:51-53FATAL_ERROR when EXECUTORCH_BUILD_CUDA is not set is appropriate since this runner is CUDA-only.


Test Coverage Assessment

The test suite is comprehensive:

  • test_recipe.py: Regex matching, layer filters, production recipe regression tests
  • test_quantize.py: Output structure, roundtrips, error cases, HQQ paths
  • test_serialize.py: Nibble packing, save/load, 3D tensors, metadata
  • test_pack_cuda.py: Module-type dispatch, matmul accuracy, mixed precision, MoE grouping
  • test_gguf.py: Q4_K/Q6_K unpacking vs reference formulas, serialize roundtrip
  • test_pipeline.py: Quantize→save→load integration
  • test_cuda_pipeline.py: Pack→generate, chunked prefill correctness, export e2e

Missing: No test for the RingKVCache in isolation (though it's exercised by test_chunked_prefill_matches_sequential).


Summary

This is a high-quality PR with clean architecture, good test coverage, and solid engineering choices for memory efficiency at 31B scale. The main actionable items are:

  1. Add a runtime assert in RingKVCache.update for seq_len <= buf_size
  2. Update the QuantConfig.bits comment to reflect actual supported values (storage widths 4 and 8)
  3. Comment the float→uint64 cast safety in main.cpp

None of these are blocking — the code is correct as written, and the invariants are upheld by the runner/export logic. These are defensive hardening suggestions.


Copilot AI review requested due to automatic review settings April 30, 2026 14:50
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mergennachin's task in 2m 30s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review quantization framework (quant/)
  • Review model implementation (model.py)
  • Review export and inference scripts
  • Review GGUF loader
  • Review C++ runner and build system
  • Review tests
  • Review CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a high-quality, well-architected PR (~5500 lines) adding a full Gemma 4 31B-IT pipeline. The separation of concerns is clean — recipe declares intent, quantize produces canonical form, serialize persists, pack converts to backend format. Test coverage is comprehensive (78 tests), and the PR addresses the previous review's concerns well (added the RingKVCache assert, fixed symmetric INT4 dequant, added GGUF tests with proper skip handling).


Architecture & Design (Strengths)

  1. Layered quantization abstractionquant/ cleanly separates recipe → quantize → serialize → pack. Adding a new backend (e.g., Metal) requires only a new pack_metal.py; adding a new model requires only a QuantRecipe. This will pay dividends for Qwen 3.5 MoE and future models.

  2. Ring-buffer KV cache (model.py:64-101) — Sizing to window_size * 2 with the runtime assert on line 96 is a good defensive choice. The sliding mask logic in _build_masks correctly handles wraparound by computing ring_pos from total_written.

  3. Streaming GGUF loading (gguf_loader.py) — Processing one tensor at a time via iter_gguf_tensors + pack_one keeps peak memory low during loading of a 31B model. The untied embed/lm_head handling (dequantize embedding for gather, keep lm_head quantized for matmul) is pragmatic and well-documented.

  4. Chunked prefill in the C++ runner (main.cpp:242-287) — Correctly queries get_max_prefill_chunk metadata and respects the ring-buffer limit. The fallback to kMaxSeqLen - 1 when metadata is absent is safe.


Issues & Suggestions

1. read_token float→uint64 cast should use llrintf (Fixed — good)

main.cpp:85 uses llrintf(val) to round the float to the nearest integer before casting. This is the right approach — raw static_cast<uint64_t>(val) truncates toward zero, which could produce off-by-one errors for token IDs. Well done.

2. _INT8_PER_AXIS hardcodes group_size=5376 (quantize_and_save.py:51-53)

_INT8_PER_AXIS = QuantConfig(bits=8, group_size=5376, symmetric=True, method="min_max")

This couples the recipe to Gemma 4 31B's hidden_size. Since the recipe is named GEMMA4_31B_* and lives in the model-specific file, this is acceptable. But if this framework is promoted to be shared across models, consider deriving from the config:

# Future: group_size=config.hidden_size for per-axis quantization

Fix this →

3. Potential integer overflow in _build_masks ring_pos calculation (model.py:455)

ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size

When total_written is 0 (i.e., input_pos[0] == 0 and seq_len == 0), total_written - 1 underflows to -1 for signed int64 tensors. In practice this can't happen (forward is never called with empty input), but it's worth noting. The actual execution path always has seq_len >= 1, so total_written >= 1.

4. GGUF test skip handling is correct (quant/tests/test_gguf.py:21-28)

The earlier review flagged that GGUF tests might fail without the gguf package. The current code handles this correctly:

  • Lines 21-25: try/except guards _HAS_GGUF
  • Line 28: Conditional import of unpack_gguf_tensor only when _HAS_GGUF=True
  • Lines 30-31: dequantize_weight and serialize/deserialize are imported unconditionally (they don't need gguf)
  • All test classes use @unittest.skipUnless(_HAS_GGUF, ...) decorators
  • CI (cuda.yml:152) installs gguf before running tests

This is properly handled now.

5. _move_to_cuda preserves tensor subclass identity (inference.py:48-55)

for name, p in model.named_parameters():
    parts = name.rsplit(".", 1)
    parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
    setattr(parent, parts[-1],
            torch.nn.Parameter(p.data.to("cuda"), requires_grad=False))

This iterates named_parameters() while mutating them via setattr. In CPython this is safe because named_parameters() yields from a snapshot of the module tree (it doesn't lazily walk), but it's fragile. More importantly, p.data.to("cuda") on a Int4TilePackedTo4dTensor subclass relies on the subclass implementing __torch_dispatch__ for the to op correctly. The docstring correctly notes this intent — just flagging that if torchao changes the subclass dispatch behavior, this could silently break. A more defensive approach would be to check type(p.data) after the move, but this is minor.

6. Export caps prefill at min(max_seq_len - 1, sliding_window * 2) (export.py:167)

max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)

For the default config (max_seq_len=4096, sliding_window=1024), this gives max_prefill=2048. The full-attention layers use a flat Gemma4KVCache of size max_seq_len=4096, so they have no issue. The sliding layers use a ring buffer of size 2048, matching exactly. This is correct.

However, consider that Dim("seq_len", min=2, max=max_prefill) means the dynamic shape for prefill ranges from 2 to 2048. If a user passes max_seq_len=512 (< 2*sliding_window), max_prefill=511, and the ring buffer is still sized to 2048 but the dynamic dim is capped at 511. This is fine — it's a conservative limit.

7. Q4_K zero handling when scale=0 (quant/gguf.py:77-79)

zero_std = torch.where(
    eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min)
)

The test_zero_scale_produces_zero test validates this edge case — when scale=0, zero_std is forced to 0, avoiding NaN from division. The comment in the test references this as a fixed regression. Good.

8. Sampler numerical stability (sampler.py:40-41)

noise = torch.rand_like(logits)
gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20)

The double-log Gumbel trick with 1e-20 epsilon is standard. However, torch.rand_like can produce exactly 0.0 (uniform on [0, 1)), and -log(1e-20) ≈ 46, so -log(46 + 1e-20) ≈ -3.8. This is fine numerically — no NaN or Inf can occur given the epsilon. The temperature clamp at 1e-6 also prevents division issues.


Minor Observations

  1. model.py:486-498 — The forward return type changes based on temperature (returns (B, T, V) logits when None, (B, 1) token IDs otherwise). This is documented but means type annotations can't express the full contract. Fine for an inference-only model.

  2. inference.py:96-101 — One-token-at-a-time prefill is intentionally chosen for torch.compile friendliness (uniform shape = one graph). The docstring explains the tradeoff clearly. For a 31B model with 4096-token prompts this is slow (~4s on A100), but the PTE path with chunked multi-token prefill is the production path.

  3. pack_cuda.py:164 — Moving INT4 packed weights from CUDA back to CPU (packed.data.to("cpu")) with torch.cuda.empty_cache() is good memory hygiene for the packing phase. Each linear is packed individually on CUDA to avoid accumulating GPU memory for all 60 layers simultaneously.

  4. quant/tests/test_pack_cuda.py — I didn't read this fully but its presence (482 lines) indicates thorough CUDA packer testing including matmul accuracy validation.


Test Coverage Assessment

The test suite is comprehensive and well-organized:

File Scope Key tests
test_recipe.py Regex matching, layer filters, production recipe regression 14 test cases
test_quantize.py quantize_weight/quantize_model API contract, HQQ paths 12 test cases
test_serialize.py Nibble packing, save/load, 3D tensors, versioned headers ~15 test cases
test_pack_cuda.py Module-type dispatch, matmul accuracy, mixed precision, MoE ~20 test cases
test_gguf.py Q4_K/Q6_K unpacking vs reference formulas, serialize roundtrip 7 test cases
test_pipeline.py Quantize→save→load integration, RingKVCache isolation 7 test cases
test_cuda_pipeline.py Pack→generate, chunked prefill correctness, export e2e 4 test cases

Good coverage. The RingKVCache unit tests (test_pipeline.py:212-266) are a welcome addition addressing the earlier review's gap.


Summary

This is a well-executed, production-quality addition. The architecture decisions are sound, the code is clean and readable (comments only where non-obvious), and the test coverage is thorough. The main items from previous reviews have been addressed. No blocking issues remain.

Actionable items (non-blocking):

  1. Comment the group_size=5376 coupling to hidden_size in quantize_and_save.py:51
  2. Consider adding @torch.no_grad() decorator to Gemma4_31B.forward since this is inference-only

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a full Gemma 4 31B-IT text-only pipeline to the ExecuTorch examples, including a reusable packing-agnostic quantization framework (recipe/quantize/serialize/pack), GGUF import, CUDA packing/export/inference flows, and a CUDA runner, with CI coverage.

Changes:

  • Introduces examples/models/gemma4_31b/ model implementation (ring-buffer KV cache), export/inference scripts, GGUF loader, and C++ CUDA runner + build targets.
  • Adds a new quant/ framework (recipes, min-max + HQQ quantization, safetensors format, CUDA packing, GGUF Q4_K/Q6_K unpack).
  • Adds unit/integration tests and wires them into the CUDA GitHub Actions workflow.

Reviewed changes

Copilot reviewed 31 out of 31 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
examples/models/gemma4_31b/tests/test_pipeline.py CPU-only pipeline + RingKVCache tests for quant/save/load and cache behavior
examples/models/gemma4_31b/tests/test_cuda_pipeline.py CUDA integration tests for pack/infer/export + chunked prefill equivalence
examples/models/gemma4_31b/sampler.py GPU-side Gumbel-max sampler (mirrors Qwen sampler behavior)
examples/models/gemma4_31b/quantize_and_save.py CLI to quantize HF checkpoint and save canonical safetensors checkpoint
examples/models/gemma4_31b/quant/tests/test_serialize.py Unit tests for canonical format + nibble packing + safetensors I/O
examples/models/gemma4_31b/quant/tests/test_recipe.py Unit tests for regex/layer-filter recipe matching + production recipe regression
examples/models/gemma4_31b/quant/tests/test_quantize.py Unit tests for min-max + HQQ quantize/dequantize and quantize_model behavior
examples/models/gemma4_31b/quant/tests/test_pack_cuda.py CUDA unit tests for packers (int4 tinygemm, int8 intx, dispatch/grouping)
examples/models/gemma4_31b/quant/tests/test_gguf.py Unit tests for GGUF Q4_K/Q6_K unpacking and serialize roundtrip
examples/models/gemma4_31b/quant/serialize.py CanonicalQuantizedWeight + serialize/deserialize + safetensors save/load
examples/models/gemma4_31b/quant/recipe.py QuantConfig/QuantRule/QuantRecipe declarative matching logic
examples/models/gemma4_31b/quant/quantize.py min-max + HQQ quantize_weight/dequantize_weight + per-model quantization
examples/models/gemma4_31b/quant/pack_cuda.py CUDA packers for Linear/Embedding and load+pack convenience wrapper
examples/models/gemma4_31b/quant/pack.py Backend-agnostic pack_model/pack_one dispatch + grouping by parent module
examples/models/gemma4_31b/quant/gguf.py GGUF tensor unpack + streaming iterator to canonical representation
examples/models/gemma4_31b/quant/init.py Public API exports for quant framework and CUDA packers
examples/models/gemma4_31b/quant/README.md Framework overview, dataflow, and backend/model extension guidance
examples/models/gemma4_31b/model.py Gemma4 31B model, ring-buffer KV cache, HF loader, runtime buffer materialization
examples/models/gemma4_31b/model.md Architecture/design notes + export/quantization details
examples/models/gemma4_31b/main.cpp CUDA runner driving exported prefill/decode methods with chunked prefill
examples/models/gemma4_31b/inference.py Eager CUDA inference with torch.compile + GGUF/prequantized load paths
examples/models/gemma4_31b/gguf_loader.py Streams GGUF tensors, remaps keys, handles tied embed/lm_head, packs model
examples/models/gemma4_31b/export.py Export+lower (decode+prefill) to ExecuTorch with CUDA backend and metadata
examples/models/gemma4_31b/init.py Package marker for gemma4_31b example
examples/models/gemma4_31b/README.md End-to-end usage docs for quantize/export/infer/build/run
examples/models/gemma4_31b/CMakePresets.json CMake workflow presets for building the runner (CUDA preset)
examples/models/gemma4_31b/CMakeLists.txt Runner build configuration + CUDA backend linkage
Makefile Adds gemma4_31b-cuda build target
.github/workflows/cuda.yml Runs gemma4_31b quant + pipeline tests in CUDA CI
examples/models/gemma4/text_decoder/gemma4_norm.py Replaces transformers RMSNorm import with self-contained implementation
examples/models/gemma4/text_decoder/init.py Exposes additional shared primitives used by gemma4_31b model

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/gemma4_31b/main.cpp
Comment thread examples/models/gemma4_31b/quant/recipe.py Outdated
Comment thread examples/models/gemma4_31b/quant/serialize.py Outdated
Comment thread examples/models/gemma4_31b/quant/quantize.py
Comment thread examples/models/gemma4_31b/quant/quantize.py Outdated
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 30, 2026

Claude finished @mergennachin's task in 2m 30s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review model implementation (model.py)
  • Review quantization framework (quant/)
  • Review GGUF loader and streaming support
  • Review export and inference scripts
  • Review C++ runner
  • Review tests and CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a well-structured, high-quality addition (~5600 lines across 31 files) introducing a full Gemma 4 31B-IT text-only pipeline. The architecture is clean — the quantization framework separates concerns into orthogonal layers (recipe → quantize → serialize → pack), the model implementation is faithful to the HF reference, and the streaming GGUF support keeps peak memory tractable. Test coverage is comprehensive (78 tests), and the PR has clearly evolved through multiple review iterations addressing earlier feedback.

The codebase is ready for merge with minor non-blocking suggestions below.


Architecture Strengths

  1. Layered quantization frameworkquant/ cleanly separates concerns: recipe.py (what to quantize), quantize.py (how to quantize), serialize.py (persistence format), pack.py (runtime dispatch), pack_cuda.py (backend-specific). Adding a new backend (Metal, XLA) requires only a new pack_*.py; adding a model requires only a recipe.

  2. Ring-buffer KV cache (model.py:64-101) — Sizing to 2 × sliding_window instead of max_seq_len saves ~75% of KV memory for sliding layers. The runtime assert on line 95-97 guards against oversized prefill — good defensive programming.

  3. Streaming I/O everywhereiter_load (serialize.py:238-281), iter_gguf_tensors (gguf.py:201-219), and load_and_pack_for_cuda (pack_cuda.py:200-224) all process one weight at a time. At 31B scale, this is the difference between 17GB peak and 60GB+ peak during model loading.

  4. Shared primitives — The numerically sensitive ops (RMSNorm, RoPE, KV cache, MLP) are shared with the E2B/E4B variant via examples.models.gemma4.text_decoder, avoiding drift.

  5. Chunked prefill in the C++ runner (main.cpp:242-293) — Correctly queries get_max_prefill_chunk metadata and chunks accordingly, with cudaDeviceSynchronize (line 308) between prefill and decode to flush shared KV-cache writes.


Issues & Suggestions

1. _build_masks ring_pos computation when total_written == 0 (model.py:453-455)

total_written = input_pos[0] + seq_len
j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device)
ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size

When input_pos[0] == 0 and seq_len >= 1, total_written >= 1, so total_written - 1 >= 0 and the floor division is well-defined. The invariant seq_len >= 1 is guaranteed by the export shapes (decode T=1, prefill T≥2). However, consider that total_written - 1 - j can be negative for j > total_written - 1 (i.e., slots that haven't been written yet). In that case (negative // buf_size) * buf_size yields a negative ring_pos, and the mask clause ring_pos >= 0 (line 457) correctly excludes those slots. This is sound.

No action needed — just documenting my verification.

2. QuantConfig.group_size comment is stale (recipe.py:24)

group_size: int  # 32, 64, 128

The actual usage includes group_size=5376 (per-axis embedding), group_size=16 (Q6_K), and group_size=32 (Q4_K, standard). The comment should say "any positive divisor of the last dim" or similar.

Fix this →

3. quantize_model() prints to stdout unconditionally (quantize.py:241-242)

print(f"  Quantized {i + 1}/{n_params}: {fqn}", end="\r")

Since this is a reusable library function called by tests, this pollutes test output. Consider gating behind a verbose parameter (default True for backward compat) or using logging.

4. _INT8_PER_AXIS hardcodes group_size=5376 (quantize_and_save.py:51-53)

_INT8_PER_AXIS = QuantConfig(  # group_size = hidden_size (5376) for Gemma 4 31B
    bits=8, group_size=5376, symmetric=True, method="min_max"
)

The comment already explains the coupling, which is good. When promoting this framework for Qwen 3.5 MoE, this will need to be parameterized. Fine for now.

5. C++ runner: missing <cstring> and <cmath> includes (main.cpp:62-88)

read_token uses memcpy (needs <cstring>) and llrintf (needs <cmath>). Currently these compile because they're transitively included via gflags/CUDA headers, but this is fragile.

Fix this →

6. deserialize() doesn't reshape loaded tensors (serialize.py:190-206)

deserialize() records original_shape from metadata but never reshapes qdata to match for 8-bit weights (4-bit gets unpacked to the right shape via _nibble_unpack). The CanonicalQuantizedWeight.__post_init__ validates element counts but not shapes. A corrupted file with transposed dimensions would pass validation but produce wrong results silently. Consider adding an explicit reshape:

qdata = qdata.reshape(original_shape)

after the nibble-unpack branch.

7. _move_to_cuda mutates parameters during iteration (inference.py:48-55)

for name, p in model.named_parameters():
    parts = name.rsplit(".", 1)
    parent = model.get_submodule(parts[0]) if len(parts) > 1 else model
    setattr(parent, parts[-1],
            torch.nn.Parameter(p.data.to("cuda"), requires_grad=False))

named_parameters() yields from a generator that walks the module tree. Mutating modules during iteration is safe in CPython (the generator captures module references upfront), but it's implementation-dependent. A safer pattern is list(model.named_parameters()). This is minor — the current code works correctly.

8. Sampler's temperature is None return type divergence (sampler.py:35-41)

if temperature is None:
    return logits  # (B, V)
...
return (logits + gumbel).argmax(dim=-1, keepdim=True).float()  # (B, 1)

The return type depends on the runtime value of temperature. This is intentional (documented in model.py:476-479) and works because the two code paths are exported as separate methods (prefill returns logits for the last chunk, decode uses temperature). No action needed, but a type annotation like -> torch.Tensor (current) is the best we can do here.


Test Coverage

The test suite is well-organized and thorough:

File Tests Key coverage
quant/tests/test_recipe.py 14 Regex matching, layer filters, production recipe regression
quant/tests/test_quantize.py 12 quantize_weight/dequantize_weight roundtrip, HQQ paths, error cases
quant/tests/test_serialize.py ~15 Nibble pack/unpack, save/load, 3D tensors, format versioning
quant/tests/test_pack_cuda.py ~20 Module dispatch, matmul accuracy, mixed precision, MoE grouping
quant/tests/test_gguf.py 7 Q4_K/Q6_K vs reference formulas, serialize roundtrip, edge cases
tests/test_pipeline.py 7 Quantize→save→load, RingKVCache isolation, corrupted checkpoint
tests/test_cuda_pipeline.py 4 Pack→generate, chunked prefill correctness, export e2e

The GGUF test skip handling (test_gguf.py:21-28) is properly implemented — @unittest.skipUnless(_HAS_GGUF, ...) on each class, conditional import of unpack_gguf_tensor, and CI installs gguf before running (cuda.yml:152).

The RingKVCache unit tests (test_pipeline.py:212-266) cover sequential write, wraparound, multi-token, and oversized-prefill assertion — good targeted coverage for the most subtle piece of logic.


Summary

This is production-quality work with sound architecture decisions for memory efficiency at 31B scale. The main actionable items are:

  1. Add <cstring> and <cmath> includes in main.cpp (fragile transitive dependency)
  2. Update group_size comment in recipe.py to reflect actual supported values
  3. Consider reshaping qdata in deserialize() for 8-bit weights (defensive hardening)

None of these are blocking — the code is correct as written and the invariants are upheld by the export/runner logic. Previous review iterations (symmetric INT4 dequant fix, RingKVCache assert, GGUF skip handling, allclose tolerances) have all been addressed properly.




@dataclass
class CanonicalQuantizedWeight:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is subclass support for this "canonical" format that is later converted to specialized formats.

The basic flow is:

  • On your device-specific subclass, define a class method constructor "from_{canonical_format}_tensor", e.g., Int4PreshuffledTensor.from_int4_tensor(tensor)
  • Load "canonical" tensor subclass checkpoint
  • Iterate through canonical tensors, applying the conversion above.

IIRC, this is the path used by torchao in vLLM (cc @jerryzh168 to confirm).

CanonicalQuantizedWeight feels like a new abstraction that we'd maintain to do the same thing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Good point.

I'll use Int4Tensor and IntxUnpackedToInt8Tensor as they seem like canonical/portable format that can serialized to/from safetensors.

Will delete CanonicalQuantizedWeight

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Made the change. Thank for the recommendation

Copilot AI review requested due to automatic review settings May 1, 2026 16:50
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from 49cfc55 to 69474e3 Compare May 1, 2026 16:50
@claude
Copy link
Copy Markdown

claude Bot commented May 8, 2026

Claude finished @mergennachin's task in 3m 31s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand the PR structure
  • Review new CUDA backend additions (INT4 dispatch, dp4a kernel, split-K SDPA)
  • Review quantization framework (quant/)
  • Review model implementation (model.py)
  • Review export, inference, and GGUF loader
  • Review C++ runner and build system
  • Review tests
  • Post comprehensive review feedback

Overall Assessment

This is a substantial, well-architected PR (~6400 lines, 41 files, 11 commits) that introduces both new CUDA backend infrastructure (dp4a decode kernel, split-K flash-decoding) and a full Gemma 4 31B-IT pipeline. The code is clean, the architecture is sound for reuse, and the test coverage is comprehensive. The quantization framework's layered design (recipe → quantize → serialize → pack) is excellent for extensibility.

This review focuses on the newest additions (dp4a kernel, split-K SDPA selection, Int4Tensor dispatch) and fresh observations on the model/quant code that previous reviews haven't fully covered.


New CUDA Backend Infrastructure

1. W4A8 dp4a Kernel (int4_plain_mm.cuh) — Well-designed

The kernel design is solid:

  • 16-byte vectorized weight loads (uint4 / __ldg)
  • Warp-cooperative activation quantization (per-32-element blocks)
  • dp4a for fused int4×int8 dot products
  • Warp-level reduction via __shfl_xor_sync

The even/odd Q8Block layout (qs_even[16], qs_odd[16]) is unusual but clever — it aligns the activation quantized data so that dp4a can process the low and high nibbles of weights against their respective activation halves without shuffle overhead.

2. Global persistent buffer is a leak risk (int4_plain_mm.cuh:182-193)

static Q8Block* g_q8_buf = nullptr;
static size_t g_q8_buf_size = 0;

static Q8Block* get_q8_buffer(size_t needed) {
  if (g_q8_buf_size < needed) {
    if (g_q8_buf) cudaFree(g_q8_buf);
    cudaMalloc(&g_q8_buf, needed);
    g_q8_buf_size = needed;
  }
  return g_q8_buf;
}

The comment says "single-stream only" which is fine for the runner, but:

  • The buffer is never freed on program exit (minor — OS reclaims, but sanitizers will flag it)
  • cudaMalloc error is not checked. If allocation fails, g_q8_buf will be the old freed pointer (UAF) or nullptr
  • If this code is used with multiple CUDA devices (multi-GPU), the static global creates a device-affinity issue

Suggested fix for the cudaMalloc call:

cudaError_t err = cudaMalloc(&g_q8_buf, needed);
if (err != cudaSuccess) {
  g_q8_buf = nullptr;
  g_q8_buf_size = 0;
  ET_CHECK_MSG(false, "cudaMalloc failed for Q8 buffer: %s", cudaGetErrorString(err));
}

Fix this →

3. Missing K alignment validation (int4_plain_mm.cuh:225)

The kernel processes K in blocks of 32 (Q8_BLOCK_SIZE). The check K >= Q8_BLOCK_SIZE passes but doesn't verify K % Q8_BLOCK_SIZE == 0. If K is not a multiple of 32, the quantization kernel's grid calculation n_q8_blocks = K / Q8_BLOCK_SIZE truncates, causing the last elements of each activation row to be unquantized/unread — producing silent numerical errors. All Gemma 4 31B layer dimensions satisfy this (5376, 21504 are multiples of 32), but for a generic shim:

ET_CHECK_MSG(
    K % Q8_BLOCK_SIZE == 0,
    "K=%d must be divisible by %d for dp4a kernel",
    K, Q8_BLOCK_SIZE);

Fix this →

4. Int4Tensor F.linear dispatch (int4_dispatch.py) — Clean design

The dispatch approach is elegant:

  • M <= 4 → custom op executorch_cuda::int4_plain_mm (routes to dp4a C shim in .pte runtime)
  • M > 4 → inline dequant + F.linear (AOTI compiles to cuBLAS)

Both paths share the same dequant math in eager, which means the Python tests validate correctness for both code paths. The M=4 threshold makes sense — dp4a matvec saturates at M=1-4, while cuBLAS GEMM wins for larger M.

5. Split-K SDPA selection (triton/replacement_pass.py) — Correct and conservative

if (
    isinstance(L_q, int) and L_q == 1
    and isinstance(L_kv, int) and L_kv > _SPLITK_LKV_THRESHOLD
    and D > 0 and (D & (D - 1)) == 0  # power of 2
):

The isinstance(L_q, int) check correctly handles dynamic shapes (symbolic dims won't match). The power-of-2 head_dim constraint avoids triton alignment issues. The 2048 threshold matches the sliding_window boundary perfectly — sliding layers (ring buffer ≤ 2048) get standard SDPA, full-attention layers (4096+) get split-K. Good.


Model & Quantization

6. _build_masks sliding mask is correct but subtle (model.py:453-457)

total_written = input_pos[0] + seq_len
j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device)
ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size
delta = q_pos - ring_pos.unsqueeze(0)
sliding = (ring_pos >= 0) & (delta >= 0) & (delta < self.config.sliding_window)

I verified: for slots where j > total_written - 1 (unwritten), total_written - 1 - j is negative, floor division with positive buf_size produces a negative quotient, so ring_pos stays less than total_written — and the mask correctly excludes those slots because delta < 0 for unwritten positions or ring_pos < 0 catches the edge case. Sound.

7. quantize_model stdout side effect (quantize.py:303)

print(f"  Quantized {i + 1}/{n_params}: {fqn}", end="\r")

This is a reusable library function called by both scripts and tests. The unconditional print pollutes test output. Consider a verbose parameter or use logging.

8. QuantConfig.group_size comment is stale (recipe.py:28)

group_size: int  # 32, 64, 128

Actual usage includes 5376 (per-axis embedding), 16 (Q6_K). The comment should be updated or removed.


C++ Runner

9. Missing standard includes (main.cpp)

The file uses memcpy (needs <cstring>), llrintf (needs <cmath>), and fprintf (needs <cstdio>) without direct includes. This compiles today via transitive includes from gflags/CUDA headers but is fragile across toolchains.

Fix this →

10. Log level filter may suppress Error/Fatal (main.cpp:43)

if (level < 'W') {
    return;
}

If et_pal_log_level_t values are character literals ('D'=68, 'E'=69, 'F'=70, 'I'=73, 'W'=87), this filter suppresses Error and Fatal logs (since 'E' < 'W' and 'F' < 'W'). This would hide runtime errors. If the intent is to only suppress Debug/Info, compare against the actual enum values rather than character ordering:

if (level == 'D' || level == 'I') {
    return;
}

Or verify that et_pal_log_level_t uses numeric severity ordering, not char values.

Fix this →


Tests

The test suite is thorough and well-organized:

Area File Coverage
dp4a dispatch backends/cuda/tests/test_int4_dispatch.py Decode/prefill/batched/bias/compile/large shapes
Split-K SDPA backends/cuda/tests/test_sdpa_splitk_replacement.py Large KV→splitk, small KV→standard, non-pow2 head
dp4a C shim backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp Multi-group, single-group, prefill batch, group_size=128
Recipe quant/tests/test_recipe.py Regex, layers, production recipe regression
Quantize quant/tests/test_quantize.py Types, roundtrip, HQQ, small-weight precision
Pack CUDA quant/tests/test_pack_cuda.py Dispatch, matmul accuracy, mixed precision
GGUF quant/tests/test_gguf.py Q4_K/Q6_K unpack, zero-scale edge case
Pipeline tests/test_pipeline.py E2E roundtrip, RingKVCache, GGUF key mapping
CUDA Pipeline tests/test_cuda_pipeline.py Generate, chunked prefill, export, Int4 passthrough

The test_int4_dispatch.py tests are well-designed — they compare against the unquantized bf16 reference weight (not the dequantized weight), which validates the full quantize→dispatch pipeline end-to-end.


Summary

This is a high-quality PR that adds significant CUDA backend infrastructure alongside the Gemma 4 31B model. The dp4a kernel and split-K SDPA selection are well-engineered and properly tested. The quantization framework is ready for reuse.

Actionable items:

# Severity Description
2 Medium get_q8_buffer doesn't check cudaMalloc return — potential UAF
3 Low Missing K % 32 == 0 validation in dp4a kernel
9 Low Missing <cstring>/<cmath>/<cstdio> includes in runner
10 Medium Log filter level < 'W' may suppress Error/Fatal
7 Low quantize_model prints to stdout unconditionally
8 Low Stale group_size comment

Items 2 and 10 are the most important for correctness/debuggability. The rest are hardening/hygiene.


Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 41 out of 41 changed files in this pull request and generated 5 comments.

Comment thread examples/models/gemma4_31b/main.cpp Outdated
Comment thread backends/cuda/triton/replacement_pass.py
Comment thread backends/cuda/cuda_backend.py
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cuh
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cuh Outdated
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from 206feef to b777a05 Compare May 8, 2026 18:39
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR now

@claude
Copy link
Copy Markdown

claude Bot commented May 8, 2026

Claude finished @mergennachin's task in 3m 47s —— View job


PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework

  • Gather context and understand recent changes since last review
  • Review CUDA backend additions (dp4a kernel, split-K SDPA, int4 dispatch)
  • Review quantization framework (quant/)
  • Review model implementation (model.py)
  • Review export, inference, and GGUF loader
  • Review C++ runner and build system
  • Review tests and CI integration
  • Post comprehensive review feedback

Overall Assessment

This is a well-architected, production-quality PR (~6400 lines, 41 files, 11 commits) that introduces both new CUDA backend infrastructure (W4A8 dp4a decode kernel, split-K flash-decoding SDPA selection) and a complete Gemma 4 31B-IT text-only pipeline. The code is clean, the layered design is excellent for reuse (Qwen 3.5 MoE and future models), and the test coverage is thorough. Previous review iterations have addressed all major correctness issues.

This review is fresh against the final state of the branch. The earlier concerns (symmetric INT4 dequant bug, log-level filter suppressing errors, missing <cstring>/<cmath>, torch.equal vs allclose, GGUF skip handling) have all been fixed.


Architecture Strengths

  1. Layered quantization frameworkrecipe.py (what to quantize) → quantize.py (how, returns torchao subclasses) → pack.py (assigns to model) → pack_cuda.py (backend packer). Adding a backend requires only a new packer; adding a model requires only a recipe. No intermediate custom format — torchao subclasses are the canonical representation.

  2. Int4Tensor F.linear dispatch (int4_dispatch.py) — Elegant dual-path: M≤4 routes to executorch_cuda::int4_plain_mm custom op (dp4a in runtime), M>4 does inline dequant+F.linear (AOTI compiles to cuBLAS). Both share the same dequant math, so Python tests validate correctness for both paths.

  3. Split-K SDPA selection (triton/replacement_pass.py:88-114) — The isinstance(L_q, int) guard correctly handles symbolic dims, the power-of-2 head_dim check avoids triton alignment issues, and the 2048 threshold naturally separates sliding-window layers (ring buffer ≤ 2048) from full-attention layers (4096+). Conservative and correct.

  4. Streaming everywhereiter_gguf_tensors + pack_one, the per-tensor load_and_pack_for_cuda loop, and the from_hf_checkpoint shard-by-shard loader all keep peak memory bounded at ~1 shard for a 31B model.

  5. Ring-buffer KV cache (model.py:64-101) — Sized to 2×sliding_window, with a runtime assert guarding the seq_len <= buf_size invariant. The _build_masks ring_pos calculation at line 455 is non-trivial but sound — I verified the edge cases (unwritten slots are correctly excluded by delta < 0).


Issues

1. get_q8_buffer is not robust to multi-stream or multi-device usage (int4_plain_mm.cuh:182-200)

The global static buffer pattern works for the single-stream runner but is fragile:

  • If called from two CUDA streams concurrently, the shared buffer produces data races
  • The comment says "single-stream only" which is fine for now, but there's no enforcement

The cudaMalloc error handling has been properly added (good fix from earlier review). The remaining concern is the buffer never being freed — sanitizers will report it. For the runner scope this is fine, but if this shim is ever used in a library context, consider adding atexit cleanup or a per-stream buffer strategy.

Severity: Low (documented limitation, runner-only scope)

2. _pick_sdpa_kernel doesn't guard against missing shape metadata (replacement_pass.py:98)

q_shape = node.args[0].meta["val"].shape

If meta["val"] is missing (e.g., from a non-standard export path), this raises KeyError. However, looking at the pass's call() method (line 66-71), the entire replacement is wrapped in try/except Exception, which catches this and logs a warning. So the fallback behavior is correct — the SDPA node just won't be replaced. Still, a more informative error message would help debugging:

if "val" not in node.args[0].meta:
    return triton.sdpa  # Can't determine shapes, use standard kernel

Severity: Low (existing try/except handles it, but silently)

3. try/except AttributeError: pass in cuda_backend.py:303-315 is too broad

try:
    options["aot_inductor.custom_ops_to_c_shims"] = {
        torch.ops.executorch_cuda.int4_plain_mm.default: [...]
    }
except AttributeError:
    pass

The intent is to skip when int4_dispatch.py hasn't been imported (so the op isn't registered). But AttributeError could also mask bugs like a schema mismatch or a typo in the op name. Consider checking explicitly:

if hasattr(torch.ops, "executorch_cuda") and hasattr(torch.ops.executorch_cuda, "int4_plain_mm"):
    options["aot_inductor.custom_ops_to_c_shims"] = { ... }

Fix this →

Severity: Low (correctness impact is nil — the op either exists or it doesn't)

4. quantize_model verbose output uses \r which doesn't play well with CI logs (quantize.py:305)

if verbose:
    print(f"  Quantized {i + 1}/{n_params}: {fqn}", end="\r")

The verbose parameter (added per earlier feedback — good) defaults to False, which is correct for library usage. However, end="\r" produces garbled output in CI logs (no terminal to interpret carriage return). Consider using \n when sys.stdout.isatty() is False, or just always use \n — the progress is infrequent enough (one line per weight, ~300 for this model) that it won't flood logs.

Severity: Negligible (cosmetic in CI)

5. _move_to_cuda iterates and mutates simultaneously (inference.py:48-55)

for name, p in model.named_parameters():
    ...
    setattr(parent, parts[-1], torch.nn.Parameter(p.data.to("cuda"), ...))

This works in CPython because named_parameters() iterates from a module tree that captures references upfront. A safer pattern would be list(model.named_parameters()). Low priority since this code is correct and CPython is the only supported runtime.

Severity: Low (works correctly, fragile under hypothetical alternative Python runtimes)


Observations (No action needed)

  1. dp4a kernel correctness — The even/odd Q8Block layout enables dp4a to process low/high nibbles independently, accumulating in a single warp reduction. The a_sum8 calculation for zero-point correction (dp4a(0x01010101, ...)) is a clever use of dp4a as a horizontal sum. The math checks out.

  2. K alignment check is present (int4_plain_mm.cuh:225): K >= Q8_BLOCK_SIZE && K % Q8_BLOCK_SIZE == 0 — addresses the earlier review concern about K divisibility.

  3. Log filter is correct (main.cpp:46): level == 'D' || level == 'I' — only suppresses Debug and Info, Error and Fatal are propagated. Fixed from the earlier level < 'W' bug.

  4. Standard includes are present (main.cpp:29-31): <cmath>, <cstdio>, <cstring> — all directly included now.

  5. allclose with tolerances (test_cuda_pipeline.py:128-134): atol=1e-2, rtol=1e-3 — appropriate for comparing across different CUDA kernel execution paths. Fixed from earlier torch.equal.

  6. GGUF skip handling (quant/tests/test_gguf.py): @unittest.skipUnless(_HAS_GGUF, ...) on test classes, CI installs gguf (cuda.yml:152). Properly handled.

  7. group_size comment (recipe.py:28): Now just says # 4 or 8 for bits — the group_size field has no misleading comment. The earlier stale comment issue was on a previous revision.


Test Coverage

Area File Highlights
dp4a dispatch backends/cuda/tests/test_int4_dispatch.py Decode M=1, prefill M=64, batched, bias, symmetric, group_size=128, large shapes, compile
Split-K SDPA backends/cuda/tests/test_sdpa_splitk_replacement.py Large KV→splitk, small KV→standard, non-pow2 head_dim
dp4a C shim backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp Multi-group accuracy, single-group, prefill batch, group_size=128
Recipe quant/tests/test_recipe.py Regex, layer filter, production recipe regression (14 tests)
Quantize quant/tests/test_quantize.py quantize/dequantize roundtrip, HQQ, error cases (12 tests)
Pack CUDA quant/tests/test_pack_cuda.py Module dispatch, matmul accuracy, mixed precision (~20 tests)
GGUF quant/tests/test_gguf.py Q4_K/Q6_K vs reference formulas, zero-scale edge case (7 tests)
Safetensors quant/tests/test_safetensors_roundtrip.py Int4Tensor + IntxUnpackedToInt8Tensor roundtrip
Pipeline tests/test_pipeline.py E2E roundtrip, RingKVCache unit tests, GGUF key mapping
CUDA Pipeline tests/test_cuda_pipeline.py Generate, chunked prefill correctness, export, Int4 passthrough

The test_chunked_prefill_matches_sequential test is particularly valuable — it verifies the ring-buffer's most subtle correctness property (wraparound produces identical logits to sequential execution).


Summary

This PR is ready for merge. The architecture is sound, the code is clean, all prior review issues have been addressed, and test coverage is comprehensive. The remaining items are minor hardening suggestions, none blocking:

# Severity Item
1 Low get_q8_buffer global static is not freed / not multi-stream safe (documented)
2 Low _pick_sdpa_kernel shape metadata access relies on outer try/except
3 Low try/except AttributeError: pass in cuda_backend.py could mask typos
4 Negligible \r progress output garbles CI logs
5 Low _move_to_cuda mutates during iteration (correct in CPython)

The codebase is well-positioned for reuse — adding Qwen 3.5 MoE will require only a recipe and optionally a custom packer, as designed.


Copilot AI review requested due to automatic review settings May 8, 2026 18:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 41 out of 41 changed files in this pull request and generated 6 comments.


def _read_f16(raw: torch.Tensor, col_start: int, col_end: int) -> torch.Tensor:
"""Read fp16 field from block bytes, return float32."""
return raw[:, col_start:col_end].contiguous().view(torch.float16).float()
Comment on lines +182 to +186
return _raw_tensor(tensor_data).view(torch.float32).reshape(shape).clone()
elif tensor_type == GGMLQuantizationType.F16:
return (
_raw_tensor(tensor_data)
.view(torch.float16)
Comment on lines +137 to +148
# Int4Tensor stores qdata as nibble-packed uint8 (N, K//2)
q = int_data.to(torch.uint8)
packed = q[..., ::2] | (q[..., 1::2] << 4)

# Int4Tensor stores scale/zero as (K//gs, N) — transposed from our (N, K//gs)
return Int4Tensor(
qdata=packed,
scale=scale.t().contiguous(),
zero_point=zero_point.t().contiguous(),
block_size=[1, config.group_size],
shape=torch.Size(int_data.shape),
)
Comment thread examples/models/gemma4_31b/quant/quantize.py
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cu
Comment thread examples/models/gemma4_31b/inference.py
mergennachin and others added 13 commits May 12, 2026 07:36
…uTorch

Text-only export of Gemma 4 31B-IT to ExecuTorch with the CUDA backend
and INT4/INT8 weight quantization via a new packing-agnostic quant/
framework.

The quant/ package separates quantization into four concerns:
  - recipe.py: declarative QuantRecipe with regex FQN matching
  - quantize.py: produces CanonicalQuantizedWeight (min_max, HQQ)
  - serialize.py: save/load to safetensors with versioned headers
  - pack.py + pack_cuda.py: per-module packer dispatch for CUDA

Two production recipes: "default" (INT4 min_max + INT8 embedding) and
"sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ elsewhere).

Sliding window attention uses a ring-buffer KV cache (2x window size)
for the 50 sliding layers, saving memory for long sequences. The 10
full-attention layers use a standard flat KV cache.

Includes C++ runner (main.cpp), eager inference script, and 60+ unit
and integration tests across quant/ and pipeline test files.
- Sliding window layers use RingKVCache (2×window) instead of flat
  max_seq_len buffer, reducing KV cache memory for long sequences.
- Prefill is capped to ring buffer size; the C++ runner chunks longer
  prompts automatically via get_max_prefill_chunk metadata.
- Both recipes now quantize embed_tokens to INT8 per-axis (~1.4 GB
  savings vs bf16). Embedding packer uses IntxUnpackedToInt8Tensor
  which supports gather.
- pack_model handles top-level FQNs (no parent module).
- C++ runner aligned with Qwen patterns: #ifdef guards for non-CUDA
  builds, better weight_sharing error handling, cudaDeviceSynchronize
  between prefill and decode.
- Test suite split into test_pipeline.py (CPU) and test_cuda_pipeline.py
  (CUDA) with shared fixtures. New chunked prefill correctness test.
- Prequantized checkpoint available at
  huggingface.co/SocialLocalMobile/gemma-4-31B-it-HQQ-INT4.
- Added Gemma 4 31B tests to cuda.yml CI workflow.
- Cleaned up stale terminology, docstrings, and comments throughout.
- quant/gguf.py: unpack Q4_K/Q6_K GGUF blocks to CanonicalQuantizedWeight,
  with iter_gguf_tensors for streaming (low peak memory). Validated
  against original bf16 weights (Q4_K: 7.9%, Q6_K: 1.9% error).
- gguf_loader.py: Gemma 4 31B GGUF key mapping + load_gguf_model.
  Handles tied embed/lm_head: embedding dequantized to bf16 (gather),
  lm_head keeps Q4_K (tinygemm matmul).
- export.py and inference.py: --gguf flag for direct GGUF file loading.
- quant/quantize.py: dequantize_weight (inverse of quantize_weight).
- quant/pack.py: pack_one for single-weight streaming; pack_model
  delegates to pack_one for unquantized, groups quantized by parent
  for multi-weight modules (MoE-compatible).
- quant/serialize.py: CanonicalQuantizedWeight.__post_init__ validation
  (dtype, shape, symmetric/zero consistency).
- Tests moved to tests/ folders (quant/tests/ and tests/).
- dequantize_weight now subtracts 8 from symmetric 4-bit qdata (stored
  as unsigned [0,15]) before scaling, matching the quantize_weight shift
- Guard test_gguf.py with skipUnless so CI doesn't break without gguf
- Install gguf in cuda.yml for GGUF test coverage
- Use torch.allclose instead of torch.equal for chunked prefill logit
  comparison to avoid CUDA FP flakiness
- Fix Usage docblock paths in test_pipeline.py and test_cuda_pipeline.py
- Fix float→uint64 truncation in main.cpp read_token (use llrintf)
- Add assert in RingKVCache.update to catch seq_len > buf_size misuse
- Add RingKVCache unit tests (sequential, wraparound, multi-token, assert)
- Add CanonicalQuantizedWeight __post_init__ validation error path tests
- Add GGUF Q4_K through tinygemm pack pipeline test (asymmetric)
- Add 8-bit asymmetric matmul test
- Add F16 GGUF tensor type test
- Document QuantConfig.bits as storage width and _INT8_PER_AXIS coupling
- serialize.py: add iter_load() generator that streams weights one at a
  time from safetensors, keeping peak memory proportional to the largest
  single weight instead of loading all weights into memory at once.
- pack_cuda.py: rewrite load_and_pack_for_cuda to use iter_load for
  streaming — avoids ~40 GB peak memory when loading the 31B checkpoint.
- __init__.py: remove low-level CUDA packer internals (pack_int4_for_cuda,
  pack_int8_for_cuda, pack_linear_for_cuda, pack_embedding_for_cuda) from
  the public API. Tests import these directly from pack_cuda.py.
Gemma's HuggingFace tokenizer does not auto-prepend BOS. Without it
the model's logits collapse. Add --bos_id (default 2) to prepend and
--eos_id (default 1) as a fallback stop token.
Delete the custom CanonicalQuantizedWeight dataclass and serialize.py
format. Quantized weights are now stored as torchao's native Int4Tensor
(4-bit) and IntxUnpackedToInt8Tensor (8-bit) subclasses, serialized via
torchao's safetensors integration.

Key changes:
- quantize_weight returns Int4Tensor or IntxUnpackedToInt8Tensor
- quantize_model returns a single state_dict (not two dicts)
- 8-bit quantization done in float32 to avoid bf16 precision loss
  (manual quantize + direct IntxUnpackedToInt8Tensor construction)
- Sensitive recipe uses HQQ asymmetric INT4 (scale + zero optimization)
- pack_model takes a single state_dict, dispatches by isinstance
- pack.py uses TorchAOBaseTensor for quantized weight detection
- GGUF unpacker produces Int4Tensor/IntxUnpackedToInt8Tensor directly
- serialize.py dissolved — callers inline torchao safetensors directly

Breaking change: existing prequantized checkpoints (old format) must
be regenerated with quantize_and_save.py.
- Use .detach() instead of .data when moving packed INT4 weight to CPU
  to preserve tensor subclass identity safely
- Remove unused loaded_keys set in load_and_pack_for_cuda
- Handle top-level tensor keys (no dot) in load_and_pack_for_cuda
Extend ReplaceEdgeOpWithTritonOpPass to select triton::sdpa_decode_splitk
for SDPA nodes where L_q=1 (decode) and L_kv exceeds 2048 (large KV
cache). This dramatically improves GPU utilization for full-attention
layers at long context lengths — standard SDPA launches only a handful
of CTAs (proportional to H_kv), while split-K partitions the KV sequence
across up to 128 CTAs.

Benchmarked on A100 with Gemma4 31B shapes at 128K context:
  Full-attention decode (H_kv=4, D=512, L_kv=131072):
    standard SDPA: 15.7ms/layer → split-K: 0.7ms/layer (22x)
  Sliding-attention decode (H_kv=16, D=256, L_kv=2048):
    unchanged (standard SDPA is faster for small L_kv)

The threshold of 2048 is chosen to match the sliding-window ring buffer
size — anything above is a full-attention cache where split-K wins.

No changes to model code — the pass inspects Q/K shapes in the exported
graph and selects the kernel automatically.
Adds executorch_cuda::int4_plain_mm custom op that reads Int4Tensor's plain
[N, K//2] nibble-packed format directly.

C shim (.pte runtime): W4A8 dp4a matvec with dynamic INT8 activation
quantization, 16-byte vectorized loads, warp-cooperative quantization.
No cuBLAS dependency.

Eager dispatch: M<=4 routes through the custom op (dp4a in .pte, dequant +
F.linear in eager). M>4 uses inline dequant + F.linear, which AOTI compiles
into the .so using inductor's own cuBLAS codegen.
Adapt gemma4_31b to upstream gemma4 changes (33419c0) that removed
precompute_freqs_cis in favor of on-the-fly RoPE computation:

- Store inv_freq buffer instead of precomputed [max_seq_len, head_dim]
  cos/sin tables — saves memory, matches qwen3_5_moe and gemma4 E2B
- Compute cos/sin per forward via torch.outer(positions, inv_freq)
- Fix gemma4/text_decoder/__init__.py to remove stale precompute_freqs_cis
  re-export
- Update model.md to reflect current architecture
@mergennachin mergennachin force-pushed the gemma4-31b-quant-framework branch from d519d97 to 6273bb2 Compare May 12, 2026 15:18
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok skimmed through. We can land this and continue working on it.

python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="

# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
pip install gguf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be installed by some reqirements.txt?

int32_t K = A.size(1);
int32_t N = qdata.size(0);

ET_CHECK(A.dtype() == c10::ScalarType::BFloat16);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: check dtype for scale, zp, and output?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int32_t blocks_per_m = (n_q8_blocks + Q8_WARPS - 1) / Q8_WARPS;
dim3 q8_grid(blocks_per_m, M);
dim3 q8_block(MV_WARP_SIZE, Q8_WARPS);
quantize_activations_q8_kernel<<<q8_grid, q8_block, 0, stream>>>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not inline this in the gemv? for M==1 there isn't a lot of work for this launch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for memory bound operations like decode, saving on kernel launch probably won't make a difference


float amax = fabsf(val);
for (int offset = 16; offset > 0; offset >>= 1)
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, offset));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

Comment on lines +159 to +160
float ws = __bfloat162float(__ldg(&scale_base[g * scale_stride]));
float wz = __bfloat162float(__ldg(&zero_base[g * scale_stride]));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit hoist these?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

gs = weight_tensor.block_size[-1]

M = x_2d.shape[0]
if M <= 4:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why not M == 1? may be important for spec dec cases



# ---------------------------------------------------------------------------
# Int4Tensor F.linear dispatch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you highlight this is export time trace through dispatch and not at runtime

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


w = weights["weight"]
if isinstance(w, (Int4Tensor, IntxUnpackedToInt8Tensor)):
module.weight = nn.Parameter(w, requires_grad=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights are already packed, so why is it called pack and not load?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights are in Int4Tensor and IntxUnpackedToInt8Tensor formats which are serializable and portable formats in ao that can interoperate with safetensors.

pack_ is mainly to pack into backend specific format. Initially I had the tiled4dformat for tinygemm but stopped doing special format, but left the pack name.

for pack_mlx, I can imagine we'll have custom formats.

for i, tok_id in enumerate(input_ids):
tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda")
pos = torch.tensor([i], dtype=torch.long, device="cuda")
sampled = model(tok, pos, temp_tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need different methods for export time trace through and kernel selection?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a test script to check if the model.py is still good in eager. It doesn't go through export or executorch at all.


return sliding_mask, full_mask

def forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here for different methods. Prefill dynamic shape, Decode static shape.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is exposed in export.py

Copy link
Copy Markdown
Contributor Author

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See inline replies above

- Add dtype checks for qdata (uint8/int8), scale (bf16), zero (bf16) in C shim
- Hoist weight scale/zero loads outside inner loop (reload only on group change)
- Clarify int4_dispatch.py docblock: runs at eager/trace time, not .pte runtime
- Clarify test docblock: tests eager dispatch, not C shim runtime
Copilot AI review requested due to automatic review settings May 12, 2026 18:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 41 out of 41 changed files in this pull request and generated 4 comments.

Comment on lines +163 to +167
# Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M).
max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)
seq_dim = Dim("seq_len", min=5, max=max_prefill)
print(f"Exporting prefill (T in [2, {max_prefill}])...")
with torch.no_grad():
Comment on lines +281 to +295
std::string run_method = (chunk_len == 1) ? "decode" : "prefill";

std::vector<int64_t> token_data(
prompt_tokens.begin() + prefill_pos,
prompt_tokens.begin() + prefill_pos + chunk_len);
std::vector<int64_t> pos_data(chunk_len);
for (int64_t i = 0; i < chunk_len; i++) {
pos_data[i] = prefill_pos + i;
}
auto tokens_tensor = from_blob(
token_data.data(),
{1, S(chunk_len)},
executorch::aten::ScalarType::Long);
auto pos_tensor = from_blob(
pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long);
Comment on lines +55 to +71
int32_t M = self->size(0);
int32_t N = qdata->size(0);
Tensor* C = nullptr;
std::array<int64_t, 2> c_shape = {M, N};
std::array<int64_t, 2> c_stride = {N, 1};
aoti_torch_empty_strided(
2,
c_shape.data(),
c_stride.data(),
static_cast<int32_t>(
executorch::backends::aoti::slim::c10::ScalarType::BFloat16),
static_cast<int32_t>(
executorch::backends::aoti::slim::c10::DeviceType::CUDA),
0,
&C);

_int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C);
Comment on lines +185 to +204
// ---------------------------------------------------------------------------
// Persistent Q8 buffer (lazy init, not thread-safe — single-stream only)
// ---------------------------------------------------------------------------

static Q8Block* g_q8_buf = nullptr;
static size_t g_q8_buf_size = 0;

static Q8Block* get_q8_buffer(size_t needed) {
if (g_q8_buf_size < needed) {
if (g_q8_buf)
cudaFree(g_q8_buf);
cudaError_t err = cudaMalloc(&g_q8_buf, needed);
ET_CHECK_MSG(
err == cudaSuccess,
"cudaMalloc failed for Q8 buffer: %s",
cudaGetErrorString(err));
g_q8_buf_size = needed;
}
return g_q8_buf;
}
@mergennachin mergennachin merged commit 9e36d62 into main May 12, 2026
184 of 192 checks passed
@mergennachin mergennachin deleted the gemma4-31b-quant-framework branch May 12, 2026 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants