Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ffb3d29
support tvm ffi interfaces
Fridge003 Apr 9, 2026
c82e589
rebase on 0426 upstream
rainj-me May 2, 2026
3d6ab9e
Relax timeout to 180s
Fridge003 Apr 28, 2026
f958b89
pin tvm to 0.1.9
Fridge003 Apr 28, 2026
5af4329
Support wheel compilation (#26)
Fridge003 May 4, 2026
66ea23e
Merge branch 'deepseek-ai:main' into release-0426
Fridge003 May 12, 2026
b97540e
[Misc] Remove verbose import message of legacy
Fridge003 May 12, 2026
160e75b
Add FP4 acts + MXF4 kind support and fused mega_moe_pre_dispatch kern…
pranjalssh May 6, 2026
d80617d
Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (com…
pranjalssh May 6, 2026
8ab0096
Add tvm-ffi wrapper for w4a4 megamoe
Fridge003 May 12, 2026
89f2f00
Bump to v0.1.0
Fridge003 May 12, 2026
23105b2
update readme
Fridge003 May 14, 2026
46a2294
Export the PDL utils of DeepGEMM (#34)
b8zhong May 20, 2026
48eb6a6
Expose BF16 grouped GEMM wrappers
popsiclexu May 28, 2026
6c9eaca
Fix version of dev branch to 0.0.0
Fridge003 May 29, 2026
86d705d
Fix IMA guard in paged MQA logits scheduler (#38)
nvjullin Jun 1, 2026
09fc810
Fix various issues in DeepGEMM tests (#39)
b8zhong Jun 9, 2026
7020fd8
Change license in pyproject.toml to avoid build and publish failures …
b8zhong Jun 10, 2026
c872baf
Move run tests script for DeepGEMM (#42)
b8zhong Jun 11, 2026
2176dff
Support num_heads == 16 in MQA logits (#43)
zRzRzRzRzRzRzR Jun 12, 2026
35d4d8c
Sm90 mega moe on sgl dev (#36)
qiushixiaoyu Jun 15, 2026
a36f7fd
Move MegaMoE Hopper test into sgl_deep_gemm tests (#45)
Fridge003 Jun 15, 2026
b5238ad
Update test modification instruction
Fridge003 Jun 15, 2026
bdabf6c
Add Hopper mega moe test to runner (#46)
Fridge003 Jun 16, 2026
77c9522
chore: bump apache-tvm-ffi 0.1.9 -> 0.1.11 (#47)
MartinHua Jun 18, 2026
774d081
feat: add signal for SBO in SM90 masked gemm
Jun 22, 2026
f4945a5
feat: add test for signal GEMM
Jun 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ deep_gemm/include/cutlass
stubs/

# Symlinks to compiled extensions
deep_gemm/*.so
deep_gemm/*.so
deep_gemm/_C_build
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations")
set(CUDA_SEPARABLE_COMPILATION ON)

list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
Expand All @@ -16,13 +17,14 @@ set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
find_package(tvm_ffi REQUIRED)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${tvm_ffi_INCLUDE_DIR} ${tvm_ffi_DLPACK_INCLUDE_DIR})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs ${tvm_ffi_ROOT_DIR}/lib)

# The main Python API entrance
pybind11_add_module(_C csrc/python_api.cpp)
Expand Down
4 changes: 4 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ original_dir=$(pwd)
script_dir=$(realpath "$(dirname "$0")")
cd "$script_dir"

# Link CUTLASS includes
ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include
ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include

# Remove old dist file, build files, and install
rm -rf build dist
rm -rf *.egg-info
Expand Down
151 changes: 151 additions & 0 deletions build_sgl_deep_gemm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/usr/bin/env bash
#
# Build a wheel for the `sgl-deep-gemm` distribution.
#
# Distribution name: sgl-deep-gemm. Top-level import name: `deep_gemm`
# (so existing call sites like `import deep_gemm` in sglang keep working).
#
# Build flow:
# 1. Initialises submodules (cutlass, fmt) — same prerequisite as `bash build.sh`.
# 2. Stages the package layout under build/deep_gemm/ with the Python
# sub-modules pulled from the source deep_gemm/ tree (utils, testing,
# legacy, mega).
# 3. Reads the version string from sgl_deep_gemm/VERSION.
# 4. Pre-compiles the tvm-ffi `_C.so` extension and bundles it into the wheel.
# 5. Invokes `python -m build` to produce dist/*.whl.

set -euo pipefail

PYTHON_EXE=$(which python3 || which python)
ROOT_DIR=$(realpath "$(dirname "$0")")
BUILD_DIR="${ROOT_DIR}/build"
PKG_DIR="${BUILD_DIR}/deep_gemm"
DIST_DIR="${ROOT_DIR}/dist"

cd "$ROOT_DIR"

if [[ ! -f "setup.py" || ! -d "sgl_deep_gemm" || ! -d "deep_gemm" || ! -d "csrc" ]]; then
echo "Error: Run from the DeepGEMM project root." >&2
exit 1
fi

echo "--- Initialising submodules ---"
git submodule update --init --recursive

echo "--- Linking CUTLASS headers into deep_gemm/include ---"
ln -sfn "${ROOT_DIR}/third-party/cutlass/include/cutlass" "${ROOT_DIR}/deep_gemm/include/cutlass"
ln -sfn "${ROOT_DIR}/third-party/cutlass/include/cute" "${ROOT_DIR}/deep_gemm/include/cute"

echo "--- Preparing build directory ---"
rm -rf "$BUILD_DIR"
mkdir -p "$PKG_DIR"

cp sgl_deep_gemm/LICENSE sgl_deep_gemm/README.md sgl_deep_gemm/pyproject.toml "$BUILD_DIR/"
cp sgl_deep_gemm/__init__.py "$PKG_DIR/"

# `__init__.py` imports `.utils`, `.testing`, `.legacy`, `.mega` — pulled from
# the existing deep_gemm/ tree.
for sub in utils testing legacy mega; do
cp -r "deep_gemm/${sub}" "$PKG_DIR/"
done

# Headers required by the runtime JIT (same set the deep_gemm wheel ships).
mkdir -p "$PKG_DIR/include"
cp -r "${ROOT_DIR}/deep_gemm/include/deep_gemm" "$PKG_DIR/include/deep_gemm"
cp -r "${ROOT_DIR}/third-party/cutlass/include/cute" "$PKG_DIR/include/cute"
cp -r "${ROOT_DIR}/third-party/cutlass/include/cutlass" "$PKG_DIR/include/cutlass"

echo "--- Reading version from sgl_deep_gemm/VERSION ---"
if [[ ! -f "sgl_deep_gemm/VERSION" ]]; then
echo "Error: sgl_deep_gemm/VERSION is missing — create it with the desired version (e.g. 0.0.1)." >&2
exit 1
fi
# Strip surrounding whitespace; the file is the single source of truth.
tr -d '[:space:]' < sgl_deep_gemm/VERSION > "$PKG_DIR/VERSION"
echo "Version: $(cat "$PKG_DIR/VERSION")"

echo "--- Compiling _C.so ---"
ROOT_DIR="$ROOT_DIR" PKG_DIR="$PKG_DIR" "$PYTHON_EXE" -u - <<'PY'
import os, shutil, subprocess, sys, sysconfig
root_dir = os.environ['ROOT_DIR']
pkg_dir = os.environ['PKG_DIR']
sys.path.insert(0, root_dir)

# tvm_ffi.cpp.build runs ninja with capture_output=True, hiding compile logs
# until a failure. Patch subprocess.run so the ninja invocation streams to the
# terminal — leaves other internal calls (nvidia-smi, nvcc --version) alone.
_orig_run = subprocess.run
def _streamed_run(*args, **kwargs):
cmd = kwargs.get('args') if 'args' in kwargs else (args[0] if args else None)
is_ninja = isinstance(cmd, (list, tuple)) and cmd and 'ninja' in str(cmd[0])
if is_ninja:
kwargs.pop('capture_output', None)
kwargs['stdout'] = None
kwargs['stderr'] = None
return _orig_run(*args, **kwargs)
subprocess.run = _streamed_run

import torch
import tvm_ffi.cpp
from setup import _find_cuda_home, _get_cuda_arch

cuda_home = _find_cuda_home()
os.environ.setdefault('TVM_FFI_CUDA_ARCH_LIST', _get_cuda_arch())

cxx_abi = int(torch.compiled_with_cxx11_abi())
extra_cflags = [
'-std=c++17', '-O3', '-fPIC',
'-Wno-psabi', '-Wno-deprecated-declarations',
f'-D_GLIBCXX_USE_CXX11_ABI={cxx_abi}',
]
if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')):
extra_cflags.append('-DDG_JIT_USE_RUNTIME_API')

torch_dir = os.path.dirname(torch.__file__)
extra_include_paths = [
f'{cuda_home}/include',
sysconfig.get_path('include'),
os.path.join(torch_dir, 'include'),
os.path.join(torch_dir, 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(root_dir, 'deep_gemm', 'include'),
os.path.join(root_dir, 'third-party', 'cutlass', 'include'),
os.path.join(root_dir, 'third-party', 'fmt', 'include'),
]
cccl = f'{cuda_home}/include/cccl'
if os.path.exists(cccl):
extra_include_paths.append(cccl)

extra_ldflags = [
f'-L{cuda_home}/lib64',
f'-L{os.path.join(torch_dir, "lib")}',
'-lcudart', '-lnvrtc', '-lcublasLt', '-lcublas',
'-ltorch', '-ltorch_cpu', '-lc10', '-lc10_cuda', '-ltorch_cuda',
]

build_subdir = os.path.join(pkg_dir, '_C_build')
os.makedirs(build_subdir, exist_ok=True)
lib_path = tvm_ffi.cpp.build(
name='_C',
cpp_files=[os.path.join(root_dir, 'csrc', 'tvm_ffi_api.cpp')],
extra_cflags=extra_cflags,
extra_ldflags=extra_ldflags,
extra_include_paths=extra_include_paths,
build_directory=build_subdir,
)
target = os.path.join(pkg_dir, '_C.so')
if os.path.exists(target):
os.remove(target)
shutil.copy2(lib_path, target)
shutil.rmtree(build_subdir, ignore_errors=True)
print(f"Built {target}")
PY

echo "--- Installing build frontend ---"
"$PYTHON_EXE" -m pip install --quiet --upgrade build

echo "--- Building wheel ---"
mkdir -p "$DIST_DIR"
"$PYTHON_EXE" -m build --wheel "$BUILD_DIR" --outdir "$DIST_DIR"

echo "--- Done ---"
ls -lh "$DIST_DIR"/sgl_deep_gemm-*.whl 2>/dev/null || ls -lh "$DIST_DIR"/sgl-deep-gemm-*.whl 2>/dev/null || ls -lh "$DIST_DIR"/sgl_deep_gemm*.whl
12 changes: 8 additions & 4 deletions csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
// Check FP4 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(num_heads == 16 or num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
Expand All @@ -117,7 +117,7 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
} else {
// Check FP8 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(num_heads == 16 or num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
Expand Down Expand Up @@ -247,7 +247,7 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(num_heads == 16 or num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
Expand Down Expand Up @@ -285,7 +285,7 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
// Check FP8 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(num_heads == 16 or num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
Expand Down Expand Up @@ -413,6 +413,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
}
#endif

#if 0

static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
Expand Down Expand Up @@ -450,4 +452,6 @@ static void register_apis(pybind11::module_& m) {
#endif
}

#endif

} // namespace deep_gemm::attention
7 changes: 4 additions & 3 deletions csrc/apis/einsum.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#pragma once

#include <pybind11/pybind11.h>
#include <torch/python.h>

#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/layout.hpp"
Expand Down Expand Up @@ -215,6 +212,8 @@ static void fp8_einsum(const std::string& expr,
}
#endif

#if 0

static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("einsum", &einsum,
Expand All @@ -228,4 +227,6 @@ static void register_apis(pybind11::module_& m) {
#endif
}

#endif

} // namespace deep_gemm::einsum
29 changes: 24 additions & 5 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair<torch::Tensor,
d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt);
}

static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
static std::optional<std::pair<int, int>> m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
Expand All @@ -228,13 +228,22 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const int& max_block_n,
const bool& enable_overlap,
const std::optional<torch::Tensor>& signal) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());

if (enable_overlap) {
DG_HOST_ASSERT(signal.has_value());
DG_HOST_ASSERT(signal.value().is_contiguous());
DG_HOST_ASSERT(signal.value().scalar_type() == torch::kInt32);
}

// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
Expand All @@ -255,17 +264,21 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);

// Dispatch implementation
std::optional<std::pair<int, int>> result = std::nullopt;
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims,
max_block_n, enable_overlap, signal);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
DG_HOST_ASSERT(not enable_overlap and "SBO overlap is only supported on SM90");
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
return result;
}

static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand Down Expand Up @@ -596,6 +609,8 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
}

#if 0

static void register_apis(pybind11::module_& m) {

#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
Expand Down Expand Up @@ -643,7 +658,9 @@ static void register_apis(pybind11::module_& m) {
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false,
py::arg("max_block_n") = 256, py::arg("enable_overlap") = false,
py::arg("signal") = std::nullopt);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
Expand Down Expand Up @@ -712,4 +729,6 @@ static void register_apis(pybind11::module_& m) {
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
}

#endif

} // namespace deep_gemm::gemm
4 changes: 4 additions & 0 deletions csrc/apis/hyperconnection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a,

#endif

#if 0

static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm,
Expand All @@ -67,4 +69,6 @@ static void register_apis(pybind11::module_& m) {
#endif
}

#endif

} // namespace deep_gemm::hyperconnection
4 changes: 4 additions & 0 deletions csrc/apis/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te

#endif

#if 0

static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
Expand All @@ -140,4 +142,6 @@ static void register_apis(pybind11::module_& m) {
}, py::arg("expected_m") = std::nullopt);
}

#endif

} // namespace deep_gemm::layout
Loading