Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bd42ee2
Add SM90 MXFP8 FP8 grouped kernels
zhangxiaolei123456 Jun 16, 2026
5e8cd3d
Fix MXFP8 FP8 per-column B scaling
zhangxiaolei123456 Jun 16, 2026
4365d1d
Fix MXFP8 contiguous accuracy test constraints
zhangxiaolei123456 Jun 16, 2026
5608f23
Add MXFP8 FP8 performance comparison test
zhangxiaolei123456 Jun 16, 2026
aaecde6
Stage MXFP8 B scales in shared memory
zhangxiaolei123456 Jun 16, 2026
cb8e933
Fix main-based MXFP8 include list
zhangxiaolei123456 Jun 16, 2026
dcb1017
Fence staged MXFP8 B scales before consume
zhangxiaolei123456 Jun 16, 2026
c1125fc
Support MXFP8 A scales on SM90 grouped kernels
zhangxiaolei123456 Jun 16, 2026
405be71
Load SM90 MXFP8 A scales from global memory
zhangxiaolei123456 Jun 16, 2026
c2b8716
Load SM90 MXFP8 B scales from global memory
zhangxiaolei123456 Jun 16, 2026
ffe17ba
Avoid SM90 masked MXFP8 cross-group stores
zhangxiaolei123456 Jun 16, 2026
44017de
Pack-load SM90 MXFP8 scales from SMEM in consumer
zhangxiaolei123456 Jun 16, 2026
6de214d
Support packed UE8M0 scales in SM90 MXFP8 GEMM
zhangxiaolei123456 Jun 16, 2026
06b2ac5
Pass explicit recipes to SM90 MXFP8 GEMM
zhangxiaolei123456 Jun 16, 2026
a5ef613
Fix SM90 masked packed scale group stride
zhangxiaolei123456 Jun 16, 2026
2b5b958
Allow non-contiguous SM90 MXFP8 scales
zhangxiaolei123456 Jun 16, 2026
9862891
Test UE8M0 int32 packing byte order
zhangxiaolei123456 Jun 17, 2026
65871a5
Fix SM90 MXFP8 contiguous RHS scale group
zhangxiaolei123456 Jun 24, 2026
203e3b3
Add SM90 MXFP8 DeepEP scale layout test
zhangxiaolei123456 Jun 25, 2026
d2032ea
Tighten SM90 MXFP8 DeepEP scale layout test
zhangxiaolei123456 Jun 25, 2026
664fa78
Add SM90 MXFP8 dense raw scale test
zhangxiaolei123456 Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_mxfp8_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
Expand Down Expand Up @@ -268,6 +269,113 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
}
}

static void m_grouped_mxfp8_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const std::string& compiled_dims,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b) {
(void) compiled_dims;
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(grouped_layout.is_contiguous());

const auto arch_major = device_runtime->get_arch_major();
const auto [m, k] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
const auto [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(arch_major == 9);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
DG_HOST_ASSERT(k % 32 == 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 and d.is_contiguous());
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
DG_HOST_ASSERT(a.second.scalar_type() == torch::kUInt8 or a.second.scalar_type() == torch::kInt);
DG_HOST_ASSERT(b.second.scalar_type() == torch::kUInt8 or b.second.scalar_type() == torch::kInt);
if (recipe_a.has_value())
DG_HOST_ASSERT(std::get<0>(recipe_a.value()) == 1 and
(std::get<1>(recipe_a.value()) == 32 or std::get<1>(recipe_a.value()) == 128));
if (recipe_b.has_value())
DG_HOST_ASSERT(std::get<0>(recipe_b.value()) == 1 and
(std::get<1>(recipe_b.value()) == 32 or std::get<1>(recipe_b.value()) == 128));
const auto [m_sfa, k_sfa] = get_shape<2>(a.second);
const auto gran_k_a = recipe_a.has_value()
? std::get<1>(recipe_a.value())
: k / (k_sfa * (a.second.scalar_type() == torch::kInt ? 4 : 1));
DG_HOST_ASSERT(m == m_sfa and (gran_k_a == 32 or gran_k_a == 128) and
k_sfa == ceil_div(k, gran_k_a * (a.second.scalar_type() == torch::kInt ? 4 : 1)));
const auto [num_groups_sfb, n_sfb, k_sfb] = get_shape<3>(b.second);
const auto gran_k_b = recipe_b.has_value()
? std::get<1>(recipe_b.value())
: k / (k_sfb * (b.second.scalar_type() == torch::kInt ? 4 : 1));
DG_HOST_ASSERT(num_groups == num_groups_sfb and n == n_sfb and
(gran_k_b == 32 or gran_k_b == 128) and
k_sfb == ceil_div(k, gran_k_b * (b.second.scalar_type() == torch::kInt ? 4 : 1)));

if (m == 0)
return;

sm90_m_grouped_mxfp8_fp8_gemm_contiguous_1d2d(
a.first, a.second, b.first, b.second, d, grouped_layout, num_groups, m, n, k,
compiled_dims, recipe_a, recipe_b);
}

static void m_grouped_mxfp8_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
const std::string& compiled_dims,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b) {
(void) expected_m;
(void) compiled_dims;
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());

