diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 42622df7d8..9675cf6e8c 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -5,6 +5,7 @@ #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm90_mxfp8_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm90_bf16_gemm.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm100_bf16_gemm.hpp" @@ -268,6 +269,113 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const std::optional>& recipe_a, + const std::optional>& recipe_b) { + (void) compiled_dims; + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); + + const auto arch_major = device_runtime->get_arch_major(); + const auto [m, k] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + const auto [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(arch_major == 9); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(k % 32 == 0 and num_groups > 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 and d.is_contiguous()); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + DG_HOST_ASSERT(a.second.scalar_type() == torch::kUInt8 or a.second.scalar_type() == torch::kInt); + DG_HOST_ASSERT(b.second.scalar_type() == torch::kUInt8 or b.second.scalar_type() == torch::kInt); + if (recipe_a.has_value()) + DG_HOST_ASSERT(std::get<0>(recipe_a.value()) == 1 and + (std::get<1>(recipe_a.value()) == 32 or std::get<1>(recipe_a.value()) == 128)); + if (recipe_b.has_value()) + DG_HOST_ASSERT(std::get<0>(recipe_b.value()) == 1 and + (std::get<1>(recipe_b.value()) == 32 or std::get<1>(recipe_b.value()) == 128)); + const auto [m_sfa, k_sfa] = get_shape<2>(a.second); + const auto gran_k_a = recipe_a.has_value() + ? std::get<1>(recipe_a.value()) + : k / (k_sfa * (a.second.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(m == m_sfa and (gran_k_a == 32 or gran_k_a == 128) and + k_sfa == ceil_div(k, gran_k_a * (a.second.scalar_type() == torch::kInt ? 4 : 1))); + const auto [num_groups_sfb, n_sfb, k_sfb] = get_shape<3>(b.second); + const auto gran_k_b = recipe_b.has_value() + ? std::get<1>(recipe_b.value()) + : k / (k_sfb * (b.second.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(num_groups == num_groups_sfb and n == n_sfb and + (gran_k_b == 32 or gran_k_b == 128) and + k_sfb == ceil_div(k, gran_k_b * (b.second.scalar_type() == torch::kInt ? 4 : 1))); + + if (m == 0) + return; + + sm90_m_grouped_mxfp8_fp8_gemm_contiguous_1d2d( + a.first, a.second, b.first, b.second, d, grouped_layout, num_groups, m, n, k, + compiled_dims, recipe_a, recipe_b); +} + +static void m_grouped_mxfp8_fp8_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + const std::string& compiled_dims, + const std::optional>& recipe_a, + const std::optional>& recipe_b) { + (void) expected_m; + (void) compiled_dims; + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_groups, m, k] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups_, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + DG_HOST_ASSERT(arch_major == 9); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__); + DG_HOST_ASSERT(masked_m.numel() == num_groups); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(k % 32 == 0 and m > 0 and n > 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 and d.is_contiguous()); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + DG_HOST_ASSERT(a.second.scalar_type() == torch::kUInt8 or a.second.scalar_type() == torch::kInt); + DG_HOST_ASSERT(b.second.scalar_type() == torch::kUInt8 or b.second.scalar_type() == torch::kInt); + if (recipe_a.has_value()) + DG_HOST_ASSERT(std::get<0>(recipe_a.value()) == 1 and + (std::get<1>(recipe_a.value()) == 32 or std::get<1>(recipe_a.value()) == 128)); + if (recipe_b.has_value()) + DG_HOST_ASSERT(std::get<0>(recipe_b.value()) == 1 and + (std::get<1>(recipe_b.value()) == 32 or std::get<1>(recipe_b.value()) == 128)); + const auto [num_groups_sfa, m_sfa, k_sfa] = get_shape<3>(a.second); + const auto gran_k_a = recipe_a.has_value() + ? std::get<1>(recipe_a.value()) + : k / (k_sfa * (a.second.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(num_groups == num_groups_sfa and m == m_sfa and + (gran_k_a == 32 or gran_k_a == 128) and + k_sfa == ceil_div(k, gran_k_a * (a.second.scalar_type() == torch::kInt ? 4 : 1))); + const auto [num_groups_sfb, n_sfb, k_sfb] = get_shape<3>(b.second); + const auto gran_k_b = recipe_b.has_value() + ? std::get<1>(recipe_b.value()) + : k / (k_sfb * (b.second.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(num_groups == num_groups_sfb and n == n_sfb and + (gran_k_b == 32 or gran_k_b == 128) and + k_sfb == ceil_div(k, gran_k_b * (b.second.scalar_type() == torch::kInt ? 4 : 1))); + + sm90_m_grouped_mxfp8_fp8_gemm_masked_1d2d( + a.first, a.second, b.first, b.second, d, masked_m, num_groups, m, n, k, + compiled_dims, recipe_a, recipe_b); +} + static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, const std::pair& b, const torch::Tensor& d, @@ -644,6 +752,14 @@ static void register_apis(pybind11::module_& m) { py::arg("expected_m"), py::arg("recipe") = std::nullopt, py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_mxfp8_fp8_gemm_nt_contiguous", &m_grouped_mxfp8_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", py::arg("recipe_a") = std::nullopt, + py::arg("recipe_b") = std::nullopt); + m.def("m_grouped_mxfp8_fp8_gemm_nt_masked", &m_grouped_mxfp8_fp8_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("compiled_dims") = "nk", + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), py::arg("ks_tensor"), py::arg("c") = std::nullopt, diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index c411fb7e01..7b53c8ba24 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -32,7 +32,16 @@ struct SM90ArchSpec { desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) { block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()}; } else if (desc.gemm_type == GemmType::MGroupedMasked) { - block_m_candidates = {64, 128}; + // Masked grouped outputs are laid out as [num_groups, expected_m, n]. + // Keep BLOCK_M aligned with expected_m when possible, otherwise a + // full-block TMA store can cross into the next group's rows. + const int expected_m = desc.get_expected_m(); + for (int candidate: {16, 32, 64, 128}) { + if (expected_m % candidate == 0) + block_m_candidates.push_back(candidate); + } + if (block_m_candidates.empty()) + block_m_candidates = {64, 128}; } // Block N candidates diff --git a/csrc/jit_kernels/impls/sm90_mxfp8_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_mxfp8_fp8_gemm_1d2d.hpp new file mode 100644 index 0000000000..c50def739b --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_mxfp8_fp8_gemm_1d2d.hpp @@ -0,0 +1,276 @@ +#pragma once + +#include + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" + +#include "runtime_utils.hpp" + +namespace deep_gemm { + +template +class SM90MXFP8FP8Gemm1D2DRuntime final: public LaunchRuntime> { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + void *sfa, *sfb, *grouped_layout; + uint32_t sfa_stride_group, sfa_stride_m, sfa_stride_k; + uint32_t sfa_gran_k; + bool sfa_packed_int32; + uint32_t sfb_stride_group, sfb_stride_n, sfb_stride_k; + uint32_t sfb_gran_k; + bool sfb_packed_int32; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; +static constexpr int kSm90MXFP8FP8ScaleRecipeJitVersion = 8; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_mxfp8_fp8_gemm_1d2d_impl< + {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, {}, + {}, {} + >); +}} +)", + kMasked ? "true" : "false", + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, + args.gemm_config.storage_config.swizzle_b_mode, + args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfa, args.sfb, args.grouped_layout, + args.sfa_stride_group, args.sfa_stride_m, args.sfa_stride_k, + args.sfa_gran_k, args.sfa_packed_int32, + args.sfb_stride_group, args.sfb_stride_n, args.sfb_stride_k, + args.sfb_gran_k, args.sfb_packed_int32, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a, args.tensor_map_b, args.tensor_map_d)); + } +}; + +static void tune_mxfp8_fp8_smem_config(GemmConfig& config, const GemmDesc& desc) { + const int orig_num_stages = config.pipeline_config.num_stages; + const int original_per_stage = + config.storage_config.load_block_m * config.layout.block_k * c10::elementSize(desc.a_dtype) + + config.storage_config.load_block_n * config.layout.block_k * c10::elementSize(desc.b_dtype) + + align(config.layout.block_m * static_cast(sizeof(float)), 128); + const int sfa_per_stage = align(config.layout.block_m * (config.layout.block_k / 32) * static_cast(sizeof(uint8_t)), 128); + const int sfb_per_stage = align(config.layout.block_n * (config.layout.block_k / 32) * static_cast(sizeof(uint8_t)), 128); + const int smem_extra = config.pipeline_config.smem_size - orig_num_stages * original_per_stage; + const int merged_per_stage = + config.storage_config.load_block_m * config.layout.block_k * c10::elementSize(desc.a_dtype) + + config.storage_config.load_block_n * config.layout.block_k * c10::elementSize(desc.b_dtype) + + sfa_per_stage + sfb_per_stage; + int chosen_stages = std::min(orig_num_stages, (SM90ArchSpec::smem_capacity - smem_extra) / merged_per_stage); + DG_HOST_ASSERT(chosen_stages >= 1); + config.pipeline_config.num_stages = chosen_stages; + config.pipeline_config.smem_size = smem_extra + chosen_stages * merged_per_stage; +} + +static void sm90_m_grouped_mxfp8_fp8_gemm_contiguous_1d2d( + const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const std::string& compiled_dims, + const std::optional>& recipe_a, + const std::optional>& recipe_b) { + DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(sfa.scalar_type() == torch::kUInt8 or sfa.scalar_type() == torch::kInt); + DG_HOST_ASSERT(sfb.scalar_type() == torch::kUInt8 or sfb.scalar_type() == torch::kInt); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt and grouped_layout.is_contiguous()); + DG_HOST_ASSERT(a.is_contiguous() and b.is_contiguous() and d.is_contiguous()); + + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedContiguous, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = k, .expected_num_groups = 1 + }; + auto config = get_best_config(desc); + tune_mxfp8_fp8_smem_config(config, desc); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto sfa_gran_k = recipe_a.has_value() + ? std::get<1>(recipe_a.value()) + : k / (static_cast(sfa.size(1)) * (sfa.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(sfa_gran_k == 32 or sfa_gran_k == 128); + DG_HOST_ASSERT(sfa.size(1) == ceil_div(k, sfa_gran_k * (sfa.scalar_type() == torch::kInt ? 4 : 1))); + const auto sfb_gran_k = recipe_b.has_value() + ? std::get<1>(recipe_b.value()) + : k / (static_cast(sfb.size(-1)) * (sfb.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(sfb_gran_k == 32 or sfb_gran_k == 128); + DG_HOST_ASSERT(sfb.size(-1) == ceil_div(k, sfb_gran_k * (sfb.scalar_type() == torch::kInt ? 4 : 1))); + + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(1)), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const typename SM90MXFP8FP8Gemm1D2DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .sfa = sfa.data_ptr(), + .sfb = sfb.data_ptr(), + .grouped_layout = grouped_layout.data_ptr(), + .sfa_stride_group = 0, + .sfa_stride_m = static_cast(sfa.stride(0)), + .sfa_stride_k = static_cast(sfa.stride(1)), + .sfa_gran_k = static_cast(sfa_gran_k), + .sfa_packed_int32 = sfa.scalar_type() == torch::kInt, + .sfb_stride_group = static_cast(sfb.stride(0)), + .sfb_stride_n = static_cast(sfb.stride(1)), + .sfb_stride_k = static_cast(sfb.stride(2)), + .sfb_gran_k = static_cast(sfb_gran_k), + .sfb_packed_int32 = sfb.scalar_type() == torch::kInt, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto code = SM90MXFP8FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_m_grouped_mxfp8_fp8_gemm_contiguous_1d2d_scale_recipe_v8", code); + SM90MXFP8FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_mxfp8_fp8_gemm_masked_1d2d( + const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const std::string& compiled_dims, + const std::optional>& recipe_a, + const std::optional>& recipe_b) { + DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(sfa.scalar_type() == torch::kUInt8 or sfa.scalar_type() == torch::kInt); + DG_HOST_ASSERT(sfb.scalar_type() == torch::kUInt8 or sfb.scalar_type() == torch::kInt); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt and masked_m.is_contiguous()); + DG_HOST_ASSERT(a.is_contiguous() and b.is_contiguous() and d.is_contiguous()); + + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + auto config = get_best_config(desc); + tune_mxfp8_fp8_smem_config(config, desc); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto sfa_gran_k = recipe_a.has_value() + ? std::get<1>(recipe_a.value()) + : k / (static_cast(sfa.size(-1)) * (sfa.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(sfa_gran_k == 32 or sfa_gran_k == 128); + DG_HOST_ASSERT(sfa.size(-1) == ceil_div(k, sfa_gran_k * (sfa.scalar_type() == torch::kInt ? 4 : 1))); + const auto sfb_gran_k = recipe_b.has_value() + ? std::get<1>(recipe_b.value()) + : k / (static_cast(sfb.size(-1)) * (sfb.scalar_type() == torch::kInt ? 4 : 1)); + DG_HOST_ASSERT(sfb_gran_k == 32 or sfb_gran_k == 128); + DG_HOST_ASSERT(sfb.size(-1) == ceil_div(k, sfb_gran_k * (sfb.scalar_type() == torch::kInt ? 4 : 1))); + + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(1)), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(1)), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + const typename SM90MXFP8FP8Gemm1D2DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .sfa = sfa.data_ptr(), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .sfa_stride_group = static_cast(sfa.stride(0)), + .sfa_stride_m = static_cast(sfa.stride(1)), + .sfa_stride_k = static_cast(sfa.stride(2)), + .sfa_gran_k = static_cast(sfa_gran_k), + .sfa_packed_int32 = sfa.scalar_type() == torch::kInt, + .sfb_stride_group = static_cast(sfb.stride(0)), + .sfb_stride_n = static_cast(sfb.stride(1)), + .sfb_stride_k = static_cast(sfb.stride(2)), + .sfb_gran_k = static_cast(sfb_gran_k), + .sfb_packed_int32 = sfb.scalar_type() == torch::kInt, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto code = SM90MXFP8FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_m_grouped_mxfp8_fp8_gemm_masked_1d2d_scale_recipe_v8", code); + SM90MXFP8FP8Gemm1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index a9542e2f44..ddea7c59a7 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -47,6 +47,8 @@ m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nn_contiguous, m_grouped_fp8_gemm_nt_masked, + m_grouped_mxfp8_fp8_gemm_nt_contiguous, + m_grouped_mxfp8_fp8_gemm_nt_masked, k_grouped_fp8_gemm_nt_contiguous, k_grouped_fp8_gemm_tn_contiguous, # BF16 GEMMs diff --git a/deep_gemm/include/deep_gemm/impls/sm90_mxfp8_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_mxfp8_fp8_gemm_1d2d.cuh new file mode 100644 index 0000000000..bd1ea3e56a --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_mxfp8_fp8_gemm_1d2d.cuh @@ -0,0 +1,394 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +namespace mxfp8_fp8_detail { + +CUTLASS_DEVICE float e8m0_to_float(uint8_t scale) { + return __uint_as_float(static_cast(scale) << 23); +} + +CUTLASS_DEVICE uint8_t load_e8m0_scale(const void* ptr, uint32_t base_offset, + uint32_t k_scale_idx, uint32_t stride_k, + bool packed_int32) { + if (packed_int32) { + const auto packed = reinterpret_cast(ptr)[base_offset + (k_scale_idx / 4) * stride_k]; + return static_cast((packed >> ((k_scale_idx % 4) * 8)) & 0xff); + } + return reinterpret_cast(ptr)[base_offset + k_scale_idx * stride_k]; +} + +} // namespace mxfp8_fp8_detail + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_mxfp8_fp8_gemm_1d2d_impl(void* sfa, void* sfb, int* grouped_layout, + uint32_t sfa_stride_group, uint32_t sfa_stride_m, uint32_t sfa_stride_k, + uint32_t sfa_gran_k, bool sfa_packed_int32, + uint32_t sfb_stride_group, uint32_t sfb_stride_n, uint32_t sfb_stride_k, + uint32_t sfb_gran_k, bool sfb_packed_int32, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + DG_STATIC_ASSERT(BLOCK_K == 128, "MXFP8 scale stage assumes 4 K/32 scale groups"); + DG_STATIC_ASSERT(kNumStages >= 1, "Invalid pipeline stages"); + + using WGMMA = typename mma::sm90::FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SHAPE_K_SFA_PER_STAGE = BLOCK_K / 32; + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * SHAPE_K_SFA_PER_STAGE * sizeof(uint8_t); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + static constexpr uint32_t SHAPE_K_SFB_PER_STAGE = BLOCK_K / 32; + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * SHAPE_K_SFB_PER_STAGE * sizeof(uint8_t); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + + const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = ptx::get_lane_idx(); + + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE + + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + }); + + auto barrier_start_ptr = reinterpret_cast( + smem_buffer + SMEM_SF_OFFSET + kNumStages * (ALIGNED_SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + cutlass::arch::fence_barrier_init(); + } + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + + cudaGridDependencySynchronize(); + + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_dealloc(); + if (warp_idx == kNumMathThreads / 32 + 2) { + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const bool is_producer_leader = cute::elect_one_sync(); + if (is_producer_leader) + empty_barriers[stage_idx]->wait(phase ^ 1); + __syncwarp(); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + if (is_producer_leader) { + tma::copy(&tensor_map_a, &full_barrier, + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma::copy(&tensor_map_b, &full_barrier, + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + } + + const uint32_t sfa_base_m = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + for (uint32_t i = lane_idx; i < BLOCK_M * SHAPE_K_SFA_PER_STAGE; i += 32) { + const uint32_t m_offset = i / SHAPE_K_SFA_PER_STAGE; + const uint32_t k_scale_offset = i % SHAPE_K_SFA_PER_STAGE; + const uint32_t m_idx = sfa_base_m + m_offset; + const uint32_t k_idx_for_scale = k_block_idx * BLOCK_K + k_scale_offset * 32; + const uint32_t k_scale_idx = k_idx_for_scale / sfa_gran_k; + const bool is_valid = m_idx < shape_m * (kMasked ? kNumGroups : 1) and + k_idx_for_scale < shape_k; + const uint32_t sfa_local_m = kMasked ? (m_idx - scheduler.current_group_idx * shape_m) : m_idx; + const uint32_t sfa_base_offset = (kMasked ? scheduler.current_group_idx * sfa_stride_group : 0) + + sfa_local_m * sfa_stride_m; + smem_sfa[stage_idx][i] = is_valid ? + mxfp8_fp8_detail::load_e8m0_scale( + sfa, sfa_base_offset, k_scale_idx, sfa_stride_k, sfa_packed_int32) : + static_cast(127); + } + + for (uint32_t i = lane_idx; i < BLOCK_N * SHAPE_K_SFB_PER_STAGE; i += 32) { + const uint32_t n_offset = i / SHAPE_K_SFB_PER_STAGE; + const uint32_t k_scale_offset = i % SHAPE_K_SFB_PER_STAGE; + const uint32_t n_idx = n_block_idx * BLOCK_N + n_offset; + const uint32_t k_idx_for_scale = k_block_idx * BLOCK_K + k_scale_offset * 32; + const uint32_t k_scale_idx = k_idx_for_scale / sfb_gran_k; + const bool is_valid = n_idx < shape_n and k_idx_for_scale < shape_k; + const uint32_t sfb_group_idx = kMasked ? + scheduler.current_group_idx : + static_cast(cute::max(0, grouped_layout[m_block_idx * BLOCK_M])); + const uint32_t sfb_base_offset = sfb_group_idx * sfb_stride_group + + n_idx * sfb_stride_n; + smem_sfb[stage_idx][i] = is_valid ? + mxfp8_fp8_detail::load_e8m0_scale( + sfb, sfb_base_offset, k_scale_idx, sfb_stride_k, sfb_packed_int32) : + static_cast(127); + } + __threadfence_block(); + __syncwarp(); + + if (is_producer_leader) + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + cutlass::arch::warpgroup_reg_alloc(); + + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for BLOCK_M < 64"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; + + auto empty_barrier_arrive = [&]() { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); + } + }; + + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + full_barriers[stage_idx]->wait(phase); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Pack-load all SFA/SFB bytes for this wave_block from SMEM: + // SFA row stride is SHAPE_K_SFA_PER_STAGE (== BLOCK_K/32 == 4) bytes, + // so 4 SFA bytes (one per kk) are loaded with a single 32-bit LDS. + // SFB has the same layout; two adjacent N rows (n_base, n_base+1) form + // 8 contiguous bytes that we fetch with a single ld.shared.v2.u32. + uint32_t sfa_pack_0 = 0, sfa_pack_1 = 0; + uint32_t sfb_pack[WGMMA::kNumAccum / 4][2]; + if (do_wgmma_store) { + sfa_pack_0 = ptx::ld_shared(reinterpret_cast( + smem_sfa[stage_idx] + (r_0 + m_offset) * SHAPE_K_SFA_PER_STAGE)); + sfa_pack_1 = ptx::ld_shared(reinterpret_cast( + smem_sfa[stage_idx] + (r_1 + m_offset) * SHAPE_K_SFA_PER_STAGE)); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const uint32_t n_scale_offset = i * 8 + (lane_idx % 4) * 2; + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : "=r"(sfb_pack[i][0]), "=r"(sfb_pack[i][1]) + : "l"(__cvta_generic_to_shared( + smem_sfb[stage_idx] + + n_scale_offset * SHAPE_K_SFB_PER_STAGE))); + } + } + + #pragma unroll + for (uint32_t kk = 0; kk < BLOCK_K / WGMMA::K; ++ kk) { + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + kk * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + kk * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, false); + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + if (not do_wgmma_store) + continue; + + const float scale_a_0 = mxfp8_fp8_detail::e8m0_to_float( + static_cast((sfa_pack_0 >> (kk * 8)) & 0xff)); + const float scale_a_1 = mxfp8_fp8_detail::e8m0_to_float( + static_cast((sfa_pack_1 >> (kk * 8)) & 0xff)); + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float scale_b_0 = mxfp8_fp8_detail::e8m0_to_float( + static_cast((sfb_pack[i][0] >> (kk * 8)) & 0xff)); + const float scale_b_1 = mxfp8_fp8_detail::e8m0_to_float( + static_cast((sfb_pack[i][1] >> (kk * 8)) & 0xff)); + shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); + } + } + } else { + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); + } + } + + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + if (not do_wgmma_store) + continue; + + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + smem_ptr = reinterpret_cast(smem_d) + + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + + m_offset * kSwizzleDMode + + atom_offset * BLOCK_M * kSwizzleDMode + + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; + } else { + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + ptx::SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + auto n_idx = n_block_idx * BLOCK_N + in_block_n_offset; + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only supports sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/tests/test_sm90_mxfp8_fp8.py b/tests/test_sm90_mxfp8_fp8.py new file mode 100644 index 0000000000..6c2ba76f4c --- /dev/null +++ b/tests/test_sm90_mxfp8_fp8.py @@ -0,0 +1,501 @@ +import sys +from pathlib import Path +from typing import Callable + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import deep_gemm +from deep_gemm.testing import bench, calc_diff +from deep_gemm.utils.math import per_token_cast_to_fp8 + + +def _require_sm90() -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for SM90 MXFP8FP8 tests") + major, _ = torch.cuda.get_device_capability() + if major != 9: + pytest.skip(f"SM90 MXFP8FP8 tests require sm_90, got sm_{major}x") + + +def _cast_back_from_fp8_1d(x: torch.Tensor, sf: torch.Tensor, gran_k: int) -> torch.Tensor: + group_idx = torch.arange(x.size(-1), device=x.device) // gran_k + return x.float() * sf[..., group_idx] + + +def _e8m0_from_fp32_pow2(sf: torch.Tensor) -> torch.Tensor: + assert sf.dtype == torch.float32 or sf.dtype == torch.float + return ((sf.view(torch.int32) >> 23) & 0xFF).to(torch.uint8) + + +def _pack_ue8m0_u8_to_i32(sf: torch.Tensor) -> torch.Tensor: + assert sf.dtype == torch.uint8 + if sf.shape[-1] % 4 != 0: + padded = torch.zeros( + (*sf.shape[:-1], ((sf.shape[-1] + 3) // 4) * 4), + device=sf.device, + dtype=sf.dtype, + ) + padded[..., : sf.shape[-1]] = sf + sf = padded + sf_i32 = sf.contiguous().view(*sf.shape[:-1], sf.shape[-1] // 4, 4).to(torch.int32) + return ( + sf_i32[..., 0] + | torch.bitwise_left_shift(sf_i32[..., 1], 8) + | torch.bitwise_left_shift(sf_i32[..., 2], 16) + | torch.bitwise_left_shift(sf_i32[..., 3], 24) + ).contiguous() + + +def _pack_ue8m0_u8_to_i32_mn_major(sf: torch.Tensor) -> torch.Tensor: + packed = _pack_ue8m0_u8_to_i32(sf) + return packed.transpose(-1, -2).contiguous().transpose(-1, -2) + + +def _fp32_from_e8m0_u8(sf: torch.Tensor) -> torch.Tensor: + return torch.bitwise_left_shift(sf.to(torch.int32), 23).contiguous().view(torch.float32) + + +def test_packed_ue8m0_i32_byte_order_matches_sm100_layout(): + _require_sm90() + import deep_gemm.utils.layout + + sf = torch.tensor( + [[2.0, 4.0, 8.0, 16.0], [32.0, 64.0, 128.0, 256.0]], + device="cuda", + dtype=torch.float32, + ) + expected = _pack_ue8m0_u8_to_i32(_e8m0_from_fp32_pow2(sf)) + packed = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + assert torch.equal(packed.cpu(), expected.cpu()) + + +def _tflops(m: int, n: int, k: int, elapsed: float) -> float: + return 2.0 * m * n * k / elapsed / 1e12 + + +def _time_kernel(fn: Callable[[], None]) -> float: + fn() + return bench(fn, num_warmups=5, num_tests=10) + + +def _make_contiguous_case(groups: int, m_per_group: int, n: int, k: int): + m = groups * m_per_group + a_ref = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + + a_data, a_sf_fp32 = per_token_cast_to_fp8(a_ref, use_ue8m0=True, gran_k=32) + b_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_sf_fp32 = torch.empty((groups, n, k // 32), device="cuda", dtype=torch.float32) + for group_id in range(groups): + b_data[group_id], b_sf_fp32[group_id] = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=True, gran_k=32 + ) + + grouped_layout = torch.arange(groups, device="cuda", dtype=torch.int32).repeat_interleave(m_per_group) + a = (a_data, _e8m0_from_fp32_pow2(a_sf_fp32)) + a_dequant = _cast_back_from_fp8_1d(a_data, a_sf_fp32, gran_k=32) + ref = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + for group_id in range(groups): + start = group_id * m_per_group + end = start + m_per_group + b_dequant = _cast_back_from_fp8_1d(b_data[group_id], b_sf_fp32[group_id], gran_k=32) + ref[start:end] = (a_dequant[start:end] @ b_dequant.t()).to(torch.bfloat16) + return a, (b_data, _e8m0_from_fp32_pow2(b_sf_fp32)), grouped_layout, ref + + +def test_m_grouped_mxfp8_fp8_contiguous_e8m0_scale_accuracy(): + _require_sm90() + # SM90 grouped-contiguous WGMMA/TMA maps one B group per M tile. + groups, m_per_group, n, k = 2, 128, 48, 128 + a, b, grouped_layout, ref = _make_contiguous_case(groups, m_per_group, n, k) + d = torch.empty_like(ref) + + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_contiguous(a, b, d, grouped_layout) + diff = calc_diff(d, ref) + assert diff < 0.03 + + +def test_m_grouped_mxfp8_fp8_masked_e8m0_scale_accuracy(): + _require_sm90() + groups, max_m, n, k = 2, 32, 48, 128 + masked_m = torch.tensor([7, 19], device="cuda", dtype=torch.int32) + a_ref = torch.randn((groups, max_m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + + a_data = torch.empty((groups, max_m, k), device="cuda", dtype=torch.float8_e4m3fn) + a_sf_fp32 = torch.empty((groups, max_m, k // 32), device="cuda", dtype=torch.float32) + b_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_sf_fp32 = torch.empty((groups, n, k // 32), device="cuda", dtype=torch.float32) + for group_id in range(groups): + a_data[group_id], a_sf_fp32[group_id] = per_token_cast_to_fp8( + a_ref[group_id], use_ue8m0=True, gran_k=32 + ) + b_data[group_id], b_sf_fp32[group_id] = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=True, gran_k=32 + ) + + a = (a_data, _e8m0_from_fp32_pow2(a_sf_fp32)) + b = (b_data, _e8m0_from_fp32_pow2(b_sf_fp32)) + a_dequant = _cast_back_from_fp8_1d(a_data, a_sf_fp32, gran_k=32) + ref = torch.zeros((groups, max_m, n), device="cuda", dtype=torch.bfloat16) + for group_id, valid_m in enumerate(masked_m.tolist()): + b_dequant = _cast_back_from_fp8_1d(b_data[group_id], b_sf_fp32[group_id], gran_k=32) + ref[group_id, :valid_m] = (a_dequant[group_id, :valid_m] @ b_dequant.t()).to(torch.bfloat16) + + d = torch.empty_like(ref) + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m=max_m) + diff = max( + calc_diff(d[group_id, :valid_m], ref[group_id, :valid_m]) + for group_id, valid_m in enumerate(masked_m.tolist()) + ) + assert diff < 0.03 + + +def test_m_grouped_mxfp8_fp8_contiguous_packed_int32_scale_accuracy(): + _require_sm90() + groups, m_per_group, n, k = 2, 128, 48, 640 + m = groups * m_per_group + a_ref = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + + a_data, a_sf_fp32 = per_token_cast_to_fp8(a_ref, use_ue8m0=True, gran_k=128) + b_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_sf_fp32 = torch.empty((groups, n, k // 32), device="cuda", dtype=torch.float32) + for group_id in range(groups): + b_data[group_id], b_sf_fp32[group_id] = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=True, gran_k=32 + ) + + a = (a_data, _pack_ue8m0_u8_to_i32(_e8m0_from_fp32_pow2(a_sf_fp32))) + b = (b_data, _pack_ue8m0_u8_to_i32(_e8m0_from_fp32_pow2(b_sf_fp32))) + grouped_layout = torch.arange(groups, device="cuda", dtype=torch.int32).repeat_interleave( + m_per_group + ) + + a_dequant = _cast_back_from_fp8_1d(a_data, a_sf_fp32, gran_k=128) + ref = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + for group_id in range(groups): + start = group_id * m_per_group + end = start + m_per_group + b_dequant = _cast_back_from_fp8_1d(b_data[group_id], b_sf_fp32[group_id], gran_k=32) + ref[start:end] = (a_dequant[start:end] @ b_dequant.t()).to(torch.bfloat16) + + d = torch.empty_like(ref) + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_contiguous( + a, b, d, grouped_layout, recipe_a=(1, 128), recipe_b=(1, 32) + ) + diff = calc_diff(d, ref) + assert diff < 0.03 + + +def test_m_grouped_mxfp8_fp8_contiguous_deepep_normal_scale_layout_accuracy(): + _require_sm90() + torch.manual_seed(0) + # Matches SGLang DeepEP normal layout: + # A scale: packed int32 MN-major non-contiguous view, gran_k=128 + # B scale: raw uint8 [expert, n, k/32], gran_k=32 + groups, m_per_group, n, k = 3, 128, 80, 640 + m = groups * m_per_group + a_ref = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + a_data, _ = per_token_cast_to_fp8(a_ref, use_ue8m0=True, gran_k=128) + b_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + for group_id in range(groups): + b_data[group_id], _ = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=True, gran_k=32 + ) + + # Keep exponents near 127 (scale 1.0). Wider synthetic ranges produce very + # large BF16 outputs, where a normal one-ULP BF16 difference has a misleading + # absolute error while relative/cosine error is still essentially zero. + a_exp = ( + 126 + + (torch.arange(m, device="cuda", dtype=torch.uint8).view(m, 1) % 2) + + (torch.arange(k // 128, device="cuda", dtype=torch.uint8).view(1, -1) % 2) + ) + b_exp = ( + 126 + + (torch.arange(groups, device="cuda", dtype=torch.uint8).view(groups, 1, 1) % 2) + + (torch.arange(n, device="cuda", dtype=torch.uint8).view(1, n, 1) % 2) + + (torch.arange(k // 32, device="cuda", dtype=torch.uint8).view(1, 1, -1) % 2) + ) + + a_scale_i32 = _pack_ue8m0_u8_to_i32_mn_major(a_exp) + b_scale_u8 = b_exp.contiguous() + a = (a_data, a_scale_i32) + b = (b_data, b_scale_u8) + grouped_layout = torch.arange(groups, device="cuda", dtype=torch.int32).repeat_interleave( + m_per_group + ) + + a_dequant = _cast_back_from_fp8_1d(a_data, _fp32_from_e8m0_u8(a_exp), gran_k=128) + b_scale_fp32 = _fp32_from_e8m0_u8(b_exp) + ref = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + for group_id in range(groups): + start = group_id * m_per_group + end = start + m_per_group + b_dequant = _cast_back_from_fp8_1d(b_data[group_id], b_scale_fp32[group_id], gran_k=32) + ref[start:end] = (a_dequant[start:end] @ b_dequant.t()).to(torch.bfloat16) + + d = torch.empty_like(ref) + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_contiguous( + a, b, d, grouped_layout, recipe_a=(1, 128), recipe_b=(1, 32) + ) + diff = calc_diff(d, ref) + max_abs_diff = (d.float() - ref.float()).abs().max().item() + ref_absmax = ref.float().abs().max().item() + max_rel_diff = max_abs_diff / max(ref_absmax, 1.0) + print( + "DeepEP-normal scale layout diff: " + f"calc_diff={diff:.6f}, max_abs_diff={max_abs_diff:.6f}, " + f"ref_absmax={ref_absmax:.6f}, max_rel_diff={max_rel_diff:.6f}, " + f"a_scale_shape={tuple(a_scale_i32.shape)}, a_scale_stride={tuple(a_scale_i32.stride())}, " + f"b_scale_shape={tuple(b_scale_u8.shape)}, b_scale_stride={tuple(b_scale_u8.stride())}" + ) + assert diff < 0.03 + + +def test_m_grouped_mxfp8_fp8_contiguous_dense_linear_raw_u8_scale_accuracy(): + _require_sm90() + torch.manual_seed(1) + # Matches SGLang dense linear through the SM90 grouped-contiguous wrapper: + # one logical RHS group, raw uint8 UE8M0 scales on both A and B, and padded M. + m, padded_m, n, k = 137, 256, 96, 640 + a_ref = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + a_data, _ = per_token_cast_to_fp8(a_ref, use_ue8m0=True, gran_k=32) + b_data, _ = per_token_cast_to_fp8(b_ref, use_ue8m0=True, gran_k=32) + + a_exp = ( + 124 + + (torch.arange(m, device="cuda", dtype=torch.uint8).view(m, 1) % 5) + + (torch.arange(k // 32, device="cuda", dtype=torch.uint8).view(1, -1) % 3) + ) + b_exp = ( + 124 + + (torch.arange(n, device="cuda", dtype=torch.uint8).view(n, 1) % 5) + + (torch.arange(k // 32, device="cuda", dtype=torch.uint8).view(1, -1) % 3) + ) + + kernel_a = torch.zeros((padded_m, k), device="cuda", dtype=torch.float8_e4m3fn) + kernel_a[:m] = a_data + kernel_a_scale = torch.zeros((padded_m, k // 32), device="cuda", dtype=torch.uint8) + kernel_a_scale[:m] = a_exp + kernel_b = b_data.unsqueeze(0).contiguous() + kernel_b_scale = b_exp.unsqueeze(0).contiguous() + m_indices = torch.full((padded_m,), -1, device="cuda", dtype=torch.int32) + m_indices[:m] = 0 + + a_dequant = _cast_back_from_fp8_1d(a_data, _fp32_from_e8m0_u8(a_exp), gran_k=32) + b_dequant = _cast_back_from_fp8_1d(b_data, _fp32_from_e8m0_u8(b_exp), gran_k=32) + ref = (a_dequant @ b_dequant.t()).to(torch.bfloat16) + d_padded = torch.empty((padded_m, n), device="cuda", dtype=torch.bfloat16) + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_contiguous( + (kernel_a, kernel_a_scale), + (kernel_b, kernel_b_scale), + d_padded, + m_indices, + recipe_a=(1, 32), + recipe_b=(1, 32), + ) + d = d_padded[:m] + + inv_a_scale = _fp32_from_e8m0_u8((254 - a_exp).to(torch.uint8)) + inv_b_scale = _fp32_from_e8m0_u8((254 - b_exp).to(torch.uint8)) + inv_ref = ( + _cast_back_from_fp8_1d(a_data, inv_a_scale, gran_k=32) + @ _cast_back_from_fp8_1d(b_data, inv_b_scale, gran_k=32).t() + ).to(torch.bfloat16) + + diff = calc_diff(d, ref) + inverse_diff = calc_diff(d, inv_ref) + max_abs_diff = (d.float() - ref.float()).abs().max().item() + ref_absmax = ref.float().abs().max().item() + max_rel_diff = max_abs_diff / max(ref_absmax, 1.0) + print( + "Dense raw-u8 scale layout diff: " + f"calc_diff={diff:.6f}, inverse_scale_calc_diff={inverse_diff:.6f}, " + f"max_abs_diff={max_abs_diff:.6f}, ref_absmax={ref_absmax:.6f}, " + f"max_rel_diff={max_rel_diff:.6f}, " + f"a_scale_shape={tuple(kernel_a_scale.shape)}, " + f"b_scale_shape={tuple(kernel_b_scale.shape)}" + ) + assert diff < 0.03 + assert inverse_diff > diff + 0.1 + + +def test_m_grouped_mxfp8_fp8_masked_packed_int32_mn_major_scale_accuracy(): + _require_sm90() + groups, max_m, n, k = 3, 128, 64, 640 + masked_m = torch.tensor([7, 65, 113], device="cuda", dtype=torch.int32) + a_ref = torch.randn((groups, max_m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + + a_data = torch.empty((groups, max_m, k), device="cuda", dtype=torch.float8_e4m3fn) + a_sf_fp32 = torch.empty((groups, max_m, k // 128), device="cuda", dtype=torch.float32) + b_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_sf_fp32 = torch.empty((groups, n, k // 32), device="cuda", dtype=torch.float32) + for group_id in range(groups): + a_data[group_id], a_sf_fp32[group_id] = per_token_cast_to_fp8( + a_ref[group_id], use_ue8m0=True, gran_k=128 + ) + b_data[group_id], b_sf_fp32[group_id] = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=True, gran_k=32 + ) + + a = (a_data, _pack_ue8m0_u8_to_i32_mn_major(_e8m0_from_fp32_pow2(a_sf_fp32))) + b = (b_data, _e8m0_from_fp32_pow2(b_sf_fp32)) + a_dequant = _cast_back_from_fp8_1d(a_data, a_sf_fp32, gran_k=128) + ref = torch.zeros((groups, max_m, n), device="cuda", dtype=torch.bfloat16) + for group_id, valid_m in enumerate(masked_m.tolist()): + b_dequant = _cast_back_from_fp8_1d(b_data[group_id], b_sf_fp32[group_id], gran_k=32) + ref[group_id, :valid_m] = (a_dequant[group_id, :valid_m] @ b_dequant.t()).to(torch.bfloat16) + + d = torch.empty_like(ref) + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_masked( + a, b, d, masked_m, expected_m=max_m, recipe_a=(1, 128), recipe_b=(1, 32) + ) + diff = max( + calc_diff(d[group_id, :valid_m], ref[group_id, :valid_m]) + for group_id, valid_m in enumerate(masked_m.tolist()) + ) + assert diff < 0.03 + + +def test_m_grouped_mxfp8_vs_fp8_perf_contiguous_and_masked(): + _require_sm90() + groups, n, k = 4, 1024, 1024 + + # Contiguous: one B group per M tile. + m_per_group = 128 + m = groups * m_per_group + a_ref = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + b_ref = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + a_data, a_mx_sf_fp32 = per_token_cast_to_fp8(a_ref, use_ue8m0=True, gran_k=32) + a = (a_data, _e8m0_from_fp32_pow2(a_mx_sf_fp32)) + a_fp8 = per_token_cast_to_fp8(a_ref, use_ue8m0=False, gran_k=128) + grouped_layout = torch.arange(groups, device="cuda", dtype=torch.int32).repeat_interleave(m_per_group) + + b_mx_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_mx_sf_fp32 = torch.empty((groups, n, k // 32), device="cuda", dtype=torch.float32) + b_fp8_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_fp8_sf = torch.empty((groups, n, k // 128), device="cuda", dtype=torch.float32) + for group_id in range(groups): + b_mx_data[group_id], b_mx_sf_fp32[group_id] = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=True, gran_k=32 + ) + b_fp8_data[group_id], b_fp8_sf[group_id] = per_token_cast_to_fp8( + b_ref[group_id], use_ue8m0=False, gran_k=128 + ) + b_mx = (b_mx_data, _e8m0_from_fp32_pow2(b_mx_sf_fp32)) + b_fp8 = (b_fp8_data, b_fp8_sf) + d_mx = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + d_fp8 = torch.empty_like(d_mx) + + def run_mx_contiguous(): + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_contiguous(a, b_mx, d_mx, grouped_layout) + + def run_fp8_contiguous(): + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + a_fp8, b_fp8, d_fp8, grouped_layout, recipe_a=(1, 128), recipe_b=(1, 128) + ) + + mx_contiguous_elapsed = _time_kernel(run_mx_contiguous) + fp8_contiguous_elapsed = _time_kernel(run_fp8_contiguous) + contiguous_diff = float(calc_diff(d_mx, d_fp8)) + + # Masked: same shape class, but allow uneven active rows per group. + max_m = 128 + masked_m = torch.tensor([128, 96, 64, 32], device="cuda", dtype=torch.int32) + a_ref_masked = torch.randn((groups, max_m, k), device="cuda", dtype=torch.bfloat16) + b_ref_masked = torch.randn((groups, n, k), device="cuda", dtype=torch.bfloat16) + a_masked_data = torch.empty((groups, max_m, k), device="cuda", dtype=torch.float8_e4m3fn) + a_masked_sf_fp32 = torch.empty((groups, max_m, k // 32), device="cuda", dtype=torch.float32) + a_fp8_masked_data = torch.empty((groups, max_m, k), device="cuda", dtype=torch.float8_e4m3fn) + a_fp8_masked_sf = torch.empty((groups, max_m, k // 128), device="cuda", dtype=torch.float32) + b_mx_masked_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_mx_masked_sf_fp32 = torch.empty((groups, n, k // 32), device="cuda", dtype=torch.float32) + b_fp8_masked_data = torch.empty((groups, n, k), device="cuda", dtype=torch.float8_e4m3fn) + b_fp8_masked_sf = torch.empty((groups, n, k // 128), device="cuda", dtype=torch.float32) + for group_id in range(groups): + a_masked_data[group_id], a_masked_sf_fp32[group_id] = per_token_cast_to_fp8( + a_ref_masked[group_id], use_ue8m0=True, gran_k=32 + ) + a_fp8_masked_data[group_id], a_fp8_masked_sf[group_id] = per_token_cast_to_fp8( + a_ref_masked[group_id], use_ue8m0=False, gran_k=128 + ) + b_mx_masked_data[group_id], b_mx_masked_sf_fp32[group_id] = per_token_cast_to_fp8( + b_ref_masked[group_id], use_ue8m0=True, gran_k=32 + ) + b_fp8_masked_data[group_id], b_fp8_masked_sf[group_id] = per_token_cast_to_fp8( + b_ref_masked[group_id], use_ue8m0=False, gran_k=128 + ) + a_masked = (a_masked_data, _e8m0_from_fp32_pow2(a_masked_sf_fp32)) + a_fp8_masked = (a_fp8_masked_data, a_fp8_masked_sf) + b_mx_masked = (b_mx_masked_data, _e8m0_from_fp32_pow2(b_mx_masked_sf_fp32)) + b_fp8_masked = (b_fp8_masked_data, b_fp8_masked_sf) + d_mx_masked = torch.empty((groups, max_m, n), device="cuda", dtype=torch.bfloat16) + d_fp8_masked = torch.empty_like(d_mx_masked) + + def run_mx_masked(): + deep_gemm.m_grouped_mxfp8_fp8_gemm_nt_masked( + a_masked, b_mx_masked, d_mx_masked, masked_m, expected_m=max_m + ) + + def run_fp8_masked(): + deep_gemm.m_grouped_fp8_gemm_nt_masked( + a_fp8_masked, + b_fp8_masked, + d_fp8_masked, + masked_m, + expected_m=max_m, + recipe_a=(1, 128), + recipe_b=(1, 128), + ) + + mx_masked_elapsed = _time_kernel(run_mx_masked) + fp8_masked_elapsed = _time_kernel(run_fp8_masked) + masked_diff = max( + float(calc_diff(d_mx_masked[group_id, :valid_m], d_fp8_masked[group_id, :valid_m])) + for group_id, valid_m in enumerate(masked_m.tolist()) + ) + + masked_active_m = int(masked_m.sum().item()) + rows = [ + ( + "contiguous", + m, + mx_contiguous_elapsed, + fp8_contiguous_elapsed, + _tflops(m, n, k, mx_contiguous_elapsed), + _tflops(m, n, k, fp8_contiguous_elapsed), + contiguous_diff, + ), + ( + "masked", + masked_active_m, + mx_masked_elapsed, + fp8_masked_elapsed, + _tflops(masked_active_m, n, k, mx_masked_elapsed), + _tflops(masked_active_m, n, k, fp8_masked_elapsed), + masked_diff, + ), + ] + print("kernel | active M | MXFP8 us | FP8 us | MXFP8 TFLOPS | FP8 TFLOPS | speedup | diff") + print("-- | -- | -- | -- | -- | -- | -- | --") + for name, active_m, mx_elapsed, fp8_elapsed, mx_tflops, fp8_tflops, diff in rows: + print( + f"{name} | {active_m} | {mx_elapsed * 1e6:.0f} | {fp8_elapsed * 1e6:.0f} | " + f"{mx_tflops:.1f} | {fp8_tflops:.1f} | {fp8_elapsed / mx_elapsed:.2f}x | {diff:.4f}" + ) + + assert mx_contiguous_elapsed > 0 + assert fp8_contiguous_elapsed > 0 + assert mx_masked_elapsed > 0 + assert fp8_masked_elapsed > 0 + assert contiguous_diff == contiguous_diff + assert masked_diff == masked_diff