Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ jobs:
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="

Comment thread
mergennachin marked this conversation as resolved.
# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
pip install gguf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be installed by some reqirements.txt?

python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts="

export-model-cuda-artifact:
name: export-model-cuda-artifact
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand Down Expand Up @@ -126,6 +126,7 @@ help:
@echo " llava-cpu - Build Llava runner with CPU backend"
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
@echo " clean - Clean build artifacts"
Expand Down Expand Up @@ -425,6 +426,15 @@ qwen3_5_moe-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"

gemma4_31b-cuda:
@echo "==> Building and installing ExecuTorch with CUDA..."
cmake --workflow --preset llm-release-cuda
@echo "==> Building Gemma 4 31B runner with CUDA..."
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"

qwen3_5_moe-metal:
@echo "==> Building and installing ExecuTorch with Metal..."
cmake --workflow --preset llm-release-metal
Expand Down
3 changes: 2 additions & 1 deletion backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
# Only build CUDA shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
runtime/shims/sort.cu runtime/shims/rand.cu
runtime/shims/int4_plain_mm.cu runtime/shims/sort.cu
runtime/shims/rand.cu
)
endif()

Expand Down
16 changes: 16 additions & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
"at::_ops::_weight_int4pack_mm::call": None,
"at::_ops::sort_stable::call": None,
"aoti_torch_cuda_randint_low_out": None,
"executorch_cuda::int4_plain_mm": None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit what does 'plain' signify? as opposed to what?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plain as in no special format like tinygemm. It understands Int4Tensor format natively

"aoti_torch_cuda_int4_plain_mm": None,
}

@classmethod
Expand Down Expand Up @@ -298,6 +300,20 @@ def get_aoti_compile_options(
"aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel,
}

try:
import torch

options["aot_inductor.custom_ops_to_c_shims"] = {
torch.ops.executorch_cuda.int4_plain_mm.default: [
"AOTITorchError aoti_torch_cuda_int4_plain_mm("
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
],
}
except AttributeError:
# int4_dispatch.py not imported — op not registered, skip C shim mapping
pass

Comment thread
mergennachin marked this conversation as resolved.
# Parse compile_specs to check for platform

platform = "linux"
Expand Down
109 changes: 109 additions & 0 deletions backends/cuda/int4_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time.

This module overrides Int4Tensor's F.linear dispatch so that torch.export
traces through our custom op and dequant logic instead of torchao's default
(mslk/tinygemm). The code here executes during eager inference and during
AOTI export tracing — it does NOT run at .pte runtime.

At .pte runtime, the captured graph is executed by the AOTI-generated .so:
- The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that
runs the W4A8 dp4a matvec kernel (backends/cuda/runtime/shims/).
- The inline dequant + F.linear is compiled by inductor into fused Triton
dequant + cuBLAS matmul kernels.

Dispatch strategy (determines what gets captured in the export graph):
Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm``
Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops)

Import this module before using nn.Linear with Int4Tensor weights::

import executorch.backends.cuda.int4_dispatch # noqa: F401
"""

import torch
import torch.nn.functional as F
from torch.library import impl, Library
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor

# ---------------------------------------------------------------------------
# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager
# ---------------------------------------------------------------------------

_lib = Library("executorch_cuda", "DEF")
_lib.define(
"int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor"
)


@impl(_lib, "int4_plain_mm", "Meta")
def _meta(self, qdata, scale, zero, group_size):
return torch.empty(
self.shape[0], qdata.shape[0], dtype=self.dtype, device=self.device
)


@impl(_lib, "int4_plain_mm", "CUDA")
def _cuda(self, qdata, scale, zero, group_size):
return _dequant_matmul(self, qdata, scale, zero, group_size)


def _dequant_matmul(x, qdata, scale, zero, group_size):
"""Dequant INT4 weights to input dtype and call F.linear."""
N, K_half = qdata.shape
K = K_half * 2
n_groups = K // group_size
gs_half = group_size // 2
dtype = x.dtype

p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half)
low = (p & 0x0F).to(dtype)
high = ((p >> 4) & 0x0F).to(dtype)
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)

s = scale.to(dtype).t().unsqueeze(-1)
z = zero.to(dtype).t().unsqueeze(-1)
w_deq = ((data - z) * s).reshape(N, K)

return F.linear(x, w_deq)


# ---------------------------------------------------------------------------
# Int4Tensor F.linear dispatch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you highlight this is export time trace through dispatch and not at runtime

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# ---------------------------------------------------------------------------

aten = torch.ops.aten
_implements = Int4Tensor.implements
_implements_torch_function = Int4Tensor.implements_torch_function


@_implements([aten.linear.default])
@_implements_torch_function([F.linear])
def _(func, types, args, kwargs):
input_tensor = args[0]
weight_tensor = args[1]
bias = args[2] if len(args) > 2 else None

orig_shape = input_tensor.shape
x_2d = input_tensor.reshape(-1, orig_shape[-1])

qdata = weight_tensor.qdata
scale = weight_tensor.scale
zero = weight_tensor.zero_point
gs = weight_tensor.block_size[-1]

M = x_2d.shape[0]
if M <= 4:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why not M == 1? may be important for spec dec cases

out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs)
else:
out = _dequant_matmul(x_2d, qdata, scale, zero, gs)

out = out.reshape(*orig_shape[:-1], -1)
if bias is not None:
out = out + bias
return out
81 changes: 81 additions & 0 deletions backends/cuda/runtime/shims/int4_plain_mm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cuda.h>
#include <cuda_runtime.h>

#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/int4_plain_mm.h>
#include <executorch/backends/cuda/runtime/shims/int4_plain_mm.cuh>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/runtime/platform/log.h>

namespace executorch::backends::cuda {
#ifdef __cplusplus
extern "C" {
#endif

AOTITorchError aoti_torch_cuda_int4_plain_mm(
Tensor* self,
Tensor* qdata,
Tensor* scale,
Tensor* zero,
int64_t group_size,
Tensor** ret0) {
ET_CHECK_OR_RETURN_ERROR(
self != nullptr,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: self is null");

ET_CHECK_OR_RETURN_ERROR(
qdata != nullptr,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: qdata is null");

ET_CHECK_OR_RETURN_ERROR(
scale != nullptr,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: scale is null");

ET_CHECK_OR_RETURN_ERROR(
zero != nullptr,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: zero is null");

ET_CHECK_OR_RETURN_ERROR(
ret0 != nullptr,
InvalidArgument,
"aoti_torch_cuda_int4_plain_mm: ret0 is null");

int32_t M = self->size(0);
int32_t N = qdata->size(0);
Tensor* C = nullptr;
std::array<int64_t, 2> c_shape = {M, N};
std::array<int64_t, 2> c_stride = {N, 1};
aoti_torch_empty_strided(
2,
c_shape.data(),
c_stride.data(),
static_cast<int32_t>(
executorch::backends::aoti::slim::c10::ScalarType::BFloat16),
static_cast<int32_t>(
executorch::backends::aoti::slim::c10::DeviceType::CUDA),
0,
&C);

_int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C);
Comment thread
mergennachin marked this conversation as resolved.
Comment on lines +55 to +71
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();

*ret0 = C;
return Error::Ok;
}

#ifdef __cplusplus
}
#endif
} // namespace executorch::backends::cuda
Loading
Loading