Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch#19213
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19213
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 154 Pending, 1 Unclassified FailureAs of commit e7375a1 with merge base f1062a7 ( UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds a new Gemma 4 31B-IT example pipeline for ExecuTorch (CUDA backend), including a packing-agnostic quantization format + recipes, CUDA packers, export/inference scripts, a C++ runner, and CI coverage.
Changes:
- Introduces
examples/models/gemma4_31b/quant/with recipe → quantize → serialize → pack flow plus unit tests. - Adds Gemma 4 31B model implementation with hybrid attention and a sliding-window KV cache, plus export + eager inference entrypoints.
- Adds CUDA runner build targets and runs Gemma 4 31B tests in the CUDA GitHub Actions workflow.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/gemma4_31b/test_pipeline.py | CPU-only integration tests for quantize/save/load roundtrip and tiny checkpoint fixtures. |
| examples/models/gemma4_31b/test_cuda_pipeline.py | CUDA integration tests for pack/infer/export on a tiny model. |
| examples/models/gemma4_31b/sampler.py | GPU-side Gumbel-max sampler used by the exported model. |
| examples/models/gemma4_31b/quantize_and_save.py | CLI to quantize HF checkpoints and write packing-agnostic safetensors bundles + production recipes. |
| examples/models/gemma4_31b/quant/test_serialize.py | Unit tests for nibble packing and safetensors serialization format. |
| examples/models/gemma4_31b/quant/test_recipe.py | Unit tests for regex/layer-filter recipe matching + production recipe regression tests. |
| examples/models/gemma4_31b/quant/test_quantize.py | Unit tests for quantize_weight and quantize_model APIs (CPU + CUDA/HQQ paths). |
| examples/models/gemma4_31b/quant/test_pack_cuda.py | CUDA unit tests for int4/int8 packers and load-and-pack dispatcher behavior. |
| examples/models/gemma4_31b/quant/serialize.py | Canonical quantized weight format + safetensors save/load with versioned metadata. |
| examples/models/gemma4_31b/quant/recipe.py | Declarative quantization recipe/rule objects with regex FQN matching and optional layer filters. |
| examples/models/gemma4_31b/quant/quantize.py | Implements min-max and HQQ quantization into canonical (packing-free) representations. |
| examples/models/gemma4_31b/quant/pack_cuda.py | CUDA-specific packers converting canonical weights into torchao runtime tensor subclasses. |
| examples/models/gemma4_31b/quant/pack.py | Backend-agnostic pack dispatcher that assigns weights/buffers and calls module-type packers. |
| examples/models/gemma4_31b/quant/init.py | Public API re-exports for quant/ package. |
| examples/models/gemma4_31b/quant/README.md | Documentation of the quant framework, data flow, and backend extension points. |
| examples/models/gemma4_31b/model.py | Gemma 4 31B model definition, HF checkpoint loader, ring KV cache for sliding layers, runtime buffer materialization. |
| examples/models/gemma4_31b/model.md | Architecture/design notes for model + quant pipeline. |
| examples/models/gemma4_31b/main.cpp | ExecuTorch CUDA runner driving exported prefill/decode and HF tokenizer decoding. |
| examples/models/gemma4_31b/inference.py | Eager CUDA inference script loading prequantized weights, packing, and generating text. |
| examples/models/gemma4_31b/export.py | Export + lowering pipeline (decode + prefill methods) targeting the CUDA backend. |
| examples/models/gemma4_31b/init.py | Package marker for the new model example. |
| examples/models/gemma4_31b/README.md | User-facing instructions for quantize/export/inference/build/run workflows. |
| examples/models/gemma4_31b/CMakePresets.json | CMake preset for building the Gemma 4 31B CUDA runner. |
| examples/models/gemma4_31b/CMakeLists.txt | CMake build for the Gemma 4 31B runner, linking ExecuTorch + CUDA backend + tokenizer. |
| examples/models/gemma4/text_decoder/gemma4_norm.py | Replaces transformers RMSNorm dependency with a self-contained implementation. |
| examples/models/gemma4/text_decoder/init.py | Exposes attention/norm/MLP primitives used by gemma4_31b for shared numerically-sensitive ops. |
| Makefile | Adds gemma4_31b-cuda build target. |
| .github/workflows/cuda.yml | Adds Gemma 4 31B quant + pipeline tests to the CUDA CI job. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Adds a full Gemma 4 31B-IT example to ExecuTorch, including a new packing-agnostic quantization framework, CUDA packing/export/inference tooling, GGUF import support, a C++ CUDA runner, and a comprehensive test suite integrated into CI.
Changes:
- Introduce
examples/models/gemma4_31b/quant/canonical quantization framework (recipe → quantize → serialize → pack) with CUDA packers and safetensors persistence. - Add Gemma 4 31B-IT model implementation with ring-buffer KV cache for sliding-window layers, plus export/eager inference/runner scripts.
- Add unit + integration tests (CPU and CUDA) and run them in the CUDA CI workflow.
Reviewed changes
Copilot reviewed 31 out of 31 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/gemma4_31b/tests/test_pipeline.py | CPU-only integration tests for quantize→save→load roundtrip and fixtures for CUDA tests. |
| examples/models/gemma4_31b/tests/test_cuda_pipeline.py | CUDA integration tests for packing, generation, chunked prefill, and export. |
| examples/models/gemma4_31b/sampler.py | GPU-side Gumbel-max sampler used by the exported model for on-device sampling. |
| examples/models/gemma4_31b/quantize_and_save.py | CLI to quantize HF checkpoints and persist packing-agnostic safetensors checkpoints. |
| examples/models/gemma4_31b/quant/tests/test_serialize.py | Unit tests for canonical serialize/deserialize and nibble pack/unpack. |
| examples/models/gemma4_31b/quant/tests/test_recipe.py | Unit tests for regex/layer-filter recipe matching + production recipe regression tests. |
| examples/models/gemma4_31b/quant/tests/test_quantize.py | Unit tests for canonical quantize/dequantize APIs and model-walking quantization. |
| examples/models/gemma4_31b/quant/tests/test_pack_cuda.py | CUDA unit tests for packing canonical weights into CUDA runtime formats and dispatch. |
| examples/models/gemma4_31b/quant/tests/test_gguf.py | Unit tests validating GGUF Q4_K/Q6_K unpacking against reference formulas. |
| examples/models/gemma4_31b/quant/serialize.py | Canonical CQW representation + safetensors format, nibble packing, save/load. |
| examples/models/gemma4_31b/quant/recipe.py | Declarative quantization recipe/rule/config structures and matching logic. |
| examples/models/gemma4_31b/quant/quantize.py | Canonical quantization implementations (min_max, HQQ) + per-model quantization walk. |
| examples/models/gemma4_31b/quant/pack_cuda.py | CUDA packers from canonical weights to tinygemm/int8 subclass runtime formats. |
| examples/models/gemma4_31b/quant/pack.py | Backend-agnostic pack dispatcher grouping weights per module and applying packers. |
| examples/models/gemma4_31b/quant/gguf.py | GGUF tensor unpacker/streamer to canonical CQW or dense tensors. |
| examples/models/gemma4_31b/quant/init.py | Public API re-exports for quant/ package. |
| examples/models/gemma4_31b/quant/README.md | Documentation of the quant framework layers, data flow, and on-disk format. |
| examples/models/gemma4_31b/model.py | Gemma 4 31B-IT model definition, ring-buffer KV cache, HF load/remap, runtime buffer materialization. |
| examples/models/gemma4_31b/model.md | Architecture/design notes including attention flavors, caching strategy, and export methods. |
| examples/models/gemma4_31b/main.cpp | CUDA ExecuTorch runner driving exported prefill/decode with tokenizer integration. |
| examples/models/gemma4_31b/inference.py | Eager CUDA inference path (load/pack/materialize/compile + generate loop). |
| examples/models/gemma4_31b/gguf_loader.py | Streams GGUF weights, remaps keys, handles tied weights, and packs into runtime model. |
| examples/models/gemma4_31b/export.py | Export+lower pipeline producing .pte/.ptd with decode/prefill methods and CUDA lowering. |
| examples/models/gemma4_31b/init.py | Package marker for the Gemma 4 31B example. |
| examples/models/gemma4_31b/README.md | User-facing docs for quantize/export/inference/build/run flows. |
| examples/models/gemma4_31b/CMakePresets.json | CMake workflow presets to build the CUDA runner. |
| examples/models/gemma4_31b/CMakeLists.txt | CMake build for the Gemma 4 31B CUDA runner binary. |
| examples/models/gemma4/text_decoder/gemma4_norm.py | Replace transformers RMSNorm dependency with a self-contained implementation. |
| examples/models/gemma4/text_decoder/init.py | Export shared primitives (attention helpers, KV cache, MLP, norm) for reuse by gemma4_31b. |
| Makefile | Add gemma4_31b-cuda build target and help entry. |
| .github/workflows/cuda.yml | Run Gemma 4 31B quant and pipeline tests as part of CUDA unit test job. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 2m 54s —— View job PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework
Overall AssessmentThis is a well-structured, substantial addition (~5300 lines) that introduces a full Gemma 4 31B-IT pipeline. The architecture is clean — the quantization framework separates concerns nicely (recipe → quantize → serialize → pack), and the code is well-organized for reuse with future models. The test coverage is good (78 tests across multiple files), and the CI integration is appropriate. I've reviewed the full diff and have the following feedback: Architecture & Design (Positive)
Issues1.
|
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 2m 30s —— View job PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework
Overall AssessmentThis is a high-quality, well-architected PR (~5500 lines) adding a full Gemma 4 31B-IT pipeline. The separation of concerns is clean — recipe declares intent, quantize produces canonical form, serialize persists, pack converts to backend format. Test coverage is comprehensive (78 tests), and the PR addresses the previous review's concerns well (added the Architecture & Design (Strengths)
Issues & Suggestions1.
|
| File | Scope | Key tests |
|---|---|---|
test_recipe.py |
Regex matching, layer filters, production recipe regression | 14 test cases |
test_quantize.py |
quantize_weight/quantize_model API contract, HQQ paths |
12 test cases |
test_serialize.py |
Nibble packing, save/load, 3D tensors, versioned headers | ~15 test cases |
test_pack_cuda.py |
Module-type dispatch, matmul accuracy, mixed precision, MoE | ~20 test cases |
test_gguf.py |
Q4_K/Q6_K unpacking vs reference formulas, serialize roundtrip | 7 test cases |
test_pipeline.py |
Quantize→save→load integration, RingKVCache isolation | 7 test cases |
test_cuda_pipeline.py |
Pack→generate, chunked prefill correctness, export e2e | 4 test cases |
Good coverage. The RingKVCache unit tests (test_pipeline.py:212-266) are a welcome addition addressing the earlier review's gap.
Summary
This is a well-executed, production-quality addition. The architecture decisions are sound, the code is clean and readable (comments only where non-obvious), and the test coverage is thorough. The main items from previous reviews have been addressed. No blocking issues remain.
Actionable items (non-blocking):
- Comment the
group_size=5376coupling to hidden_size inquantize_and_save.py:51 - Consider adding
@torch.no_grad()decorator toGemma4_31B.forwardsince this is inference-only
There was a problem hiding this comment.
Pull request overview
Adds a full Gemma 4 31B-IT text-only pipeline to the ExecuTorch examples, including a reusable packing-agnostic quantization framework (recipe/quantize/serialize/pack), GGUF import, CUDA packing/export/inference flows, and a CUDA runner, with CI coverage.
Changes:
- Introduces
examples/models/gemma4_31b/model implementation (ring-buffer KV cache), export/inference scripts, GGUF loader, and C++ CUDA runner + build targets. - Adds a new
quant/framework (recipes, min-max + HQQ quantization, safetensors format, CUDA packing, GGUF Q4_K/Q6_K unpack). - Adds unit/integration tests and wires them into the CUDA GitHub Actions workflow.
Reviewed changes
Copilot reviewed 31 out of 31 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/gemma4_31b/tests/test_pipeline.py | CPU-only pipeline + RingKVCache tests for quant/save/load and cache behavior |
| examples/models/gemma4_31b/tests/test_cuda_pipeline.py | CUDA integration tests for pack/infer/export + chunked prefill equivalence |
| examples/models/gemma4_31b/sampler.py | GPU-side Gumbel-max sampler (mirrors Qwen sampler behavior) |
| examples/models/gemma4_31b/quantize_and_save.py | CLI to quantize HF checkpoint and save canonical safetensors checkpoint |
| examples/models/gemma4_31b/quant/tests/test_serialize.py | Unit tests for canonical format + nibble packing + safetensors I/O |
| examples/models/gemma4_31b/quant/tests/test_recipe.py | Unit tests for regex/layer-filter recipe matching + production recipe regression |
| examples/models/gemma4_31b/quant/tests/test_quantize.py | Unit tests for min-max + HQQ quantize/dequantize and quantize_model behavior |
| examples/models/gemma4_31b/quant/tests/test_pack_cuda.py | CUDA unit tests for packers (int4 tinygemm, int8 intx, dispatch/grouping) |
| examples/models/gemma4_31b/quant/tests/test_gguf.py | Unit tests for GGUF Q4_K/Q6_K unpacking and serialize roundtrip |
| examples/models/gemma4_31b/quant/serialize.py | CanonicalQuantizedWeight + serialize/deserialize + safetensors save/load |
| examples/models/gemma4_31b/quant/recipe.py | QuantConfig/QuantRule/QuantRecipe declarative matching logic |
| examples/models/gemma4_31b/quant/quantize.py | min-max + HQQ quantize_weight/dequantize_weight + per-model quantization |
| examples/models/gemma4_31b/quant/pack_cuda.py | CUDA packers for Linear/Embedding and load+pack convenience wrapper |
| examples/models/gemma4_31b/quant/pack.py | Backend-agnostic pack_model/pack_one dispatch + grouping by parent module |
| examples/models/gemma4_31b/quant/gguf.py | GGUF tensor unpack + streaming iterator to canonical representation |
| examples/models/gemma4_31b/quant/init.py | Public API exports for quant framework and CUDA packers |
| examples/models/gemma4_31b/quant/README.md | Framework overview, dataflow, and backend/model extension guidance |
| examples/models/gemma4_31b/model.py | Gemma4 31B model, ring-buffer KV cache, HF loader, runtime buffer materialization |
| examples/models/gemma4_31b/model.md | Architecture/design notes + export/quantization details |
| examples/models/gemma4_31b/main.cpp | CUDA runner driving exported prefill/decode methods with chunked prefill |
| examples/models/gemma4_31b/inference.py | Eager CUDA inference with torch.compile + GGUF/prequantized load paths |
| examples/models/gemma4_31b/gguf_loader.py | Streams GGUF tensors, remaps keys, handles tied embed/lm_head, packs model |
| examples/models/gemma4_31b/export.py | Export+lower (decode+prefill) to ExecuTorch with CUDA backend and metadata |
| examples/models/gemma4_31b/init.py | Package marker for gemma4_31b example |
| examples/models/gemma4_31b/README.md | End-to-end usage docs for quantize/export/infer/build/run |
| examples/models/gemma4_31b/CMakePresets.json | CMake workflow presets for building the runner (CUDA preset) |
| examples/models/gemma4_31b/CMakeLists.txt | Runner build configuration + CUDA backend linkage |
| Makefile | Adds gemma4_31b-cuda build target |
| .github/workflows/cuda.yml | Runs gemma4_31b quant + pipeline tests in CUDA CI |
| examples/models/gemma4/text_decoder/gemma4_norm.py | Replaces transformers RMSNorm import with self-contained implementation |
| examples/models/gemma4/text_decoder/init.py | Exposes additional shared primitives used by gemma4_31b model |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 2m 30s —— View job PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework
Overall AssessmentThis is a well-structured, high-quality addition (~5600 lines across 31 files) introducing a full Gemma 4 31B-IT text-only pipeline. The architecture is clean — the quantization framework separates concerns into orthogonal layers (recipe → quantize → serialize → pack), the model implementation is faithful to the HF reference, and the streaming GGUF support keeps peak memory tractable. Test coverage is comprehensive (78 tests), and the PR has clearly evolved through multiple review iterations addressing earlier feedback. The codebase is ready for merge with minor non-blocking suggestions below. Architecture Strengths
Issues & Suggestions1.
|
| File | Tests | Key coverage |
|---|---|---|
quant/tests/test_recipe.py |
14 | Regex matching, layer filters, production recipe regression |
quant/tests/test_quantize.py |
12 | quantize_weight/dequantize_weight roundtrip, HQQ paths, error cases |
quant/tests/test_serialize.py |
~15 | Nibble pack/unpack, save/load, 3D tensors, format versioning |
quant/tests/test_pack_cuda.py |
~20 | Module dispatch, matmul accuracy, mixed precision, MoE grouping |
quant/tests/test_gguf.py |
7 | Q4_K/Q6_K vs reference formulas, serialize roundtrip, edge cases |
tests/test_pipeline.py |
7 | Quantize→save→load, RingKVCache isolation, corrupted checkpoint |
tests/test_cuda_pipeline.py |
4 | Pack→generate, chunked prefill correctness, export e2e |
The GGUF test skip handling (test_gguf.py:21-28) is properly implemented — @unittest.skipUnless(_HAS_GGUF, ...) on each class, conditional import of unpack_gguf_tensor, and CI installs gguf before running (cuda.yml:152).
The RingKVCache unit tests (test_pipeline.py:212-266) cover sequential write, wraparound, multi-token, and oversized-prefill assertion — good targeted coverage for the most subtle piece of logic.
Summary
This is production-quality work with sound architecture decisions for memory efficiency at 31B scale. The main actionable items are:
- Add
<cstring>and<cmath>includes inmain.cpp(fragile transitive dependency) - Update
group_sizecomment inrecipe.pyto reflect actual supported values - Consider reshaping qdata in
deserialize()for 8-bit weights (defensive hardening)
None of these are blocking — the code is correct as written and the invariants are upheld by the export/runner logic. Previous review iterations (symmetric INT4 dequant fix, RingKVCache assert, GGUF skip handling, allclose tolerances) have all been addressed properly.
|
|
||
|
|
||
| @dataclass | ||
| class CanonicalQuantizedWeight: |
There was a problem hiding this comment.
There is subclass support for this "canonical" format that is later converted to specialized formats.
The basic flow is:
- On your device-specific subclass, define a class method constructor "from_{canonical_format}_tensor", e.g., Int4PreshuffledTensor.from_int4_tensor(tensor)
- Load "canonical" tensor subclass checkpoint
- Iterate through canonical tensors, applying the conversion above.
IIRC, this is the path used by torchao in vLLM (cc @jerryzh168 to confirm).
CanonicalQuantizedWeight feels like a new abstraction that we'd maintain to do the same thing?
There was a problem hiding this comment.
@metascroy Good point.
I'll use Int4Tensor and IntxUnpackedToInt8Tensor as they seem like canonical/portable format that can serialized to/from safetensors.
Will delete CanonicalQuantizedWeight
There was a problem hiding this comment.
@metascroy Made the change. Thank for the recommendation
49cfc55 to
69474e3
Compare
|
Claude finished @mergennachin's task in 3m 31s —— View job PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework
Overall AssessmentThis is a substantial, well-architected PR (~6400 lines, 41 files, 11 commits) that introduces both new CUDA backend infrastructure (dp4a decode kernel, split-K flash-decoding) and a full Gemma 4 31B-IT pipeline. The code is clean, the architecture is sound for reuse, and the test coverage is comprehensive. The quantization framework's layered design (recipe → quantize → serialize → pack) is excellent for extensibility. This review focuses on the newest additions (dp4a kernel, split-K SDPA selection, Int4Tensor dispatch) and fresh observations on the model/quant code that previous reviews haven't fully covered. New CUDA Backend Infrastructure1. W4A8 dp4a Kernel (
|
| Area | File | Coverage |
|---|---|---|
| dp4a dispatch | backends/cuda/tests/test_int4_dispatch.py |
Decode/prefill/batched/bias/compile/large shapes |
| Split-K SDPA | backends/cuda/tests/test_sdpa_splitk_replacement.py |
Large KV→splitk, small KV→standard, non-pow2 head |
| dp4a C shim | backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp |
Multi-group, single-group, prefill batch, group_size=128 |
| Recipe | quant/tests/test_recipe.py |
Regex, layers, production recipe regression |
| Quantize | quant/tests/test_quantize.py |
Types, roundtrip, HQQ, small-weight precision |
| Pack CUDA | quant/tests/test_pack_cuda.py |
Dispatch, matmul accuracy, mixed precision |
| GGUF | quant/tests/test_gguf.py |
Q4_K/Q6_K unpack, zero-scale edge case |
| Pipeline | tests/test_pipeline.py |
E2E roundtrip, RingKVCache, GGUF key mapping |
| CUDA Pipeline | tests/test_cuda_pipeline.py |
Generate, chunked prefill, export, Int4 passthrough |
The test_int4_dispatch.py tests are well-designed — they compare against the unquantized bf16 reference weight (not the dequantized weight), which validates the full quantize→dispatch pipeline end-to-end.
Summary
This is a high-quality PR that adds significant CUDA backend infrastructure alongside the Gemma 4 31B model. The dp4a kernel and split-K SDPA selection are well-engineered and properly tested. The quantization framework is ready for reuse.
Actionable items:
| # | Severity | Description |
|---|---|---|
| 2 | Medium | get_q8_buffer doesn't check cudaMalloc return — potential UAF |
| 3 | Low | Missing K % 32 == 0 validation in dp4a kernel |
| 9 | Low | Missing <cstring>/<cmath>/<cstdio> includes in runner |
| 10 | Medium | Log filter level < 'W' may suppress Error/Fatal |
| 7 | Low | quantize_model prints to stdout unconditionally |
| 8 | Low | Stale group_size comment |
Items 2 and 10 are the most important for correctness/debuggability. The rest are hardening/hygiene.
206feef to
b777a05
Compare
|
@claude Review this PR now |
|
Claude finished @mergennachin's task in 3m 47s —— View job PR Review: Gemma 4 31B-IT Model, Export, and Quantization Framework
Overall AssessmentThis is a well-architected, production-quality PR (~6400 lines, 41 files, 11 commits) that introduces both new CUDA backend infrastructure (W4A8 dp4a decode kernel, split-K flash-decoding SDPA selection) and a complete Gemma 4 31B-IT text-only pipeline. The code is clean, the layered design is excellent for reuse (Qwen 3.5 MoE and future models), and the test coverage is thorough. Previous review iterations have addressed all major correctness issues. This review is fresh against the final state of the branch. The earlier concerns (symmetric INT4 dequant bug, log-level filter suppressing errors, missing Architecture Strengths
Issues1.
|
| Area | File | Highlights |
|---|---|---|
| dp4a dispatch | backends/cuda/tests/test_int4_dispatch.py |
Decode M=1, prefill M=64, batched, bias, symmetric, group_size=128, large shapes, compile |
| Split-K SDPA | backends/cuda/tests/test_sdpa_splitk_replacement.py |
Large KV→splitk, small KV→standard, non-pow2 head_dim |
| dp4a C shim | backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp |
Multi-group accuracy, single-group, prefill batch, group_size=128 |
| Recipe | quant/tests/test_recipe.py |
Regex, layer filter, production recipe regression (14 tests) |
| Quantize | quant/tests/test_quantize.py |
quantize/dequantize roundtrip, HQQ, error cases (12 tests) |
| Pack CUDA | quant/tests/test_pack_cuda.py |
Module dispatch, matmul accuracy, mixed precision (~20 tests) |
| GGUF | quant/tests/test_gguf.py |
Q4_K/Q6_K vs reference formulas, zero-scale edge case (7 tests) |
| Safetensors | quant/tests/test_safetensors_roundtrip.py |
Int4Tensor + IntxUnpackedToInt8Tensor roundtrip |
| Pipeline | tests/test_pipeline.py |
E2E roundtrip, RingKVCache unit tests, GGUF key mapping |
| CUDA Pipeline | tests/test_cuda_pipeline.py |
Generate, chunked prefill correctness, export, Int4 passthrough |
The test_chunked_prefill_matches_sequential test is particularly valuable — it verifies the ring-buffer's most subtle correctness property (wraparound produces identical logits to sequential execution).
Summary
This PR is ready for merge. The architecture is sound, the code is clean, all prior review issues have been addressed, and test coverage is comprehensive. The remaining items are minor hardening suggestions, none blocking:
| # | Severity | Item |
|---|---|---|
| 1 | Low | get_q8_buffer global static is not freed / not multi-stream safe (documented) |
| 2 | Low | _pick_sdpa_kernel shape metadata access relies on outer try/except |
| 3 | Low | try/except AttributeError: pass in cuda_backend.py could mask typos |
| 4 | Negligible | \r progress output garbles CI logs |
| 5 | Low | _move_to_cuda mutates during iteration (correct in CPython) |
The codebase is well-positioned for reuse — adding Qwen 3.5 MoE will require only a recipe and optionally a custom packer, as designed.
|
|
||
| def _read_f16(raw: torch.Tensor, col_start: int, col_end: int) -> torch.Tensor: | ||
| """Read fp16 field from block bytes, return float32.""" | ||
| return raw[:, col_start:col_end].contiguous().view(torch.float16).float() |
| return _raw_tensor(tensor_data).view(torch.float32).reshape(shape).clone() | ||
| elif tensor_type == GGMLQuantizationType.F16: | ||
| return ( | ||
| _raw_tensor(tensor_data) | ||
| .view(torch.float16) |
| # Int4Tensor stores qdata as nibble-packed uint8 (N, K//2) | ||
| q = int_data.to(torch.uint8) | ||
| packed = q[..., ::2] | (q[..., 1::2] << 4) | ||
|
|
||
| # Int4Tensor stores scale/zero as (K//gs, N) — transposed from our (N, K//gs) | ||
| return Int4Tensor( | ||
| qdata=packed, | ||
| scale=scale.t().contiguous(), | ||
| zero_point=zero_point.t().contiguous(), | ||
| block_size=[1, config.group_size], | ||
| shape=torch.Size(int_data.shape), | ||
| ) |
…uTorch Text-only export of Gemma 4 31B-IT to ExecuTorch with the CUDA backend and INT4/INT8 weight quantization via a new packing-agnostic quant/ framework. The quant/ package separates quantization into four concerns: - recipe.py: declarative QuantRecipe with regex FQN matching - quantize.py: produces CanonicalQuantizedWeight (min_max, HQQ) - serialize.py: save/load to safetensors with versioned headers - pack.py + pack_cuda.py: per-module packer dispatch for CUDA Two production recipes: "default" (INT4 min_max + INT8 embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ elsewhere). Sliding window attention uses a ring-buffer KV cache (2x window size) for the 50 sliding layers, saving memory for long sequences. The 10 full-attention layers use a standard flat KV cache. Includes C++ runner (main.cpp), eager inference script, and 60+ unit and integration tests across quant/ and pipeline test files.
- Sliding window layers use RingKVCache (2×window) instead of flat max_seq_len buffer, reducing KV cache memory for long sequences. - Prefill is capped to ring buffer size; the C++ runner chunks longer prompts automatically via get_max_prefill_chunk metadata. - Both recipes now quantize embed_tokens to INT8 per-axis (~1.4 GB savings vs bf16). Embedding packer uses IntxUnpackedToInt8Tensor which supports gather. - pack_model handles top-level FQNs (no parent module). - C++ runner aligned with Qwen patterns: #ifdef guards for non-CUDA builds, better weight_sharing error handling, cudaDeviceSynchronize between prefill and decode. - Test suite split into test_pipeline.py (CPU) and test_cuda_pipeline.py (CUDA) with shared fixtures. New chunked prefill correctness test. - Prequantized checkpoint available at huggingface.co/SocialLocalMobile/gemma-4-31B-it-HQQ-INT4. - Added Gemma 4 31B tests to cuda.yml CI workflow. - Cleaned up stale terminology, docstrings, and comments throughout.
- quant/gguf.py: unpack Q4_K/Q6_K GGUF blocks to CanonicalQuantizedWeight, with iter_gguf_tensors for streaming (low peak memory). Validated against original bf16 weights (Q4_K: 7.9%, Q6_K: 1.9% error). - gguf_loader.py: Gemma 4 31B GGUF key mapping + load_gguf_model. Handles tied embed/lm_head: embedding dequantized to bf16 (gather), lm_head keeps Q4_K (tinygemm matmul). - export.py and inference.py: --gguf flag for direct GGUF file loading. - quant/quantize.py: dequantize_weight (inverse of quantize_weight). - quant/pack.py: pack_one for single-weight streaming; pack_model delegates to pack_one for unquantized, groups quantized by parent for multi-weight modules (MoE-compatible). - quant/serialize.py: CanonicalQuantizedWeight.__post_init__ validation (dtype, shape, symmetric/zero consistency). - Tests moved to tests/ folders (quant/tests/ and tests/).
- dequantize_weight now subtracts 8 from symmetric 4-bit qdata (stored as unsigned [0,15]) before scaling, matching the quantize_weight shift - Guard test_gguf.py with skipUnless so CI doesn't break without gguf - Install gguf in cuda.yml for GGUF test coverage - Use torch.allclose instead of torch.equal for chunked prefill logit comparison to avoid CUDA FP flakiness - Fix Usage docblock paths in test_pipeline.py and test_cuda_pipeline.py
- Fix float→uint64 truncation in main.cpp read_token (use llrintf) - Add assert in RingKVCache.update to catch seq_len > buf_size misuse - Add RingKVCache unit tests (sequential, wraparound, multi-token, assert) - Add CanonicalQuantizedWeight __post_init__ validation error path tests - Add GGUF Q4_K through tinygemm pack pipeline test (asymmetric) - Add 8-bit asymmetric matmul test - Add F16 GGUF tensor type test - Document QuantConfig.bits as storage width and _INT8_PER_AXIS coupling
- serialize.py: add iter_load() generator that streams weights one at a time from safetensors, keeping peak memory proportional to the largest single weight instead of loading all weights into memory at once. - pack_cuda.py: rewrite load_and_pack_for_cuda to use iter_load for streaming — avoids ~40 GB peak memory when loading the 31B checkpoint. - __init__.py: remove low-level CUDA packer internals (pack_int4_for_cuda, pack_int8_for_cuda, pack_linear_for_cuda, pack_embedding_for_cuda) from the public API. Tests import these directly from pack_cuda.py.
Gemma's HuggingFace tokenizer does not auto-prepend BOS. Without it the model's logits collapse. Add --bos_id (default 2) to prepend and --eos_id (default 1) as a fallback stop token.
Delete the custom CanonicalQuantizedWeight dataclass and serialize.py format. Quantized weights are now stored as torchao's native Int4Tensor (4-bit) and IntxUnpackedToInt8Tensor (8-bit) subclasses, serialized via torchao's safetensors integration. Key changes: - quantize_weight returns Int4Tensor or IntxUnpackedToInt8Tensor - quantize_model returns a single state_dict (not two dicts) - 8-bit quantization done in float32 to avoid bf16 precision loss (manual quantize + direct IntxUnpackedToInt8Tensor construction) - Sensitive recipe uses HQQ asymmetric INT4 (scale + zero optimization) - pack_model takes a single state_dict, dispatches by isinstance - pack.py uses TorchAOBaseTensor for quantized weight detection - GGUF unpacker produces Int4Tensor/IntxUnpackedToInt8Tensor directly - serialize.py dissolved — callers inline torchao safetensors directly Breaking change: existing prequantized checkpoints (old format) must be regenerated with quantize_and_save.py.
- Use .detach() instead of .data when moving packed INT4 weight to CPU to preserve tensor subclass identity safely - Remove unused loaded_keys set in load_and_pack_for_cuda - Handle top-level tensor keys (no dot) in load_and_pack_for_cuda
Extend ReplaceEdgeOpWithTritonOpPass to select triton::sdpa_decode_splitk
for SDPA nodes where L_q=1 (decode) and L_kv exceeds 2048 (large KV
cache). This dramatically improves GPU utilization for full-attention
layers at long context lengths — standard SDPA launches only a handful
of CTAs (proportional to H_kv), while split-K partitions the KV sequence
across up to 128 CTAs.
Benchmarked on A100 with Gemma4 31B shapes at 128K context:
Full-attention decode (H_kv=4, D=512, L_kv=131072):
standard SDPA: 15.7ms/layer → split-K: 0.7ms/layer (22x)
Sliding-attention decode (H_kv=16, D=256, L_kv=2048):
unchanged (standard SDPA is faster for small L_kv)
The threshold of 2048 is chosen to match the sliding-window ring buffer
size — anything above is a full-attention cache where split-K wins.
No changes to model code — the pass inspects Q/K shapes in the exported
graph and selects the kernel automatically.
Adds executorch_cuda::int4_plain_mm custom op that reads Int4Tensor's plain [N, K//2] nibble-packed format directly. C shim (.pte runtime): W4A8 dp4a matvec with dynamic INT8 activation quantization, 16-byte vectorized loads, warp-cooperative quantization. No cuBLAS dependency. Eager dispatch: M<=4 routes through the custom op (dp4a in .pte, dequant + F.linear in eager). M>4 uses inline dequant + F.linear, which AOTI compiles into the .so using inductor's own cuBLAS codegen.
Adapt gemma4_31b to upstream gemma4 changes (33419c0) that removed precompute_freqs_cis in favor of on-the-fly RoPE computation: - Store inv_freq buffer instead of precomputed [max_seq_len, head_dim] cos/sin tables — saves memory, matches qwen3_5_moe and gemma4 E2B - Compute cos/sin per forward via torch.outer(positions, inv_freq) - Fix gemma4/text_decoder/__init__.py to remove stale precompute_freqs_cis re-export - Update model.md to reflect current architecture
d519d97 to
6273bb2
Compare
digantdesai
left a comment
There was a problem hiding this comment.
Ok skimmed through. We can land this and continue working on it.
| python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" | ||
|
|
||
| # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) | ||
| pip install gguf |
There was a problem hiding this comment.
should this be installed by some reqirements.txt?
| int32_t K = A.size(1); | ||
| int32_t N = qdata.size(0); | ||
|
|
||
| ET_CHECK(A.dtype() == c10::ScalarType::BFloat16); |
There was a problem hiding this comment.
Nit: check dtype for scale, zp, and output?
| int32_t blocks_per_m = (n_q8_blocks + Q8_WARPS - 1) / Q8_WARPS; | ||
| dim3 q8_grid(blocks_per_m, M); | ||
| dim3 q8_block(MV_WARP_SIZE, Q8_WARPS); | ||
| quantize_activations_q8_kernel<<<q8_grid, q8_block, 0, stream>>>( |
There was a problem hiding this comment.
why not inline this in the gemv? for M==1 there isn't a lot of work for this launch.
There was a problem hiding this comment.
for memory bound operations like decode, saving on kernel launch probably won't make a difference
|
|
||
| float amax = fabsf(val); | ||
| for (int offset = 16; offset > 0; offset >>= 1) | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, offset)); |
| float ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); | ||
| float wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); |
| gs = weight_tensor.block_size[-1] | ||
|
|
||
| M = x_2d.shape[0] | ||
| if M <= 4: |
There was a problem hiding this comment.
curious why not M == 1? may be important for spec dec cases
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Int4Tensor F.linear dispatch |
There was a problem hiding this comment.
can you highlight this is export time trace through dispatch and not at runtime
|
|
||
| w = weights["weight"] | ||
| if isinstance(w, (Int4Tensor, IntxUnpackedToInt8Tensor)): | ||
| module.weight = nn.Parameter(w, requires_grad=False) |
There was a problem hiding this comment.
The weights are already packed, so why is it called pack and not load?
There was a problem hiding this comment.
The weights are in Int4Tensor and IntxUnpackedToInt8Tensor formats which are serializable and portable formats in ao that can interoperate with safetensors.
pack_ is mainly to pack into backend specific format. Initially I had the tiled4dformat for tinygemm but stopped doing special format, but left the pack name.
for pack_mlx, I can imagine we'll have custom formats.
| for i, tok_id in enumerate(input_ids): | ||
| tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") | ||
| pos = torch.tensor([i], dtype=torch.long, device="cuda") | ||
| sampled = model(tok, pos, temp_tensor) |
There was a problem hiding this comment.
don't we need different methods for export time trace through and kernel selection?
There was a problem hiding this comment.
This is just a test script to check if the model.py is still good in eager. It doesn't go through export or executorch at all.
|
|
||
| return sliding_mask, full_mask | ||
|
|
||
| def forward( |
There was a problem hiding this comment.
same here for different methods. Prefill dynamic shape, Decode static shape.
There was a problem hiding this comment.
Yes, it is exposed in export.py
- Add dtype checks for qdata (uint8/int8), scale (bf16), zero (bf16) in C shim - Hoist weight scale/zero loads outside inner loop (reload only on group change) - Clarify int4_dispatch.py docblock: runs at eager/trace time, not .pte runtime - Clarify test docblock: tests eager dispatch, not C shim runtime
| # Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M). | ||
| max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) | ||
| seq_dim = Dim("seq_len", min=5, max=max_prefill) | ||
| print(f"Exporting prefill (T in [2, {max_prefill}])...") | ||
| with torch.no_grad(): |
| std::string run_method = (chunk_len == 1) ? "decode" : "prefill"; | ||
|
|
||
| std::vector<int64_t> token_data( | ||
| prompt_tokens.begin() + prefill_pos, | ||
| prompt_tokens.begin() + prefill_pos + chunk_len); | ||
| std::vector<int64_t> pos_data(chunk_len); | ||
| for (int64_t i = 0; i < chunk_len; i++) { | ||
| pos_data[i] = prefill_pos + i; | ||
| } | ||
| auto tokens_tensor = from_blob( | ||
| token_data.data(), | ||
| {1, S(chunk_len)}, | ||
| executorch::aten::ScalarType::Long); | ||
| auto pos_tensor = from_blob( | ||
| pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); |
| int32_t M = self->size(0); | ||
| int32_t N = qdata->size(0); | ||
| Tensor* C = nullptr; | ||
| std::array<int64_t, 2> c_shape = {M, N}; | ||
| std::array<int64_t, 2> c_stride = {N, 1}; | ||
| aoti_torch_empty_strided( | ||
| 2, | ||
| c_shape.data(), | ||
| c_stride.data(), | ||
| static_cast<int32_t>( | ||
| executorch::backends::aoti::slim::c10::ScalarType::BFloat16), | ||
| static_cast<int32_t>( | ||
| executorch::backends::aoti::slim::c10::DeviceType::CUDA), | ||
| 0, | ||
| &C); | ||
|
|
||
| _int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C); |
| // --------------------------------------------------------------------------- | ||
| // Persistent Q8 buffer (lazy init, not thread-safe — single-stream only) | ||
| // --------------------------------------------------------------------------- | ||
|
|
||
| static Q8Block* g_q8_buf = nullptr; | ||
| static size_t g_q8_buf_size = 0; | ||
|
|
||
| static Q8Block* get_q8_buffer(size_t needed) { | ||
| if (g_q8_buf_size < needed) { | ||
| if (g_q8_buf) | ||
| cudaFree(g_q8_buf); | ||
| cudaError_t err = cudaMalloc(&g_q8_buf, needed); | ||
| ET_CHECK_MSG( | ||
| err == cudaSuccess, | ||
| "cudaMalloc failed for Q8 buffer: %s", | ||
| cudaGetErrorString(err)); | ||
| g_q8_buf_size = needed; | ||
| } | ||
| return g_q8_buf; | ||
| } |
Text-only export of Gemma 4 31B-IT to ExecuTorch with INT4/INT8 weight quantization. Quantized weights use torchao's native tensor subclasses (Int4Tensor, IntxUnpackedToInt8Tensor) for serialization, aligning with the torchao ecosystem.
quant/ package separates quantization into independent modules:
Serialization uses torchao's safetensors integration (torchao.prototype.safetensors) — no custom format. Checkpoints are compatible with torchao's save_pretrained/load_pretrained and can be loaded by vLLM.
This framework is designed to be promoted and reused for Qwen 3.5 MoE and other models — adding a new model requires only a QuantRecipe and optionally a custom packer.
Quantization recipes: "default" (INT4 min_max linears + INT8 per-axis embedding) and "sensitive" (INT8 for edge-layer v_proj/down_proj, INT4 HQQ asymmetric elsewhere).
Dual-path INT4 linear dispatch: IntxUnpackedToInt8Tensor's F.linear dispatch dequantizes to bf16 and calls cuBLAS, optimal for prefill (12x faster than tinygemm at T=2048). For decode, a model-agnostic source transform (backends/cuda/transforms/int4_linear_dispatch.py) converts to Int4TilePackedTo4dTensor (tinygemm), optimal for M=1. Export flow: prefill first (dequant+cuBLAS), then tinygemm transform, then decode export. inference.py applies the tinygemm transform for fast eager decode.
Split-K flash-decoding: ReplaceEdgeOpWithTritonOpPass in the CUDA backend selects triton::sdpa_decode_splitk for SDPA nodes where L_q=1 and L_kv exceeds 2048. At 128K context, full-attention decode SDPA improves from 15.7ms/layer to 0.7ms/layer (22x). Sliding-window layers (ring buffer <= 2048) use standard triton::sdpa. No model code changes — the pass inspects Q/K shapes in the exported graph automatically.
GGUF support: inference.py --gguf and export.py --gguf load community-quantized GGUF files directly. Tied embed/lm_head is untied — embedding dequantized to bf16 for gather, lm_head keeps INT4 for matmul.
Ring-buffer KV cache: Sliding window layers use RingKVCache (2x window) instead of flat max_seq_len buffers. The C++ runner chunks long prompts automatically via get_max_prefill_chunk metadata. Chunked prefill produces identical logits to sequential (verified by test).
Includes: C++ runner with BOS/EOS handling, chunked prefill, and #ifdef guards for non-CUDA builds; eager inference with torch.compile; unit and integration tests across quant/tests/, tests/, and backends/cuda/tests/.