const auto arch_major = device_runtime->get_arch_major();
const auto [num_groups, m, k] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups_, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
DG_HOST_ASSERT(arch_major == 9);
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__);
DG_HOST_ASSERT(masked_m.numel() == num_groups);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(k % 32 == 0 and m > 0 and n > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 and d.is_contiguous());
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
DG_HOST_ASSERT(a.second.scalar_type() == torch::kUInt8 or a.second.scalar_type() == torch::kInt);
DG_HOST_ASSERT(b.second.scalar_type() == torch::kUInt8 or b.second.scalar_type() == torch::kInt);
if (recipe_a.has_value())
DG_HOST_ASSERT(std::get<0>(recipe_a.value()) == 1 and
(std::get<1>(recipe_a.value()) == 32 or std::get<1>(recipe_a.value()) == 128));
if (recipe_b.has_value())
DG_HOST_ASSERT(std::get<0>(recipe_b.value()) == 1 and
(std::get<1>(recipe_b.value()) == 32 or std::get<1>(recipe_b.value()) == 128));
const auto [num_groups_sfa, m_sfa, k_sfa] = get_shape<3>(a.second);
const auto gran_k_a = recipe_a.has_value()
? std::get<1>(recipe_a.value())
: k / (k_sfa * (a.second.scalar_type() == torch::kInt ? 4 : 1));
DG_HOST_ASSERT(num_groups == num_groups_sfa and m == m_sfa and
(gran_k_a == 32 or gran_k_a == 128) and
k_sfa == ceil_div(k, gran_k_a * (a.second.scalar_type() == torch::kInt ? 4 : 1)));
const auto [num_groups_sfb, n_sfb, k_sfb] = get_shape<3>(b.second);
const auto gran_k_b = recipe_b.has_value()
? std::get<1>(recipe_b.value())
: k / (k_sfb * (b.second.scalar_type() == torch::kInt ? 4 : 1));
DG_HOST_ASSERT(num_groups == num_groups_sfb and n == n_sfb and
(gran_k_b == 32 or gran_k_b == 128) and
k_sfb == ceil_div(k, gran_k_b * (b.second.scalar_type() == torch::kInt ? 4 : 1)));

sm90_m_grouped_mxfp8_fp8_gemm_masked_1d2d(
a.first, a.second, b.first, b.second, d, masked_m, num_groups, m, n, k,
compiled_dims, recipe_a, recipe_b);
}

static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
Expand Down Expand Up @@ -644,6 +752,14 @@ static void register_apis(pybind11::module_& m) {
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_mxfp8_fp8_gemm_nt_contiguous", &m_grouped_mxfp8_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("compiled_dims") = "nk", py::arg("recipe_a") = std::nullopt,
py::arg("recipe_b") = std::nullopt);
m.def("m_grouped_mxfp8_fp8_gemm_nt_masked", &m_grouped_mxfp8_fp8_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("compiled_dims") = "nk",
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
Expand Down
11 changes: 10 additions & 1 deletion csrc/jit_kernels/heuristics/sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ struct SM90ArchSpec {
desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) {
block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()};
} else if (desc.gemm_type == GemmType::MGroupedMasked) {
block_m_candidates = {64, 128};
// Masked grouped outputs are laid out as [num_groups, expected_m, n].
// Keep BLOCK_M aligned with expected_m when possible, otherwise a
// full-block TMA store can cross into the next group's rows.
const int expected_m = desc.get_expected_m();
for (int candidate: {16, 32, 64, 128}) {
if (expected_m % candidate == 0)
block_m_candidates.push_back(candidate);
}
if (block_m_candidates.empty())
block_m_candidates = {64, 128};
}

// Block N candidates
Expand Down
Loading