-
Notifications
You must be signed in to change notification settings - Fork 1k
Add Gemma 4 31B-IT model, export, and quantization framework for ExecuTorch #19213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c761d6b
cecb42d
8768019
595ddfd
313799b
89cd615
94c0d4c
ad57fc9
48d83c7
869a2ac
cd50af2
644ec6e
6273bb2
e7375a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,6 +148,10 @@ jobs: | |
| # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) | ||
| python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" | ||
|
|
||
| # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) | ||
| pip install gguf | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be installed by some reqirements.txt? |
||
| python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" | ||
|
|
||
| export-model-cuda-artifact: | ||
| name: export-model-cuda-artifact | ||
| # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -226,6 +226,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: | |
| "at::_ops::_weight_int4pack_mm::call": None, | ||
| "at::_ops::sort_stable::call": None, | ||
| "aoti_torch_cuda_randint_low_out": None, | ||
| "executorch_cuda::int4_plain_mm": None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit what does 'plain' signify? as opposed to what?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plain as in no special format like tinygemm. It understands Int4Tensor format natively |
||
| "aoti_torch_cuda_int4_plain_mm": None, | ||
| } | ||
|
|
||
| @classmethod | ||
|
|
@@ -298,6 +300,20 @@ def get_aoti_compile_options( | |
| "aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel, | ||
| } | ||
|
|
||
| try: | ||
| import torch | ||
|
|
||
| options["aot_inductor.custom_ops_to_c_shims"] = { | ||
| torch.ops.executorch_cuda.int4_plain_mm.default: [ | ||
| "AOTITorchError aoti_torch_cuda_int4_plain_mm(" | ||
| "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " | ||
| "AtenTensorHandle, int64_t, AtenTensorHandle*)" | ||
| ], | ||
| } | ||
| except AttributeError: | ||
| # int4_dispatch.py not imported — op not registered, skip C shim mapping | ||
| pass | ||
|
|
||
|
mergennachin marked this conversation as resolved.
|
||
| # Parse compile_specs to check for platform | ||
|
|
||
| platform = "linux" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. | ||
|
|
||
| This module overrides Int4Tensor's F.linear dispatch so that torch.export | ||
| traces through our custom op and dequant logic instead of torchao's default | ||
| (mslk/tinygemm). The code here executes during eager inference and during | ||
| AOTI export tracing — it does NOT run at .pte runtime. | ||
|
|
||
| At .pte runtime, the captured graph is executed by the AOTI-generated .so: | ||
| - The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that | ||
| runs the W4A8 dp4a matvec kernel (backends/cuda/runtime/shims/). | ||
| - The inline dequant + F.linear is compiled by inductor into fused Triton | ||
| dequant + cuBLAS matmul kernels. | ||
|
|
||
| Dispatch strategy (determines what gets captured in the export graph): | ||
| Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` | ||
| Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) | ||
|
|
||
| Import this module before using nn.Linear with Int4Tensor weights:: | ||
|
|
||
| import executorch.backends.cuda.int4_dispatch # noqa: F401 | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.library import impl, Library | ||
| from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| _lib = Library("executorch_cuda", "DEF") | ||
| _lib.define( | ||
| "int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor" | ||
| ) | ||
|
|
||
|
|
||
| @impl(_lib, "int4_plain_mm", "Meta") | ||
| def _meta(self, qdata, scale, zero, group_size): | ||
| return torch.empty( | ||
| self.shape[0], qdata.shape[0], dtype=self.dtype, device=self.device | ||
| ) | ||
|
|
||
|
|
||
| @impl(_lib, "int4_plain_mm", "CUDA") | ||
| def _cuda(self, qdata, scale, zero, group_size): | ||
| return _dequant_matmul(self, qdata, scale, zero, group_size) | ||
|
|
||
|
|
||
| def _dequant_matmul(x, qdata, scale, zero, group_size): | ||
| """Dequant INT4 weights to input dtype and call F.linear.""" | ||
| N, K_half = qdata.shape | ||
| K = K_half * 2 | ||
| n_groups = K // group_size | ||
| gs_half = group_size // 2 | ||
| dtype = x.dtype | ||
|
|
||
| p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) | ||
| low = (p & 0x0F).to(dtype) | ||
| high = ((p >> 4) & 0x0F).to(dtype) | ||
| data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) | ||
|
|
||
| s = scale.to(dtype).t().unsqueeze(-1) | ||
| z = zero.to(dtype).t().unsqueeze(-1) | ||
| w_deq = ((data - z) * s).reshape(N, K) | ||
|
|
||
| return F.linear(x, w_deq) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Int4Tensor F.linear dispatch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you highlight this is export time trace through dispatch and not at runtime
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| # --------------------------------------------------------------------------- | ||
|
|
||
| aten = torch.ops.aten | ||
| _implements = Int4Tensor.implements | ||
| _implements_torch_function = Int4Tensor.implements_torch_function | ||
|
|
||
|
|
||
| @_implements([aten.linear.default]) | ||
| @_implements_torch_function([F.linear]) | ||
| def _(func, types, args, kwargs): | ||
| input_tensor = args[0] | ||
| weight_tensor = args[1] | ||
| bias = args[2] if len(args) > 2 else None | ||
|
|
||
| orig_shape = input_tensor.shape | ||
| x_2d = input_tensor.reshape(-1, orig_shape[-1]) | ||
|
|
||
| qdata = weight_tensor.qdata | ||
| scale = weight_tensor.scale | ||
| zero = weight_tensor.zero_point | ||
| gs = weight_tensor.block_size[-1] | ||
|
|
||
| M = x_2d.shape[0] | ||
| if M <= 4: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious why not M == 1? may be important for spec dec cases |
||
| out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs) | ||
| else: | ||
| out = _dequant_matmul(x_2d, qdata, scale, zero, gs) | ||
|
|
||
| out = out.reshape(*orig_shape[:-1], -1) | ||
| if bias is not None: | ||
| out = out + bias | ||
| return out | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <executorch/backends/aoti/utils.h> | ||
| #include <executorch/backends/cuda/runtime/shims/int4_plain_mm.h> | ||
| #include <executorch/backends/cuda/runtime/shims/int4_plain_mm.cuh> | ||
| #include <executorch/backends/cuda/runtime/shims/memory.h> | ||
| #include <executorch/runtime/platform/log.h> | ||
|
|
||
| namespace executorch::backends::cuda { | ||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| AOTITorchError aoti_torch_cuda_int4_plain_mm( | ||
| Tensor* self, | ||
| Tensor* qdata, | ||
| Tensor* scale, | ||
| Tensor* zero, | ||
| int64_t group_size, | ||
| Tensor** ret0) { | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| self != nullptr, | ||
| InvalidArgument, | ||
| "aoti_torch_cuda_int4_plain_mm: self is null"); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| qdata != nullptr, | ||
| InvalidArgument, | ||
| "aoti_torch_cuda_int4_plain_mm: qdata is null"); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| scale != nullptr, | ||
| InvalidArgument, | ||
| "aoti_torch_cuda_int4_plain_mm: scale is null"); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| zero != nullptr, | ||
| InvalidArgument, | ||
| "aoti_torch_cuda_int4_plain_mm: zero is null"); | ||
|
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| ret0 != nullptr, | ||
| InvalidArgument, | ||
| "aoti_torch_cuda_int4_plain_mm: ret0 is null"); | ||
|
|
||
| int32_t M = self->size(0); | ||
| int32_t N = qdata->size(0); | ||
| Tensor* C = nullptr; | ||
| std::array<int64_t, 2> c_shape = {M, N}; | ||
| std::array<int64_t, 2> c_stride = {N, 1}; | ||
| aoti_torch_empty_strided( | ||
| 2, | ||
| c_shape.data(), | ||
| c_stride.data(), | ||
| static_cast<int32_t>( | ||
| executorch::backends::aoti::slim::c10::ScalarType::BFloat16), | ||
| static_cast<int32_t>( | ||
| executorch::backends::aoti::slim::c10::DeviceType::CUDA), | ||
| 0, | ||
| &C); | ||
|
|
||
| _int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C); | ||
|
mergennachin marked this conversation as resolved.
Comment on lines
+55
to
+71
|
||
| ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); | ||
|
|
||
| *ret0 = C; | ||
| return Error::Ok; | ||
| } | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif | ||
| } // namespace executorch::backends::cuda | ||
Uh oh!
There was an error while loading. Please reload this page.