From 5e012d431084f13011b4f86e4b534024d6b54dd4 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 26 May 2026 11:24:07 +0800 Subject: [PATCH 01/15] Optimize SM90 MegaMoE kernels --- csrc/apis/mega.hpp | 253 +- csrc/jit/compiler.hpp | 6 + csrc/jit/handle.hpp | 7 +- csrc/jit_kernels/heuristics/mega_moe.hpp | 269 ++ csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 271 ++ csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp | 271 ++ deep_gemm/__init__.py | 3 + deep_gemm/include/deep_gemm/common/math.cuh | 10 +- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 2380 +++++++++++++++++ .../deep_gemm/impls/sm90_w4a8_mega_moe.cuh | 2318 ++++++++++++++++ .../deep_gemm/quantization/mxfp4_dequant.cuh | 58 + .../include/deep_gemm/scheduler/mega_moe.cuh | 64 +- deep_gemm/mega/__init__.py | 85 + scripts/run_ncu_mega_moe_sm90.sh | 89 + tests/bench_mega_moe_sm90.py | 250 ++ tests/bench_mega_moe_sm90_skew.py | 305 +++ tests/bench_mega_moe_sm90_skewed.py | 231 ++ tests/bench_w4a8_mega_moe_sm90.py | 250 ++ tests/test_mega_moe.py | 2 +- tests/test_mega_moe_hopper.py | 1525 +++++++++++ tests/test_mega_moe_sm90.py | 528 ++++ 21 files changed, 9145 insertions(+), 30 deletions(-) create mode 100644 csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp create mode 100644 csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp create mode 100644 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh create mode 100644 deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh create mode 100644 deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh create mode 100755 scripts/run_ncu_mega_moe_sm90.sh create mode 100644 tests/bench_mega_moe_sm90.py create mode 100644 tests/bench_mega_moe_sm90_skew.py create mode 100644 tests/bench_mega_moe_sm90_skewed.py create mode 100644 tests/bench_w4a8_mega_moe_sm90.py create mode 100644 tests/test_mega_moe_hopper.py create mode 100644 tests/test_mega_moe_sm90.py diff --git a/csrc/apis/mega.hpp b/csrc/apis/mega.hpp index efc3a780d1..ea9ff0fbab 100644 --- a/csrc/apis/mega.hpp +++ b/csrc/apis/mega.hpp @@ -8,6 +8,8 @@ #endif #include "../jit/device_runtime.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" +#include "../jit_kernels/impls/sm90_fp8_mega_moe.hpp" +#include "../jit_kernels/impls/sm90_w4a8_mega_moe.hpp" namespace deep_gemm::mega { @@ -23,6 +25,15 @@ get_symm_buffer_size_for_mega_moe( const bool& use_fp8_dispatch, const std::string& activation) { DG_HOST_ASSERT(num_experts % num_ranks == 0); + // Architecture-dependent SF dtype for the user-facing tensor view: + // * SM100: per-32 UE8M0 packed 4-into-int (`torch::kInt`). + // * SM90 : per-128 channel float (`torch::kFloat32`). + // Both use the same number of bytes per token (hidden / 32), so the symmetric + // buffer layout is shared; only the slice view dtype changes. + const auto arch_major = device_runtime->get_arch_major(); + const bool is_sm90 = arch_major == 9; + const auto sf_dtype = is_sm90 ? torch::kFloat32 : torch::kInt; + // Workspace bytes const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); @@ -31,7 +42,16 @@ get_symm_buffer_size_for_mega_moe( const auto bf16_token_layout = layout::Data(hidden * 2); const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); const auto fp8_sf_layout = layout::Data(hidden / 32); - const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); + // L2 acts SF granularity differs by arch: + // * SM100 packs 4 UE8M0 bytes per int along K, so each token uses + // `intermediate_hidden / 32` bytes (per-32 K). + // * SM90 stores per-64 K floats so that each L1 epilogue block (which + // produces 64 post-SwiGLU columns) can write its own SF independently + // without cross-CTA amax synchronisation; bytes per token become + // `intermediate_hidden / 64 * sizeof(float) = intermediate_hidden / 16`. + const int fp8_intermediate_sf_bytes_per_token = + is_sm90 ? (intermediate_hidden / 16) : (intermediate_hidden / 32); + const auto fp8_intermediate_sf_layout = layout::Data(fp8_intermediate_sf_bytes_per_token); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -86,10 +106,14 @@ get_symm_buffer_size_for_mega_moe( // Check SF buffer requirements DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); - DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); + // SM100 packs 4 UE8M0 bytes per int along K, so the padded SF token count + // must be divisible by 4. SM90 stores per-128 floats and has no such constraint. + if (not is_sm90) + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major + // Dtype is per-arch (see `sf_dtype` above): float on SM90, int (packed UE8M0) on SM100. auto slice_input_buffers = [=](const torch::Tensor& buffer) { auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), @@ -98,7 +122,7 @@ get_symm_buffer_size_for_mega_moe( auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), {num_max_tokens_per_rank, hidden / 128}, - torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + torch::TensorOptions().dtype(sf_dtype).device(buffer.device())); auto topk_idx = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), {num_max_tokens_per_rank, num_topk}, @@ -115,16 +139,16 @@ get_symm_buffer_size_for_mega_moe( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), {num_max_padded_sf_pool_tokens, hidden / 128}, {1, num_max_padded_sf_pool_tokens}, - torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + torch::TensorOptions().dtype(sf_dtype).device(buffer.device())); auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), {num_max_pool_tokens, intermediate_hidden}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), - {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {num_max_padded_sf_pool_tokens, is_sm90 ? intermediate_hidden / 64 : intermediate_hidden / 128}, {1, num_max_padded_sf_pool_tokens}, - torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + torch::TensorOptions().dtype(sf_dtype).device(buffer.device())); return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); }; return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; @@ -182,7 +206,10 @@ static void fp8_fp4_mega_moe( // Check stats counter if (cumulative_local_expert_recv_stats.has_value()) { DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); - DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); + const auto stats_numel = cumulative_local_expert_recv_stats->numel(); + const bool phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0; + DG_HOST_ASSERT(stats_numel == num_experts_per_rank or + (phase_profile and stats_numel >= num_experts_per_rank + 64)); DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); } @@ -224,11 +251,223 @@ static void fp8_fp4_mega_moe( sym_buffer.zero_(); } +// SM90 (Hopper) FP8 MegaMoE entry point. +// +// Mirrors `fp8_fp4_mega_moe` but expects FP8 (e4m3) weights with per-128 channel +// float scale factors. Top-level routing (which entry to call) is the caller's +// responsibility (see `deep_gemm/mega/__init__.py`). +static void fp8_mega_moe( + const torch::Tensor& y, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, + const torch::Tensor& sym_buffer, + const std::vector& sym_buffer_ptrs, const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, const int& num_topk, + const std::tuple& recipe, + const std::string& activation, + const std::optional& activation_clamp_opt, + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; + + // Architecture check + const auto arch_major = device_runtime->get_arch_major(); + DG_HOST_ASSERT(arch_major == 9); + + // Config checks: SM90 uses block (128, 128) float SF for weights, + // per-token per-128-K float SF for activations. + const auto num_tokens = static_cast(y.size(0)); + const auto [rm, rn, rk] = recipe; + DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128); + DG_HOST_ASSERT(activation == "swiglu"); + + // Activation checks + const auto activation_clamp = + activation_clamp_opt.value_or(std::numeric_limits::infinity()); + DG_HOST_ASSERT(activation_clamp >= 0); + + // Tensor checks: SM90 weights must be FP8 e4m3, K-major + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn); + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights); + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights); + DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); + DG_HOST_ASSERT(hidden == hidden_); + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); + + // Shape constraints required by the SM90 kernel: + // * Hidden dims must be multiples of 128 (per-128 SF + scheduler integer-tiling). + // * `l2_arrival_mask` is uint64, with one bit per L1-output N-block of size 64 in the + // intermediate dim, so `kNumL1BlockNs = intermediate_hidden / 64` must be ≤ 64. + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(intermediate_hidden / 64 <= 64); + + // Check weight SF layout (block (128, 128) float, MN-major; not TMA-loaded + // so no TMA-stride alignment is required, but we do require contiguity in + // the K-direction within each expert). + constexpr int kGranMN = 128, kGranK = 128; + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, + num_experts_per_rank, false, true, torch::kFloat); + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, + num_experts_per_rank, false, true, torch::kFloat); + + // Check stats counter + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + const auto stats_numel = cumulative_local_expert_recv_stats->numel(); + const bool phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0; + DG_HOST_ASSERT(stats_numel == num_experts_per_rank or + (phase_profile and stats_numel >= num_experts_per_rank + 64)); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + + // Check buffer bytes + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts_ = num_experts_per_rank * num_ranks; + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( + num_ranks, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + true, activation); + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); + DG_HOST_ASSERT(num_experts == num_experts_); + + // Already registered tensors + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); + + sm90_fp8_mega_moe(y, + l1_acts, l1_acts_sf, + l2_acts, l2_acts_sf, + l1_weights, l2_weights, + l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer_ptrs, + rank_idx, num_max_tokens_per_rank, + num_experts_per_rank, + num_tokens, num_topk, + hidden, intermediate_hidden, + activation_clamp, fast_math); + + if (get_env("DG_COMM_KERNEL_DEBUG")) + sym_buffer.zero_(); +} + +// SM90 (Hopper) W4A8 MegaMoE entry point — Phase-0 scaffold. +// +// Accepts the same FP8 weights + per-128 float SF as `fp8_mega_moe` and +// dispatches to the W4A8 runtime. The runtime kernel is currently a verbatim +// copy of the FP8 kernel (only renamed); subsequent phases will: +// 1. swap weight dtype to packed MXFP4 (uint8) + per-32 E8M0 (uint8) SF +// 2. add `dequant_mxfp4_to_fp8` insertion in the math warp mainloop +// 3. apply per-32 E8M0 scale on WGMMA accumulator (Marlin-style) +// +// In Phase 0 this entry is bitwise equivalent to `fp8_mega_moe`. Used to +// verify build/link/run plumbing for the new kernel name before changing +// kernel logic. +static void w4a8_mega_moe( + const torch::Tensor& y, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, + const torch::Tensor& sym_buffer, + const std::vector& sym_buffer_ptrs, const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, const int& num_topk, + const std::tuple& recipe, + const std::string& activation, + const std::optional& activation_clamp_opt, + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; + + // Architecture check + const auto arch_major = device_runtime->get_arch_major(); + DG_HOST_ASSERT(arch_major == 9); + + // Phase-0 input checks: identical to fp8_mega_moe. + const auto num_tokens = static_cast(y.size(0)); + const auto [rm, rn, rk] = recipe; + DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128); + DG_HOST_ASSERT(activation == "swiglu"); + + const auto activation_clamp = + activation_clamp_opt.value_or(std::numeric_limits::infinity()); + DG_HOST_ASSERT(activation_clamp >= 0); + + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn); + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights); + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights); + DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); + DG_HOST_ASSERT(hidden == hidden_); + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(intermediate_hidden / 64 <= 64); + + constexpr int kGranMN = 128, kGranK = 128; + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, + num_experts_per_rank, false, true, torch::kFloat); + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, + num_experts_per_rank, false, true, torch::kFloat); + + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + const auto stats_numel = cumulative_local_expert_recv_stats->numel(); + const bool phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0; + DG_HOST_ASSERT(stats_numel == num_experts_per_rank or + (phase_profile and stats_numel >= num_experts_per_rank + 64)); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts_ = num_experts_per_rank * num_ranks; + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( + num_ranks, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + true, activation); + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); + DG_HOST_ASSERT(num_experts == num_experts_); + + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); + + // Dispatch to W4A8 runtime (Phase-0: identical kernel to FP8). + sm90_w4a8_mega_moe(y, + l1_acts, l1_acts_sf, + l2_acts, l2_acts_sf, + l1_weights, l2_weights, + l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer_ptrs, + rank_idx, num_max_tokens_per_rank, + num_experts_per_rank, + num_tokens, num_topk, + hidden, intermediate_hidden, + activation_clamp, fast_math); + + if (get_env("DG_COMM_KERNEL_DEBUG")) + sym_buffer.zero_(); +} + static void register_apis(pybind11::module_& m) { #if DG_TENSORMAP_COMPATIBLE m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); + m.def("fp8_mega_moe", &fp8_mega_moe); + m.def("w4a8_mega_moe", &w4a8_mega_moe); #endif } diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 7d85a5f556..fb58a496ea 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -59,6 +59,12 @@ class Compiler { flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage"; if (get_env("DG_JIT_WITH_LINEINFO", 0)) flags += " -Xcompiler -rdynamic -lineinfo"; + // NOTES: `--device-debug` (-G) emits full device DWARF so that cuda-gdb + // can resolve `__device__` global variables / line numbers in JIT + // kernels. It DISABLES device-side optimization and will tank perf, so + // it is gated behind an explicit env var. + if (get_env("DG_JIT_WITH_DEVICE_DEBUG", 0)) + flags += " --device-debug"; } virtual ~Compiler() = default; diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index be3bc31c07..f073646507 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -120,8 +120,11 @@ using KernelHandle = CUfunction; using LaunchConfigHandle = CUlaunchConfig; using LaunchAttrHandle = CUlaunchAttribute; -// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4 -#if CUDA_VERSION >= 12040 +// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4. +// Define `DG_JIT_FORCE_LEGACY_LOAD` to force the older `cuModuleLoad` path +// (useful when building against a newer CUDA SDK but running with an older +// driver that lacks the `cuLibrary*` symbols). +#if CUDA_VERSION >= 12040 && !defined(DG_JIT_FORCE_LEGACY_LOAD) #define DG_JIT_USE_LIBRARY_ENUM_KERNELS DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels); diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index b1ba6bd70c..85ebf8e552 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -9,6 +9,7 @@ #include "../../utils/math.hpp" #include "../../utils/system.hpp" #include "sm100.hpp" +#include "sm90.hpp" namespace deep_gemm { @@ -237,4 +238,272 @@ static MegaMoEConfig get_mega_moe_config( return config; } +// ============================================================================ +// SM90 (Hopper) MegaMoE configuration +// ---------------------------------------------------------------------------- +// SM90 differs from SM100 in: +// - No tensor memory (TMEM): WGMMA accumulators live in registers. +// - No FP4: weights are FP8 e4m3, scales are per-128 channel float. +// - No 2-CTA cluster MMA: TMA multicast cluster=2 may still be used. +// - SF for activations is float (not UE8M0 int) and per-128 (not per-32). +// The kernel is in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and is currently +// a skeleton; this config is what the host runtime reads. +// ============================================================================ + +struct MegaMoESM90Config { + // Block tiling (no STORE_BLOCK_M / SF_BLOCK_M concept on SM90) + int block_m, block_n, block_k; + + // Cluster size for TMA multicast (1 or 2). Multicast is on A. + int cluster_size; + + // Pool capacity and SF-padded token count (SF is per-128 float on SM90) + int num_max_pool_tokens; + int num_padded_sf_pool_tokens; + + // Swizzle modes for TMA descriptors (acts/weights). Both are 128B on FP8 K-major. + int swizzle_acts_mode, swizzle_weights_mode; + + // Number of experts to process per wave + int num_experts_per_wave; + + // Pipeline stages and shared memory + int num_stages, smem_size; + + // Thread layout: dispatch + non-epilogue (TMA) + epilogue (math) + int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads; + + friend std::ostream& operator << (std::ostream& os, const MegaMoESM90Config& config) { + os << "MegaMoESM90Config(" + << "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k + << ", cluster_size=" << config.cluster_size + << ", num_max_pool_tokens=" << config.num_max_pool_tokens + << ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens + << ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode + << ", num_experts_per_wave=" << config.num_experts_per_wave + << ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size + << ", num_dispatch_threads=" << config.num_dispatch_threads + << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads + << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + return os; + } +}; + +static std::tuple get_block_config_for_mega_moe_sm90( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& num_tokens) { + // Keep mma.sync decode variants gated. M16 was correct but slower on H20; + // M32 is an experimental middle ground that cuts M64 padding without paying + // as much CTA/epilogue overhead as M16. + const int num_experts_per_rank = num_experts / num_ranks; + const float expected_tokens_per_expert = + static_cast(num_tokens) * num_topk / num_experts_per_rank; + const int requested_mma_m = get_env("DG_SM90_MOE_MMA_SYNC_M") > 0 + ? get_env("DG_SM90_MOE_MMA_SYNC_M") + : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); + DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); + const bool use_mma_sync_decode = + requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); + const bool use_bn256_split_n = + get_env("DG_SM90_MOE_BN256_2WG") != 0 and not use_mma_sync_decode; + const bool use_bn256_seq_n = + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and not use_mma_sync_decode; + const bool use_b_stationary_2wg = + get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; + DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); + DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n or use_bn256_seq_n))); + const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); + const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); + DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); + DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or + forced_epilogue_warpgroups == 1 or + forced_epilogue_warpgroups == 2); + const int block_m = forced_block_m > 0 + ? forced_block_m + : (use_b_stationary_2wg ? 128 : (use_mma_sync_decode ? requested_mma_m : 64)); + const int num_epilogue_warpgroups = forced_epilogue_warpgroups > 0 + ? forced_epilogue_warpgroups + : ((use_b_stationary_2wg or use_bn256_split_n) ? 2 : 1); + DG_HOST_ASSERT(block_m % num_epilogue_warpgroups == 0); + DG_HOST_ASSERT((block_m != 16 and block_m != 32) or num_epilogue_warpgroups == 1); + DG_HOST_ASSERT(block_m != 128 or num_epilogue_warpgroups == 2); + + DG_HOST_ASSERT(std::any_of( + layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, + [=](const auto& candidate) { return candidate == block_m; }) + ); + return {block_m, num_epilogue_warpgroups * 128}; +} + +static int get_num_experts_per_wave_for_mega_moe_sm90( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) { + if (const int forced = get_env("DG_SM90_MOE_EXPERTS_PER_WAVE"); forced > 0) { + DG_HOST_ASSERT(forced <= num_experts_per_rank); + DG_HOST_ASSERT(num_experts_per_rank % forced == 0); + return forced; + } + + // SM90 (Hopper) wave heuristic. + // + // The generic heuristic is useful in the middle of the block_m=64 band, but + // very sparse routing and large batches both do better as a single all-expert + // wave: sparse cases avoid extra L1->L2 wave transitions, while large cases + // keep enough work resident without fragmenting expert scheduling. + const float expected_tokens_per_expert = + static_cast(num_tokens) * num_topk / num_experts_per_rank; + if (block_m == 64 and (expected_tokens_per_expert < 1.0f or expected_tokens_per_expert > 4.0f)) { + return num_experts_per_rank; + } + return get_num_experts_per_wave_for_mega_moe( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); +} + +static std::pair get_pipeline_config_for_mega_moe_sm90( + const int& smem_capacity, + const int& num_experts, const int& hidden, + const int& block_m, const int& block_n, const int& block_k, + const int& num_dispatch_warps, const int& num_epilogue_warps) { + constexpr int kSmemAlignment = 1024; + + // Dispatch region (same as SM100) + const int smem_expert_count_size = align( + num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); + const int smem_send_buffers_size = align( + static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), + kSmemAlignment); + const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; + + // C/D output region: max of L1 FP8 (single-buffered, BLOCK_N/2 post-SwiGLU) + // and L2 BF16, then 1024-byte aligned (matches kernel's SMEM_CD_SIZE). + // The mma.sync decode path additionally stages one BLOCK_M x BLOCK_N FP32 + // accumulator tile in SMEM for logical-row epilogue mapping. + const auto num_epilogue_warpgroups = num_epilogue_warps / 4; + const bool split_n_warpgroups = block_m == 64 and block_n == 256 and num_epilogue_warpgroups == 2; + const bool serial_n_warpgroups = false; + const int wg_block_m = split_n_warpgroups ? block_m : block_m / num_epilogue_warpgroups; + const int wg_block_n = (split_n_warpgroups or serial_n_warpgroups) ? block_n / 2 : block_n; + const int smem_cd_accum = (block_m == 16 or block_m == 32) ? align(block_m * block_n * static_cast(sizeof(float)), kSmemAlignment) : 0; + const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) + const bool direct_l2_scatter = get_env("DG_SM90_MOE_DIRECT_L2_SCATTER", 0) != 0 and + block_m != 16 and block_m != 32 and + not split_n_warpgroups and not serial_n_warpgroups; + const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and + block_m != 16 and block_m != 32 and + not split_n_warpgroups and num_epilogue_warpgroups == 1; + const int smem_cd_l2 = direct_l2_scatter ? 0 : + num_epilogue_warpgroups * wg_block_m * wg_block_n * static_cast(sizeof(nv_bfloat16)); + const int smem_cd_l1_async = async_l1_tma_store ? + 2 * num_epilogue_warpgroups * wg_block_m * (block_n / 2) : 0; + const int smem_cd = smem_cd_accum + align(std::max(std::max(smem_cd_l1, smem_cd_l2), smem_cd_l1_async), kSmemAlignment); + + // SF on SM90: + // * SFA per stage must hold the larger of L1 (BLOCK_M floats, per-128 K) + // and L2 (2 * BLOCK_M floats, per-64 K), aligned to 128 bytes + // * SFB is loaded directly from global by the math warpgroup (block-(128,128) + // weight quantization), so no SMEM is reserved for it. + const int smem_sfa_half_stride_bytes = align(block_m * static_cast(sizeof(float)), 128); + const int smem_sfa_per_stage = 2 * smem_sfa_half_stride_bytes; + const int smem_sfb_per_stage = 0; + + // Per-stage: A tile + B tile + SFA tile + SFB tile + const int smem_per_stage = block_m * block_k + block_n * block_k + + smem_sfa_per_stage + smem_sfb_per_stage; + + // Barriers (8 bytes each): + // * dispatch: num_dispatch_warps + // * GEMM full + empty: 2 * num_stages + // * combine: 2 * num_epilogue_warps + const int smem_barriers_fixed = (num_dispatch_warps + 2 * num_epilogue_warps) * 8; + const int smem_barriers_per_stage = 2 * 8; + + // Fixed total + const int smem_fixed = smem_dispatch_size + smem_cd + smem_barriers_fixed; + + // Select max num_stages, with an optional SM90-only sweep override. + const int max_num_stages = (smem_capacity - smem_fixed) / + (smem_per_stage + smem_barriers_per_stage); + const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); + const int num_stages = forced_num_stages > 0 ? forced_num_stages : max_num_stages; + DG_HOST_ASSERT(num_stages >= 2 and num_stages <= max_num_stages); + return {num_stages, + smem_fixed + num_stages * (smem_per_stage + smem_barriers_per_stage)}; +} + +static MegaMoESM90Config get_mega_moe_config_sm90( + const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90( + num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); + const bool use_bn256_split_n = + get_env("DG_SM90_MOE_BN256_2WG") != 0 and block_m != 16 and block_m != 32; + const bool use_bn256_seq_n = + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and block_m != 16 and block_m != 32; + DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); + const int block_n = (use_bn256_split_n or use_bn256_seq_n) ? 256 : 128; + DG_HOST_ASSERT((not use_bn256_split_n) or num_epilogue_threads == 256); + const int block_k = 128; + // Default remains cluster_size=1. The experimental cluster=2 path below + // is M-split/B-multicast for exact-balanced large-M tests: peer CTAs process + // adjacent M blocks for the same expert/N tile and share the B TMA load. + // `DG_SM90_MOE_B_STATIONARY_2WG=1` extends the same idea to M128/2WG: + // two WGs split M within each CTA, and the cluster pair reuses one B tile + // across four M64 WGMMA consumers. + const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; + const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; + DG_HOST_ASSERT((not use_cluster_bcast_b) or + ((block_m == 64 and block_n == 128 and num_epilogue_threads == 128) or + (block_m == 128 and block_n == 128 and num_epilogue_threads == 256))); + const int cluster_size = use_cluster_bcast_b ? 2 : 1; + const int num_max_pool_tokens = layout::get_num_max_pool_tokens( + num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + const int swizzle_acts_mode = (block_m == 16 or block_m == 32) ? 0 : 128; + const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; + + const int num_sms = device_runtime->get_num_sms(); + const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); + + const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS"); + const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND") != 0; + DG_HOST_ASSERT(forced_dispatch_warps == 0 or forced_dispatch_warps == 2 or + forced_dispatch_warps == 4 or forced_dispatch_warps == 8); + const int num_dispatch_threads = (forced_dispatch_warps > 0 ? forced_dispatch_warps : 4) * 32; + DG_HOST_ASSERT((not compact_frontend) or num_dispatch_threads == 64); + const int num_non_epilogue_threads = compact_frontend ? 64 : 128; + + const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe_sm90( + SM90ArchSpec::smem_capacity, + num_experts, hidden, + block_m, block_n, block_k, + num_dispatch_threads / 32, num_epilogue_threads / 32); + + const auto config = MegaMoESM90Config { + block_m, block_n, block_k, + cluster_size, + num_max_pool_tokens, num_padded_sf_pool_tokens, + swizzle_acts_mode, swizzle_weights_mode, + num_experts_per_wave, + num_stages, smem_size, + num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads + }; + + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + const auto key = fmt::format( + "MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << key << ": " << config << std::endl; + printed.insert(key); + } + } + return config; +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp new file mode 100644 index 0000000000..309e7de50f --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -0,0 +1,271 @@ +#pragma once + +#include +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +// ============================================================================ +// SM90 (Hopper) FP8 MegaMoE host runtime +// ---------------------------------------------------------------------------- +// This is the SM90 counterpart of `SM100FP8FP4MegaMoERuntime`. The kernel +// itself lives in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and is currently a +// skeleton: dispatch/combine paths are intended to be portable from the SM100 +// version, while the GEMM (TMA load + WGMMA + epilogue) is being implemented +// in a follow-up step. +// +// Differences from SM100 path: +// * Activations and weights are both FP8 (e4m3); no FP4. +// * Activation/weight scale factors (SF) are per-128-channel float (not UE8M0 +// int + per-32 UTCCP layout). +// * No tensor memory: WGMMA accumulators are register-resident. +// * Cluster size is at most 2 (TMA multicast on A); no 2-CTA UMMA. +// ============================================================================ + +class SM90FP8MegaMoERuntime final : public LaunchRuntime { +public: + struct Args { + // Templated arguments + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + float activation_clamp; + bool fast_math; + bool async_l1_tma_store; + bool split_sfa_tma; + bool direct_l2_scatter; + bool l2_dual_accum; + bool phase_profile; + bool l1_dual_k_accum; + bool l2_nmajor_schedule; + bool l1_nmajor_schedule; + MegaMoESM90Config config; + + // Runtime arguments + void* y; + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + // Tensormaps for activations and weights. Weight scale factors use + // block (128, 128) quantization and are loaded by the math warpgroup + // directly from global memory (no TMA descriptor required). + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + const float* l1_weights_sf; + CUtensorMap tensor_map_l1_output; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + const float* l2_weights_sf; + + // Launch configs + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_mega_moe_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, + {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {} + >); +}}; +)", + args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.config.cluster_size, + args.launch_args.grid_dim.first, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false", + args.async_l1_tma_store ? "true" : "false", + args.split_sfa_tma ? "true" : "false", + args.direct_l2_scatter ? "true" : "false", + args.l2_dual_accum ? "true" : "false", + args.phase_profile ? "true" : "false", + args.l1_dual_k_accum ? "true" : "false", + args.l2_nmajor_schedule ? "true" : "false", + args.l1_nmajor_schedule ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.l1_weights_sf, + args.tensor_map_l1_output, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.l2_weights_sf + )); + } +}; + +static void sm90_fp8_mega_moe( + const torch::Tensor& y, + const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, + const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, + const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, + const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, + const std::vector& sym_buffer_ptrs, + const int& rank_idx, const int& num_max_tokens_per_rank, + const int& num_experts_per_rank, + const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math +) { + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); + + // Heuristics + const auto config = get_mega_moe_config_sm90( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens); + + // Tensormap construction + // Acts/weights: standard 2D TMA descriptors (FP8 K-major). + // Activation SF: per-128 channel float for L1, per-64 for L2 (MN-major, no swizzle). + // Weight SF: block (128, 128) raw float pointer (no TMA descriptor). + constexpr int kGranK = 128; + constexpr int kL2ActsSFGranK = 64; + const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, + hidden, config.num_max_pool_tokens, + config.block_k, config.block_m, + static_cast(l1_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, + config.num_padded_sf_pool_tokens, hidden, + config.block_m, kGranK, + 1, 0); + const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, + hidden, num_experts_per_rank * intermediate_hidden * 2, + config.block_k, config.block_n, + static_cast(l1_weights.stride(-2)), + config.swizzle_weights_mode); + // L1 output (post-SwiGLU FP8): N is halved. The SM90 epilogue writes this + // staging tile to SMEM as plain row-major bytes, so the TMA store descriptor + // must use no shared-memory swizzle. Later L2 TMA loads may still swizzle + // from this row-major global buffer into their own SMEM tile. + // The default TMA store is issued per warpgroup, each writing a WG_BLOCK_M + // row tile. The split-N experiment has two WGs produce different N halves + // of the same M rows, then one TMA store writes the full 64x128 post-SwiGLU tile. + const int num_epilogue_warpgroups_h = config.num_epilogue_threads / 128; + const bool split_n_warpgroups_h = + config.block_m == 64 and config.block_n == 256 and num_epilogue_warpgroups_h == 2; + const int wg_block_m = split_n_warpgroups_h ? config.block_m : config.block_m / num_epilogue_warpgroups_h; + const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, wg_block_m, + static_cast(l2_acts.stride(-2)), + 0); + const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, + config.num_padded_sf_pool_tokens, intermediate_hidden, + config.block_m, kL2ActsSFGranK, + 1, 0); + const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, + intermediate_hidden, num_experts_per_rank * hidden, + config.block_k, config.block_n, + static_cast(l2_weights.stride(-2)), + config.swizzle_weights_mode); + + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + + // Launch + const auto num_sms = device_runtime->get_num_sms(); + const SM90FP8MegaMoERuntime::Args args = { + .num_max_tokens_per_rank = num_max_tokens_per_rank, + .hidden = hidden, .intermediate_hidden = intermediate_hidden, + .num_experts = num_experts, .num_topk = num_topk, + .num_ranks = num_ranks, + .activation_clamp = activation_clamp, + .fast_math = fast_math, + .async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0, + .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, + .direct_l2_scatter = get_env("DG_SM90_MOE_DIRECT_L2_SCATTER", 0) != 0, + .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, + .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, + .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, + .l2_nmajor_schedule = get_env("DG_SM90_MOE_L2_NMAJOR", 0) != 0, + .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, + .config = config, + .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, + .num_tokens = num_tokens, + .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), + .tensor_map_l1_acts = tensor_map_l1_acts, + .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, + .tensor_map_l1_weights = tensor_map_l1_weights, + .l1_weights_sf = l1_weights_sf.data_ptr(), + .tensor_map_l1_output = tensor_map_l1_output, + .tensor_map_l2_acts = tensor_map_l2_acts, + .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, + .tensor_map_l2_weights = tensor_map_l2_weights, + .l2_weights_sf = l2_weights_sf.data_ptr(), + .launch_args = LaunchArgs(num_sms, config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, + config.smem_size, config.cluster_size) + }; + const auto code = SM90FP8MegaMoERuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_mega_moe", code); + SM90FP8MegaMoERuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp new file mode 100644 index 0000000000..52e8b380aa --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp @@ -0,0 +1,271 @@ +#pragma once + +#include +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +// ============================================================================ +// SM90 (Hopper) FP8 MegaMoE host runtime +// ---------------------------------------------------------------------------- +// This is the SM90 counterpart of `SM100FP8FP4MegaMoERuntime`. The kernel +// itself lives in `deep_gemm/impls/sm90_w4a8_mega_moe.cuh` and is currently a +// skeleton: dispatch/combine paths are intended to be portable from the SM100 +// version, while the GEMM (TMA load + WGMMA + epilogue) is being implemented +// in a follow-up step. +// +// Differences from SM100 path: +// * Activations and weights are both FP8 (e4m3); no FP4. +// * Activation/weight scale factors (SF) are per-128-channel float (not UE8M0 +// int + per-32 UTCCP layout). +// * No tensor memory: WGMMA accumulators are register-resident. +// * Cluster size is at most 2 (TMA multicast on A); no 2-CTA UMMA. +// ============================================================================ + +class SM90W4A8MegaMoERuntime final : public LaunchRuntime { +public: + struct Args { + // Templated arguments + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + float activation_clamp; + bool fast_math; + bool async_l1_tma_store; + bool split_sfa_tma; + bool direct_l2_scatter; + bool l2_dual_accum; + bool phase_profile; + bool l1_dual_k_accum; + bool l2_nmajor_schedule; + bool l1_nmajor_schedule; + MegaMoESM90Config config; + + // Runtime arguments + void* y; + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + // Tensormaps for activations and weights. Weight scale factors use + // block (128, 128) quantization and are loaded by the math warpgroup + // directly from global memory (no TMA descriptor required). + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + const float* l1_weights_sf; + CUtensorMap tensor_map_l1_output; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + const float* l2_weights_sf; + + // Launch configs + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_w4a8_mega_moe_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, + {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {} + >); +}}; +)", + args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.config.cluster_size, + args.launch_args.grid_dim.first, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false", + args.async_l1_tma_store ? "true" : "false", + args.split_sfa_tma ? "true" : "false", + args.direct_l2_scatter ? "true" : "false", + args.l2_dual_accum ? "true" : "false", + args.phase_profile ? "true" : "false", + args.l1_dual_k_accum ? "true" : "false", + args.l2_nmajor_schedule ? "true" : "false", + args.l1_nmajor_schedule ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.l1_weights_sf, + args.tensor_map_l1_output, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.l2_weights_sf + )); + } +}; + +static void sm90_w4a8_mega_moe( + const torch::Tensor& y, + const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, + const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, + const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, + const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, + const std::vector& sym_buffer_ptrs, + const int& rank_idx, const int& num_max_tokens_per_rank, + const int& num_experts_per_rank, + const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math +) { + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); + + // Heuristics + const auto config = get_mega_moe_config_sm90( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens); + + // Tensormap construction + // Acts/weights: standard 2D TMA descriptors (FP8 K-major). + // Activation SF: per-128 channel float for L1, per-64 for L2 (MN-major, no swizzle). + // Weight SF: block (128, 128) raw float pointer (no TMA descriptor). + constexpr int kGranK = 128; + constexpr int kL2ActsSFGranK = 64; + const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, + hidden, config.num_max_pool_tokens, + config.block_k, config.block_m, + static_cast(l1_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, + config.num_padded_sf_pool_tokens, hidden, + config.block_m, kGranK, + 1, 0); + const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, + hidden, num_experts_per_rank * intermediate_hidden * 2, + config.block_k, config.block_n, + static_cast(l1_weights.stride(-2)), + config.swizzle_weights_mode); + // L1 output (post-SwiGLU FP8): N is halved. The SM90 epilogue writes this + // staging tile to SMEM as plain row-major bytes, so the TMA store descriptor + // must use no shared-memory swizzle. Later L2 TMA loads may still swizzle + // from this row-major global buffer into their own SMEM tile. + // The default TMA store is issued per warpgroup, each writing a WG_BLOCK_M + // row tile. The split-N experiment has two WGs produce different N halves + // of the same M rows, then one TMA store writes the full 64x128 post-SwiGLU tile. + const int num_epilogue_warpgroups_h = config.num_epilogue_threads / 128; + const bool split_n_warpgroups_h = + config.block_m == 64 and config.block_n == 256 and num_epilogue_warpgroups_h == 2; + const int wg_block_m = split_n_warpgroups_h ? config.block_m : config.block_m / num_epilogue_warpgroups_h; + const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, wg_block_m, + static_cast(l2_acts.stride(-2)), + 0); + const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, + config.num_padded_sf_pool_tokens, intermediate_hidden, + config.block_m, kL2ActsSFGranK, + 1, 0); + const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, + intermediate_hidden, num_experts_per_rank * hidden, + config.block_k, config.block_n, + static_cast(l2_weights.stride(-2)), + config.swizzle_weights_mode); + + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + + // Launch + const auto num_sms = device_runtime->get_num_sms(); + const SM90W4A8MegaMoERuntime::Args args = { + .num_max_tokens_per_rank = num_max_tokens_per_rank, + .hidden = hidden, .intermediate_hidden = intermediate_hidden, + .num_experts = num_experts, .num_topk = num_topk, + .num_ranks = num_ranks, + .activation_clamp = activation_clamp, + .fast_math = fast_math, + .async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0, + .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, + .direct_l2_scatter = get_env("DG_SM90_MOE_DIRECT_L2_SCATTER", 0) != 0, + .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, + .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, + .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, + .l2_nmajor_schedule = get_env("DG_SM90_MOE_L2_NMAJOR", 0) != 0, + .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, + .config = config, + .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, + .num_tokens = num_tokens, + .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), + .tensor_map_l1_acts = tensor_map_l1_acts, + .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, + .tensor_map_l1_weights = tensor_map_l1_weights, + .l1_weights_sf = l1_weights_sf.data_ptr(), + .tensor_map_l1_output = tensor_map_l1_output, + .tensor_map_l2_acts = tensor_map_l2_acts, + .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, + .tensor_map_l2_weights = tensor_map_l2_weights, + .l2_weights_sf = l2_weights_sf.data_ptr(), + .launch_args = LaunchArgs(num_sms, config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, + config.smem_size, config.cluster_size) + }; + const auto code = SM90W4A8MegaMoERuntime::generate(args); + const auto runtime = compiler->build("sm90_w4a8_mega_moe", code); + SM90W4A8MegaMoERuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index a9542e2f44..d7c2aaa18e 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -85,7 +85,10 @@ SymmBuffer, get_symm_buffer_for_mega_moe, transform_weights_for_mega_moe, + transform_weights_for_mega_moe_sm90, fp8_fp4_mega_moe, + fp8_mega_moe, + w4a8_mega_moe, ) # Some utils diff --git a/deep_gemm/include/deep_gemm/common/math.cuh b/deep_gemm/include/deep_gemm/common/math.cuh index 0f0d250481..a93ef04e01 100644 --- a/deep_gemm/include/deep_gemm/common/math.cuh +++ b/deep_gemm/include/deep_gemm/common/math.cuh @@ -62,6 +62,14 @@ CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { #endif } +CUTLASS_DEVICE float2 mul2(const float2& a, const float2& b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __fmul2_rn(a, b); +#else + return make_float2(__fmul_rn(a.x, b.x), __fmul_rn(a.y, b.y)); +#endif +} + CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { float ret; asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); @@ -91,7 +99,7 @@ template CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; - const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto scaled = mul2(amax, finfo_factor); const auto exp_x = fast_log2_ceil(scaled.x); const auto exp_y = fast_log2_ceil(scaled.y); sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh new file mode 100644 index 0000000000..d0719a068e --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -0,0 +1,2380 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define __CLION_IDE__ + +namespace deep_gemm { + +// ============================================================================ +// SM90 (Hopper) FP8 MegaMoE — full implementation +// ---------------------------------------------------------------------------- +// Pipeline (cluster=1, no TMA multicast): +// * Dispatch warps: pull tokens (FP8) and SF (per-128 channel float) from +// remote ranks via NVLink into the local L1 pool. +// * GEMM TMA-load warps (1 for A+SFA, 1 for B+SFB) feed the pipeline stages. +// * Math warpgroups (1 or 2, totalling kNumEpilogueThreads) consume each +// stage with WGMMA, accumulate into registers, then run the epilogue: +// - L1 (Linear1): SwiGLU with gate/up granularity-8 interleaved layout, +// per-row amax over the 64 post-SwiGLU columns of this block, FP8 e4m3 +// quantize, STSM into SMEM, TMA store to local L1 output buffer. +// The per-row SF is written as a *float* into the L2-acts SF buffer at +// per-64 K granularity (one SF per L1 N block), so each block is fully +// self-contained and no cross-CTA amax synchronisation is needed. +// - L2 (Linear2): BF16 cast of the GEMM output, STSM into SMEM, then +// NVLink scatter to remote combine buffers. +// * After all GEMM blocks, the math warps run the COMBINE step (top-k +// reduction in BF16) — ported verbatim from the SM100 kernel. +// ============================================================================ + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kClusterSize, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + bool kAsyncL1TMAStoreRequested = true, + bool kSplitSFATMARequested = false, + bool kDirectL2ScatterRequested = false, + bool kL2DualAccumRequested = false, + bool kPhaseProfileRequested = false, + bool kL1DualKAccumRequested = false, + bool kL2NMajorScheduleRequested = false, + bool kL1NMajorScheduleRequested = false, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm90_fp8_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const float* __restrict__ l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const float* __restrict__ l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900) and (__CUDA_ARCH__ < 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // ===================================================================== + // Template checks + // ===================================================================== + DG_STATIC_ASSERT(kNumDispatchThreads == 64 or kNumDispatchThreads % 128 == 0, + "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 64 or kNumNonEpilogueThreads == 128, + "Invalid number of GEMM TMA warps (2 or 4 warps expected)"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of math/epilogue threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + DG_STATIC_ASSERT(kClusterSize == 1 or kClusterSize == 2, "Invalid cluster size"); + DG_STATIC_ASSERT(kNumSMs % kClusterSize == 0, "SM count must be divisible by cluster size"); + DG_STATIC_ASSERT(BLOCK_M == 16 or BLOCK_M == 32 or BLOCK_M % 64 == 0, + "BLOCK_M must be 16/32 for mma.sync decode or a multiple of WGMMA::M (64)"); + DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); + DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); + + // ===================================================================== + // Thread / warp identification + // ===================================================================== + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch all TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + } + + // ===================================================================== + // Workspaces and symmetric buffer slicing (mirror SM100 layout, except SF + // for L2 activations uses per-64 K granularity) + // ===================================================================== + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + // Per-128 K float SF: 4 bytes per per-128 group => `kHidden / 32` bytes/token (same as SM100 packing) + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + // Per-64 K float SF (SM90 only): 4 bytes per per-64 group => `kIntermediateHidden / 16` bytes/token + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered input area + const auto input_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxTokensPerRank, workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumMaxTokensPerRank, input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer(input_topk_idx_layout, 1, kNumMaxTokensPerRank, input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer(input_topk_weights_layout, 1, kNumMaxTokensPerRank, input_topk_idx_buffer.get_end_ptr()); + + // L1 input area + const auto l1_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxPoolTokens, input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumPaddedSFPoolTokens, l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer(l1_topk_weights_layout, 1, kNumMaxPoolTokens, l1_sf_buffer.get_end_ptr()); + + // L2 input area + const auto l2_token_buffer = layout::Buffer(fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer(fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, l2_token_buffer.get_end_ptr()); + + // Combine input area + const auto combine_token_buffer = layout::Buffer(bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, l2_sf_buffer.get_end_ptr()); + + // ===================================================================== + // GEMM data types and shape constants + // ===================================================================== + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::float_e4m3_t; + constexpr bool kUseMMASync = (BLOCK_M == 16 or BLOCK_M == 32); + constexpr bool kSplitNWarpgroups = + (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 2; + constexpr bool kSerialNWarpgroups = false; + constexpr bool kWideNWarpgroups = + (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 1; + constexpr uint32_t WG_BLOCK_M = kSplitNWarpgroups ? BLOCK_M : BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N + constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N + constexpr bool kAsyncL1TMAStore = + kAsyncL1TMAStoreRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && + kNumEpilogueWarpgroups == 1; + constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); + constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && (!kSerialNWarpgroups); + constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; + constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && + (kHidden / BLOCK_K) % 2 == 0; + using L1WGMMA = typename mma::sm90::FP8MMASelector::type; + using L2WGMMA = typename mma::sm90::FP8MMASelector::type; + static_assert(L1WGMMA::M == 64 and L1WGMMA::N == WG_BLOCK_N and L1WGMMA::K == 32, + "Unexpected WGMMA shape"); + DG_STATIC_ASSERT((!kSplitNWarpgroups) or (BLOCK_M == 64 and WG_BLOCK_N == 128), + "Split-N path expects two M64N128 WGMMA consumers"); + + // A is always CTA-local. When kClusterSize=2 the scheduler pairs adjacent + // M blocks with identical expert/N/K coordinates so the B TMA can multicast. + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M; + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + constexpr uint32_t kSwizzleAMode = kUseMMASync ? 0 : BLOCK_K * sizeof(a_dtype_t); // 0 or 128 + constexpr uint32_t kSwizzleBMode = kUseMMASync ? 0 : BLOCK_K * sizeof(b_dtype_t); // 0 or 128 + constexpr uint32_t kSwizzleCDMode = 128; + constexpr uint32_t kGranK = 128; // L1 acts SF, weights SF + constexpr uint32_t kL2ActsSFGranK = 64; // L2 acts SF (per-64 K, SM90 only) + + // ===================================================================== + // Shared memory layout + // ===================================================================== + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 + // (two per-64-K halves). Each TMA destination must be 128B aligned, so + // the second L2 half cannot start immediately after 16 floats in M16 decode. + constexpr uint32_t kL2SFAHalfStride = + math::constexpr_align(BLOCK_M * sizeof(float), 128u) / sizeof(float); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = 2 * kL2SFAHalfStride * sizeof(float); + // Block (128, 128) weight SF: 1 float per (BLOCK_N, BLOCK_K) tile for L2, + // 2 floats (gate/up) for L1. Loaded by math warpgroup directly from global, + // so no SMEM is needed. + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = 0; + + // CD output: max of L1 FP8 (BLOCK_M * (BLOCK_N/2) * 1 byte * num_wg) and + // L2 BF16 (BLOCK_M * BLOCK_N * 2 bytes * num_wg). + constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync + ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) + : 0u; + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); + constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : + kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_L1_ASYNC_SIZE = kAsyncL1TMAStore ? + 2 * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) : 0u; + constexpr uint32_t SMEM_CD_OUTPUT_BASE_SIZE = + SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_OUTPUT_UNALIGNED_SIZE = + SMEM_CD_OUTPUT_BASE_SIZE > SMEM_CD_L1_ASYNC_SIZE ? SMEM_CD_OUTPUT_BASE_SIZE : SMEM_CD_L1_ASYNC_SIZE; + constexpr uint32_t SMEM_CD_OUTPUT_SIZE = math::constexpr_align( + SMEM_CD_OUTPUT_UNALIGNED_SIZE, kSharedMemoryAlignment); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_ACCUM_SIZE + SMEM_CD_OUTPUT_SIZE; + + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + + // SMEM pointers + auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE); + + // mma.sync decode stages FP32 accumulators through SMEM so the epilogue can + // use logical row/column indices instead of WGMMA accumulator layout. + auto smem_accum_f32 = reinterpret_cast(smem_gemm_base); + auto smem_cd_base = math::advance_ptr(smem_gemm_base, SMEM_CD_ACCUM_SIZE); + // CD output is shared by L1 (FP8) and L2 (BF16); reinterpret-cast as needed. + auto smem_cd_l1 = reinterpret_cast(smem_cd_base); + auto smem_cd_l2 = reinterpret_cast(smem_cd_base); + + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + + // Barriers live after SF (SFB is loaded directly from global, no SMEM) + auto barrier_start_ptr = reinterpret_cast( + sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + i; }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages + i; }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + i; }); + + // ===================================================================== + // Initialization + // ===================================================================== + if (warp_idx == 0) { + // Clean expert-count shared memory + #pragma unroll + for (uint32_t i = lane_idx; i < kNumExperts; i += 32) + ptx::st_shared(smem_expert_count + i, 0u); + } else if (warp_idx == 1) { + // Init dispatch m-barriers + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM full/empty barriers and combine barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Producer arrivals: A(+SFA) + B, or A + B + SFA when + // the split-SFA experiment uses an otherwise idle TMA warp. + full_barriers[i]->init(kSplitSFATMA ? 3 : 2); + // With cluster multicast the leader CTA's TMA warp waits on peer + // empty barriers too, so every math warp releases both CTAs. + empty_barriers[i]->init(kClusterSize * kNumEpilogueWarps); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if constexpr (kClusterSize > 1) { + cute::cluster_sync(); + } else { + __syncthreads(); + } + + // ===================================================================== + // Scheduler (cluster=1) + // ===================================================================== + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, kNumExpertsPerWave, + kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested>(workspace); + + // Pipeline state shared by TMA loaders and math warpgroups + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM barrier indices (mirroring SM100) + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // Cross-rank NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Register reconfiguration counts (chosen to fit in 64512 reg budget). + // For the 256-epilogue-thread case (block_m=128, 2 math WGs): + // 128*48 + 128*40 + 256*208 = 64512 exactly. + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr bool kCompactFrontendWarpgroup = (kNumDispatchWarps == 2 and kNumMMANonEpilogueWarps == 2); + constexpr uint32_t kNumNonEpilogueRegisters = kCompactFrontendWarpgroup ? kNumDispatchRegisters : 40; + constexpr uint32_t kNumEpilogueRegisters = (kSerialNWarpgroups or kWideNWarpgroups) ? 256 : ((kUseMMASync and BLOCK_M == 32) ? 240 : 208); + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + constexpr uint32_t kProfileDispatchTotal = 0; + constexpr uint32_t kProfileDispatchPull = 1; + constexpr uint32_t kProfileMathLoop = 2; + constexpr uint32_t kProfileCombineBarrier = 3; + constexpr uint32_t kProfileCombineReduce = 4; + constexpr uint32_t kProfileGemmCore = 5; + constexpr uint32_t kProfileL1Epilogue = 6; + constexpr uint32_t kProfileL2Epilogue = 7; + const auto phase_profile_clock = [&]() -> unsigned long long { + if constexpr (kPhaseProfileRequested) { + unsigned long long t; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t)); + return t; + } else { + return 0ull; + } + }; + const auto phase_profile_record = [&](const uint32_t& metric, const unsigned long long& cycles) { + if constexpr (kPhaseProfileRequested) { + if (cumulative_local_expert_recv_stats != nullptr and cycles > 0) { + auto profile = reinterpret_cast( + cumulative_local_expert_recv_stats + kNumExpertsPerRank); + atomicAdd(profile + metric, cycles); + atomicMax(profile + 8 + metric, cycles); + atomicAdd(profile + 16 + metric, 1ull); + } + } + }; + + // ===================================================================== + // ROLE 1: DISPATCH WARPS + // Mirrors SM100 dispatch with two changes: + // * SF is per-128 channel float (no UTCCP transpose). We store the + // remote per-token SF directly into the local L1 SF buffer in + // MN-major layout: `local_sf[k_chunk * num_padded_sf_pool_tokens + token_idx]`. + // * The "token_idx_in_expert" → SF token index is now the simple + // per-block linear mapping (no 4×32 transpose). + // ===================================================================== + if (warp_idx < kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + const unsigned long long dispatch_total_start = phase_profile_clock(); + + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count tokens per expert + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Stake out per-expert SM offsets via global atomic + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source token-topk indices to remote ranks + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + false, true); + + // Sync with epilogue warps before pulling tokens + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + const unsigned long long dispatch_pull_start = phase_profile_clock(); + + // Token / SF pull loop + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + scheduler.fetch_expert_recv_count(); + + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + if (current_expert_idx >= kNumExpertsPerRank) + break; + + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection (identical to SM100) + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA pull token data into SMEM + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Copy SF: per-128 K floats, written linearly (no UTCCP transpose). + constexpr uint32_t kNumSFFloats = kHidden / 128; + DG_STATIC_ASSERT(kNumSFFloats > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const uint32_t sf_pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFFloats, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFFloats) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + + + // Cleanup workspace, overlapping with combine + const unsigned long long dispatch_pull_end = phase_profile_clock(); + if (lane_idx == 0) { + phase_profile_record(kProfileDispatchPull, dispatch_pull_end - dispatch_pull_start); + phase_profile_record(kProfileDispatchTotal, dispatch_pull_end - dispatch_total_start); + } + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + true, false); + + // ===================================================================== + // ROLE 2: GEMM TMA LOAD warps (load A+SFA, B+SFB) + // Default: 4 non-epilogue warps, two active and two idle. + // Compact frontend experiment: 2 dispatch warps + 2 TMA warps share the + // first warpgroup, reducing total CTA threads for the M128/2WG path. + // ===================================================================== + } else if (warp_idx == kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t valid_m = scheduler.template get_valid_m(); + const bool has_valid_m = valid_m > 0; + + // Wait for the pool to be ready. Cluster peers can be dummy CTAs for + // the tail M unit when an expert has an odd number of M blocks. + if (has_valid_m) { + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = valid_m; + while (ptx::ld_acq(ptr) != expected); + } else { + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. + constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; + const uint64_t expected = (kNumL1BlockNs >= 64) + ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); + while (ptx::ld_acq_gpu(ptr) != expected); + } + } + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + if (has_valid_m) { + const uint32_t m_idx = pool_block_idx * BLOCK_M; + const uint32_t k_idx = k_block_idx * BLOCK_K; + + // TMA load A + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], + k_idx, m_idx, 1); + + if constexpr (kSplitSFATMA) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE); + } else { + // TMA load SFA + if (block_phase == sched::BlockPhase::Linear1) { + // L1 SFA per-128: load (BLOCK_M, 1) at K=k_block_idx + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx, 1); + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + BLOCK_M * sizeof(float)); + } else { + // L2 SFA per-64: descriptor box is (block_mn, 1) (see make_tma_sf_desc), + // so we must issue two single-group TMAs and place them at smem offsets + // 0 and BLOCK_M to match math's load offsets (`+ 0 * BLOCK_M` / `+ 1 * BLOCK_M`). + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx * 2, 1); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], + smem_sfa[stage_idx] + kL2SFAHalfStride, + m_idx, k_block_idx * 2 + 1, 1); + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + 2 * BLOCK_M * sizeof(float)); + } + } + } else { + full_barriers[stage_idx]->arrive(); + } + } + __syncwarp(); + } + }); + + } else if (warp_idx == kNumDispatchWarps + 1) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + + const uint32_t shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + const uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + const uint32_t k_idx = k_block_idx * BLOCK_K; + + // TMA load B (weight SF is now loaded directly by math warps from global) + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], + k_idx, n_idx, kClusterSize); + + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); + } + __syncwarp(); + } + }); + + } else if (kSplitSFATMA && warp_idx == kNumDispatchWarps + 2) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void)local_expert_idx; + (void)n_block_idx; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t valid_m = scheduler.template get_valid_m(); + const bool has_valid_m = valid_m > 0; + + if (has_valid_m) { + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = valid_m; + while (ptx::ld_acq(ptr) != expected); + } else { + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; + const uint64_t expected = (kNumL1BlockNs >= 64) + ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); + while (ptx::ld_acq_gpu(ptr) != expected); + } + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + if (has_valid_m) { + const uint32_t m_idx = pool_block_idx * BLOCK_M; + + if (block_phase == sched::BlockPhase::Linear1) { + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx, 1); + full_barriers[stage_idx]->arrive_and_expect_tx(BLOCK_M * sizeof(float)); + } else { + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx * 2, 1); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], + smem_sfa[stage_idx] + kL2SFAHalfStride, + m_idx, k_block_idx * 2 + 1, 1); + full_barriers[stage_idx]->arrive_and_expect_tx(2 * BLOCK_M * sizeof(float)); + } + } else { + full_barriers[stage_idx]->arrive(); + } + } + __syncwarp(); + } + }); + + } else if (warp_idx < kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Idle non-epilogue warps (kNumDispatchWarps+2, +3). They must still + // participate in the warpgroup-collective `setmaxnreg.dec.sync.aligned` + // so that the math warpgroup's `warpgroup_reg_alloc` can succeed. + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // ===================================================================== + // ROLE 3: MATH WARPGROUPS (WGMMA + epilogue + combine) + // ===================================================================== + cutlass::arch::warpgroup_reg_alloc(); + + const uint32_t epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const uint32_t epilogue_wg_idx = epilogue_warp_idx / 4; + const uint32_t epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const uint32_t warp_idx_in_wg = epilogue_warp_idx % 4; + + uint32_t async_l1_store_stage = 0; + bool async_l1_store_pending[2] = {false, false}; + uint32_t async_l1_store_pool[2] = {0, 0}; + uint32_t async_l1_store_n[2] = {0, 0}; + + const auto arrive_empty_barrier = [&](const uint32_t& s) { + if constexpr (kClusterSize == 1) { + if (lane_idx == 0) + empty_barriers[s]->arrive(); + } else { + if (lane_idx < kClusterSize) + empty_barriers[s]->arrive(lane_idx); + } + }; + + const auto notify_l1_ready = [&](const uint32_t& ready_pool_block_idx, + const uint32_t& ready_n_block_idx) { + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), + 1ull << ready_n_block_idx); + } + __syncwarp(); + }; + + const auto drain_async_l1_store_stage = [&](const uint32_t& store_stage) { + if constexpr (kAsyncL1TMAStore) { + if (async_l1_store_pending[store_stage]) { + // Two SMEM L1 store buffers are used in FIFO order; waiting + // for <=1 outstanding store makes the older buffer reusable. + ptx::tma_store_wait<1>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(async_l1_store_pool[store_stage], + async_l1_store_n[store_stage]); + async_l1_store_pending[store_stage] = false; + } + } + }; + + const auto drain_all_async_l1_stores = [&]() { + if constexpr (kAsyncL1TMAStore) { + if (async_l1_store_pending[0] or async_l1_store_pending[1]) { + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (async_l1_store_pending[0]) { + notify_l1_ready(async_l1_store_pool[0], async_l1_store_n[0]); + async_l1_store_pending[0] = false; + } + if (async_l1_store_pending[1]) { + notify_l1_ready(async_l1_store_pool[1], async_l1_store_n[1]); + async_l1_store_pending[1] = false; + } + } + } + }; + + // WGMMA-output register layout helpers + const uint32_t row_idx = lane_idx / 4; + const uint32_t col_idx = lane_idx % 4; + const uint32_t r_0 = warp_idx_in_wg * 16 + row_idx; + const uint32_t r_1 = r_0 + 8; + + DG_STATIC_ASSERT(kSplitNWarpgroups || (BLOCK_M % kNumEpilogueWarpgroups == 0), "Invalid block M"); + if constexpr (kUseMMASync) { + DG_STATIC_ASSERT(WG_BLOCK_M == BLOCK_M, "mma.sync decode path uses one M tile per warpgroup"); + DG_STATIC_ASSERT(kNumEpilogueWarpgroups == 1, "mma.sync decode path currently uses one math warpgroup"); + } else if constexpr (kSplitNWarpgroups) { + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, + "Split-N WGs must each run one M64N128 WGMMA per K-block"); + } else if constexpr (kSerialNWarpgroups) { + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, + "Serial-N path runs two M64N128 WGMMAs per K-block"); + } else { + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M, "Each warpgroup must run exactly one WGMMA per K-block"); + } + + // Sync with dispatch + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + const unsigned long long math_loop_start = phase_profile_clock(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const uint32_t valid_m = scheduler.template get_valid_m(); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t m_idx = pool_block_idx * BLOCK_M; + const uint32_t wg_n_idx = kSplitNWarpgroups ? epilogue_wg_idx * WG_BLOCK_N : 0; + const uint32_t wg_l1_out_n_idx = kSplitNWarpgroups ? epilogue_wg_idx * WG_L1_OUT_BLOCK_N : 0; + const uint32_t n_idx = n_block_idx * BLOCK_N + wg_n_idx; + const uint32_t row_block_offset = kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M; + const uint32_t row_offset_r0 = row_block_offset + r_0; + const uint32_t row_offset_r1 = row_block_offset + r_1; + const bool valid_r0 = row_offset_r0 < valid_m; + const bool valid_r1 = row_offset_r1 < valid_m; + + + if constexpr (kAsyncL1TMAStore) { + if (block_phase != sched::BlockPhase::Linear1) + drain_all_async_l1_stores(); + } + + if constexpr (kUseMMASync) { + using MMASyncTiled = cute::TiledMMA< + cute::MMA_Atom, + cute::Layout, cute::_4, cute::_1>>>; + auto mma_sync_tiled = MMASyncTiled{}; + auto thr_mma = mma_sync_tiled.get_thread_slice(epilogue_thread_idx); + auto sC = cute::make_tensor( + cute::make_smem_ptr(smem_accum_f32), + cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); + auto tCsC = thr_mma.partition_C(sC); + auto tCrC = thr_mma.make_fragment_C(tCsC); + auto tCrFinal = thr_mma.make_fragment_C(tCsC); + cute::clear(tCrFinal); + auto cC = cute::make_identity_tensor(cute::shape(sC)); + auto tCcC = thr_mma.partition_C(cC); + + auto add_scaled_fragment = [&](const float& scale_b, const uint32_t& sfa_offset) { + CUTE_UNROLL + for (int i = 0; i < cute::size(tCrC); ++i) { + const auto coord = tCcC(i); + if (cute::elem_less(coord, cute::shape(sC))) { + const uint32_t row = static_cast(cute::get<0>(coord)); + const uint32_t col = static_cast(cute::get<1>(coord)); + const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + sfa_offset + row); + tCrFinal(i) += scale_a * scale_b * tCrC(i); + } + } + }; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + const float gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + const float up_sf = __ldg(base + up_n * kL1SFKBlocks); + + auto sA = cute::make_tensor( + cute::make_smem_ptr(smem_a[stage_idx]), + cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); + auto sB = cute::make_tensor( + cute::make_smem_ptr(smem_b[stage_idx]), + cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<128>{}), cute::GenRowMajor{})); + cute::clear(tCrC); + cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); + + CUTE_UNROLL + for (int i = 0; i < cute::size(tCrC); ++i) { + const auto coord = tCcC(i); + if (cute::elem_less(coord, cute::shape(sC))) { + const uint32_t row = static_cast(cute::get<0>(coord)); + const uint32_t col = static_cast(cute::get<1>(coord)); + const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + row); + const float scale_b = ((col / 8u) & 1u) ? up_sf : gate_sf; + tCrFinal(i) += scale_a * scale_b * tCrC(i); + } + } + } else { + const float l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert + + ((n_block_idx * BLOCK_N + wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); + CUTE_UNROLL + for (uint32_t half = 0; half < 2; ++half) { + const uint32_t k_off = half * (BLOCK_K / 2); + auto sA = cute::make_tensor( + cute::make_smem_ptr(smem_a[stage_idx] + k_off), + cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<64>{}), + cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); + auto sB = cute::make_tensor( + cute::make_smem_ptr(smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off), + cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<64>{}), + cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); + cute::clear(tCrC); + cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); + add_scaled_fragment(l2_sf, half * kL2SFAHalfStride); + } + } + + arrive_empty_barrier(stage_idx); + __syncwarp(); + } + + CUTE_UNROLL + for (int i = 0; i < cute::size(tCrFinal); ++i) { + const auto coord = tCcC(i); + if (cute::elem_less(coord, cute::shape(sC))) { + const uint32_t row = static_cast(cute::get<0>(coord)); + const uint32_t col = static_cast(cute::get<1>(coord)); + smem_accum_f32[row * BLOCK_N + col] = tCrFinal(i); + } + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + + constexpr uint32_t kMMASyncRowsPerPass = kNumEpilogueThreads / 8; + DG_STATIC_ASSERT(kMMASyncRowsPerPass == 16, "mma.sync epilogue maps 8 lanes per row"); + + if (block_phase == sched::BlockPhase::Linear1) { + #pragma unroll + for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { + const uint32_t row = row_base + epilogue_thread_idx / 8; + const uint32_t lane_in_row = epilogue_thread_idx % 8; + const bool valid_row = row < valid_m; + float swiglu_values[8]; + float amax = 0.0f; + + #pragma unroll + for (uint32_t p = 0; p < 8; ++p) { + const uint32_t gate_col = p * 16 + lane_in_row; + const uint32_t up_col = gate_col + 8; + float gate = valid_row ? smem_accum_f32[row * BLOCK_N + gate_col] : 0.0f; + float up = valid_row ? smem_accum_f32[row * BLOCK_N + up_col] : 0.0f; + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + gate = cute::min(gate, kActivationClamp); + up = cute::min(cute::max(up, -kActivationClamp), kActivationClamp); + } + const float e = kFastMath ? __expf(-gate) : expf(-gate); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + const float weight = valid_row ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row) + .get_base_ptr() : 0.0f; + const float value = gate * sig * up * weight; + swiglu_values[p] = value; + amax = cute::max(amax, cute::abs(value)); + } + + amax = math::warp_reduce<8, false>(amax, math::ReduceMax()); + float2 amax_pair = {amax, amax}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + const float sf = sf_pair.x; + const float sf_inv = sf_inv_pair.x; + + if (valid_row) { + #pragma unroll + for (uint32_t p = 0; p < 8; ++p) { + const uint32_t out_col = p * 8 + lane_in_row; + smem_cd_l1[row * L1_OUT_BLOCK_N + out_col] = + cutlass::float_e4m3_t(swiglu_values[p] * sf_inv); + } + if (lane_in_row == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_idx = pool_block_idx * BLOCK_M + row; + const uint32_t k_sf_idx = n_block_idx; + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_idx] = sf; + } + } + } + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1, + out_n_idx, + m_idx); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx); + } + __syncwarp(); + } else { + #pragma unroll + for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { + const uint32_t row = row_base + epilogue_thread_idx / 8; + const uint32_t lane_in_row = epilogue_thread_idx % 8; + if (row < valid_m) { + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + #pragma unroll + for (uint32_t v = 0; v < 2; ++v) { + const uint32_t col = lane_in_row * 16 + v * 8; + uint4 packed; + float f0 = smem_accum_f32[row * BLOCK_N + col + 0]; + float f1 = smem_accum_f32[row * BLOCK_N + col + 1]; + float f2 = smem_accum_f32[row * BLOCK_N + col + 2]; + float f3 = smem_accum_f32[row * BLOCK_N + col + 3]; + float f4 = smem_accum_f32[row * BLOCK_N + col + 4]; + float f5 = smem_accum_f32[row * BLOCK_N + col + 5]; + float f6 = smem_accum_f32[row * BLOCK_N + col + 6]; + float f7 = smem_accum_f32[row * BLOCK_N + col + 7]; + packed.x = math::cast_into_bf16_and_pack(f0, f1); + packed.y = math::cast_into_bf16_and_pack(f2, f3); + packed.z = math::cast_into_bf16_and_pack(f4, f5); + packed.w = math::cast_into_bf16_and_pack(f6, f7); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + col * sizeof(nv_bfloat16)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + return; + } + + if constexpr (kSerialNWarpgroups) { + using WGMMA = L1WGMMA; + constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; + constexpr uint32_t kNumSerialN = 2; + float final_accum[kNumSerialN][kAccumPerThread] = {}; + float accum[kAccumPerThread]; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + float scale_a_0_lo, scale_a_1_lo; + float scale_a_0_hi, scale_a_1_hi; + if (block_phase == sched::BlockPhase::Linear1) { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + } else { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r0); + scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r1); + } + + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + + #pragma unroll + for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { + const uint32_t serial_wg_n_idx = serial_n_idx * WG_BLOCK_N; + float gate_sf = 0.0f, up_sf = 0.0f, l2_sf = 0.0f; + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = (n_block_idx * BLOCK_N + serial_wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + up_sf = __ldg(base + up_n * kL1SFKBlocks); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb = (i & 1u) ? up_sf : gate_sf; + final_accum[serial_n_idx][i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; + final_accum[serial_n_idx][i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; + final_accum[serial_n_idx][i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; + final_accum[serial_n_idx][i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; + } + } else { + l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert + + ((n_block_idx * BLOCK_N + serial_wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[serial_n_idx][i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[serial_n_idx][i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[serial_n_idx][i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[serial_n_idx][i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + } + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k_off, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k_off, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[serial_n_idx][i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; + final_accum[serial_n_idx][i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; + final_accum[serial_n_idx][i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; + final_accum[serial_n_idx][i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; + } + } + } + + arrive_empty_barrier(stage_idx); + __syncwarp(); + } + + if (row_block_offset >= valid_m) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + return; + } + + if (block_phase == sched::BlockPhase::Linear1) { + constexpr uint32_t kNumPairs = kAccumPerThread / 8; + #pragma unroll + for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { + const uint32_t serial_l1_out_n_idx = serial_n_idx * WG_L1_OUT_BLOCK_N; + float swiglu_r0[kNumPairs][2]; + float swiglu_r1[kNumPairs][2]; + float amax_r0 = 0.0f, amax_r1 = 0.0f; + + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t gate = 2 * p, up = 2 * p + 1; + auto clamp_gate = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(x, kActivationClamp); + }; + auto clamp_up = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); + }; + float g_r0_c0 = final_accum[serial_n_idx][gate*4 + 0]; clamp_gate(g_r0_c0); + float g_r0_c1 = final_accum[serial_n_idx][gate*4 + 1]; clamp_gate(g_r0_c1); + float g_r1_c0 = final_accum[serial_n_idx][gate*4 + 2]; clamp_gate(g_r1_c0); + float g_r1_c1 = final_accum[serial_n_idx][gate*4 + 3]; clamp_gate(g_r1_c1); + float u_r0_c0 = final_accum[serial_n_idx][up*4 + 0]; clamp_up(u_r0_c0); + float u_r0_c1 = final_accum[serial_n_idx][up*4 + 1]; clamp_up(u_r0_c1); + float u_r1_c0 = final_accum[serial_n_idx][up*4 + 2]; clamp_up(u_r1_c0); + float u_r1_c1 = final_accum[serial_n_idx][up*4 + 3]; clamp_up(u_r1_c1); + auto silu = [](float x) -> float { + const float e = kFastMath ? __expf(-x) : expf(-x); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + return x * sig; + }; + if (valid_r0) { + swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; + swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; + amax_r0 = cute::max(amax_r0, cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); + } else { + swiglu_r0[p][0] = 0.0f; + swiglu_r0[p][1] = 0.0f; + } + if (valid_r1) { + swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; + swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; + amax_r1 = cute::max(amax_r1, cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); + } else { + swiglu_r1[p][0] = 0.0f; + swiglu_r1[p][1] = 0.0f; + } + } + + const float weight_r0 = [&]() { + if constexpr (kNumMaxTokensPerRank <= 1024) { + float weight = 0.0f; + if (col_idx == 0) + weight = valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + return __shfl_sync(0xffffffff, weight, static_cast(lane_idx - col_idx)); + } else { + return valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + } + }(); + const float weight_r1 = [&]() { + if constexpr (kNumMaxTokensPerRank <= 1024) { + float weight = 0.0f; + if (col_idx == 0) + weight = valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + return __shfl_sync(0xffffffff, weight, static_cast(lane_idx - col_idx)); + } else { + return valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + } + }(); + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + swiglu_r0[p][0] *= weight_r0; + swiglu_r0[p][1] *= weight_r0; + swiglu_r1[p][0] *= weight_r1; + swiglu_r1[p][1] *= weight_r1; + } + amax_r0 *= cute::abs(weight_r0); + amax_r1 *= cute::abs(weight_r1); + amax_r0 = math::warp_reduce<4, false>(amax_r0, math::ReduceMax()); + amax_r1 = math::warp_reduce<4, false>(amax_r1, math::ReduceMax()); + + float sf_r0, sf_inv_r0, sf_r1, sf_inv_r1; + { + float2 amax_pair = {amax_r0, amax_r1}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + sf_r0 = sf_pair.x; sf_inv_r0 = sf_inv_pair.x; + sf_r1 = sf_pair.y; sf_inv_r1 = sf_inv_pair.y; + } + + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const float v00 = swiglu_r0[p][0] * sf_inv_r0; + const float v01 = swiglu_r0[p][1] * sf_inv_r0; + const float v10 = swiglu_r1[p][0] * sf_inv_r1; + const float v11 = swiglu_r1[p][1] * sf_inv_r1; + const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); + const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); + const uint32_t col = p * 8 + col_idx * 2; + auto* p0 = reinterpret_cast( + smem_cd_l1 + r_0 * L1_OUT_BLOCK_N + serial_l1_out_n_idx + col); + auto* p1 = reinterpret_cast( + smem_cd_l1 + r_1 * L1_OUT_BLOCK_N + serial_l1_out_n_idx + col); + if (valid_r0) + *p0 = r0_pair.__x; + if (valid_r1) + *p1 = r1_pair.__x; + } + + if (col_idx == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; + const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; + const uint32_t k_sf_idx = (n_block_idx * L1_OUT_BLOCK_N + serial_l1_out_n_idx) / 64u; + if (valid_r0) + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r0] = sf_r0; + if (valid_r1) + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r1] = sf_r1; + } + } + + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1, + out_n_idx, + m_idx + row_block_offset); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } else { + constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; + #pragma unroll + for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { + const uint32_t serial_n_idx_base = n_block_idx * BLOCK_N + serial_n_idx * WG_BLOCK_N; + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { + auto smem_ptr = smem_cd_l2 + row * WG_BLOCK_N + col; + *reinterpret_cast(smem_ptr) = packed; + }; + if (valid_r0) { + const uint32_t r0_lo = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_lo*4 + 0], final_accum[serial_n_idx][chunk_lo*4 + 1]); + const uint32_t r0_hi = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_hi*4 + 0], final_accum[serial_n_idx][chunk_hi*4 + 1]); + write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); + write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); + } + if (valid_r1) { + const uint32_t r1_lo = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_lo*4 + 2], final_accum[serial_n_idx][chunk_lo*4 + 3]); + const uint32_t r1_hi = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_hi*4 + 2], final_accum[serial_n_idx][chunk_hi*4 + 3]); + write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); + write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); + } + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + const uint32_t row_in_warp_block = lane_idx / 16; + const uint32_t lane_in_row = lane_idx % 16; + constexpr uint32_t cols_per_lane = WG_BLOCK_N / 16; + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; + const uint32_t m_idx_in_block = row_block_offset + row_in_wg; + if (m_idx_in_block >= valid_m) break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + auto smem_ptr = smem_cd_l2 + row_in_wg * WG_BLOCK_N + lane_in_row * cols_per_lane; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto packed = *reinterpret_cast(smem_ptr); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + serial_n_idx_base * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + } + return; + } + + // ---------------- GEMM ---------------- + using WGMMA = L1WGMMA; + constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; // 64 for M=64,N=128 + float final_accum[kAccumPerThread] = {}; + float accum[kAccumPerThread]; + + const unsigned long long block_gemm_start = phase_profile_clock(); + const auto run_default_gemm_loop = [&]() { +for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + // Read SF (must precede warpgroup_arrive) + float scale_a_0_lo, scale_a_1_lo; + float scale_a_0_hi, scale_a_1_hi; // Only used in L2 (per-64 K) + if (block_phase == sched::BlockPhase::Linear1) { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + } else { + // L2: SFA layout is (K=2, M=BLOCK_M) MN-major; first half SF at offset 0, second at BLOCK_M + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r0); + scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r1); + } + + // ----- Block (128, 128) weight SF (loaded directly from global) ----- + // L1 weight SF shape: (E, 2*IH/128, H/128) MN-major. The N axis is + // [gate(IH/128), up(IH/128)]; with the gate/up gran-8 interleave on + // the FP8 weight, each BLOCK_N=128 tile covers 64 rows of gate plus + // 64 rows of up taken from the same original 128-row block, so: + // gate_sf_n = n_block_idx / 2 + // up_sf_n = (IH/128) + n_block_idx / 2 + // + // L2 weight SF shape: (E, H/128, IH/128) MN-major. One scalar per + // (BLOCK_N, BLOCK_K) tile, broadcast across all WGMMA accumulators. + // + // NOTE: we tried hoisting these LDGs above the barrier wait and/or + // having only lane 0 load + shfl-broadcast. Both regressed on H20 + // by 7-11% across all batch sizes, presumably because (a) Hopper's + // L1 read-only cache already coalesces same-address LDGs from all + // 128 WG threads and (b) hoisting contended with the dispatch + // warps' NVLink LDGs on the MIO unit. Keep the simple parallel + // post-wait load. + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + float gate_sf = 0.0f, up_sf = 0.0f, l2_sf_lo = 0.0f, l2_sf_hi = 0.0f; + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + up_sf = __ldg(base + up_n * kL1SFKBlocks); + } else { + const float* base = l2_weights_sf + local_expert_idx * kL2SFPerExpert + k_block_idx; + const uint32_t sf_n = (n_block_idx * BLOCK_N + wg_n_idx) / 128u; + l2_sf_lo = __ldg(base + sf_n * kL2SFKBlocks); + if constexpr (WG_BLOCK_N > 128) + l2_sf_hi = __ldg(base + (sf_n + 1u) * kL2SFKBlocks); + else + l2_sf_hi = l2_sf_lo; + } + + if (block_phase == sched::BlockPhase::Linear1) { + // Single per-128 K-block WGMMA group + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage_idx); + + // L1: gate/up alternate at gran=8 along N; each `i` block of 8 + // cols belongs entirely to one of {gate, up}, so .x and .y + // share the same scalar. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb = (i & 1u) ? up_sf : gate_sf; + final_accum[i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; + } + } else { + if constexpr (kL2DualAccum) { + float accum_hi[kAccumPerThread]; + + const auto desc_a_lo0 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K, 1); + const auto desc_b_lo0 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K, 1); + const auto desc_a_lo1 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + WGMMA::K, 1); + const auto desc_b_lo1 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + WGMMA::K, 1); + const auto desc_a_hi0 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + BLOCK_K / 2, 1); + const auto desc_b_hi0 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + BLOCK_K / 2, 1); + const auto desc_a_hi1 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + BLOCK_K / 2 + WGMMA::K, 1); + const auto desc_b_hi1 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + BLOCK_K / 2 + WGMMA::K, 1); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_hi[i]); + } + ptx::warpgroup_arrive(); + WGMMA::wgmma(desc_a_lo0, desc_b_lo0, accum, false); + WGMMA::wgmma(desc_a_lo1, desc_b_lo1, accum, true); + WGMMA::wgmma(desc_a_hi0, desc_b_hi0, accum_hi, false); + WGMMA::wgmma(desc_a_hi1, desc_b_hi1, accum_hi, true); + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_hi[i]); + } + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage_idx); + + if constexpr (WG_BLOCK_N == 128) { + const float scale_0_lo = scale_a_0_lo * l2_sf_lo; + const float scale_1_lo = scale_a_1_lo * l2_sf_lo; + const float scale_0_hi = scale_a_0_hi * l2_sf_lo; + const float scale_1_hi = scale_a_1_hi * l2_sf_lo; + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[i*4+0] += scale_0_lo * accum[i*4+0]; + final_accum[i*4+1] += scale_0_lo * accum[i*4+1]; + final_accum[i*4+2] += scale_1_lo * accum[i*4+2]; + final_accum[i*4+3] += scale_1_lo * accum[i*4+3]; + final_accum[i*4+0] += scale_0_hi * accum_hi[i*4+0]; + final_accum[i*4+1] += scale_0_hi * accum_hi[i*4+1]; + final_accum[i*4+2] += scale_1_hi * accum_hi[i*4+2]; + final_accum[i*4+3] += scale_1_hi * accum_hi[i*4+3]; + } + } else { + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; + final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum_hi[i*4+0]; + final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum_hi[i*4+1]; + final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum_hi[i*4+2]; + final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum_hi[i*4+3]; + } + } + } else { + // L2: split BLOCK_K=128 into two halves (per-64 SFA), each 2 WGMMAs. + // First half: K=0..63, SFA = scale_a_*_lo + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // L2 weight SF is per 128 output columns; M64N256 spans two SF groups. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; + final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + } + + // Second half: K=64..127, SFA = scale_a_*_hi + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k_off, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage_idx); + + // L2 second half: same SFA half, still choose weight SF by N chunk. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; + final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; + } + } + } + } + }; + + const auto run_l1_dual_k_gemm_loop = [&]() { + DG_STATIC_ASSERT((kHidden / BLOCK_K) % 2 == 0, "L1 dual-K expects an even number of K blocks"); + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* expert_sf_base = l1_weights_sf + local_expert_idx * kL1SFPerExpert; + const float* gate_sf_base = expert_sf_base + gate_n * kL1SFKBlocks; + const float* up_sf_base = expert_sf_base + up_n * kL1SFKBlocks; + float accum_b[kAccumPerThread]; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks;) { + const uint32_t stage0 = stage_idx; + const uint32_t phase0 = phase; + const uint32_t k0 = k_block_idx; + full_barriers[stage0]->wait(phase0); + + const float scale_a0_r0 = ptx::ld_shared(smem_sfa[stage0] + row_offset_r0); + const float scale_a0_r1 = ptx::ld_shared(smem_sfa[stage0] + row_offset_r1); + const float2 gate_sf_pair = __ldg(reinterpret_cast(gate_sf_base + k0)); + const float2 up_sf_pair = __ldg(reinterpret_cast(up_sf_base + k0)); + const float gate_sf0 = gate_sf_pair.x; + const float up_sf0 = up_sf_pair.x; + + advance_pipeline(k_block_idx); + const uint32_t stage1 = stage_idx; + const uint32_t phase1 = phase; + full_barriers[stage1]->wait(phase1); + + const float scale_a1_r0 = ptx::ld_shared(smem_sfa[stage1] + row_offset_r0); + const float scale_a1_r1 = ptx::ld_shared(smem_sfa[stage1] + row_offset_r1); + const float gate_sf1 = gate_sf_pair.y; + const float up_sf1 = up_sf_pair.y; + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_b[i]); + } + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage0] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage0] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage1] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage1] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum_b, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_b[i]); + } + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage0); + arrive_empty_barrier(stage1); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb0 = (i & 1u) ? up_sf0 : gate_sf0; + const float sb1 = (i & 1u) ? up_sf1 : gate_sf1; + final_accum[i*4+0] += scale_a0_r0 * sb0 * accum[i*4+0]; + final_accum[i*4+1] += scale_a0_r0 * sb0 * accum[i*4+1]; + final_accum[i*4+2] += scale_a0_r1 * sb0 * accum[i*4+2]; + final_accum[i*4+3] += scale_a0_r1 * sb0 * accum[i*4+3]; + final_accum[i*4+0] += scale_a1_r0 * sb1 * accum_b[i*4+0]; + final_accum[i*4+1] += scale_a1_r0 * sb1 * accum_b[i*4+1]; + final_accum[i*4+2] += scale_a1_r1 * sb1 * accum_b[i*4+2]; + final_accum[i*4+3] += scale_a1_r1 * sb1 * accum_b[i*4+3]; + } + + advance_pipeline(k_block_idx); + } + }; + + if constexpr (kL1DualKAccum) { + if (block_phase == sched::BlockPhase::Linear1) + run_l1_dual_k_gemm_loop(); + else + run_default_gemm_loop(); + } else { + run_default_gemm_loop(); + } + + const unsigned long long block_gemm_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileGemmCore, block_gemm_end - block_gemm_start); + + // Skip epilogue when block is past valid M (still must release via empty). + // A dummy cluster peer may still carry an async L1 store from the + // previous valid block, so drain it before leaving the L1 wave. + if (row_block_offset >= valid_m) { + if constexpr (kAsyncL1TMAStore) { + if (block_phase == sched::BlockPhase::Linear1) + drain_all_async_l1_stores(); + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + return; + } + + const unsigned long long block_epilogue_start = phase_profile_clock(); + if (block_phase == sched::BlockPhase::Linear1) { + // ---------------- L1 EPILOGUE: SwiGLU + FP8 quantize + TMA store ---------------- + // Layout in `final_accum`: + // 16 chunks of 8 N-cols, each chunk = 4 floats per thread = (r0c0, r0c1, r1c0, r1c1). + // Gate chunks: even (0, 2, ..., 14). Up chunks: odd (1, 3, ..., 15). + // Pair `p` ∈ [0, 8): gate chunk = 2p, up chunk = 2p+1. + // + // For each pair we produce 4 post-SwiGLU floats per thread, mapped to + // output cols (p*8 + col_idx*2 + {0,1}) for both r0 and r1. + + constexpr uint32_t kNumPairs = kAccumPerThread / 8; + constexpr uint32_t kNumSFGroups = WG_L1_OUT_BLOCK_N / 64; + DG_STATIC_ASSERT(WG_L1_OUT_BLOCK_N % 64 == 0, "L1 output SF is per 64 columns"); + float swiglu_r0[kNumPairs][2]; + float swiglu_r1[kNumPairs][2]; + + // Per-row amax, one scale for each 64-col L1 output group. + float amax_r0[kNumSFGroups] = {}; + float amax_r1[kNumSFGroups] = {}; + + // Compute SwiGLU + per-group amax. + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t gate = 2 * p, up = 2 * p + 1; + const uint32_t sf_group = p / 8; + + auto clamp_gate = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(x, kActivationClamp); + }; + auto clamp_up = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); + }; + float g_r0_c0 = final_accum[gate*4 + 0]; clamp_gate(g_r0_c0); + float g_r0_c1 = final_accum[gate*4 + 1]; clamp_gate(g_r0_c1); + float g_r1_c0 = final_accum[gate*4 + 2]; clamp_gate(g_r1_c0); + float g_r1_c1 = final_accum[gate*4 + 3]; clamp_gate(g_r1_c1); + float u_r0_c0 = final_accum[up*4 + 0]; clamp_up(u_r0_c0); + float u_r0_c1 = final_accum[up*4 + 1]; clamp_up(u_r0_c1); + float u_r1_c0 = final_accum[up*4 + 2]; clamp_up(u_r1_c0); + float u_r1_c1 = final_accum[up*4 + 3]; clamp_up(u_r1_c1); + + auto silu = [](float x) -> float { + const float e = kFastMath ? __expf(-x) : expf(-x); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + return x * sig; + }; + + if (valid_r0) { + swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; + swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; + amax_r0[sf_group] = cute::max( + amax_r0[sf_group], + cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); + } else { + swiglu_r0[p][0] = 0.0f; + swiglu_r0[p][1] = 0.0f; + } + if (valid_r1) { + swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; + swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; + amax_r1[sf_group] = cute::max( + amax_r1[sf_group], + cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); + } else { + swiglu_r1[p][0] = 0.0f; + swiglu_r1[p][1] = 0.0f; + } + } + + + float weight_r0 = 0.0f, weight_r1 = 0.0f; + if constexpr (kNumMaxTokensPerRank <= 1024) { + const int topk_weight_src_lane = static_cast(lane_idx - col_idx); + if (col_idx == 0) { + weight_r0 = valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + weight_r1 = valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + } + weight_r0 = __shfl_sync(0xffffffff, weight_r0, topk_weight_src_lane); + weight_r1 = __shfl_sync(0xffffffff, weight_r1, topk_weight_src_lane); + } else { + weight_r0 = valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + weight_r1 = valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + } + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + swiglu_r0[p][0] *= weight_r0; + swiglu_r0[p][1] *= weight_r0; + swiglu_r1[p][0] *= weight_r1; + swiglu_r1[p][1] *= weight_r1; + } + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + amax_r0[g] *= cute::abs(weight_r0); + amax_r1[g] *= cute::abs(weight_r1); + } + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + amax_r0[g] = math::warp_reduce<4, false>(amax_r0[g], math::ReduceMax()); + amax_r1[g] = math::warp_reduce<4, false>(amax_r1[g], math::ReduceMax()); + } + + float sf_r0[kNumSFGroups], sf_inv_r0[kNumSFGroups]; + float sf_r1[kNumSFGroups], sf_inv_r1[kNumSFGroups]; + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + float2 amax_pair = {amax_r0[g], amax_r1[g]}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + sf_r0[g] = sf_pair.x; sf_inv_r0[g] = sf_inv_pair.x; + sf_r1[g] = sf_pair.y; sf_inv_r1[g] = sf_inv_pair.y; + } + + // Quantize and write to smem_cd_l1 (row-major, no swizzle). + const uint32_t l1_store_stage = kAsyncL1TMAStore ? async_l1_store_stage : 0u; + if constexpr (kAsyncL1TMAStore) + drain_async_l1_store_stage(l1_store_stage); + auto* smem_cd_l1_wg = smem_cd_l1 + + l1_store_stage * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N + + (kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M * L1_OUT_BLOCK_N); + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t sf_group = p / 8; + const float v00 = swiglu_r0[p][0] * sf_inv_r0[sf_group]; + const float v01 = swiglu_r0[p][1] * sf_inv_r0[sf_group]; + const float v10 = swiglu_r1[p][0] * sf_inv_r1[sf_group]; + const float v11 = swiglu_r1[p][1] * sf_inv_r1[sf_group]; + + const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); + const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); + + const uint32_t col = p * 8 + col_idx * 2; + auto* p0 = reinterpret_cast( + smem_cd_l1_wg + r_0 * L1_OUT_BLOCK_N + wg_l1_out_n_idx + col); + auto* p1 = reinterpret_cast( + smem_cd_l1_wg + r_1 * L1_OUT_BLOCK_N + wg_l1_out_n_idx + col); + if (valid_r0) + *p0 = r0_pair.__x; + if (valid_r1) + *p1 = r1_pair.__x; + } + + // Write L2-activation SF as float, one value per 64 output columns. + if (col_idx == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; + const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; + const uint32_t base_k_sf_idx = (n_block_idx * L1_OUT_BLOCK_N + wg_l1_out_n_idx) / 64u; + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + if (valid_r0) + sf_base_ptr[(base_k_sf_idx + g) * kNumPaddedSFPoolTokens + token_r0] = sf_r0[g]; + if (valid_r1) + sf_base_ptr[(base_k_sf_idx + g) * kNumPaddedSFPoolTokens + token_r1] = sf_r1[g]; + } + } + + // Issue TMA store of the entire tile. Padding rows beyond + // `valid_m` are written with stale/garbage FP8 to the L1-output + // pool buffer, but they are never consumed downstream: the L2 + // GEMM tile loads them, but its NVLink-scatter epilogue is + // gated by `m_idx_in_block >= valid_m`, and stale SF in the + // padding rows can produce NaN accumulators that simply stay + // in registers (only valid rows are converted to BF16 and + // STSM'd into smem). Using TMA for partial tiles is a large + // win for low-batch / decode where every tile is partial. + if constexpr (kSplitNWarpgroups) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1, + out_n_idx, + m_idx); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } else { + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1_wg, + out_n_idx, + m_idx + row_block_offset); + cute::tma_store_arrive(); + } + __syncwarp(); + if constexpr (kAsyncL1TMAStore) { + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + async_l1_store_pending[l1_store_stage] = true; + async_l1_store_pool[l1_store_stage] = pool_block_idx; + async_l1_store_n[l1_store_stage] = n_block_idx; + async_l1_store_stage ^= 1u; + } else { + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } + } + const unsigned long long block_epilogue_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileL1Epilogue, block_epilogue_end - block_epilogue_start); + } else { + // ---------------- L2 EPILOGUE: BF16 cast + NVLink scatter ---------------- + constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; + + if constexpr (kDirectL2Scatter) { + DG_STATIC_ASSERT(WG_BLOCK_N == 128, "Direct L2 scatter prototype only supports N128"); + + auto scatter_direct_row = [&](const uint32_t& row_offset, const bool& valid_row, const uint32_t& row_accum_offset) { + if (valid_row) { + uint32_t dst_rank_idx = 0, dst_token_idx = 0, dst_topk_idx = 0; + if (col_idx == 0) { + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row_offset); + dst_rank_idx = src_metadata.rank_idx; + dst_token_idx = src_metadata.token_idx; + dst_topk_idx = src_metadata.topk_idx; + } + const int src_lane = static_cast(lane_idx - col_idx); + dst_rank_idx = __shfl_sync(0xffffffff, dst_rank_idx, src_lane); + dst_token_idx = __shfl_sync(0xffffffff, dst_token_idx, src_lane); + dst_topk_idx = __shfl_sync(0xffffffff, dst_topk_idx, src_lane); + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + auto dst_base = math::advance_ptr( + dst_token.get_base_ptr(), n_idx * sizeof(nv_bfloat16)); + auto mapped_dst_base = sym_buffer.map(dst_base, dst_rank_idx); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + const uint32_t col_lo = chunk_lo * 8 + col_idx * 2; + const uint32_t col_hi = chunk_hi * 8 + col_idx * 2; + const uint32_t packed_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo * 4 + row_accum_offset + 0], + final_accum[chunk_lo * 4 + row_accum_offset + 1]); + const uint32_t packed_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi * 4 + row_accum_offset + 0], + final_accum[chunk_hi * 4 + row_accum_offset + 1]); + *reinterpret_cast(mapped_dst_base + col_lo * sizeof(nv_bfloat16)) = packed_lo; + *reinterpret_cast(mapped_dst_base + col_hi * sizeof(nv_bfloat16)) = packed_hi; + } + } + }; + + scatter_direct_row(row_offset_r0, valid_r0, 0); + scatter_direct_row(row_offset_r1, valid_r1, 2); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } else { + // STSM into smem_cd_l2 (BF16). Reuse SM100 column-swizzle layout. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + // Each i consumes 8 floats (one 16x256b chunk in SM100 terms). + // For SM90 WGMMA layout, 8 floats per i correspond to 2 chunks of 4 floats: + // final_accum[i*8 + (0..3)] = chunk 2i: (r0c0, r0c1, r1c0, r1c1) + // final_accum[i*8 + (4..7)] = chunk 2i+1: same shape + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + + // Write to SMEM at appropriate position + // Row r_0 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r0_lo + // Row r_0 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r0_hi + // Row r_1 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r1_lo + // Row r_1 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r1_hi + auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { + auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * WG_BLOCK_M * WG_BLOCK_N + + row * WG_BLOCK_N + + col; + // BF16 STS: 2 bf16 elements + *reinterpret_cast(smem_ptr) = packed; + }; + if (valid_r0) { + const uint32_t r0_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo*4 + 0], final_accum[chunk_lo*4 + 1]); + const uint32_t r0_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi*4 + 0], final_accum[chunk_hi*4 + 1]); + write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); + write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); + } + if (valid_r1) { + const uint32_t r1_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo*4 + 2], final_accum[chunk_lo*4 + 3]); + const uint32_t r1_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi*4 + 2], final_accum[chunk_hi*4 + 3]); + write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); + write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); + } + } + + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Scatter to remote ranks via NVLink (one row per warp-pair) + // Each warpgroup-warp covers 8 unique rows × 2 (r_0 + r_1 doubled by warps) + // Lane group of 16 within a warp → 1 row. + const uint32_t row_in_warp_block = lane_idx / 16; // 0 or 1 + const uint32_t lane_in_row = lane_idx % 16; + const uint32_t cols_per_lane = WG_BLOCK_N / 16; + static_assert(WG_BLOCK_N == 64 or WG_BLOCK_N == 128 or WG_BLOCK_N == 256, + "L2 scatter supports per-WG N64/N128/N256"); + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; + const uint32_t m_idx_in_block = row_block_offset + row_in_wg; + if (m_idx_in_block >= valid_m) break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * WG_BLOCK_M * WG_BLOCK_N + + row_in_wg * WG_BLOCK_N + + lane_in_row * cols_per_lane; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + + if constexpr (WG_BLOCK_N == 256) { + const auto packed0 = *reinterpret_cast(smem_ptr); + const auto packed1 = *(reinterpret_cast(smem_ptr) + 1); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * 2u * sizeof(uint4)); + auto mapped_dst_ptr = sym_buffer.map(dst_ptr, dst_rank_idx); + mapped_dst_ptr[0] = packed0; + mapped_dst_ptr[1] = packed1; + } else if constexpr (WG_BLOCK_N == 128) { + const auto packed = *reinterpret_cast(smem_ptr); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } else { + const auto packed = *reinterpret_cast(smem_ptr); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint2)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + const unsigned long long block_epilogue_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileL2Epilogue, block_epilogue_end - block_epilogue_start); + } + }); + const unsigned long long math_loop_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileMathLoop, math_loop_end - math_loop_start); + + // ---------------- COMBINE ---------------- + // NVLink barrier first: signals remote ranks that this rank's GEMM + // outputs (NVLink scatter targets) are fully written. + const unsigned long long combine_barrier_start = phase_profile_clock(); + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + const unsigned long long combine_barrier_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileCombineBarrier, combine_barrier_end - combine_barrier_start); + + // Sync with dispatch (paired with dispatch's pre-cleanup sync) so that + // dispatch may now safely clean workspace state. + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + const unsigned long long combine_reduce_start = phase_profile_clock(); + + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + constexpr uint32_t kNumChunks = + (kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE + and kHidden <= 32 * kNumMaxRegistersForBuffer) ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr( + smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + bool do_reduce = move_mask_and_load(load_stage_idx); + + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + const unsigned long long combine_reduce_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileCombineReduce, combine_reduce_end - combine_reduce_start); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only supports sm_90"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh new file mode 100644 index 0000000000..42c5980b0f --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh @@ -0,0 +1,2318 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define __CLION_IDE__ + +namespace deep_gemm { + +// ============================================================================ +// SM90 (Hopper) FP8 MegaMoE — full implementation +// ---------------------------------------------------------------------------- +// Pipeline (cluster=1, no TMA multicast): +// * Dispatch warps: pull tokens (FP8) and SF (per-128 channel float) from +// remote ranks via NVLink into the local L1 pool. +// * GEMM TMA-load warps (1 for A+SFA, 1 for B+SFB) feed the pipeline stages. +// * Math warpgroups (1 or 2, totalling kNumEpilogueThreads) consume each +// stage with WGMMA, accumulate into registers, then run the epilogue: +// - L1 (Linear1): SwiGLU with gate/up granularity-8 interleaved layout, +// per-row amax over the 64 post-SwiGLU columns of this block, FP8 e4m3 +// quantize, STSM into SMEM, TMA store to local L1 output buffer. +// The per-row SF is written as a *float* into the L2-acts SF buffer at +// per-64 K granularity (one SF per L1 N block), so each block is fully +// self-contained and no cross-CTA amax synchronisation is needed. +// - L2 (Linear2): BF16 cast of the GEMM output, STSM into SMEM, then +// NVLink scatter to remote combine buffers. +// * After all GEMM blocks, the math warps run the COMBINE step (top-k +// reduction in BF16) — ported verbatim from the SM100 kernel. +// ============================================================================ + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kClusterSize, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + bool kAsyncL1TMAStoreRequested = true, + bool kSplitSFATMARequested = false, + bool kDirectL2ScatterRequested = false, + bool kL2DualAccumRequested = false, + bool kPhaseProfileRequested = false, + bool kL1DualKAccumRequested = false, + bool kL2NMajorScheduleRequested = false, + bool kL1NMajorScheduleRequested = false, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm90_w4a8_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const float* __restrict__ l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const float* __restrict__ l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900) and (__CUDA_ARCH__ < 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // ===================================================================== + // Template checks + // ===================================================================== + DG_STATIC_ASSERT(kNumDispatchThreads == 64 or kNumDispatchThreads % 128 == 0, + "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 64 or kNumNonEpilogueThreads == 128, + "Invalid number of GEMM TMA warps (2 or 4 warps expected)"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of math/epilogue threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + DG_STATIC_ASSERT(kClusterSize == 1 or kClusterSize == 2, "Invalid cluster size"); + DG_STATIC_ASSERT(kNumSMs % kClusterSize == 0, "SM count must be divisible by cluster size"); + DG_STATIC_ASSERT(BLOCK_M == 16 or BLOCK_M == 32 or BLOCK_M % 64 == 0, + "BLOCK_M must be 16/32 for mma.sync decode or a multiple of WGMMA::M (64)"); + DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); + DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); + + // ===================================================================== + // Thread / warp identification + // ===================================================================== + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch all TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + } + + // ===================================================================== + // Workspaces and symmetric buffer slicing (mirror SM100 layout, except SF + // for L2 activations uses per-64 K granularity) + // ===================================================================== + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + // Per-128 K float SF: 4 bytes per per-128 group => `kHidden / 32` bytes/token (same as SM100 packing) + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + // Per-64 K float SF (SM90 only): 4 bytes per per-64 group => `kIntermediateHidden / 16` bytes/token + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered input area + const auto input_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxTokensPerRank, workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumMaxTokensPerRank, input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer(input_topk_idx_layout, 1, kNumMaxTokensPerRank, input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer(input_topk_weights_layout, 1, kNumMaxTokensPerRank, input_topk_idx_buffer.get_end_ptr()); + + // L1 input area + const auto l1_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxPoolTokens, input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumPaddedSFPoolTokens, l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer(l1_topk_weights_layout, 1, kNumMaxPoolTokens, l1_sf_buffer.get_end_ptr()); + + // L2 input area + const auto l2_token_buffer = layout::Buffer(fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer(fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, l2_token_buffer.get_end_ptr()); + + // Combine input area + const auto combine_token_buffer = layout::Buffer(bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, l2_sf_buffer.get_end_ptr()); + + // ===================================================================== + // GEMM data types and shape constants + // ===================================================================== + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::float_e4m3_t; + constexpr bool kUseMMASync = (BLOCK_M == 16 or BLOCK_M == 32); + constexpr bool kSplitNWarpgroups = + (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 2; + constexpr bool kSerialNWarpgroups = false; + constexpr bool kWideNWarpgroups = + (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 1; + constexpr uint32_t WG_BLOCK_M = kSplitNWarpgroups ? BLOCK_M : BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N + constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N + constexpr bool kAsyncL1TMAStore = + kAsyncL1TMAStoreRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && + kNumEpilogueWarpgroups == 1; + constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); + constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && (!kSerialNWarpgroups); + constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; + constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && + (kHidden / BLOCK_K) % 2 == 0; + using L1WGMMA = typename mma::sm90::FP8MMASelector::type; + using L2WGMMA = typename mma::sm90::FP8MMASelector::type; + static_assert(L1WGMMA::M == 64 and L1WGMMA::N == WG_BLOCK_N and L1WGMMA::K == 32, + "Unexpected WGMMA shape"); + DG_STATIC_ASSERT((!kSplitNWarpgroups) or (BLOCK_M == 64 and WG_BLOCK_N == 128), + "Split-N path expects two M64N128 WGMMA consumers"); + + // A is always CTA-local. When kClusterSize=2 the scheduler pairs adjacent + // M blocks with identical expert/N/K coordinates so the B TMA can multicast. + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M; + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + constexpr uint32_t kSwizzleAMode = kUseMMASync ? 0 : BLOCK_K * sizeof(a_dtype_t); // 0 or 128 + constexpr uint32_t kSwizzleBMode = kUseMMASync ? 0 : BLOCK_K * sizeof(b_dtype_t); // 0 or 128 + constexpr uint32_t kSwizzleCDMode = 128; + constexpr uint32_t kGranK = 128; // L1 acts SF, weights SF + constexpr uint32_t kL2ActsSFGranK = 64; // L2 acts SF (per-64 K, SM90 only) + + // ===================================================================== + // Shared memory layout + // ===================================================================== + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 + // (two per-64-K halves). Each TMA destination must be 128B aligned, so + // the second L2 half cannot start immediately after 16 floats in M16 decode. + constexpr uint32_t kL2SFAHalfStride = + math::constexpr_align(BLOCK_M * sizeof(float), 128u) / sizeof(float); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = 2 * kL2SFAHalfStride * sizeof(float); + // Block (128, 128) weight SF: 1 float per (BLOCK_N, BLOCK_K) tile for L2, + // 2 floats (gate/up) for L1. Loaded by math warpgroup directly from global, + // so no SMEM is needed. + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = 0; + + // CD output: max of L1 FP8 (BLOCK_M * (BLOCK_N/2) * 1 byte * num_wg) and + // L2 BF16 (BLOCK_M * BLOCK_N * 2 bytes * num_wg). + constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync + ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) + : 0u; + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); + constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : + kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_L1_ASYNC_SIZE = kAsyncL1TMAStore ? + 2 * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) : 0u; + constexpr uint32_t SMEM_CD_OUTPUT_BASE_SIZE = + SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_OUTPUT_UNALIGNED_SIZE = + SMEM_CD_OUTPUT_BASE_SIZE > SMEM_CD_L1_ASYNC_SIZE ? SMEM_CD_OUTPUT_BASE_SIZE : SMEM_CD_L1_ASYNC_SIZE; + constexpr uint32_t SMEM_CD_OUTPUT_SIZE = math::constexpr_align( + SMEM_CD_OUTPUT_UNALIGNED_SIZE, kSharedMemoryAlignment); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_ACCUM_SIZE + SMEM_CD_OUTPUT_SIZE; + + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + + // SMEM pointers + auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE); + + // mma.sync decode stages FP32 accumulators through SMEM so the epilogue can + // use logical row/column indices instead of WGMMA accumulator layout. + auto smem_accum_f32 = reinterpret_cast(smem_gemm_base); + auto smem_cd_base = math::advance_ptr(smem_gemm_base, SMEM_CD_ACCUM_SIZE); + // CD output is shared by L1 (FP8) and L2 (BF16); reinterpret-cast as needed. + auto smem_cd_l1 = reinterpret_cast(smem_cd_base); + auto smem_cd_l2 = reinterpret_cast(smem_cd_base); + + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + + // Barriers live after SF (SFB is loaded directly from global, no SMEM) + auto barrier_start_ptr = reinterpret_cast( + sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + i; }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages + i; }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + i; }); + + // ===================================================================== + // Initialization + // ===================================================================== + if (warp_idx == 0) { + // Clean expert-count shared memory + #pragma unroll + for (uint32_t i = lane_idx; i < kNumExperts; i += 32) + ptx::st_shared(smem_expert_count + i, 0u); + } else if (warp_idx == 1) { + // Init dispatch m-barriers + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM full/empty barriers and combine barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Producer arrivals: A(+SFA) + B, or A + B + SFA when + // the split-SFA experiment uses an otherwise idle TMA warp. + full_barriers[i]->init(kSplitSFATMA ? 3 : 2); + // With cluster multicast the leader CTA's TMA warp waits on peer + // empty barriers too, so every math warp releases both CTAs. + empty_barriers[i]->init(kClusterSize * kNumEpilogueWarps); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if constexpr (kClusterSize > 1) { + cute::cluster_sync(); + } else { + __syncthreads(); + } + + // ===================================================================== + // Scheduler (cluster=1) + // ===================================================================== + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, kNumExpertsPerWave, + kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested>(workspace); + + // Pipeline state shared by TMA loaders and math warpgroups + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM barrier indices (mirroring SM100) + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // Cross-rank NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Register reconfiguration counts (chosen to fit in 64512 reg budget). + // For the 256-epilogue-thread case (block_m=128, 2 math WGs): + // 128*48 + 128*40 + 256*208 = 64512 exactly. + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr bool kCompactFrontendWarpgroup = (kNumDispatchWarps == 2 and kNumMMANonEpilogueWarps == 2); + constexpr uint32_t kNumNonEpilogueRegisters = kCompactFrontendWarpgroup ? kNumDispatchRegisters : 40; + constexpr uint32_t kNumEpilogueRegisters = (kSerialNWarpgroups or kWideNWarpgroups) ? 256 : ((kUseMMASync and BLOCK_M == 32) ? 240 : 208); + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + constexpr uint32_t kProfileDispatchTotal = 0; + constexpr uint32_t kProfileDispatchPull = 1; + constexpr uint32_t kProfileMathLoop = 2; + constexpr uint32_t kProfileCombineBarrier = 3; + constexpr uint32_t kProfileCombineReduce = 4; + constexpr uint32_t kProfileGemmCore = 5; + constexpr uint32_t kProfileL1Epilogue = 6; + constexpr uint32_t kProfileL2Epilogue = 7; + const auto phase_profile_clock = [&]() -> unsigned long long { + if constexpr (kPhaseProfileRequested) { + unsigned long long t; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t)); + return t; + } else { + return 0ull; + } + }; + const auto phase_profile_record = [&](const uint32_t& metric, const unsigned long long& cycles) { + if constexpr (kPhaseProfileRequested) { + if (cumulative_local_expert_recv_stats != nullptr and cycles > 0) { + auto profile = reinterpret_cast( + cumulative_local_expert_recv_stats + kNumExpertsPerRank); + atomicAdd(profile + metric, cycles); + atomicMax(profile + 8 + metric, cycles); + atomicAdd(profile + 16 + metric, 1ull); + } + } + }; + + // ===================================================================== + // ROLE 1: DISPATCH WARPS + // Mirrors SM100 dispatch with two changes: + // * SF is per-128 channel float (no UTCCP transpose). We store the + // remote per-token SF directly into the local L1 SF buffer in + // MN-major layout: `local_sf[k_chunk * num_padded_sf_pool_tokens + token_idx]`. + // * The "token_idx_in_expert" → SF token index is now the simple + // per-block linear mapping (no 4×32 transpose). + // ===================================================================== + if (warp_idx < kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + const unsigned long long dispatch_total_start = phase_profile_clock(); + + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count tokens per expert + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Stake out per-expert SM offsets via global atomic + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source token-topk indices to remote ranks + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + false, true); + + // Sync with epilogue warps before pulling tokens + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + const unsigned long long dispatch_pull_start = phase_profile_clock(); + + // Token / SF pull loop + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + scheduler.fetch_expert_recv_count(); + + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + if (current_expert_idx >= kNumExpertsPerRank) + break; + + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection (identical to SM100) + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA pull token data into SMEM + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Copy SF: per-128 K floats, written linearly (no UTCCP transpose). + constexpr uint32_t kNumSFFloats = kHidden / 128; + DG_STATIC_ASSERT(kNumSFFloats > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const uint32_t sf_pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFFloats, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFFloats) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + + + // Cleanup workspace, overlapping with combine + const unsigned long long dispatch_pull_end = phase_profile_clock(); + if (lane_idx == 0) { + phase_profile_record(kProfileDispatchPull, dispatch_pull_end - dispatch_pull_start); + phase_profile_record(kProfileDispatchTotal, dispatch_pull_end - dispatch_total_start); + } + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + true, false); + + // ===================================================================== + // ROLE 2: GEMM TMA LOAD warps (load A+SFA, B+SFB) + // Default: 4 non-epilogue warps, two active and two idle. + // Compact frontend experiment: 2 dispatch warps + 2 TMA warps share the + // first warpgroup, reducing total CTA threads for the M128/2WG path. + // ===================================================================== + } else if (warp_idx == kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t valid_m = scheduler.template get_valid_m(); + const bool has_valid_m = valid_m > 0; + + // Wait for the pool to be ready. Cluster peers can be dummy CTAs for + // the tail M unit when an expert has an odd number of M blocks. + if (has_valid_m) { + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = valid_m; + while (ptx::ld_acq(ptr) != expected); + } else { + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. + constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; + const uint64_t expected = (kNumL1BlockNs >= 64) + ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); + while (ptx::ld_acq_gpu(ptr) != expected); + } + } + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + if (has_valid_m) { + const uint32_t m_idx = pool_block_idx * BLOCK_M; + const uint32_t k_idx = k_block_idx * BLOCK_K; + + // TMA load A + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], + k_idx, m_idx, 1); + + if constexpr (kSplitSFATMA) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE); + } else { + // TMA load SFA + if (block_phase == sched::BlockPhase::Linear1) { + // L1 SFA per-128: load (BLOCK_M, 1) at K=k_block_idx + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx, 1); + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + BLOCK_M * sizeof(float)); + } else { + // L2 SFA per-64: descriptor box is (block_mn, 1) (see make_tma_sf_desc), + // so we must issue two single-group TMAs and place them at smem offsets + // 0 and BLOCK_M to match math's load offsets (`+ 0 * BLOCK_M` / `+ 1 * BLOCK_M`). + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx * 2, 1); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], + smem_sfa[stage_idx] + kL2SFAHalfStride, + m_idx, k_block_idx * 2 + 1, 1); + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + 2 * BLOCK_M * sizeof(float)); + } + } + } else { + full_barriers[stage_idx]->arrive(); + } + } + __syncwarp(); + } + }); + + } else if (warp_idx == kNumDispatchWarps + 1) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + + const uint32_t shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + const uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + const uint32_t k_idx = k_block_idx * BLOCK_K; + + // TMA load B (weight SF is now loaded directly by math warps from global) + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], + k_idx, n_idx, kClusterSize); + + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); + } + __syncwarp(); + } + }); + + } else if (kSplitSFATMA && warp_idx == kNumDispatchWarps + 2) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + (void)local_expert_idx; + (void)n_block_idx; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t valid_m = scheduler.template get_valid_m(); + const bool has_valid_m = valid_m > 0; + + if (has_valid_m) { + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = valid_m; + while (ptx::ld_acq(ptr) != expected); + } else { + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; + const uint64_t expected = (kNumL1BlockNs >= 64) + ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); + while (ptx::ld_acq_gpu(ptr) != expected); + } + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + if (has_valid_m) { + const uint32_t m_idx = pool_block_idx * BLOCK_M; + + if (block_phase == sched::BlockPhase::Linear1) { + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx, 1); + full_barriers[stage_idx]->arrive_and_expect_tx(BLOCK_M * sizeof(float)); + } else { + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx * 2, 1); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], + smem_sfa[stage_idx] + kL2SFAHalfStride, + m_idx, k_block_idx * 2 + 1, 1); + full_barriers[stage_idx]->arrive_and_expect_tx(2 * BLOCK_M * sizeof(float)); + } + } else { + full_barriers[stage_idx]->arrive(); + } + } + __syncwarp(); + } + }); + + } else if (warp_idx < kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Idle non-epilogue warps (kNumDispatchWarps+2, +3). They must still + // participate in the warpgroup-collective `setmaxnreg.dec.sync.aligned` + // so that the math warpgroup's `warpgroup_reg_alloc` can succeed. + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // ===================================================================== + // ROLE 3: MATH WARPGROUPS (WGMMA + epilogue + combine) + // ===================================================================== + cutlass::arch::warpgroup_reg_alloc(); + + const uint32_t epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const uint32_t epilogue_wg_idx = epilogue_warp_idx / 4; + const uint32_t epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const uint32_t warp_idx_in_wg = epilogue_warp_idx % 4; + + uint32_t async_l1_store_stage = 0; + bool async_l1_store_pending[2] = {false, false}; + uint32_t async_l1_store_pool[2] = {0, 0}; + uint32_t async_l1_store_n[2] = {0, 0}; + + const auto arrive_empty_barrier = [&](const uint32_t& s) { + if constexpr (kClusterSize == 1) { + if (lane_idx == 0) + empty_barriers[s]->arrive(); + } else { + if (lane_idx < kClusterSize) + empty_barriers[s]->arrive(lane_idx); + } + }; + + const auto notify_l1_ready = [&](const uint32_t& ready_pool_block_idx, + const uint32_t& ready_n_block_idx) { + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), + 1ull << ready_n_block_idx); + } + __syncwarp(); + }; + + const auto drain_async_l1_store_stage = [&](const uint32_t& store_stage) { + if constexpr (kAsyncL1TMAStore) { + if (async_l1_store_pending[store_stage]) { + // Two SMEM L1 store buffers are used in FIFO order; waiting + // for <=1 outstanding store makes the older buffer reusable. + ptx::tma_store_wait<1>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(async_l1_store_pool[store_stage], + async_l1_store_n[store_stage]); + async_l1_store_pending[store_stage] = false; + } + } + }; + + const auto drain_all_async_l1_stores = [&]() { + if constexpr (kAsyncL1TMAStore) { + if (async_l1_store_pending[0] or async_l1_store_pending[1]) { + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (async_l1_store_pending[0]) { + notify_l1_ready(async_l1_store_pool[0], async_l1_store_n[0]); + async_l1_store_pending[0] = false; + } + if (async_l1_store_pending[1]) { + notify_l1_ready(async_l1_store_pool[1], async_l1_store_n[1]); + async_l1_store_pending[1] = false; + } + } + } + }; + + // WGMMA-output register layout helpers + const uint32_t row_idx = lane_idx / 4; + const uint32_t col_idx = lane_idx % 4; + const uint32_t r_0 = warp_idx_in_wg * 16 + row_idx; + const uint32_t r_1 = r_0 + 8; + + DG_STATIC_ASSERT(kSplitNWarpgroups || (BLOCK_M % kNumEpilogueWarpgroups == 0), "Invalid block M"); + if constexpr (kUseMMASync) { + DG_STATIC_ASSERT(WG_BLOCK_M == BLOCK_M, "mma.sync decode path uses one M tile per warpgroup"); + DG_STATIC_ASSERT(kNumEpilogueWarpgroups == 1, "mma.sync decode path currently uses one math warpgroup"); + } else if constexpr (kSplitNWarpgroups) { + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, + "Split-N WGs must each run one M64N128 WGMMA per K-block"); + } else if constexpr (kSerialNWarpgroups) { + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, + "Serial-N path runs two M64N128 WGMMAs per K-block"); + } else { + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M, "Each warpgroup must run exactly one WGMMA per K-block"); + } + + // Sync with dispatch + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + const unsigned long long math_loop_start = phase_profile_clock(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const uint32_t valid_m = scheduler.template get_valid_m(); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t m_idx = pool_block_idx * BLOCK_M; + const uint32_t wg_n_idx = kSplitNWarpgroups ? epilogue_wg_idx * WG_BLOCK_N : 0; + const uint32_t wg_l1_out_n_idx = kSplitNWarpgroups ? epilogue_wg_idx * WG_L1_OUT_BLOCK_N : 0; + const uint32_t n_idx = n_block_idx * BLOCK_N + wg_n_idx; + const uint32_t row_block_offset = kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M; + const uint32_t row_offset_r0 = row_block_offset + r_0; + const uint32_t row_offset_r1 = row_block_offset + r_1; + const bool valid_r0 = row_offset_r0 < valid_m; + const bool valid_r1 = row_offset_r1 < valid_m; + + + if constexpr (kAsyncL1TMAStore) { + if (block_phase != sched::BlockPhase::Linear1) + drain_all_async_l1_stores(); + } + + if constexpr (kUseMMASync) { + using MMASyncTiled = cute::TiledMMA< + cute::MMA_Atom, + cute::Layout, cute::_4, cute::_1>>>; + auto mma_sync_tiled = MMASyncTiled{}; + auto thr_mma = mma_sync_tiled.get_thread_slice(epilogue_thread_idx); + auto sC = cute::make_tensor( + cute::make_smem_ptr(smem_accum_f32), + cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); + auto tCsC = thr_mma.partition_C(sC); + auto tCrC = thr_mma.make_fragment_C(tCsC); + auto tCrFinal = thr_mma.make_fragment_C(tCsC); + cute::clear(tCrFinal); + auto cC = cute::make_identity_tensor(cute::shape(sC)); + auto tCcC = thr_mma.partition_C(cC); + + auto add_scaled_fragment = [&](const float& scale_b, const uint32_t& sfa_offset) { + CUTE_UNROLL + for (int i = 0; i < cute::size(tCrC); ++i) { + const auto coord = tCcC(i); + if (cute::elem_less(coord, cute::shape(sC))) { + const uint32_t row = static_cast(cute::get<0>(coord)); + const uint32_t col = static_cast(cute::get<1>(coord)); + const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + sfa_offset + row); + tCrFinal(i) += scale_a * scale_b * tCrC(i); + } + } + }; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + const float gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + const float up_sf = __ldg(base + up_n * kL1SFKBlocks); + + auto sA = cute::make_tensor( + cute::make_smem_ptr(smem_a[stage_idx]), + cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); + auto sB = cute::make_tensor( + cute::make_smem_ptr(smem_b[stage_idx]), + cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<128>{}), cute::GenRowMajor{})); + cute::clear(tCrC); + cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); + + CUTE_UNROLL + for (int i = 0; i < cute::size(tCrC); ++i) { + const auto coord = tCcC(i); + if (cute::elem_less(coord, cute::shape(sC))) { + const uint32_t row = static_cast(cute::get<0>(coord)); + const uint32_t col = static_cast(cute::get<1>(coord)); + const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + row); + const float scale_b = ((col / 8u) & 1u) ? up_sf : gate_sf; + tCrFinal(i) += scale_a * scale_b * tCrC(i); + } + } + } else { + const float l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert + + ((n_block_idx * BLOCK_N + wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); + CUTE_UNROLL + for (uint32_t half = 0; half < 2; ++half) { + const uint32_t k_off = half * (BLOCK_K / 2); + auto sA = cute::make_tensor( + cute::make_smem_ptr(smem_a[stage_idx] + k_off), + cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<64>{}), + cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); + auto sB = cute::make_tensor( + cute::make_smem_ptr(smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off), + cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<64>{}), + cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); + cute::clear(tCrC); + cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); + add_scaled_fragment(l2_sf, half * kL2SFAHalfStride); + } + } + + arrive_empty_barrier(stage_idx); + __syncwarp(); + } + + CUTE_UNROLL + for (int i = 0; i < cute::size(tCrFinal); ++i) { + const auto coord = tCcC(i); + if (cute::elem_less(coord, cute::shape(sC))) { + const uint32_t row = static_cast(cute::get<0>(coord)); + const uint32_t col = static_cast(cute::get<1>(coord)); + smem_accum_f32[row * BLOCK_N + col] = tCrFinal(i); + } + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + + constexpr uint32_t kMMASyncRowsPerPass = kNumEpilogueThreads / 8; + DG_STATIC_ASSERT(kMMASyncRowsPerPass == 16, "mma.sync epilogue maps 8 lanes per row"); + + if (block_phase == sched::BlockPhase::Linear1) { + #pragma unroll + for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { + const uint32_t row = row_base + epilogue_thread_idx / 8; + const uint32_t lane_in_row = epilogue_thread_idx % 8; + const bool valid_row = row < valid_m; + float swiglu_values[8]; + float amax = 0.0f; + + #pragma unroll + for (uint32_t p = 0; p < 8; ++p) { + const uint32_t gate_col = p * 16 + lane_in_row; + const uint32_t up_col = gate_col + 8; + float gate = valid_row ? smem_accum_f32[row * BLOCK_N + gate_col] : 0.0f; + float up = valid_row ? smem_accum_f32[row * BLOCK_N + up_col] : 0.0f; + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + gate = cute::min(gate, kActivationClamp); + up = cute::min(cute::max(up, -kActivationClamp), kActivationClamp); + } + const float e = kFastMath ? __expf(-gate) : expf(-gate); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + const float weight = valid_row ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row) + .get_base_ptr() : 0.0f; + const float value = gate * sig * up * weight; + swiglu_values[p] = value; + amax = cute::max(amax, cute::abs(value)); + } + + amax = math::warp_reduce<8, false>(amax, math::ReduceMax()); + float2 amax_pair = {amax, amax}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + const float sf = sf_pair.x; + const float sf_inv = sf_inv_pair.x; + + if (valid_row) { + #pragma unroll + for (uint32_t p = 0; p < 8; ++p) { + const uint32_t out_col = p * 8 + lane_in_row; + smem_cd_l1[row * L1_OUT_BLOCK_N + out_col] = + cutlass::float_e4m3_t(swiglu_values[p] * sf_inv); + } + if (lane_in_row == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_idx = pool_block_idx * BLOCK_M + row; + const uint32_t k_sf_idx = n_block_idx; + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_idx] = sf; + } + } + } + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1, + out_n_idx, + m_idx); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx); + } + __syncwarp(); + } else { + #pragma unroll + for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { + const uint32_t row = row_base + epilogue_thread_idx / 8; + const uint32_t lane_in_row = epilogue_thread_idx % 8; + if (row < valid_m) { + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + #pragma unroll + for (uint32_t v = 0; v < 2; ++v) { + const uint32_t col = lane_in_row * 16 + v * 8; + uint4 packed; + float f0 = smem_accum_f32[row * BLOCK_N + col + 0]; + float f1 = smem_accum_f32[row * BLOCK_N + col + 1]; + float f2 = smem_accum_f32[row * BLOCK_N + col + 2]; + float f3 = smem_accum_f32[row * BLOCK_N + col + 3]; + float f4 = smem_accum_f32[row * BLOCK_N + col + 4]; + float f5 = smem_accum_f32[row * BLOCK_N + col + 5]; + float f6 = smem_accum_f32[row * BLOCK_N + col + 6]; + float f7 = smem_accum_f32[row * BLOCK_N + col + 7]; + packed.x = math::cast_into_bf16_and_pack(f0, f1); + packed.y = math::cast_into_bf16_and_pack(f2, f3); + packed.z = math::cast_into_bf16_and_pack(f4, f5); + packed.w = math::cast_into_bf16_and_pack(f6, f7); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + col * sizeof(nv_bfloat16)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + return; + } + + if constexpr (kSerialNWarpgroups) { + using WGMMA = L1WGMMA; + constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; + constexpr uint32_t kNumSerialN = 2; + float final_accum[kNumSerialN][kAccumPerThread] = {}; + float accum[kAccumPerThread]; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + float scale_a_0_lo, scale_a_1_lo; + float scale_a_0_hi, scale_a_1_hi; + if (block_phase == sched::BlockPhase::Linear1) { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + } else { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r0); + scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r1); + } + + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + + #pragma unroll + for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { + const uint32_t serial_wg_n_idx = serial_n_idx * WG_BLOCK_N; + float gate_sf = 0.0f, up_sf = 0.0f, l2_sf = 0.0f; + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = (n_block_idx * BLOCK_N + serial_wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + up_sf = __ldg(base + up_n * kL1SFKBlocks); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb = (i & 1u) ? up_sf : gate_sf; + final_accum[serial_n_idx][i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; + final_accum[serial_n_idx][i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; + final_accum[serial_n_idx][i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; + final_accum[serial_n_idx][i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; + } + } else { + l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert + + ((n_block_idx * BLOCK_N + serial_wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[serial_n_idx][i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[serial_n_idx][i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[serial_n_idx][i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[serial_n_idx][i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + } + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k_off, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k_off, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[serial_n_idx][i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; + final_accum[serial_n_idx][i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; + final_accum[serial_n_idx][i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; + final_accum[serial_n_idx][i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; + } + } + } + + arrive_empty_barrier(stage_idx); + __syncwarp(); + } + + if (row_block_offset >= valid_m) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + return; + } + + if (block_phase == sched::BlockPhase::Linear1) { + constexpr uint32_t kNumPairs = kAccumPerThread / 8; + #pragma unroll + for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { + const uint32_t serial_l1_out_n_idx = serial_n_idx * WG_L1_OUT_BLOCK_N; + float swiglu_r0[kNumPairs][2]; + float swiglu_r1[kNumPairs][2]; + float amax_r0 = 0.0f, amax_r1 = 0.0f; + + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t gate = 2 * p, up = 2 * p + 1; + auto clamp_gate = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(x, kActivationClamp); + }; + auto clamp_up = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); + }; + float g_r0_c0 = final_accum[serial_n_idx][gate*4 + 0]; clamp_gate(g_r0_c0); + float g_r0_c1 = final_accum[serial_n_idx][gate*4 + 1]; clamp_gate(g_r0_c1); + float g_r1_c0 = final_accum[serial_n_idx][gate*4 + 2]; clamp_gate(g_r1_c0); + float g_r1_c1 = final_accum[serial_n_idx][gate*4 + 3]; clamp_gate(g_r1_c1); + float u_r0_c0 = final_accum[serial_n_idx][up*4 + 0]; clamp_up(u_r0_c0); + float u_r0_c1 = final_accum[serial_n_idx][up*4 + 1]; clamp_up(u_r0_c1); + float u_r1_c0 = final_accum[serial_n_idx][up*4 + 2]; clamp_up(u_r1_c0); + float u_r1_c1 = final_accum[serial_n_idx][up*4 + 3]; clamp_up(u_r1_c1); + auto silu = [](float x) -> float { + const float e = kFastMath ? __expf(-x) : expf(-x); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + return x * sig; + }; + if (valid_r0) { + swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; + swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; + amax_r0 = cute::max(amax_r0, cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); + } else { + swiglu_r0[p][0] = 0.0f; + swiglu_r0[p][1] = 0.0f; + } + if (valid_r1) { + swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; + swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; + amax_r1 = cute::max(amax_r1, cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); + } else { + swiglu_r1[p][0] = 0.0f; + swiglu_r1[p][1] = 0.0f; + } + } + + float weight_r0 = valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + float weight_r1 = valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + swiglu_r0[p][0] *= weight_r0; + swiglu_r0[p][1] *= weight_r0; + swiglu_r1[p][0] *= weight_r1; + swiglu_r1[p][1] *= weight_r1; + } + amax_r0 *= cute::abs(weight_r0); + amax_r1 *= cute::abs(weight_r1); + amax_r0 = math::warp_reduce<4, false>(amax_r0, math::ReduceMax()); + amax_r1 = math::warp_reduce<4, false>(amax_r1, math::ReduceMax()); + + float sf_r0, sf_inv_r0, sf_r1, sf_inv_r1; + { + float2 amax_pair = {amax_r0, amax_r1}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + sf_r0 = sf_pair.x; sf_inv_r0 = sf_inv_pair.x; + sf_r1 = sf_pair.y; sf_inv_r1 = sf_inv_pair.y; + } + + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const float v00 = swiglu_r0[p][0] * sf_inv_r0; + const float v01 = swiglu_r0[p][1] * sf_inv_r0; + const float v10 = swiglu_r1[p][0] * sf_inv_r1; + const float v11 = swiglu_r1[p][1] * sf_inv_r1; + const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); + const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); + const uint32_t col = p * 8 + col_idx * 2; + auto* p0 = reinterpret_cast( + smem_cd_l1 + r_0 * L1_OUT_BLOCK_N + serial_l1_out_n_idx + col); + auto* p1 = reinterpret_cast( + smem_cd_l1 + r_1 * L1_OUT_BLOCK_N + serial_l1_out_n_idx + col); + if (valid_r0) + *p0 = r0_pair.__x; + if (valid_r1) + *p1 = r1_pair.__x; + } + + if (col_idx == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; + const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; + const uint32_t k_sf_idx = (n_block_idx * L1_OUT_BLOCK_N + serial_l1_out_n_idx) / 64u; + if (valid_r0) + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r0] = sf_r0; + if (valid_r1) + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r1] = sf_r1; + } + } + + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1, + out_n_idx, + m_idx + row_block_offset); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } else { + constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; + #pragma unroll + for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { + const uint32_t serial_n_idx_base = n_block_idx * BLOCK_N + serial_n_idx * WG_BLOCK_N; + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { + auto smem_ptr = smem_cd_l2 + row * WG_BLOCK_N + col; + *reinterpret_cast(smem_ptr) = packed; + }; + if (valid_r0) { + const uint32_t r0_lo = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_lo*4 + 0], final_accum[serial_n_idx][chunk_lo*4 + 1]); + const uint32_t r0_hi = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_hi*4 + 0], final_accum[serial_n_idx][chunk_hi*4 + 1]); + write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); + write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); + } + if (valid_r1) { + const uint32_t r1_lo = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_lo*4 + 2], final_accum[serial_n_idx][chunk_lo*4 + 3]); + const uint32_t r1_hi = math::cast_into_bf16_and_pack( + final_accum[serial_n_idx][chunk_hi*4 + 2], final_accum[serial_n_idx][chunk_hi*4 + 3]); + write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); + write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); + } + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + const uint32_t row_in_warp_block = lane_idx / 16; + const uint32_t lane_in_row = lane_idx % 16; + constexpr uint32_t cols_per_lane = WG_BLOCK_N / 16; + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; + const uint32_t m_idx_in_block = row_block_offset + row_in_wg; + if (m_idx_in_block >= valid_m) break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + auto smem_ptr = smem_cd_l2 + row_in_wg * WG_BLOCK_N + lane_in_row * cols_per_lane; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto packed = *reinterpret_cast(smem_ptr); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + serial_n_idx_base * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + } + return; + } + + // ---------------- GEMM ---------------- + using WGMMA = L1WGMMA; + constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; // 64 for M=64,N=128 + float final_accum[kAccumPerThread] = {}; + float accum[kAccumPerThread]; + + const unsigned long long block_gemm_start = phase_profile_clock(); + const auto run_default_gemm_loop = [&]() { +for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + // Read SF (must precede warpgroup_arrive) + float scale_a_0_lo, scale_a_1_lo; + float scale_a_0_hi, scale_a_1_hi; // Only used in L2 (per-64 K) + if (block_phase == sched::BlockPhase::Linear1) { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + } else { + // L2: SFA layout is (K=2, M=BLOCK_M) MN-major; first half SF at offset 0, second at BLOCK_M + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); + scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r0); + scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r1); + } + + // ----- Block (128, 128) weight SF (loaded directly from global) ----- + // L1 weight SF shape: (E, 2*IH/128, H/128) MN-major. The N axis is + // [gate(IH/128), up(IH/128)]; with the gate/up gran-8 interleave on + // the FP8 weight, each BLOCK_N=128 tile covers 64 rows of gate plus + // 64 rows of up taken from the same original 128-row block, so: + // gate_sf_n = n_block_idx / 2 + // up_sf_n = (IH/128) + n_block_idx / 2 + // + // L2 weight SF shape: (E, H/128, IH/128) MN-major. One scalar per + // (BLOCK_N, BLOCK_K) tile, broadcast across all WGMMA accumulators. + // + // NOTE: we tried hoisting these LDGs above the barrier wait and/or + // having only lane 0 load + shfl-broadcast. Both regressed on H20 + // by 7-11% across all batch sizes, presumably because (a) Hopper's + // L1 read-only cache already coalesces same-address LDGs from all + // 128 WG threads and (b) hoisting contended with the dispatch + // warps' NVLink LDGs on the MIO unit. Keep the simple parallel + // post-wait load. + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + float gate_sf = 0.0f, up_sf = 0.0f, l2_sf_lo = 0.0f, l2_sf_hi = 0.0f; + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + up_sf = __ldg(base + up_n * kL1SFKBlocks); + } else { + const float* base = l2_weights_sf + local_expert_idx * kL2SFPerExpert + k_block_idx; + const uint32_t sf_n = (n_block_idx * BLOCK_N + wg_n_idx) / 128u; + l2_sf_lo = __ldg(base + sf_n * kL2SFKBlocks); + if constexpr (WG_BLOCK_N > 128) + l2_sf_hi = __ldg(base + (sf_n + 1u) * kL2SFKBlocks); + else + l2_sf_hi = l2_sf_lo; + } + + if (block_phase == sched::BlockPhase::Linear1) { + // Single per-128 K-block WGMMA group + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage_idx); + + // L1: gate/up alternate at gran=8 along N; each `i` block of 8 + // cols belongs entirely to one of {gate, up}, so .x and .y + // share the same scalar. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb = (i & 1u) ? up_sf : gate_sf; + final_accum[i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; + } + } else { + if constexpr (kL2DualAccum) { + float accum_hi[kAccumPerThread]; + + const auto desc_a_lo0 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K, 1); + const auto desc_b_lo0 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K, 1); + const auto desc_a_lo1 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + WGMMA::K, 1); + const auto desc_b_lo1 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + WGMMA::K, 1); + const auto desc_a_hi0 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + BLOCK_K / 2, 1); + const auto desc_b_hi0 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + BLOCK_K / 2, 1); + const auto desc_a_hi1 = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + BLOCK_K / 2 + WGMMA::K, 1); + const auto desc_b_hi1 = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + BLOCK_K / 2 + WGMMA::K, 1); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_hi[i]); + } + ptx::warpgroup_arrive(); + WGMMA::wgmma(desc_a_lo0, desc_b_lo0, accum, false); + WGMMA::wgmma(desc_a_lo1, desc_b_lo1, accum, true); + WGMMA::wgmma(desc_a_hi0, desc_b_hi0, accum_hi, false); + WGMMA::wgmma(desc_a_hi1, desc_b_hi1, accum_hi, true); + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_hi[i]); + } + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage_idx); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; + final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum_hi[i*4+0]; + final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum_hi[i*4+1]; + final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum_hi[i*4+2]; + final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum_hi[i*4+3]; + } + } else { + // L2: split BLOCK_K=128 into two halves (per-64 SFA), each 2 WGMMAs. + // First half: K=0..63, SFA = scale_a_*_lo + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // L2 weight SF is per 128 output columns; M64N256 spans two SF groups. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; + final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + } + + // Second half: K=64..127, SFA = scale_a_*_hi + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + row_block_offset * BLOCK_K + k_off, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage_idx); + + // L2 second half: same SFA half, still choose weight SF by N chunk. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; + final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; + } + } + } + } + }; + + const auto run_l1_dual_k_gemm_loop = [&]() { + DG_STATIC_ASSERT((kHidden / BLOCK_K) % 2 == 0, "L1 dual-K expects an even number of K blocks"); + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* expert_sf_base = l1_weights_sf + local_expert_idx * kL1SFPerExpert; + const float* gate_sf_base = expert_sf_base + gate_n * kL1SFKBlocks; + const float* up_sf_base = expert_sf_base + up_n * kL1SFKBlocks; + float accum_b[kAccumPerThread]; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks;) { + const uint32_t stage0 = stage_idx; + const uint32_t phase0 = phase; + const uint32_t k0 = k_block_idx; + full_barriers[stage0]->wait(phase0); + + const float scale_a0_r0 = ptx::ld_shared(smem_sfa[stage0] + row_offset_r0); + const float scale_a0_r1 = ptx::ld_shared(smem_sfa[stage0] + row_offset_r1); + const float2 gate_sf_pair = __ldg(reinterpret_cast(gate_sf_base + k0)); + const float2 up_sf_pair = __ldg(reinterpret_cast(up_sf_base + k0)); + const float gate_sf0 = gate_sf_pair.x; + const float up_sf0 = up_sf_pair.x; + + advance_pipeline(k_block_idx); + const uint32_t stage1 = stage_idx; + const uint32_t phase1 = phase; + full_barriers[stage1]->wait(phase1); + + const float scale_a1_r0 = ptx::ld_shared(smem_sfa[stage1] + row_offset_r0); + const float scale_a1_r1 = ptx::ld_shared(smem_sfa[stage1] + row_offset_r1); + const float gate_sf1 = gate_sf_pair.y; + const float up_sf1 = up_sf_pair.y; + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_b[i]); + } + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage0] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage0] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage1] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage1] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum_b, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) { + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum_b[i]); + } + ptx::warpgroup_wait<0>(); + + arrive_empty_barrier(stage0); + arrive_empty_barrier(stage1); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb0 = (i & 1u) ? up_sf0 : gate_sf0; + const float sb1 = (i & 1u) ? up_sf1 : gate_sf1; + final_accum[i*4+0] += scale_a0_r0 * sb0 * accum[i*4+0]; + final_accum[i*4+1] += scale_a0_r0 * sb0 * accum[i*4+1]; + final_accum[i*4+2] += scale_a0_r1 * sb0 * accum[i*4+2]; + final_accum[i*4+3] += scale_a0_r1 * sb0 * accum[i*4+3]; + final_accum[i*4+0] += scale_a1_r0 * sb1 * accum_b[i*4+0]; + final_accum[i*4+1] += scale_a1_r0 * sb1 * accum_b[i*4+1]; + final_accum[i*4+2] += scale_a1_r1 * sb1 * accum_b[i*4+2]; + final_accum[i*4+3] += scale_a1_r1 * sb1 * accum_b[i*4+3]; + } + + advance_pipeline(k_block_idx); + } + }; + + if constexpr (kL1DualKAccum) { + if (block_phase == sched::BlockPhase::Linear1) + run_l1_dual_k_gemm_loop(); + else + run_default_gemm_loop(); + } else { + run_default_gemm_loop(); + } + + const unsigned long long block_gemm_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileGemmCore, block_gemm_end - block_gemm_start); + + // Skip epilogue when block is past valid M (still must release via empty). + // A dummy cluster peer may still carry an async L1 store from the + // previous valid block, so drain it before leaving the L1 wave. + if (row_block_offset >= valid_m) { + if constexpr (kAsyncL1TMAStore) { + if (block_phase == sched::BlockPhase::Linear1) + drain_all_async_l1_stores(); + } + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + return; + } + + const unsigned long long block_epilogue_start = phase_profile_clock(); + if (block_phase == sched::BlockPhase::Linear1) { + // ---------------- L1 EPILOGUE: SwiGLU + FP8 quantize + TMA store ---------------- + // Layout in `final_accum`: + // 16 chunks of 8 N-cols, each chunk = 4 floats per thread = (r0c0, r0c1, r1c0, r1c1). + // Gate chunks: even (0, 2, ..., 14). Up chunks: odd (1, 3, ..., 15). + // Pair `p` ∈ [0, 8): gate chunk = 2p, up chunk = 2p+1. + // + // For each pair we produce 4 post-SwiGLU floats per thread, mapped to + // output cols (p*8 + col_idx*2 + {0,1}) for both r0 and r1. + + constexpr uint32_t kNumPairs = kAccumPerThread / 8; + constexpr uint32_t kNumSFGroups = WG_L1_OUT_BLOCK_N / 64; + DG_STATIC_ASSERT(WG_L1_OUT_BLOCK_N % 64 == 0, "L1 output SF is per 64 columns"); + float swiglu_r0[kNumPairs][2]; + float swiglu_r1[kNumPairs][2]; + + // Per-row amax, one scale for each 64-col L1 output group. + float amax_r0[kNumSFGroups] = {}; + float amax_r1[kNumSFGroups] = {}; + + // Compute SwiGLU + per-group amax. + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t gate = 2 * p, up = 2 * p + 1; + const uint32_t sf_group = p / 8; + + auto clamp_gate = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(x, kActivationClamp); + }; + auto clamp_up = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); + }; + float g_r0_c0 = final_accum[gate*4 + 0]; clamp_gate(g_r0_c0); + float g_r0_c1 = final_accum[gate*4 + 1]; clamp_gate(g_r0_c1); + float g_r1_c0 = final_accum[gate*4 + 2]; clamp_gate(g_r1_c0); + float g_r1_c1 = final_accum[gate*4 + 3]; clamp_gate(g_r1_c1); + float u_r0_c0 = final_accum[up*4 + 0]; clamp_up(u_r0_c0); + float u_r0_c1 = final_accum[up*4 + 1]; clamp_up(u_r0_c1); + float u_r1_c0 = final_accum[up*4 + 2]; clamp_up(u_r1_c0); + float u_r1_c1 = final_accum[up*4 + 3]; clamp_up(u_r1_c1); + + auto silu = [](float x) -> float { + const float e = kFastMath ? __expf(-x) : expf(-x); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + return x * sig; + }; + + if (valid_r0) { + swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; + swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; + amax_r0[sf_group] = cute::max( + amax_r0[sf_group], + cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); + } else { + swiglu_r0[p][0] = 0.0f; + swiglu_r0[p][1] = 0.0f; + } + if (valid_r1) { + swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; + swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; + amax_r1[sf_group] = cute::max( + amax_r1[sf_group], + cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); + } else { + swiglu_r1[p][0] = 0.0f; + swiglu_r1[p][1] = 0.0f; + } + } + + + const float weight_r0 = valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + const float weight_r1 = valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + swiglu_r0[p][0] *= weight_r0; + swiglu_r0[p][1] *= weight_r0; + swiglu_r1[p][0] *= weight_r1; + swiglu_r1[p][1] *= weight_r1; + } + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + amax_r0[g] *= cute::abs(weight_r0); + amax_r1[g] *= cute::abs(weight_r1); + } + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + amax_r0[g] = math::warp_reduce<4, false>(amax_r0[g], math::ReduceMax()); + amax_r1[g] = math::warp_reduce<4, false>(amax_r1[g], math::ReduceMax()); + } + + float sf_r0[kNumSFGroups], sf_inv_r0[kNumSFGroups]; + float sf_r1[kNumSFGroups], sf_inv_r1[kNumSFGroups]; + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + float2 amax_pair = {amax_r0[g], amax_r1[g]}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + sf_r0[g] = sf_pair.x; sf_inv_r0[g] = sf_inv_pair.x; + sf_r1[g] = sf_pair.y; sf_inv_r1[g] = sf_inv_pair.y; + } + + // Quantize and write to smem_cd_l1 (row-major, no swizzle). + const uint32_t l1_store_stage = kAsyncL1TMAStore ? async_l1_store_stage : 0u; + if constexpr (kAsyncL1TMAStore) + drain_async_l1_store_stage(l1_store_stage); + auto* smem_cd_l1_wg = smem_cd_l1 + + l1_store_stage * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N + + (kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M * L1_OUT_BLOCK_N); + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t sf_group = p / 8; + const float v00 = swiglu_r0[p][0] * sf_inv_r0[sf_group]; + const float v01 = swiglu_r0[p][1] * sf_inv_r0[sf_group]; + const float v10 = swiglu_r1[p][0] * sf_inv_r1[sf_group]; + const float v11 = swiglu_r1[p][1] * sf_inv_r1[sf_group]; + + const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); + const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); + + const uint32_t col = p * 8 + col_idx * 2; + auto* p0 = reinterpret_cast( + smem_cd_l1_wg + r_0 * L1_OUT_BLOCK_N + wg_l1_out_n_idx + col); + auto* p1 = reinterpret_cast( + smem_cd_l1_wg + r_1 * L1_OUT_BLOCK_N + wg_l1_out_n_idx + col); + if (valid_r0) + *p0 = r0_pair.__x; + if (valid_r1) + *p1 = r1_pair.__x; + } + + // Write L2-activation SF as float, one value per 64 output columns. + if (col_idx == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; + const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; + const uint32_t base_k_sf_idx = (n_block_idx * L1_OUT_BLOCK_N + wg_l1_out_n_idx) / 64u; + #pragma unroll + for (uint32_t g = 0; g < kNumSFGroups; ++ g) { + if (valid_r0) + sf_base_ptr[(base_k_sf_idx + g) * kNumPaddedSFPoolTokens + token_r0] = sf_r0[g]; + if (valid_r1) + sf_base_ptr[(base_k_sf_idx + g) * kNumPaddedSFPoolTokens + token_r1] = sf_r1[g]; + } + } + + // Issue TMA store of the entire tile. Padding rows beyond + // `valid_m` are written with stale/garbage FP8 to the L1-output + // pool buffer, but they are never consumed downstream: the L2 + // GEMM tile loads them, but its NVLink-scatter epilogue is + // gated by `m_idx_in_block >= valid_m`, and stale SF in the + // padding rows can produce NaN accumulators that simply stay + // in registers (only valid rows are converted to BF16 and + // STSM'd into smem). Using TMA for partial tiles is a large + // win for low-batch / decode where every tile is partial. + if constexpr (kSplitNWarpgroups) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1, + out_n_idx, + m_idx); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } else { + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1_wg, + out_n_idx, + m_idx + row_block_offset); + cute::tma_store_arrive(); + } + __syncwarp(); + if constexpr (kAsyncL1TMAStore) { + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + async_l1_store_pending[l1_store_stage] = true; + async_l1_store_pool[l1_store_stage] = pool_block_idx; + async_l1_store_n[l1_store_stage] = n_block_idx; + async_l1_store_stage ^= 1u; + } else { + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } + } + const unsigned long long block_epilogue_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileL1Epilogue, block_epilogue_end - block_epilogue_start); + } else { + // ---------------- L2 EPILOGUE: BF16 cast + NVLink scatter ---------------- + constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; + + if constexpr (kDirectL2Scatter) { + DG_STATIC_ASSERT(WG_BLOCK_N == 128, "Direct L2 scatter prototype only supports N128"); + + auto scatter_direct_row = [&](const uint32_t& row_offset, const bool& valid_row, const uint32_t& row_accum_offset) { + if (valid_row) { + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row_offset); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + auto dst_base = math::advance_ptr( + dst_token.get_base_ptr(), n_idx * sizeof(nv_bfloat16)); + auto mapped_dst_base = sym_buffer.map(dst_base, dst_rank_idx); + + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + const uint32_t col_lo = chunk_lo * 8 + col_idx * 2; + const uint32_t col_hi = chunk_hi * 8 + col_idx * 2; + const uint32_t packed_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo * 4 + row_accum_offset + 0], + final_accum[chunk_lo * 4 + row_accum_offset + 1]); + const uint32_t packed_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi * 4 + row_accum_offset + 0], + final_accum[chunk_hi * 4 + row_accum_offset + 1]); + *reinterpret_cast(mapped_dst_base + col_lo * sizeof(nv_bfloat16)) = packed_lo; + *reinterpret_cast(mapped_dst_base + col_hi * sizeof(nv_bfloat16)) = packed_hi; + } + } + }; + + scatter_direct_row(row_offset_r0, valid_r0, 0); + scatter_direct_row(row_offset_r1, valid_r1, 2); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } else { + // STSM into smem_cd_l2 (BF16). Reuse SM100 column-swizzle layout. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + // Each i consumes 8 floats (one 16x256b chunk in SM100 terms). + // For SM90 WGMMA layout, 8 floats per i correspond to 2 chunks of 4 floats: + // final_accum[i*8 + (0..3)] = chunk 2i: (r0c0, r0c1, r1c0, r1c1) + // final_accum[i*8 + (4..7)] = chunk 2i+1: same shape + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + + // Write to SMEM at appropriate position + // Row r_0 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r0_lo + // Row r_0 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r0_hi + // Row r_1 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r1_lo + // Row r_1 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r1_hi + auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { + auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * WG_BLOCK_M * WG_BLOCK_N + + row * WG_BLOCK_N + + col; + // BF16 STS: 2 bf16 elements + *reinterpret_cast(smem_ptr) = packed; + }; + if (valid_r0) { + const uint32_t r0_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo*4 + 0], final_accum[chunk_lo*4 + 1]); + const uint32_t r0_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi*4 + 0], final_accum[chunk_hi*4 + 1]); + write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); + write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); + } + if (valid_r1) { + const uint32_t r1_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo*4 + 2], final_accum[chunk_lo*4 + 3]); + const uint32_t r1_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi*4 + 2], final_accum[chunk_hi*4 + 3]); + write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); + write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); + } + } + + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Scatter to remote ranks via NVLink (one row per warp-pair) + // Each warpgroup-warp covers 8 unique rows × 2 (r_0 + r_1 doubled by warps) + // Lane group of 16 within a warp → 1 row. + const uint32_t row_in_warp_block = lane_idx / 16; // 0 or 1 + const uint32_t lane_in_row = lane_idx % 16; + const uint32_t cols_per_lane = WG_BLOCK_N / 16; + static_assert(WG_BLOCK_N == 64 or WG_BLOCK_N == 128 or WG_BLOCK_N == 256, + "L2 scatter supports per-WG N64/N128/N256"); + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; + const uint32_t m_idx_in_block = row_block_offset + row_in_wg; + if (m_idx_in_block >= valid_m) break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * WG_BLOCK_M * WG_BLOCK_N + + row_in_wg * WG_BLOCK_N + + lane_in_row * cols_per_lane; + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + + if constexpr (WG_BLOCK_N == 256) { + const auto packed0 = *reinterpret_cast(smem_ptr); + const auto packed1 = *(reinterpret_cast(smem_ptr) + 1); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * 2u * sizeof(uint4)); + auto mapped_dst_ptr = sym_buffer.map(dst_ptr, dst_rank_idx); + mapped_dst_ptr[0] = packed0; + mapped_dst_ptr[1] = packed1; + } else if constexpr (WG_BLOCK_N == 128) { + const auto packed = *reinterpret_cast(smem_ptr); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } else { + const auto packed = *reinterpret_cast(smem_ptr); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint2)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + const unsigned long long block_epilogue_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileL2Epilogue, block_epilogue_end - block_epilogue_start); + } + }); + const unsigned long long math_loop_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileMathLoop, math_loop_end - math_loop_start); + + // ---------------- COMBINE ---------------- + // NVLink barrier first: signals remote ranks that this rank's GEMM + // outputs (NVLink scatter targets) are fully written. + const unsigned long long combine_barrier_start = phase_profile_clock(); + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + const unsigned long long combine_barrier_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileCombineBarrier, combine_barrier_end - combine_barrier_start); + + // Sync with dispatch (paired with dispatch's pre-cleanup sync) so that + // dispatch may now safely clean workspace state. + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + const unsigned long long combine_reduce_start = phase_profile_clock(); + + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + constexpr uint32_t kNumChunks = + (kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE + and kHidden <= 32 * kNumMaxRegistersForBuffer) ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr( + smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + bool do_reduce = move_mask_and_load(load_stage_idx); + + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + const unsigned long long combine_reduce_end = phase_profile_clock(); + if (epilogue_warp_idx == 0 and lane_idx == 0) + phase_profile_record(kProfileCombineReduce, combine_reduce_end - combine_reduce_start); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only supports sm_90"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh b/deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh new file mode 100644 index 0000000000..503b5b872d --- /dev/null +++ b/deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// +// MXFP4 (E2M1) → FP8 (E4M3) dequant helper for SM90 W4A8 fused MegaMoE. +// +// Ported from vLLM Marlin's `dequant<__nv_fp8x4_e4m3, kFE2M1f, true>` in +// `csrc/quantization/marlin/dequant.h` (Marlin / vLLM Apache 2.0). The bit +// pattern conversion is intentionally identical so that weight files prepared +// with Marlin / TensorRT-LLM `mxfp4_moe` preprocessing can be reused. + +#pragma once + +#include +#include + +namespace deep_gemm { +namespace w4a8 { + +#define DG_W4A8_INLINE __device__ __forceinline__ + +// Convert one packed-FP4 dword (8 nibbles = 8 FP4 values) into eight FP8 E4M3 +// bytes laid out as two `__nv_fp8x4_e4m3` fragments. +// +// IMPORTANT: this routine performs the bit-pattern conversion only; the per-32 +// E8M0 group scale must be applied separately on the WGMMA accumulator +// (Marlin-style), not folded into the FP8 exponent (which is the Humming +// alternative we deliberately do not use here). +// +// Layout note inherited from Marlin: the upper half of `q` (q << 4 step) +// produces `frag_b[0]` and the lower half produces `frag_b[1]`. This reverse +// indexing matches the way Marlin permutes the packed-FP4 weight tile so the +// FP8 register layout aligns with the WGMMA m64n*k32 B-fragment layout. +DG_W4A8_INLINE void dequant_mxfp4_to_fp8(int q, __nv_fp8x4_e4m3* frag_b) { + constexpr int FP4_EXPONENT = 2; + constexpr int FP8_EXPONENT = 4; + constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT; // = 2 + constexpr int MASK = 0x70707070; + + int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +// Convert an E8M0 byte (unsigned 8-bit exponent, bias = 127) into a float +// multiplicative scale. E8M0 stores `2 ** (e - 127)`. Used to apply the +// MXFP4 per-32 group scale to a WGMMA accumulator. +DG_W4A8_INLINE float e8m0_to_float(std::uint8_t e8m0_byte) { + // 2 ** (e - 127). Build via float bit pattern: sign=0, exponent=e, mantissa=0. + std::uint32_t bits = static_cast(e8m0_byte) << 23; + return *reinterpret_cast(&bits); +} + +#undef DG_W4A8_INLINE + +} // namespace w4a8 +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh index cdbecccd56..f65c891968 100644 --- a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -22,6 +22,9 @@ template CUTLASS_DEVICE uint32_t get_valid_m() const { - const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + const auto m_start = m_block_idx * BLOCK_M; + if (m_start >= current_num_tokens) + return 0; + const auto m = cute::min(current_num_tokens - m_start, BLOCK_M); return kDoUMMAAligned ? math::align(m, 16u) : m; } @@ -117,12 +125,22 @@ struct MegaMoEScheduler { const auto wave_end_expert_idx = get_wave_expert_end_idx(); while (current_local_expert_idx < wave_end_expert_idx) { const auto num_m_blocks = get_current_num_m_blocks(); - m_block_idx = block_idx / kNumL1BlockNs; - if (m_block_idx < num_m_blocks) + const auto num_m_units = math::ceil_div(num_m_blocks, kClusterSize); + if (block_idx < num_m_units * kNumL1BlockNs) { + if constexpr (kL1NMajorSchedule) { + n_block_idx = block_idx / num_m_units; + const auto m_unit_idx = block_idx - n_block_idx * num_m_units; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + } else { + const auto m_unit_idx = block_idx / kNumL1BlockNs; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + n_block_idx = block_idx % kNumL1BlockNs; + } return true; + } // Current expert is fully assigned, move to the next - block_idx -= num_m_blocks * kNumL1BlockNs; + block_idx -= num_m_units * kNumL1BlockNs; advance_expert_idx(); } return false; @@ -132,18 +150,28 @@ struct MegaMoEScheduler { const auto wave_end_expert_idx = get_wave_expert_end_idx(); while (current_local_expert_idx < wave_end_expert_idx) { const auto num_m_blocks = get_current_num_m_blocks(); - if (block_idx < num_m_blocks * kNumL2BlockNs) { - m_block_idx = block_idx / kNumL2BlockNs; + const auto num_m_units = math::ceil_div(num_m_blocks, kClusterSize); + if (block_idx < num_m_units * kNumL2BlockNs) { + if constexpr (kL2NMajorSchedule) { + n_block_idx = block_idx / num_m_units; + const auto m_unit_idx = block_idx - n_block_idx * num_m_units; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + } else { + const auto m_unit_idx = block_idx / kNumL2BlockNs; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + n_block_idx = block_idx % kNumL2BlockNs; + } return true; } // Current expert is fully assigned, move to the next - block_idx -= num_m_blocks * kNumL2BlockNs; + block_idx -= num_m_units * kNumL2BlockNs; advance_expert_idx(); } return false; } + // Core state machine: assigns the next block CUTLASS_DEVICE cute::tuple get_next_block() { while (true) { @@ -153,9 +181,8 @@ struct MegaMoEScheduler { if (next_phase == BlockPhase::Linear1) { if (fetch_next_l1_block()) { // Found a new L1 block - n_block_idx = block_idx - m_block_idx * kNumL1BlockNs; - // Jump to next block - block_idx += kNumSMs; + // Jump to next cluster-scheduled work unit + block_idx += kNumSMs / kClusterSize; return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; } else { // L1 for the current wave is complete, transition to L2 @@ -165,9 +192,8 @@ struct MegaMoEScheduler { } else { if (fetch_next_l2_block()) { // Found a new L2 block - n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; - // Jump to next block - block_idx += kNumSMs; + // Jump to next cluster-scheduled work unit + block_idx += kNumSMs / kClusterSize; return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; } else { // Move to L1 of the next wave diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index e624ecf273..6f2295c748 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -105,6 +105,32 @@ def transform_weights_for_mega_moe( return l1_weights, l2_weights +def transform_weights_for_mega_moe_sm90( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """SM90 (Hopper) variant of `transform_weights_for_mega_moe`. + + SM90 has no TMEM / UTCCP path, so the SF tensors are consumed directly by + WGMMA promote and don't need the 4x32 transpose. With block (128, 128) + weight quantization, weight SFs are read by the math warpgroup directly + from global memory in their natural ``(E, N/128, K/128)`` MN-major layout + and require no transformation. Only L1's gate/up FP8 weight interleave is + preserved. + """ + l1_fp8, l1_sf = l1_weights + # Reuse the gran-8 N interleave on the FP8 weight only; the block SF stays + # in its natural ``(E, 2*IH/128, H/128)`` layout (gate then up along N). + def _interleave_one(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return (_interleave_one(l1_fp8), l1_sf), l2_weights + + def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], @@ -126,3 +152,62 @@ def fp8_fp4_mega_moe(y: torch.Tensor, activation, activation_clamp, fast_math ) + + +def fp8_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (128, 128, 128), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + """SM90 (Hopper) MegaMoE entry point. + + Expects FP8 e4m3 weights and block-(128, 128) float scale factors. The + weight SF layout matches the convention used by ``DeepSeekV4FlashFp8`` / + DeepEP, so the same SF tensors can be physically shared between the + DeepEP path and this kernel. + """ + _C.fp8_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) + + +def w4a8_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (128, 128, 128), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + """SM90 (Hopper) W4A8 MegaMoE entry — Phase 0 scaffold. + + Accepts the same FP8 weights + per-128 float SF as ``fp8_mega_moe`` and + dispatches to the W4A8 runtime (currently a verbatim FP8 copy under a + different name; subsequent phases will switch to packed MXFP4 + E8M0). + """ + _C.w4a8_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/scripts/run_ncu_mega_moe_sm90.sh b/scripts/run_ncu_mega_moe_sm90.sh new file mode 100755 index 0000000000..e8c2b0ef61 --- /dev/null +++ b/scripts/run_ncu_mega_moe_sm90.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# SM90 (Hopper) variant of run_ncu_mega_moe.sh +# Drives `tests/bench_mega_moe_sm90.py` with NCU, profiling the +# `sm90_fp8_mega_moe_impl` kernel for a single batch size. + +set -e + +num_processes=8 +output_dir=work_sm90 +python_args=() +for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do + arg="${!arg_idx}" + case "$arg" in + --num-processes) + python_args+=("$arg") + if ((arg_idx < $#)); then + ((arg_idx++)) + num_processes="${!arg_idx}" + python_args+=("$num_processes") + fi + ;; + -h|--help) + echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]" + exit 0 + ;; + --num-processes=*) + num_processes="${arg#*=}" + python_args+=("$arg") + ;; + -o|--output) + if ((arg_idx < $#)); then + ((arg_idx++)) + output_dir="${!arg_idx}" + fi + ;; + --output=*) + output_dir="${arg#*=}" + ;; + *) + python_args+=("$arg") + ;; + esac +done + +echo "Python Args: ${python_args[*]}" +echo "Num Processes: $num_processes" +echo "Output Dir: $output_dir" +mkdir -p "$output_dir" + +export DG_JIT_WITH_LINEINFO=1 + +echo "Warm up JIT cache" +python tests/bench_mega_moe_sm90.py --ncu-profile-only "${python_args[@]}" + +sleep 2 + +ncu_args=( + --config-file off + --force-overwrite + --kernel-name sm90_fp8_mega_moe_impl + --import-source yes + --replay-mode application + --section SpeedOfLight + --section LaunchStats + --section SchedulerStats + --section WarpStateStats + --section MemoryWorkloadAnalysis + --section InstructionStats + --launch-skip 0 + --launch-count 1 + --clock-control none + --kill yes + --app-replay-buffer memory +) + +echo "Run Job" + +for ((i = 0; i < num_processes; ++i)); do + ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe-sm90.$i" \ + python tests/bench_mega_moe_sm90.py \ + --local-rank-idx=$i \ + --ncu-profile-only \ + "${python_args[@]}" & +done + +echo "Waiting" +wait +echo "Done" diff --git a/tests/bench_mega_moe_sm90.py b/tests/bench_mega_moe_sm90.py new file mode 100644 index 0000000000..3711dd48c2 --- /dev/null +++ b/tests/bench_mega_moe_sm90.py @@ -0,0 +1,250 @@ +"""SM90 (Hopper) MegaMoE benchmark / NCU-profile harness. + +Mirrors ``tests/test_mega_moe.py``'s ``--ncu-profile-only`` / +``--local-rank-idx`` interface so the same ``scripts/run_ncu_mega_moe.sh`` +pattern can drive it for SM90. + +In normal (non-NCU) mode it sweeps a list of ``num_tokens`` values (default: +1, 2, 4, 8, 16, 32) and reports per-call kernel time via the same +``bench_kineto`` helper used by the SM100 perf test, plus a rough TFLOPS / +HBM GB/s figure useful for tracking optimisation deltas. +""" + +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, calc_diff, get_arch_major + + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _run_one_config(args, num_tokens, num_max_tokens_per_rank, + hidden, intermediate_hidden, + num_experts, num_topk, num_ranks, rank_idx, group, + activation_clamp, fast_math, + print_perf=True): + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + # Symmetric buffer (one per config: cheaper to recreate than to keep max-size) + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + ) + + # Inputs (bf16, then quantised) + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_w.masked_fill_(topk_idx < 0, 0) + + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), + ) + + phase_profile_enabled = os.environ.get('DG_SM90_MOE_PHASE_PROFILE', '0') != '0' + phase_profile_ints = 64 if phase_profile_enabled else 0 + cum_stats = torch.zeros(num_experts_per_rank + phase_profile_ints, dtype=torch.int, device='cuda') + + # Stage inputs once; bench-loop re-copies them each call (bench helper expects + # an idempotent ``fn``). + def run_fused(): + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + return y + + if args.ncu_profile_only: + dist_print(f'[NCU] tokens={num_tokens} hidden={hidden} ih={intermediate_hidden}', + once_in_node=True) + run_fused() + torch.cuda.synchronize() + dist.barrier() + buffer.destroy() + return + + # Warm up + benchmark + run_fused() + dist.barrier() + if phase_profile_enabled: + cum_stats.zero_() + torch.cuda.synchronize() + dist.barrier() + # NSYS MULTI-ITER (aichenf): N timed iters with barrier+sleep between them. + # bench_kineto returns 1 under DG_USE_NVIDIA_TOOLS=1, but this loop puts + # multiple mega_moe instances on the nsys timeline so we can measure variance. + import os as _os + _nsys_iters = int(_os.environ.get('NSYS_ITERS', '0')) + if _nsys_iters > 0: + for _it in range(_nsys_iters): + torch.cuda.synchronize() + dist.barrier() + torch.cuda._sleep(int(2e7)) # 10ms gap between iters + dist.barrier() + run_fused() + torch.cuda.synchronize() + dist.barrier() + t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True) + + # Count tokens that landed on this rank for stats + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() + + safe_div = lambda a, b: float('nan') if b == 0 else a / b + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + num_touched_experts = max(0, torch.unique(gathered_topk_idx.flatten()).numel() - 1) + # FP8 weights = 1 byte, FP8 acts = 1 byte, BF16 output = 2 bytes + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden + # L1 weights + num_touched_experts * hidden * intermediate_hidden + # L2 weights + num_recv_tokens * hidden + # L1 acts read + num_recv_tokens * intermediate_hidden + # L1 out write + num_recv_tokens * intermediate_hidden + # L2 acts read + num_recv_tokens * hidden * 2 # L2 out write + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + if print_perf: + dist_print( + f' tokens={num_tokens:4d} recv={num_recv_tokens:5d} experts={num_touched_experts:4d} ' + f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', + once_in_node=True, + ) + if phase_profile_enabled: + torch.cuda.synchronize() + profile = cum_stats[num_experts_per_rank:num_experts_per_rank + 48].view(torch.int64).cpu().tolist() + names = [ + 'dispatch_total', 'dispatch_pull', 'math_loop', 'combine_barrier', + 'combine_reduce', 'gemm_core', 'l1_epilogue', 'l2_epilogue', + ] + for i, name in enumerate(names): + total, max_v, count = profile[i], profile[8 + i], profile[16 + i] + avg = float(total) / count if count else 0.0 + dist_print( + f' phase {name:16s} avg={avg:10.0f} max={max_v:10d} count={count}', + once_in_node=True, + ) + + dist.barrier() + buffer.destroy() + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + forced_num_sms = int(os.environ.get('DG_SM90_MOE_SET_NUM_SMS', '0')) + if forced_num_sms > 0: + deep_gemm.set_num_sms(forced_num_sms) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print(f'[SKIP] requires SM90, got SM{get_arch_major()}0', once_in_node=True) + dist.destroy_process_group() + return + + if args.batches is None: + batches = [1, 2, 4, 8, 16, 32] + else: + batches = args.batches + + dist_print( + f'SM90 MegaMoE bench: ranks={num_ranks} hidden={args.hidden} ' + f'ih={args.intermediate_hidden} experts={args.num_experts} topk={args.num_topk} ' + f'masked_ratio={args.masked_ratio} fast_math={bool(args.fast_math)}', + once_in_node=True, + ) + + # In NCU mode we run only one batch (the first one in `batches`) so that + # ncu's `--launch-count 1` is unambiguous. + if args.ncu_profile_only: + batches = batches[:1] + + num_max_tokens_per_rank = max(batches) + for num_tokens in batches: + _run_one_config( + args, num_tokens, num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + num_ranks, rank_idx, group, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math), + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SM90 MegaMoE benchmark') + + parser.add_argument('--ncu-profile-only', action='store_true') + parser.add_argument('--num-processes', type=int, default=8) + parser.add_argument('--local-rank-idx', type=int, default=None) + + parser.add_argument('--batches', type=int, nargs='+', default=None, + help='List of num_tokens to sweep (default: 1 2 4 8 16 32)') + parser.add_argument('--hidden', type=int, default=7168) + parser.add_argument('--intermediate-hidden', type=int, default=2048) + parser.add_argument('--num-experts', type=int, default=256) + parser.add_argument('--num-topk', type=int, default=8) + parser.add_argument('--activation-clamp', type=float, default=10.0) + parser.add_argument('--masked-ratio', type=float, default=0.0) + parser.add_argument('--fast-math', type=int, default=1) + parser.add_argument('--num-tests', type=int, default=20) + + args = parser.parse_args() + + if args.local_rank_idx is not None: + test(args.local_rank_idx, args.num_processes, args) + else: + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/bench_mega_moe_sm90_skew.py b/tests/bench_mega_moe_sm90_skew.py new file mode 100644 index 0000000000..4b6855d7ba --- /dev/null +++ b/tests/bench_mega_moe_sm90_skew.py @@ -0,0 +1,305 @@ +"""SM90 (Hopper) MegaMoE benchmark — patched to support power-law routing +and aggregate per-rank timing. + +Adds vs upstream `tests/bench_mega_moe_sm90.py`: + --skew-alpha Zipf exponent for expert popularity (0 = uniform, upstream default) + --skew-gain Scale of log-bias added to scores (default 1.0) + --skew-seed Seed for the per-expert bias permutation. Constant across + ranks so all ranks see the SAME hot/cold expert assignment + (matches production: gate is shared globally). + +Output is aggregated across all ranks: + - per-rank t_us, recv_tokens (max/mean/min) + - per-rank max/mean ratio (intra-rank imbalance) + - inter-rank max/mean t_us ratio (cross-rank imbalance) +""" +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, get_arch_major + + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _make_global_bias(num_experts: int, alpha: float, gain: float, seed: int): + """Power-law (Zipf) bias per expert, identical across all ranks. + + Generated with a deterministic seed so all ranks see the same hot/cold + expert assignment. Returns None if alpha == 0 (uniform). + """ + if alpha <= 0: + return None + gen = torch.Generator(device='cuda') + gen.manual_seed(seed) + ranks = torch.arange(1, num_experts + 1, device='cuda', dtype=torch.float) + probs = 1.0 / ranks.pow(alpha) + probs = probs / probs.sum() + bias = torch.log(probs * num_experts) * gain + perm = torch.randperm(num_experts, device='cuda', generator=gen) + return bias[perm] + + +def _run_one_config(args, num_tokens, num_max_tokens_per_rank, + hidden, intermediate_hidden, + num_experts, num_topk, num_ranks, rank_idx, group, + activation_clamp, fast_math, + print_perf=True): + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + ) + + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + + # Per-rank random scores + global skew bias (added to all rows) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + global_bias = _make_global_bias(num_experts, args.skew_alpha, args.skew_gain, + args.skew_seed) + if global_bias is not None: + scores = scores + global_bias.unsqueeze(0) + + replica_for = {} + if args.num_redundant_experts > 0: + assert args.num_redundant_experts % num_ranks == 0, 'redundant experts must divide ranks' + num_replicas_per_rank = args.num_redundant_experts // num_ranks + assert 0 < num_replicas_per_rank < num_experts_per_rank, 'invalid redundant expert count' + replica_slots = [] + for r in range(num_ranks): + base = r * num_experts_per_rank + replica_slots += list(range(base + num_experts_per_rank - num_replicas_per_rank, + base + num_experts_per_rank)) + logical_mask = torch.ones(num_experts, dtype=torch.bool, device='cuda') + logical_mask[torch.tensor(replica_slots, dtype=torch.long, device='cuda')] = False + scores[:, ~logical_mask] = -float('inf') + if global_bias is not None: + hot_order = torch.argsort(global_bias.masked_fill(~logical_mask, -float('inf')), descending=True).tolist() + else: + hot_order = torch.arange(num_experts, device='cuda')[logical_mask].tolist() + hot_experts = hot_order[:args.num_redundant_experts] + replica_for = {int(h): int(s) for h, s in zip(hot_experts, replica_slots)} + if rank_idx == 0 and print_perf: + print( + f'eplb_sim redundant={args.num_redundant_experts} ' + f'replicas_per_rank={num_replicas_per_rank} ' + f'dispatch={args.replica_dispatch}', + flush=True, + ) + + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + if replica_for and args.replica_dispatch == 'hash': + token_ids = torch.arange(num_tokens, device='cuda')[:, None] + slot_ids = torch.arange(num_topk, device='cuda')[None, :] + choose_replica = ((token_ids * num_topk + slot_ids + rank_idx) & 1).bool() + mapped = topk_idx.clone() + for logical_expert, replica_slot in replica_for.items(): + mapped = torch.where((topk_idx == logical_expert) & choose_replica, + torch.full_like(mapped, replica_slot), mapped) + topk_idx = mapped + elif replica_for and args.replica_dispatch == 'static': + mapped = topk_idx.clone() + for logical_expert, replica_slot in replica_for.items(): + logical_rank = logical_expert // num_experts_per_rank + replica_rank = replica_slot // num_experts_per_rank + if rank_idx == logical_rank: + chosen = logical_expert + elif rank_idx == replica_rank: + chosen = replica_slot + else: + chosen = replica_slot if ((rank_idx + logical_expert) & 1) else logical_expert + if chosen != logical_expert: + mapped = torch.where(topk_idx == logical_expert, + torch.full_like(mapped, chosen), mapped) + topk_idx = mapped + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_w.masked_fill_(topk_idx < 0, 0) + + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), + ) + + cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + + def run_fused(): + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + return y + + run_fused() + dist.barrier() + t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True) + + # Per-rank token receive counts + per-local-expert distribution + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + local_mask = ((gathered_topk_idx >= rank_idx * num_experts_per_rank) & + (gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)) + gathered_local = gathered_topk_idx.clone() + gathered_local[~local_mask] = -1 + num_recv_tokens = int(local_mask.sum().item()) + num_touched_experts = max(0, int(torch.unique(gathered_local.flatten()).numel()) - 1) + + if num_recv_tokens > 0: + per_expert_count = torch.zeros(num_experts_per_rank, dtype=torch.int64, device='cuda') + local_ids = gathered_local[local_mask] - rank_idx * num_experts_per_rank + per_expert_count.scatter_add_(0, local_ids.to(torch.int64), + torch.ones_like(local_ids, dtype=torch.int64)) + counts = per_expert_count.cpu().tolist() + local_max = max(counts) + local_mean = sum(counts) / num_experts_per_rank + local_max_mean = local_max / local_mean if local_mean > 0 else 0.0 + else: + local_max_mean = 0.0 + + # Cross-rank aggregation + info = torch.tensor([t_fused, float(num_recv_tokens), float(num_touched_experts), + local_max_mean], device='cuda', dtype=torch.float64) + gather_buf = [torch.zeros_like(info) for _ in range(num_ranks)] + dist.all_gather(gather_buf, info, group=group) + + if rank_idx == 0 and print_perf: + all_t = [g[0].item() for g in gather_buf] + all_recv = [int(g[1].item()) for g in gather_buf] + all_touch = [int(g[2].item()) for g in gather_buf] + all_lmm = [g[3].item() for g in gather_buf] + t_mean = sum(all_t) / num_ranks + t_max = max(all_t) + t_min = min(all_t) + recv_mean = sum(all_recv) / num_ranks + recv_max = max(all_recv) + recv_min = min(all_recv) + inter_max_mean = t_max / t_mean if t_mean > 0 else 0.0 + # Aggregate TFLOPS based on max wall time (real serving sees max) + total_flops = 2 * sum(all_recv) * (hidden * intermediate_hidden * 3) + tflops_agg = total_flops / 1e12 / t_max if t_max > 0 else 0.0 + # Per-rank avg view: mean recv across ranks per mean time + tflops_mean = (2 * recv_mean * hidden * intermediate_hidden * 3) / 1e12 / t_mean + + print(f'tokens={num_tokens:5d} ' + f'recv[mean/max/min]={recv_mean:6.0f}/{recv_max}/{recv_min} ' + f'local_max/mean[mean]={sum(all_lmm)/num_ranks:.2f} ' + f'inter_max/mean={inter_max_mean:.3f} ' + f't[mean/max/min]us={t_mean*1e6:7.1f}/{t_max*1e6:7.1f}/{t_min*1e6:7.1f} ' + f'TFLOPS_agg={tflops_agg:6.1f} TFLOPS_mean={tflops_mean:6.1f}', + flush=True) + + dist.barrier() + buffer.destroy() + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print(f'[SKIP] requires SM90', once_in_node=True) + dist.destroy_process_group() + return + + batches = args.batches if args.batches else [1, 2, 4, 8, 16, 32] + + if rank_idx == 0: + print(f'SM90 MegaMoE bench: ranks={num_ranks} hidden={args.hidden} ' + f'ih={args.intermediate_hidden} experts={args.num_experts} ' + f'topk={args.num_topk} skew_alpha={args.skew_alpha} ' + f'skew_gain={args.skew_gain} masked_ratio={args.masked_ratio} ' + f'fast_math={bool(args.fast_math)}', flush=True) + + num_max_tokens_per_rank = max(batches) + for num_tokens in batches: + _run_one_config( + args, num_tokens, num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + num_ranks, rank_idx, group, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math), + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SM90 MegaMoE bench (skew-aware)') + + parser.add_argument('--ncu-profile-only', action='store_true') + parser.add_argument('--num-processes', type=int, default=8) + parser.add_argument('--local-rank-idx', type=int, default=None) + + parser.add_argument('--batches', type=int, nargs='+', default=None) + parser.add_argument('--hidden', type=int, default=4096) + parser.add_argument('--intermediate-hidden', type=int, default=2048) + parser.add_argument('--num-experts', type=int, default=256) + parser.add_argument('--num-topk', type=int, default=8) + parser.add_argument('--activation-clamp', type=float, default=10.0) + parser.add_argument('--masked-ratio', type=float, default=0.0) + parser.add_argument('--fast-math', type=int, default=1) + parser.add_argument('--num-tests', type=int, default=20) + parser.add_argument('--skew-alpha', type=float, default=0.0, + help='Zipf exponent for expert popularity; 0=uniform') + parser.add_argument('--skew-gain', type=float, default=1.0, + help='Multiplier on log-bias added to scores') + parser.add_argument('--skew-seed', type=int, default=0, + help='Seed for the global hot/cold expert permutation') + parser.add_argument('--num-redundant-experts', type=int, default=0, + help='Benchmark-only EPLB simulation: reserve physical expert slots as hot-expert replicas') + parser.add_argument('--replica-dispatch', choices=('hash', 'static'), default='hash', + help='Replica remap model: token-level hash or SGLang static source-rank approximation') + + args = parser.parse_args() + + if args.local_rank_idx is not None: + test(args.local_rank_idx, args.num_processes, args) + else: + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/bench_mega_moe_sm90_skewed.py b/tests/bench_mega_moe_sm90_skewed.py new file mode 100644 index 0000000000..eb72f870ba --- /dev/null +++ b/tests/bench_mega_moe_sm90_skewed.py @@ -0,0 +1,231 @@ +"""SM90 MegaMoE bench with power-law expert load skew. + +Adds --skew-alpha (Zipf exponent). The score generator biases the topk to +favor a power-law subset of experts: + + bias[i] = log( 1 / (rank[i]+1)^alpha * num_experts ) * gain + scores = N(0, 1) + bias + +alpha=0 => uniform (matches upstream bench). +alpha=1 => Zipfian harmonic — hot experts get ~ln(N) more tokens than cold. +alpha=2 => strong skew — top expert can get 5-10x mean. + +Per-rank expert permutation is randomized so hot experts spread across ranks +(matches production behavior where placement is shuffled). + +Reports observed skew via num_recv/num_touched and per-expert recv distribution. +""" +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, get_arch_major + + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _generate_skewed_scores(num_tokens: int, num_experts: int, num_topk: int, + alpha: float, gain: float, generator): + """Generate scores that produce a Zipfian topk distribution. + + Adds a per-expert bias derived from Zipfian rank to N(0,1) scores. The + final topk[i] is therefore drawn from a softer-than-Zipf but still + power-law-tailed distribution. + """ + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, + device='cuda', generator=generator) + if alpha > 0: + # Zipf weights: probs[k] = 1 / (k+1)^alpha, normalized + ranks = torch.arange(1, num_experts + 1, device='cuda', dtype=torch.float) + probs = 1.0 / ranks.pow(alpha) + probs = probs / probs.sum() + # Bias scale relative to uniform expert prob + bias = torch.log(probs * num_experts) * gain + # Shuffle so hot experts are not concentrated at low ids (mimics + # production where init_expert_location randomizes placement). + perm = torch.randperm(num_experts, device='cuda', generator=generator) + bias = bias[perm] + scores = scores + bias.unsqueeze(0) + return scores + + +def _run_one_config(args, num_tokens, num_max_tokens_per_rank, + hidden, intermediate_hidden, + num_experts, num_topk, num_ranks, rank_idx, group): + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + ) + + gen = torch.Generator(device='cuda') + gen.manual_seed(rank_idx * 1009 + int(num_tokens)) + + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda', generator=gen) + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda', generator=gen) * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda', generator=gen) * 0.05 + + scores = _generate_skewed_scores(num_tokens, num_experts, num_topk, + args.skew_alpha, args.skew_gain, gen) + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), + ) + + cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + + def run_fused(): + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=10.0, + fast_math=True, + ) + return y + + run_fused() + dist.barrier() + t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True) + + # Local expert count distribution + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + local_mask = ((gathered_topk_idx >= rank_idx * num_experts_per_rank) & + (gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)) + gathered_local = gathered_topk_idx.clone() + gathered_local[~local_mask] = -1 + num_recv_tokens = int(local_mask.sum().item()) + num_touched_experts = max(0, int(torch.unique(gathered_local.flatten()).numel()) - 1) + + # Per-local-expert recv count (for skew diagnostics) + if num_recv_tokens > 0: + per_expert_count = torch.zeros(num_experts_per_rank, dtype=torch.int64, device='cuda') + local_ids = gathered_local[local_mask] - rank_idx * num_experts_per_rank + per_expert_count.scatter_add_(0, local_ids.to(torch.int64), + torch.ones_like(local_ids, dtype=torch.int64)) + counts = per_expert_count.cpu().tolist() + mean_c = sum(counts) / num_experts_per_rank + max_c = max(counts) + min_c = min(counts) + nonzero = sum(1 for c in counts if c > 0) + skew_max_mean = max_c / mean_c if mean_c > 0 else 0.0 + else: + max_c = min_c = mean_c = 0 + skew_max_mean = 0.0 + nonzero = 0 + + safe_div = lambda a, b: float('nan') if b == 0 else a / b + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden + + num_touched_experts * hidden * intermediate_hidden + + num_recv_tokens * hidden + + num_recv_tokens * intermediate_hidden + + num_recv_tokens * intermediate_hidden + + num_recv_tokens * hidden * 2 + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + dist_print( + f' tokens={num_tokens:5d} recv={num_recv_tokens:6d} nz_exp={nonzero:3d}/{num_experts_per_rank} ' + f'max/mean={skew_max_mean:.2f} ' + f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', + once_in_node=True, + ) + + dist.barrier() + buffer.destroy() + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print(f'[SKIP] requires SM90', once_in_node=True) + dist.destroy_process_group() + return + + batches = args.batches if args.batches else [1, 2, 4, 8, 16, 32] + + dist_print( + f'SM90 MegaMoE bench (skew_alpha={args.skew_alpha} gain={args.skew_gain}): ' + f'ranks={num_ranks} hidden={args.hidden} ih={args.intermediate_hidden} ' + f'experts={args.num_experts} topk={args.num_topk} fast_math=True', + once_in_node=True, + ) + + num_max_tokens_per_rank = max(batches) + for num_tokens in batches: + _run_one_config( + args, num_tokens, num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + num_ranks, rank_idx, group, + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SM90 MegaMoE bench (skewed routing)') + parser.add_argument('--num-processes', type=int, default=8) + parser.add_argument('--local-rank-idx', type=int, default=None) + parser.add_argument('--batches', type=int, nargs='+', default=None) + parser.add_argument('--hidden', type=int, default=4096) + parser.add_argument('--intermediate-hidden', type=int, default=2048) + parser.add_argument('--num-experts', type=int, default=256) + parser.add_argument('--num-topk', type=int, default=8) + parser.add_argument('--num-tests', type=int, default=20) + parser.add_argument('--skew-alpha', type=float, default=0.0, + help='Zipf exponent for expert bias (0=uniform)') + parser.add_argument('--skew-gain', type=float, default=1.0, + help='Scale of log-bias added to scores (default 1.0)') + args = parser.parse_args() + if args.local_rank_idx is not None: + test(args.local_rank_idx, args.num_processes, args) + else: + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/bench_w4a8_mega_moe_sm90.py b/tests/bench_w4a8_mega_moe_sm90.py new file mode 100644 index 0000000000..0199fb27d3 --- /dev/null +++ b/tests/bench_w4a8_mega_moe_sm90.py @@ -0,0 +1,250 @@ +"""SM90 (Hopper) MegaMoE benchmark / NCU-profile harness. + +Mirrors ``tests/test_mega_moe.py``'s ``--ncu-profile-only`` / +``--local-rank-idx`` interface so the same ``scripts/run_ncu_mega_moe.sh`` +pattern can drive it for SM90. + +In normal (non-NCU) mode it sweeps a list of ``num_tokens`` values (default: +1, 2, 4, 8, 16, 32) and reports per-call kernel time via the same +``bench_kineto`` helper used by the SM100 perf test, plus a rough TFLOPS / +HBM GB/s figure useful for tracking optimisation deltas. +""" + +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, calc_diff, get_arch_major + + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _run_one_config(args, num_tokens, num_max_tokens_per_rank, + hidden, intermediate_hidden, + num_experts, num_topk, num_ranks, rank_idx, group, + activation_clamp, fast_math, + print_perf=True): + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + # Symmetric buffer (one per config: cheaper to recreate than to keep max-size) + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + ) + + # Inputs (bf16, then quantised) + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_w.masked_fill_(topk_idx < 0, 0) + + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), + ) + + phase_profile_enabled = os.environ.get('DG_SM90_MOE_PHASE_PROFILE', '0') != '0' + phase_profile_ints = 64 if phase_profile_enabled else 0 + cum_stats = torch.zeros(num_experts_per_rank + phase_profile_ints, dtype=torch.int, device='cuda') + + # Stage inputs once; bench-loop re-copies them each call (bench helper expects + # an idempotent ``fn``). + def run_fused(): + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.w4a8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + return y + + if args.ncu_profile_only: + dist_print(f'[NCU] tokens={num_tokens} hidden={hidden} ih={intermediate_hidden}', + once_in_node=True) + run_fused() + torch.cuda.synchronize() + dist.barrier() + buffer.destroy() + return + + # Warm up + benchmark + run_fused() + dist.barrier() + if phase_profile_enabled: + cum_stats.zero_() + torch.cuda.synchronize() + dist.barrier() + # NSYS MULTI-ITER (aichenf): N timed iters with barrier+sleep between them. + # bench_kineto returns 1 under DG_USE_NVIDIA_TOOLS=1, but this loop puts + # multiple mega_moe instances on the nsys timeline so we can measure variance. + import os as _os + _nsys_iters = int(_os.environ.get('NSYS_ITERS', '0')) + if _nsys_iters > 0: + for _it in range(_nsys_iters): + torch.cuda.synchronize() + dist.barrier() + torch.cuda._sleep(int(2e7)) # 10ms gap between iters + dist.barrier() + run_fused() + torch.cuda.synchronize() + dist.barrier() + t_fused = bench_kineto(run_fused, 'sm90_w4a8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True) + + # Count tokens that landed on this rank for stats + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() + + safe_div = lambda a, b: float('nan') if b == 0 else a / b + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + num_touched_experts = max(0, torch.unique(gathered_topk_idx.flatten()).numel() - 1) + # FP8 weights = 1 byte, FP8 acts = 1 byte, BF16 output = 2 bytes + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden + # L1 weights + num_touched_experts * hidden * intermediate_hidden + # L2 weights + num_recv_tokens * hidden + # L1 acts read + num_recv_tokens * intermediate_hidden + # L1 out write + num_recv_tokens * intermediate_hidden + # L2 acts read + num_recv_tokens * hidden * 2 # L2 out write + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + if print_perf: + dist_print( + f' tokens={num_tokens:4d} recv={num_recv_tokens:5d} experts={num_touched_experts:4d} ' + f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', + once_in_node=True, + ) + if phase_profile_enabled: + torch.cuda.synchronize() + profile = cum_stats[num_experts_per_rank:num_experts_per_rank + 48].view(torch.int64).cpu().tolist() + names = [ + 'dispatch_total', 'dispatch_pull', 'math_loop', 'combine_barrier', + 'combine_reduce', 'gemm_core', 'l1_epilogue', 'l2_epilogue', + ] + for i, name in enumerate(names): + total, max_v, count = profile[i], profile[8 + i], profile[16 + i] + avg = float(total) / count if count else 0.0 + dist_print( + f' phase {name:16s} avg={avg:10.0f} max={max_v:10d} count={count}', + once_in_node=True, + ) + + dist.barrier() + buffer.destroy() + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + forced_num_sms = int(os.environ.get('DG_SM90_MOE_SET_NUM_SMS', '0')) + if forced_num_sms > 0: + deep_gemm.set_num_sms(forced_num_sms) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print(f'[SKIP] requires SM90, got SM{get_arch_major()}0', once_in_node=True) + dist.destroy_process_group() + return + + if args.batches is None: + batches = [1, 2, 4, 8, 16, 32] + else: + batches = args.batches + + dist_print( + f'SM90 MegaMoE bench: ranks={num_ranks} hidden={args.hidden} ' + f'ih={args.intermediate_hidden} experts={args.num_experts} topk={args.num_topk} ' + f'masked_ratio={args.masked_ratio} fast_math={bool(args.fast_math)}', + once_in_node=True, + ) + + # In NCU mode we run only one batch (the first one in `batches`) so that + # ncu's `--launch-count 1` is unambiguous. + if args.ncu_profile_only: + batches = batches[:1] + + num_max_tokens_per_rank = max(batches) + for num_tokens in batches: + _run_one_config( + args, num_tokens, num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + num_ranks, rank_idx, group, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math), + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SM90 MegaMoE benchmark') + + parser.add_argument('--ncu-profile-only', action='store_true') + parser.add_argument('--num-processes', type=int, default=8) + parser.add_argument('--local-rank-idx', type=int, default=None) + + parser.add_argument('--batches', type=int, nargs='+', default=None, + help='List of num_tokens to sweep (default: 1 2 4 8 16 32)') + parser.add_argument('--hidden', type=int, default=7168) + parser.add_argument('--intermediate-hidden', type=int, default=2048) + parser.add_argument('--num-experts', type=int, default=256) + parser.add_argument('--num-topk', type=int, default=8) + parser.add_argument('--activation-clamp', type=float, default=10.0) + parser.add_argument('--masked-ratio', type=float, default=0.0) + parser.add_argument('--fast-math', type=int, default=1) + parser.add_argument('--num-tests', type=int, default=20) + + args = parser.parse_args() + + if args.local_rank_idx is not None: + test(args.local_rank_idx, args.num_processes, args) + else: + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/test_mega_moe.py b/tests/test_mega_moe.py index e74b65e5d1..83e8d622f7 100644 --- a/tests/test_mega_moe.py +++ b/tests/test_mega_moe.py @@ -151,7 +151,7 @@ def run_fused(): num_topk=num_topk, use_fp8_dispatch=True, explicitly_destroy=True, allow_multiple_reduction=False, - gpu_timeout_secs=10, cpu_timeout_secs=30 + num_gpu_timeout_secs=10, num_cpu_timeout_secs=30 ) if is_legacy_loaded else None def run_baseline(): diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py new file mode 100644 index 0000000000..021382710e --- /dev/null +++ b/tests/test_mega_moe_hopper.py @@ -0,0 +1,1525 @@ +""" +H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。 + +结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8: + * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`), + 使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。 + * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine, + 使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation + per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA + 同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照, + 不是 bitwise apples-to-apples correctness oracle。 + * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us / + reduction us / `t_baseline / t_fused` legacy 比。 +""" + +import argparse +import math +import os +import random +import torch +import torch.distributed as dist +import triton +import triton.language as tl +from typing import Tuple + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, get_arch_major + +try: + import deep_ep as _deep_ep + _deep_ep_import_error = None +except Exception as ex: + _deep_ep = None + _deep_ep_import_error = ex + + +# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名, +# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段 +SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl" + + +# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准 +FP8_E4M3_MAX = 448.0 +# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例, +# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。 +_FP8_E4M3_MAX_TL = tl.constexpr(448.0) +L1_ACT_SF_GRAN = 128 +FUSED_L2_ACT_SF_GRAN = 64 +BASELINE_L2_ACT_SF_GRAN = 128 +WEIGHT_SF_GRAN_MN = 128 +WEIGHT_SF_GRAN_K = 128 + + +# ============================================================================ +# 模块 1:Triton SwiGLU + FP8 量化内核 +# ---------------------------------------------------------------------------- +# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按 +# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则, +# 避免再额外引入 exact-FP32-scale 差异。 +# 输入 x : (M, 2*H) bf16,内层是 [gate_part | up_part] +# 输入 topk_w : (M,) fp32,可选 +# 输出 y : (M, H) fp8_e4m3fn +# 输出 y_sf : (M, H/BLOCK_K) fp32 行主序 +# ============================================================================ + + +@triton.jit +def _swiglu_apply_weight_to_fp8_kernel( + x_ptr, + topk_w_ptr, + y_ptr, + y_sf_ptr, + M, + H, # 运行时形状 + stride_xm, + stride_xn, # x: (M, 2H) 的 stride + stride_ym, + stride_yn, # y: (M, H) 的 stride + stride_sfm, + stride_sfk, # y_sf: (M, H/BLOCK_K) 的 stride + clamp_value, # 当 HAS_CLAMP=False 时这个参数无意义 + HAS_TOPK: tl.constexpr, + HAS_CLAMP: tl.constexpr, + USE_UE8M0_SCALE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, # = num_per_channels +): + # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列) + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + # 当前 K-block 内的列索引(在 H 维度,不是 2H) + offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))---- + # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的 + gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn + up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn + gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32) + up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32) + + # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)---- + if HAS_CLAMP: + gate = tl.minimum(gate, clamp_value) + up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value) + + # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)---- + y = gate * tl.sigmoid(gate) * up + + # ---- 4) 可选 MoE 权重缩放(per-token 标量)---- + if HAS_TOPK: + w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0) + y = y * w[:, None] + + # ---- 5) 当前 K-block 内每行 absmax → scale ---- + amax = tl.max(tl.abs(y), axis=1) # (BLOCK_M,) + sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30) + if USE_UE8M0_SCALE: + # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv: + # scale = 2 ** ceil(log2(amax / 448)). + sf = tl.exp2(tl.ceil(tl.log2(sf))) + + # ---- 6) 量化为 FP8 e4m3fn ---- + y_fp8 = (y / sf[:, None]).to(tl.float8e4nv) + + # ---- 7) 写回 y 和 sf ---- + y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn + tl.store(y_ptrs, y_fp8, mask=mask_m[:, None]) + + sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk + tl.store(sf_ptrs, sf, mask=mask_m) + + +def swiglu_apply_weight_to_fp8_triton( + x: torch.Tensor, + topk_weights: torch.Tensor | None, + clamp_value: float | None = None, + num_per_channels: int = BASELINE_L2_ACT_SF_GRAN, + use_ue8m0_scale: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """SwiGLU + FP8 量化。语义等价于 PyTorch reference: + gate, up = x[:, :H], x[:, H:] + y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w + y_sf = y.view(M, H/np, np).abs().amax(-1) / 448 + if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf) + y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8) + """ + assert x.is_cuda and x.dtype == torch.bfloat16 + assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位" + M, two_H = x.shape + H = two_H // 2 + assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍" + + y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device) + y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device) + + # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调 + BLOCK_M = 16 + grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels) + + # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位 + topk_ptr = topk_weights if topk_weights is not None else x + + _swiglu_apply_weight_to_fp8_kernel[grid]( + x, + topk_ptr, + y, + y_sf, + M, + H, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + y_sf.stride(0), + y_sf.stride(1), + float(clamp_value) if clamp_value is not None else 0.0, + HAS_TOPK=topk_weights is not None, + HAS_CLAMP=clamp_value is not None, + USE_UE8M0_SCALE=use_ue8m0_scale, + BLOCK_M=BLOCK_M, + BLOCK_K=num_per_channels, + ) + return y, y_sf + + +# ============================================================================ +# 模块 2:grouped weight 的 (128, 128) FP8 块量化 +# ---------------------------------------------------------------------------- +# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定: +# 每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。 +# 与 SM100 FP4 路径的差异: +# * 不需要 deep_gemm.transform_sf_into_required_layout +# * SF 是 FP32,不是 UE8M0 packed +# ============================================================================ + + +def _quantize_grouped_fp8_block_128_128( + w: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。""" + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数" + + # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + + # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块 + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128) + sf = amax / FP8_E4M3_MAX + + # 量化:每个元素除以所属子块的 sf 后转 FP8 + # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +# ============================================================================ +# 模块 3:尝试导入 deep_ep(用于 dispatch / combine) +# ============================================================================ + + +def _import_deep_ep(): + if _deep_ep is None: + dist_print(f"Failed to import deep_ep: {_deep_ep_import_error}", once_in_node=True) + return None + return _deep_ep + + +class _DeepEPHandle: + def __init__(self, raw_handle, psum_num_recv_tokens_per_expert: torch.Tensor): + self.raw_handle = raw_handle + self.psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert + + +class _DeepEPBufferCompat: + """Compatibility shim for newer DeepEP versions that expose Buffer, not ElasticBuffer.""" + + def __init__(self, deep_ep, group, num_nvl_bytes: int): + self.buffer = deep_ep.Buffer( + group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=0, + explicitly_destroy=True, + ) + + def dispatch( + self, + x, + *, + topk_idx, + topk_weights, + num_experts: int, + expert_alignment: int, + **_, + ): + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = ( + self.buffer.get_dispatch_layout(topk_idx, num_experts) + ) + recv_x, _, recv_topk_weights, num_recv_tokens_per_expert, raw_handle, event = self.buffer.dispatch( + x, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + topk_idx=topk_idx, + topk_weights=topk_weights, + expert_alignment=expert_alignment, + ) + psum = torch.tensor( + num_recv_tokens_per_expert, dtype=torch.int, device=topk_idx.device + ).cumsum(dim=0, dtype=torch.int) + return recv_x, None, recv_topk_weights, _DeepEPHandle(raw_handle, psum), event + + def combine(self, x, *, handle): + raw_handle = handle.raw_handle if isinstance(handle, _DeepEPHandle) else handle + return self.buffer.combine(x, handle=raw_handle) + + def barrier(self, use_comm_stream: bool = False): + torch.cuda.synchronize() + dist.barrier() + + def destroy(self): + self.buffer.destroy() + + +def _make_deep_ep_buffer(deep_ep, group, num_max_tokens_per_rank, hidden, num_topk, sym_buffer_bytes): + if hasattr(deep_ep, "ElasticBuffer"): + return deep_ep.ElasticBuffer( + group, + num_max_tokens_per_rank=num_max_tokens_per_rank, + hidden=hidden, + num_topk=num_topk, + use_fp8_dispatch=True, + explicitly_destroy=True, + allow_multiple_reduction=False, + ) + nvl_alignment = 2 * 1024 * 1024 + num_nvl_bytes = ((int(sym_buffer_bytes) + nvl_alignment - 1) // nvl_alignment) * nvl_alignment + return _DeepEPBufferCompat(deep_ep, group, num_nvl_bytes=num_nvl_bytes) + + +# ============================================================================ +# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖) +# ============================================================================ + + +def _bench_cuda_events( + fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0 +) -> float: + """返回 fn 的中位数耗时(秒)。""" + for _ in range(num_warmup): + fn() + torch.cuda.synchronize() + times_ms = [] + for _ in range(num_repeat): + # L2 flush,避免重复访问命中 cache 让测时偏低 + if l2_flush_gb > 0: + free_bytes, _ = torch.cuda.mem_get_info() + flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5)) + if flush_bytes >= 4: + torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + fn() + e.record() + e.synchronize() + times_ms.append(s.elapsed_time(e)) + times_ms.sort() + return times_ms[len(times_ms) // 2] / 1e3 + + +# ============================================================================ +# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline +# ============================================================================ + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print( + f"[SKIP] test_mega_moe_hopper requires SM90; got SM{get_arch_major()}0", + once_in_node=True, + ) + dist.destroy_process_group() + return + + # 形状参数(与 test_mega_moe.py 同名同义) + num_max_tokens_per_rank = args.num_max_tokens_per_rank + num_tokens = ( + max( + 0, + args.num_max_tokens_per_rank + - random.randint(0, args.num_max_removed_tokens), + ) + if args.num_tokens == 0 + else args.num_tokens + ) + hidden, intermediate_hidden = args.hidden, args.intermediate_hidden + num_experts, num_topk = args.num_experts, args.num_topk + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + assert num_experts % num_ranks == 0, ( + f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除" + ) + + # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe): + # * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF) + # * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列) + assert hidden % 128 == 0 + assert intermediate_hidden % 128 == 0 + assert intermediate_hidden // 64 <= 64, ( + f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}" + ) + + # ---- 创建 BF16 输入:token 与两层 weight ---- + # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维 + x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起) + l1_weights_bf16 = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, + device="cuda", + ) + # L2 weight: 每个 expert 把 intermediate_hidden → hidden + l2_weights_bf16 = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, + device="cuda", + ) + + # 路由:scores → topk_idx (M, K) + topk_weights (M, K) + eplb_replica_for = {} + eplb_replica_slots: list[int] = [] + if args.num_redundant_experts > 0: + assert args.num_redundant_experts % num_ranks == 0, "num_redundant_experts must divide num_ranks" + num_replicas_per_rank = args.num_redundant_experts // num_ranks + assert 0 < num_replicas_per_rank < num_experts_per_rank, "invalid redundant expert count" + for r in range(num_ranks): + base = r * num_experts_per_rank + eplb_replica_slots += list(range(base + num_experts_per_rank - num_replicas_per_rank, + base + num_experts_per_rank)) + logical_mask = torch.ones(num_experts, dtype=torch.bool, device="cuda") + logical_mask[torch.tensor(eplb_replica_slots, dtype=torch.long, device="cuda")] = False + if args.score_powerlaw_alpha > 0: + expert_rank = torch.arange(1, num_experts + 1, dtype=torch.float, device="cuda") + bias_for_hot = torch.pow(expert_rank, -args.score_powerlaw_alpha) + bias_for_hot = (bias_for_hot - bias_for_hot.mean()) / (bias_for_hot.std() + 1e-6) + hot_order = torch.argsort(bias_for_hot.masked_fill(~logical_mask, -float("inf")), descending=True).cpu().tolist() + else: + hot_order = torch.arange(num_experts, device="cuda")[logical_mask].cpu().tolist() + hot_experts = hot_order[:args.num_redundant_experts] + eplb_replica_for = {int(h): int(s) for h, s in zip(hot_experts, eplb_replica_slots)} + if rank_idx == 0: + print( + f" > eplb_sim redundant={args.num_redundant_experts} " + f"replicas_per_rank={num_replicas_per_rank} " + f"dispatch={args.replica_dispatch}", + flush=True, + ) + + def make_scores(): + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda") + if args.score_powerlaw_alpha > 0: + expert_rank = torch.arange(1, num_experts + 1, dtype=torch.float, device="cuda") + bias = torch.pow(expert_rank, -args.score_powerlaw_alpha) + bias = (bias - bias.mean()) / (bias.std() + 1e-6) + scores = scores + args.score_powerlaw_scale * bias[None, :] + if eplb_replica_slots: + scores[:, torch.tensor(eplb_replica_slots, dtype=torch.long, device="cuda")] = -float("inf") + return scores + + def apply_eplb_replicas(idx: torch.Tensor) -> torch.Tensor: + if not eplb_replica_for: + return idx + mapped = idx.clone() + if args.replica_dispatch == "hash": + token_ids = torch.arange(num_tokens, device="cuda")[:, None] + slot_ids = torch.arange(num_topk, device="cuda")[None, :] + choose_replica = ((token_ids * num_topk + slot_ids + rank_idx) & 1).bool() + for logical_expert, replica_slot in eplb_replica_for.items(): + mapped = torch.where((idx == logical_expert) & choose_replica, + torch.full_like(mapped, replica_slot), mapped) + elif args.replica_dispatch == "static": + for logical_expert, replica_slot in eplb_replica_for.items(): + logical_rank = logical_expert // num_experts_per_rank + replica_rank = replica_slot // num_experts_per_rank + if rank_idx == logical_rank: + chosen = logical_expert + elif rank_idx == replica_rank: + chosen = replica_slot + else: + chosen = replica_slot if ((rank_idx + logical_expert) & 1) else logical_expert + if chosen != logical_expert: + mapped = torch.where(idx == logical_expert, + torch.full_like(mapped, chosen), mapped) + else: + raise ValueError(f"unknown replica_dispatch={args.replica_dispatch}") + return mapped + + if args.routing_mode == "balanced": + assert args.masked_ratio == 0.0, "balanced routing does not support masked_ratio" + assert (num_tokens * num_topk) % num_experts == 0, "balanced routing requires M*topk divisible by num_experts" + token_ids = torch.arange(num_tokens, device="cuda", dtype=torch.long)[:, None] + topk_offsets = torch.arange(num_topk, device="cuda", dtype=torch.long)[None, :] + topk_idx = (token_ids * num_topk + topk_offsets) % num_experts + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") + elif args.routing_mode in ("balanced-shuffled", "balanced-shuffled-score"): + assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" + assert num_tokens % num_experts == 0, f"{args.routing_mode} requires M divisible by num_experts" + assert num_experts % num_topk == 0, f"{args.routing_mode} requires experts divisible by topk" + token_perm = torch.randperm(num_tokens, device="cuda") + expert_perm = torch.randperm(num_experts, device="cuda") + positions = torch.arange(num_tokens, device="cuda", dtype=torch.long) + slot_stride = num_experts // num_topk + topk_idx = torch.empty((num_tokens, num_topk), dtype=torch.long, device="cuda") + for slot in range(num_topk): + expert_ids = expert_perm[(positions + slot * slot_stride) % num_experts] + topk_idx[token_perm, slot] = expert_ids + if args.routing_mode == "balanced-shuffled-score": + scores = make_scores() + topk_weights = torch.gather(scores, 1, topk_idx) + else: + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") + elif args.routing_mode == "topk-repair-budget-softmax": + assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" + assert (num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact local expert capacity" + scores = make_scores() + topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + expert_capacity = (num_tokens * num_topk) // num_experts + scores_cpu = scores.cpu() + probs_cpu = torch.softmax(scores_cpu, dim=1) + selected = topk_idx.cpu().tolist() + original_selected = [row[:] for row in selected] + selected_sets = [set(row) for row in selected] + counts = [0] * num_experts + for row in selected: + for expert in row: + counts[expert] += 1 + + selected_mass = 0.0 + for token, row in enumerate(selected): + selected_mass += sum(float(probs_cpu[token, expert].item()) for expert in row) + original_mass = selected_mass + changed_slots = 0 + touched_tokens = set() + budget = args.repair_mass_drop_budget + + while True: + overflow = {e for e, c in enumerate(counts) if c > expert_capacity} + underfull = [e for e, c in enumerate(counts) if c < expert_capacity] + if not overflow or not underfull: + break + candidates = [] + for token, row in enumerate(selected): + token_set = selected_sets[token] + for slot, old_expert in enumerate(row): + if old_expert not in overflow: + continue + old_score = float(scores_cpu[token, old_expert].item()) + old_mass = float(probs_cpu[token, old_expert].item()) + for new_expert in underfull: + if new_expert in token_set: + continue + new_mass = float(probs_cpu[token, new_expert].item()) + loss = old_score - float(scores_cpu[token, new_expert].item()) + mass_loss = old_mass - new_mass + candidates.append((loss, mass_loss, token, slot, old_expert, new_expert)) + if not candidates: + break + candidates.sort(key=lambda x: x[0]) + changed = False + for _, mass_loss, token, slot, old_expert, new_expert in candidates: + if counts[old_expert] <= expert_capacity or counts[new_expert] >= expert_capacity: + continue + if selected[token][slot] != old_expert or new_expert in selected_sets[token]: + continue + next_mass = selected_mass - mass_loss + next_drop = (original_mass - next_mass) / max(original_mass, 1e-12) + if next_drop > budget: + continue + selected[token][slot] = new_expert + selected_sets[token].remove(old_expert) + selected_sets[token].add(new_expert) + counts[old_expert] -= 1 + counts[new_expert] += 1 + selected_mass = next_mass + changed_slots += 1 + touched_tokens.add(token) + changed = True + if not changed: + break + + topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") + topk_weights = torch.softmax(torch.gather(scores, 1, topk_idx), dim=-1) + mass_drop = (original_mass - selected_mass) / max(original_mass, 1e-12) + over_slots = sum(max(0, c - expert_capacity) for c in counts) + max_count = max(counts) + dist_print( + f" > bounded_repair rank={rank_idx}: budget={budget:.3f} " + f"changed={changed_slots / max(num_tokens * num_topk, 1) * 100:.1f}% " + f"touched={len(touched_tokens) / max(num_tokens, 1) * 100:.1f}% " + f"mass_drop={mass_drop * 100:.1f}% over_slots={over_slots} max_count={max_count}", + once_in_node=False, + ) + elif args.routing_mode == "global-repair-budget-softmax": + assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" + scores = make_scores() + all_scores = uneven_all_gather(scores, group=group) + local_num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + all_num_tokens_t = [torch.zeros_like(local_num_tokens) for _ in range(num_ranks)] + dist.all_gather(all_num_tokens_t, local_num_tokens, group=group) + all_num_tokens = [int(x.item()) for x in all_num_tokens_t] + local_offset = sum(all_num_tokens[:rank_idx]) + total_num_tokens = sum(all_num_tokens) + assert (total_num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact global expert capacity" + + all_topk_weights, all_topk_idx = torch.topk(all_scores, num_topk, dim=-1, largest=True, sorted=False) + expert_capacity = (total_num_tokens * num_topk) // num_experts + scores_cpu = all_scores.cpu() + probs_cpu = torch.softmax(scores_cpu, dim=1) + score_order = torch.argsort(all_scores, dim=1, descending=True).cpu().tolist() + selected = all_topk_idx.cpu().tolist() + selected_sets = [set(row) for row in selected] + counts = [0] * num_experts + for row in selected: + for expert in row: + counts[expert] += 1 + + selected_mass = 0.0 + for token, row in enumerate(selected): + selected_mass += sum(float(probs_cpu[token, expert].item()) for expert in row) + original_mass = selected_mass + changed_slots = 0 + touched_tokens = set() + budget = args.repair_mass_drop_budget + max_rounds = max(1, args.repair_max_rounds) + + for _round in range(max_rounds): + overflow = {e for e, c in enumerate(counts) if c > expert_capacity} + if not overflow: + break + underfull = {e for e, c in enumerate(counts) if c < expert_capacity} + if not underfull: + break + candidates = [] + for token, row in enumerate(selected): + token_set = selected_sets[token] + for slot, old_expert in enumerate(row): + if old_expert not in overflow: + continue + new_expert = -1 + for cand in score_order[token]: + if cand in underfull and cand not in token_set: + new_expert = cand + break + if new_expert < 0: + continue + old_score = float(scores_cpu[token, old_expert].item()) + old_mass = float(probs_cpu[token, old_expert].item()) + new_mass = float(probs_cpu[token, new_expert].item()) + loss = old_score - float(scores_cpu[token, new_expert].item()) + mass_loss = old_mass - new_mass + candidates.append((loss, mass_loss, token, slot, old_expert, new_expert)) + if not candidates: + break + candidates.sort(key=lambda x: x[0]) + changed = False + for _, mass_loss, token, slot, old_expert, new_expert in candidates: + if counts[old_expert] <= expert_capacity or counts[new_expert] >= expert_capacity: + continue + if selected[token][slot] != old_expert or new_expert in selected_sets[token]: + continue + next_mass = selected_mass - mass_loss + next_drop = (original_mass - next_mass) / max(original_mass, 1e-12) + if next_drop > budget: + continue + selected[token][slot] = new_expert + selected_sets[token].remove(old_expert) + selected_sets[token].add(new_expert) + counts[old_expert] -= 1 + counts[new_expert] += 1 + selected_mass = next_mass + changed_slots += 1 + touched_tokens.add(token) + changed = True + if not changed: + break + + all_topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") + topk_idx = all_topk_idx[local_offset:local_offset + num_tokens].contiguous() + topk_weights = torch.softmax(torch.gather(scores, 1, topk_idx), dim=-1) + mass_drop = (original_mass - selected_mass) / max(original_mass, 1e-12) + over_slots = sum(max(0, c - expert_capacity) for c in counts) + max_count = max(counts) + if rank_idx == 0: + dist_print( + f" > global_bounded_repair: budget={budget:.3f} " + f"changed={changed_slots / max(total_num_tokens * num_topk, 1) * 100:.1f}% " + f"touched={len(touched_tokens) / max(total_num_tokens, 1) * 100:.1f}% " + f"mass_drop={mass_drop * 100:.1f}% over_slots={over_slots} max_count={max_count}", + once_in_node=False, + ) + elif args.routing_mode in ("topk-repair", "topk-repair-one", "topk-repair-softmax"): + assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" + assert (num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact local expert capacity" + scores = make_scores() + topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + expert_capacity = (num_tokens * num_topk) // num_experts + scores_cpu = scores.cpu() + selected = topk_idx.cpu().tolist() + selected_sets = [set(row) for row in selected] + counts = [0] * num_experts + for row in selected: + for expert in row: + counts[expert] += 1 + + while True: + overflow = {e for e, c in enumerate(counts) if c > expert_capacity} + underfull = [e for e, c in enumerate(counts) if c < expert_capacity] + if not overflow: + break + candidates = [] + underfull_set = set(underfull) + for token, row in enumerate(selected): + token_set = selected_sets[token] + for slot, old_expert in enumerate(row): + if old_expert not in overflow: + continue + old_score = float(scores_cpu[token, old_expert].item()) + for new_expert in underfull: + if new_expert in token_set: + continue + loss = old_score - float(scores_cpu[token, new_expert].item()) + candidates.append((loss, token, slot, old_expert, new_expert)) + assert candidates, "topk-repair could not find a repair candidate" + candidates.sort(key=lambda x: x[0]) + changed = False + for _, token, slot, old_expert, new_expert in candidates: + if counts[old_expert] <= expert_capacity or counts[new_expert] >= expert_capacity: + continue + if selected[token][slot] != old_expert or new_expert in selected_sets[token]: + continue + selected[token][slot] = new_expert + selected_sets[token].remove(old_expert) + selected_sets[token].add(new_expert) + counts[old_expert] -= 1 + counts[new_expert] += 1 + changed = True + assert changed, "topk-repair made no progress" + assert all(c == expert_capacity for c in counts), "topk-repair failed to reach exact capacity" + topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") + if args.routing_mode == "topk-repair-one": + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") + elif args.routing_mode == "topk-repair-softmax": + topk_weights = torch.softmax(torch.gather(scores, 1, topk_idx), dim=-1) + else: + topk_weights = torch.gather(scores, 1, topk_idx) + elif args.routing_mode in ("local-exact-score", "local-exact-score-one"): + assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" + assert (num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact local expert capacity" + scores = make_scores() + expert_capacity = (num_tokens * num_topk) // num_experts + assert expert_capacity % num_topk == 0, f"{args.routing_mode} requires per-slot expert capacity" + per_slot_capacity = expert_capacity // num_topk + score_order = torch.argsort(scores, dim=1, descending=True).cpu().tolist() + selected = [[-1] * num_topk for _ in range(num_tokens)] + selected_sets = [set() for _ in range(num_tokens)] + + import sys + sys.setrecursionlimit(max(10000, num_tokens * 4)) + for slot in range(num_topk): + assignment = [-1] * num_tokens + matched_tokens = [[] for _ in range(num_experts)] + token_order = sorted( + range(num_tokens), + key=lambda t: scores[t, score_order[t][0]].item(), + reverse=True, + ) + + def try_assign(token: int, seen_experts: set[int]) -> bool: + for expert in score_order[token]: + if expert in selected_sets[token] or expert in seen_experts: + continue + seen_experts.add(expert) + if len(matched_tokens[expert]) < per_slot_capacity: + matched_tokens[expert].append(token) + assignment[token] = expert + return True + for idx, other_token in enumerate(list(matched_tokens[expert])): + if try_assign(other_token, seen_experts): + matched_tokens[expert][idx] = token + assignment[token] = expert + return True + return False + + for token in token_order: + assert try_assign(token, set()), "local-exact-score matching failed" + assert all(expert >= 0 for expert in assignment), "local-exact-score left unassigned tokens" + assert all(len(tokens) == per_slot_capacity for tokens in matched_tokens), "local-exact-score left capacity imbalance" + for token, expert in enumerate(assignment): + selected[token][slot] = expert + selected_sets[token].add(expert) + + topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") + if args.routing_mode == "local-exact-score-one": + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") + else: + topk_weights = torch.gather(scores, 1, topk_idx) + elif args.routing_mode == "local-capacity": + assert args.masked_ratio == 0.0, "local-capacity routing does not support masked_ratio" + scores = make_scores() + expert_capacity = math.ceil(num_tokens * num_topk / num_experts) + candidate_k = num_experts + cand_vals, cand_idx = torch.topk(scores, candidate_k, dim=-1, largest=True, sorted=True) + order = torch.argsort(cand_vals[:, 0], descending=True).cpu().tolist() + cand_idx_cpu = cand_idx.cpu().tolist() + remaining = [expert_capacity] * num_experts + selected = [[-1] * num_topk for _ in range(num_tokens)] + pending: list[tuple[int, int]] = [] + for token in order: + used = 0 + for expert in cand_idx_cpu[token]: + if remaining[expert] > 0: + selected[token][used] = expert + remaining[expert] -= 1 + used += 1 + if used == num_topk: + break + if used < num_topk: + pending.append((token, used)) + fallback_cursor = 0 + for token, used in pending: + already = set(selected[token][:used]) + while used < num_topk: + found = False + for _ in range(num_experts): + expert = fallback_cursor % num_experts + fallback_cursor += 1 + if remaining[expert] > 0 and expert not in already: + selected[token][used] = expert + remaining[expert] -= 1 + already.add(expert) + used += 1 + found = True + break + if not found: + for expert in cand_idx_cpu[token]: + if expert not in already: + selected[token][used] = expert + already.add(expert) + used += 1 + found = True + break + assert found, "local-capacity routing could not fill all topk slots" + topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") + topk_weights = torch.gather(scores, 1, topk_idx) + elif args.routing_mode == "capacity": + assert args.masked_ratio == 0.0, "capacity routing does not support masked_ratio" + scores = make_scores() + all_scores = uneven_all_gather(scores, group=group) + local_num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + all_num_tokens_t = [torch.zeros_like(local_num_tokens) for _ in range(num_ranks)] + dist.all_gather(all_num_tokens_t, local_num_tokens, group=group) + all_num_tokens = [int(x.item()) for x in all_num_tokens_t] + local_offset = sum(all_num_tokens[:rank_idx]) + total_num_tokens = sum(all_num_tokens) + expert_capacity = math.ceil(total_num_tokens * num_topk / num_experts) + candidate_k = num_experts + cand_vals, cand_idx = torch.topk(all_scores, candidate_k, dim=-1, largest=True, sorted=True) + order = torch.argsort(cand_vals[:, 0], descending=True).cpu().tolist() + cand_idx_cpu = cand_idx.cpu().tolist() + remaining = [expert_capacity] * num_experts + selected = [[-1] * num_topk for _ in range(total_num_tokens)] + pending: list[tuple[int, int]] = [] + for token in order: + used = 0 + for expert in cand_idx_cpu[token]: + if remaining[expert] > 0: + selected[token][used] = expert + remaining[expert] -= 1 + used += 1 + if used == num_topk: + break + if used < num_topk: + pending.append((token, used)) + fallback_cursor = 0 + for token, used in pending: + already = set(selected[token][:used]) + while used < num_topk: + found = False + for _ in range(num_experts): + expert = fallback_cursor % num_experts + fallback_cursor += 1 + if remaining[expert] > 0 and expert not in already: + selected[token][used] = expert + remaining[expert] -= 1 + already.add(expert) + used += 1 + found = True + break + if not found: + for expert in cand_idx_cpu[token]: + if expert not in already: + selected[token][used] = expert + already.add(expert) + used += 1 + found = True + break + assert found, "capacity routing could not fill all topk slots" + if rank_idx == 0: + probs_cpu = torch.softmax(all_scores.cpu(), dim=1) + natural_mass = 0.0 + selected_mass = 0.0 + changed_slots = 0 + touched_tokens = 0 + for token, row in enumerate(selected): + natural = cand_idx_cpu[token][:num_topk] + natural_set = set(natural) + row_set = set(row) + overlap = len(natural_set & row_set) + changed_slots += num_topk - overlap + touched_tokens += overlap != num_topk + natural_mass += sum(float(probs_cpu[token, expert].item()) for expert in natural) + selected_mass += sum(float(probs_cpu[token, expert].item()) for expert in row) + mass_drop = (natural_mass - selected_mass) / max(natural_mass, 1e-12) + used_counts = [expert_capacity - r for r in remaining] + dist_print( + f" > capacity_quality: changed={changed_slots / max(total_num_tokens * num_topk, 1) * 100:.1f}% " + f"touched={touched_tokens / max(total_num_tokens, 1) * 100:.1f}% " + f"mass_drop={mass_drop * 100:.1f}% max_count={max(used_counts)}", + once_in_node=False, + ) + all_topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") + topk_idx = all_topk_idx[local_offset:local_offset + num_tokens].contiguous() + topk_weights = torch.gather(scores, 1, topk_idx) + else: + scores = make_scores() + topk_weights, topk_idx = torch.topk( + scores, num_topk, dim=-1, largest=True, sorted=False + ) + topk_idx = apply_eplb_replicas(topk_idx) + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_weights.masked_fill_(topk_idx < 0, 0) + + # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖 + phase_profile_enabled = os.environ.get("DG_SM90_MOE_PHASE_PROFILE", "0") not in ("", "0") + phase_profile_extra = 64 if phase_profile_enabled else 0 + cum_stats_fused = torch.zeros( + (num_experts_per_rank + phase_profile_extra,), dtype=torch.int, device="cuda" + ) + cum_stats_baseline = cum_stats_fused.clone() + + # ---- BF16 → FP8 量化 ---- + # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序) + # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF + x_fp8 = per_token_cast_to_fp8( + x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False + ) + + # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF) + # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组 + l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16) + l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16) + + # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变 + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + l1_weights, l2_weights + ) + + # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致) + clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None + run_baseline_enabled = args.run_baseline or bool(args.check_output_diff) + + # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)---- + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + + # ---- 分配 fused 的 SymmBuffer 与输出 buffer ---- + sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + ) + y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + + def run_fused(): + # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时 + # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入 + sym_buffer.x[:num_tokens].copy_(x_fp8[0]) + sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1]) + sym_buffer.topk_idx[:num_tokens].copy_(topk_idx) + sym_buffer.topk_weights[:num_tokens].copy_(topk_weights) + + deep_gemm.fp8_mega_moe( + y_fused, + transformed_l1, + transformed_l2, + sym_buffer, + cumulative_local_expert_recv_stats=cum_stats_fused, + recipe=(128, 128, 128), + activation="swiglu", + activation_clamp=clamp_arg, + fast_math=bool(args.fast_math), + ) + return y_fused + + # ---- 打印 config ---- + dist_print("Config (H200 fused mega-MoE):", once_in_node=True) + dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True) + dist_print( + f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True + ) + dist_print( + f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})", + once_in_node=True, + ) + dist_print(f" > Masked ratio: {args.masked_ratio}", once_in_node=True) + dist_print( + f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, " + f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 " + f"(SM90 grouped GEMM constraint)", + once_in_node=True, + ) + dist_print( + f" > Baseline: {'enabled' if run_baseline_enabled else 'disabled'}", + once_in_node=True, + ) + dist_print( + f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True + ) + dist_print(once_in_node=True) + + # 与社区版 test_mega_moe.py 对齐:NCU 模式只跑 fused kernel,避免 baseline 噪声。 + if args.ncu_profile_only: + dist_print("Run fused SM90 mega-MoE kernel:", once_in_node=True) + y = run_fused() + torch.cuda.synchronize() + assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16 + dist_print(" > Done, exiting", once_in_node=True) + dist.barrier() + sym_buffer.destroy() + dist.destroy_process_group() + return + + # ---- 分配 DeepEP buffer(baseline 用)---- + deep_ep = _import_deep_ep() if run_baseline_enabled else None + ep_buffer = None + if deep_ep is not None: + ep_buffer = _make_deep_ep_buffer( + deep_ep, + group, + num_max_tokens_per_rank, + hidden, + num_topk, + sym_buffer.buffer.nbytes, + ) + + # ---------------------------------------------------------------- + # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine + # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换** + # 的版本(baseline grouped GEMM 不需要 gate/up interleave) + # ---------------------------------------------------------------- + def run_baseline(): + recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch( + x_fp8, + topk_idx=topk_idx, + topk_weights=topk_weights, + cumulative_local_expert_recv_stats=cum_stats_baseline, + num_experts=num_experts, + expert_alignment=alignment, + do_cpu_sync=False, + do_handle_copy=False, + do_expand=True, + use_tma_aligned_col_major_sf=False, # SM90: row-major float SF + ) + n = recv_x[0].size(0) + + # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接) + l1_y = torch.empty( + (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda" + ) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + recv_x, + l1_weights, + l1_y, + handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, + disable_ue8m0_cast=True, + ) + + # Triton SwiGLU + FP8 量化(含 topk 权重乘法) + # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K; + # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline + # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。 + l1_y = swiglu_apply_weight_to_fp8_triton( + x=l1_y, + topk_weights=recv_topk_weights, + clamp_value=clamp_arg, + num_per_channels=BASELINE_L2_ACT_SF_GRAN, + use_ue8m0_scale=True, + ) + + # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16 + l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda") + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + l1_y, + l2_weights, + l2_y, + handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, + disable_ue8m0_cast=True, + ) + + # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank + return ep_buffer.combine(l2_y, handle=handle)[0] + + # ---- 跑一次确保不报错(fused + 可选 baseline)---- + y = run_fused() + assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, ( + f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}" + ) + if ep_buffer is not None: + out_b = run_baseline() + assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, ( + f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}" + ) + if args.check_output_diff: + diff = (y.float() - out_b.float()).abs() + denom = out_b.float().abs().mean().clamp_min(1e-12) + dist_print( + "Output diff (fused vs legacy-per128 baseline):", once_in_node=True + ) + dist_print( + f" > max_abs={diff.max().item():.6e}, " + f"mean_abs={diff.mean().item():.6e}, " + f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}", + once_in_node=True, + ) + dist_print(once_in_node=True) + + # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ---- + # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目 + # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。 + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + all_routed_topk_idx = gathered_topk_idx + local_num_tokens_t = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + all_num_tokens_t = [torch.zeros_like(local_num_tokens_t) for _ in range(num_ranks)] + dist.all_gather(all_num_tokens_t, local_num_tokens_t, group=group) + all_num_tokens = [int(x.item()) for x in all_num_tokens_t] + peer_recv_counts = [] + row_start = 0 + for src_tokens in all_num_tokens: + src_topk = all_routed_topk_idx[row_start:row_start + src_tokens] + peer_recv_counts.append(int(((src_topk >= rank_idx * num_experts_per_rank) & + (src_topk < (rank_idx + 1) * num_experts_per_rank)).sum().item())) + row_start += src_tokens + max_peer_recv = max(peer_recv_counts) if peer_recv_counts else 0 + gathered_topk_idx = all_routed_topk_idx.clone() + gathered_topk_idx[ + (gathered_topk_idx < rank_idx * num_experts_per_rank) + | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank) + ] = -1 + local_expert_ids = gathered_topk_idx[gathered_topk_idx != -1] + num_recv_tokens = int(local_expert_ids.numel()) + num_touched_experts = int(torch.unique(local_expert_ids).numel()) + if num_recv_tokens > 0: + local_counts = torch.bincount( + local_expert_ids - rank_idx * num_experts_per_rank, + minlength=num_experts_per_rank, + ) + num_m_tiles = int(((local_counts + 63) // 64).sum().item()) + max_expert_tokens = int(local_counts.max().item()) + else: + num_m_tiles = 0 + max_expert_tokens = 0 + + # ---- NSYS external profiler multi-iter ---- + # Under DG_USE_NVIDIA_TOOLS=1, bench_kineto returns a sentinel and does not + # run its internal torch.profiler loop. Keep this explicit loop so nsys + # captures multiple steady-state mega_moe kernels, matching decode_t256/t512. + _nsys_iters = int(os.environ.get("NSYS_ITERS", "0")) + if _nsys_iters > 0: + for _it in range(_nsys_iters): + torch.cuda.synchronize() + if ep_buffer is not None: + ep_buffer.barrier(use_comm_stream=False) + else: + dist.barrier() + torch.cuda._sleep(int(2e7)) # ~10ms gap between iters + if ep_buffer is not None: + ep_buffer.barrier(use_comm_stream=False) + else: + dist.barrier() + run_fused() + torch.cuda.synchronize() + if ep_buffer is not None: + ep_buffer.barrier(use_comm_stream=False) + else: + dist.barrier() + + # ---- benchmark ---- + # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead) + if phase_profile_enabled: + cum_stats_fused.zero_() + t_fused = bench_kineto( + run_fused, + SM90_KERNEL_NAME, + num_tests=args.num_bench_tests, + barrier=lambda: ep_buffer.barrier(use_comm_stream=False) + if ep_buffer is not None + else dist.barrier(), + trace_path=( + f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json" + if args.dump_profile_traces + else None + ), + ) + if phase_profile_enabled: + cum_stats_fused.zero_() + torch.cuda.synchronize() + if ep_buffer is not None: + ep_buffer.barrier(use_comm_stream=False) + else: + dist.barrier() + phase_start = torch.cuda.Event(enable_timing=True) + phase_end = torch.cuda.Event(enable_timing=True) + phase_start.record() + run_fused() + phase_end.record() + torch.cuda.synchronize() + phase_event_us = phase_start.elapsed_time(phase_end) * 1000.0 + raw_i32 = cum_stats_fused[num_experts_per_rank:num_experts_per_rank + 64].detach().cpu().tolist() + def _u64(slot: int) -> int: + lo = raw_i32[slot * 2] & 0xffffffff + hi = raw_i32[slot * 2 + 1] & 0xffffffff + return lo | (hi << 32) + names = ("dispatch_total", "dispatch_pull", "math_loop", "combine_barrier", "combine_reduce", "gemm_core", "l1_epilogue", "l2_epilogue") + pieces = [] + for idx, name in enumerate(names): + total = _u64(idx) + max_cycles = _u64(8 + idx) + count = _u64(16 + idx) + avg_us = (total / count / 1000.0) if count else 0.0 + max_us = max_cycles / 1000.0 + pieces.append(f"{name}:avg={avg_us:.1f}us,max={max_us:.1f}us,n={count},ns={total}/{max_cycles}") + dist_print(f" > phase_profile rank={rank_idx}: event={phase_event_us:.1f}us; " + "; ".join(pieces)) + # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events) + t_baseline = ( + _bench_cuda_events( + run_baseline, + num_warmup=args.num_warmup, + num_repeat=args.num_repeat, + l2_flush_gb=args.l2_flush_gb, + ) + if ep_buffer is not None + else 0.0 + ) + + def safe_div(a, b): + return float("nan") if b == 0 else a / b + + # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens + tflops = safe_div( + 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused + ) + + # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同) + l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden + l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden + l1_weight_sf_bytes = ( + num_touched_experts + * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN) + * (hidden // WEIGHT_SF_GRAN_K) + * 4 + ) + l2_weight_sf_bytes = ( + num_touched_experts + * (hidden // WEIGHT_SF_GRAN_MN) + * (intermediate_hidden // WEIGHT_SF_GRAN_K) + * 4 + ) + l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4 + l2_act_sf_bytes = ( + num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4 + ) + num_hbm_bytes = ( + l1_weight_bytes + + l2_weight_bytes # weights (FP8) + + l1_weight_sf_bytes + + l2_weight_sf_bytes # weight SF (FP32) + + num_recv_tokens * hidden + + l1_input_sf_bytes # L1 输入读 (FP8 + SF) + + num_recv_tokens * intermediate_hidden + + l2_act_sf_bytes # L1 输出写 (FP8 + SF) + + num_recv_tokens * intermediate_hidden + + l2_act_sf_bytes # L2 输入读 (FP8 + SF) + + num_recv_tokens * hidden * 2 # L2 输出写 (BF16) + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16 + num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2) + nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused) + + # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s) + t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12 + + # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐 + approx_factor = t_fused / max(t_fused - t_reduction, 1e-12) + + # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline + tflops_baseline = safe_div( + 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline + ) + hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline) + nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline) + + def fmt_perf_line( + name: str, + t: float, + compute_tflops: float, + hbm_gbs_: float, + nvlink_gbs_: float, + reduction_us: float | None = None, + speedup: float | None = None, + ) -> str: + reduction = f"{reduction_us:13.1f}" if reduction_us is not None else f"{'-':>13}" + speedup_text = ( + f"{speedup:6.2f}x {'fused faster' if speedup > 1 else 'baseline faster'}" + if speedup is not None else + f"{'-':>21}" + ) + return ( + f" > {name:<10} {rank_idx:2d}/{num_ranks:<2d} " + f"{num_recv_tokens:12d} " + f"{num_touched_experts:14d} {num_m_tiles:7d} {max_expert_tokens:8d} {max_peer_recv:8d} | " + f"{compute_tflops:15.0f} " + f"{hbm_gbs_:9.0f} " + f"{nvlink_gbs_:9.0f} " + f"{t * 1e6:9.0f} " + f"{reduction} " + f"{speedup_text}" + ) + + dist_print("Performance:", once_in_node=True) + dist_print( + " > kind EP recv_tokens active_experts m_tiles max_exp max_peer | " + "compute(TFLOPS) HBM(GB/s) NVL(GB/s) time(us) reduction(us) speedup", + once_in_node=True, + ) + dist_print( + fmt_perf_line( + "[fused]", + t_fused, + tflops * approx_factor, + hbm_gbs * approx_factor, + nvlink_gbs * approx_factor, + reduction_us=t_reduction * 1e6, + ) + ) + if ep_buffer is not None: + speedup = safe_div(t_baseline, t_fused) + dist_print( + fmt_perf_line( + "[baseline]", + t_baseline, + tflops_baseline, + hbm_gbs_baseline, + nvlink_gbs_baseline, + speedup=speedup, + ) + ) + else: + reason = ( + "disabled; pass --run-baseline or --check-output-diff to compare" + if not run_baseline_enabled + else "deep_ep unavailable" + ) + dist_print(f" > [baseline] ({reason})", once_in_node=True) + + # ---- 清理 ---- + dist.barrier() + sym_buffer.destroy() + if ep_buffer is not None: + ep_buffer.destroy() + dist.destroy_process_group() + + +# ============================================================================ +# 模块 6:argparse + spawn +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline" + ) + + # 资源 + parser.add_argument( + "--ncu-profile-only", + action="store_true", + help="只运行一次 fused SM90 kernel,便于 NCU/Nsight 采样", + ) + parser.add_argument( + "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)" + ) + parser.add_argument( + "--local-rank-idx", + type=int, + default=None, + help="单进程模式的 local rank;用于外部 launcher/NCU 分别启动每个 rank", + ) + + # 模型形状 + # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096 + parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192) + parser.add_argument( + "--num-tokens", + type=int, + default=0, + help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank", + ) + parser.add_argument( + "--num-max-removed-tokens", + type=int, + default=0, + help="num-tokens 为 0 时,每个 rank 随机移除的最大 token 数", + ) + parser.add_argument("--hidden", type=int, default=7168) + parser.add_argument( + "--intermediate-hidden", + type=int, + default=3072, + help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)", + ) + parser.add_argument( + "--activation-clamp", + type=float, + default=10.0, + help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭", + ) + parser.add_argument("--num-experts", type=int, default=384) + parser.add_argument("--num-topk", type=int, default=6) + parser.add_argument( + "--routing-mode", + type=str, + default="random", + choices=("random", "balanced", "balanced-shuffled", "balanced-shuffled-score", "topk-repair", "topk-repair-one", "topk-repair-softmax", "topk-repair-budget-softmax", "global-repair-budget-softmax", "local-exact-score", "local-exact-score-one", "local-capacity", "capacity"), + help="routing 构造方式;balanced/balanced-shuffled/topk-repair/local-exact-score/local-capacity/capacity 控制每 expert assignment 数", + ) + parser.add_argument( + "--masked-ratio", + type=float, + default=0.0, + help="随机 mask 掉部分 topk expert selection,用于验证稀疏路由边界", + ) + parser.add_argument( + "--score-powerlaw-alpha", + type=float, + default=0.0, + help="给 routing score 加 Zipf/power-law expert bias;0 表示关闭", + ) + parser.add_argument( + "--score-powerlaw-scale", + type=float, + default=1.0, + help="power-law bias 的标准差尺度", + ) + parser.add_argument( + "--repair-mass-drop-budget", + type=float, + default=0.0, + help="topk-repair-budget-softmax 的 aggregate selected softmax-mass drop 上限,例如 0.10", + ) + parser.add_argument( + "--repair-max-rounds", + type=int, + default=4, + help="global-repair-budget-softmax 的 greedy repair 最大轮数,避免诊断代码超时", + ) + parser.add_argument( + "--num-redundant-experts", + type=int, + default=0, + help="benchmark-only EPLB simulation: reserve physical expert slots as hot-expert replicas", + ) + parser.add_argument( + "--replica-dispatch", + choices=("hash", "static"), + default="hash", + help="replica remap model: token-level hash or SGLang static source-rank approximation", + ) + parser.add_argument( + "--fast-math", + type=int, + default=1, + help="fused 内 SwiGLU 是否启用 fast-math(0/1)", + ) + + # 测时 + parser.add_argument( + "--num-bench-tests", + type=int, + default=30, + help="bench_kineto 抓 fused 时的迭代数", + ) + parser.add_argument( + "--num-warmup", type=int, default=5, help="baseline cuda events warmup" + ) + parser.add_argument( + "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代" + ) + parser.add_argument( + "--l2-flush-gb", + type=float, + default=8.0, + help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭", + ) + parser.add_argument( + "--run-baseline", + action="store_true", + help="启用 DeepEP+grouped-FP8 legacy baseline;默认关闭以避免 full-size 默认配置触发 baseline kernel 非法访问", + ) + parser.add_argument( + "--check-output-diff", + type=int, + default=0, + help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)", + ) + parser.add_argument( + "--dump-profile-traces", + type=str, + default="", + help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)", + ) + + args = parser.parse_args() + + if args.dump_profile_traces: + os.makedirs(args.dump_profile_traces, exist_ok=True) + + if args.local_rank_idx is not None: + # 单进程模式:由外部 launcher 分别设置 MASTER_ADDR/PORT/WORLD_SIZE/RANK。 + test(args.local_rank_idx, args.num_processes, args) + else: + # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group。 + torch.multiprocessing.spawn( + test, args=(args.num_processes, args), nprocs=args.num_processes + ) diff --git a/tests/test_mega_moe_sm90.py b/tests/test_mega_moe_sm90.py new file mode 100644 index 0000000000..f38a5be5c8 --- /dev/null +++ b/tests/test_mega_moe_sm90.py @@ -0,0 +1,528 @@ +"""Layered tests for the SM90 (Hopper) MegaMoE kernel. + +The fused FP8 SM90 MegaMoE kernel is exercised across a hierarchy of +scenarios so that each kernel path / heuristic branch / edge case is +covered with at least one configuration. + +Layers +------ + L1 Smoke : single tiny config; only verifies the kernel runs + and produces an output close to a PyTorch reference. + L2 Heuristic : sweeps tokens-per-expert across the bands of + ``get_block_config_for_mega_moe_sm90`` so each + ``{block_m, num_epilogue_warpgroups}`` case is hit. + L3 Shape sweep : sweeps ``hidden``, ``intermediate_hidden`` and + ``num_topk`` over divisible-by-128 values. + L4 Edge cases : masking ratio, activation clamp (finite vs inf), + ``fast_math`` 0/1, ``num_tokens`` boundaries. + L5 Stress : ``--num-correctness-tests`` repeated random configs. + +Notes +----- +* The reference is a pure PyTorch BF16/FP32 simulation of the fused path + (dequantize -> matmul -> SwiGLU + clamp + per-row quantize -> matmul -> + cross-rank scatter -> BF16 reduce). It is *not* bitwise-identical to + the kernel; correctness is checked with ``calc_diff < 0.07``. +* Because every scenario allocates its own symmetric memory buffer we + re-`init_dist`/`destroy` once per process at the outer level only, + and re-create ``SymmBuffer`` per scenario. +* Skips itself when the device is not SM90. +""" + +import argparse +import math +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple, List, Dict, Any + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import calc_diff, get_arch_major + + +# ---------------------------------------------------------------------------- +# Quantization helpers +# ---------------------------------------------------------------------------- + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Block (128, 128) FP8 quantization along (N, K). + + Args + ---- + w : (G, N, K) bf16, with N % 128 == 0 and K % 128 == 0 + + Returns + ------- + fp8 : (G, N, K) torch.float8_e4m3fn + sf : (G, N // 128, K // 128) torch.float32, MN-major in the (N, K) + plane (i.e. K is the inner contiguous dim, matching the kernel's + ``stride_k = 1`` expectation and the DeepEP convention). + """ + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _dequant_block_128_128(w_fp8: torch.Tensor, sf: torch.Tensor) -> torch.Tensor: + """Inverse of `_quantize_grouped_fp8_block_128_128`. Returns fp32.""" + *prefix, n, k = w_fp8.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w_fp8.float().view(*prefix, n // 128, 128, k // 128, 128) + return (w_view * sf.unsqueeze(-1).unsqueeze(-3)).view(*prefix, n, k) + + +def _dequant_per_token_per_128_k(x_fp8: torch.Tensor, sf: torch.Tensor) -> torch.Tensor: + """For (M, K) fp8 with (M, K // 128) float SF (per-token, K-major).""" + m, k = x_fp8.shape + assert k % 128 == 0 + w_view = x_fp8.float().view(m, k // 128, 128) + return (w_view * sf.unsqueeze(-1)).view(m, k) + + +# ---------------------------------------------------------------------------- +# PyTorch reference +# ---------------------------------------------------------------------------- + +def _swiglu_fp32(gate_up: torch.Tensor, clamp: float) -> torch.Tensor: + """SwiGLU with one-sided gate clamp and two-sided up clamp. + + Matches the fused kernel: ``silu(min(gate, c)) * clamp(up, -c, c)``. + """ + n2 = gate_up.size(-1) + half = n2 // 2 + gate, up = gate_up[..., :half], gate_up[..., half:] + if math.isfinite(clamp): + gate = gate.clamp(max=clamp) + up = up.clamp(min=-clamp, max=clamp) + return torch.nn.functional.silu(gate) * up + + +def _reference_fused( + x_fp8_local: torch.Tensor, x_sf_local: torch.Tensor, + topk_idx_local: torch.Tensor, topk_weights_local: torch.Tensor, + l1_w_fp8: torch.Tensor, l1_w_sf: torch.Tensor, + l2_w_fp8: torch.Tensor, l2_w_sf: torch.Tensor, + rank_idx: int, num_ranks: int, group: dist.ProcessGroup, + num_experts: int, num_topk: int, + hidden: int, intermediate_hidden: int, + activation_clamp: float, +) -> torch.Tensor: + """Reference: returns (num_tokens, hidden) bf16 result for *this* rank. + + All-gathers the global tokens / topk decisions / per-rank weights, then + for each global token routes through its topk experts, applies the + L1+SwiGLU+L2 path, and reduces over topk on the source rank. + """ + num_experts_per_rank = num_experts // num_ranks + + # --- gather global token data -------------------------------------------------- + x_fp8_g = uneven_all_gather(x_fp8_local, group=group) # (Mg, H) + x_sf_g = uneven_all_gather(x_sf_local, group=group) # (Mg, H/128) + topk_idx_g = uneven_all_gather(topk_idx_local, group=group) # (Mg, K) + topk_w_g = uneven_all_gather(topk_weights_local, group=group) # (Mg, K) + mg = x_fp8_g.size(0) + + # rank-id lookup for each gathered token (for combine routing) + rank_offsets = [0] + sizes = [torch.tensor([0], device='cuda')] # placeholder + # mimic uneven_all_gather to compute per-rank token counts + local_size = torch.tensor([x_fp8_local.size(0)], device='cuda', dtype=torch.long) + sizes_t = torch.empty(num_ranks, dtype=torch.long, device='cuda') + dist.all_gather_into_tensor(sizes_t, local_size, group=group) + sizes_list = sizes_t.tolist() + src_rank_of = torch.empty(mg, dtype=torch.long, device='cuda') + cur = 0 + for r, s in enumerate(sizes_list): + src_rank_of[cur:cur + s] = r + cur += s + assert cur == mg + + # --- gather all-rank weights -------------------------------------------------- + # l1_w_fp8: (E_pr, 2*IH, H), l1_w_sf: (E_pr, 2*IH, H/128) + l1_w_g = [torch.empty_like(l1_w_fp8) for _ in range(num_ranks)] + l1_sf_g = [torch.empty_like(l1_w_sf) for _ in range(num_ranks)] + l2_w_g = [torch.empty_like(l2_w_fp8) for _ in range(num_ranks)] + l2_sf_g = [torch.empty_like(l2_w_sf) for _ in range(num_ranks)] + dist.all_gather(l1_w_g, l1_w_fp8, group=group) + dist.all_gather(l1_sf_g, l1_w_sf, group=group) + dist.all_gather(l2_w_g, l2_w_fp8, group=group) + dist.all_gather(l2_sf_g, l2_w_sf, group=group) + l1_w_all = torch.stack(l1_w_g, dim=0) # (R, E_pr, 2*IH, H) + l1_sf_all = torch.stack(l1_sf_g, dim=0) + l2_w_all = torch.stack(l2_w_g, dim=0) + l2_sf_all = torch.stack(l2_sf_g, dim=0) + + # --- per-token / per-topk compute -------------------------------------------------- + # The combine slot tensor: (Mg, K, H) bf16 — each src rank will reduce over K. + combine_buf = torch.zeros(mg, num_topk, hidden, dtype=torch.float32, device='cuda') + + # Precompute dequantized x in fp32 + x_fp32 = _dequant_per_token_per_128_k(x_fp8_g, x_sf_g) # (Mg, H) + + # Iterate (cheap; reference is for small test configs only) + # Token-chunked to keep gathered (S, 2*IH, H) dequant tensors below GPU memory. + _CHUNK = 256 + for k in range(num_topk): + # Skip masked + mask = topk_idx_g[:, k] >= 0 + if not mask.any(): + continue + sel_idx_full = mask.nonzero(as_tuple=False).squeeze(-1) # (S,) + for c0 in range(0, sel_idx_full.numel(), _CHUNK): + sel_idx = sel_idx_full[c0:c0 + _CHUNK] + eids = topk_idx_g[sel_idx, k] # (S,) + weights = topk_w_g[sel_idx, k] # (S,) + x_sel = x_fp32[sel_idx] # (S, H) + + dst_rank = (eids // num_experts_per_rank).long() + dst_local = (eids % num_experts_per_rank).long() + + # L1 GEMM (per-token): y = x @ W^T shape (S, 2*IH) + l1_w_sel = _dequant_block_128_128( + l1_w_all[dst_rank, dst_local], # (S, 2*IH, H) + l1_sf_all[dst_rank, dst_local], + ) + l1_y = torch.einsum('sk,snk->sn', x_sel, l1_w_sel) # (S, 2*IH) + del l1_w_sel + + # SwiGLU + clamp + multiply by topk weight + l1_y = _swiglu_fp32(l1_y, activation_clamp) * weights.unsqueeze(-1) # (S, IH) + + # Per-row, per-64-col FP8 quantize -> dequantize + s_, ih = l1_y.shape + assert ih == intermediate_hidden and ih % 64 == 0 + l1_view = l1_y.view(s_, ih // 64, 64) + amax = l1_view.abs().amax(dim=-1).clamp(1e-4) # (S, IH/64) + sf2 = amax / 448.0 + l1_q = (l1_view / sf2.unsqueeze(-1)).to(torch.float8_e4m3fn).float() + l2_in = (l1_q * sf2.unsqueeze(-1)).view(s_, ih) # (S, IH) fp32 + + # L2 GEMM + l2_w_sel = _dequant_block_128_128( + l2_w_all[dst_rank, dst_local], # (S, H, IH) + l2_sf_all[dst_rank, dst_local], + ) + l2_y = torch.einsum('sn,smn->sm', l2_in, l2_w_sel) # (S, H) + del l2_w_sel + + # Scatter to combine buffer (cast to bf16 then back to mimic kernel storage) + combine_buf[sel_idx, k] = l2_y.to(torch.bfloat16).float() + + # Sum over K -> (Mg, H), keep only this rank's slice + y_full_bf16 = combine_buf.to(torch.bfloat16).sum(dim=1).to(torch.bfloat16) # (Mg, H) + start = sum(sizes_list[:rank_idx]) + end = start + sizes_list[rank_idx] + return y_full_bf16[start:end].contiguous() + + +# ---------------------------------------------------------------------------- +# Single-scenario runner +# ---------------------------------------------------------------------------- + +def _run_scenario( + name: str, + cfg: Dict[str, Any], + rank_idx: int, num_ranks: int, group: dist.ProcessGroup, + diff_tol: float, +): + num_max = cfg['num_max_tokens_per_rank'] + num_tokens = cfg.get('num_tokens', num_max) + hidden = cfg['hidden'] + intermediate_hidden = cfg['intermediate_hidden'] + num_experts = cfg['num_experts'] + num_topk = cfg['num_topk'] + masked_ratio = cfg.get('masked_ratio', 0.0) + activation_clamp = cfg.get('activation_clamp', 10.0) + fast_math = cfg.get('fast_math', True) + + assert num_experts % num_ranks == 0, f'{name}: experts {num_experts} not divisible by ranks {num_ranks}' + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max + assert hidden % 128 == 0 and intermediate_hidden % 128 == 0 + + verbose = bool(int(os.environ.get('DG_TEST_VERBOSE', '0'))) + def _trace(stage: str): + if verbose: + print(f'[rank{rank_idx}] {name} :: {stage}', flush=True) + + _trace('begin') + torch.manual_seed(rank_idx * 1000 + abs(hash(name)) % 1000) + random.seed(rank_idx * 1000 + abs(hash(name)) % 1000) + + # ---- Inputs (bf16) ------------------------------------------------------- + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + if masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < masked_ratio, -1) + topk_w.masked_fill_(topk_idx < 0, 0) + + # Quantize x to FP8 with per-128 K float SF (SM90 format) + # Quantize x to FP8 with per-128 K float SF (SM90 format) + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + # Quantize weights with block (128, 128) — matches DeepSeekV4FlashFp8 / DeepEP. + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + + # SM90 weight transform (gate/up interleave only). With block (128, 128) + # SF, the SF tensor is consumed by the kernel as-is — no MN-major TMA + # transform and no SF-side gate/up interleave is needed. + _trace('weight_transform') + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf) + ) + + # ---- Allocate symm buffer ----------------------------------------------- + _trace('alloc_symm_buffer') + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max, num_topk, + hidden, intermediate_hidden, + ) + cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + + # ---- Run fused ----------------------------------------------------------- + _trace('copy_inputs') + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + + y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + _trace('launch_fused (may JIT-compile, can take minutes)') + deep_gemm.fp8_mega_moe( + y_fused, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp if math.isfinite(activation_clamp) else None, + fast_math=fast_math, + ) + _trace('sync_fused') + torch.cuda.synchronize() + _trace('fused_done') + + # ---- Reference & check --------------------------------------------------- + # Use the FP8 weights and their block-(128, 128) SF directly — the dequant + # helper expects this MN/K-block SF layout, and the original (gate||up) row + # ordering is what `_swiglu_fp32` splits with ``[..., :IH], [..., IH:]``. + _trace('reference') + y_ref = _reference_fused( + x_fp8, x_sf, topk_idx, topk_w, + l1_w_fp8, l1_w_sf, l2_w_fp8, l2_w_sf, + rank_idx, num_ranks, group, + num_experts, num_topk, + hidden, intermediate_hidden, + activation_clamp, + ) + + diff = calc_diff(y_fused, y_ref) + ok = diff < diff_tol + dist_print(f' [{name:<32}] diff={diff:.4f} ' + f'(tol={diff_tol:.2f}) {"OK" if ok else "FAIL"}', + once_in_node=True) + assert ok, f'{name}: diff={diff} >= tol={diff_tol}' + + # Verify cum_stats has been incremented (i.e. dispatch ran) + if num_tokens > 0 and masked_ratio < 1.0: + assert cum_stats.sum().item() >= 0 # non-negative; can be 0 if nothing routed here + + buffer.destroy() + dist.barrier() + + +# ---------------------------------------------------------------------------- +# Scenario tables +# ---------------------------------------------------------------------------- + +# A single tiny config used as a smoke test. +_SMOKE = dict( + num_max_tokens_per_rank=64, num_tokens=64, + hidden=512, intermediate_hidden=512, + num_experts=8, num_topk=2, +) + + +def _layer1_smoke() -> List[Tuple[str, Dict[str, Any]]]: + return [('L1.smoke', dict(_SMOKE))] + + +def _layer2_heuristic_branches(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: + """Vary tokens / (num_experts * num_topk / num_ranks) so each + ``get_block_config_for_mega_moe_sm90`` band fires at least once. + + The heuristic decides on ``avg_tokens_per_expert``; we approximate by + setting ``num_max_tokens_per_rank`` and ``num_topk`` while keeping + ``num_experts`` fixed. The bands are at 64.5 / 96.5 / 192.5. + """ + base = dict(hidden=1024, intermediate_hidden=1024, + num_experts=8 * num_ranks, num_topk=2) + out: List[Tuple[str, Dict[str, Any]]] = [] + # tokens-per-rank settings chosen to hit (small / mid / large) bands + for tokens, label in [(64, 'small'), (256, 'midA'), (512, 'midB'), (2048, 'large')]: + cfg = dict(base) + cfg.update(num_max_tokens_per_rank=tokens, num_tokens=tokens) + out.append((f'L2.heur.{label}.t{tokens}', cfg)) + return out + + +def _layer3_shape_sweep(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: + out: List[Tuple[str, Dict[str, Any]]] = [] + base_experts = 8 * num_ranks + for hidden in (512, 2048): + for ih in (512, 2048): + for topk in (1, 2, 4): + if topk > base_experts: + continue + cfg = dict(num_max_tokens_per_rank=128, num_tokens=128, + hidden=hidden, intermediate_hidden=ih, + num_experts=base_experts, num_topk=topk) + out.append((f'L3.h{hidden}_ih{ih}_k{topk}', cfg)) + return out + + +def _layer4_edges(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: + base = dict(num_max_tokens_per_rank=128, + hidden=512, intermediate_hidden=512, + num_experts=8 * num_ranks, num_topk=2) + out = [] + # Masked ratios + for mr in (0.0, 0.3, 0.7): + cfg = dict(base); cfg.update(num_tokens=128, masked_ratio=mr) + out.append((f'L4.mask{mr:.1f}', cfg)) + # All masked + cfg = dict(base); cfg.update(num_tokens=128, masked_ratio=1.0) + out.append(('L4.mask_all', cfg)) + # Activation clamp variations (finite vs inf) + for c in (1.0, 10.0, math.inf): + cfg = dict(base); cfg.update(num_tokens=128, activation_clamp=c) + out.append((f'L4.clamp{c}', cfg)) + # fast_math toggle + for fm in (True, False): + cfg = dict(base); cfg.update(num_tokens=128, fast_math=fm) + out.append((f'L4.fm{int(fm)}', cfg)) + # num_tokens boundaries + cfg = dict(base); cfg.update(num_tokens=0) + out.append(('L4.tokens0', cfg)) + cfg = dict(base); cfg.update(num_tokens=base['num_max_tokens_per_rank']) + out.append(('L4.tokens_max', cfg)) + return out + + +def _layer5_stress(num_ranks: int, num_tests: int) -> List[Tuple[str, Dict[str, Any]]]: + """Random configs under simple constraints.""" + rng = random.Random(0xC0FFEE) + out = [] + for i in range(num_tests): + hidden = rng.choice([512, 1024, 2048]) + ih = rng.choice([512, 1024, 2048]) + topk = rng.choice([1, 2, 4]) + tokens = rng.choice([32, 64, 128, 256, 512]) + masked = rng.choice([0.0, 0.0, 0.3, 0.5]) + clamp = rng.choice([1.0, 10.0, math.inf]) + fm = rng.choice([True, False]) + cfg = dict(num_max_tokens_per_rank=tokens, num_tokens=tokens, + hidden=hidden, intermediate_hidden=ih, + num_experts=8 * num_ranks, num_topk=topk, + masked_ratio=masked, activation_clamp=clamp, fast_math=fm) + out.append((f'L5.rand{i:03d}', cfg)) + return out + + +# ---------------------------------------------------------------------------- +# Entry point +# ---------------------------------------------------------------------------- + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + + # Skip on non-SM90 + if get_arch_major() != 9: + dist_print(f'[SKIP] test_mega_moe_sm90 requires SM90; got SM{get_arch_major()}0', + once_in_node=True) + dist.destroy_process_group() + return + + diff_tol = args.diff_tol + layers: List[Tuple[str, Dict[str, Any]]] = [] + + if 1 in args.layers: + layers += _layer1_smoke() + if 2 in args.layers: + layers += _layer2_heuristic_branches(num_ranks) + if 3 in args.layers: + layers += _layer3_shape_sweep(num_ranks) + if 4 in args.layers: + layers += _layer4_edges(num_ranks) + if 5 in args.layers: + layers += _layer5_stress(num_ranks, args.num_correctness_tests or 8) + + if args.filter: + layers = [(n, c) for n, c in layers if args.filter in n] + + dist_print(f'SM90 MegaMoE test plan: {len(layers)} scenarios across ' + f'layers {sorted(args.layers)} on {num_ranks} ranks', + once_in_node=True) + + failures: List[str] = [] + for name, cfg in layers: + try: + _run_scenario(name, cfg, rank_idx, num_ranks, group, diff_tol) + except AssertionError as ex: + dist_print(f' [{name}] FAIL: {ex}', once_in_node=True) + failures.append(name) + if args.fail_fast: + break + + dist_print('', once_in_node=True) + if failures: + dist_print(f'FAILED {len(failures)}/{len(layers)} scenarios: {failures}', + once_in_node=True) + else: + dist_print(f'PASSED all {len(layers)} scenarios', once_in_node=True) + + dist.barrier() + dist.destroy_process_group() + if failures: + sys.exit(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Layered SM90 MegaMoE tests') + parser.add_argument('--num-processes', type=int, default=2, + help='Number of ranks to spawn (default: 2)') + parser.add_argument('--layers', type=int, nargs='+', default=[1, 2, 3, 4], + help='Which layers to run (1..5). Default: 1 2 3 4. ' + 'Layer 5 requires --num-correctness-tests.') + parser.add_argument('--num-correctness-tests', type=int, default=None, + help='Layer 5 stress test count') + parser.add_argument('--filter', type=str, default='', + help='Substring filter on scenario names') + parser.add_argument('--diff-tol', type=float, default=0.07, + help='calc_diff tolerance (default: 0.07)') + parser.add_argument('--fail-fast', action='store_true', + help='Stop on first failing scenario') + args = parser.parse_args() + + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) From fb732739a958fc1aabf7cbbe72e7880c7715b453 Mon Sep 17 00:00:00 2001 From: AichenF Date: Mon, 1 Jun 2026 14:31:27 +0800 Subject: [PATCH 02/15] Layer local SM90 split tuning on PR1 --- csrc/jit_kernels/heuristics/mega_moe.hpp | 118 +++++-- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 73 ++++- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 58 ++-- .../include/deep_gemm/scheduler/mega_moe.cuh | 309 +++++++++++++++--- 4 files changed, 469 insertions(+), 89 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 7d5f15e82a..fe310122d1 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -406,7 +406,9 @@ static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( static std::tuple get_block_config_for_mega_moe_sm90( const int& num_ranks, const int& num_experts, const int& num_max_tokens_per_rank, const int& num_topk, - const int& num_tokens) { + const int& num_tokens, + const int& split_phase_mode = 0) { + DG_HOST_ASSERT(split_phase_mode >= 0 and split_phase_mode <= 2); // Keep mma.sync decode variants opt-in; the default path uses M64 CTAs. const int num_experts_per_rank = num_experts / num_ranks; const float expected_tokens_per_expert = @@ -415,8 +417,21 @@ static std::tuple get_block_config_for_mega_moe_sm90( ? get_env("DG_SM90_MOE_MMA_SYNC_M") : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); - const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); - const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); + const bool split_linear1_only = split_phase_mode == 1; + const bool split_linear2_only = split_phase_mode == 2; + int forced_block_m = 0; + int forced_epilogue_warpgroups = 0; + if (split_linear1_only) { + forced_block_m = get_env("DG_SM90_MOE_L1_FORCE_BLOCK_M"); + forced_epilogue_warpgroups = get_env("DG_SM90_MOE_L1_FORCE_EPILOGUE_WG"); + } else if (split_linear2_only) { + forced_block_m = get_env("DG_SM90_MOE_L2_FORCE_BLOCK_M"); + forced_epilogue_warpgroups = get_env("DG_SM90_MOE_L2_FORCE_EPILOGUE_WG"); + } + if (forced_block_m == 0) + forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); + if (forced_epilogue_warpgroups == 0) + forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or forced_epilogue_warpgroups == 1 or @@ -424,13 +439,19 @@ static std::tuple get_block_config_for_mega_moe_sm90( const bool use_mma_sync_decode = requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and + (get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 or + (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_2WG") != 0) or + (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_2WG") != 0)) and forced_block_m != 128 and not use_mma_sync_decode; const bool use_bn256_seq_n = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and + (get_env("DG_SM90_MOE_BN256_SEQ") != 0 or + (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_SEQ") != 0) or + (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_SEQ") != 0)) and forced_block_m != 128 and not use_mma_sync_decode; const bool use_b_stationary_2wg = - get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; + (get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 or + (split_linear1_only and get_env("DG_SM90_MOE_L1_B_STATIONARY_2WG") != 0)) and + not split_linear2_only and not use_mma_sync_decode; DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n or use_bn256_seq_n))); const int block_m = forced_block_m > 0 @@ -505,13 +526,18 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int& block_m, const int& block_n, const int& block_k, const int& num_dispatch_warps, const int& num_epilogue_warps, const bool& direct_l2_scatter_default = false, - const int& default_num_stages = 0) { + const int& default_num_stages = 0, + const int& split_phase_mode = 0) { constexpr int kSmemAlignment = 1024; + DG_HOST_ASSERT(split_phase_mode >= 0 and split_phase_mode <= 2); + const bool split_linear1_only = split_phase_mode == 1; + const bool split_linear2_only = split_phase_mode == 2; - // Dispatch region (same as SM100) - const int smem_expert_count_size = align( + // Dispatch region. Split K2 does not pull remote tokens, so its dynamic + // shared memory does not need expert-count scratch or dispatch buffers. + const int smem_expert_count_size = split_linear2_only ? 0 : align( num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); - const int smem_send_buffers_size = align( + const int smem_send_buffers_size = split_linear2_only ? 0 : align( static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), kSmemAlignment); const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; @@ -526,16 +552,19 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int wg_block_m = split_n_warpgroups ? block_m : block_m / num_epilogue_warpgroups; const int wg_block_n = (split_n_warpgroups or serial_n_warpgroups) ? block_n / 2 : block_n; const int smem_cd_accum = (block_m == 16 or block_m == 32) ? align(block_m * block_n * static_cast(sizeof(float)), kSmemAlignment) : 0; - const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) - const bool direct_l2_scatter = get_env( + const int smem_cd_l1 = split_linear2_only ? 0 : + num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) + const bool direct_l2_scatter = (not split_linear1_only) and + get_env( "DG_SM90_MOE_DIRECT_L2_SCATTER", direct_l2_scatter_default ? 1 : 0) != 0 and block_m != 16 and block_m != 32 and not serial_n_warpgroups and wg_block_n == 128; - const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and + const bool async_l1_tma_store = (not split_linear2_only) and + get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and block_m != 16 and block_m != 32 and not split_n_warpgroups and num_epilogue_warpgroups == 1; - const int smem_cd_l2 = direct_l2_scatter ? 0 : + const int smem_cd_l2 = (split_linear1_only or direct_l2_scatter) ? 0 : num_epilogue_warpgroups * wg_block_m * wg_block_n * static_cast(sizeof(nv_bfloat16)); const int smem_cd_l1_async = async_l1_tma_store ? 2 * num_epilogue_warpgroups * wg_block_m * (block_n / 2) : 0; @@ -547,7 +576,8 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( // * SFB is loaded directly from global by the math warpgroup (block-(128,128) // weight quantization), so no SMEM is reserved for it. const int smem_sfa_half_stride_bytes = align(block_m * static_cast(sizeof(float)), 128); - const int smem_sfa_per_stage = 2 * smem_sfa_half_stride_bytes; + const int smem_sfa_per_stage = split_linear1_only + ? smem_sfa_half_stride_bytes : 2 * smem_sfa_half_stride_bytes; const int smem_sfb_per_stage = 0; // Per-stage: A tile + B tile + SFA tile + SFB tile @@ -572,7 +602,13 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int preferred_num_stages = default_num_stages > 0 ? std::min(default_num_stages, max_num_stages) : (prefer_bn256_split ? std::min(4, max_num_stages) : 0); - const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); + int forced_num_stages = 0; + if (split_linear1_only) + forced_num_stages = get_env("DG_SM90_MOE_L1_NUM_STAGES"); + else if (split_linear2_only) + forced_num_stages = get_env("DG_SM90_MOE_L2_NUM_STAGES"); + if (forced_num_stages == 0) + forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); const int num_stages = forced_num_stages > 0 ? forced_num_stages : (preferred_num_stages > 0 ? preferred_num_stages : max_num_stages); @@ -585,21 +621,40 @@ static MegaMoESM90Config get_mega_moe_config_sm90( const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, const int& hidden, const int& intermediate_hidden, - const int& num_padded_sf_pool_tokens) { + const int& num_padded_sf_pool_tokens, + const int& split_phase_mode = 0) { + DG_HOST_ASSERT(split_phase_mode >= 0 and split_phase_mode <= 2); + const bool split_linear1_only = split_phase_mode == 1; + const bool split_linear2_only = split_phase_mode == 2; const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90( - num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); + num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens, split_phase_mode); const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and + (get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 or + (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_2WG") != 0) or + (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_2WG") != 0)) and block_m == 64; const bool use_bn256_seq_n = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and block_m == 64; + (get_env("DG_SM90_MOE_BN256_SEQ") != 0 or + (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_SEQ") != 0) or + (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_SEQ") != 0)) and + block_m == 64; DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); - const int block_n = (use_bn256_split_n or use_bn256_seq_n) ? 256 : 128; + int forced_block_n = 0; + if (split_linear1_only) + forced_block_n = get_env("DG_SM90_MOE_L1_FORCE_BLOCK_N"); + else if (split_linear2_only) + forced_block_n = get_env("DG_SM90_MOE_L2_FORCE_BLOCK_N"); + if (forced_block_n == 0) + forced_block_n = get_env("DG_SM90_MOE_FORCE_BLOCK_N"); + DG_HOST_ASSERT(forced_block_n == 0 or forced_block_n == 128 or forced_block_n == 256); + const int block_n = forced_block_n > 0 ? forced_block_n : ((use_bn256_split_n or use_bn256_seq_n) ? 256 : 128); + DG_HOST_ASSERT((not use_bn256_split_n) or block_n == 256); DG_HOST_ASSERT((not use_bn256_split_n) or num_epilogue_threads == 256); const int block_k = 128; // Default remains cluster_size=1; the opt-in cluster path uses B multicast // across adjacent M blocks. - const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; + const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 or + (split_linear1_only and get_env("DG_SM90_MOE_L1_B_STATIONARY_2WG") != 0); const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; DG_HOST_ASSERT((not use_cluster_bcast_b) or ((block_m == 64 and block_n == 128 and num_epilogue_threads == 128) or @@ -611,9 +666,19 @@ static MegaMoESM90Config get_mega_moe_config_sm90( const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; const int num_sms = device_runtime->get_num_sms(); - const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( + int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( num_experts_per_rank, num_tokens, num_topk, intermediate_hidden, block_m, block_n, num_sms); + int forced_phase_experts_per_wave = 0; + if (split_linear1_only) + forced_phase_experts_per_wave = get_env("DG_SM90_MOE_L1_EXPERTS_PER_WAVE"); + else if (split_linear2_only) + forced_phase_experts_per_wave = get_env("DG_SM90_MOE_L2_EXPERTS_PER_WAVE"); + if (forced_phase_experts_per_wave > 0) { + DG_HOST_ASSERT(forced_phase_experts_per_wave <= num_experts_per_rank); + DG_HOST_ASSERT(num_experts_per_rank % forced_phase_experts_per_wave == 0); + num_experts_per_wave = forced_phase_experts_per_wave; + } const bool split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0; const bool prefer_compact_frontend = get_sm90_moe_split_l1_l2_default() and block_n == 256 and not split_sfa_tma; @@ -647,7 +712,8 @@ static MegaMoESM90Config get_mega_moe_config_sm90( block_m, block_n, block_k, num_dispatch_threads / 32, num_epilogue_threads / 32, direct_l2_scatter_default, - default_num_stages); + default_num_stages, + split_phase_mode); const auto config = MegaMoESM90Config { block_m, block_n, block_k, @@ -661,8 +727,8 @@ static MegaMoESM90Config get_mega_moe_config_sm90( if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { const auto key = fmt::format( - "MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", - num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + "MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={}, split_phase_mode={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk, split_phase_mode); static std::unordered_set printed; if (printed.count(key) == 0) { std::cout << key << ": " << config << std::endl; diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 6411012495..60d13902f4 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -48,7 +48,11 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime bool l2_nmajor_schedule; bool l1_nmajor_schedule; bool one_warp_cleanup; + int l2_m_swizzle_group; + int l1_m_swizzle_group; int split_phase_mode; + int expert_range_start; + int expert_range_end; MegaMoESM90Config config; // Runtime arguments @@ -104,6 +108,10 @@ static void __instantiate_kernel() {{ {}, {}, {}, + {}, + {}, + {}, + {}, {} >); }}; @@ -130,7 +138,11 @@ static void __instantiate_kernel() {{ args.l2_nmajor_schedule ? "true" : "false", args.l1_nmajor_schedule ? "true" : "false", args.one_warp_cleanup ? "true" : "false", - args.split_phase_mode); + args.l2_m_swizzle_group, + args.l1_m_swizzle_group, + args.split_phase_mode, + args.expert_range_start, + args.expert_range_end); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -266,7 +278,11 @@ static void sm90_fp8_mega_moe( .one_warp_cleanup = get_env( "DG_SM90_MOE_ONE_WARP_CLEANUP", one_warp_cleanup_default ? 1 : 0) != 0, + .l2_m_swizzle_group = get_env("DG_SM90_MOE_L2_M_SWIZZLE_GROUP", 0), + .l1_m_swizzle_group = get_env("DG_SM90_MOE_L1_M_SWIZZLE_GROUP", 0), .split_phase_mode = 0, + .expert_range_start = get_env("DG_SM90_MOE_EXPERT_RANGE_START", 0), + .expert_range_end = get_env("DG_SM90_MOE_EXPERT_RANGE_END", 0), .config = config, .y = y.data_ptr(), .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, @@ -287,6 +303,53 @@ static void sm90_fp8_mega_moe( const auto launch_with_split_mode = [&](const int split_phase_mode, const char* kernel_name) { auto split_args = args; split_args.split_phase_mode = split_phase_mode; + if (split_phase_mode != 0 and get_env("DG_SM90_MOE_SPLIT_PHASE_CONFIG", 1) != 0) { + split_args.config = get_mega_moe_config_sm90( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens, + split_phase_mode); + split_args.launch_args = LaunchArgs( + num_sms, + split_args.config.num_dispatch_threads + split_args.config.num_non_epilogue_threads + split_args.config.num_epilogue_threads, + split_args.config.smem_size, split_args.config.cluster_size); + + const auto& phase_config = split_args.config; + split_args.tensor_map_l1_acts = make_tma_2d_desc( + l1_acts, hidden, phase_config.num_max_pool_tokens, + phase_config.block_k, phase_config.block_m, + static_cast(l1_acts.stride(-2)), phase_config.swizzle_acts_mode); + split_args.tensor_map_l1_acts_sf = make_tma_sf_desc( + cute::UMMA::Major::MN, l1_acts_sf, + phase_config.num_padded_sf_pool_tokens, hidden, + phase_config.block_m, kGranK, 1, 0); + split_args.tensor_map_l1_weights = make_tma_2d_desc( + l1_weights, hidden, num_experts_per_rank * intermediate_hidden * 2, + phase_config.block_k, phase_config.block_n, + static_cast(l1_weights.stride(-2)), phase_config.swizzle_weights_mode); + + const int phase_num_epilogue_warpgroups = phase_config.num_epilogue_threads / 128; + const bool phase_split_n_warpgroups = + phase_config.block_m == 64 and phase_config.block_n == 256 and phase_num_epilogue_warpgroups == 2; + const int phase_wg_block_m = phase_split_n_warpgroups + ? phase_config.block_m : phase_config.block_m / phase_num_epilogue_warpgroups; + split_args.tensor_map_l1_output = make_tma_2d_desc( + l2_acts, intermediate_hidden, phase_config.num_max_pool_tokens, + phase_config.block_n / 2, phase_wg_block_m, + static_cast(l2_acts.stride(-2)), 0); + split_args.tensor_map_l2_acts = make_tma_2d_desc( + l2_acts, intermediate_hidden, phase_config.num_max_pool_tokens, + phase_config.block_k, phase_config.block_m, + static_cast(l2_acts.stride(-2)), phase_config.swizzle_acts_mode); + split_args.tensor_map_l2_acts_sf = make_tma_sf_desc( + cute::UMMA::Major::MN, l2_acts_sf, + phase_config.num_padded_sf_pool_tokens, intermediate_hidden, + phase_config.block_m, kL2ActsSFGranK, 1, 0); + split_args.tensor_map_l2_weights = make_tma_2d_desc( + l2_weights, intermediate_hidden, num_experts_per_rank * hidden, + phase_config.block_k, phase_config.block_n, + static_cast(l2_weights.stride(-2)), phase_config.swizzle_weights_mode); + } const auto code = SM90FP8MegaMoERuntime::generate(split_args); const auto runtime = compiler->build(kernel_name, code); SM90FP8MegaMoERuntime::launch(runtime, split_args); @@ -294,8 +357,12 @@ static void sm90_fp8_mega_moe( const bool split_l1_l2 = get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; if (split_l1_l2) { - launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); - launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); + const int split_only_phase = get_env("DG_SM90_MOE_SPLIT_ONLY_PHASE", 0); + DG_HOST_ASSERT(split_only_phase >= 0 and split_only_phase <= 2); + if (split_only_phase == 0 or split_only_phase == 1) + launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); + if (split_only_phase == 0 or split_only_phase == 2) + launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); } else { launch_with_split_mode(0, "sm90_fp8_mega_moe"); } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index c98c326700..14c8f7d7d3 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -74,7 +74,11 @@ template < bool kL2NMajorScheduleRequested = false, bool kL1NMajorScheduleRequested = false, bool kOneWarpCleanupRequested = false, + uint32_t kL2MSwizzleGroupRequested = 0, + uint32_t kL1MSwizzleGroupRequested = 0, uint32_t kSplitPhaseMode = 0, + uint32_t kExpertRangeStart = 0, + uint32_t kExpertRangeEnd = 0, uint32_t L1_SHAPE_N = kIntermediateHidden * 2, uint32_t L1_SHAPE_K = kHidden, uint32_t L2_SHAPE_N = kHidden, @@ -120,6 +124,8 @@ sm90_fp8_mega_moe_impl(void* y, DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); DG_STATIC_ASSERT(kSplitPhaseMode <= 2, "Invalid SM90 MegaMoE split phase mode"); + DG_STATIC_ASSERT(kExpertRangeStart <= (kExpertRangeEnd == 0 ? kNumExpertsPerRank : kExpertRangeEnd), "Invalid expert range"); + DG_STATIC_ASSERT((kExpertRangeEnd == 0 ? kNumExpertsPerRank : kExpertRangeEnd) <= kNumExpertsPerRank, "Expert range exceeds local experts"); // ===================================================================== // Thread / warp identification @@ -197,15 +203,16 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N constexpr bool kRunOnlyLinear1 = kSplitPhaseMode == 1; constexpr bool kRunOnlyLinear2 = kSplitPhaseMode == 2; + constexpr uint32_t kEffectiveExpertRangeEnd = kExpertRangeEnd == 0 ? kNumExpertsPerRank : kExpertRangeEnd; constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && (!kUseMMASync) && + kAsyncL1TMAStoreRequested && (!kRunOnlyLinear2) && (!kUseMMASync) && (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); - constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && + constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kRunOnlyLinear1) && (!kUseMMASync) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && + constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kRunOnlyLinear1) && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && + constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kRunOnlyLinear2) && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && (kHidden / BLOCK_K) % 2 == 0; using L1WGMMA = typename mma::sm90::FP8MMASelector::type; @@ -231,9 +238,9 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t kSharedMemoryAlignment = 1024; extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; - constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = kRunOnlyLinear2 ? 0u : math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); - constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = kRunOnlyLinear2 ? 0u : math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); @@ -242,7 +249,8 @@ sm90_fp8_mega_moe_impl(void* y, // the second L2 half cannot start immediately after 16 floats in M16 decode. constexpr uint32_t kL2SFAHalfStride = math::constexpr_align(BLOCK_M * sizeof(float), 128u) / sizeof(float); - constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = 2 * kL2SFAHalfStride * sizeof(float); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = + (kRunOnlyLinear1 ? kL2SFAHalfStride : 2u * kL2SFAHalfStride) * sizeof(float); // Block (128, 128) weight SF: 1 float per (BLOCK_N, BLOCK_K) tile for L2, // 2 floats (gate/up) for L1. Loaded by math warpgroup directly from global, // so no SMEM is needed. @@ -253,9 +261,9 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) : 0u; - constexpr uint32_t SMEM_CD_L1_SIZE = + constexpr uint32_t SMEM_CD_L1_SIZE = kRunOnlyLinear2 ? 0u : kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); - constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : + constexpr uint32_t SMEM_CD_L2_SIZE = (kRunOnlyLinear1 || kDirectL2Scatter) ? 0u : kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); constexpr uint32_t SMEM_CD_L1_ASYNC_ELEMS = kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N; @@ -314,10 +322,12 @@ sm90_fp8_mega_moe_impl(void* y, // Initialization // ===================================================================== if (warp_idx == 0) { - // Clean expert-count shared memory - #pragma unroll - for (uint32_t i = lane_idx; i < kNumExperts; i += 32) - ptx::st_shared(smem_expert_count + i, 0u); + if constexpr (!kRunOnlyLinear2) { + // Clean expert-count shared memory + #pragma unroll + for (uint32_t i = lane_idx; i < kNumExperts; i += 32) + ptx::st_shared(smem_expert_count + i, 0u); + } } else if (warp_idx == 1) { // Init dispatch m-barriers #pragma unroll @@ -358,7 +368,9 @@ sm90_fp8_mega_moe_impl(void* y, L1_SHAPE_N, L1_SHAPE_K, L2_SHAPE_N, L2_SHAPE_K, kNumExpertsPerRank, kNumExpertsPerWave, - kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested>(workspace); + kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested, + kL2MSwizzleGroupRequested, kL1MSwizzleGroupRequested, + kExpertRangeStart, kEffectiveExpertRangeEnd>(workspace); // Pipeline state shared by TMA loaders and math warpgroups uint32_t stage_idx = 0, phase = 0; @@ -892,9 +904,10 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t k_idx = k_block_idx * BLOCK_K; // TMA load B (weight SF is now loaded directly by math warps from global) + const uint32_t num_tma_multicast_b = (kClusterSize > 1 and scheduler.is_b_multicast_valid) ? kClusterSize : 1u; tma::copy( tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], - k_idx, n_idx, kClusterSize); + k_idx, n_idx, num_tma_multicast_b); full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); } @@ -993,8 +1006,9 @@ sm90_fp8_mega_moe_impl(void* y, if (lane_idx == 0) empty_barriers[s]->arrive(); } else { + const auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); if (lane_idx < kClusterSize) - empty_barriers[s]->arrive(lane_idx); + empty_barriers[s]->arrive(target_cta); } }; @@ -2384,9 +2398,12 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if (epilogue_warp_idx == 0 and lane_idx == 0) phase_profile_record(kProfileCombineBarrier, combine_barrier_end - combine_barrier_start); - // Sync with dispatch (paired with dispatch's pre-cleanup sync) so that - // dispatch may now safely clean workspace state. - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + // Fused mode overlaps dispatch cleanup with combine reduce. In split K2, + // delay the cleanup sync until the combine loads are issued so K2's + // epilogue-side memory traffic does not fight the reducer. + if constexpr (!kRunOnlyLinear2) { + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + } const unsigned long long combine_reduce_start = phase_profile_clock(); constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); @@ -2496,6 +2513,9 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl const unsigned long long combine_reduce_end = phase_profile_clock(); if (epilogue_warp_idx == 0 and lane_idx == 0) phase_profile_record(kProfileCombineReduce, combine_reduce_end - combine_reduce_start); + if constexpr (kRunOnlyLinear2) { + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + } } #else if (blockIdx.x == 0 and threadIdx.x == 0) diff --git a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh index dd3b077df6..14206464a7 100644 --- a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -25,6 +25,10 @@ template (math::align(current_local_expert_idx + 1, kNumExpertsPerWave), kExpertRangeEnd); } CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { @@ -92,8 +106,47 @@ struct MegaMoEScheduler { return __reduce_add_sync(0xffffffff, num_blocks); } + template + CUTLASS_DEVICE uint32_t get_packed_l2_token_offset(const uint32_t& expert_idx) { + uint32_t num_tokens = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_tokens += math::align(stored_num_tokens_per_expert[i], kPackedAlignment); + } + return __reduce_add_sync(0xffffffff, num_tokens); + } + + template + CUTLASS_DEVICE uint32_t get_packed_l2_block_offset(const uint32_t& expert_idx) { + DG_STATIC_ASSERT(kPackedAlignment % BLOCK_M == 0, "Packed L2 alignment must be a multiple of BLOCK_M"); + return get_packed_l2_token_offset(expert_idx) / BLOCK_M; + } + + template + CUTLASS_DEVICE uint32_t get_aligned_pool_token_offset(const uint32_t& expert_idx) { + uint32_t num_tokens = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_tokens += math::ceil_div(stored_num_tokens_per_expert[i], kPoolBlockM) * kPoolBlockM; + } + return __reduce_add_sync(0xffffffff, num_tokens); + } + + CUTLASS_DEVICE uint32_t get_token_offset(const uint32_t& expert_idx) { + uint32_t num_tokens = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_tokens += stored_num_tokens_per_expert[i]; + } + return __reduce_add_sync(0xffffffff, num_tokens); + } + CUTLASS_DEVICE void advance_expert_idx() { current_pool_block_offset += get_current_num_m_blocks(); + current_token_offset += current_num_tokens; current_local_expert_idx += 1; current_num_tokens = get_num_tokens(current_local_expert_idx); } @@ -102,12 +155,17 @@ struct MegaMoEScheduler { current_local_expert_idx = expert_idx; current_num_tokens = get_num_tokens(expert_idx); current_pool_block_offset = get_pool_block_offset(expert_idx); + current_token_offset = get_token_offset(expert_idx); } CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { return current_pool_block_offset; } + CUTLASS_DEVICE uint32_t get_current_token_offset() const { + return current_token_offset; + } + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { return math::ceil_div(current_num_tokens, BLOCK_M); } @@ -121,26 +179,134 @@ struct MegaMoEScheduler { return kDoUMMAAligned ? math::align(m, 16u) : m; } + template + CUTLASS_DEVICE void map_expert_block_idx(const uint32_t& local_block_idx, + const uint32_t& num_m_units, + uint32_t& m_unit_idx, + uint32_t& n_idx) const { + if constexpr (kMSwizzleGroup > 0) { + const auto group_start = (local_block_idx / (kMSwizzleGroup * kNumBlockNs)) * kMSwizzleGroup; + const auto in_group_idx = local_block_idx % (kMSwizzleGroup * kNumBlockNs); + const auto num_m_in_group = cute::min(kMSwizzleGroup, num_m_units - group_start); + n_idx = in_group_idx / num_m_in_group; + m_unit_idx = group_start + in_group_idx - n_idx * num_m_in_group; + } else if constexpr (kNMajorSchedule) { + n_idx = local_block_idx / num_m_units; + m_unit_idx = local_block_idx - n_idx * num_m_units; + } else { + m_unit_idx = local_block_idx / kNumBlockNs; + n_idx = local_block_idx % kNumBlockNs; + } + } + + CUTLASS_DEVICE void update_peer_cta_alive(const uint32_t& num_m_blocks) { + if constexpr (kClusterSize == 1) { + is_peer_cta_alive = true; + is_a_multicast_valid = true; + is_b_multicast_valid = true; + } else { + const auto peer_m_block_idx = (m_block_idx ^ 1u); + is_peer_cta_alive = m_block_idx < num_m_blocks and + peer_m_block_idx < num_m_blocks; + is_a_multicast_valid = false; + is_b_multicast_valid = is_peer_cta_alive; + } + } + + template + CUTLASS_DEVICE void map_cluster_bcast_b_block_idx(const uint32_t& local_block_idx, + const uint32_t& num_m_blocks, + uint32_t& m_idx, + uint32_t& n_idx) { + DG_STATIC_ASSERT(kClusterSize == 2, "Cluster B multicast mapping requires 2 CTAs"); + DG_STATIC_ASSERT(kMSwizzleGroup == 0 or kMSwizzleGroup == 4 or kMSwizzleGroup == 8 or kMSwizzleGroup == 16, + "Invalid cluster B multicast M-swizzle group"); + + if constexpr (kMSwizzleGroup == 0) { + const auto even_m_blocks = num_m_blocks & ~1u; + const auto even_region_blocks = even_m_blocks * kNumBlockNs; + if (local_block_idx < even_region_blocks) { + n_idx = local_block_idx / even_m_blocks; + m_idx = local_block_idx - n_idx * even_m_blocks; + is_peer_cta_alive = true; + is_a_multicast_valid = false; + is_b_multicast_valid = true; + } else { + const auto tail_idx = local_block_idx - even_region_blocks; + m_idx = even_m_blocks; + n_idx = tail_idx; + // Odd-M tails pair CTAs across adjacent N blocks. The peer CTA is + // still alive for remote empty-barrier arrivals, but it does not + // consume the same B tile, so B multicast must be disabled. + is_peer_cta_alive = (tail_idx ^ 1u) < kNumBlockNs; + is_a_multicast_valid = false; + is_b_multicast_valid = false; + } + } else { + const auto num_blocks_per_group = kMSwizzleGroup * kNumBlockNs; + const auto group_idx = local_block_idx / num_blocks_per_group; + const auto first_m_idx = group_idx * kMSwizzleGroup; + auto in_group_idx = local_block_idx - group_idx * num_blocks_per_group; + auto num_m_in_group = cute::min(kMSwizzleGroup, num_m_blocks - first_m_idx); + + if (num_m_in_group % 2 != 0) { + const auto even_m_in_group = num_m_in_group ^ 1u; + const auto even_region_blocks = even_m_in_group * kNumBlockNs; + if (in_group_idx < even_region_blocks) { + num_m_in_group = even_m_in_group; + } else { + in_group_idx -= even_region_blocks; + m_idx = first_m_idx + even_m_in_group; + n_idx = in_group_idx; + is_peer_cta_alive = (in_group_idx ^ 1u) < kNumBlockNs; + is_a_multicast_valid = false; + is_b_multicast_valid = false; + return; + } + } + + m_idx = first_m_idx + in_group_idx % num_m_in_group; + n_idx = in_group_idx / num_m_in_group; + is_peer_cta_alive = true; + is_a_multicast_valid = false; + is_b_multicast_valid = true; + } + } + + template + CUTLASS_DEVICE void map_cluster_bcast_a_block_idx(const uint32_t& local_block_idx, + const uint32_t& num_m_blocks, + uint32_t& m_idx, + uint32_t& n_idx) { + (void)num_m_blocks; + DG_STATIC_ASSERT(kClusterSize == 2, "Cluster A multicast mapping requires 2 CTAs"); + DG_STATIC_ASSERT(kNumBlockNs % 2 == 0, "N block count must be even for 2-CTA A multicast"); + m_idx = local_block_idx / kNumBlockNs; + n_idx = local_block_idx - m_idx * kNumBlockNs; + is_peer_cta_alive = true; + is_a_multicast_valid = true; + is_b_multicast_valid = false; + } + CUTLASS_DEVICE bool fetch_next_l1_block() { const auto wave_end_expert_idx = get_wave_expert_end_idx(); while (current_local_expert_idx < wave_end_expert_idx) { const auto num_m_blocks = get_current_num_m_blocks(); - const auto num_m_units = math::ceil_div(num_m_blocks, kClusterSize); - if (block_idx < num_m_units * kNumL1BlockNs) { - if constexpr (kL1NMajorSchedule) { - n_block_idx = block_idx / num_m_units; - const auto m_unit_idx = block_idx - n_block_idx * num_m_units; - m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + const auto num_blocks = num_m_blocks * kNumL1BlockNs; + if (block_idx < num_blocks) { + if constexpr (kClusterSize == 1) { + map_expert_block_idx( + block_idx, num_m_blocks, m_block_idx, n_block_idx); + update_peer_cta_alive(num_m_blocks); } else { - const auto m_unit_idx = block_idx / kNumL1BlockNs; - m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); - n_block_idx = block_idx % kNumL1BlockNs; + map_cluster_bcast_b_block_idx( + block_idx, num_m_blocks, m_block_idx, n_block_idx); } return true; } // Current expert is fully assigned, move to the next - block_idx -= num_m_units * kNumL1BlockNs; + block_idx -= num_blocks; advance_expert_idx(); } return false; @@ -150,22 +316,26 @@ struct MegaMoEScheduler { const auto wave_end_expert_idx = get_wave_expert_end_idx(); while (current_local_expert_idx < wave_end_expert_idx) { const auto num_m_blocks = get_current_num_m_blocks(); - const auto num_m_units = math::ceil_div(num_m_blocks, kClusterSize); - if (block_idx < num_m_units * kNumL2BlockNs) { - if constexpr (kL2NMajorSchedule) { - n_block_idx = block_idx / num_m_units; - const auto m_unit_idx = block_idx - n_block_idx * num_m_units; - m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + const auto num_blocks = num_m_blocks * kNumL2BlockNs; + if (block_idx < num_blocks) { + if constexpr (kClusterSize == 1) { + map_expert_block_idx( + block_idx, num_m_blocks, m_block_idx, n_block_idx); + update_peer_cta_alive(num_m_blocks); } else { - const auto m_unit_idx = block_idx / kNumL2BlockNs; - m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); - n_block_idx = block_idx % kNumL2BlockNs; + if constexpr (kL2NMajorSchedule) { + map_cluster_bcast_a_block_idx( + block_idx, num_m_blocks, m_block_idx, n_block_idx); + } else { + map_cluster_bcast_b_block_idx( + block_idx, num_m_blocks, m_block_idx, n_block_idx); + } } return true; } // Current expert is fully assigned, move to the next - block_idx -= num_m_units * kNumL2BlockNs; + block_idx -= num_blocks; advance_expert_idx(); } return false; @@ -175,14 +345,14 @@ struct MegaMoEScheduler { // Core state machine: assigns the next block CUTLASS_DEVICE cute::tuple get_next_block() { while (true) { - if (current_local_expert_idx >= kNumExpertsPerRank) + if (current_local_expert_idx >= kExpertRangeEnd) break; if (next_phase == BlockPhase::Linear1) { if (fetch_next_l1_block()) { // Found a new L1 block - // Jump to next cluster-scheduled work unit - block_idx += kNumSMs / kClusterSize; + // Jump to next persistent CTA work item + block_idx += kNumSMs; return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; } else { // L1 for the current wave is complete, transition to L2 @@ -192,8 +362,8 @@ struct MegaMoEScheduler { } else { if (fetch_next_l2_block()) { // Found a new L2 block - // Jump to next cluster-scheduled work unit - block_idx += kNumSMs / kClusterSize; + // Jump to next persistent CTA work item + block_idx += kNumSMs; return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; } else { // Move to L1 of the next wave @@ -222,13 +392,67 @@ struct MegaMoEScheduler { __syncwarp(); } + CUTLASS_DEVICE void fetch_packed_l2_metadata_count() { + // Split K2 starts after split K1 globally completes. When K1 built the + // metadata prefix, the per-expert actual-M values are already final, so + // K2 can initialize scheduler counts without volatile polling. + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint32_t value = 0; + if (expert_idx < kNumExpertsPerRank) + value = *workspace.get_packed_l2_expert_m_ptr(expert_idx); + stored_num_tokens_per_expert[i] = value; + } + __syncwarp(); + } + + template + CUTLASS_DEVICE void build_packed_l2_metadata_prefix(const bool& is_leader_thread) const { + DG_STATIC_ASSERT(kPackedAlignment == 128, "Only 128-row packed L2 metadata is currently supported"); + if (blockIdx.x != 0 or !is_leader_thread) + return; + + uint32_t packed_offset = 0; + #pragma unroll + for (uint32_t expert_idx = 0; expert_idx < kNumExpertsPerRank; ++expert_idx) { + uint64_t value = 0; + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kNumSMs * kNumRanks); + + const uint32_t actual_m = static_cast(value); + *workspace.get_packed_l2_expert_start_ptr(expert_idx) = packed_offset; + *workspace.get_packed_l2_expert_m_ptr(expert_idx) = actual_m; + packed_offset += math::align(actual_m, kPackedAlignment); + } + *workspace.get_packed_l2_expert_start_ptr(kNumExpertsPerRank) = packed_offset; + } + + template + CUTLASS_DEVICE void fill_packed_l2_metadata_rows(const uint32_t& thread_idx, + const uint32_t& num_threads) const { + DG_STATIC_ASSERT(kPackedAlignment == 128, "Only 128-row packed L2 metadata is currently supported"); + for (uint32_t expert_idx = blockIdx.x; expert_idx < kNumExpertsPerRank; expert_idx += kNumSMs) { + const uint32_t packed_start = *workspace.get_packed_l2_expert_start_ptr(expert_idx); + const uint32_t packed_end = *workspace.get_packed_l2_expert_start_ptr(expert_idx + 1); + const uint32_t actual_m = *workspace.get_packed_l2_expert_m_ptr(expert_idx); + const uint32_t packed_m = packed_end - packed_start; + + for (uint32_t row = thread_idx; row < packed_m; row += num_threads) { + *workspace.get_packed_l2_row_to_expert_ptr(packed_start + row) = + row < actual_m ? static_cast(expert_idx) : -1; + } + } + } + template CUTLASS_DEVICE void for_each_block(Func&& func) { // Wait for all expert counters to be finalized fetch_expert_recv_count(); // Initialize current expert with 0 - set_expert_idx(0); + set_expert_idx(kExpertRangeStart); // Iterate over all blocks // TODO: add swizzle within expert waves for better L2 cache utilization @@ -248,29 +472,32 @@ struct MegaMoEScheduler { // Split-kernel mode: K1 owns only dispatch + Linear1. Unlike // for_each_block(), do not burn scheduler iterations on Linear2 blocks. fetch_expert_recv_count(); - set_expert_idx(0); - while (current_local_expert_idx < kNumExpertsPerRank) { + set_expert_idx(kExpertRangeStart); + while (current_local_expert_idx < kExpertRangeEnd) { if (fetch_next_l1_block()) { - block_idx += kNumSMs / kClusterSize; + block_idx += kNumSMs; func(current_local_expert_idx, kNumL1BlockKs, m_block_idx, n_block_idx); - } else if (current_local_expert_idx >= kNumExpertsPerRank) { + } else if (current_local_expert_idx >= kExpertRangeEnd) { break; } } } - template + template CUTLASS_DEVICE void for_each_linear2_block(Func&& func) { // Split-kernel mode: K2 starts after K1 has completed globally, so all // L2-ready masks are already final. Schedule Linear2 blocks directly // from the phase start, like a standalone grouped GEMM. - fetch_expert_recv_count(); - set_expert_idx(0); - while (current_local_expert_idx < kNumExpertsPerRank) { + if constexpr (kUsePackedL2MetadataCounts) + fetch_packed_l2_metadata_count(); + else + fetch_expert_recv_count(); + set_expert_idx(kExpertRangeStart); + while (current_local_expert_idx < kExpertRangeEnd) { if (fetch_next_l2_block()) { - block_idx += kNumSMs / kClusterSize; + block_idx += kNumSMs; func(current_local_expert_idx, kNumL2BlockKs, m_block_idx, n_block_idx); - } else if (current_local_expert_idx >= kNumExpertsPerRank) { + } else if (current_local_expert_idx >= kExpertRangeEnd) { break; } } From 0a544b3fdf8d4c811d70a9d10dae10cca038578d Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Sun, 31 May 2026 20:50:52 +0800 Subject: [PATCH 03/15] Optimize SM90 MegaMoE split path --- csrc/jit_kernels/heuristics/mega_moe.hpp | 244 +++++++++++++++--- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 38 ++- deep_gemm/include/deep_gemm/comm/barrier.cuh | 8 +- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 182 ++++++++----- deep_gemm/testing/bench.py | 37 ++- tests/bench_mega_moe_sm90.py | 74 +++--- tests/bench_mega_moe_sm90_skew.py | 113 ++++++-- tests/bench_mega_moe_sm90_skewed.py | 47 +++- tests/test_mega_moe_hopper.py | 152 +++++++---- 9 files changed, 661 insertions(+), 234 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 85ebf8e552..7d5f15e82a 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -246,8 +246,8 @@ static MegaMoEConfig get_mega_moe_config( // - No FP4: weights are FP8 e4m3, scales are per-128 channel float. // - No 2-CTA cluster MMA: TMA multicast cluster=2 may still be used. // - SF for activations is float (not UE8M0 int) and per-128 (not per-32). -// The kernel is in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and is currently -// a skeleton; this config is what the host runtime reads. +// The kernel is in `deep_gemm/impls/sm90_fp8_mega_moe.cuh`; this config is +// what the host runtime reads when instantiating a shape-specialized variant. // ============================================================================ struct MegaMoESM90Config { @@ -289,13 +289,125 @@ struct MegaMoESM90Config { } }; +static bool get_sm90_moe_split_l1_l2_default() { + return get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; +} + +struct Sm90MoeHeuristicPolicy { + bool split_l1_l2; + int num_experts_per_rank, num_topk, intermediate_hidden; + int block_m, block_n; + float expected_tokens_per_expert; + + template + bool expected_is_one_of(const Values&... values) const { + return ((expected_tokens_per_expert == static_cast(values)) or ...); + } + + bool expected_is_between(const float& low, const float& high) const { + return expected_tokens_per_expert >= low and expected_tokens_per_expert <= high; + } + + bool uses_split_bn256() const { + return split_l1_l2 and block_m == 64 and block_n == 256; + } + + bool is_main_topk8() const { + return num_experts_per_rank == 32 and num_topk == 8 and intermediate_hidden == 2048; + } + + bool is_hopper_topk6() const { + return num_experts_per_rank == 48 and num_topk == 6 and intermediate_hidden == 3072; + } + + int experts_per_wave_override() const { + if (not (block_m == 64 and block_n == 256)) + return 0; + if (is_hopper_topk6() and expected_tokens_per_expert >= 8.0f and expected_tokens_per_expert <= 32.0f) + return 16; + if (is_main_topk8() and expected_tokens_per_expert == 128.0f) + return 16; + if (is_main_topk8() and expected_tokens_per_expert >= 256.0f and expected_tokens_per_expert < 512.0f) + return 16; + return 0; + } + + bool direct_l2_scatter() const { + if (not uses_split_bn256()) + return false; + if (is_main_topk8()) { + return expected_is_one_of(2, 4, 8, 16, 32, 64, 76, 80, 88, 128) or + expected_is_between(96.0f, 120.0f) or + expected_tokens_per_expert >= 144.0f; + } + if (is_hopper_topk6()) { + return expected_is_between(61.0f, 62.0f) or + expected_tokens_per_expert >= 64.0f; + } + return false; + } + + bool l2_nmajor_schedule(const bool& eplb_hint, const bool& skew_hint) const { + if (not uses_split_bn256() or not is_main_topk8()) + return false; + if (expected_tokens_per_expert == 256.0f and eplb_hint) + return false; + if (expected_tokens_per_expert >= 256.0f and skew_hint) + return false; + return expected_tokens_per_expert >= 256.0f; + } + + bool one_warp_cleanup(const bool& masked_hint) const { + if (not uses_split_bn256()) + return false; + if (is_main_topk8() and expected_tokens_per_expert <= 64.0f) + return true; + if (is_hopper_topk6() and masked_hint and expected_tokens_per_expert == 64.0f) + return true; + return is_hopper_topk6() and expected_is_one_of(80, 128); + } + + bool stage5_pipeline(const bool& direct_l2_scatter_enabled, + const bool& eplb_hint, + const bool& skew_hint, + const bool& masked_hint) const { + if (not direct_l2_scatter_enabled) + return false; + if (is_main_topk8()) { + const bool hinted_m64 = (eplb_hint or skew_hint or masked_hint) and expected_tokens_per_expert == 64.0f; + return expected_is_one_of(2, 4, 16, 32, 128) or + hinted_m64 or + expected_tokens_per_expert >= 192.0f; + } + if (is_hopper_topk6()) { + return expected_tokens_per_expert == 64.0f or + expected_is_between(76.0f, 96.0f) or + (expected_tokens_per_expert >= 128.0f and expected_tokens_per_expert < 240.0f) or + expected_tokens_per_expert >= 384.0f; + } + return false; + } +}; + +static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n) { + return { + get_sm90_moe_split_l1_l2_default(), + num_experts_per_rank, + num_topk, + intermediate_hidden, + block_m, + block_n, + static_cast(num_tokens) * num_topk / num_experts_per_rank + }; +} + static std::tuple get_block_config_for_mega_moe_sm90( const int& num_ranks, const int& num_experts, const int& num_max_tokens_per_rank, const int& num_topk, const int& num_tokens) { - // Keep mma.sync decode variants gated. M16 was correct but slower on H20; - // M32 is an experimental middle ground that cuts M64 padding without paying - // as much CTA/epilogue overhead as M16. + // Keep mma.sync decode variants opt-in; the default path uses M64 CTAs. const int num_experts_per_rank = num_experts / num_ranks; const float expected_tokens_per_expert = static_cast(num_tokens) * num_topk / num_experts_per_rank; @@ -303,28 +415,30 @@ static std::tuple get_block_config_for_mega_moe_sm90( ? get_env("DG_SM90_MOE_MMA_SYNC_M") : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); + const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); + const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); + DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); + DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or + forced_epilogue_warpgroups == 1 or + forced_epilogue_warpgroups == 2); const bool use_mma_sync_decode = requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG") != 0 and not use_mma_sync_decode; + get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and + forced_block_m != 128 and not use_mma_sync_decode; const bool use_bn256_seq_n = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and not use_mma_sync_decode; + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and + forced_block_m != 128 and not use_mma_sync_decode; const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n or use_bn256_seq_n))); - const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); - const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); - DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); - DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or - forced_epilogue_warpgroups == 1 or - forced_epilogue_warpgroups == 2); const int block_m = forced_block_m > 0 ? forced_block_m : (use_b_stationary_2wg ? 128 : (use_mma_sync_decode ? requested_mma_m : 64)); const int num_epilogue_warpgroups = forced_epilogue_warpgroups > 0 ? forced_epilogue_warpgroups - : ((use_b_stationary_2wg or use_bn256_split_n) ? 2 : 1); + : ((block_m == 128 or use_b_stationary_2wg or use_bn256_split_n) ? 2 : 1); DG_HOST_ASSERT(block_m % num_epilogue_warpgroups == 0); DG_HOST_ASSERT((block_m != 16 and block_m != 32) or num_epilogue_warpgroups == 1); DG_HOST_ASSERT(block_m != 128 or num_epilogue_warpgroups == 2); @@ -345,15 +459,12 @@ static int get_num_experts_per_wave_for_mega_moe_sm90( return forced; } - // SM90 (Hopper) wave heuristic. - // - // The generic heuristic is useful in the middle of the block_m=64 band, but - // very sparse routing and large batches both do better as a single all-expert - // wave: sparse cases avoid extra L1->L2 wave transitions, while large cases - // keep enough work resident without fragmenting expert scheduling. - const float expected_tokens_per_expert = - static_cast(num_tokens) * num_topk / num_experts_per_rank; - if (block_m == 64 and (expected_tokens_per_expert < 1.0f or expected_tokens_per_expert > 4.0f)) { + const auto policy = get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, intermediate_hidden, block_m, block_n); + if (const int wave_override = policy.experts_per_wave_override(); wave_override > 0) + return wave_override; + if (block_m == 64 and + (policy.expected_tokens_per_expert < 1.0f or policy.expected_tokens_per_expert > 4.0f)) { return num_experts_per_rank; } return get_num_experts_per_wave_for_mega_moe( @@ -361,11 +472,40 @@ static int get_num_experts_per_wave_for_mega_moe_sm90( intermediate_hidden, block_m, block_n, num_sms); } +static bool get_sm90_moe_direct_l2_scatter_default( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n) { + return get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n).direct_l2_scatter(); +} + +static bool get_sm90_moe_l2_nmajor_schedule_default( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n) { + return get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n).l2_nmajor_schedule( + get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, + get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0); +} + +static bool get_sm90_moe_one_warp_cleanup_default( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n) { + return get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n).one_warp_cleanup( + get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0); +} + static std::pair get_pipeline_config_for_mega_moe_sm90( const int& smem_capacity, const int& num_experts, const int& hidden, const int& block_m, const int& block_n, const int& block_k, - const int& num_dispatch_warps, const int& num_epilogue_warps) { + const int& num_dispatch_warps, const int& num_epilogue_warps, + const bool& direct_l2_scatter_default = false, + const int& default_num_stages = 0) { constexpr int kSmemAlignment = 1024; // Dispatch region (same as SM100) @@ -387,9 +527,11 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int wg_block_n = (split_n_warpgroups or serial_n_warpgroups) ? block_n / 2 : block_n; const int smem_cd_accum = (block_m == 16 or block_m == 32) ? align(block_m * block_n * static_cast(sizeof(float)), kSmemAlignment) : 0; const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) - const bool direct_l2_scatter = get_env("DG_SM90_MOE_DIRECT_L2_SCATTER", 0) != 0 and + const bool direct_l2_scatter = get_env( + "DG_SM90_MOE_DIRECT_L2_SCATTER", + direct_l2_scatter_default ? 1 : 0) != 0 and block_m != 16 and block_m != 32 and - not split_n_warpgroups and not serial_n_warpgroups; + not serial_n_warpgroups and wg_block_n == 128; const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and block_m != 16 and block_m != 32 and not split_n_warpgroups and num_epilogue_warpgroups == 1; @@ -422,11 +564,18 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( // Fixed total const int smem_fixed = smem_dispatch_size + smem_cd + smem_barriers_fixed; - // Select max num_stages, with an optional SM90-only sweep override. + // Select the retained stage count for the current shape. const int max_num_stages = (smem_capacity - smem_fixed) / (smem_per_stage + smem_barriers_per_stage); + const bool split_l1_l2 = get_sm90_moe_split_l1_l2_default(); + const bool prefer_bn256_split = split_l1_l2 and block_n == 256; + const int preferred_num_stages = default_num_stages > 0 + ? std::min(default_num_stages, max_num_stages) + : (prefer_bn256_split ? std::min(4, max_num_stages) : 0); const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); - const int num_stages = forced_num_stages > 0 ? forced_num_stages : max_num_stages; + const int num_stages = forced_num_stages > 0 + ? forced_num_stages + : (preferred_num_stages > 0 ? preferred_num_stages : max_num_stages); DG_HOST_ASSERT(num_stages >= 2 and num_stages <= max_num_stages); return {num_stages, smem_fixed + num_stages * (smem_per_stage + smem_barriers_per_stage)}; @@ -440,19 +589,16 @@ static MegaMoESM90Config get_mega_moe_config_sm90( const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90( num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG") != 0 and block_m != 16 and block_m != 32; + get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and + block_m == 64; const bool use_bn256_seq_n = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and block_m != 16 and block_m != 32; + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and block_m == 64; DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); const int block_n = (use_bn256_split_n or use_bn256_seq_n) ? 256 : 128; DG_HOST_ASSERT((not use_bn256_split_n) or num_epilogue_threads == 256); const int block_k = 128; - // Default remains cluster_size=1. The experimental cluster=2 path below - // is M-split/B-multicast for exact-balanced large-M tests: peer CTAs process - // adjacent M blocks for the same expert/N tile and share the B TMA load. - // `DG_SM90_MOE_B_STATIONARY_2WG=1` extends the same idea to M128/2WG: - // two WGs split M within each CTA, and the cluster pair reuses one B tile - // across four M64 WGMMA consumers. + // Default remains cluster_size=1; the opt-in cluster path uses B multicast + // across adjacent M blocks. const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; DG_HOST_ASSERT((not use_cluster_bcast_b) or @@ -469,19 +615,39 @@ static MegaMoESM90Config get_mega_moe_config_sm90( num_experts_per_rank, num_tokens, num_topk, intermediate_hidden, block_m, block_n, num_sms); - const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS"); - const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND") != 0; + const bool split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0; + const bool prefer_compact_frontend = get_sm90_moe_split_l1_l2_default() and block_n == 256 and not split_sfa_tma; + const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND", + prefer_compact_frontend ? 1 : 0) != 0; + const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS", + compact_frontend ? 2 : 0); DG_HOST_ASSERT(forced_dispatch_warps == 0 or forced_dispatch_warps == 2 or forced_dispatch_warps == 4 or forced_dispatch_warps == 8); const int num_dispatch_threads = (forced_dispatch_warps > 0 ? forced_dispatch_warps : 4) * 32; + DG_HOST_ASSERT((not split_sfa_tma) or (not compact_frontend)); DG_HOST_ASSERT((not compact_frontend) or num_dispatch_threads == 64); const int num_non_epilogue_threads = compact_frontend ? 64 : 128; + DG_HOST_ASSERT((num_dispatch_threads + num_non_epilogue_threads) % 128 == 0); + const auto policy = get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n); + const bool direct_l2_scatter_default = policy.direct_l2_scatter(); + const bool direct_l2_scatter_enabled = get_env( + "DG_SM90_MOE_DIRECT_L2_SCATTER", + direct_l2_scatter_default ? 1 : 0) != 0; + const int default_num_stages = policy.stage5_pipeline( + direct_l2_scatter_enabled, + get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, + get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0, + get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0) ? 5 : 0; const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe_sm90( SM90ArchSpec::smem_capacity, num_experts, hidden, block_m, block_n, block_k, - num_dispatch_threads / 32, num_epilogue_threads / 32); + num_dispatch_threads / 32, num_epilogue_threads / 32, + direct_l2_scatter_default, + default_num_stages); const auto config = MegaMoESM90Config { block_m, block_n, block_k, diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 88b8b31705..6411012495 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -18,10 +18,8 @@ namespace deep_gemm { // SM90 (Hopper) FP8 MegaMoE host runtime // ---------------------------------------------------------------------------- // This is the SM90 counterpart of `SM100FP8FP4MegaMoERuntime`. The kernel -// itself lives in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and is currently a -// skeleton: dispatch/combine paths are intended to be portable from the SM100 -// version, while the GEMM (TMA load + WGMMA + epilogue) is being implemented -// in a follow-up step. +// itself lives in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and uses the same +// dispatch/combine contract with an SM90 FP8 TMA/WGMMA implementation. // // Differences from SM100 path: // * Activations and weights are both FP8 (e4m3); no FP4. @@ -49,6 +47,7 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime bool l1_dual_k_accum; bool l2_nmajor_schedule; bool l1_nmajor_schedule; + bool one_warp_cleanup; int split_phase_mode; MegaMoESM90Config config; @@ -91,7 +90,6 @@ static void __instantiate_kernel() {{ {}, {}, {}, - {}, {}, {}, {}, {}, {}, {}, @@ -104,6 +102,8 @@ static void __instantiate_kernel() {{ {}, {}, {}, + {}, + {}, {} >); }}; @@ -129,6 +129,7 @@ static void __instantiate_kernel() {{ args.l1_dual_k_accum ? "true" : "false", args.l2_nmajor_schedule ? "true" : "false", args.l1_nmajor_schedule ? "true" : "false", + args.one_warp_cleanup ? "true" : "false", args.split_phase_mode); } @@ -201,8 +202,8 @@ static void sm90_fp8_mega_moe( // must use no shared-memory swizzle. Later L2 TMA loads may still swizzle // from this row-major global buffer into their own SMEM tile. // The default TMA store is issued per warpgroup, each writing a WG_BLOCK_M - // row tile. The split-N experiment has two WGs produce different N halves - // of the same M rows, then one TMA store writes the full 64x128 post-SwiGLU tile. + // row tile. In split-N mode, two WGs produce different N halves of the same + // M rows, then one TMA store writes the full 64x128 post-SwiGLU tile. const int num_epilogue_warpgroups_h = config.num_epilogue_threads / 128; const bool split_n_warpgroups_h = config.block_m == 64 and config.block_n == 256 and num_epilogue_warpgroups_h == 2; @@ -234,6 +235,15 @@ static void sm90_fp8_mega_moe( // Launch const auto num_sms = device_runtime->get_num_sms(); + const bool direct_l2_scatter_default = get_sm90_moe_direct_l2_scatter_default( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, config.block_m, config.block_n); + const bool l2_nmajor_schedule_default = get_sm90_moe_l2_nmajor_schedule_default( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, config.block_m, config.block_n); + const bool one_warp_cleanup_default = get_sm90_moe_one_warp_cleanup_default( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, config.block_m, config.block_n); const SM90FP8MegaMoERuntime::Args args = { .num_max_tokens_per_rank = num_max_tokens_per_rank, .hidden = hidden, .intermediate_hidden = intermediate_hidden, @@ -243,12 +253,19 @@ static void sm90_fp8_mega_moe( .fast_math = fast_math, .async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0, .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, - .direct_l2_scatter = get_env("DG_SM90_MOE_DIRECT_L2_SCATTER", 0) != 0, + .direct_l2_scatter = get_env( + "DG_SM90_MOE_DIRECT_L2_SCATTER", + direct_l2_scatter_default ? 1 : 0) != 0, .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, - .l2_nmajor_schedule = get_env("DG_SM90_MOE_L2_NMAJOR", 0) != 0, + .l2_nmajor_schedule = get_env( + "DG_SM90_MOE_L2_NMAJOR", + l2_nmajor_schedule_default ? 1 : 0) != 0, .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, + .one_warp_cleanup = get_env( + "DG_SM90_MOE_ONE_WARP_CLEANUP", + one_warp_cleanup_default ? 1 : 0) != 0, .split_phase_mode = 0, .config = config, .y = y.data_ptr(), @@ -275,8 +292,7 @@ static void sm90_fp8_mega_moe( SM90FP8MegaMoERuntime::launch(runtime, split_args); }; - const bool split_l1_l2 = get_env( - "DG_SM90_MOE_SPLIT_L1_L2", num_max_tokens_per_rank >= 1024 ? 1 : 0) != 0; + const bool split_l1_l2 = get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; if (split_l1_l2) { launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); diff --git a/deep_gemm/include/deep_gemm/comm/barrier.cuh b/deep_gemm/include/deep_gemm/comm/barrier.cuh index eb9858d801..5f388282c7 100644 --- a/deep_gemm/include/deep_gemm/comm/barrier.cuh +++ b/deep_gemm/include/deep_gemm/comm/barrier.cuh @@ -6,6 +6,10 @@ #include #include +#ifndef DG_NVLINK_BARRIER_TIMEOUT_PRINTF +#define DG_NVLINK_BARRIER_TIMEOUT_PRINTF 0 +#endif + namespace deep_gemm::comm { CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { @@ -67,9 +71,11 @@ CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace, const auto start_clock = clock64(); while (ptx::ld_acq_sys(signal_ptr) != target) { if (clock64() - start_clock >= kNumTimeoutCycles) { +#if DG_NVLINK_BARRIER_TIMEOUT_PRINTF printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n", sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag); - DG_DEVICE_ASSERT(false and "NVLink barrier timeout"); +#endif + DG_TRAP_ONLY_DEVICE_ASSERT(false and "NVLink barrier timeout"); } } } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index 248ffc810f..c98c326700 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -73,6 +73,7 @@ template < bool kL1DualKAccumRequested = false, bool kL2NMajorScheduleRequested = false, bool kL1NMajorScheduleRequested = false, + bool kOneWarpCleanupRequested = false, uint32_t kSplitPhaseMode = 0, uint32_t L1_SHAPE_N = kIntermediateHidden * 2, uint32_t L1_SHAPE_K = kHidden, @@ -128,15 +129,19 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); const uint32_t lane_idx = ptx::get_lane_idx(); - // Prefetch all TMA descriptors at the very beginning + // Prefetch the TMA descriptors used by this split phase. if (warp_idx == 0 and cute::elect_one_sync()) { - cute::prefetch_tma_descriptor(&tensor_map_l1_acts); - cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); - cute::prefetch_tma_descriptor(&tensor_map_l1_weights); - cute::prefetch_tma_descriptor(&tensor_map_l1_output); - cute::prefetch_tma_descriptor(&tensor_map_l2_acts); - cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); - cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + if constexpr (kSplitPhaseMode != 2) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + } + if constexpr (kSplitPhaseMode != 1) { + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + } } // ===================================================================== @@ -193,11 +198,11 @@ sm90_fp8_mega_moe_impl(void* y, constexpr bool kRunOnlyLinear1 = kSplitPhaseMode == 1; constexpr bool kRunOnlyLinear2 = kSplitPhaseMode == 2; constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && - kNumEpilogueWarpgroups == 1; + kAsyncL1TMAStoreRequested && (!kUseMMASync) && + (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && - (!kSplitNWarpgroups) && (!kSerialNWarpgroups); + (!kSerialNWarpgroups) && WG_BLOCK_N == 128; constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && @@ -252,8 +257,10 @@ sm90_fp8_mega_moe_impl(void* y, kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_L1_ASYNC_ELEMS = + kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N; constexpr uint32_t SMEM_CD_L1_ASYNC_SIZE = kAsyncL1TMAStore ? - 2 * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) : 0u; + 2 * SMEM_CD_L1_ASYNC_ELEMS * sizeof(cutlass::float_e4m3_t) : 0u; constexpr uint32_t SMEM_CD_OUTPUT_BASE_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; constexpr uint32_t SMEM_CD_OUTPUT_UNALIGNED_SIZE = @@ -323,15 +330,17 @@ sm90_fp8_mega_moe_impl(void* y, #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { // Producer arrivals: A(+SFA) + B, or A + B + SFA when - // the split-SFA experiment uses an otherwise idle TMA warp. + // split-SFA uses an otherwise idle TMA warp. full_barriers[i]->init(kSplitSFATMA ? 3 : 2); // With cluster multicast the leader CTA's TMA warp waits on peer // empty barriers too, so every math warp releases both CTAs. empty_barriers[i]->init(kClusterSize * kNumEpilogueWarps); } - #pragma unroll - for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) - combine_barriers[i]->init(1); + if constexpr (!kRunOnlyLinear1) { + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } } cutlass::arch::fence_barrier_init(); } @@ -393,6 +402,7 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t kProfileGemmCore = 5; constexpr uint32_t kProfileL1Epilogue = 6; constexpr uint32_t kProfileL2Epilogue = 7; + constexpr uint32_t kNumProfileMetrics = 8; const auto phase_profile_clock = [&]() -> unsigned long long { if constexpr (kPhaseProfileRequested) { unsigned long long t; @@ -408,8 +418,8 @@ sm90_fp8_mega_moe_impl(void* y, auto profile = reinterpret_cast( cumulative_local_expert_recv_stats + kNumExpertsPerRank); atomicAdd(profile + metric, cycles); - atomicMax(profile + 8 + metric, cycles); - atomicAdd(profile + 16 + metric, 1ull); + atomicMax(profile + kNumProfileMetrics + metric, cycles); + atomicAdd(profile + 2 * kNumProfileMetrics + metric, 1ull); } } }; @@ -456,26 +466,59 @@ sm90_fp8_mega_moe_impl(void* y, const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); const auto cleanup_pool_block_offset = scheduler.get_pool_block_offset(i); - ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + if constexpr (!kOneWarpCleanupRequested) + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); - if (warp_idx == 0) { - *workspace.get_expert_recv_count_sum_ptr(i) = 0; - } else if (warp_idx == 1) { - if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) - ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); - __syncwarp(); + if constexpr (kOneWarpCleanupRequested) { + if (warp_idx == 0) { + if (lane_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + if (cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + } + } + } else { + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } } - for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) - *workspace.get_expert_recv_count_ptr(j, i) = 0; - __syncwarp(); + if constexpr (!kRunOnlyLinear2) { + if constexpr (kOneWarpCleanupRequested) { + if (warp_idx == 0) { + for (uint32_t j = lane_idx; j < kNumRanks; j += 32) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + } + } else { + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + } + } - for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { - *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; - *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; + if constexpr (kOneWarpCleanupRequested) { + if (warp_idx == 0) { + for (uint32_t j = lane_idx; j < num_recv_m_blocks; j += 32) { + *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; + if constexpr (!kRunOnlyLinear2) + *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } else { + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; + if constexpr (!kRunOnlyLinear2) + *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; + } + __syncwarp(); } - __syncwarp(); } } }; @@ -724,10 +767,10 @@ sm90_fp8_mega_moe_impl(void* y, phase_profile_record(kProfileDispatchPull, dispatch_pull_end - dispatch_pull_start); phase_profile_record(kProfileDispatchTotal, dispatch_pull_end - dispatch_total_start); } - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); if constexpr (kRunOnlyLinear1) return; + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); cleanup_workspace(); comm::nvlink_barrier(); @@ -957,12 +1000,18 @@ sm90_fp8_mega_moe_impl(void* y, const auto notify_l1_ready = [&](const uint32_t& ready_pool_block_idx, const uint32_t& ready_n_block_idx) { - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - ptx::red_or_rel_gpu( - workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), - 1ull << ready_n_block_idx); + // True split L2 runs in a later kernel launch and skips arrival-mask waits. + if constexpr (!kRunOnlyLinear1) { + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), + 1ull << ready_n_block_idx); + } + __syncwarp(); + } else { + (void)ready_pool_block_idx; + (void)ready_n_block_idx; } - __syncwarp(); }; const auto drain_async_l1_store_stage = [&](const uint32_t& store_stage) { @@ -1631,13 +1680,8 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl // L2 weight SF shape: (E, H/128, IH/128) MN-major. One scalar per // (BLOCK_N, BLOCK_K) tile, broadcast across all WGMMA accumulators. // - // NOTE: we tried hoisting these LDGs above the barrier wait and/or - // having only lane 0 load + shfl-broadcast. Both regressed on H20 - // by 7-11% across all batch sizes, presumably because (a) Hopper's - // L1 read-only cache already coalesces same-address LDGs from all - // 128 WG threads and (b) hoisting contended with the dispatch - // warps' NVLink LDGs on the MIO unit. Keep the simple parallel - // post-wait load. + // Keep the simple parallel post-wait load; same-address LDGs + // are handled efficiently by Hopper's read-only cache. constexpr uint32_t kL1SFKBlocks = kHidden / 128; constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; @@ -1980,7 +2024,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl float u_r1_c0 = final_accum[up*4 + 2]; clamp_up(u_r1_c0); float u_r1_c1 = final_accum[up*4 + 3]; clamp_up(u_r1_c1); - auto silu = [](float x) -> float { + auto silu = [](float x) { const float e = kFastMath ? __expf(-x) : expf(-x); const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); return x * sig; @@ -2064,7 +2108,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if constexpr (kAsyncL1TMAStore) drain_async_l1_store_stage(l1_store_stage); auto* smem_cd_l1_wg = smem_cd_l1 - + l1_store_stage * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N + + l1_store_stage * SMEM_CD_L1_ASYNC_ELEMS + (kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M * L1_OUT_BLOCK_N); #pragma unroll for (uint32_t p = 0; p < kNumPairs; ++ p) { @@ -2125,9 +2169,19 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl cute::tma_store_arrive(); } __syncwarp(); - ptx::tma_store_wait<0>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); + if constexpr (kAsyncL1TMAStore) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + async_l1_store_pending[l1_store_stage] = true; + async_l1_store_pool[l1_store_stage] = pool_block_idx; + async_l1_store_n[l1_store_stage] = n_block_idx; + async_l1_store_stage ^= 1u; + } else { + ptx::tma_store_wait<0>(); + if constexpr (!kRunOnlyLinear1) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } + } } else { ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { @@ -2149,8 +2203,10 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl async_l1_store_stage ^= 1u; } else { ptx::tma_store_wait<0>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); + if constexpr (!kRunOnlyLinear1) { + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + notify_l1_ready(pool_block_idx, n_block_idx); + } } } const unsigned long long block_epilogue_end = phase_profile_clock(); @@ -2163,19 +2219,22 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if constexpr (kDirectL2Scatter) { DG_STATIC_ASSERT(WG_BLOCK_N == 128, "Direct L2 scatter prototype only supports N128"); - auto scatter_direct_row = [&](const uint32_t& row_offset, const bool& valid_row, const uint32_t& row_accum_offset) { + auto scatter_direct_row = [&](const uint32_t& row_offset, const bool& valid_row, + const uint32_t& row_accum_offset) { if (valid_row) { uint32_t dst_rank_idx = 0, dst_token_idx = 0, dst_topk_idx = 0; + const uint32_t row_group_base = lane_idx - col_idx; if (col_idx == 0) { const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row_offset); dst_rank_idx = src_metadata.rank_idx; dst_token_idx = src_metadata.token_idx; dst_topk_idx = src_metadata.topk_idx; } - const int src_lane = static_cast(lane_idx - col_idx); - dst_rank_idx = __shfl_sync(0xffffffff, dst_rank_idx, src_lane); - dst_token_idx = __shfl_sync(0xffffffff, dst_token_idx, src_lane); - dst_topk_idx = __shfl_sync(0xffffffff, dst_topk_idx, src_lane); + const uint32_t row_group_mask = 0xfu << row_group_base; + const int src_lane = static_cast(row_group_base); + dst_rank_idx = __shfl_sync(row_group_mask, dst_rank_idx, src_lane); + dst_token_idx = __shfl_sync(row_group_mask, dst_token_idx, src_lane); + dst_topk_idx = __shfl_sync(row_group_mask, dst_topk_idx, src_lane); const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) .get_data_buffer(dst_token_idx); auto dst_base = math::advance_ptr( @@ -2201,7 +2260,6 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl scatter_direct_row(row_offset_r0, valid_r0, 0); scatter_direct_row(row_offset_r1, valid_r1, 2); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } else { // STSM into smem_cd_l2 (BF16). Reuse SM100 column-swizzle layout. #pragma unroll @@ -2310,7 +2368,6 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if constexpr (kRunOnlyLinear1) { if constexpr (kAsyncL1TMAStore) drain_all_async_l1_stores(); - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); return; } @@ -2350,9 +2407,6 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements"); DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); - DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( - reinterpret_cast(barrier_start_ptr) - smem_buffer)); - const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); }); diff --git a/deep_gemm/testing/bench.py b/deep_gemm/testing/bench.py index 915be3533c..0e75bbabe4 100644 --- a/deep_gemm/testing/bench.py +++ b/deep_gemm/testing/bench.py @@ -1,6 +1,7 @@ import os import sys import torch +import torch.distributed as dist from typing import Callable, Optional @@ -89,10 +90,13 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): return (1, ) * len(kernel_names) if is_tuple else 1 - # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle - flush_l2_size = int(8e9 // 4) + # Flush L2 between timed calls. Large MoE benchmark runs can use shared nodes + # with limited free memory, so allow lowering only the flush buffer while + # preserving the default behavior. + flush_l2_bytes = int(os.environ.get('DG_BENCH_FLUSH_L2_BYTES', str(int(8e9)))) + flush_l2_size = max(0, flush_l2_bytes // 4) - # For some auto-tuning kernels with prints + # Warm up once before profiling. fn() # Profile @@ -104,7 +108,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, with profiler: for i in range(2): for _ in range(num_tests): - if flush_l2: + if flush_l2 and flush_l2_size > 0: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() if barrier is not None: # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead @@ -116,11 +120,34 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, profiler.step() # Parse the profiling table - prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + report_split_kernels = int(os.environ.get('DG_SM90_MOE_REPORT_SPLIT_KERNELS', 0)) != 0 + max_name_column_width = int(os.environ.get('DG_BENCH_MAX_NAME_COLUMN_WIDTH', 100)) + if report_split_kernels and with_multiple_kernels: + max_name_column_width = max(max_name_column_width, 512) + prof_lines = profiler.key_averages().table( + sort_by='cuda_time_total', + max_name_column_width=max_name_column_width).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names if not with_multiple_kernels: for name in kernel_names: assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}' + elif report_split_kernels: + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + for line in prof_lines: + if any(name in line for name in kernel_names): + phase = '?' + if 'sm90_fp8_mega_moe_impl<' in line: + template_args = line.split('sm90_fp8_mega_moe_impl<', 1)[1].split('>(void', 1)[0] + args = [arg.strip() for arg in template_args.split(',')] + if len(args) > 29: + phase_arg = args[29] + phase = {'1u': 'l1', '2u': 'l2', '0u': 'one'}.get(phase_arg, phase_arg) + fields = line.split() + cuda_time = fields[-2] if len(fields) >= 2 else 'unknown' + count = fields[-1] if fields else 'unknown' + print(f' > split_kernel rank={rank} phase={phase} cuda_time={cuda_time} count={count}') + if int(os.environ.get('DG_SM90_MOE_REPORT_SPLIT_KERNELS_RAW', 0)) != 0: + print(f' > split_kernel_raw rank={rank}: {" ".join(fields)}') # Save chrome traces if trace_path is not None: diff --git a/tests/bench_mega_moe_sm90.py b/tests/bench_mega_moe_sm90.py index 3711dd48c2..ae104f4fd4 100644 --- a/tests/bench_mega_moe_sm90.py +++ b/tests/bench_mega_moe_sm90.py @@ -4,7 +4,7 @@ ``--local-rank-idx`` interface so the same ``scripts/run_ncu_mega_moe.sh`` pattern can drive it for SM90. -In normal (non-NCU) mode it sweeps a list of ``num_tokens`` values (default: +In normal (non-NCU) mode it runs a list of ``num_tokens`` values (default: 1, 2, 4, 8, 16, 32) and reports per-call kernel time via the same ``bench_kineto`` helper used by the SM100 perf test, plus a rough TFLOPS / HBM GB/s figure useful for tracking optimisation deltas. @@ -31,11 +31,17 @@ def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g, n, k = w.shape assert n % 128 == 0 and k % 128 == 0 - w_view = w.view(g, n // 128, 128, k // 128, 128).float() - amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) - sf = amax / 448.0 - w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) - return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + chunk_g = 4 + w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) + sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) + for start in range(0, g, chunk_g): + end = min(start + chunk_g, g) + w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() + sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / 448.0 + w_fp8[start:end].copy_( + (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k)) + sf[start:end].copy_(sf_chunk) + return w_fp8, sf.contiguous() def _run_one_config(args, num_tokens, num_max_tokens_per_rank, @@ -79,6 +85,7 @@ def _run_one_config(args, num_tokens, num_max_tokens_per_rank, phase_profile_enabled = os.environ.get('DG_SM90_MOE_PHASE_PROFILE', '0') != '0' phase_profile_ints = 64 if phase_profile_enabled else 0 cum_stats = torch.zeros(num_experts_per_rank + phase_profile_ints, dtype=torch.int, device='cuda') + use_masked_hint = args.masked_ratio > 0 # Stage inputs once; bench-loop re-copies them each call (bench helper expects # an idempotent ``fn``). @@ -88,14 +95,24 @@ def run_fused(): buffer.topk_idx[:num_tokens].copy_(topk_idx) buffer.topk_weights[:num_tokens].copy_(topk_w) y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - deep_gemm.fp8_mega_moe( - y, transformed_l1, transformed_l2, buffer, - cumulative_local_expert_recv_stats=cum_stats, - recipe=(128, 128, 128), - activation='swiglu', - activation_clamp=activation_clamp, - fast_math=fast_math, - ) + old_masked_hint = os.environ.get('DG_SM90_MOE_MASKED_HINT') + if use_masked_hint: + os.environ['DG_SM90_MOE_MASKED_HINT'] = '1' + try: + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + finally: + if use_masked_hint: + if old_masked_hint is None: + os.environ.pop('DG_SM90_MOE_MASKED_HINT', None) + else: + os.environ['DG_SM90_MOE_MASKED_HINT'] = old_masked_hint return y if args.ncu_profile_only: @@ -114,24 +131,12 @@ def run_fused(): cum_stats.zero_() torch.cuda.synchronize() dist.barrier() - # NSYS MULTI-ITER (aichenf): N timed iters with barrier+sleep between them. - # bench_kineto returns 1 under DG_USE_NVIDIA_TOOLS=1, but this loop puts - # multiple mega_moe instances on the nsys timeline so we can measure variance. - import os as _os - _nsys_iters = int(_os.environ.get('NSYS_ITERS', '0')) - if _nsys_iters > 0: - for _it in range(_nsys_iters): - torch.cuda.synchronize() - dist.barrier() - torch.cuda._sleep(int(2e7)) # 10ms gap between iters - dist.barrier() - run_fused() - torch.cuda.synchronize() - dist.barrier() t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', barrier=lambda: dist.barrier(), num_tests=args.num_tests, - suppress_kineto_output=True) + suppress_kineto_output=True, + with_multiple_kernels=os.environ.get( + 'DG_SM90_MOE_SPLIT_L1_L2', '1') != '0') # Count tokens that landed on this rank for stats gathered_topk_idx = uneven_all_gather(topk_idx, group=group) @@ -161,13 +166,18 @@ def run_fused(): ) if phase_profile_enabled: torch.cuda.synchronize() - profile = cum_stats[num_experts_per_rank:num_experts_per_rank + 48].view(torch.int64).cpu().tolist() names = [ 'dispatch_total', 'dispatch_pull', 'math_loop', 'combine_barrier', 'combine_reduce', 'gemm_core', 'l1_epilogue', 'l2_epilogue', ] + num_profile_metrics = len(names) + profile = cum_stats[ + num_experts_per_rank:num_experts_per_rank + phase_profile_ints + ].view(torch.int64).cpu().tolist() for i, name in enumerate(names): - total, max_v, count = profile[i], profile[8 + i], profile[16 + i] + total = profile[i] + max_v = profile[num_profile_metrics + i] + count = profile[2 * num_profile_metrics + i] avg = float(total) / count if count else 0.0 dist_print( f' phase {name:16s} avg={avg:10.0f} max={max_v:10d} count={count}', @@ -231,7 +241,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument('--local-rank-idx', type=int, default=None) parser.add_argument('--batches', type=int, nargs='+', default=None, - help='List of num_tokens to sweep (default: 1 2 4 8 16 32)') + help='List of num_tokens to benchmark (default: 1 2 4 8 16 32)') parser.add_argument('--hidden', type=int, default=7168) parser.add_argument('--intermediate-hidden', type=int, default=2048) parser.add_argument('--num-experts', type=int, default=256) diff --git a/tests/bench_mega_moe_sm90_skew.py b/tests/bench_mega_moe_sm90_skew.py index 4b6855d7ba..1d8b40a728 100644 --- a/tests/bench_mega_moe_sm90_skew.py +++ b/tests/bench_mega_moe_sm90_skew.py @@ -34,11 +34,17 @@ def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g, n, k = w.shape assert n % 128 == 0 and k % 128 == 0 - w_view = w.view(g, n // 128, 128, k // 128, 128).float() - amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) - sf = amax / 448.0 - w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) - return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + chunk_g = 4 + w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) + sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) + for start in range(0, g, chunk_g): + end = min(start + chunk_g, g) + w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() + sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / 448.0 + w_fp8[start:end].copy_( + (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k)) + sf[start:end].copy_(sf_chunk) + return w_fp8, sf.contiguous() def _make_global_bias(num_experts: int, alpha: float, gain: float, seed: int): @@ -153,7 +159,12 @@ def _run_one_config(args, num_tokens, num_max_tokens_per_rank, (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), ) - cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + phase_profile_enabled = os.environ.get('DG_SM90_MOE_PHASE_PROFILE', '0') != '0' + phase_profile_ints = 64 if phase_profile_enabled else 0 + cum_stats = torch.zeros(num_experts_per_rank + phase_profile_ints, dtype=torch.int, device='cuda') + use_eplb_hint = bool(replica_for) + use_skew_hint = global_bias is not None + use_masked_hint = args.masked_ratio > 0 def run_fused(): buffer.x[:num_tokens].copy_(x_fp8) @@ -161,22 +172,54 @@ def run_fused(): buffer.topk_idx[:num_tokens].copy_(topk_idx) buffer.topk_weights[:num_tokens].copy_(topk_w) y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - deep_gemm.fp8_mega_moe( - y, transformed_l1, transformed_l2, buffer, - cumulative_local_expert_recv_stats=cum_stats, - recipe=(128, 128, 128), - activation='swiglu', - activation_clamp=activation_clamp, - fast_math=fast_math, - ) + old_eplb_hint = os.environ.get('DG_SM90_MOE_EPLB_HINT') + old_skew_hint = os.environ.get('DG_SM90_MOE_SKEW_HINT') + old_masked_hint = os.environ.get('DG_SM90_MOE_MASKED_HINT') + if use_eplb_hint: + os.environ['DG_SM90_MOE_EPLB_HINT'] = '1' + if use_skew_hint: + os.environ['DG_SM90_MOE_SKEW_HINT'] = '1' + if use_masked_hint: + os.environ['DG_SM90_MOE_MASKED_HINT'] = '1' + try: + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + finally: + if use_eplb_hint: + if old_eplb_hint is None: + os.environ.pop('DG_SM90_MOE_EPLB_HINT', None) + else: + os.environ['DG_SM90_MOE_EPLB_HINT'] = old_eplb_hint + if use_skew_hint: + if old_skew_hint is None: + os.environ.pop('DG_SM90_MOE_SKEW_HINT', None) + else: + os.environ['DG_SM90_MOE_SKEW_HINT'] = old_skew_hint + if use_masked_hint: + if old_masked_hint is None: + os.environ.pop('DG_SM90_MOE_MASKED_HINT', None) + else: + os.environ['DG_SM90_MOE_MASKED_HINT'] = old_masked_hint return y run_fused() dist.barrier() + if phase_profile_enabled: + cum_stats.zero_() + torch.cuda.synchronize() + dist.barrier() t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', barrier=lambda: dist.barrier(), num_tests=args.num_tests, - suppress_kineto_output=True) + suppress_kineto_output=True, + with_multiple_kernels=os.environ.get( + 'DG_SM90_MOE_SPLIT_L1_L2', '1') != '0') # Per-rank token receive counts + per-local-expert distribution gathered_topk_idx = uneven_all_gather(topk_idx, group=group) @@ -205,6 +248,28 @@ def run_fused(): gather_buf = [torch.zeros_like(info) for _ in range(num_ranks)] dist.all_gather(gather_buf, info, group=group) + phase_gather_buf = None + phase_names = [ + 'dispatch_total', 'dispatch_pull', 'math_loop', 'combine_barrier', + 'combine_reduce', 'gemm_core', 'l1_epilogue', 'l2_epilogue', + ] + if phase_profile_enabled: + torch.cuda.synchronize() + num_profile_metrics = len(phase_names) + profile = cum_stats[ + num_experts_per_rank:num_experts_per_rank + phase_profile_ints + ].view(torch.int64) + phase_values = [] + for i in range(num_profile_metrics): + total = float(profile[i].item()) + max_v = float(profile[num_profile_metrics + i].item()) + count = float(profile[2 * num_profile_metrics + i].item()) + avg = total / count if count else 0.0 + phase_values.extend([avg, max_v, count]) + phase_info = torch.tensor(phase_values, device='cuda', dtype=torch.float64) + phase_gather_buf = [torch.zeros_like(phase_info) for _ in range(num_ranks)] + dist.all_gather(phase_gather_buf, phase_info, group=group) + if rank_idx == 0 and print_perf: all_t = [g[0].item() for g in gather_buf] all_recv = [int(g[1].item()) for g in gather_buf] @@ -230,6 +295,22 @@ def run_fused(): f't[mean/max/min]us={t_mean*1e6:7.1f}/{t_max*1e6:7.1f}/{t_min*1e6:7.1f} ' f'TFLOPS_agg={tflops_agg:6.1f} TFLOPS_mean={tflops_mean:6.1f}', flush=True) + if phase_gather_buf is not None: + phase_by_rank = [g.cpu().tolist() for g in phase_gather_buf] + for i, name in enumerate(phase_names): + avg_values = [rank_values[3 * i] for rank_values in phase_by_rank] + max_values = [rank_values[3 * i + 1] for rank_values in phase_by_rank] + count_values = [rank_values[3 * i + 2] for rank_values in phase_by_rank] + avg_mean = sum(avg_values) / len(avg_values) + avg_max = max(avg_values) + max_max = max(max_values) + count_max = max(count_values) + print( + f' phase {name:16s} avg_mean={avg_mean:10.0f} ' + f'avg_max={avg_max:10.0f} max={max_max:10.0f} ' + f'count_max={count_max:8.0f}', + flush=True, + ) dist.barrier() buffer.destroy() @@ -292,7 +373,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument('--skew-seed', type=int, default=0, help='Seed for the global hot/cold expert permutation') parser.add_argument('--num-redundant-experts', type=int, default=0, - help='Benchmark-only EPLB simulation: reserve physical expert slots as hot-expert replicas') + help='EPLB replica simulation: reserve physical expert slots as hot-expert replicas') parser.add_argument('--replica-dispatch', choices=('hash', 'static'), default='hash', help='Replica remap model: token-level hash or SGLang static source-rank approximation') diff --git a/tests/bench_mega_moe_sm90_skewed.py b/tests/bench_mega_moe_sm90_skewed.py index eb72f870ba..b875c7ebb2 100644 --- a/tests/bench_mega_moe_sm90_skewed.py +++ b/tests/bench_mega_moe_sm90_skewed.py @@ -36,11 +36,17 @@ def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g, n, k = w.shape assert n % 128 == 0 and k % 128 == 0 - w_view = w.view(g, n // 128, 128, k // 128, 128).float() - amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) - sf = amax / 448.0 - w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) - return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + chunk_g = 4 + w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) + sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) + for start in range(0, g, chunk_g): + end = min(start + chunk_g, g) + w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() + sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / 448.0 + w_fp8[start:end].copy_( + (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k)) + sf[start:end].copy_(sf_chunk) + return w_fp8, sf.contiguous() def _generate_skewed_scores(num_tokens: int, num_experts: int, num_topk: int, @@ -104,6 +110,7 @@ def _run_one_config(args, num_tokens, num_max_tokens_per_rank, ) cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + use_skew_hint = args.skew_alpha > 0.0 def run_fused(): buffer.x[:num_tokens].copy_(x_fp8) @@ -111,14 +118,24 @@ def run_fused(): buffer.topk_idx[:num_tokens].copy_(topk_idx) buffer.topk_weights[:num_tokens].copy_(topk_w) y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - deep_gemm.fp8_mega_moe( - y, transformed_l1, transformed_l2, buffer, - cumulative_local_expert_recv_stats=cum_stats, - recipe=(128, 128, 128), - activation='swiglu', - activation_clamp=10.0, - fast_math=True, - ) + old_skew_hint = os.environ.get('DG_SM90_MOE_SKEW_HINT') + if use_skew_hint: + os.environ['DG_SM90_MOE_SKEW_HINT'] = '1' + try: + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=10.0, + fast_math=True, + ) + finally: + if use_skew_hint: + if old_skew_hint is None: + os.environ.pop('DG_SM90_MOE_SKEW_HINT', None) + else: + os.environ['DG_SM90_MOE_SKEW_HINT'] = old_skew_hint return y run_fused() @@ -126,7 +143,9 @@ def run_fused(): t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', barrier=lambda: dist.barrier(), num_tests=args.num_tests, - suppress_kineto_output=True) + suppress_kineto_output=True, + with_multiple_kernels=os.environ.get( + 'DG_SM90_MOE_SPLIT_L1_L2', '1') != '0') # Local expert count distribution gathered_topk_idx = uneven_all_gather(topk_idx, group=group) diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py index 461a2a2e04..5db2b4adc8 100644 --- a/tests/test_mega_moe_hopper.py +++ b/tests/test_mega_moe_hopper.py @@ -207,17 +207,18 @@ def _quantize_grouped_fp8_block_128_128( g, n, k = w.shape assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数" - # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部 - w_view = w.view(g, n // 128, 128, k // 128, 128).float() - - # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块 - amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128) - sf = amax / FP8_E4M3_MAX - - # 量化:每个元素除以所属子块的 sf 后转 FP8 - # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度 - w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) - return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + chunk_g = 4 + w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) + sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) + for start in range(0, g, chunk_g): + end = min(start + chunk_g, g) + w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() + sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / FP8_E4M3_MAX + w_fp8[start:end].copy_( + (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k) + ) + sf[start:end].copy_(sf_chunk) + return w_fp8, sf.contiguous() # ============================================================================ @@ -958,6 +959,9 @@ def try_assign(token: int, seen_experts: set[int]) -> bool: intermediate_hidden, ) y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + use_eplb_hint = bool(eplb_replica_for) + use_skew_hint = args.score_powerlaw_alpha > 0.0 + use_masked_hint = args.masked_ratio > 0.0 def run_fused(): # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时 @@ -967,17 +971,43 @@ def run_fused(): sym_buffer.topk_idx[:num_tokens].copy_(topk_idx) sym_buffer.topk_weights[:num_tokens].copy_(topk_weights) - deep_gemm.fp8_mega_moe( - y_fused, - transformed_l1, - transformed_l2, - sym_buffer, - cumulative_local_expert_recv_stats=cum_stats_fused, - recipe=(128, 128, 128), - activation="swiglu", - activation_clamp=clamp_arg, - fast_math=bool(args.fast_math), - ) + old_eplb_hint = os.environ.get("DG_SM90_MOE_EPLB_HINT") + old_skew_hint = os.environ.get("DG_SM90_MOE_SKEW_HINT") + old_masked_hint = os.environ.get("DG_SM90_MOE_MASKED_HINT") + if use_eplb_hint: + os.environ["DG_SM90_MOE_EPLB_HINT"] = "1" + if use_skew_hint: + os.environ["DG_SM90_MOE_SKEW_HINT"] = "1" + if use_masked_hint: + os.environ["DG_SM90_MOE_MASKED_HINT"] = "1" + try: + deep_gemm.fp8_mega_moe( + y_fused, + transformed_l1, + transformed_l2, + sym_buffer, + cumulative_local_expert_recv_stats=cum_stats_fused, + recipe=(128, 128, 128), + activation="swiglu", + activation_clamp=clamp_arg, + fast_math=bool(args.fast_math), + ) + finally: + if use_eplb_hint: + if old_eplb_hint is None: + os.environ.pop("DG_SM90_MOE_EPLB_HINT", None) + else: + os.environ["DG_SM90_MOE_EPLB_HINT"] = old_eplb_hint + if use_skew_hint: + if old_skew_hint is None: + os.environ.pop("DG_SM90_MOE_SKEW_HINT", None) + else: + os.environ["DG_SM90_MOE_SKEW_HINT"] = old_skew_hint + if use_masked_hint: + if old_masked_hint is None: + os.environ.pop("DG_SM90_MOE_MASKED_HINT", None) + else: + os.environ["DG_SM90_MOE_MASKED_HINT"] = old_masked_hint return y_fused # ---- 打印 config ---- @@ -1095,6 +1125,32 @@ def run_baseline(): assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, ( f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}" ) + if args.check_split_vs_one_kernel: + old_split_env = os.environ.get("DG_SM90_MOE_SPLIT_L1_L2") + try: + os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = "1" + y_split = run_fused().detach().clone() + torch.cuda.synchronize() + os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = "0" + y_one_kernel = run_fused().detach().clone() + torch.cuda.synchronize() + finally: + if old_split_env is None: + os.environ.pop("DG_SM90_MOE_SPLIT_L1_L2", None) + else: + os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = old_split_env + diff = (y_split.float() - y_one_kernel.float()).abs() + denom = y_one_kernel.float().abs().mean().clamp_min(1e-12) + dist_print( + "Output diff (split two-kernel vs one-kernel):", once_in_node=True + ) + dist_print( + f" > max_abs={diff.max().item():.6e}, " + f"mean_abs={diff.mean().item():.6e}, " + f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}", + once_in_node=True, + ) + dist_print(once_in_node=True) if ep_buffer is not None: out_b = run_baseline() assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, ( @@ -1150,30 +1206,6 @@ def run_baseline(): num_m_tiles = 0 max_expert_tokens = 0 - # ---- NSYS external profiler multi-iter ---- - # Under DG_USE_NVIDIA_TOOLS=1, bench_kineto returns a sentinel and does not - # run its internal torch.profiler loop. Keep this explicit loop so nsys - # captures multiple steady-state mega_moe kernels, matching decode_t256/t512. - _nsys_iters = int(os.environ.get("NSYS_ITERS", "0")) - if _nsys_iters > 0: - for _it in range(_nsys_iters): - torch.cuda.synchronize() - if ep_buffer is not None: - ep_buffer.barrier(use_comm_stream=False) - else: - dist.barrier() - torch.cuda._sleep(int(2e7)) # ~10ms gap between iters - if ep_buffer is not None: - ep_buffer.barrier(use_comm_stream=False) - else: - dist.barrier() - run_fused() - torch.cuda.synchronize() - if ep_buffer is not None: - ep_buffer.barrier(use_comm_stream=False) - else: - dist.barrier() - # ---- benchmark ---- # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead) if phase_profile_enabled: @@ -1192,7 +1224,7 @@ def run_baseline(): ), with_multiple_kernels=os.environ.get( "DG_SM90_MOE_SPLIT_L1_L2", - "1" if args.num_max_tokens_per_rank >= 1024 else "0", + "1", ) != "0", ) if phase_profile_enabled: @@ -1209,17 +1241,27 @@ def run_baseline(): phase_end.record() torch.cuda.synchronize() phase_event_us = phase_start.elapsed_time(phase_end) * 1000.0 - raw_i32 = cum_stats_fused[num_experts_per_rank:num_experts_per_rank + 64].detach().cpu().tolist() + raw_i32 = cum_stats_fused[num_experts_per_rank:num_experts_per_rank + phase_profile_extra].detach().cpu().tolist() def _u64(slot: int) -> int: lo = raw_i32[slot * 2] & 0xffffffff hi = raw_i32[slot * 2 + 1] & 0xffffffff return lo | (hi << 32) - names = ("dispatch_total", "dispatch_pull", "math_loop", "combine_barrier", "combine_reduce", "gemm_core", "l1_epilogue", "l2_epilogue") + names = ( + "dispatch_total", + "dispatch_pull", + "math_loop", + "combine_barrier", + "combine_reduce", + "gemm_core", + "l1_epilogue", + "l2_epilogue", + ) + num_profile_metrics = len(names) pieces = [] for idx, name in enumerate(names): total = _u64(idx) - max_cycles = _u64(8 + idx) - count = _u64(16 + idx) + max_cycles = _u64(num_profile_metrics + idx) + count = _u64(2 * num_profile_metrics + idx) avg_us = (total / count / 1000.0) if count else 0.0 max_us = max_cycles / 1000.0 pieces.append(f"{name}:avg={avg_us:.1f}us,max={max_us:.1f}us,n={count},ns={total}/{max_cycles}") @@ -1462,7 +1504,7 @@ def fmt_perf_line( "--num-redundant-experts", type=int, default=0, - help="benchmark-only EPLB simulation: reserve physical expert slots as hot-expert replicas", + help="EPLB replica simulation: reserve physical expert slots as hot-expert replicas", ) parser.add_argument( "--replica-dispatch", @@ -1507,6 +1549,12 @@ def fmt_perf_line( default=0, help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)", ) + parser.add_argument( + "--check-split-vs-one-kernel", + type=int, + default=0, + help="非 0 时打印 split two-kernel 与 one-kernel fused 的输出差异", + ) parser.add_argument( "--dump-profile-traces", type=str, From 615ba0ab7dd600fb4d2b9e02113c96edfe5f7332 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 26 May 2026 12:51:10 +0800 Subject: [PATCH 04/15] Split SM90 MegaMoE into L1 and L2 kernels --- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 25 ++- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 198 ++++++++++++------ .../include/deep_gemm/scheduler/mega_moe.cuh | 33 +++ deep_gemm/testing/bench.py | 7 +- tests/test_mega_moe_hopper.py | 4 + 5 files changed, 200 insertions(+), 67 deletions(-) diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 309e7de50f..88b8b31705 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -49,6 +49,7 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime bool l1_dual_k_accum; bool l2_nmajor_schedule; bool l1_nmajor_schedule; + int split_phase_mode; MegaMoESM90Config config; // Runtime arguments @@ -102,6 +103,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, + {}, {} >); }}; @@ -126,7 +128,8 @@ static void __instantiate_kernel() {{ args.phase_profile ? "true" : "false", args.l1_dual_k_accum ? "true" : "false", args.l2_nmajor_schedule ? "true" : "false", - args.l1_nmajor_schedule ? "true" : "false"); + args.l1_nmajor_schedule ? "true" : "false", + args.split_phase_mode); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -246,6 +249,7 @@ static void sm90_fp8_mega_moe( .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, .l2_nmajor_schedule = get_env("DG_SM90_MOE_L2_NMAJOR", 0) != 0, .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, + .split_phase_mode = 0, .config = config, .y = y.data_ptr(), .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, @@ -263,9 +267,22 @@ static void sm90_fp8_mega_moe( .launch_args = LaunchArgs(num_sms, config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, config.smem_size, config.cluster_size) }; - const auto code = SM90FP8MegaMoERuntime::generate(args); - const auto runtime = compiler->build("sm90_fp8_mega_moe", code); - SM90FP8MegaMoERuntime::launch(runtime, args); + const auto launch_with_split_mode = [&](const int split_phase_mode, const char* kernel_name) { + auto split_args = args; + split_args.split_phase_mode = split_phase_mode; + const auto code = SM90FP8MegaMoERuntime::generate(split_args); + const auto runtime = compiler->build(kernel_name, code); + SM90FP8MegaMoERuntime::launch(runtime, split_args); + }; + + const bool split_l1_l2 = get_env( + "DG_SM90_MOE_SPLIT_L1_L2", num_max_tokens_per_rank >= 1024 ? 1 : 0) != 0; + if (split_l1_l2) { + launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); + launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); + } else { + launch_with_split_mode(0, "sm90_fp8_mega_moe"); + } } } // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index d0719a068e..248ffc810f 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -73,6 +73,7 @@ template < bool kL1DualKAccumRequested = false, bool kL2NMajorScheduleRequested = false, bool kL1NMajorScheduleRequested = false, + uint32_t kSplitPhaseMode = 0, uint32_t L1_SHAPE_N = kIntermediateHidden * 2, uint32_t L1_SHAPE_K = kHidden, uint32_t L2_SHAPE_N = kHidden, @@ -117,6 +118,7 @@ sm90_fp8_mega_moe_impl(void* y, "BLOCK_M must be 16/32 for mma.sync decode or a multiple of WGMMA::M (64)"); DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); + DG_STATIC_ASSERT(kSplitPhaseMode <= 2, "Invalid SM90 MegaMoE split phase mode"); // ===================================================================== // Thread / warp identification @@ -188,6 +190,8 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N + constexpr bool kRunOnlyLinear1 = kSplitPhaseMode == 1; + constexpr bool kRunOnlyLinear2 = kSplitPhaseMode == 2; constexpr bool kAsyncL1TMAStore = kAsyncL1TMAStoreRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; @@ -410,6 +414,72 @@ sm90_fp8_mega_moe_impl(void* y, } }; + const auto should_run_block_phase = [](const sched::BlockPhase& block_phase) { + if constexpr (kRunOnlyLinear1) { + return block_phase == sched::BlockPhase::Linear1; + } else if constexpr (kRunOnlyLinear2) { + return block_phase == sched::BlockPhase::Linear2; + } else { + (void)block_phase; + return true; + } + }; + + const auto for_each_selected_block = [&](auto&& func) { + if constexpr (kRunOnlyLinear1) { + scheduler.for_each_linear1_block([&](const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + func(sched::BlockPhase::Linear1, local_expert_idx, num_k_blocks, m_block_idx, n_block_idx); + }); + } else if constexpr (kRunOnlyLinear2) { + scheduler.for_each_linear2_block([&](const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + func(sched::BlockPhase::Linear2, local_expert_idx, num_k_blocks, m_block_idx, n_block_idx); + }); + } else { + scheduler.for_each_block(func); + } + }; + + const auto cleanup_workspace = [&]() { + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + const auto cleanup_pool_block_offset = scheduler.get_pool_block_offset(i); + + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + }; + // ===================================================================== // ROLE 1: DISPATCH WARPS // Mirrors SM100 dispatch with two changes: @@ -423,6 +493,19 @@ sm90_fp8_mega_moe_impl(void* y, cutlass::arch::warpgroup_reg_dealloc(); const unsigned long long dispatch_total_start = phase_profile_clock(); + if constexpr (kRunOnlyLinear2) { + scheduler.fetch_expert_recv_count(); + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + cleanup_workspace(); + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + true, false); + return; + } + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; const auto read_topk_idx = [&](const auto& process) { @@ -642,42 +725,10 @@ sm90_fp8_mega_moe_impl(void* y, phase_profile_record(kProfileDispatchTotal, dispatch_pull_end - dispatch_total_start); } ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + if constexpr (kRunOnlyLinear1) + return; - DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); - if (sm_idx == 0) { - #pragma unroll - for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) - *workspace.get_expert_send_count_ptr(i) = 0; - } else { - for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { - const auto num_recv_tokens = static_cast( - *workspace.get_expert_recv_count_sum_ptr(i)); - const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); - - expert_pool_block_offset = scheduler.get_pool_block_offset(i); - - ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); - - DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); - if (warp_idx == 0) { - *workspace.get_expert_recv_count_sum_ptr(i) = 0; - } else if (warp_idx == 1) { - if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) - ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); - __syncwarp(); - } - - for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) - *workspace.get_expert_recv_count_ptr(j, i) = 0; - __syncwarp(); - - for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { - *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; - *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; - } - __syncwarp(); - } - } + cleanup_workspace(); comm::nvlink_barrier( @@ -694,13 +745,17 @@ sm90_fp8_mega_moe_impl(void* y, } else if (warp_idx == kNumDispatchWarps) { cutlass::arch::warpgroup_reg_dealloc(); - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + for_each_selected_block([&](const sched::BlockPhase& block_phase, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + if (!should_run_block_phase(block_phase)) + return; + const bool is_linear1_phase = + kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); + const auto tensor_map_a_ptr = !is_linear1_phase ? &tensor_map_l2_acts : &tensor_map_l1_acts; - const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + const auto tensor_map_sfa_ptr = !is_linear1_phase ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; @@ -710,11 +765,11 @@ sm90_fp8_mega_moe_impl(void* y, // Wait for the pool to be ready. Cluster peers can be dummy CTAs for // the tail M unit when an expert has an odd number of M blocks. if (has_valid_m) { - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = valid_m; while (ptx::ld_acq(ptr) != expected); - } else { + } else if constexpr (!kRunOnlyLinear2) { const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; @@ -740,7 +795,7 @@ sm90_fp8_mega_moe_impl(void* y, full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE); } else { // TMA load SFA - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { // L1 SFA per-128: load (BLOCK_M, 1) at K=k_block_idx tma::copy( tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], @@ -773,14 +828,18 @@ sm90_fp8_mega_moe_impl(void* y, } else if (warp_idx == kNumDispatchWarps + 1) { cutlass::arch::warpgroup_reg_dealloc(); - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + for_each_selected_block([&](const sched::BlockPhase& block_phase, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + if (!should_run_block_phase(block_phase)) + return; + const bool is_linear1_phase = + kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); const auto tensor_map_b_ptr = - block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + !is_linear1_phase ? &tensor_map_l2_weights : &tensor_map_l1_weights; - const uint32_t shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const uint32_t shape_n = !is_linear1_phase ? L2_SHAPE_N : L1_SHAPE_N; for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { empty_barriers[stage_idx]->wait(phase ^ 1); @@ -803,13 +862,17 @@ sm90_fp8_mega_moe_impl(void* y, } else if (kSplitSFATMA && warp_idx == kNumDispatchWarps + 2) { cutlass::arch::warpgroup_reg_dealloc(); - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + for_each_selected_block([&](const sched::BlockPhase& block_phase, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + if (!should_run_block_phase(block_phase)) + return; + const bool is_linear1_phase = + kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); (void)local_expert_idx; (void)n_block_idx; - const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + const auto tensor_map_sfa_ptr = !is_linear1_phase ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; @@ -817,11 +880,11 @@ sm90_fp8_mega_moe_impl(void* y, const bool has_valid_m = valid_m > 0; if (has_valid_m) { - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = valid_m; while (ptx::ld_acq(ptr) != expected); - } else { + } else if constexpr (!kRunOnlyLinear2) { const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; const uint64_t expected = (kNumL1BlockNs >= 64) @@ -837,7 +900,7 @@ sm90_fp8_mega_moe_impl(void* y, if (has_valid_m) { const uint32_t m_idx = pool_block_idx * BLOCK_M; - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { tma::copy( tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], m_idx, k_block_idx, 1); @@ -957,10 +1020,14 @@ sm90_fp8_mega_moe_impl(void* y, ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); const unsigned long long math_loop_start = phase_profile_clock(); - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + for_each_selected_block([&](const sched::BlockPhase& block_phase, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + if (!should_run_block_phase(block_phase)) + return; + const bool is_linear1_phase = + kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); const uint32_t valid_m = scheduler.template get_valid_m(); const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; const uint32_t m_idx = pool_block_idx * BLOCK_M; @@ -975,7 +1042,7 @@ sm90_fp8_mega_moe_impl(void* y, if constexpr (kAsyncL1TMAStore) { - if (block_phase != sched::BlockPhase::Linear1) + if (!is_linear1_phase) drain_all_async_l1_stores(); } @@ -1017,7 +1084,7 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; const uint32_t up_n = kL1SFGateBlks + gate_n; const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; @@ -1082,7 +1149,7 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t kMMASyncRowsPerPass = kNumEpilogueThreads / 8; DG_STATIC_ASSERT(kMMASyncRowsPerPass == 16, "mma.sync epilogue maps 8 lanes per row"); - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { #pragma unroll for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { const uint32_t row = row_base + epilogue_thread_idx / 8; @@ -1207,7 +1274,7 @@ sm90_fp8_mega_moe_impl(void* y, float scale_a_0_lo, scale_a_1_lo; float scale_a_0_hi, scale_a_1_hi; - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); } else { @@ -1227,7 +1294,7 @@ sm90_fp8_mega_moe_impl(void* y, for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { const uint32_t serial_wg_n_idx = serial_n_idx * WG_BLOCK_N; float gate_sf = 0.0f, up_sf = 0.0f, l2_sf = 0.0f; - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { const uint32_t gate_n = (n_block_idx * BLOCK_N + serial_wg_n_idx) / 256u; const uint32_t up_n = kL1SFGateBlks + gate_n; const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; @@ -1322,7 +1389,7 @@ sm90_fp8_mega_moe_impl(void* y, return; } - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { constexpr uint32_t kNumPairs = kAccumPerThread / 8; #pragma unroll for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { @@ -1542,7 +1609,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl // Read SF (must precede warpgroup_arrive) float scale_a_0_lo, scale_a_1_lo; float scale_a_0_hi, scale_a_1_hi; // Only used in L2 (per-64 K) - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); } else { @@ -1577,7 +1644,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; float gate_sf = 0.0f, up_sf = 0.0f, l2_sf_lo = 0.0f, l2_sf_hi = 0.0f; - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; const uint32_t up_n = kL1SFGateBlks + gate_n; const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; @@ -1593,7 +1660,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl l2_sf_hi = l2_sf_lo; } - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { // Single per-128 K-block WGMMA group #pragma unroll for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); @@ -1845,7 +1912,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl }; if constexpr (kL1DualKAccum) { - if (block_phase == sched::BlockPhase::Linear1) + if (is_linear1_phase) run_l1_dual_k_gemm_loop(); else run_default_gemm_loop(); @@ -1862,7 +1929,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl // previous valid block, so drain it before leaving the L1 wave. if (row_block_offset >= valid_m) { if constexpr (kAsyncL1TMAStore) { - if (block_phase == sched::BlockPhase::Linear1) + if (is_linear1_phase) drain_all_async_l1_stores(); } ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); @@ -1870,7 +1937,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl } const unsigned long long block_epilogue_start = phase_profile_clock(); - if (block_phase == sched::BlockPhase::Linear1) { + if (is_linear1_phase) { // ---------------- L1 EPILOGUE: SwiGLU + FP8 quantize + TMA store ---------------- // Layout in `final_accum`: // 16 chunks of 8 N-cols, each chunk = 4 floats per thread = (r0c0, r0c1, r1c0, r1c1). @@ -2240,6 +2307,13 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if (epilogue_warp_idx == 0 and lane_idx == 0) phase_profile_record(kProfileMathLoop, math_loop_end - math_loop_start); + if constexpr (kRunOnlyLinear1) { + if constexpr (kAsyncL1TMAStore) + drain_all_async_l1_stores(); + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + return; + } + // ---------------- COMBINE ---------------- // NVLink barrier first: signals remote ranks that this rank's GEMM // outputs (NVLink scatter targets) are fully written. diff --git a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh index f65c891968..dd3b077df6 100644 --- a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -242,6 +242,39 @@ struct MegaMoEScheduler { m_block_idx, n_block_idx); } } + + template + CUTLASS_DEVICE void for_each_linear1_block(Func&& func) { + // Split-kernel mode: K1 owns only dispatch + Linear1. Unlike + // for_each_block(), do not burn scheduler iterations on Linear2 blocks. + fetch_expert_recv_count(); + set_expert_idx(0); + while (current_local_expert_idx < kNumExpertsPerRank) { + if (fetch_next_l1_block()) { + block_idx += kNumSMs / kClusterSize; + func(current_local_expert_idx, kNumL1BlockKs, m_block_idx, n_block_idx); + } else if (current_local_expert_idx >= kNumExpertsPerRank) { + break; + } + } + } + + template + CUTLASS_DEVICE void for_each_linear2_block(Func&& func) { + // Split-kernel mode: K2 starts after K1 has completed globally, so all + // L2-ready masks are already final. Schedule Linear2 blocks directly + // from the phase start, like a standalone grouped GEMM. + fetch_expert_recv_count(); + set_expert_idx(0); + while (current_local_expert_idx < kNumExpertsPerRank) { + if (fetch_next_l2_block()) { + block_idx += kNumSMs / kClusterSize; + func(current_local_expert_idx, kNumL2BlockKs, m_block_idx, n_block_idx); + } else if (current_local_expert_idx >= kNumExpertsPerRank) { + break; + } + } + } }; } // namespace deep_gemm::sched diff --git a/deep_gemm/testing/bench.py b/deep_gemm/testing/bench.py index 552b9aa18a..915be3533c 100644 --- a/deep_gemm/testing/bench.py +++ b/deep_gemm/testing/bench.py @@ -141,6 +141,11 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, total_time += float(time_str.replace(unit, '')) / scale * int(num_str) total_num += int(num_str) break - kernel_times.append(total_time / total_num if total_num > 0 else 0) + if total_num > 0 and with_multiple_kernels: + # Multiple matching kernels can belong to one logical benchmarked op + # (e.g. split MegaMoE L1/L2). Report summed CUDA time per fn() call. + kernel_times.append(total_time / num_tests) + else: + kernel_times.append(total_time / total_num if total_num > 0 else 0) return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py index 021382710e..461a2a2e04 100644 --- a/tests/test_mega_moe_hopper.py +++ b/tests/test_mega_moe_hopper.py @@ -1190,6 +1190,10 @@ def run_baseline(): if args.dump_profile_traces else None ), + with_multiple_kernels=os.environ.get( + "DG_SM90_MOE_SPLIT_L1_L2", + "1" if args.num_max_tokens_per_rank >= 1024 else "0", + ) != "0", ) if phase_profile_enabled: cum_stats_fused.zero_() From eece053b2770a8eb22943ce844acbaf9dad18d46 Mon Sep 17 00:00:00 2001 From: AichenF Date: Mon, 1 Jun 2026 17:14:50 +0800 Subject: [PATCH 05/15] Clean SM90 MegaMoE split path --- csrc/apis/mega.hpp | 104 - csrc/jit_kernels/heuristics/mega_moe.hpp | 118 +- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 73 +- csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp | 271 -- deep_gemm/__init__.py | 1 - .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 58 +- .../deep_gemm/impls/sm90_w4a8_mega_moe.cuh | 2318 ----------------- .../deep_gemm/quantization/mxfp4_dequant.cuh | 58 - .../include/deep_gemm/scheduler/mega_moe.cuh | 309 +-- deep_gemm/mega/__init__.py | 28 - tests/bench_w4a8_mega_moe_sm90.py | 250 -- 11 files changed, 89 insertions(+), 3499 deletions(-) delete mode 100644 csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp delete mode 100644 deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh delete mode 100644 deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh delete mode 100644 tests/bench_w4a8_mega_moe_sm90.py diff --git a/csrc/apis/mega.hpp b/csrc/apis/mega.hpp index ea9ff0fbab..c999099b1b 100644 --- a/csrc/apis/mega.hpp +++ b/csrc/apis/mega.hpp @@ -9,7 +9,6 @@ #include "../jit/device_runtime.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" #include "../jit_kernels/impls/sm90_fp8_mega_moe.hpp" -#include "../jit_kernels/impls/sm90_w4a8_mega_moe.hpp" namespace deep_gemm::mega { @@ -359,115 +358,12 @@ static void fp8_mega_moe( sym_buffer.zero_(); } -// SM90 (Hopper) W4A8 MegaMoE entry point — Phase-0 scaffold. -// -// Accepts the same FP8 weights + per-128 float SF as `fp8_mega_moe` and -// dispatches to the W4A8 runtime. The runtime kernel is currently a verbatim -// copy of the FP8 kernel (only renamed); subsequent phases will: -// 1. swap weight dtype to packed MXFP4 (uint8) + per-32 E8M0 (uint8) SF -// 2. add `dequant_mxfp4_to_fp8` insertion in the math warp mainloop -// 3. apply per-32 E8M0 scale on WGMMA accumulator (Marlin-style) -// -// In Phase 0 this entry is bitwise equivalent to `fp8_mega_moe`. Used to -// verify build/link/run plumbing for the new kernel name before changing -// kernel logic. -static void w4a8_mega_moe( - const torch::Tensor& y, - const std::tuple& l1_weights_tuple, - const std::tuple& l2_weights_tuple, - const std::optional& cumulative_local_expert_recv_stats, - const torch::Tensor& sym_buffer, - const std::vector& sym_buffer_ptrs, const int& rank_idx, - const int& num_max_tokens_per_rank, - const int& num_experts, const int& num_topk, - const std::tuple& recipe, - const std::string& activation, - const std::optional& activation_clamp_opt, - const bool& fast_math -) { - const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; - const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; - - // Architecture check - const auto arch_major = device_runtime->get_arch_major(); - DG_HOST_ASSERT(arch_major == 9); - - // Phase-0 input checks: identical to fp8_mega_moe. - const auto num_tokens = static_cast(y.size(0)); - const auto [rm, rn, rk] = recipe; - DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128); - DG_HOST_ASSERT(activation == "swiglu"); - - const auto activation_clamp = - activation_clamp_opt.value_or(std::numeric_limits::infinity()); - DG_HOST_ASSERT(activation_clamp >= 0); - - DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); - DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); - DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn); - const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights); - const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights); - DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); - DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); - DG_HOST_ASSERT(hidden == hidden_); - DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); - DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); - DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); - DG_HOST_ASSERT(intermediate_hidden / 64 <= 64); - - constexpr int kGranMN = 128, kGranK = 128; - check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, - num_experts_per_rank, false, true, torch::kFloat); - check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, - num_experts_per_rank, false, true, torch::kFloat); - - if (cumulative_local_expert_recv_stats.has_value()) { - DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); - const auto stats_numel = cumulative_local_expert_recv_stats->numel(); - const bool phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0; - DG_HOST_ASSERT(stats_numel == num_experts_per_rank or - (phase_profile and stats_numel >= num_experts_per_rank + 64)); - DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); - } - - const auto num_ranks = static_cast(sym_buffer_ptrs.size()); - const auto num_experts_ = num_experts_per_rank * num_ranks; - const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( - num_ranks, num_experts, - num_max_tokens_per_rank, num_topk, - hidden, intermediate_hidden, - true, activation); - DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); - DG_HOST_ASSERT(num_experts == num_experts_); - - const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); - - // Dispatch to W4A8 runtime (Phase-0: identical kernel to FP8). - sm90_w4a8_mega_moe(y, - l1_acts, l1_acts_sf, - l2_acts, l2_acts_sf, - l1_weights, l2_weights, - l1_weights_sf, l2_weights_sf, - cumulative_local_expert_recv_stats, - sym_buffer_ptrs, - rank_idx, num_max_tokens_per_rank, - num_experts_per_rank, - num_tokens, num_topk, - hidden, intermediate_hidden, - activation_clamp, fast_math); - - if (get_env("DG_COMM_KERNEL_DEBUG")) - sym_buffer.zero_(); -} - static void register_apis(pybind11::module_& m) { #if DG_TENSORMAP_COMPATIBLE m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); m.def("fp8_mega_moe", &fp8_mega_moe); - m.def("w4a8_mega_moe", &w4a8_mega_moe); #endif } diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index fe310122d1..7d5f15e82a 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -406,9 +406,7 @@ static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( static std::tuple get_block_config_for_mega_moe_sm90( const int& num_ranks, const int& num_experts, const int& num_max_tokens_per_rank, const int& num_topk, - const int& num_tokens, - const int& split_phase_mode = 0) { - DG_HOST_ASSERT(split_phase_mode >= 0 and split_phase_mode <= 2); + const int& num_tokens) { // Keep mma.sync decode variants opt-in; the default path uses M64 CTAs. const int num_experts_per_rank = num_experts / num_ranks; const float expected_tokens_per_expert = @@ -417,21 +415,8 @@ static std::tuple get_block_config_for_mega_moe_sm90( ? get_env("DG_SM90_MOE_MMA_SYNC_M") : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); - const bool split_linear1_only = split_phase_mode == 1; - const bool split_linear2_only = split_phase_mode == 2; - int forced_block_m = 0; - int forced_epilogue_warpgroups = 0; - if (split_linear1_only) { - forced_block_m = get_env("DG_SM90_MOE_L1_FORCE_BLOCK_M"); - forced_epilogue_warpgroups = get_env("DG_SM90_MOE_L1_FORCE_EPILOGUE_WG"); - } else if (split_linear2_only) { - forced_block_m = get_env("DG_SM90_MOE_L2_FORCE_BLOCK_M"); - forced_epilogue_warpgroups = get_env("DG_SM90_MOE_L2_FORCE_EPILOGUE_WG"); - } - if (forced_block_m == 0) - forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); - if (forced_epilogue_warpgroups == 0) - forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); + const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); + const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or forced_epilogue_warpgroups == 1 or @@ -439,19 +424,13 @@ static std::tuple get_block_config_for_mega_moe_sm90( const bool use_mma_sync_decode = requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); const bool use_bn256_split_n = - (get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 or - (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_2WG") != 0) or - (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_2WG") != 0)) and + get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and forced_block_m != 128 and not use_mma_sync_decode; const bool use_bn256_seq_n = - (get_env("DG_SM90_MOE_BN256_SEQ") != 0 or - (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_SEQ") != 0) or - (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_SEQ") != 0)) and + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and forced_block_m != 128 and not use_mma_sync_decode; const bool use_b_stationary_2wg = - (get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 or - (split_linear1_only and get_env("DG_SM90_MOE_L1_B_STATIONARY_2WG") != 0)) and - not split_linear2_only and not use_mma_sync_decode; + get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n or use_bn256_seq_n))); const int block_m = forced_block_m > 0 @@ -526,18 +505,13 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int& block_m, const int& block_n, const int& block_k, const int& num_dispatch_warps, const int& num_epilogue_warps, const bool& direct_l2_scatter_default = false, - const int& default_num_stages = 0, - const int& split_phase_mode = 0) { + const int& default_num_stages = 0) { constexpr int kSmemAlignment = 1024; - DG_HOST_ASSERT(split_phase_mode >= 0 and split_phase_mode <= 2); - const bool split_linear1_only = split_phase_mode == 1; - const bool split_linear2_only = split_phase_mode == 2; - // Dispatch region. Split K2 does not pull remote tokens, so its dynamic - // shared memory does not need expert-count scratch or dispatch buffers. - const int smem_expert_count_size = split_linear2_only ? 0 : align( + // Dispatch region (same as SM100) + const int smem_expert_count_size = align( num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); - const int smem_send_buffers_size = split_linear2_only ? 0 : align( + const int smem_send_buffers_size = align( static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), kSmemAlignment); const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; @@ -552,19 +526,16 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int wg_block_m = split_n_warpgroups ? block_m : block_m / num_epilogue_warpgroups; const int wg_block_n = (split_n_warpgroups or serial_n_warpgroups) ? block_n / 2 : block_n; const int smem_cd_accum = (block_m == 16 or block_m == 32) ? align(block_m * block_n * static_cast(sizeof(float)), kSmemAlignment) : 0; - const int smem_cd_l1 = split_linear2_only ? 0 : - num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) - const bool direct_l2_scatter = (not split_linear1_only) and - get_env( + const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) + const bool direct_l2_scatter = get_env( "DG_SM90_MOE_DIRECT_L2_SCATTER", direct_l2_scatter_default ? 1 : 0) != 0 and block_m != 16 and block_m != 32 and not serial_n_warpgroups and wg_block_n == 128; - const bool async_l1_tma_store = (not split_linear2_only) and - get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and + const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and block_m != 16 and block_m != 32 and not split_n_warpgroups and num_epilogue_warpgroups == 1; - const int smem_cd_l2 = (split_linear1_only or direct_l2_scatter) ? 0 : + const int smem_cd_l2 = direct_l2_scatter ? 0 : num_epilogue_warpgroups * wg_block_m * wg_block_n * static_cast(sizeof(nv_bfloat16)); const int smem_cd_l1_async = async_l1_tma_store ? 2 * num_epilogue_warpgroups * wg_block_m * (block_n / 2) : 0; @@ -576,8 +547,7 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( // * SFB is loaded directly from global by the math warpgroup (block-(128,128) // weight quantization), so no SMEM is reserved for it. const int smem_sfa_half_stride_bytes = align(block_m * static_cast(sizeof(float)), 128); - const int smem_sfa_per_stage = split_linear1_only - ? smem_sfa_half_stride_bytes : 2 * smem_sfa_half_stride_bytes; + const int smem_sfa_per_stage = 2 * smem_sfa_half_stride_bytes; const int smem_sfb_per_stage = 0; // Per-stage: A tile + B tile + SFA tile + SFB tile @@ -602,13 +572,7 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int preferred_num_stages = default_num_stages > 0 ? std::min(default_num_stages, max_num_stages) : (prefer_bn256_split ? std::min(4, max_num_stages) : 0); - int forced_num_stages = 0; - if (split_linear1_only) - forced_num_stages = get_env("DG_SM90_MOE_L1_NUM_STAGES"); - else if (split_linear2_only) - forced_num_stages = get_env("DG_SM90_MOE_L2_NUM_STAGES"); - if (forced_num_stages == 0) - forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); + const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); const int num_stages = forced_num_stages > 0 ? forced_num_stages : (preferred_num_stages > 0 ? preferred_num_stages : max_num_stages); @@ -621,40 +585,21 @@ static MegaMoESM90Config get_mega_moe_config_sm90( const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, const int& hidden, const int& intermediate_hidden, - const int& num_padded_sf_pool_tokens, - const int& split_phase_mode = 0) { - DG_HOST_ASSERT(split_phase_mode >= 0 and split_phase_mode <= 2); - const bool split_linear1_only = split_phase_mode == 1; - const bool split_linear2_only = split_phase_mode == 2; + const int& num_padded_sf_pool_tokens) { const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90( - num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens, split_phase_mode); + num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); const bool use_bn256_split_n = - (get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 or - (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_2WG") != 0) or - (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_2WG") != 0)) and + get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and block_m == 64; const bool use_bn256_seq_n = - (get_env("DG_SM90_MOE_BN256_SEQ") != 0 or - (split_linear1_only and get_env("DG_SM90_MOE_L1_BN256_SEQ") != 0) or - (split_linear2_only and get_env("DG_SM90_MOE_L2_BN256_SEQ") != 0)) and - block_m == 64; + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and block_m == 64; DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); - int forced_block_n = 0; - if (split_linear1_only) - forced_block_n = get_env("DG_SM90_MOE_L1_FORCE_BLOCK_N"); - else if (split_linear2_only) - forced_block_n = get_env("DG_SM90_MOE_L2_FORCE_BLOCK_N"); - if (forced_block_n == 0) - forced_block_n = get_env("DG_SM90_MOE_FORCE_BLOCK_N"); - DG_HOST_ASSERT(forced_block_n == 0 or forced_block_n == 128 or forced_block_n == 256); - const int block_n = forced_block_n > 0 ? forced_block_n : ((use_bn256_split_n or use_bn256_seq_n) ? 256 : 128); - DG_HOST_ASSERT((not use_bn256_split_n) or block_n == 256); + const int block_n = (use_bn256_split_n or use_bn256_seq_n) ? 256 : 128; DG_HOST_ASSERT((not use_bn256_split_n) or num_epilogue_threads == 256); const int block_k = 128; // Default remains cluster_size=1; the opt-in cluster path uses B multicast // across adjacent M blocks. - const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 or - (split_linear1_only and get_env("DG_SM90_MOE_L1_B_STATIONARY_2WG") != 0); + const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; DG_HOST_ASSERT((not use_cluster_bcast_b) or ((block_m == 64 and block_n == 128 and num_epilogue_threads == 128) or @@ -666,19 +611,9 @@ static MegaMoESM90Config get_mega_moe_config_sm90( const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; const int num_sms = device_runtime->get_num_sms(); - int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( + const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( num_experts_per_rank, num_tokens, num_topk, intermediate_hidden, block_m, block_n, num_sms); - int forced_phase_experts_per_wave = 0; - if (split_linear1_only) - forced_phase_experts_per_wave = get_env("DG_SM90_MOE_L1_EXPERTS_PER_WAVE"); - else if (split_linear2_only) - forced_phase_experts_per_wave = get_env("DG_SM90_MOE_L2_EXPERTS_PER_WAVE"); - if (forced_phase_experts_per_wave > 0) { - DG_HOST_ASSERT(forced_phase_experts_per_wave <= num_experts_per_rank); - DG_HOST_ASSERT(num_experts_per_rank % forced_phase_experts_per_wave == 0); - num_experts_per_wave = forced_phase_experts_per_wave; - } const bool split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0; const bool prefer_compact_frontend = get_sm90_moe_split_l1_l2_default() and block_n == 256 and not split_sfa_tma; @@ -712,8 +647,7 @@ static MegaMoESM90Config get_mega_moe_config_sm90( block_m, block_n, block_k, num_dispatch_threads / 32, num_epilogue_threads / 32, direct_l2_scatter_default, - default_num_stages, - split_phase_mode); + default_num_stages); const auto config = MegaMoESM90Config { block_m, block_n, block_k, @@ -727,8 +661,8 @@ static MegaMoESM90Config get_mega_moe_config_sm90( if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { const auto key = fmt::format( - "MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={}, split_phase_mode={})", - num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk, split_phase_mode); + "MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); static std::unordered_set printed; if (printed.count(key) == 0) { std::cout << key << ": " << config << std::endl; diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 60d13902f4..6411012495 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -48,11 +48,7 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime bool l2_nmajor_schedule; bool l1_nmajor_schedule; bool one_warp_cleanup; - int l2_m_swizzle_group; - int l1_m_swizzle_group; int split_phase_mode; - int expert_range_start; - int expert_range_end; MegaMoESM90Config config; // Runtime arguments @@ -108,10 +104,6 @@ static void __instantiate_kernel() {{ {}, {}, {}, - {}, - {}, - {}, - {}, {} >); }}; @@ -138,11 +130,7 @@ static void __instantiate_kernel() {{ args.l2_nmajor_schedule ? "true" : "false", args.l1_nmajor_schedule ? "true" : "false", args.one_warp_cleanup ? "true" : "false", - args.l2_m_swizzle_group, - args.l1_m_swizzle_group, - args.split_phase_mode, - args.expert_range_start, - args.expert_range_end); + args.split_phase_mode); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -278,11 +266,7 @@ static void sm90_fp8_mega_moe( .one_warp_cleanup = get_env( "DG_SM90_MOE_ONE_WARP_CLEANUP", one_warp_cleanup_default ? 1 : 0) != 0, - .l2_m_swizzle_group = get_env("DG_SM90_MOE_L2_M_SWIZZLE_GROUP", 0), - .l1_m_swizzle_group = get_env("DG_SM90_MOE_L1_M_SWIZZLE_GROUP", 0), .split_phase_mode = 0, - .expert_range_start = get_env("DG_SM90_MOE_EXPERT_RANGE_START", 0), - .expert_range_end = get_env("DG_SM90_MOE_EXPERT_RANGE_END", 0), .config = config, .y = y.data_ptr(), .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, @@ -303,53 +287,6 @@ static void sm90_fp8_mega_moe( const auto launch_with_split_mode = [&](const int split_phase_mode, const char* kernel_name) { auto split_args = args; split_args.split_phase_mode = split_phase_mode; - if (split_phase_mode != 0 and get_env("DG_SM90_MOE_SPLIT_PHASE_CONFIG", 1) != 0) { - split_args.config = get_mega_moe_config_sm90( - num_ranks, num_experts, num_experts_per_rank, - num_max_tokens_per_rank, num_tokens, num_topk, - hidden, intermediate_hidden, num_padded_sf_pool_tokens, - split_phase_mode); - split_args.launch_args = LaunchArgs( - num_sms, - split_args.config.num_dispatch_threads + split_args.config.num_non_epilogue_threads + split_args.config.num_epilogue_threads, - split_args.config.smem_size, split_args.config.cluster_size); - - const auto& phase_config = split_args.config; - split_args.tensor_map_l1_acts = make_tma_2d_desc( - l1_acts, hidden, phase_config.num_max_pool_tokens, - phase_config.block_k, phase_config.block_m, - static_cast(l1_acts.stride(-2)), phase_config.swizzle_acts_mode); - split_args.tensor_map_l1_acts_sf = make_tma_sf_desc( - cute::UMMA::Major::MN, l1_acts_sf, - phase_config.num_padded_sf_pool_tokens, hidden, - phase_config.block_m, kGranK, 1, 0); - split_args.tensor_map_l1_weights = make_tma_2d_desc( - l1_weights, hidden, num_experts_per_rank * intermediate_hidden * 2, - phase_config.block_k, phase_config.block_n, - static_cast(l1_weights.stride(-2)), phase_config.swizzle_weights_mode); - - const int phase_num_epilogue_warpgroups = phase_config.num_epilogue_threads / 128; - const bool phase_split_n_warpgroups = - phase_config.block_m == 64 and phase_config.block_n == 256 and phase_num_epilogue_warpgroups == 2; - const int phase_wg_block_m = phase_split_n_warpgroups - ? phase_config.block_m : phase_config.block_m / phase_num_epilogue_warpgroups; - split_args.tensor_map_l1_output = make_tma_2d_desc( - l2_acts, intermediate_hidden, phase_config.num_max_pool_tokens, - phase_config.block_n / 2, phase_wg_block_m, - static_cast(l2_acts.stride(-2)), 0); - split_args.tensor_map_l2_acts = make_tma_2d_desc( - l2_acts, intermediate_hidden, phase_config.num_max_pool_tokens, - phase_config.block_k, phase_config.block_m, - static_cast(l2_acts.stride(-2)), phase_config.swizzle_acts_mode); - split_args.tensor_map_l2_acts_sf = make_tma_sf_desc( - cute::UMMA::Major::MN, l2_acts_sf, - phase_config.num_padded_sf_pool_tokens, intermediate_hidden, - phase_config.block_m, kL2ActsSFGranK, 1, 0); - split_args.tensor_map_l2_weights = make_tma_2d_desc( - l2_weights, intermediate_hidden, num_experts_per_rank * hidden, - phase_config.block_k, phase_config.block_n, - static_cast(l2_weights.stride(-2)), phase_config.swizzle_weights_mode); - } const auto code = SM90FP8MegaMoERuntime::generate(split_args); const auto runtime = compiler->build(kernel_name, code); SM90FP8MegaMoERuntime::launch(runtime, split_args); @@ -357,12 +294,8 @@ static void sm90_fp8_mega_moe( const bool split_l1_l2 = get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; if (split_l1_l2) { - const int split_only_phase = get_env("DG_SM90_MOE_SPLIT_ONLY_PHASE", 0); - DG_HOST_ASSERT(split_only_phase >= 0 and split_only_phase <= 2); - if (split_only_phase == 0 or split_only_phase == 1) - launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); - if (split_only_phase == 0 or split_only_phase == 2) - launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); + launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); + launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); } else { launch_with_split_mode(0, "sm90_fp8_mega_moe"); } diff --git a/csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp deleted file mode 100644 index 52e8b380aa..0000000000 --- a/csrc/jit_kernels/impls/sm90_w4a8_mega_moe.hpp +++ /dev/null @@ -1,271 +0,0 @@ -#pragma once - -#include -#include "../../jit/compiler.hpp" -#include "../../jit/kernel_runtime.hpp" -#include "../../utils/exception.hpp" -#include "../../utils/format.hpp" -#include "runtime_utils.hpp" - -#include -#include - -#include "../heuristics/mega_moe.hpp" - -namespace deep_gemm { - -// ============================================================================ -// SM90 (Hopper) FP8 MegaMoE host runtime -// ---------------------------------------------------------------------------- -// This is the SM90 counterpart of `SM100FP8FP4MegaMoERuntime`. The kernel -// itself lives in `deep_gemm/impls/sm90_w4a8_mega_moe.cuh` and is currently a -// skeleton: dispatch/combine paths are intended to be portable from the SM100 -// version, while the GEMM (TMA load + WGMMA + epilogue) is being implemented -// in a follow-up step. -// -// Differences from SM100 path: -// * Activations and weights are both FP8 (e4m3); no FP4. -// * Activation/weight scale factors (SF) are per-128-channel float (not UE8M0 -// int + per-32 UTCCP layout). -// * No tensor memory: WGMMA accumulators are register-resident. -// * Cluster size is at most 2 (TMA multicast on A); no 2-CTA UMMA. -// ============================================================================ - -class SM90W4A8MegaMoERuntime final : public LaunchRuntime { -public: - struct Args { - // Templated arguments - int num_max_tokens_per_rank; - int hidden, intermediate_hidden; - int num_experts, num_topk; - int num_ranks; - float activation_clamp; - bool fast_math; - bool async_l1_tma_store; - bool split_sfa_tma; - bool direct_l2_scatter; - bool l2_dual_accum; - bool phase_profile; - bool l1_dual_k_accum; - bool l2_nmajor_schedule; - bool l1_nmajor_schedule; - MegaMoESM90Config config; - - // Runtime arguments - void* y; - int* cumulative_local_expert_recv_stats; - int num_tokens; - layout::SymBuffer<> sym_buffer_ptrs; - - // Tensormaps for activations and weights. Weight scale factors use - // block (128, 128) quantization and are loaded by the math warpgroup - // directly from global memory (no TMA descriptor required). - CUtensorMap tensor_map_l1_acts; - CUtensorMap tensor_map_l1_acts_sf; - CUtensorMap tensor_map_l1_weights; - const float* l1_weights_sf; - CUtensorMap tensor_map_l1_output; - CUtensorMap tensor_map_l2_acts; - CUtensorMap tensor_map_l2_acts_sf; - CUtensorMap tensor_map_l2_weights; - const float* l2_weights_sf; - - // Launch configs - LaunchArgs launch_args; - }; - - static std::string generate_impl(const Args& args) { - return fmt::format(R"( -#include - -using namespace deep_gemm; - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&sm90_w4a8_mega_moe_impl< - {}, - {}, {}, - {}, {}, - {}, - {}, {}, {}, - {}, - {}, - {}, - {}, - {}, {}, {}, - {}, - {}, {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {} - >); -}}; -)", - args.num_max_tokens_per_rank, - args.hidden, args.intermediate_hidden, - args.num_experts, args.num_topk, - args.config.num_experts_per_wave, - args.config.block_m, args.config.block_n, args.config.block_k, - args.config.num_max_pool_tokens, - args.config.num_padded_sf_pool_tokens, - args.config.num_stages, - args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, - args.config.cluster_size, - args.launch_args.grid_dim.first, args.num_ranks, - to_string(args.activation_clamp), - args.fast_math ? "true" : "false", - args.async_l1_tma_store ? "true" : "false", - args.split_sfa_tma ? "true" : "false", - args.direct_l2_scatter ? "true" : "false", - args.l2_dual_accum ? "true" : "false", - args.phase_profile ? "true" : "false", - args.l1_dual_k_accum ? "true" : "false", - args.l2_nmajor_schedule ? "true" : "false", - args.l1_nmajor_schedule ? "true" : "false"); - } - - static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { - DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.y, - args.cumulative_local_expert_recv_stats, - args.num_tokens, - args.sym_buffer_ptrs, - args.tensor_map_l1_acts, - args.tensor_map_l1_acts_sf, - args.tensor_map_l1_weights, - args.l1_weights_sf, - args.tensor_map_l1_output, - args.tensor_map_l2_acts, - args.tensor_map_l2_acts_sf, - args.tensor_map_l2_weights, - args.l2_weights_sf - )); - } -}; - -static void sm90_w4a8_mega_moe( - const torch::Tensor& y, - const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, - const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, - const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, - const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, - const std::optional cumulative_local_expert_recv_stats, - const std::vector& sym_buffer_ptrs, - const int& rank_idx, const int& num_max_tokens_per_rank, - const int& num_experts_per_rank, - const int& num_tokens, const int& num_topk, - const int& hidden, const int& intermediate_hidden, - const float& activation_clamp, - const bool& fast_math -) { - const auto num_ranks = static_cast(sym_buffer_ptrs.size()); - const auto num_experts = num_experts_per_rank * num_ranks; - const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); - - // Heuristics - const auto config = get_mega_moe_config_sm90( - num_ranks, num_experts, num_experts_per_rank, - num_max_tokens_per_rank, num_tokens, num_topk, - hidden, intermediate_hidden, num_padded_sf_pool_tokens); - - // Tensormap construction - // Acts/weights: standard 2D TMA descriptors (FP8 K-major). - // Activation SF: per-128 channel float for L1, per-64 for L2 (MN-major, no swizzle). - // Weight SF: block (128, 128) raw float pointer (no TMA descriptor). - constexpr int kGranK = 128; - constexpr int kL2ActsSFGranK = 64; - const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, - hidden, config.num_max_pool_tokens, - config.block_k, config.block_m, - static_cast(l1_acts.stride(-2)), - config.swizzle_acts_mode); - const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, - config.num_padded_sf_pool_tokens, hidden, - config.block_m, kGranK, - 1, 0); - const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, - hidden, num_experts_per_rank * intermediate_hidden * 2, - config.block_k, config.block_n, - static_cast(l1_weights.stride(-2)), - config.swizzle_weights_mode); - // L1 output (post-SwiGLU FP8): N is halved. The SM90 epilogue writes this - // staging tile to SMEM as plain row-major bytes, so the TMA store descriptor - // must use no shared-memory swizzle. Later L2 TMA loads may still swizzle - // from this row-major global buffer into their own SMEM tile. - // The default TMA store is issued per warpgroup, each writing a WG_BLOCK_M - // row tile. The split-N experiment has two WGs produce different N halves - // of the same M rows, then one TMA store writes the full 64x128 post-SwiGLU tile. - const int num_epilogue_warpgroups_h = config.num_epilogue_threads / 128; - const bool split_n_warpgroups_h = - config.block_m == 64 and config.block_n == 256 and num_epilogue_warpgroups_h == 2; - const int wg_block_m = split_n_warpgroups_h ? config.block_m : config.block_m / num_epilogue_warpgroups_h; - const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, - intermediate_hidden, config.num_max_pool_tokens, - config.block_n / 2, wg_block_m, - static_cast(l2_acts.stride(-2)), - 0); - const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, - intermediate_hidden, config.num_max_pool_tokens, - config.block_k, config.block_m, - static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode); - const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, - config.num_padded_sf_pool_tokens, intermediate_hidden, - config.block_m, kL2ActsSFGranK, - 1, 0); - const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, - intermediate_hidden, num_experts_per_rank * hidden, - config.block_k, config.block_n, - static_cast(l2_weights.stride(-2)), - config.swizzle_weights_mode); - - // Stats can be optional - int* cumulative_local_expert_recv_stats_ptr = nullptr; - if (cumulative_local_expert_recv_stats.has_value()) - cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); - - // Launch - const auto num_sms = device_runtime->get_num_sms(); - const SM90W4A8MegaMoERuntime::Args args = { - .num_max_tokens_per_rank = num_max_tokens_per_rank, - .hidden = hidden, .intermediate_hidden = intermediate_hidden, - .num_experts = num_experts, .num_topk = num_topk, - .num_ranks = num_ranks, - .activation_clamp = activation_clamp, - .fast_math = fast_math, - .async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0, - .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, - .direct_l2_scatter = get_env("DG_SM90_MOE_DIRECT_L2_SCATTER", 0) != 0, - .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, - .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, - .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, - .l2_nmajor_schedule = get_env("DG_SM90_MOE_L2_NMAJOR", 0) != 0, - .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, - .config = config, - .y = y.data_ptr(), - .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, - .num_tokens = num_tokens, - .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), - .tensor_map_l1_acts = tensor_map_l1_acts, - .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, - .tensor_map_l1_weights = tensor_map_l1_weights, - .l1_weights_sf = l1_weights_sf.data_ptr(), - .tensor_map_l1_output = tensor_map_l1_output, - .tensor_map_l2_acts = tensor_map_l2_acts, - .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, - .tensor_map_l2_weights = tensor_map_l2_weights, - .l2_weights_sf = l2_weights_sf.data_ptr(), - .launch_args = LaunchArgs(num_sms, config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, - config.smem_size, config.cluster_size) - }; - const auto code = SM90W4A8MegaMoERuntime::generate(args); - const auto runtime = compiler->build("sm90_w4a8_mega_moe", code); - SM90W4A8MegaMoERuntime::launch(runtime, args); -} - -} // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index d7c2aaa18e..66bc81a2c1 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -88,7 +88,6 @@ transform_weights_for_mega_moe_sm90, fp8_fp4_mega_moe, fp8_mega_moe, - w4a8_mega_moe, ) # Some utils diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index 14c8f7d7d3..c98c326700 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -74,11 +74,7 @@ template < bool kL2NMajorScheduleRequested = false, bool kL1NMajorScheduleRequested = false, bool kOneWarpCleanupRequested = false, - uint32_t kL2MSwizzleGroupRequested = 0, - uint32_t kL1MSwizzleGroupRequested = 0, uint32_t kSplitPhaseMode = 0, - uint32_t kExpertRangeStart = 0, - uint32_t kExpertRangeEnd = 0, uint32_t L1_SHAPE_N = kIntermediateHidden * 2, uint32_t L1_SHAPE_K = kHidden, uint32_t L2_SHAPE_N = kHidden, @@ -124,8 +120,6 @@ sm90_fp8_mega_moe_impl(void* y, DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); DG_STATIC_ASSERT(kSplitPhaseMode <= 2, "Invalid SM90 MegaMoE split phase mode"); - DG_STATIC_ASSERT(kExpertRangeStart <= (kExpertRangeEnd == 0 ? kNumExpertsPerRank : kExpertRangeEnd), "Invalid expert range"); - DG_STATIC_ASSERT((kExpertRangeEnd == 0 ? kNumExpertsPerRank : kExpertRangeEnd) <= kNumExpertsPerRank, "Expert range exceeds local experts"); // ===================================================================== // Thread / warp identification @@ -203,16 +197,15 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N constexpr bool kRunOnlyLinear1 = kSplitPhaseMode == 1; constexpr bool kRunOnlyLinear2 = kSplitPhaseMode == 2; - constexpr uint32_t kEffectiveExpertRangeEnd = kExpertRangeEnd == 0 ? kNumExpertsPerRank : kExpertRangeEnd; constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && (!kRunOnlyLinear2) && (!kUseMMASync) && + kAsyncL1TMAStoreRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); - constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kRunOnlyLinear1) && (!kUseMMASync) && + constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kRunOnlyLinear1) && (!kUseMMASync) && + constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kRunOnlyLinear2) && (!kUseMMASync) && + constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && (kHidden / BLOCK_K) % 2 == 0; using L1WGMMA = typename mma::sm90::FP8MMASelector::type; @@ -238,9 +231,9 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t kSharedMemoryAlignment = 1024; extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; - constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = kRunOnlyLinear2 ? 0u : + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); - constexpr uint32_t SMEM_SEND_BUFFER_SIZE = kRunOnlyLinear2 ? 0u : + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); @@ -249,8 +242,7 @@ sm90_fp8_mega_moe_impl(void* y, // the second L2 half cannot start immediately after 16 floats in M16 decode. constexpr uint32_t kL2SFAHalfStride = math::constexpr_align(BLOCK_M * sizeof(float), 128u) / sizeof(float); - constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = - (kRunOnlyLinear1 ? kL2SFAHalfStride : 2u * kL2SFAHalfStride) * sizeof(float); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = 2 * kL2SFAHalfStride * sizeof(float); // Block (128, 128) weight SF: 1 float per (BLOCK_N, BLOCK_K) tile for L2, // 2 floats (gate/up) for L1. Loaded by math warpgroup directly from global, // so no SMEM is needed. @@ -261,9 +253,9 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) : 0u; - constexpr uint32_t SMEM_CD_L1_SIZE = kRunOnlyLinear2 ? 0u : + constexpr uint32_t SMEM_CD_L1_SIZE = kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); - constexpr uint32_t SMEM_CD_L2_SIZE = (kRunOnlyLinear1 || kDirectL2Scatter) ? 0u : + constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); constexpr uint32_t SMEM_CD_L1_ASYNC_ELEMS = kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N; @@ -322,12 +314,10 @@ sm90_fp8_mega_moe_impl(void* y, // Initialization // ===================================================================== if (warp_idx == 0) { - if constexpr (!kRunOnlyLinear2) { - // Clean expert-count shared memory - #pragma unroll - for (uint32_t i = lane_idx; i < kNumExperts; i += 32) - ptx::st_shared(smem_expert_count + i, 0u); - } + // Clean expert-count shared memory + #pragma unroll + for (uint32_t i = lane_idx; i < kNumExperts; i += 32) + ptx::st_shared(smem_expert_count + i, 0u); } else if (warp_idx == 1) { // Init dispatch m-barriers #pragma unroll @@ -368,9 +358,7 @@ sm90_fp8_mega_moe_impl(void* y, L1_SHAPE_N, L1_SHAPE_K, L2_SHAPE_N, L2_SHAPE_K, kNumExpertsPerRank, kNumExpertsPerWave, - kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested, - kL2MSwizzleGroupRequested, kL1MSwizzleGroupRequested, - kExpertRangeStart, kEffectiveExpertRangeEnd>(workspace); + kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested>(workspace); // Pipeline state shared by TMA loaders and math warpgroups uint32_t stage_idx = 0, phase = 0; @@ -904,10 +892,9 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t k_idx = k_block_idx * BLOCK_K; // TMA load B (weight SF is now loaded directly by math warps from global) - const uint32_t num_tma_multicast_b = (kClusterSize > 1 and scheduler.is_b_multicast_valid) ? kClusterSize : 1u; tma::copy( tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], - k_idx, n_idx, num_tma_multicast_b); + k_idx, n_idx, kClusterSize); full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); } @@ -1006,9 +993,8 @@ sm90_fp8_mega_moe_impl(void* y, if (lane_idx == 0) empty_barriers[s]->arrive(); } else { - const auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); if (lane_idx < kClusterSize) - empty_barriers[s]->arrive(target_cta); + empty_barriers[s]->arrive(lane_idx); } }; @@ -2398,12 +2384,9 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if (epilogue_warp_idx == 0 and lane_idx == 0) phase_profile_record(kProfileCombineBarrier, combine_barrier_end - combine_barrier_start); - // Fused mode overlaps dispatch cleanup with combine reduce. In split K2, - // delay the cleanup sync until the combine loads are issued so K2's - // epilogue-side memory traffic does not fight the reducer. - if constexpr (!kRunOnlyLinear2) { - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); - } + // Sync with dispatch (paired with dispatch's pre-cleanup sync) so that + // dispatch may now safely clean workspace state. + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); const unsigned long long combine_reduce_start = phase_profile_clock(); constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); @@ -2513,9 +2496,6 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl const unsigned long long combine_reduce_end = phase_profile_clock(); if (epilogue_warp_idx == 0 and lane_idx == 0) phase_profile_record(kProfileCombineReduce, combine_reduce_end - combine_reduce_start); - if constexpr (kRunOnlyLinear2) { - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); - } } #else if (blockIdx.x == 0 and threadIdx.x == 0) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh deleted file mode 100644 index 42c5980b0f..0000000000 --- a/deep_gemm/include/deep_gemm/impls/sm90_w4a8_mega_moe.cuh +++ /dev/null @@ -1,2318 +0,0 @@ -#pragma once - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#define __CLION_IDE__ - -namespace deep_gemm { - -// ============================================================================ -// SM90 (Hopper) FP8 MegaMoE — full implementation -// ---------------------------------------------------------------------------- -// Pipeline (cluster=1, no TMA multicast): -// * Dispatch warps: pull tokens (FP8) and SF (per-128 channel float) from -// remote ranks via NVLink into the local L1 pool. -// * GEMM TMA-load warps (1 for A+SFA, 1 for B+SFB) feed the pipeline stages. -// * Math warpgroups (1 or 2, totalling kNumEpilogueThreads) consume each -// stage with WGMMA, accumulate into registers, then run the epilogue: -// - L1 (Linear1): SwiGLU with gate/up granularity-8 interleaved layout, -// per-row amax over the 64 post-SwiGLU columns of this block, FP8 e4m3 -// quantize, STSM into SMEM, TMA store to local L1 output buffer. -// The per-row SF is written as a *float* into the L2-acts SF buffer at -// per-64 K granularity (one SF per L1 N block), so each block is fully -// self-contained and no cross-CTA amax synchronisation is needed. -// - L2 (Linear2): BF16 cast of the GEMM output, STSM into SMEM, then -// NVLink scatter to remote combine buffers. -// * After all GEMM blocks, the math warps run the COMBINE step (top-k -// reduction in BF16) — ported verbatim from the SM100 kernel. -// ============================================================================ - -template < - uint32_t kNumMaxTokensPerRank, - uint32_t kHidden, uint32_t kIntermediateHidden, - uint32_t kNumExperts, uint32_t kNumTopk, - uint32_t kNumExpertsPerWave, - uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, - uint32_t kNumMaxPoolTokens, - uint32_t kNumPaddedSFPoolTokens, - uint32_t kNumStages, - uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, - uint32_t kNumEpilogueThreads, - uint32_t kClusterSize, - uint32_t kNumSMs, uint32_t kNumRanks, - float kActivationClamp, - bool kFastMath, - bool kAsyncL1TMAStoreRequested = true, - bool kSplitSFATMARequested = false, - bool kDirectL2ScatterRequested = false, - bool kL2DualAccumRequested = false, - bool kPhaseProfileRequested = false, - bool kL1DualKAccumRequested = false, - bool kL2NMajorScheduleRequested = false, - bool kL1NMajorScheduleRequested = false, - uint32_t L1_SHAPE_N = kIntermediateHidden * 2, - uint32_t L1_SHAPE_K = kHidden, - uint32_t L2_SHAPE_N = kHidden, - uint32_t L2_SHAPE_K = kIntermediateHidden, - uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, - uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, - uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, - uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, - uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, - uint32_t kNumTokensPerWarp = 32 / kNumTopk, - uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks -> -CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void -sm90_w4a8_mega_moe_impl(void* y, - int* cumulative_local_expert_recv_stats, - const uint32_t num_tokens, - const __grid_constant__ layout::SymBuffer sym_buffer, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, - const float* __restrict__ l1_weights_sf, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, - const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, - const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, - const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, - const float* __restrict__ l2_weights_sf) { -#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900) and (__CUDA_ARCH__ < 1000)) or defined(__CLION_IDE__) - using Barrier = cutlass::arch::ClusterTransactionBarrier; - - // ===================================================================== - // Template checks - // ===================================================================== - DG_STATIC_ASSERT(kNumDispatchThreads == 64 or kNumDispatchThreads % 128 == 0, - "Invalid number of dispatch threads"); - DG_STATIC_ASSERT(kNumNonEpilogueThreads == 64 or kNumNonEpilogueThreads == 128, - "Invalid number of GEMM TMA warps (2 or 4 warps expected)"); - DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of math/epilogue threads"); - DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); - DG_STATIC_ASSERT(kClusterSize == 1 or kClusterSize == 2, "Invalid cluster size"); - DG_STATIC_ASSERT(kNumSMs % kClusterSize == 0, "SM count must be divisible by cluster size"); - DG_STATIC_ASSERT(BLOCK_M == 16 or BLOCK_M == 32 or BLOCK_M % 64 == 0, - "BLOCK_M must be 16/32 for mma.sync decode or a multiple of WGMMA::M (64)"); - DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); - DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); - - // ===================================================================== - // Thread / warp identification - // ===================================================================== - const uint32_t sm_idx = blockIdx.x; - const uint32_t thread_idx = threadIdx.x; - const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); - const uint32_t lane_idx = ptx::get_lane_idx(); - - // Prefetch all TMA descriptors at the very beginning - if (warp_idx == 0 and cute::elect_one_sync()) { - cute::prefetch_tma_descriptor(&tensor_map_l1_acts); - cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); - cute::prefetch_tma_descriptor(&tensor_map_l1_weights); - cute::prefetch_tma_descriptor(&tensor_map_l1_output); - cute::prefetch_tma_descriptor(&tensor_map_l2_acts); - cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); - cute::prefetch_tma_descriptor(&tensor_map_l2_weights); - } - - // ===================================================================== - // Workspaces and symmetric buffer slicing (mirror SM100 layout, except SF - // for L2 activations uses per-64 K granularity) - // ===================================================================== - const auto workspace = layout::Workspace( - sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); - - constexpr auto fp8_token_layout = layout::Data(kHidden); - constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); - constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); - // Per-128 K float SF: 4 bytes per per-128 group => `kHidden / 32` bytes/token (same as SM100 packing) - constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); - // Per-64 K float SF (SM90 only): 4 bytes per per-64 group => `kIntermediateHidden / 16` bytes/token - constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); - constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); - constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); - constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); - - // Registered input area - const auto input_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxTokensPerRank, workspace.get_end_ptr()); - const auto input_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumMaxTokensPerRank, input_token_buffer.get_end_ptr()); - const auto input_topk_idx_buffer = layout::Buffer(input_topk_idx_layout, 1, kNumMaxTokensPerRank, input_sf_buffer.get_end_ptr()); - const auto input_topk_weights_buffer = layout::Buffer(input_topk_weights_layout, 1, kNumMaxTokensPerRank, input_topk_idx_buffer.get_end_ptr()); - - // L1 input area - const auto l1_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxPoolTokens, input_topk_weights_buffer.get_end_ptr()); - const auto l1_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumPaddedSFPoolTokens, l1_token_buffer.get_end_ptr()); - const auto l1_topk_weights_buffer = layout::Buffer(l1_topk_weights_layout, 1, kNumMaxPoolTokens, l1_sf_buffer.get_end_ptr()); - - // L2 input area - const auto l2_token_buffer = layout::Buffer(fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, l1_topk_weights_buffer.get_end_ptr()); - const auto l2_sf_buffer = layout::Buffer(fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, l2_token_buffer.get_end_ptr()); - - // Combine input area - const auto combine_token_buffer = layout::Buffer(bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, l2_sf_buffer.get_end_ptr()); - - // ===================================================================== - // GEMM data types and shape constants - // ===================================================================== - using a_dtype_t = cutlass::float_e4m3_t; - using b_dtype_t = cutlass::float_e4m3_t; - constexpr bool kUseMMASync = (BLOCK_M == 16 or BLOCK_M == 32); - constexpr bool kSplitNWarpgroups = - (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 2; - constexpr bool kSerialNWarpgroups = false; - constexpr bool kWideNWarpgroups = - (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 1; - constexpr uint32_t WG_BLOCK_M = kSplitNWarpgroups ? BLOCK_M : BLOCK_M / kNumEpilogueWarpgroups; - constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; - constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N - constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N - constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && (!kUseMMASync) && (!kSplitNWarpgroups) && - kNumEpilogueWarpgroups == 1; - constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); - constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && - (!kSplitNWarpgroups) && (!kSerialNWarpgroups); - constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && - (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && - (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && - (kHidden / BLOCK_K) % 2 == 0; - using L1WGMMA = typename mma::sm90::FP8MMASelector::type; - using L2WGMMA = typename mma::sm90::FP8MMASelector::type; - static_assert(L1WGMMA::M == 64 and L1WGMMA::N == WG_BLOCK_N and L1WGMMA::K == 32, - "Unexpected WGMMA shape"); - DG_STATIC_ASSERT((!kSplitNWarpgroups) or (BLOCK_M == 64 and WG_BLOCK_N == 128), - "Split-N path expects two M64N128 WGMMA consumers"); - - // A is always CTA-local. When kClusterSize=2 the scheduler pairs adjacent - // M blocks with identical expert/N/K coordinates so the B TMA can multicast. - constexpr uint32_t LOAD_BLOCK_M = BLOCK_M; - constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; - constexpr uint32_t kSwizzleAMode = kUseMMASync ? 0 : BLOCK_K * sizeof(a_dtype_t); // 0 or 128 - constexpr uint32_t kSwizzleBMode = kUseMMASync ? 0 : BLOCK_K * sizeof(b_dtype_t); // 0 or 128 - constexpr uint32_t kSwizzleCDMode = 128; - constexpr uint32_t kGranK = 128; // L1 acts SF, weights SF - constexpr uint32_t kL2ActsSFGranK = 64; // L2 acts SF (per-64 K, SM90 only) - - // ===================================================================== - // Shared memory layout - // ===================================================================== - constexpr uint32_t kSharedMemoryAlignment = 1024; - extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; - - constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = - math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); - constexpr uint32_t SMEM_SEND_BUFFER_SIZE = - math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); - constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); - constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); - // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 - // (two per-64-K halves). Each TMA destination must be 128B aligned, so - // the second L2 half cannot start immediately after 16 floats in M16 decode. - constexpr uint32_t kL2SFAHalfStride = - math::constexpr_align(BLOCK_M * sizeof(float), 128u) / sizeof(float); - constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = 2 * kL2SFAHalfStride * sizeof(float); - // Block (128, 128) weight SF: 1 float per (BLOCK_N, BLOCK_K) tile for L2, - // 2 floats (gate/up) for L1. Loaded by math warpgroup directly from global, - // so no SMEM is needed. - constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = 0; - - // CD output: max of L1 FP8 (BLOCK_M * (BLOCK_N/2) * 1 byte * num_wg) and - // L2 BF16 (BLOCK_M * BLOCK_N * 2 bytes * num_wg). - constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync - ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) - : 0u; - constexpr uint32_t SMEM_CD_L1_SIZE = - kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); - constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : - kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); - constexpr uint32_t SMEM_CD_L1_ASYNC_SIZE = kAsyncL1TMAStore ? - 2 * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) : 0u; - constexpr uint32_t SMEM_CD_OUTPUT_BASE_SIZE = - SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; - constexpr uint32_t SMEM_CD_OUTPUT_UNALIGNED_SIZE = - SMEM_CD_OUTPUT_BASE_SIZE > SMEM_CD_L1_ASYNC_SIZE ? SMEM_CD_OUTPUT_BASE_SIZE : SMEM_CD_L1_ASYNC_SIZE; - constexpr uint32_t SMEM_CD_OUTPUT_SIZE = math::constexpr_align( - SMEM_CD_OUTPUT_UNALIGNED_SIZE, kSharedMemoryAlignment); - constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_ACCUM_SIZE + SMEM_CD_OUTPUT_SIZE; - - constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = - SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + - kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - - // SMEM pointers - auto smem_expert_count = reinterpret_cast(smem_buffer); - const auto smem_send_buffers = layout::Buffer( - fp8_token_layout, kNumDispatchWarps, 1, - math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); - - auto smem_gemm_base = math::advance_ptr( - smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE); - - // mma.sync decode stages FP32 accumulators through SMEM so the epilogue can - // use logical row/column indices instead of WGMMA accumulator layout. - auto smem_accum_f32 = reinterpret_cast(smem_gemm_base); - auto smem_cd_base = math::advance_ptr(smem_gemm_base, SMEM_CD_ACCUM_SIZE); - // CD output is shared by L1 (FP8) and L2 (BF16); reinterpret-cast as needed. - auto smem_cd_l1 = reinterpret_cast(smem_cd_base); - auto smem_cd_l2 = reinterpret_cast(smem_cd_base); - - auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { - return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); - }); - auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { - return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - }); - auto sf_start_ptr = math::advance_ptr(smem_gemm_base, - SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { - return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); - }); - - // Barriers live after SF (SFB is loaded directly from global, no SMEM) - auto barrier_start_ptr = reinterpret_cast( - sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE); - auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); - auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + i; }); - auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages + i; }); - auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + i; }); - - // ===================================================================== - // Initialization - // ===================================================================== - if (warp_idx == 0) { - // Clean expert-count shared memory - #pragma unroll - for (uint32_t i = lane_idx; i < kNumExperts; i += 32) - ptx::st_shared(smem_expert_count + i, 0u); - } else if (warp_idx == 1) { - // Init dispatch m-barriers - #pragma unroll - for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) - dispatch_barriers[i]->init(1); - cutlass::arch::fence_barrier_init(); - } else if (warp_idx == 2) { - // Init GEMM full/empty barriers and combine barriers - if (cute::elect_one_sync()) { - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - // Producer arrivals: A(+SFA) + B, or A + B + SFA when - // the split-SFA experiment uses an otherwise idle TMA warp. - full_barriers[i]->init(kSplitSFATMA ? 3 : 2); - // With cluster multicast the leader CTA's TMA warp waits on peer - // empty barriers too, so every math warp releases both CTAs. - empty_barriers[i]->init(kClusterSize * kNumEpilogueWarps); - } - #pragma unroll - for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) - combine_barriers[i]->init(1); - } - cutlass::arch::fence_barrier_init(); - } - if constexpr (kClusterSize > 1) { - cute::cluster_sync(); - } else { - __syncthreads(); - } - - // ===================================================================== - // Scheduler (cluster=1) - // ===================================================================== - auto scheduler = sched::MegaMoEScheduler< - BLOCK_M, BLOCK_N, BLOCK_K, - L1_SHAPE_N, L1_SHAPE_K, - L2_SHAPE_N, L2_SHAPE_K, - kNumExpertsPerRank, kNumExpertsPerWave, - kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested>(workspace); - - // Pipeline state shared by TMA loaders and math warpgroups - uint32_t stage_idx = 0, phase = 0; - auto advance_pipeline = [&](uint32_t& k_block_idx) { - ++ k_block_idx; - stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; - phase ^= stage_idx == 0; - }; - - // Intra-SM barrier indices (mirroring SM100) - constexpr uint32_t kDispatchBarrierIdx = 0; - constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; - constexpr uint32_t kEpilogueFullBarrierIdx = 2; - constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; - - // Cross-rank NVLink barrier tags - constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; - constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; - constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; - - // Register reconfiguration counts (chosen to fit in 64512 reg budget). - // For the 256-epilogue-thread case (block_m=128, 2 math WGs): - // 128*48 + 128*40 + 256*208 = 64512 exactly. - constexpr uint32_t kNumDispatchRegisters = 48; - constexpr bool kCompactFrontendWarpgroup = (kNumDispatchWarps == 2 and kNumMMANonEpilogueWarps == 2); - constexpr uint32_t kNumNonEpilogueRegisters = kCompactFrontendWarpgroup ? kNumDispatchRegisters : 40; - constexpr uint32_t kNumEpilogueRegisters = (kSerialNWarpgroups or kWideNWarpgroups) ? 256 : ((kUseMMASync and BLOCK_M == 32) ? 240 : 208); - DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + - kNumNonEpilogueRegisters * kNumNonEpilogueThreads + - kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, - "Too many registers"); - - constexpr uint32_t kDispatchGridSyncIndex = 0; - constexpr uint32_t kEpilogueGridSyncIndex = 1; - - constexpr uint32_t kProfileDispatchTotal = 0; - constexpr uint32_t kProfileDispatchPull = 1; - constexpr uint32_t kProfileMathLoop = 2; - constexpr uint32_t kProfileCombineBarrier = 3; - constexpr uint32_t kProfileCombineReduce = 4; - constexpr uint32_t kProfileGemmCore = 5; - constexpr uint32_t kProfileL1Epilogue = 6; - constexpr uint32_t kProfileL2Epilogue = 7; - const auto phase_profile_clock = [&]() -> unsigned long long { - if constexpr (kPhaseProfileRequested) { - unsigned long long t; - asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t)); - return t; - } else { - return 0ull; - } - }; - const auto phase_profile_record = [&](const uint32_t& metric, const unsigned long long& cycles) { - if constexpr (kPhaseProfileRequested) { - if (cumulative_local_expert_recv_stats != nullptr and cycles > 0) { - auto profile = reinterpret_cast( - cumulative_local_expert_recv_stats + kNumExpertsPerRank); - atomicAdd(profile + metric, cycles); - atomicMax(profile + 8 + metric, cycles); - atomicAdd(profile + 16 + metric, 1ull); - } - } - }; - - // ===================================================================== - // ROLE 1: DISPATCH WARPS - // Mirrors SM100 dispatch with two changes: - // * SF is per-128 channel float (no UTCCP transpose). We store the - // remote per-token SF directly into the local L1 SF buffer in - // MN-major layout: `local_sf[k_chunk * num_padded_sf_pool_tokens + token_idx]`. - // * The "token_idx_in_expert" → SF token index is now the simple - // per-block linear mapping (no 4×32 transpose). - // ===================================================================== - if (warp_idx < kNumDispatchWarps) { - cutlass::arch::warpgroup_reg_dealloc(); - const unsigned long long dispatch_total_start = phase_profile_clock(); - - DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); - constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; - const auto read_topk_idx = [&](const auto& process) { - #pragma unroll - for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; - i < num_tokens; - i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { - int expert_idx = -1; - if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { - expert_idx = static_cast( - __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); - if (expert_idx >= 0) - process(i * kNumTopk + lane_idx, expert_idx); - } - __syncwarp(); - } - }; - - // Count tokens per expert - read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { - atomicAdd_block(smem_expert_count + expert_idx, 1); - }); - ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); - - // Stake out per-expert SM offsets via global atomic - #pragma unroll - for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { - const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); - smem_expert_count[i] = static_cast( - ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); - } - ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); - - // Write source token-topk indices to remote ranks - read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { - const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; - const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); - const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( - expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); - *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; - }); - - comm::grid_sync( - workspace, sm_idx, thread_idx, - [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } - ); - - if (sm_idx == 0) { - #pragma unroll - for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { - const auto dst_rank_idx = i / kNumExpertsPerRank; - const auto dst_local_expert_idx = i % kNumExpertsPerRank; - const auto expert_status = *workspace.get_expert_send_count_ptr(i); - *sym_buffer.map( - workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), - dst_rank_idx) = expert_status & 0xffffffff; - ptx::atomic_add_sys( - sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), - expert_status); - } - } - ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); - - comm::nvlink_barrier( - workspace, sym_buffer, sm_idx, thread_idx, - [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, - false, true); - - // Sync with epilogue warps before pulling tokens - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); - const unsigned long long dispatch_pull_start = phase_profile_clock(); - - // Token / SF pull loop - uint32_t pull_mbarrier_phase = 0; - const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); - const auto pull_mbarrier = dispatch_barriers[warp_idx]; - - scheduler.fetch_expert_recv_count(); - - constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); - int current_expert_idx = -1; - uint32_t stored_rank_count[kNumRanksPerLane] = {}; - uint32_t expert_start_idx = 0, expert_end_idx = 0; - uint32_t expert_pool_block_offset = 0; - - constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; - for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { - int old_expert_idx = current_expert_idx; - while (token_idx >= expert_end_idx) { - if (++ current_expert_idx >= kNumExpertsPerRank) - break; - expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); - expert_start_idx = expert_end_idx; - expert_end_idx += scheduler.get_num_tokens(current_expert_idx); - } - if (current_expert_idx >= kNumExpertsPerRank) - break; - - if (old_expert_idx != current_expert_idx) { - old_expert_idx = current_expert_idx; - #pragma unroll - for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { - const uint32_t j = i * 32 + lane_idx; - stored_rank_count[i] = j < kNumRanks ? - static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; - } - } - - // Round-robin rank selection (identical to SM100) - uint32_t current_rank_in_expert_idx; - uint32_t remaining[kNumRanksPerLane]; - #pragma unroll - for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) - remaining[i] = stored_rank_count[i]; - uint32_t offset = 0; - uint32_t token_idx_in_expert = token_idx - expert_start_idx; - uint32_t slot_idx = token_idx_in_expert; - uint32_t token_idx_in_rank; - while (true) { - uint32_t num_actives_in_lane = 0; - uint32_t min_in_lane = 0xffffffff; - #pragma unroll - for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { - num_actives_in_lane += remaining[i] > 0; - if (remaining[i] > 0) - min_in_lane = cute::min(min_in_lane, remaining[i]); - } - const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); - const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); - - const uint32_t num_round_tokens = length * num_active_ranks; - if (slot_idx < num_round_tokens) { - const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; - uint32_t num_seen_ranks = 0; - current_rank_in_expert_idx = 0; - #pragma unroll - for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { - const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); - const uint32_t num_active_lanes = __popc(mask); - if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) - current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); - num_seen_ranks += num_active_lanes; - } - token_idx_in_rank = offset + (slot_idx / num_active_ranks); - break; - } - slot_idx -= num_round_tokens; - offset += length; - #pragma unroll - for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) - remaining[i] -= cute::min(remaining[i], length); - } - - const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( - current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); - const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; - const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; - - // TMA pull token data into SMEM - if (cute::elect_one_sync()) { - ptx::tma_load_1d( - pull_buffer.get_base_ptr(), - sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), - current_rank_in_expert_idx), - pull_mbarrier, kHidden); - } - __syncwarp(); - - // Copy SF: per-128 K floats, written linearly (no UTCCP transpose). - constexpr uint32_t kNumSFFloats = kHidden / 128; - DG_STATIC_ASSERT(kNumSFFloats > 0 and kHidden % 128 == 0, "Invalid SF"); - const auto remote_sf_ptr = sym_buffer.map( - input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), - current_rank_in_expert_idx); - const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); - const uint32_t sf_pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; - #pragma unroll - for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFFloats, 32u); ++ i) { - const uint32_t j = i * 32 + lane_idx; - if (j < kNumSFFloats) - local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; - } - __syncwarp(); - - const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; - if (cute::elect_one_sync()) { - const auto weight = *sym_buffer.map( - input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, - current_rank_in_expert_idx); - *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; - - ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); - ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); - - ptx::tma_store_1d( - l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), - pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); - - *workspace.get_token_src_metadata_ptr(pool_token_idx) = - {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; - - cute::tma_store_arrive(); - ptx::tma_store_wait<0>(); - ptx::red_add_rel( - workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); - } - __syncwarp(); - } - - - - // Cleanup workspace, overlapping with combine - const unsigned long long dispatch_pull_end = phase_profile_clock(); - if (lane_idx == 0) { - phase_profile_record(kProfileDispatchPull, dispatch_pull_end - dispatch_pull_start); - phase_profile_record(kProfileDispatchTotal, dispatch_pull_end - dispatch_total_start); - } - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); - - DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); - if (sm_idx == 0) { - #pragma unroll - for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) - *workspace.get_expert_send_count_ptr(i) = 0; - } else { - for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { - const auto num_recv_tokens = static_cast( - *workspace.get_expert_recv_count_sum_ptr(i)); - const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); - - expert_pool_block_offset = scheduler.get_pool_block_offset(i); - - ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); - - DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); - if (warp_idx == 0) { - *workspace.get_expert_recv_count_sum_ptr(i) = 0; - } else if (warp_idx == 1) { - if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) - ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); - __syncwarp(); - } - - for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) - *workspace.get_expert_recv_count_ptr(j, i) = 0; - __syncwarp(); - - for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { - *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; - *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; - } - __syncwarp(); - } - } - - comm::nvlink_barrier( - workspace, sym_buffer, sm_idx, thread_idx, - [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, - true, false); - - // ===================================================================== - // ROLE 2: GEMM TMA LOAD warps (load A+SFA, B+SFB) - // Default: 4 non-epilogue warps, two active and two idle. - // Compact frontend experiment: 2 dispatch warps + 2 TMA warps share the - // first warpgroup, reducing total CTA threads for the M128/2WG path. - // ===================================================================== - } else if (warp_idx == kNumDispatchWarps) { - cutlass::arch::warpgroup_reg_dealloc(); - - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, - const uint32_t& local_expert_idx, - const uint32_t& num_k_blocks, - const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 - ? &tensor_map_l2_acts : &tensor_map_l1_acts; - const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 - ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; - - const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; - const uint32_t valid_m = scheduler.template get_valid_m(); - const bool has_valid_m = valid_m > 0; - - // Wait for the pool to be ready. Cluster peers can be dummy CTAs for - // the tail M unit when an expert has an odd number of M blocks. - if (has_valid_m) { - if (block_phase == sched::BlockPhase::Linear1) { - const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); - const auto expected = valid_m; - while (ptx::ld_acq(ptr) != expected); - } else { - const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); - // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. - constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; - const uint64_t expected = (kNumL1BlockNs >= 64) - ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); - while (ptx::ld_acq_gpu(ptr) != expected); - } - } - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - empty_barriers[stage_idx]->wait(phase ^ 1); - - if (cute::elect_one_sync()) { - if (has_valid_m) { - const uint32_t m_idx = pool_block_idx * BLOCK_M; - const uint32_t k_idx = k_block_idx * BLOCK_K; - - // TMA load A - tma::copy( - tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], - k_idx, m_idx, 1); - - if constexpr (kSplitSFATMA) { - full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE); - } else { - // TMA load SFA - if (block_phase == sched::BlockPhase::Linear1) { - // L1 SFA per-128: load (BLOCK_M, 1) at K=k_block_idx - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], - m_idx, k_block_idx, 1); - full_barriers[stage_idx]->arrive_and_expect_tx( - SMEM_A_SIZE_PER_STAGE + BLOCK_M * sizeof(float)); - } else { - // L2 SFA per-64: descriptor box is (block_mn, 1) (see make_tma_sf_desc), - // so we must issue two single-group TMAs and place them at smem offsets - // 0 and BLOCK_M to match math's load offsets (`+ 0 * BLOCK_M` / `+ 1 * BLOCK_M`). - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], - m_idx, k_block_idx * 2, 1); - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], - smem_sfa[stage_idx] + kL2SFAHalfStride, - m_idx, k_block_idx * 2 + 1, 1); - full_barriers[stage_idx]->arrive_and_expect_tx( - SMEM_A_SIZE_PER_STAGE + 2 * BLOCK_M * sizeof(float)); - } - } - } else { - full_barriers[stage_idx]->arrive(); - } - } - __syncwarp(); - } - }); - - } else if (warp_idx == kNumDispatchWarps + 1) { - cutlass::arch::warpgroup_reg_dealloc(); - - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, - const uint32_t& local_expert_idx, - const uint32_t& num_k_blocks, - const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const auto tensor_map_b_ptr = - block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; - - const uint32_t shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; - - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - empty_barriers[stage_idx]->wait(phase ^ 1); - - if (cute::elect_one_sync()) { - const uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; - const uint32_t k_idx = k_block_idx * BLOCK_K; - - // TMA load B (weight SF is now loaded directly by math warps from global) - tma::copy( - tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], - k_idx, n_idx, kClusterSize); - - full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); - } - __syncwarp(); - } - }); - - } else if (kSplitSFATMA && warp_idx == kNumDispatchWarps + 2) { - cutlass::arch::warpgroup_reg_dealloc(); - - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, - const uint32_t& local_expert_idx, - const uint32_t& num_k_blocks, - const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - (void)local_expert_idx; - (void)n_block_idx; - const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 - ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; - - const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; - const uint32_t valid_m = scheduler.template get_valid_m(); - const bool has_valid_m = valid_m > 0; - - if (has_valid_m) { - if (block_phase == sched::BlockPhase::Linear1) { - const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); - const auto expected = valid_m; - while (ptx::ld_acq(ptr) != expected); - } else { - const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); - constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; - const uint64_t expected = (kNumL1BlockNs >= 64) - ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); - while (ptx::ld_acq_gpu(ptr) != expected); - } - } - - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - empty_barriers[stage_idx]->wait(phase ^ 1); - - if (cute::elect_one_sync()) { - if (has_valid_m) { - const uint32_t m_idx = pool_block_idx * BLOCK_M; - - if (block_phase == sched::BlockPhase::Linear1) { - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], - m_idx, k_block_idx, 1); - full_barriers[stage_idx]->arrive_and_expect_tx(BLOCK_M * sizeof(float)); - } else { - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], - m_idx, k_block_idx * 2, 1); - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], - smem_sfa[stage_idx] + kL2SFAHalfStride, - m_idx, k_block_idx * 2 + 1, 1); - full_barriers[stage_idx]->arrive_and_expect_tx(2 * BLOCK_M * sizeof(float)); - } - } else { - full_barriers[stage_idx]->arrive(); - } - } - __syncwarp(); - } - }); - - } else if (warp_idx < kNumDispatchWarps + kNumMMANonEpilogueWarps) { - // Idle non-epilogue warps (kNumDispatchWarps+2, +3). They must still - // participate in the warpgroup-collective `setmaxnreg.dec.sync.aligned` - // so that the math warpgroup's `warpgroup_reg_alloc` can succeed. - cutlass::arch::warpgroup_reg_dealloc(); - - } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { - // ===================================================================== - // ROLE 3: MATH WARPGROUPS (WGMMA + epilogue + combine) - // ===================================================================== - cutlass::arch::warpgroup_reg_alloc(); - - const uint32_t epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); - const uint32_t epilogue_wg_idx = epilogue_warp_idx / 4; - const uint32_t epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; - const uint32_t warp_idx_in_wg = epilogue_warp_idx % 4; - - uint32_t async_l1_store_stage = 0; - bool async_l1_store_pending[2] = {false, false}; - uint32_t async_l1_store_pool[2] = {0, 0}; - uint32_t async_l1_store_n[2] = {0, 0}; - - const auto arrive_empty_barrier = [&](const uint32_t& s) { - if constexpr (kClusterSize == 1) { - if (lane_idx == 0) - empty_barriers[s]->arrive(); - } else { - if (lane_idx < kClusterSize) - empty_barriers[s]->arrive(lane_idx); - } - }; - - const auto notify_l1_ready = [&](const uint32_t& ready_pool_block_idx, - const uint32_t& ready_n_block_idx) { - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - ptx::red_or_rel_gpu( - workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), - 1ull << ready_n_block_idx); - } - __syncwarp(); - }; - - const auto drain_async_l1_store_stage = [&](const uint32_t& store_stage) { - if constexpr (kAsyncL1TMAStore) { - if (async_l1_store_pending[store_stage]) { - // Two SMEM L1 store buffers are used in FIFO order; waiting - // for <=1 outstanding store makes the older buffer reusable. - ptx::tma_store_wait<1>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(async_l1_store_pool[store_stage], - async_l1_store_n[store_stage]); - async_l1_store_pending[store_stage] = false; - } - } - }; - - const auto drain_all_async_l1_stores = [&]() { - if constexpr (kAsyncL1TMAStore) { - if (async_l1_store_pending[0] or async_l1_store_pending[1]) { - ptx::tma_store_wait<0>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (async_l1_store_pending[0]) { - notify_l1_ready(async_l1_store_pool[0], async_l1_store_n[0]); - async_l1_store_pending[0] = false; - } - if (async_l1_store_pending[1]) { - notify_l1_ready(async_l1_store_pool[1], async_l1_store_n[1]); - async_l1_store_pending[1] = false; - } - } - } - }; - - // WGMMA-output register layout helpers - const uint32_t row_idx = lane_idx / 4; - const uint32_t col_idx = lane_idx % 4; - const uint32_t r_0 = warp_idx_in_wg * 16 + row_idx; - const uint32_t r_1 = r_0 + 8; - - DG_STATIC_ASSERT(kSplitNWarpgroups || (BLOCK_M % kNumEpilogueWarpgroups == 0), "Invalid block M"); - if constexpr (kUseMMASync) { - DG_STATIC_ASSERT(WG_BLOCK_M == BLOCK_M, "mma.sync decode path uses one M tile per warpgroup"); - DG_STATIC_ASSERT(kNumEpilogueWarpgroups == 1, "mma.sync decode path currently uses one math warpgroup"); - } else if constexpr (kSplitNWarpgroups) { - DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, - "Split-N WGs must each run one M64N128 WGMMA per K-block"); - } else if constexpr (kSerialNWarpgroups) { - DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, - "Serial-N path runs two M64N128 WGMMAs per K-block"); - } else { - DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M, "Each warpgroup must run exactly one WGMMA per K-block"); - } - - // Sync with dispatch - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); - const unsigned long long math_loop_start = phase_profile_clock(); - - scheduler.for_each_block([&](const sched::BlockPhase& block_phase, - const uint32_t& local_expert_idx, - const uint32_t& num_k_blocks, - const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const uint32_t valid_m = scheduler.template get_valid_m(); - const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; - const uint32_t m_idx = pool_block_idx * BLOCK_M; - const uint32_t wg_n_idx = kSplitNWarpgroups ? epilogue_wg_idx * WG_BLOCK_N : 0; - const uint32_t wg_l1_out_n_idx = kSplitNWarpgroups ? epilogue_wg_idx * WG_L1_OUT_BLOCK_N : 0; - const uint32_t n_idx = n_block_idx * BLOCK_N + wg_n_idx; - const uint32_t row_block_offset = kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M; - const uint32_t row_offset_r0 = row_block_offset + r_0; - const uint32_t row_offset_r1 = row_block_offset + r_1; - const bool valid_r0 = row_offset_r0 < valid_m; - const bool valid_r1 = row_offset_r1 < valid_m; - - - if constexpr (kAsyncL1TMAStore) { - if (block_phase != sched::BlockPhase::Linear1) - drain_all_async_l1_stores(); - } - - if constexpr (kUseMMASync) { - using MMASyncTiled = cute::TiledMMA< - cute::MMA_Atom, - cute::Layout, cute::_4, cute::_1>>>; - auto mma_sync_tiled = MMASyncTiled{}; - auto thr_mma = mma_sync_tiled.get_thread_slice(epilogue_thread_idx); - auto sC = cute::make_tensor( - cute::make_smem_ptr(smem_accum_f32), - cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); - auto tCsC = thr_mma.partition_C(sC); - auto tCrC = thr_mma.make_fragment_C(tCsC); - auto tCrFinal = thr_mma.make_fragment_C(tCsC); - cute::clear(tCrFinal); - auto cC = cute::make_identity_tensor(cute::shape(sC)); - auto tCcC = thr_mma.partition_C(cC); - - auto add_scaled_fragment = [&](const float& scale_b, const uint32_t& sfa_offset) { - CUTE_UNROLL - for (int i = 0; i < cute::size(tCrC); ++i) { - const auto coord = tCcC(i); - if (cute::elem_less(coord, cute::shape(sC))) { - const uint32_t row = static_cast(cute::get<0>(coord)); - const uint32_t col = static_cast(cute::get<1>(coord)); - const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + sfa_offset + row); - tCrFinal(i) += scale_a * scale_b * tCrC(i); - } - } - }; - - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - full_barriers[stage_idx]->wait(phase); - - constexpr uint32_t kL1SFKBlocks = kHidden / 128; - constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; - constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; - - if (block_phase == sched::BlockPhase::Linear1) { - const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; - const uint32_t up_n = kL1SFGateBlks + gate_n; - const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; - const float gate_sf = __ldg(base + gate_n * kL1SFKBlocks); - const float up_sf = __ldg(base + up_n * kL1SFKBlocks); - - auto sA = cute::make_tensor( - cute::make_smem_ptr(smem_a[stage_idx]), - cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); - auto sB = cute::make_tensor( - cute::make_smem_ptr(smem_b[stage_idx]), - cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<128>{}), cute::GenRowMajor{})); - cute::clear(tCrC); - cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); - - CUTE_UNROLL - for (int i = 0; i < cute::size(tCrC); ++i) { - const auto coord = tCcC(i); - if (cute::elem_less(coord, cute::shape(sC))) { - const uint32_t row = static_cast(cute::get<0>(coord)); - const uint32_t col = static_cast(cute::get<1>(coord)); - const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + row); - const float scale_b = ((col / 8u) & 1u) ? up_sf : gate_sf; - tCrFinal(i) += scale_a * scale_b * tCrC(i); - } - } - } else { - const float l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert - + ((n_block_idx * BLOCK_N + wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); - CUTE_UNROLL - for (uint32_t half = 0; half < 2; ++half) { - const uint32_t k_off = half * (BLOCK_K / 2); - auto sA = cute::make_tensor( - cute::make_smem_ptr(smem_a[stage_idx] + k_off), - cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<64>{}), - cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); - auto sB = cute::make_tensor( - cute::make_smem_ptr(smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off), - cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<64>{}), - cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); - cute::clear(tCrC); - cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); - add_scaled_fragment(l2_sf, half * kL2SFAHalfStride); - } - } - - arrive_empty_barrier(stage_idx); - __syncwarp(); - } - - CUTE_UNROLL - for (int i = 0; i < cute::size(tCrFinal); ++i) { - const auto coord = tCcC(i); - if (cute::elem_less(coord, cute::shape(sC))) { - const uint32_t row = static_cast(cute::get<0>(coord)); - const uint32_t col = static_cast(cute::get<1>(coord)); - smem_accum_f32[row * BLOCK_N + col] = tCrFinal(i); - } - } - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - - constexpr uint32_t kMMASyncRowsPerPass = kNumEpilogueThreads / 8; - DG_STATIC_ASSERT(kMMASyncRowsPerPass == 16, "mma.sync epilogue maps 8 lanes per row"); - - if (block_phase == sched::BlockPhase::Linear1) { - #pragma unroll - for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { - const uint32_t row = row_base + epilogue_thread_idx / 8; - const uint32_t lane_in_row = epilogue_thread_idx % 8; - const bool valid_row = row < valid_m; - float swiglu_values[8]; - float amax = 0.0f; - - #pragma unroll - for (uint32_t p = 0; p < 8; ++p) { - const uint32_t gate_col = p * 16 + lane_in_row; - const uint32_t up_col = gate_col + 8; - float gate = valid_row ? smem_accum_f32[row * BLOCK_N + gate_col] : 0.0f; - float up = valid_row ? smem_accum_f32[row * BLOCK_N + up_col] : 0.0f; - if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { - gate = cute::min(gate, kActivationClamp); - up = cute::min(cute::max(up, -kActivationClamp), kActivationClamp); - } - const float e = kFastMath ? __expf(-gate) : expf(-gate); - const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); - const float weight = valid_row ? *l1_topk_weights_buffer - .get_data_buffer(m_idx + row) - .get_base_ptr() : 0.0f; - const float value = gate * sig * up * weight; - swiglu_values[p] = value; - amax = cute::max(amax, cute::abs(value)); - } - - amax = math::warp_reduce<8, false>(amax, math::ReduceMax()); - float2 amax_pair = {amax, amax}; - float2 sf_pair, sf_inv_pair; - math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); - const float sf = sf_pair.x; - const float sf_inv = sf_inv_pair.x; - - if (valid_row) { - #pragma unroll - for (uint32_t p = 0; p < 8; ++p) { - const uint32_t out_col = p * 8 + lane_in_row; - smem_cd_l1[row * L1_OUT_BLOCK_N + out_col] = - cutlass::float_e4m3_t(swiglu_values[p] * sf_inv); - } - if (lane_in_row == 0) { - auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); - const uint32_t token_idx = pool_block_idx * BLOCK_M + row; - const uint32_t k_sf_idx = n_block_idx; - sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_idx] = sf; - } - } - } - - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; - cute::tma_store_fence(); - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_l1_output, - smem_cd_l1, - out_n_idx, - m_idx); - cute::tma_store_arrive(); - } - __syncwarp(); - ptx::tma_store_wait<0>(); - - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - ptx::red_or_rel_gpu( - workspace.get_l2_arrival_mask_ptr(pool_block_idx), - 1ull << n_block_idx); - } - __syncwarp(); - } else { - #pragma unroll - for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { - const uint32_t row = row_base + epilogue_thread_idx / 8; - const uint32_t lane_in_row = epilogue_thread_idx % 8; - if (row < valid_m) { - const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row); - const uint32_t dst_rank_idx = src_metadata.rank_idx; - const uint32_t dst_token_idx = src_metadata.token_idx; - const uint32_t dst_topk_idx = src_metadata.topk_idx; - const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) - .get_data_buffer(dst_token_idx); - #pragma unroll - for (uint32_t v = 0; v < 2; ++v) { - const uint32_t col = lane_in_row * 16 + v * 8; - uint4 packed; - float f0 = smem_accum_f32[row * BLOCK_N + col + 0]; - float f1 = smem_accum_f32[row * BLOCK_N + col + 1]; - float f2 = smem_accum_f32[row * BLOCK_N + col + 2]; - float f3 = smem_accum_f32[row * BLOCK_N + col + 3]; - float f4 = smem_accum_f32[row * BLOCK_N + col + 4]; - float f5 = smem_accum_f32[row * BLOCK_N + col + 5]; - float f6 = smem_accum_f32[row * BLOCK_N + col + 6]; - float f7 = smem_accum_f32[row * BLOCK_N + col + 7]; - packed.x = math::cast_into_bf16_and_pack(f0, f1); - packed.y = math::cast_into_bf16_and_pack(f2, f3); - packed.z = math::cast_into_bf16_and_pack(f4, f5); - packed.w = math::cast_into_bf16_and_pack(f6, f7); - auto dst_ptr = math::advance_ptr( - dst_token.get_base_ptr(), - n_idx * sizeof(nv_bfloat16) + col * sizeof(nv_bfloat16)); - *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; - } - } - } - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - } - return; - } - - if constexpr (kSerialNWarpgroups) { - using WGMMA = L1WGMMA; - constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; - constexpr uint32_t kNumSerialN = 2; - float final_accum[kNumSerialN][kAccumPerThread] = {}; - float accum[kAccumPerThread]; - - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - full_barriers[stage_idx]->wait(phase); - - float scale_a_0_lo, scale_a_1_lo; - float scale_a_0_hi, scale_a_1_hi; - if (block_phase == sched::BlockPhase::Linear1) { - scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); - scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); - } else { - scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); - scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); - scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r0); - scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r1); - } - - constexpr uint32_t kL1SFKBlocks = kHidden / 128; - constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; - constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; - - #pragma unroll - for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { - const uint32_t serial_wg_n_idx = serial_n_idx * WG_BLOCK_N; - float gate_sf = 0.0f, up_sf = 0.0f, l2_sf = 0.0f; - if (block_phase == sched::BlockPhase::Linear1) { - const uint32_t gate_n = (n_block_idx * BLOCK_N + serial_wg_n_idx) / 256u; - const uint32_t up_n = kL1SFGateBlks + gate_n; - const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; - gate_sf = __ldg(base + gate_n * kL1SFKBlocks); - up_sf = __ldg(base + up_n * kL1SFKBlocks); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_wait<0>(); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - const float sb = (i & 1u) ? up_sf : gate_sf; - final_accum[serial_n_idx][i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; - final_accum[serial_n_idx][i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; - final_accum[serial_n_idx][i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; - final_accum[serial_n_idx][i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; - } - } else { - l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert - + ((n_block_idx * BLOCK_N + serial_wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_wait<0>(); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - final_accum[serial_n_idx][i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; - final_accum[serial_n_idx][i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; - final_accum[serial_n_idx][i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; - final_accum[serial_n_idx][i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; - } - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { - const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + k_off, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage_idx] + serial_wg_n_idx * BLOCK_K + k_off, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_wait<0>(); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - final_accum[serial_n_idx][i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; - final_accum[serial_n_idx][i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; - final_accum[serial_n_idx][i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; - final_accum[serial_n_idx][i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; - } - } - } - - arrive_empty_barrier(stage_idx); - __syncwarp(); - } - - if (row_block_offset >= valid_m) { - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - return; - } - - if (block_phase == sched::BlockPhase::Linear1) { - constexpr uint32_t kNumPairs = kAccumPerThread / 8; - #pragma unroll - for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { - const uint32_t serial_l1_out_n_idx = serial_n_idx * WG_L1_OUT_BLOCK_N; - float swiglu_r0[kNumPairs][2]; - float swiglu_r1[kNumPairs][2]; - float amax_r0 = 0.0f, amax_r1 = 0.0f; - - #pragma unroll - for (uint32_t p = 0; p < kNumPairs; ++ p) { - const uint32_t gate = 2 * p, up = 2 * p + 1; - auto clamp_gate = [](float& x) { - if constexpr (kActivationClamp != cute::numeric_limits::infinity()) - x = cute::min(x, kActivationClamp); - }; - auto clamp_up = [](float& x) { - if constexpr (kActivationClamp != cute::numeric_limits::infinity()) - x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); - }; - float g_r0_c0 = final_accum[serial_n_idx][gate*4 + 0]; clamp_gate(g_r0_c0); - float g_r0_c1 = final_accum[serial_n_idx][gate*4 + 1]; clamp_gate(g_r0_c1); - float g_r1_c0 = final_accum[serial_n_idx][gate*4 + 2]; clamp_gate(g_r1_c0); - float g_r1_c1 = final_accum[serial_n_idx][gate*4 + 3]; clamp_gate(g_r1_c1); - float u_r0_c0 = final_accum[serial_n_idx][up*4 + 0]; clamp_up(u_r0_c0); - float u_r0_c1 = final_accum[serial_n_idx][up*4 + 1]; clamp_up(u_r0_c1); - float u_r1_c0 = final_accum[serial_n_idx][up*4 + 2]; clamp_up(u_r1_c0); - float u_r1_c1 = final_accum[serial_n_idx][up*4 + 3]; clamp_up(u_r1_c1); - auto silu = [](float x) -> float { - const float e = kFastMath ? __expf(-x) : expf(-x); - const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); - return x * sig; - }; - if (valid_r0) { - swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; - swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; - amax_r0 = cute::max(amax_r0, cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); - } else { - swiglu_r0[p][0] = 0.0f; - swiglu_r0[p][1] = 0.0f; - } - if (valid_r1) { - swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; - swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; - amax_r1 = cute::max(amax_r1, cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); - } else { - swiglu_r1[p][0] = 0.0f; - swiglu_r1[p][1] = 0.0f; - } - } - - float weight_r0 = valid_r0 ? *l1_topk_weights_buffer - .get_data_buffer(m_idx + row_offset_r0) - .get_base_ptr() : 0.0f; - float weight_r1 = valid_r1 ? *l1_topk_weights_buffer - .get_data_buffer(m_idx + row_offset_r1) - .get_base_ptr() : 0.0f; - #pragma unroll - for (uint32_t p = 0; p < kNumPairs; ++ p) { - swiglu_r0[p][0] *= weight_r0; - swiglu_r0[p][1] *= weight_r0; - swiglu_r1[p][0] *= weight_r1; - swiglu_r1[p][1] *= weight_r1; - } - amax_r0 *= cute::abs(weight_r0); - amax_r1 *= cute::abs(weight_r1); - amax_r0 = math::warp_reduce<4, false>(amax_r0, math::ReduceMax()); - amax_r1 = math::warp_reduce<4, false>(amax_r1, math::ReduceMax()); - - float sf_r0, sf_inv_r0, sf_r1, sf_inv_r1; - { - float2 amax_pair = {amax_r0, amax_r1}; - float2 sf_pair, sf_inv_pair; - math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); - sf_r0 = sf_pair.x; sf_inv_r0 = sf_inv_pair.x; - sf_r1 = sf_pair.y; sf_inv_r1 = sf_inv_pair.y; - } - - #pragma unroll - for (uint32_t p = 0; p < kNumPairs; ++ p) { - const float v00 = swiglu_r0[p][0] * sf_inv_r0; - const float v01 = swiglu_r0[p][1] * sf_inv_r0; - const float v10 = swiglu_r1[p][0] * sf_inv_r1; - const float v11 = swiglu_r1[p][1] * sf_inv_r1; - const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); - const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); - const uint32_t col = p * 8 + col_idx * 2; - auto* p0 = reinterpret_cast( - smem_cd_l1 + r_0 * L1_OUT_BLOCK_N + serial_l1_out_n_idx + col); - auto* p1 = reinterpret_cast( - smem_cd_l1 + r_1 * L1_OUT_BLOCK_N + serial_l1_out_n_idx + col); - if (valid_r0) - *p0 = r0_pair.__x; - if (valid_r1) - *p1 = r1_pair.__x; - } - - if (col_idx == 0) { - auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); - const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; - const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; - const uint32_t k_sf_idx = (n_block_idx * L1_OUT_BLOCK_N + serial_l1_out_n_idx) / 64u; - if (valid_r0) - sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r0] = sf_r0; - if (valid_r1) - sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r1] = sf_r1; - } - } - - ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { - const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; - cute::tma_store_fence(); - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_l1_output, - smem_cd_l1, - out_n_idx, - m_idx + row_block_offset); - cute::tma_store_arrive(); - } - __syncwarp(); - ptx::tma_store_wait<0>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); - } else { - constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; - #pragma unroll - for (uint32_t serial_n_idx = 0; serial_n_idx < kNumSerialN; ++serial_n_idx) { - const uint32_t serial_n_idx_base = n_block_idx * BLOCK_N + serial_n_idx * WG_BLOCK_N; - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { - const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; - auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { - auto smem_ptr = smem_cd_l2 + row * WG_BLOCK_N + col; - *reinterpret_cast(smem_ptr) = packed; - }; - if (valid_r0) { - const uint32_t r0_lo = math::cast_into_bf16_and_pack( - final_accum[serial_n_idx][chunk_lo*4 + 0], final_accum[serial_n_idx][chunk_lo*4 + 1]); - const uint32_t r0_hi = math::cast_into_bf16_and_pack( - final_accum[serial_n_idx][chunk_hi*4 + 0], final_accum[serial_n_idx][chunk_hi*4 + 1]); - write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); - write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); - } - if (valid_r1) { - const uint32_t r1_lo = math::cast_into_bf16_and_pack( - final_accum[serial_n_idx][chunk_lo*4 + 2], final_accum[serial_n_idx][chunk_lo*4 + 3]); - const uint32_t r1_hi = math::cast_into_bf16_and_pack( - final_accum[serial_n_idx][chunk_hi*4 + 2], final_accum[serial_n_idx][chunk_hi*4 + 3]); - write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); - write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); - } - } - ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - - const uint32_t row_in_warp_block = lane_idx / 16; - const uint32_t lane_in_row = lane_idx % 16; - constexpr uint32_t cols_per_lane = WG_BLOCK_N / 16; - #pragma unroll - for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { - const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; - const uint32_t m_idx_in_block = row_block_offset + row_in_wg; - if (m_idx_in_block >= valid_m) break; - - const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); - const uint32_t dst_rank_idx = src_metadata.rank_idx; - const uint32_t dst_token_idx = src_metadata.token_idx; - const uint32_t dst_topk_idx = src_metadata.topk_idx; - auto smem_ptr = smem_cd_l2 + row_in_wg * WG_BLOCK_N + lane_in_row * cols_per_lane; - const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) - .get_data_buffer(dst_token_idx); - const auto packed = *reinterpret_cast(smem_ptr); - auto dst_ptr = math::advance_ptr( - dst_token.get_base_ptr(), - serial_n_idx_base * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); - *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; - } - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - } - } - return; - } - - // ---------------- GEMM ---------------- - using WGMMA = L1WGMMA; - constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; // 64 for M=64,N=128 - float final_accum[kAccumPerThread] = {}; - float accum[kAccumPerThread]; - - const unsigned long long block_gemm_start = phase_profile_clock(); - const auto run_default_gemm_loop = [&]() { -for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - full_barriers[stage_idx]->wait(phase); - - // Read SF (must precede warpgroup_arrive) - float scale_a_0_lo, scale_a_1_lo; - float scale_a_0_hi, scale_a_1_hi; // Only used in L2 (per-64 K) - if (block_phase == sched::BlockPhase::Linear1) { - scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); - scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); - } else { - // L2: SFA layout is (K=2, M=BLOCK_M) MN-major; first half SF at offset 0, second at BLOCK_M - scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r0); - scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + row_offset_r1); - scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r0); - scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + kL2SFAHalfStride + row_offset_r1); - } - - // ----- Block (128, 128) weight SF (loaded directly from global) ----- - // L1 weight SF shape: (E, 2*IH/128, H/128) MN-major. The N axis is - // [gate(IH/128), up(IH/128)]; with the gate/up gran-8 interleave on - // the FP8 weight, each BLOCK_N=128 tile covers 64 rows of gate plus - // 64 rows of up taken from the same original 128-row block, so: - // gate_sf_n = n_block_idx / 2 - // up_sf_n = (IH/128) + n_block_idx / 2 - // - // L2 weight SF shape: (E, H/128, IH/128) MN-major. One scalar per - // (BLOCK_N, BLOCK_K) tile, broadcast across all WGMMA accumulators. - // - // NOTE: we tried hoisting these LDGs above the barrier wait and/or - // having only lane 0 load + shfl-broadcast. Both regressed on H20 - // by 7-11% across all batch sizes, presumably because (a) Hopper's - // L1 read-only cache already coalesces same-address LDGs from all - // 128 WG threads and (b) hoisting contended with the dispatch - // warps' NVLink LDGs on the MIO unit. Keep the simple parallel - // post-wait load. - constexpr uint32_t kL1SFKBlocks = kHidden / 128; - constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; - constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; - float gate_sf = 0.0f, up_sf = 0.0f, l2_sf_lo = 0.0f, l2_sf_hi = 0.0f; - if (block_phase == sched::BlockPhase::Linear1) { - const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; - const uint32_t up_n = kL1SFGateBlks + gate_n; - const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; - gate_sf = __ldg(base + gate_n * kL1SFKBlocks); - up_sf = __ldg(base + up_n * kL1SFKBlocks); - } else { - const float* base = l2_weights_sf + local_expert_idx * kL2SFPerExpert + k_block_idx; - const uint32_t sf_n = (n_block_idx * BLOCK_N + wg_n_idx) / 128u; - l2_sf_lo = __ldg(base + sf_n * kL2SFKBlocks); - if constexpr (WG_BLOCK_N > 128) - l2_sf_hi = __ldg(base + (sf_n + 1u) * kL2SFKBlocks); - else - l2_sf_hi = l2_sf_lo; - } - - if (block_phase == sched::BlockPhase::Linear1) { - // Single per-128 K-block WGMMA group - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_wait<0>(); - - arrive_empty_barrier(stage_idx); - - // L1: gate/up alternate at gran=8 along N; each `i` block of 8 - // cols belongs entirely to one of {gate, up}, so .x and .y - // share the same scalar. - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - const float sb = (i & 1u) ? up_sf : gate_sf; - final_accum[i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; - final_accum[i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; - final_accum[i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; - final_accum[i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; - } - } else { - if constexpr (kL2DualAccum) { - float accum_hi[kAccumPerThread]; - - const auto desc_a_lo0 = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K, 1); - const auto desc_b_lo0 = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K, 1); - const auto desc_a_lo1 = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + WGMMA::K, 1); - const auto desc_b_lo1 = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K + WGMMA::K, 1); - const auto desc_a_hi0 = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + BLOCK_K / 2, 1); - const auto desc_b_hi0 = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K + BLOCK_K / 2, 1); - const auto desc_a_hi1 = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + BLOCK_K / 2 + WGMMA::K, 1); - const auto desc_b_hi1 = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K + BLOCK_K / 2 + WGMMA::K, 1); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) { - ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_fence_operand(accum_hi[i]); - } - ptx::warpgroup_arrive(); - WGMMA::wgmma(desc_a_lo0, desc_b_lo0, accum, false); - WGMMA::wgmma(desc_a_lo1, desc_b_lo1, accum, true); - WGMMA::wgmma(desc_a_hi0, desc_b_hi0, accum_hi, false); - WGMMA::wgmma(desc_a_hi1, desc_b_hi1, accum_hi, true); - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) { - ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_fence_operand(accum_hi[i]); - } - ptx::warpgroup_wait<0>(); - - arrive_empty_barrier(stage_idx); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; - final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; - final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; - final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; - final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; - final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum_hi[i*4+0]; - final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum_hi[i*4+1]; - final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum_hi[i*4+2]; - final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum_hi[i*4+3]; - } - } else { - // L2: split BLOCK_K=128 into two halves (per-64 SFA), each 2 WGMMAs. - // First half: K=0..63, SFA = scale_a_*_lo - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_wait<0>(); - - // L2 weight SF is per 128 output columns; M64N256 spans two SF groups. - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; - final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; - final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; - final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; - final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; - } - - // Second half: K=64..127, SFA = scale_a_*_hi - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { - const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage_idx] + row_block_offset * BLOCK_K + k_off, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_wait<0>(); - - arrive_empty_barrier(stage_idx); - - // L2 second half: same SFA half, still choose weight SF by N chunk. - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - const float l2_sf = (i < 16u) ? l2_sf_lo : l2_sf_hi; - final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; - final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; - final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; - final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; - } - } - } - } - }; - - const auto run_l1_dual_k_gemm_loop = [&]() { - DG_STATIC_ASSERT((kHidden / BLOCK_K) % 2 == 0, "L1 dual-K expects an even number of K blocks"); - constexpr uint32_t kL1SFKBlocks = kHidden / 128; - constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; - const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; - const uint32_t up_n = kL1SFGateBlks + gate_n; - const float* expert_sf_base = l1_weights_sf + local_expert_idx * kL1SFPerExpert; - const float* gate_sf_base = expert_sf_base + gate_n * kL1SFKBlocks; - const float* up_sf_base = expert_sf_base + up_n * kL1SFKBlocks; - float accum_b[kAccumPerThread]; - - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks;) { - const uint32_t stage0 = stage_idx; - const uint32_t phase0 = phase; - const uint32_t k0 = k_block_idx; - full_barriers[stage0]->wait(phase0); - - const float scale_a0_r0 = ptx::ld_shared(smem_sfa[stage0] + row_offset_r0); - const float scale_a0_r1 = ptx::ld_shared(smem_sfa[stage0] + row_offset_r1); - const float2 gate_sf_pair = __ldg(reinterpret_cast(gate_sf_base + k0)); - const float2 up_sf_pair = __ldg(reinterpret_cast(up_sf_base + k0)); - const float gate_sf0 = gate_sf_pair.x; - const float up_sf0 = up_sf_pair.x; - - advance_pipeline(k_block_idx); - const uint32_t stage1 = stage_idx; - const uint32_t phase1 = phase; - full_barriers[stage1]->wait(phase1); - - const float scale_a1_r0 = ptx::ld_shared(smem_sfa[stage1] + row_offset_r0); - const float scale_a1_r1 = ptx::ld_shared(smem_sfa[stage1] + row_offset_r1); - const float gate_sf1 = gate_sf_pair.y; - const float up_sf1 = up_sf_pair.y; - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) { - ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_fence_operand(accum_b[i]); - } - ptx::warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage0] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage0] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = mma::sm90::make_smem_desc( - smem_a[stage1] + row_block_offset * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = mma::sm90::make_smem_desc( - smem_b[stage1] + wg_n_idx * BLOCK_K + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum_b, k); - } - ptx::warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread; ++ i) { - ptx::warpgroup_fence_operand(accum[i]); - ptx::warpgroup_fence_operand(accum_b[i]); - } - ptx::warpgroup_wait<0>(); - - arrive_empty_barrier(stage0); - arrive_empty_barrier(stage1); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { - const float sb0 = (i & 1u) ? up_sf0 : gate_sf0; - const float sb1 = (i & 1u) ? up_sf1 : gate_sf1; - final_accum[i*4+0] += scale_a0_r0 * sb0 * accum[i*4+0]; - final_accum[i*4+1] += scale_a0_r0 * sb0 * accum[i*4+1]; - final_accum[i*4+2] += scale_a0_r1 * sb0 * accum[i*4+2]; - final_accum[i*4+3] += scale_a0_r1 * sb0 * accum[i*4+3]; - final_accum[i*4+0] += scale_a1_r0 * sb1 * accum_b[i*4+0]; - final_accum[i*4+1] += scale_a1_r0 * sb1 * accum_b[i*4+1]; - final_accum[i*4+2] += scale_a1_r1 * sb1 * accum_b[i*4+2]; - final_accum[i*4+3] += scale_a1_r1 * sb1 * accum_b[i*4+3]; - } - - advance_pipeline(k_block_idx); - } - }; - - if constexpr (kL1DualKAccum) { - if (block_phase == sched::BlockPhase::Linear1) - run_l1_dual_k_gemm_loop(); - else - run_default_gemm_loop(); - } else { - run_default_gemm_loop(); - } - - const unsigned long long block_gemm_end = phase_profile_clock(); - if (epilogue_warp_idx == 0 and lane_idx == 0) - phase_profile_record(kProfileGemmCore, block_gemm_end - block_gemm_start); - - // Skip epilogue when block is past valid M (still must release via empty). - // A dummy cluster peer may still carry an async L1 store from the - // previous valid block, so drain it before leaving the L1 wave. - if (row_block_offset >= valid_m) { - if constexpr (kAsyncL1TMAStore) { - if (block_phase == sched::BlockPhase::Linear1) - drain_all_async_l1_stores(); - } - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - return; - } - - const unsigned long long block_epilogue_start = phase_profile_clock(); - if (block_phase == sched::BlockPhase::Linear1) { - // ---------------- L1 EPILOGUE: SwiGLU + FP8 quantize + TMA store ---------------- - // Layout in `final_accum`: - // 16 chunks of 8 N-cols, each chunk = 4 floats per thread = (r0c0, r0c1, r1c0, r1c1). - // Gate chunks: even (0, 2, ..., 14). Up chunks: odd (1, 3, ..., 15). - // Pair `p` ∈ [0, 8): gate chunk = 2p, up chunk = 2p+1. - // - // For each pair we produce 4 post-SwiGLU floats per thread, mapped to - // output cols (p*8 + col_idx*2 + {0,1}) for both r0 and r1. - - constexpr uint32_t kNumPairs = kAccumPerThread / 8; - constexpr uint32_t kNumSFGroups = WG_L1_OUT_BLOCK_N / 64; - DG_STATIC_ASSERT(WG_L1_OUT_BLOCK_N % 64 == 0, "L1 output SF is per 64 columns"); - float swiglu_r0[kNumPairs][2]; - float swiglu_r1[kNumPairs][2]; - - // Per-row amax, one scale for each 64-col L1 output group. - float amax_r0[kNumSFGroups] = {}; - float amax_r1[kNumSFGroups] = {}; - - // Compute SwiGLU + per-group amax. - #pragma unroll - for (uint32_t p = 0; p < kNumPairs; ++ p) { - const uint32_t gate = 2 * p, up = 2 * p + 1; - const uint32_t sf_group = p / 8; - - auto clamp_gate = [](float& x) { - if constexpr (kActivationClamp != cute::numeric_limits::infinity()) - x = cute::min(x, kActivationClamp); - }; - auto clamp_up = [](float& x) { - if constexpr (kActivationClamp != cute::numeric_limits::infinity()) - x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); - }; - float g_r0_c0 = final_accum[gate*4 + 0]; clamp_gate(g_r0_c0); - float g_r0_c1 = final_accum[gate*4 + 1]; clamp_gate(g_r0_c1); - float g_r1_c0 = final_accum[gate*4 + 2]; clamp_gate(g_r1_c0); - float g_r1_c1 = final_accum[gate*4 + 3]; clamp_gate(g_r1_c1); - float u_r0_c0 = final_accum[up*4 + 0]; clamp_up(u_r0_c0); - float u_r0_c1 = final_accum[up*4 + 1]; clamp_up(u_r0_c1); - float u_r1_c0 = final_accum[up*4 + 2]; clamp_up(u_r1_c0); - float u_r1_c1 = final_accum[up*4 + 3]; clamp_up(u_r1_c1); - - auto silu = [](float x) -> float { - const float e = kFastMath ? __expf(-x) : expf(-x); - const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); - return x * sig; - }; - - if (valid_r0) { - swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; - swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; - amax_r0[sf_group] = cute::max( - amax_r0[sf_group], - cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); - } else { - swiglu_r0[p][0] = 0.0f; - swiglu_r0[p][1] = 0.0f; - } - if (valid_r1) { - swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; - swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; - amax_r1[sf_group] = cute::max( - amax_r1[sf_group], - cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); - } else { - swiglu_r1[p][0] = 0.0f; - swiglu_r1[p][1] = 0.0f; - } - } - - - const float weight_r0 = valid_r0 ? *l1_topk_weights_buffer - .get_data_buffer(m_idx + row_offset_r0) - .get_base_ptr() : 0.0f; - const float weight_r1 = valid_r1 ? *l1_topk_weights_buffer - .get_data_buffer(m_idx + row_offset_r1) - .get_base_ptr() : 0.0f; - #pragma unroll - for (uint32_t p = 0; p < kNumPairs; ++ p) { - swiglu_r0[p][0] *= weight_r0; - swiglu_r0[p][1] *= weight_r0; - swiglu_r1[p][0] *= weight_r1; - swiglu_r1[p][1] *= weight_r1; - } - #pragma unroll - for (uint32_t g = 0; g < kNumSFGroups; ++ g) { - amax_r0[g] *= cute::abs(weight_r0); - amax_r1[g] *= cute::abs(weight_r1); - } - #pragma unroll - for (uint32_t g = 0; g < kNumSFGroups; ++ g) { - amax_r0[g] = math::warp_reduce<4, false>(amax_r0[g], math::ReduceMax()); - amax_r1[g] = math::warp_reduce<4, false>(amax_r1[g], math::ReduceMax()); - } - - float sf_r0[kNumSFGroups], sf_inv_r0[kNumSFGroups]; - float sf_r1[kNumSFGroups], sf_inv_r1[kNumSFGroups]; - #pragma unroll - for (uint32_t g = 0; g < kNumSFGroups; ++ g) { - float2 amax_pair = {amax_r0[g], amax_r1[g]}; - float2 sf_pair, sf_inv_pair; - math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); - sf_r0[g] = sf_pair.x; sf_inv_r0[g] = sf_inv_pair.x; - sf_r1[g] = sf_pair.y; sf_inv_r1[g] = sf_inv_pair.y; - } - - // Quantize and write to smem_cd_l1 (row-major, no swizzle). - const uint32_t l1_store_stage = kAsyncL1TMAStore ? async_l1_store_stage : 0u; - if constexpr (kAsyncL1TMAStore) - drain_async_l1_store_stage(l1_store_stage); - auto* smem_cd_l1_wg = smem_cd_l1 - + l1_store_stage * kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N - + (kSplitNWarpgroups ? 0 : epilogue_wg_idx * WG_BLOCK_M * L1_OUT_BLOCK_N); - #pragma unroll - for (uint32_t p = 0; p < kNumPairs; ++ p) { - const uint32_t sf_group = p / 8; - const float v00 = swiglu_r0[p][0] * sf_inv_r0[sf_group]; - const float v01 = swiglu_r0[p][1] * sf_inv_r0[sf_group]; - const float v10 = swiglu_r1[p][0] * sf_inv_r1[sf_group]; - const float v11 = swiglu_r1[p][1] * sf_inv_r1[sf_group]; - - const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); - const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); - - const uint32_t col = p * 8 + col_idx * 2; - auto* p0 = reinterpret_cast( - smem_cd_l1_wg + r_0 * L1_OUT_BLOCK_N + wg_l1_out_n_idx + col); - auto* p1 = reinterpret_cast( - smem_cd_l1_wg + r_1 * L1_OUT_BLOCK_N + wg_l1_out_n_idx + col); - if (valid_r0) - *p0 = r0_pair.__x; - if (valid_r1) - *p1 = r1_pair.__x; - } - - // Write L2-activation SF as float, one value per 64 output columns. - if (col_idx == 0) { - auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); - const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; - const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; - const uint32_t base_k_sf_idx = (n_block_idx * L1_OUT_BLOCK_N + wg_l1_out_n_idx) / 64u; - #pragma unroll - for (uint32_t g = 0; g < kNumSFGroups; ++ g) { - if (valid_r0) - sf_base_ptr[(base_k_sf_idx + g) * kNumPaddedSFPoolTokens + token_r0] = sf_r0[g]; - if (valid_r1) - sf_base_ptr[(base_k_sf_idx + g) * kNumPaddedSFPoolTokens + token_r1] = sf_r1[g]; - } - } - - // Issue TMA store of the entire tile. Padding rows beyond - // `valid_m` are written with stale/garbage FP8 to the L1-output - // pool buffer, but they are never consumed downstream: the L2 - // GEMM tile loads them, but its NVLink-scatter epilogue is - // gated by `m_idx_in_block >= valid_m`, and stale SF in the - // padding rows can produce NaN accumulators that simply stay - // in registers (only valid rows are converted to BF16 and - // STSM'd into smem). Using TMA for partial tiles is a large - // win for low-batch / decode where every tile is partial. - if constexpr (kSplitNWarpgroups) { - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; - cute::tma_store_fence(); - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_l1_output, - smem_cd_l1, - out_n_idx, - m_idx); - cute::tma_store_arrive(); - } - __syncwarp(); - ptx::tma_store_wait<0>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); - } else { - ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { - const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; - cute::tma_store_fence(); - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_l1_output, - smem_cd_l1_wg, - out_n_idx, - m_idx + row_block_offset); - cute::tma_store_arrive(); - } - __syncwarp(); - if constexpr (kAsyncL1TMAStore) { - ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - async_l1_store_pending[l1_store_stage] = true; - async_l1_store_pool[l1_store_stage] = pool_block_idx; - async_l1_store_n[l1_store_stage] = n_block_idx; - async_l1_store_stage ^= 1u; - } else { - ptx::tma_store_wait<0>(); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); - } - } - const unsigned long long block_epilogue_end = phase_profile_clock(); - if (epilogue_warp_idx == 0 and lane_idx == 0) - phase_profile_record(kProfileL1Epilogue, block_epilogue_end - block_epilogue_start); - } else { - // ---------------- L2 EPILOGUE: BF16 cast + NVLink scatter ---------------- - constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; - - if constexpr (kDirectL2Scatter) { - DG_STATIC_ASSERT(WG_BLOCK_N == 128, "Direct L2 scatter prototype only supports N128"); - - auto scatter_direct_row = [&](const uint32_t& row_offset, const bool& valid_row, const uint32_t& row_accum_offset) { - if (valid_row) { - const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row_offset); - const uint32_t dst_rank_idx = src_metadata.rank_idx; - const uint32_t dst_token_idx = src_metadata.token_idx; - const uint32_t dst_topk_idx = src_metadata.topk_idx; - const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) - .get_data_buffer(dst_token_idx); - auto dst_base = math::advance_ptr( - dst_token.get_base_ptr(), n_idx * sizeof(nv_bfloat16)); - auto mapped_dst_base = sym_buffer.map(dst_base, dst_rank_idx); - - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { - const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; - const uint32_t col_lo = chunk_lo * 8 + col_idx * 2; - const uint32_t col_hi = chunk_hi * 8 + col_idx * 2; - const uint32_t packed_lo = math::cast_into_bf16_and_pack( - final_accum[chunk_lo * 4 + row_accum_offset + 0], - final_accum[chunk_lo * 4 + row_accum_offset + 1]); - const uint32_t packed_hi = math::cast_into_bf16_and_pack( - final_accum[chunk_hi * 4 + row_accum_offset + 0], - final_accum[chunk_hi * 4 + row_accum_offset + 1]); - *reinterpret_cast(mapped_dst_base + col_lo * sizeof(nv_bfloat16)) = packed_lo; - *reinterpret_cast(mapped_dst_base + col_hi * sizeof(nv_bfloat16)) = packed_hi; - } - } - }; - - scatter_direct_row(row_offset_r0, valid_r0, 0); - scatter_direct_row(row_offset_r1, valid_r1, 2); - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - } else { - // STSM into smem_cd_l2 (BF16). Reuse SM100 column-swizzle layout. - #pragma unroll - for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { - // Each i consumes 8 floats (one 16x256b chunk in SM100 terms). - // For SM90 WGMMA layout, 8 floats per i correspond to 2 chunks of 4 floats: - // final_accum[i*8 + (0..3)] = chunk 2i: (r0c0, r0c1, r1c0, r1c1) - // final_accum[i*8 + (4..7)] = chunk 2i+1: same shape - const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; - - // Write to SMEM at appropriate position - // Row r_0 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r0_lo - // Row r_0 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r0_hi - // Row r_1 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r1_lo - // Row r_1 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r1_hi - auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { - auto smem_ptr = smem_cd_l2 - + epilogue_wg_idx * WG_BLOCK_M * WG_BLOCK_N - + row * WG_BLOCK_N - + col; - // BF16 STS: 2 bf16 elements - *reinterpret_cast(smem_ptr) = packed; - }; - if (valid_r0) { - const uint32_t r0_lo = math::cast_into_bf16_and_pack( - final_accum[chunk_lo*4 + 0], final_accum[chunk_lo*4 + 1]); - const uint32_t r0_hi = math::cast_into_bf16_and_pack( - final_accum[chunk_hi*4 + 0], final_accum[chunk_hi*4 + 1]); - write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); - write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); - } - if (valid_r1) { - const uint32_t r1_lo = math::cast_into_bf16_and_pack( - final_accum[chunk_lo*4 + 2], final_accum[chunk_lo*4 + 3]); - const uint32_t r1_hi = math::cast_into_bf16_and_pack( - final_accum[chunk_hi*4 + 2], final_accum[chunk_hi*4 + 3]); - write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); - write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); - } - } - - ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - - // Scatter to remote ranks via NVLink (one row per warp-pair) - // Each warpgroup-warp covers 8 unique rows × 2 (r_0 + r_1 doubled by warps) - // Lane group of 16 within a warp → 1 row. - const uint32_t row_in_warp_block = lane_idx / 16; // 0 or 1 - const uint32_t lane_in_row = lane_idx % 16; - const uint32_t cols_per_lane = WG_BLOCK_N / 16; - static_assert(WG_BLOCK_N == 64 or WG_BLOCK_N == 128 or WG_BLOCK_N == 256, - "L2 scatter supports per-WG N64/N128/N256"); - - #pragma unroll - for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { - const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; - const uint32_t m_idx_in_block = row_block_offset + row_in_wg; - if (m_idx_in_block >= valid_m) break; - - const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); - const uint32_t dst_rank_idx = src_metadata.rank_idx; - const uint32_t dst_token_idx = src_metadata.token_idx; - const uint32_t dst_topk_idx = src_metadata.topk_idx; - - auto smem_ptr = smem_cd_l2 - + epilogue_wg_idx * WG_BLOCK_M * WG_BLOCK_N - + row_in_wg * WG_BLOCK_N - + lane_in_row * cols_per_lane; - const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) - .get_data_buffer(dst_token_idx); - - if constexpr (WG_BLOCK_N == 256) { - const auto packed0 = *reinterpret_cast(smem_ptr); - const auto packed1 = *(reinterpret_cast(smem_ptr) + 1); - auto dst_ptr = math::advance_ptr( - dst_token.get_base_ptr(), - n_idx * sizeof(nv_bfloat16) + lane_in_row * 2u * sizeof(uint4)); - auto mapped_dst_ptr = sym_buffer.map(dst_ptr, dst_rank_idx); - mapped_dst_ptr[0] = packed0; - mapped_dst_ptr[1] = packed1; - } else if constexpr (WG_BLOCK_N == 128) { - const auto packed = *reinterpret_cast(smem_ptr); - auto dst_ptr = math::advance_ptr( - dst_token.get_base_ptr(), - n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); - *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; - } else { - const auto packed = *reinterpret_cast(smem_ptr); - auto dst_ptr = math::advance_ptr( - dst_token.get_base_ptr(), - n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint2)); - *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; - } - } - - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - } - const unsigned long long block_epilogue_end = phase_profile_clock(); - if (epilogue_warp_idx == 0 and lane_idx == 0) - phase_profile_record(kProfileL2Epilogue, block_epilogue_end - block_epilogue_start); - } - }); - const unsigned long long math_loop_end = phase_profile_clock(); - if (epilogue_warp_idx == 0 and lane_idx == 0) - phase_profile_record(kProfileMathLoop, math_loop_end - math_loop_start); - - // ---------------- COMBINE ---------------- - // NVLink barrier first: signals remote ranks that this rank's GEMM - // outputs (NVLink scatter targets) are fully written. - const unsigned long long combine_barrier_start = phase_profile_clock(); - comm::nvlink_barrier( - workspace, sym_buffer, sm_idx, epilogue_thread_idx, - [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } - ); - const unsigned long long combine_barrier_end = phase_profile_clock(); - if (epilogue_warp_idx == 0 and lane_idx == 0) - phase_profile_record(kProfileCombineBarrier, combine_barrier_end - combine_barrier_start); - - // Sync with dispatch (paired with dispatch's pre-cleanup sync) so that - // dispatch may now safely clean workspace state. - ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); - const unsigned long long combine_reduce_start = phase_profile_clock(); - - constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); - constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); - - constexpr uint32_t kNumChunkSlots = 3; - constexpr uint32_t kNumMaxRegistersForBuffer = 128; - constexpr uint32_t kNumChunks = - (kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE - and kHidden <= 32 * kNumMaxRegistersForBuffer) ? 1 : 2; - constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; - constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); - constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; - DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); - DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); - DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); - DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); - DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements"); - DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); - - DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( - reinterpret_cast(barrier_start_ptr) - smem_buffer)); - - const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { - return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); - }); - const auto combine_store_buffer = math::advance_ptr( - smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); - - auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { - return combine_barriers[i + epilogue_warp_idx * 2]; - }); - - uint32_t combine_phase = 0; - uint32_t load_stage_idx = 0; - for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; - token_idx < num_tokens; - token_idx += kNumSMs * kNumEpilogueWarps) { - const int stored_topk_slot_idx = lane_idx < kNumTopk ? - static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; - const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); - - for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { - const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; - - uint32_t mask = total_mask; - const auto move_mask_and_load = [&](const uint32_t& i) { - if (mask) { - const uint32_t slot_idx = __ffs(mask) - 1; - mask ^= 1 << slot_idx; - if (cute::elect_one_sync()) { - const auto src_ptr = math::advance_ptr( - combine_token_buffer.get_rank_buffer(slot_idx) - .get_data_buffer(token_idx).get_base_ptr(), - chunk_byte_offset); - ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); - ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); - } - __syncwarp(); - return true; - } - return false; - }; - - bool do_reduce = move_mask_and_load(load_stage_idx); - - float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; - while (do_reduce) { - do_reduce = move_mask_and_load(load_stage_idx ^ 1); - combine_load_barriers[load_stage_idx]->wait(combine_phase); - #pragma unroll - for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { - const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; - const auto bf16_values = reinterpret_cast(&uint4_values); - #pragma unroll - for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) - ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); - } - combine_phase ^= load_stage_idx; - load_stage_idx ^= 1; - } - - #pragma unroll - for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { - uint4 casted; - auto casted_bf16 = reinterpret_cast(&casted); - #pragma unroll - for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) - casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); - - if (j == 0) { - ptx::tma_store_wait<0>(); - __syncwarp(); - } - ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, - casted.x, casted.y, casted.z, casted.w); - } - __syncwarp(); - - if (cute::elect_one_sync()) { - cute::tma_store_fence(); - ptx::tma_store_1d( - math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), - combine_store_buffer, kNumChunkBytes); - cute::tma_store_arrive(); - } - __syncwarp(); - } - } - const unsigned long long combine_reduce_end = phase_profile_clock(); - if (epilogue_warp_idx == 0 and lane_idx == 0) - phase_profile_record(kProfileCombineReduce, combine_reduce_end - combine_reduce_start); - } -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only supports sm_90"); -#endif -} - -} // namespace deep_gemm - -#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh b/deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh deleted file mode 100644 index 503b5b872d..0000000000 --- a/deep_gemm/include/deep_gemm/quantization/mxfp4_dequant.cuh +++ /dev/null @@ -1,58 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// MXFP4 (E2M1) → FP8 (E4M3) dequant helper for SM90 W4A8 fused MegaMoE. -// -// Ported from vLLM Marlin's `dequant<__nv_fp8x4_e4m3, kFE2M1f, true>` in -// `csrc/quantization/marlin/dequant.h` (Marlin / vLLM Apache 2.0). The bit -// pattern conversion is intentionally identical so that weight files prepared -// with Marlin / TensorRT-LLM `mxfp4_moe` preprocessing can be reused. - -#pragma once - -#include -#include - -namespace deep_gemm { -namespace w4a8 { - -#define DG_W4A8_INLINE __device__ __forceinline__ - -// Convert one packed-FP4 dword (8 nibbles = 8 FP4 values) into eight FP8 E4M3 -// bytes laid out as two `__nv_fp8x4_e4m3` fragments. -// -// IMPORTANT: this routine performs the bit-pattern conversion only; the per-32 -// E8M0 group scale must be applied separately on the WGMMA accumulator -// (Marlin-style), not folded into the FP8 exponent (which is the Humming -// alternative we deliberately do not use here). -// -// Layout note inherited from Marlin: the upper half of `q` (q << 4 step) -// produces `frag_b[0]` and the lower half produces `frag_b[1]`. This reverse -// indexing matches the way Marlin permutes the packed-FP4 weight tile so the -// FP8 register layout aligns with the WGMMA m64n*k32 B-fragment layout. -DG_W4A8_INLINE void dequant_mxfp4_to_fp8(int q, __nv_fp8x4_e4m3* frag_b) { - constexpr int FP4_EXPONENT = 2; - constexpr int FP8_EXPONENT = 4; - constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT; // = 2 - constexpr int MASK = 0x70707070; - - int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 4; - int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); - - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); -} - -// Convert an E8M0 byte (unsigned 8-bit exponent, bias = 127) into a float -// multiplicative scale. E8M0 stores `2 ** (e - 127)`. Used to apply the -// MXFP4 per-32 group scale to a WGMMA accumulator. -DG_W4A8_INLINE float e8m0_to_float(std::uint8_t e8m0_byte) { - // 2 ** (e - 127). Build via float bit pattern: sign=0, exponent=e, mantissa=0. - std::uint32_t bits = static_cast(e8m0_byte) << 23; - return *reinterpret_cast(&bits); -} - -#undef DG_W4A8_INLINE - -} // namespace w4a8 -} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh index 14206464a7..dd3b077df6 100644 --- a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -25,10 +25,6 @@ template (math::align(current_local_expert_idx + 1, kNumExpertsPerWave), kExpertRangeEnd); + return math::align(current_local_expert_idx + 1, kNumExpertsPerWave); } CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { @@ -106,47 +92,8 @@ struct MegaMoEScheduler { return __reduce_add_sync(0xffffffff, num_blocks); } - template - CUTLASS_DEVICE uint32_t get_packed_l2_token_offset(const uint32_t& expert_idx) { - uint32_t num_tokens = 0; - #pragma unroll - for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { - if (i * 32 + ptx::get_lane_idx() < expert_idx) - num_tokens += math::align(stored_num_tokens_per_expert[i], kPackedAlignment); - } - return __reduce_add_sync(0xffffffff, num_tokens); - } - - template - CUTLASS_DEVICE uint32_t get_packed_l2_block_offset(const uint32_t& expert_idx) { - DG_STATIC_ASSERT(kPackedAlignment % BLOCK_M == 0, "Packed L2 alignment must be a multiple of BLOCK_M"); - return get_packed_l2_token_offset(expert_idx) / BLOCK_M; - } - - template - CUTLASS_DEVICE uint32_t get_aligned_pool_token_offset(const uint32_t& expert_idx) { - uint32_t num_tokens = 0; - #pragma unroll - for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { - if (i * 32 + ptx::get_lane_idx() < expert_idx) - num_tokens += math::ceil_div(stored_num_tokens_per_expert[i], kPoolBlockM) * kPoolBlockM; - } - return __reduce_add_sync(0xffffffff, num_tokens); - } - - CUTLASS_DEVICE uint32_t get_token_offset(const uint32_t& expert_idx) { - uint32_t num_tokens = 0; - #pragma unroll - for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { - if (i * 32 + ptx::get_lane_idx() < expert_idx) - num_tokens += stored_num_tokens_per_expert[i]; - } - return __reduce_add_sync(0xffffffff, num_tokens); - } - CUTLASS_DEVICE void advance_expert_idx() { current_pool_block_offset += get_current_num_m_blocks(); - current_token_offset += current_num_tokens; current_local_expert_idx += 1; current_num_tokens = get_num_tokens(current_local_expert_idx); } @@ -155,17 +102,12 @@ struct MegaMoEScheduler { current_local_expert_idx = expert_idx; current_num_tokens = get_num_tokens(expert_idx); current_pool_block_offset = get_pool_block_offset(expert_idx); - current_token_offset = get_token_offset(expert_idx); } CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { return current_pool_block_offset; } - CUTLASS_DEVICE uint32_t get_current_token_offset() const { - return current_token_offset; - } - CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { return math::ceil_div(current_num_tokens, BLOCK_M); } @@ -179,134 +121,26 @@ struct MegaMoEScheduler { return kDoUMMAAligned ? math::align(m, 16u) : m; } - template - CUTLASS_DEVICE void map_expert_block_idx(const uint32_t& local_block_idx, - const uint32_t& num_m_units, - uint32_t& m_unit_idx, - uint32_t& n_idx) const { - if constexpr (kMSwizzleGroup > 0) { - const auto group_start = (local_block_idx / (kMSwizzleGroup * kNumBlockNs)) * kMSwizzleGroup; - const auto in_group_idx = local_block_idx % (kMSwizzleGroup * kNumBlockNs); - const auto num_m_in_group = cute::min(kMSwizzleGroup, num_m_units - group_start); - n_idx = in_group_idx / num_m_in_group; - m_unit_idx = group_start + in_group_idx - n_idx * num_m_in_group; - } else if constexpr (kNMajorSchedule) { - n_idx = local_block_idx / num_m_units; - m_unit_idx = local_block_idx - n_idx * num_m_units; - } else { - m_unit_idx = local_block_idx / kNumBlockNs; - n_idx = local_block_idx % kNumBlockNs; - } - } - - CUTLASS_DEVICE void update_peer_cta_alive(const uint32_t& num_m_blocks) { - if constexpr (kClusterSize == 1) { - is_peer_cta_alive = true; - is_a_multicast_valid = true; - is_b_multicast_valid = true; - } else { - const auto peer_m_block_idx = (m_block_idx ^ 1u); - is_peer_cta_alive = m_block_idx < num_m_blocks and - peer_m_block_idx < num_m_blocks; - is_a_multicast_valid = false; - is_b_multicast_valid = is_peer_cta_alive; - } - } - - template - CUTLASS_DEVICE void map_cluster_bcast_b_block_idx(const uint32_t& local_block_idx, - const uint32_t& num_m_blocks, - uint32_t& m_idx, - uint32_t& n_idx) { - DG_STATIC_ASSERT(kClusterSize == 2, "Cluster B multicast mapping requires 2 CTAs"); - DG_STATIC_ASSERT(kMSwizzleGroup == 0 or kMSwizzleGroup == 4 or kMSwizzleGroup == 8 or kMSwizzleGroup == 16, - "Invalid cluster B multicast M-swizzle group"); - - if constexpr (kMSwizzleGroup == 0) { - const auto even_m_blocks = num_m_blocks & ~1u; - const auto even_region_blocks = even_m_blocks * kNumBlockNs; - if (local_block_idx < even_region_blocks) { - n_idx = local_block_idx / even_m_blocks; - m_idx = local_block_idx - n_idx * even_m_blocks; - is_peer_cta_alive = true; - is_a_multicast_valid = false; - is_b_multicast_valid = true; - } else { - const auto tail_idx = local_block_idx - even_region_blocks; - m_idx = even_m_blocks; - n_idx = tail_idx; - // Odd-M tails pair CTAs across adjacent N blocks. The peer CTA is - // still alive for remote empty-barrier arrivals, but it does not - // consume the same B tile, so B multicast must be disabled. - is_peer_cta_alive = (tail_idx ^ 1u) < kNumBlockNs; - is_a_multicast_valid = false; - is_b_multicast_valid = false; - } - } else { - const auto num_blocks_per_group = kMSwizzleGroup * kNumBlockNs; - const auto group_idx = local_block_idx / num_blocks_per_group; - const auto first_m_idx = group_idx * kMSwizzleGroup; - auto in_group_idx = local_block_idx - group_idx * num_blocks_per_group; - auto num_m_in_group = cute::min(kMSwizzleGroup, num_m_blocks - first_m_idx); - - if (num_m_in_group % 2 != 0) { - const auto even_m_in_group = num_m_in_group ^ 1u; - const auto even_region_blocks = even_m_in_group * kNumBlockNs; - if (in_group_idx < even_region_blocks) { - num_m_in_group = even_m_in_group; - } else { - in_group_idx -= even_region_blocks; - m_idx = first_m_idx + even_m_in_group; - n_idx = in_group_idx; - is_peer_cta_alive = (in_group_idx ^ 1u) < kNumBlockNs; - is_a_multicast_valid = false; - is_b_multicast_valid = false; - return; - } - } - - m_idx = first_m_idx + in_group_idx % num_m_in_group; - n_idx = in_group_idx / num_m_in_group; - is_peer_cta_alive = true; - is_a_multicast_valid = false; - is_b_multicast_valid = true; - } - } - - template - CUTLASS_DEVICE void map_cluster_bcast_a_block_idx(const uint32_t& local_block_idx, - const uint32_t& num_m_blocks, - uint32_t& m_idx, - uint32_t& n_idx) { - (void)num_m_blocks; - DG_STATIC_ASSERT(kClusterSize == 2, "Cluster A multicast mapping requires 2 CTAs"); - DG_STATIC_ASSERT(kNumBlockNs % 2 == 0, "N block count must be even for 2-CTA A multicast"); - m_idx = local_block_idx / kNumBlockNs; - n_idx = local_block_idx - m_idx * kNumBlockNs; - is_peer_cta_alive = true; - is_a_multicast_valid = true; - is_b_multicast_valid = false; - } - CUTLASS_DEVICE bool fetch_next_l1_block() { const auto wave_end_expert_idx = get_wave_expert_end_idx(); while (current_local_expert_idx < wave_end_expert_idx) { const auto num_m_blocks = get_current_num_m_blocks(); - const auto num_blocks = num_m_blocks * kNumL1BlockNs; - if (block_idx < num_blocks) { - if constexpr (kClusterSize == 1) { - map_expert_block_idx( - block_idx, num_m_blocks, m_block_idx, n_block_idx); - update_peer_cta_alive(num_m_blocks); + const auto num_m_units = math::ceil_div(num_m_blocks, kClusterSize); + if (block_idx < num_m_units * kNumL1BlockNs) { + if constexpr (kL1NMajorSchedule) { + n_block_idx = block_idx / num_m_units; + const auto m_unit_idx = block_idx - n_block_idx * num_m_units; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); } else { - map_cluster_bcast_b_block_idx( - block_idx, num_m_blocks, m_block_idx, n_block_idx); + const auto m_unit_idx = block_idx / kNumL1BlockNs; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + n_block_idx = block_idx % kNumL1BlockNs; } return true; } // Current expert is fully assigned, move to the next - block_idx -= num_blocks; + block_idx -= num_m_units * kNumL1BlockNs; advance_expert_idx(); } return false; @@ -316,26 +150,22 @@ struct MegaMoEScheduler { const auto wave_end_expert_idx = get_wave_expert_end_idx(); while (current_local_expert_idx < wave_end_expert_idx) { const auto num_m_blocks = get_current_num_m_blocks(); - const auto num_blocks = num_m_blocks * kNumL2BlockNs; - if (block_idx < num_blocks) { - if constexpr (kClusterSize == 1) { - map_expert_block_idx( - block_idx, num_m_blocks, m_block_idx, n_block_idx); - update_peer_cta_alive(num_m_blocks); + const auto num_m_units = math::ceil_div(num_m_blocks, kClusterSize); + if (block_idx < num_m_units * kNumL2BlockNs) { + if constexpr (kL2NMajorSchedule) { + n_block_idx = block_idx / num_m_units; + const auto m_unit_idx = block_idx - n_block_idx * num_m_units; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); } else { - if constexpr (kL2NMajorSchedule) { - map_cluster_bcast_a_block_idx( - block_idx, num_m_blocks, m_block_idx, n_block_idx); - } else { - map_cluster_bcast_b_block_idx( - block_idx, num_m_blocks, m_block_idx, n_block_idx); - } + const auto m_unit_idx = block_idx / kNumL2BlockNs; + m_block_idx = m_unit_idx * kClusterSize + cute::block_rank_in_cluster(); + n_block_idx = block_idx % kNumL2BlockNs; } return true; } // Current expert is fully assigned, move to the next - block_idx -= num_blocks; + block_idx -= num_m_units * kNumL2BlockNs; advance_expert_idx(); } return false; @@ -345,14 +175,14 @@ struct MegaMoEScheduler { // Core state machine: assigns the next block CUTLASS_DEVICE cute::tuple get_next_block() { while (true) { - if (current_local_expert_idx >= kExpertRangeEnd) + if (current_local_expert_idx >= kNumExpertsPerRank) break; if (next_phase == BlockPhase::Linear1) { if (fetch_next_l1_block()) { // Found a new L1 block - // Jump to next persistent CTA work item - block_idx += kNumSMs; + // Jump to next cluster-scheduled work unit + block_idx += kNumSMs / kClusterSize; return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; } else { // L1 for the current wave is complete, transition to L2 @@ -362,8 +192,8 @@ struct MegaMoEScheduler { } else { if (fetch_next_l2_block()) { // Found a new L2 block - // Jump to next persistent CTA work item - block_idx += kNumSMs; + // Jump to next cluster-scheduled work unit + block_idx += kNumSMs / kClusterSize; return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; } else { // Move to L1 of the next wave @@ -392,67 +222,13 @@ struct MegaMoEScheduler { __syncwarp(); } - CUTLASS_DEVICE void fetch_packed_l2_metadata_count() { - // Split K2 starts after split K1 globally completes. When K1 built the - // metadata prefix, the per-expert actual-M values are already final, so - // K2 can initialize scheduler counts without volatile polling. - #pragma unroll - for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { - const auto expert_idx = i * 32 + ptx::get_lane_idx(); - uint32_t value = 0; - if (expert_idx < kNumExpertsPerRank) - value = *workspace.get_packed_l2_expert_m_ptr(expert_idx); - stored_num_tokens_per_expert[i] = value; - } - __syncwarp(); - } - - template - CUTLASS_DEVICE void build_packed_l2_metadata_prefix(const bool& is_leader_thread) const { - DG_STATIC_ASSERT(kPackedAlignment == 128, "Only 128-row packed L2 metadata is currently supported"); - if (blockIdx.x != 0 or !is_leader_thread) - return; - - uint32_t packed_offset = 0; - #pragma unroll - for (uint32_t expert_idx = 0; expert_idx < kNumExpertsPerRank; ++expert_idx) { - uint64_t value = 0; - do { - value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); - } while (static_cast(value >> 32) != kNumSMs * kNumRanks); - - const uint32_t actual_m = static_cast(value); - *workspace.get_packed_l2_expert_start_ptr(expert_idx) = packed_offset; - *workspace.get_packed_l2_expert_m_ptr(expert_idx) = actual_m; - packed_offset += math::align(actual_m, kPackedAlignment); - } - *workspace.get_packed_l2_expert_start_ptr(kNumExpertsPerRank) = packed_offset; - } - - template - CUTLASS_DEVICE void fill_packed_l2_metadata_rows(const uint32_t& thread_idx, - const uint32_t& num_threads) const { - DG_STATIC_ASSERT(kPackedAlignment == 128, "Only 128-row packed L2 metadata is currently supported"); - for (uint32_t expert_idx = blockIdx.x; expert_idx < kNumExpertsPerRank; expert_idx += kNumSMs) { - const uint32_t packed_start = *workspace.get_packed_l2_expert_start_ptr(expert_idx); - const uint32_t packed_end = *workspace.get_packed_l2_expert_start_ptr(expert_idx + 1); - const uint32_t actual_m = *workspace.get_packed_l2_expert_m_ptr(expert_idx); - const uint32_t packed_m = packed_end - packed_start; - - for (uint32_t row = thread_idx; row < packed_m; row += num_threads) { - *workspace.get_packed_l2_row_to_expert_ptr(packed_start + row) = - row < actual_m ? static_cast(expert_idx) : -1; - } - } - } - template CUTLASS_DEVICE void for_each_block(Func&& func) { // Wait for all expert counters to be finalized fetch_expert_recv_count(); // Initialize current expert with 0 - set_expert_idx(kExpertRangeStart); + set_expert_idx(0); // Iterate over all blocks // TODO: add swizzle within expert waves for better L2 cache utilization @@ -472,32 +248,29 @@ struct MegaMoEScheduler { // Split-kernel mode: K1 owns only dispatch + Linear1. Unlike // for_each_block(), do not burn scheduler iterations on Linear2 blocks. fetch_expert_recv_count(); - set_expert_idx(kExpertRangeStart); - while (current_local_expert_idx < kExpertRangeEnd) { + set_expert_idx(0); + while (current_local_expert_idx < kNumExpertsPerRank) { if (fetch_next_l1_block()) { - block_idx += kNumSMs; + block_idx += kNumSMs / kClusterSize; func(current_local_expert_idx, kNumL1BlockKs, m_block_idx, n_block_idx); - } else if (current_local_expert_idx >= kExpertRangeEnd) { + } else if (current_local_expert_idx >= kNumExpertsPerRank) { break; } } } - template + template CUTLASS_DEVICE void for_each_linear2_block(Func&& func) { // Split-kernel mode: K2 starts after K1 has completed globally, so all // L2-ready masks are already final. Schedule Linear2 blocks directly // from the phase start, like a standalone grouped GEMM. - if constexpr (kUsePackedL2MetadataCounts) - fetch_packed_l2_metadata_count(); - else - fetch_expert_recv_count(); - set_expert_idx(kExpertRangeStart); - while (current_local_expert_idx < kExpertRangeEnd) { + fetch_expert_recv_count(); + set_expert_idx(0); + while (current_local_expert_idx < kNumExpertsPerRank) { if (fetch_next_l2_block()) { - block_idx += kNumSMs; + block_idx += kNumSMs / kClusterSize; func(current_local_expert_idx, kNumL2BlockKs, m_block_idx, n_block_idx); - } else if (current_local_expert_idx >= kExpertRangeEnd) { + } else if (current_local_expert_idx >= kNumExpertsPerRank) { break; } } diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 6f2295c748..e2c5e026b2 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -183,31 +183,3 @@ def fp8_mega_moe(y: torch.Tensor, fast_math ) - -def w4a8_mega_moe(y: torch.Tensor, - l1_weights: Tuple[torch.Tensor, torch.Tensor], - l2_weights: Tuple[torch.Tensor, torch.Tensor], - sym_buffer: SymmBuffer, - cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, - recipe: Tuple[int, int, int] = (128, 128, 128), - activation: str = 'swiglu', - activation_clamp: Optional[float] = None, - fast_math: bool = True): - """SM90 (Hopper) W4A8 MegaMoE entry — Phase 0 scaffold. - - Accepts the same FP8 weights + per-128 float SF as ``fp8_mega_moe`` and - dispatches to the W4A8 runtime (currently a verbatim FP8 copy under a - different name; subsequent phases will switch to packed MXFP4 + E8M0). - """ - _C.w4a8_mega_moe( - y, - l1_weights, l2_weights, - cumulative_local_expert_recv_stats, - sym_buffer.buffer, - sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), - sym_buffer.num_max_tokens_per_rank, - sym_buffer.num_experts, sym_buffer.num_topk, - recipe, - activation, activation_clamp, - fast_math - ) diff --git a/tests/bench_w4a8_mega_moe_sm90.py b/tests/bench_w4a8_mega_moe_sm90.py deleted file mode 100644 index 0199fb27d3..0000000000 --- a/tests/bench_w4a8_mega_moe_sm90.py +++ /dev/null @@ -1,250 +0,0 @@ -"""SM90 (Hopper) MegaMoE benchmark / NCU-profile harness. - -Mirrors ``tests/test_mega_moe.py``'s ``--ncu-profile-only`` / -``--local-rank-idx`` interface so the same ``scripts/run_ncu_mega_moe.sh`` -pattern can drive it for SM90. - -In normal (non-NCU) mode it sweeps a list of ``num_tokens`` values (default: -1, 2, 4, 8, 16, 32) and reports per-call kernel time via the same -``bench_kineto`` helper used by the SM100 perf test, plus a rough TFLOPS / -HBM GB/s figure useful for tracking optimisation deltas. -""" - -import argparse -import os -import random -import sys -import torch -import torch.distributed as dist -from typing import Tuple - -REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if REPO_ROOT not in sys.path: - sys.path.insert(0, REPO_ROOT) - -import deep_gemm -from deep_gemm.utils import per_token_cast_to_fp8 -from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather -from deep_gemm.testing import bench_kineto, calc_diff, get_arch_major - - -def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g, n, k = w.shape - assert n % 128 == 0 and k % 128 == 0 - w_view = w.view(g, n // 128, 128, k // 128, 128).float() - amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) - sf = amax / 448.0 - w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) - return w_fp8.view(g, n, k).contiguous(), sf.contiguous() - - -def _run_one_config(args, num_tokens, num_max_tokens_per_rank, - hidden, intermediate_hidden, - num_experts, num_topk, num_ranks, rank_idx, group, - activation_clamp, fast_math, - print_perf=True): - num_experts_per_rank = num_experts // num_ranks - assert num_tokens <= num_max_tokens_per_rank - - # Symmetric buffer (one per config: cheaper to recreate than to keep max-size) - buffer = deep_gemm.get_symm_buffer_for_mega_moe( - group, num_experts, - num_max_tokens_per_rank, num_topk, - hidden, intermediate_hidden, - ) - - # Inputs (bf16, then quantised) - x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - l1_bf = torch.randn( - (num_experts_per_rank, intermediate_hidden * 2, hidden), - dtype=torch.bfloat16, device='cuda') * 0.05 - l2_bf = torch.randn( - (num_experts_per_rank, hidden, intermediate_hidden), - dtype=torch.bfloat16, device='cuda') * 0.05 - scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') - topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) - if args.masked_ratio > 0: - rand_mask = torch.rand_like(topk_idx, dtype=torch.float) - topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) - topk_w.masked_fill_(topk_idx < 0, 0) - - x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, - use_packed_ue8m0=False) - l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) - l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) - transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( - (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), - ) - - phase_profile_enabled = os.environ.get('DG_SM90_MOE_PHASE_PROFILE', '0') != '0' - phase_profile_ints = 64 if phase_profile_enabled else 0 - cum_stats = torch.zeros(num_experts_per_rank + phase_profile_ints, dtype=torch.int, device='cuda') - - # Stage inputs once; bench-loop re-copies them each call (bench helper expects - # an idempotent ``fn``). - def run_fused(): - buffer.x[:num_tokens].copy_(x_fp8) - buffer.x_sf[:num_tokens].copy_(x_sf) - buffer.topk_idx[:num_tokens].copy_(topk_idx) - buffer.topk_weights[:num_tokens].copy_(topk_w) - y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - deep_gemm.w4a8_mega_moe( - y, transformed_l1, transformed_l2, buffer, - cumulative_local_expert_recv_stats=cum_stats, - recipe=(128, 128, 128), - activation='swiglu', - activation_clamp=activation_clamp, - fast_math=fast_math, - ) - return y - - if args.ncu_profile_only: - dist_print(f'[NCU] tokens={num_tokens} hidden={hidden} ih={intermediate_hidden}', - once_in_node=True) - run_fused() - torch.cuda.synchronize() - dist.barrier() - buffer.destroy() - return - - # Warm up + benchmark - run_fused() - dist.barrier() - if phase_profile_enabled: - cum_stats.zero_() - torch.cuda.synchronize() - dist.barrier() - # NSYS MULTI-ITER (aichenf): N timed iters with barrier+sleep between them. - # bench_kineto returns 1 under DG_USE_NVIDIA_TOOLS=1, but this loop puts - # multiple mega_moe instances on the nsys timeline so we can measure variance. - import os as _os - _nsys_iters = int(_os.environ.get('NSYS_ITERS', '0')) - if _nsys_iters > 0: - for _it in range(_nsys_iters): - torch.cuda.synchronize() - dist.barrier() - torch.cuda._sleep(int(2e7)) # 10ms gap between iters - dist.barrier() - run_fused() - torch.cuda.synchronize() - dist.barrier() - t_fused = bench_kineto(run_fused, 'sm90_w4a8_mega_moe', - barrier=lambda: dist.barrier(), - num_tests=args.num_tests, - suppress_kineto_output=True) - - # Count tokens that landed on this rank for stats - gathered_topk_idx = uneven_all_gather(topk_idx, group=group) - gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | - (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 - num_recv_tokens = (gathered_topk_idx != -1).sum().item() - - safe_div = lambda a, b: float('nan') if b == 0 else a / b - tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) - num_touched_experts = max(0, torch.unique(gathered_topk_idx.flatten()).numel() - 1) - # FP8 weights = 1 byte, FP8 acts = 1 byte, BF16 output = 2 bytes - num_hbm_bytes = ( - num_touched_experts * intermediate_hidden * 2 * hidden + # L1 weights - num_touched_experts * hidden * intermediate_hidden + # L2 weights - num_recv_tokens * hidden + # L1 acts read - num_recv_tokens * intermediate_hidden + # L1 out write - num_recv_tokens * intermediate_hidden + # L2 acts read - num_recv_tokens * hidden * 2 # L2 out write - ) - hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) - - if print_perf: - dist_print( - f' tokens={num_tokens:4d} recv={num_recv_tokens:5d} experts={num_touched_experts:4d} ' - f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', - once_in_node=True, - ) - if phase_profile_enabled: - torch.cuda.synchronize() - profile = cum_stats[num_experts_per_rank:num_experts_per_rank + 48].view(torch.int64).cpu().tolist() - names = [ - 'dispatch_total', 'dispatch_pull', 'math_loop', 'combine_barrier', - 'combine_reduce', 'gemm_core', 'l1_epilogue', 'l2_epilogue', - ] - for i, name in enumerate(names): - total, max_v, count = profile[i], profile[8 + i], profile[16 + i] - avg = float(total) / count if count else 0.0 - dist_print( - f' phase {name:16s} avg={avg:10.0f} max={max_v:10d} count={count}', - once_in_node=True, - ) - - dist.barrier() - buffer.destroy() - - -def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) - forced_num_sms = int(os.environ.get('DG_SM90_MOE_SET_NUM_SMS', '0')) - if forced_num_sms > 0: - deep_gemm.set_num_sms(forced_num_sms) - torch.manual_seed(rank_idx) - random.seed(rank_idx) - - if get_arch_major() != 9: - dist_print(f'[SKIP] requires SM90, got SM{get_arch_major()}0', once_in_node=True) - dist.destroy_process_group() - return - - if args.batches is None: - batches = [1, 2, 4, 8, 16, 32] - else: - batches = args.batches - - dist_print( - f'SM90 MegaMoE bench: ranks={num_ranks} hidden={args.hidden} ' - f'ih={args.intermediate_hidden} experts={args.num_experts} topk={args.num_topk} ' - f'masked_ratio={args.masked_ratio} fast_math={bool(args.fast_math)}', - once_in_node=True, - ) - - # In NCU mode we run only one batch (the first one in `batches`) so that - # ncu's `--launch-count 1` is unambiguous. - if args.ncu_profile_only: - batches = batches[:1] - - num_max_tokens_per_rank = max(batches) - for num_tokens in batches: - _run_one_config( - args, num_tokens, num_max_tokens_per_rank, - args.hidden, args.intermediate_hidden, - args.num_experts, args.num_topk, - num_ranks, rank_idx, group, - activation_clamp=args.activation_clamp, - fast_math=bool(args.fast_math), - ) - - dist.barrier() - dist.destroy_process_group() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='SM90 MegaMoE benchmark') - - parser.add_argument('--ncu-profile-only', action='store_true') - parser.add_argument('--num-processes', type=int, default=8) - parser.add_argument('--local-rank-idx', type=int, default=None) - - parser.add_argument('--batches', type=int, nargs='+', default=None, - help='List of num_tokens to sweep (default: 1 2 4 8 16 32)') - parser.add_argument('--hidden', type=int, default=7168) - parser.add_argument('--intermediate-hidden', type=int, default=2048) - parser.add_argument('--num-experts', type=int, default=256) - parser.add_argument('--num-topk', type=int, default=8) - parser.add_argument('--activation-clamp', type=float, default=10.0) - parser.add_argument('--masked-ratio', type=float, default=0.0) - parser.add_argument('--fast-math', type=int, default=1) - parser.add_argument('--num-tests', type=int, default=20) - - args = parser.parse_args() - - if args.local_rank_idx is not None: - test(args.local_rank_idx, args.num_processes, args) - else: - np = args.num_processes - torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) From 03a5c9a3b83074cd7a5b0cc4a9a261b157572b77 Mon Sep 17 00:00:00 2001 From: AichenF Date: Mon, 1 Jun 2026 18:36:04 +0800 Subject: [PATCH 06/15] Add SM90 MegaMoE config selector --- csrc/jit_kernels/heuristics/mega_moe.hpp | 476 ++++++++++++++++--- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 15 +- docs/sm90_megamoe_config_search.md | 122 +++++ 3 files changed, 528 insertions(+), 85 deletions(-) create mode 100644 docs/sm90_megamoe_config_search.md diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 7d5f15e82a..74c3d62756 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -1,7 +1,10 @@ #pragma once #include +#include +#include #include +#include #include @@ -273,6 +276,11 @@ struct MegaMoESM90Config { // Thread layout: dispatch + non-epilogue (TMA) + epilogue (math) int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads; + // Chosen scheduler / epilogue modes. Keeping these in the config makes the + // SM90 path follow the same single-source-of-truth style as regular GEMM + // configs: the selector chooses a complete candidate, then launch consumes it. + bool direct_l2_scatter, l2_nmajor_schedule, one_warp_cleanup; + friend std::ostream& operator << (std::ostream& os, const MegaMoESM90Config& config) { os << "MegaMoESM90Config(" << "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k @@ -284,7 +292,10 @@ struct MegaMoESM90Config { << ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size << ", num_dispatch_threads=" << config.num_dispatch_threads << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads - << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + << ", num_epilogue_threads=" << config.num_epilogue_threads + << ", direct_l2_scatter=" << config.direct_l2_scatter + << ", l2_nmajor_schedule=" << config.l2_nmajor_schedule + << ", one_warp_cleanup=" << config.one_warp_cleanup << ")"; return os; } }; @@ -325,6 +336,8 @@ struct Sm90MoeHeuristicPolicy { return 0; if (is_hopper_topk6() and expected_tokens_per_expert >= 8.0f and expected_tokens_per_expert <= 32.0f) return 16; + if (is_main_topk8() and expected_tokens_per_expert == 8.0f) + return 16; if (is_main_topk8() and expected_tokens_per_expert == 128.0f) return 16; if (is_main_topk8() and expected_tokens_per_expert >= 256.0f and expected_tokens_per_expert < 512.0f) @@ -337,6 +350,7 @@ struct Sm90MoeHeuristicPolicy { return false; if (is_main_topk8()) { return expected_is_one_of(2, 4, 8, 16, 32, 64, 76, 80, 88, 128) or + expected_is_between(64.0f, 80.0f) or expected_is_between(96.0f, 120.0f) or expected_tokens_per_expert >= 144.0f; } @@ -360,7 +374,7 @@ struct Sm90MoeHeuristicPolicy { bool one_warp_cleanup(const bool& masked_hint) const { if (not uses_split_bn256()) return false; - if (is_main_topk8() and expected_tokens_per_expert <= 64.0f) + if (is_main_topk8() and expected_tokens_per_expert <= 80.0f) return true; if (is_hopper_topk6() and masked_hint and expected_tokens_per_expert == 64.0f) return true; @@ -504,7 +518,7 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int& num_experts, const int& hidden, const int& block_m, const int& block_n, const int& block_k, const int& num_dispatch_warps, const int& num_epilogue_warps, - const bool& direct_l2_scatter_default = false, + const bool& direct_l2_scatter_enabled = false, const int& default_num_stages = 0) { constexpr int kSmemAlignment = 1024; @@ -527,9 +541,7 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int wg_block_n = (split_n_warpgroups or serial_n_warpgroups) ? block_n / 2 : block_n; const int smem_cd_accum = (block_m == 16 or block_m == 32) ? align(block_m * block_n * static_cast(sizeof(float)), kSmemAlignment) : 0; const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) - const bool direct_l2_scatter = get_env( - "DG_SM90_MOE_DIRECT_L2_SCATTER", - direct_l2_scatter_default ? 1 : 0) != 0 and + const bool direct_l2_scatter = direct_l2_scatter_enabled and block_m != 16 and block_m != 32 and not serial_n_warpgroups and wg_block_n == 128; const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and @@ -574,90 +586,408 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( : (prefer_bn256_split ? std::min(4, max_num_stages) : 0); const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); const int num_stages = forced_num_stages > 0 - ? forced_num_stages + ? std::min(forced_num_stages, max_num_stages) : (preferred_num_stages > 0 ? preferred_num_stages : max_num_stages); DG_HOST_ASSERT(num_stages >= 2 and num_stages <= max_num_stages); return {num_stages, smem_fixed + num_stages * (smem_per_stage + smem_barriers_per_stage)}; } -static MegaMoESM90Config get_mega_moe_config_sm90( +template +static void append_unique_moe_candidate(std::vector& values, const T& value) { + if (std::find(values.begin(), values.end(), value) == values.end()) + values.emplace_back(value); +} + +static std::vector get_sm90_moe_bool_candidates( + const std::string& env_name, + const bool& default_value, + const bool& allow_alternative) { + const int forced = get_env(env_name, -1); + DG_HOST_ASSERT(forced == -1 or forced == 0 or forced == 1); + std::vector values; + if (forced != -1) { + values.emplace_back(forced); + return values; + } + append_unique_moe_candidate(values, default_value ? 1 : 0); + if (allow_alternative) + append_unique_moe_candidate(values, default_value ? 0 : 1); + return values; +} + +struct Sm90MoeConfigInfo { + int64_t score; + int num_blocks, num_waves, last_wave_util; + int empirical_penalty; + MegaMoESM90Config config; + + friend std::ostream& operator << (std::ostream& os, const Sm90MoeConfigInfo& info) { + os << "Sm90MoeConfigInfo(score=" << info.score + << ", num_blocks=" << info.num_blocks + << ", num_waves=" << info.num_waves + << ", last_wave_util=" << info.last_wave_util + << ", empirical_penalty=" << info.empirical_penalty + << ", config=" << info.config << ")"; + return os; + } +}; + +static Sm90MoeConfigInfo get_sm90_moe_config_info( + const MegaMoESM90Config& config, + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, const int& num_sms, + const bool& empirical_direct_l2_scatter, + const bool& empirical_l2_nmajor_schedule, + const bool& empirical_one_warp_cleanup, + const int& empirical_num_stages, + const int& empirical_num_experts_per_wave) { + const float expected_tokens_per_expert = + static_cast(num_tokens) * num_topk / num_experts_per_rank; + const int expected_tokens_ceil = + std::max(1, static_cast(std::ceil(expected_tokens_per_expert))); + const int num_m_blocks = ceil_div(expected_tokens_ceil, config.block_m); + const int num_l1_n_blocks = ceil_div(2 * intermediate_hidden, config.block_n); + const int num_l2_n_blocks = ceil_div(hidden, config.block_n); + const int num_blocks = num_experts_per_rank * num_m_blocks * + (num_l1_n_blocks + num_l2_n_blocks); + const int num_waves = ceil_div(num_blocks, num_sms); + const int num_last_blocks = num_blocks % num_sms; + const int last_wave_util = num_last_blocks == 0 ? num_sms : num_last_blocks; + + // This cost model intentionally mirrors DeepGEMM's style: enumerate legal + // candidates, rank them with cheap shape-derived estimates, and use a small + // empirical calibration layer for modes whose cost is dominated by dispatch / + // combine details not captured by the block-count model. + int empirical_penalty = 0; + if (config.direct_l2_scatter != empirical_direct_l2_scatter) + empirical_penalty += 1000000; + if (config.l2_nmajor_schedule != empirical_l2_nmajor_schedule) + empirical_penalty += 500000; + if (config.one_warp_cleanup != empirical_one_warp_cleanup) + empirical_penalty += 250000; + if (config.num_stages != empirical_num_stages) + empirical_penalty += 500000; + if (config.num_experts_per_wave != empirical_num_experts_per_wave) + empirical_penalty += 250000; + + int64_t score = 0; + score += static_cast(num_waves) * 100000; + score -= static_cast(last_wave_util) * 100; + score += static_cast(num_blocks); + score += static_cast(config.smem_size / 1024); + score += empirical_penalty; + + // Prefer the compact split frontend when the calibrated modes tie. + if (config.block_m == 64 and config.block_n == 256 and + config.num_dispatch_threads == 64 and config.num_non_epilogue_threads == 64) + score -= 1000; + + return {score, num_blocks, num_waves, last_wave_util, empirical_penalty, config}; +} + +static std::vector get_mega_moe_config_candidates_sm90( const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, const int& hidden, const int& intermediate_hidden, const int& num_padded_sf_pool_tokens) { - const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90( - num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); - const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and - block_m == 64; - const bool use_bn256_seq_n = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and block_m == 64; - DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); - const int block_n = (use_bn256_split_n or use_bn256_seq_n) ? 256 : 128; - DG_HOST_ASSERT((not use_bn256_split_n) or num_epilogue_threads == 256); - const int block_k = 128; - // Default remains cluster_size=1; the opt-in cluster path uses B multicast - // across adjacent M blocks. - const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; - const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; - DG_HOST_ASSERT((not use_cluster_bcast_b) or - ((block_m == 64 and block_n == 128 and num_epilogue_threads == 128) or - (block_m == 128 and block_n == 128 and num_epilogue_threads == 256))); - const int cluster_size = use_cluster_bcast_b ? 2 : 1; + const bool split_l1_l2 = get_sm90_moe_split_l1_l2_default(); + const bool extra_modes = get_env("DG_SM90_MOE_SEARCH_EXTRA_MODES", 1) != 0; + const bool extra_block_shapes = get_env("DG_SM90_MOE_SEARCH_BLOCK_SHAPES", 0) != 0; + const float expected_tokens_per_expert = + static_cast(num_tokens) * num_topk / num_experts_per_rank; + + const int requested_mma_m = get_env("DG_SM90_MOE_MMA_SYNC_M") > 0 + ? get_env("DG_SM90_MOE_MMA_SYNC_M") + : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); + DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); + const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); + const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); + DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); + DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or + forced_epilogue_warpgroups == 1 or + forced_epilogue_warpgroups == 2); + + const bool use_mma_sync_decode = + requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); + const bool use_b_stationary_2wg = + get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; + const bool use_bn256_split_n_env = + get_env("DG_SM90_MOE_BN256_2WG", split_l1_l2 ? 1 : 0) != 0 and + forced_block_m != 128 and not use_mma_sync_decode; + const bool use_bn256_seq_n_env = + get_env("DG_SM90_MOE_BN256_SEQ") != 0 and + forced_block_m != 128 and not use_mma_sync_decode; + DG_HOST_ASSERT(not (use_bn256_split_n_env and use_bn256_seq_n_env)); + DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n_env or use_bn256_seq_n_env))); + + std::vector block_m_candidates; + if (forced_block_m > 0) { + append_unique_moe_candidate(block_m_candidates, forced_block_m); + } else if (use_b_stationary_2wg) { + append_unique_moe_candidate(block_m_candidates, 128); + } else if (use_mma_sync_decode) { + append_unique_moe_candidate(block_m_candidates, requested_mma_m); + } else { + append_unique_moe_candidate(block_m_candidates, 64); + if (extra_block_shapes) + append_unique_moe_candidate(block_m_candidates, 128); + } + const int num_max_pool_tokens = layout::get_num_max_pool_tokens( num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); - const int swizzle_acts_mode = (block_m == 16 or block_m == 32) ? 0 : 128; - const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; - + const int block_k = 128; + const bool split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0; + const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; const int num_sms = device_runtime->get_num_sms(); - const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, block_m, block_n, num_sms); - const bool split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0; - const bool prefer_compact_frontend = get_sm90_moe_split_l1_l2_default() and block_n == 256 and not split_sfa_tma; - const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND", - prefer_compact_frontend ? 1 : 0) != 0; - const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS", - compact_frontend ? 2 : 0); - DG_HOST_ASSERT(forced_dispatch_warps == 0 or forced_dispatch_warps == 2 or - forced_dispatch_warps == 4 or forced_dispatch_warps == 8); - const int num_dispatch_threads = (forced_dispatch_warps > 0 ? forced_dispatch_warps : 4) * 32; - DG_HOST_ASSERT((not split_sfa_tma) or (not compact_frontend)); - DG_HOST_ASSERT((not compact_frontend) or num_dispatch_threads == 64); - const int num_non_epilogue_threads = compact_frontend ? 64 : 128; - DG_HOST_ASSERT((num_dispatch_threads + num_non_epilogue_threads) % 128 == 0); + std::vector candidates; + for (const int& block_m: block_m_candidates) { + DG_HOST_ASSERT(std::any_of( + layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, + [=](const auto& candidate) { return candidate == block_m; }) + ); + + std::vector block_n_candidates; + if (block_m == 64 and not use_mma_sync_decode and + (use_bn256_split_n_env or use_bn256_seq_n_env)) { + append_unique_moe_candidate(block_n_candidates, 256); + if (extra_block_shapes) + append_unique_moe_candidate(block_n_candidates, 128); + } else { + append_unique_moe_candidate(block_n_candidates, 128); + if (extra_block_shapes and block_m == 64 and not use_mma_sync_decode) + append_unique_moe_candidate(block_n_candidates, 256); + } - const auto policy = get_sm90_moe_heuristic_policy( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, block_m, block_n); - const bool direct_l2_scatter_default = policy.direct_l2_scatter(); - const bool direct_l2_scatter_enabled = get_env( - "DG_SM90_MOE_DIRECT_L2_SCATTER", - direct_l2_scatter_default ? 1 : 0) != 0; - const int default_num_stages = policy.stage5_pipeline( - direct_l2_scatter_enabled, - get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, - get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0, - get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0) ? 5 : 0; - const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe_sm90( - SM90ArchSpec::smem_capacity, - num_experts, hidden, - block_m, block_n, block_k, - num_dispatch_threads / 32, num_epilogue_threads / 32, - direct_l2_scatter_default, - default_num_stages); + for (const int& block_n: block_n_candidates) { + std::vector epilogue_wg_candidates; + if (forced_epilogue_warpgroups > 0) { + append_unique_moe_candidate(epilogue_wg_candidates, forced_epilogue_warpgroups); + } else { + append_unique_moe_candidate(epilogue_wg_candidates, + (block_m == 128 or use_b_stationary_2wg or block_n == 256) ? 2 : 1); + } + + for (const int& num_epilogue_warpgroups: epilogue_wg_candidates) { + if (block_m % num_epilogue_warpgroups != 0) + continue; + if ((block_m == 16 or block_m == 32) and num_epilogue_warpgroups != 1) + continue; + if (block_m == 128 and num_epilogue_warpgroups != 2) + continue; + if (block_m == 64 and block_n == 256 and num_epilogue_warpgroups != 2) + continue; + const int num_epilogue_threads = num_epilogue_warpgroups * 128; + + if (use_cluster_bcast_b and + not ((block_m == 64 and block_n == 128 and num_epilogue_threads == 128) or + (block_m == 128 and block_n == 128 and num_epilogue_threads == 256))) + continue; + const int cluster_size = use_cluster_bcast_b ? 2 : 1; + const int swizzle_acts_mode = (block_m == 16 or block_m == 32) ? 0 : 128; + const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; + + const bool prefer_compact_frontend = + split_l1_l2 and block_n == 256 and not split_sfa_tma; + const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND", + prefer_compact_frontend ? 1 : 0) != 0; + const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS", -1); + DG_HOST_ASSERT(forced_dispatch_warps == -1 or forced_dispatch_warps == 0 or + forced_dispatch_warps == 2 or forced_dispatch_warps == 4 or + forced_dispatch_warps == 8); + std::vector dispatch_warp_candidates; + if (forced_dispatch_warps > 0) { + append_unique_moe_candidate(dispatch_warp_candidates, forced_dispatch_warps); + } else { + append_unique_moe_candidate(dispatch_warp_candidates, compact_frontend ? 2 : 4); + if (extra_block_shapes and not compact_frontend) + append_unique_moe_candidate(dispatch_warp_candidates, 2); + } + + for (const int& num_dispatch_warps: dispatch_warp_candidates) { + if (compact_frontend and num_dispatch_warps != 2) + continue; + if (split_sfa_tma and compact_frontend) + continue; + const int num_dispatch_threads = num_dispatch_warps * 32; + const int num_non_epilogue_threads = compact_frontend ? 64 : 128; + if ((num_dispatch_threads + num_non_epilogue_threads) % 128 != 0) + continue; + + const auto policy = get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n); + const bool direct_l2_scatter_default = policy.direct_l2_scatter(); + const bool l2_nmajor_schedule_default = policy.l2_nmajor_schedule( + get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, + get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0); + const bool one_warp_cleanup_default = policy.one_warp_cleanup( + get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0); + const bool direct_l2_scatter_legal = + block_m != 16 and block_m != 32 and + ((block_m == 64 and block_n == 256 and num_epilogue_warpgroups == 2) or + block_n == 128); + + auto direct_candidates = get_sm90_moe_bool_candidates( + "DG_SM90_MOE_DIRECT_L2_SCATTER", + direct_l2_scatter_default and direct_l2_scatter_legal, + extra_modes and direct_l2_scatter_legal); + auto l2_nmajor_candidates = get_sm90_moe_bool_candidates( + "DG_SM90_MOE_L2_NMAJOR", + l2_nmajor_schedule_default, + extra_modes and policy.uses_split_bn256() and policy.is_main_topk8()); + auto cleanup_candidates = get_sm90_moe_bool_candidates( + "DG_SM90_MOE_ONE_WARP_CLEANUP", + one_warp_cleanup_default, + extra_modes and policy.uses_split_bn256()); + + const int default_epw = get_num_experts_per_wave_for_mega_moe_sm90( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); + std::vector experts_per_wave_candidates; + append_unique_moe_candidate(experts_per_wave_candidates, default_epw); + if (get_env("DG_SM90_MOE_EXPERTS_PER_WAVE") <= 0 and extra_modes) { + if (16 <= num_experts_per_rank and num_experts_per_rank % 16 == 0) + append_unique_moe_candidate(experts_per_wave_candidates, 16); + append_unique_moe_candidate(experts_per_wave_candidates, num_experts_per_rank); + } + + for (const int& direct_value: direct_candidates) { + const bool direct_l2_scatter = direct_value != 0; + if (direct_l2_scatter and not direct_l2_scatter_legal) + continue; + const int empirical_stage = policy.stage5_pipeline( + direct_l2_scatter, + get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, + get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0, + get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0) ? 5 : 4; + const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); + std::vector stage_candidates; + if (forced_num_stages > 0) { + append_unique_moe_candidate(stage_candidates, forced_num_stages); + } else { + append_unique_moe_candidate(stage_candidates, empirical_stage); + if (extra_modes) { + append_unique_moe_candidate(stage_candidates, 4); + append_unique_moe_candidate(stage_candidates, 5); + } + } + + for (const int& requested_num_stages: stage_candidates) { + const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe_sm90( + SM90ArchSpec::smem_capacity, + num_experts, hidden, + block_m, block_n, block_k, + num_dispatch_threads / 32, num_epilogue_threads / 32, + direct_l2_scatter, + requested_num_stages); + for (const int& l2_nmajor_value: l2_nmajor_candidates) { + for (const int& cleanup_value: cleanup_candidates) { + for (const int& num_experts_per_wave: experts_per_wave_candidates) { + if (num_experts_per_wave <= 0 or + num_experts_per_wave > num_experts_per_rank or + num_experts_per_rank % num_experts_per_wave != 0) + continue; + candidates.emplace_back(MegaMoESM90Config { + block_m, block_n, block_k, + cluster_size, + num_max_pool_tokens, num_padded_sf_pool_tokens, + swizzle_acts_mode, swizzle_weights_mode, + num_experts_per_wave, + num_stages, smem_size, + num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads, + direct_l2_scatter, l2_nmajor_value != 0, cleanup_value != 0 + }); + } + } + } + } + } + } + } + } + } - const auto config = MegaMoESM90Config { - block_m, block_n, block_k, - cluster_size, - num_max_pool_tokens, num_padded_sf_pool_tokens, - swizzle_acts_mode, swizzle_weights_mode, - num_experts_per_wave, - num_stages, smem_size, - num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads + DG_HOST_ASSERT(not candidates.empty()); + return candidates; +} + +static Sm90MoeConfigInfo get_best_mega_moe_config_info_sm90( + const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + const auto candidates = get_mega_moe_config_candidates_sm90( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens); + const int num_sms = device_runtime->get_num_sms(); + + Sm90MoeConfigInfo best { + std::numeric_limits::max(), 0, 0, 0, 0, candidates[0] }; + std::vector infos; + infos.reserve(candidates.size()); + for (const auto& candidate: candidates) { + const auto policy = get_sm90_moe_heuristic_policy( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, candidate.block_m, candidate.block_n); + const bool empirical_direct_l2_scatter = policy.direct_l2_scatter(); + const bool empirical_l2_nmajor_schedule = policy.l2_nmajor_schedule( + get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, + get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0); + const bool empirical_one_warp_cleanup = policy.one_warp_cleanup( + get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0); + const int empirical_num_stages = policy.stage5_pipeline( + candidate.direct_l2_scatter, + get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, + get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0, + get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0) ? 5 : 4; + const int empirical_num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, candidate.block_m, candidate.block_n, num_sms); + auto info = get_sm90_moe_config_info( + candidate, + num_experts_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_sms, + empirical_direct_l2_scatter, + empirical_l2_nmajor_schedule, + empirical_one_warp_cleanup, + empirical_num_stages, + empirical_num_experts_per_wave); + infos.emplace_back(info); + if (info.score < best.score) + best = info; + } + + if (get_env("DG_SM90_MOE_PRINT_SEARCH", 0) > 0) { + const auto key = fmt::format( + "Sm90MoeConfigSearch(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << key << ": selected " << best << std::endl; + if (get_env("DG_SM90_MOE_PRINT_SEARCH", 0) > 1) { + for (const auto& info: infos) + std::cout << " candidate " << info << std::endl; + } + printed.insert(key); + } + } + return best; +} + +static MegaMoESM90Config get_mega_moe_config_sm90( + const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + const auto config_info = get_best_mega_moe_config_info_sm90( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens); + const auto config = config_info.config; if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { const auto key = fmt::format( diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 6411012495..f189c6e1ae 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -235,15 +235,6 @@ static void sm90_fp8_mega_moe( // Launch const auto num_sms = device_runtime->get_num_sms(); - const bool direct_l2_scatter_default = get_sm90_moe_direct_l2_scatter_default( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, config.block_m, config.block_n); - const bool l2_nmajor_schedule_default = get_sm90_moe_l2_nmajor_schedule_default( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, config.block_m, config.block_n); - const bool one_warp_cleanup_default = get_sm90_moe_one_warp_cleanup_default( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, config.block_m, config.block_n); const SM90FP8MegaMoERuntime::Args args = { .num_max_tokens_per_rank = num_max_tokens_per_rank, .hidden = hidden, .intermediate_hidden = intermediate_hidden, @@ -255,17 +246,17 @@ static void sm90_fp8_mega_moe( .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, .direct_l2_scatter = get_env( "DG_SM90_MOE_DIRECT_L2_SCATTER", - direct_l2_scatter_default ? 1 : 0) != 0, + config.direct_l2_scatter ? 1 : 0) != 0, .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, .l2_nmajor_schedule = get_env( "DG_SM90_MOE_L2_NMAJOR", - l2_nmajor_schedule_default ? 1 : 0) != 0, + config.l2_nmajor_schedule ? 1 : 0) != 0, .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, .one_warp_cleanup = get_env( "DG_SM90_MOE_ONE_WARP_CLEANUP", - one_warp_cleanup_default ? 1 : 0) != 0, + config.one_warp_cleanup ? 1 : 0) != 0, .split_phase_mode = 0, .config = config, .y = y.data_ptr(), diff --git a/docs/sm90_megamoe_config_search.md b/docs/sm90_megamoe_config_search.md new file mode 100644 index 0000000000..d94dcf8949 --- /dev/null +++ b/docs/sm90_megamoe_config_search.md @@ -0,0 +1,122 @@ +# SM90 MegaMoE Config Search + +This note records the H20 SM90 MegaMoE split-kernel defaults before moving the +selector to the DeepGEMM-style candidate-search path. + +## Selector Model + +Regular DeepGEMM GEMM selects parameters by: + +1. generating legal config candidates, +2. estimating a cheap `LayoutInfo`-like score from the shape, +3. choosing the best candidate deterministically, +4. JIT-compiling only the selected kernel variant. + +SM90 MegaMoE now follows the same shape. The selector builds complete +`MegaMoESM90Config` candidates containing tile sizes, wave grouping, stage +count, thread layout, and scheduler/epilogue mode flags. It then ranks them +with a block/wave score plus an empirical calibration layer for MoE-specific +dispatch/combine choices. + +This is not online autotuning. Runtime launch does not benchmark candidate +kernels. The default search is deterministic and keeps the H20 empirical +choices as calibration targets. + +## Search Space + +Default candidate dimensions: + +- `block_m`: default `64`; `128` can be included with + `DG_SM90_MOE_SEARCH_BLOCK_SHAPES=1`. +- `block_n`: default `256` for split L1/L2, otherwise `128`; alternate block-N + candidates can be included with `DG_SM90_MOE_SEARCH_BLOCK_SHAPES=1`. +- `num_epilogue_threads`: derived from block shape; `64x256` uses two epilogue + warpgroups. +- `num_dispatch_threads`: compact split frontend uses `64`; otherwise `128`. +- `direct_l2_scatter`: candidate default from empirical policy, alternate + included unless forced by `DG_SM90_MOE_DIRECT_L2_SCATTER`. +- `l2_nmajor_schedule`: candidate default from empirical policy, alternate + included unless forced by `DG_SM90_MOE_L2_NMAJOR`. +- `one_warp_cleanup`: candidate default from empirical policy, alternate + included unless forced by `DG_SM90_MOE_ONE_WARP_CLEANUP`. +- `num_stages`: empirical 4/5-stage target plus alternate 4/5 candidates unless + forced by `DG_SM90_MOE_NUM_STAGES`; impossible forced values are clamped to + the shared-memory-limited maximum stage count. +- `num_experts_per_wave`: empirical target plus `16` and full-rank candidates + when legal unless forced by `DG_SM90_MOE_EXPERTS_PER_WAVE`. + +Debugging: + +```bash +DG_PRINT_CONFIGS=1 python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 8 16 32 --num-tests 1 +DG_SM90_MOE_PRINT_SEARCH=2 python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 128 --num-tests 1 +``` + +## H20 Empirical Reference + +Configuration: `num_ranks=8`, `hidden=7168`, `intermediate_hidden=2048`, +`num_experts=256`, `num_experts_per_rank=32`, `topk=8`. + +`expected_tokens_per_expert = M * topk / num_experts_per_rank = M / 4`. + +| M | expected/expert | block_m | block_n | stages | dispatch/non-epi/epi threads | experts/wave | direct L2 scatter | L2 N-major | one-warp cleanup | +|---:|---:|---:|---:|---:|---:|---:|---|---|---| +| 8 | 2 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | off | on | +| 16 | 4 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | off | on | +| 32 | 8 | 64 | 256 | 4 | 64 / 64 / 256 | 32 | on | off | on | +| 64 | 16 | 64 | 256 | 5 | 64 / 64 / 256 | 32 | on | off | on | +| 128 | 32 | 64 | 256 | 5 | 64 / 64 / 256 | 32 | on | off | on | +| 256 | 64 | 64 | 256 | 4 | 64 / 64 / 256 | 32 | on | off | on | +| 260 | 65 | 64 | 256 | 4 | 64 / 64 / 256 | 32 | off | off | off | +| 512 | 128 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | off | off | +| 1024 | 256 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | on | off | + +These are the pre-search empirical defaults. Focused H20 sweeps after adding +candidate search found two better defaults for single-size runs: + +- `M=32`: `num_experts_per_wave=16` was faster than the previous `32`. +- `M=260`: enabling both direct L2 scatter and one-warp cleanup was faster than + the previous boundary behavior that turned both off after + `expected_tokens_per_expert > 64`. + +The selector calibration has been updated for those two cases. Other tested +differences, such as 4-stage vs 5-stage at `M=512` and L2 N-major at `M=1024`, +were not stable enough to change the default. + +## Validation Log + +Commands: + +```bash +DG_PRINT_CONFIGS=1 python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 8 16 32 64 128 256 260 512 1024 --num-tests 1 +python3 tests/test_mega_moe_sm90.py --num-processes 1 --layers 1 2 --fail-fast +python3 tests/test_mega_moe_sm90.py --num-processes 2 --layers 1 2 --fail-fast +python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 8 16 32 64 128 256 260 512 --num-tests 20 +``` + +Correctness after candidate selector: + +- `--num-processes 1 --layers 1 2 --fail-fast`: passed all 5 scenarios, + max printed diff `0.0006`. +- `--num-processes 2 --layers 1 2 --fail-fast`: passed all 5 scenarios, + max printed diff `0.0006`. + +Focused H20 benchmark after candidate selector: + +```text +M latency us +8 849.8 +16 806.4 +32 748.7 +64 799.9 +128 864.0 +256 1160.3 +260 1174.0 +512 1894.0 +1024 3226.0 +``` + +The strongest improvements versus the pre-search defaults are at `M=32` +(`experts_per_wave=16`) and `M=260` (direct L2 scatter plus one-warp cleanup). +Other search candidates remain available through env overrides, but the default +calibration was not changed when repeated measurements were noisy or neutral. From 1a5bd09a42f7670d7a072acb3124fc808c48f794 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 2 Jun 2026 13:30:21 +0800 Subject: [PATCH 07/15] Add SM90 MegaMoE H20 H200 profile selector --- csrc/jit_kernels/heuristics/mega_moe.hpp | 158 +++++++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 74c3d62756..46dc07b630 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -1,8 +1,10 @@ #pragma once #include +#include #include #include +#include #include #include @@ -304,7 +306,47 @@ static bool get_sm90_moe_split_l1_l2_default() { return get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; } +enum class Sm90MoeDeviceProfile { + Generic, + H20, + H200 +}; + +static std::string get_sm90_moe_lowercase(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](const unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; +} + +static Sm90MoeDeviceProfile get_sm90_moe_device_profile() { + const auto forced = get_sm90_moe_lowercase( + get_env("DG_SM90_MOE_DEVICE_PROFILE", "")); + if (not forced.empty() and forced != "auto") { + DG_HOST_ASSERT(forced == "generic" or forced == "h20" or forced == "h200"); + if (forced == "h20") + return Sm90MoeDeviceProfile::H20; + if (forced == "h200") + return Sm90MoeDeviceProfile::H200; + return Sm90MoeDeviceProfile::Generic; + } + + const auto device_name = get_sm90_moe_lowercase(device_runtime->get_prop()->name); + if (device_name.find("h200") != std::string::npos) + return Sm90MoeDeviceProfile::H200; + if (device_name.find("h20") != std::string::npos) + return Sm90MoeDeviceProfile::H20; + return Sm90MoeDeviceProfile::Generic; +} + +struct Sm90MoeProfileConfig { + // A zero wave count means "use the generic SM-count based computation". + int num_experts_per_wave, num_stages; + bool direct_l2_scatter, l2_nmajor_schedule, one_warp_cleanup; +}; + struct Sm90MoeHeuristicPolicy { + Sm90MoeDeviceProfile device_profile; bool split_l1_l2; int num_experts_per_rank, num_topk, intermediate_hidden; int block_m, block_n; @@ -331,9 +373,111 @@ struct Sm90MoeHeuristicPolicy { return num_experts_per_rank == 48 and num_topk == 6 and intermediate_hidden == 3072; } + bool h20_main_topk8_profile_config(Sm90MoeProfileConfig& config, + const bool& direct_l2_scatter_enabled, + const bool& eplb_hint, + const bool& skew_hint, + const bool& masked_hint) const { + int wave_override = 0; + if (expected_tokens_per_expert == 8.0f or + expected_tokens_per_expert == 128.0f or + (expected_tokens_per_expert >= 256.0f and expected_tokens_per_expert < 512.0f)) { + wave_override = 16; + } + + const bool direct_l2_scatter_enabled_by_profile = + expected_is_one_of(2, 4, 8, 16, 32, 64, 76, 80, 88, 128) or + expected_is_between(64.0f, 80.0f) or + expected_is_between(96.0f, 120.0f) or + expected_tokens_per_expert >= 144.0f; + + const bool l2_nmajor_schedule_enabled = [&]() { + if (expected_tokens_per_expert == 256.0f and eplb_hint) + return false; + if (expected_tokens_per_expert >= 256.0f and skew_hint) + return false; + return expected_tokens_per_expert >= 256.0f; + }(); + + const bool one_warp_cleanup_enabled = expected_tokens_per_expert <= 80.0f; + const bool stage5_pipeline_enabled = [&]() { + if (not direct_l2_scatter_enabled) + return false; + const bool hinted_m64 = + (eplb_hint or skew_hint or masked_hint) and expected_tokens_per_expert == 64.0f; + return expected_is_one_of(2, 4, 16, 32, 128) or + hinted_m64 or + expected_tokens_per_expert >= 192.0f; + }(); + + config = { + wave_override, + stage5_pipeline_enabled ? 5 : 4, + direct_l2_scatter_enabled_by_profile, + l2_nmajor_schedule_enabled, + one_warp_cleanup_enabled + }; + return true; + } + + bool h200_main_topk8_profile_config(Sm90MoeProfileConfig& config) const { + // Calibrated from the H200 0601 sweep. Buckets are keyed by + // expected_tokens_per_expert = M * topk / experts_per_rank. + // The sweep labels include some requested `d0_s5` rows, but BN256 with + // direct L2 scatter disabled is capped at 4 stages by SM90 SMEM limits. + if (expected_tokens_per_expert <= 3.0f) { + config = {32, 4, true, true, false}; + } else if (expected_tokens_per_expert <= 6.0f) { + config = {32, 4, false, true, true}; + } else if (expected_tokens_per_expert <= 12.0f) { + config = {32, 4, true, false, true}; + } else if (expected_tokens_per_expert <= 24.0f) { + config = {32, 4, false, true, true}; + } else if (expected_tokens_per_expert <= 48.0f) { + config = {32, 4, true, false, true}; + } else if (expected_tokens_per_expert <= 64.5f) { + config = {32, 4, false, true, true}; + } else if (expected_tokens_per_expert <= 160.0f) { + config = {32, 4, false, true, false}; + } else if (expected_tokens_per_expert <= 240.0f) { + config = {32, 4, false, true, false}; + } else if (expected_tokens_per_expert <= 384.0f) { + config = {16, 4, false, true, false}; + } else if (expected_tokens_per_expert <= 640.0f) { + config = {32, 4, false, true, true}; + } else if (expected_tokens_per_expert <= 896.0f) { + config = {32, 4, false, true, false}; + } else if (expected_tokens_per_expert <= 1536.0f) { + config = {32, 4, false, true, true}; + } else { + config = {32, 4, false, true, false}; + } + return true; + } + + bool device_profile_config(Sm90MoeProfileConfig& config, + const bool& direct_l2_scatter_enabled = false, + const bool& eplb_hint = false, + const bool& skew_hint = false, + const bool& masked_hint = false) const { + if (not uses_split_bn256() or not is_main_topk8()) + return false; + + if (device_profile == Sm90MoeDeviceProfile::H20) { + return h20_main_topk8_profile_config( + config, direct_l2_scatter_enabled, eplb_hint, skew_hint, masked_hint); + } + if (device_profile == Sm90MoeDeviceProfile::H200) + return h200_main_topk8_profile_config(config); + return false; + } + int experts_per_wave_override() const { if (not (block_m == 64 and block_n == 256)) return 0; + Sm90MoeProfileConfig profile_config; + if (device_profile_config(profile_config)) + return profile_config.num_experts_per_wave; if (is_hopper_topk6() and expected_tokens_per_expert >= 8.0f and expected_tokens_per_expert <= 32.0f) return 16; if (is_main_topk8() and expected_tokens_per_expert == 8.0f) @@ -348,6 +492,9 @@ struct Sm90MoeHeuristicPolicy { bool direct_l2_scatter() const { if (not uses_split_bn256()) return false; + Sm90MoeProfileConfig profile_config; + if (device_profile_config(profile_config)) + return profile_config.direct_l2_scatter; if (is_main_topk8()) { return expected_is_one_of(2, 4, 8, 16, 32, 64, 76, 80, 88, 128) or expected_is_between(64.0f, 80.0f) or @@ -364,6 +511,9 @@ struct Sm90MoeHeuristicPolicy { bool l2_nmajor_schedule(const bool& eplb_hint, const bool& skew_hint) const { if (not uses_split_bn256() or not is_main_topk8()) return false; + Sm90MoeProfileConfig profile_config; + if (device_profile_config(profile_config, false, eplb_hint, skew_hint)) + return profile_config.l2_nmajor_schedule; if (expected_tokens_per_expert == 256.0f and eplb_hint) return false; if (expected_tokens_per_expert >= 256.0f and skew_hint) @@ -374,6 +524,9 @@ struct Sm90MoeHeuristicPolicy { bool one_warp_cleanup(const bool& masked_hint) const { if (not uses_split_bn256()) return false; + Sm90MoeProfileConfig profile_config; + if (device_profile_config(profile_config, false, false, false, masked_hint)) + return profile_config.one_warp_cleanup; if (is_main_topk8() and expected_tokens_per_expert <= 80.0f) return true; if (is_hopper_topk6() and masked_hint and expected_tokens_per_expert == 64.0f) @@ -385,6 +538,10 @@ struct Sm90MoeHeuristicPolicy { const bool& eplb_hint, const bool& skew_hint, const bool& masked_hint) const { + Sm90MoeProfileConfig profile_config; + if (device_profile_config( + profile_config, direct_l2_scatter_enabled, eplb_hint, skew_hint, masked_hint)) + return profile_config.num_stages == 5; if (not direct_l2_scatter_enabled) return false; if (is_main_topk8()) { @@ -407,6 +564,7 @@ static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, const int& intermediate_hidden, const int& block_m, const int& block_n) { return { + get_sm90_moe_device_profile(), get_sm90_moe_split_l1_l2_default(), num_experts_per_rank, num_topk, From b03343004d8fc0edc20829911e18c4b31af03726 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 2 Jun 2026 14:10:14 +0800 Subject: [PATCH 08/15] Refactor SM90 MegaMoE split kernels --- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 42 +-- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 323 +++++++++++------- .../include/deep_gemm/scheduler/mega_moe.cuh | 39 +-- tests/test_mega_moe_hopper.py | 35 +- 4 files changed, 267 insertions(+), 172 deletions(-) diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index f189c6e1ae..78963bd046 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -31,6 +31,12 @@ namespace deep_gemm { class SM90FP8MegaMoERuntime final : public LaunchRuntime { public: + enum class KernelPhase { + Fused, + Linear1, + Linear2 + }; + struct Args { // Templated arguments int num_max_tokens_per_rank; @@ -48,7 +54,7 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime bool l2_nmajor_schedule; bool l1_nmajor_schedule; bool one_warp_cleanup; - int split_phase_mode; + KernelPhase kernel_phase; MegaMoESM90Config config; // Runtime arguments @@ -75,13 +81,15 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime }; static std::string generate_impl(const Args& args) { + const char* kernel_symbol = args.kernel_phase == KernelPhase::Linear1 ? "sm90_fp8_mega_moe_l1_impl" : + (args.kernel_phase == KernelPhase::Linear2 ? "sm90_fp8_mega_moe_l2_impl" : "sm90_fp8_mega_moe_impl"); return fmt::format(R"( #include using namespace deep_gemm; static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&sm90_fp8_mega_moe_impl< + auto ptr = reinterpret_cast(&{}< {}, {}, {}, {}, {}, @@ -103,11 +111,11 @@ static void __instantiate_kernel() {{ {}, {}, {}, - {}, {} >); }}; )", + kernel_symbol, args.num_max_tokens_per_rank, args.hidden, args.intermediate_hidden, args.num_experts, args.num_topk, @@ -129,8 +137,7 @@ static void __instantiate_kernel() {{ args.l1_dual_k_accum ? "true" : "false", args.l2_nmajor_schedule ? "true" : "false", args.l1_nmajor_schedule ? "true" : "false", - args.one_warp_cleanup ? "true" : "false", - args.split_phase_mode); + args.one_warp_cleanup ? "true" : "false"); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -244,20 +251,14 @@ static void sm90_fp8_mega_moe( .fast_math = fast_math, .async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0, .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, - .direct_l2_scatter = get_env( - "DG_SM90_MOE_DIRECT_L2_SCATTER", - config.direct_l2_scatter ? 1 : 0) != 0, + .direct_l2_scatter = config.direct_l2_scatter, .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, - .l2_nmajor_schedule = get_env( - "DG_SM90_MOE_L2_NMAJOR", - config.l2_nmajor_schedule ? 1 : 0) != 0, + .l2_nmajor_schedule = config.l2_nmajor_schedule, .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, - .one_warp_cleanup = get_env( - "DG_SM90_MOE_ONE_WARP_CLEANUP", - config.one_warp_cleanup ? 1 : 0) != 0, - .split_phase_mode = 0, + .one_warp_cleanup = config.one_warp_cleanup, + .kernel_phase = SM90FP8MegaMoERuntime::KernelPhase::Fused, .config = config, .y = y.data_ptr(), .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, @@ -275,9 +276,10 @@ static void sm90_fp8_mega_moe( .launch_args = LaunchArgs(num_sms, config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, config.smem_size, config.cluster_size) }; - const auto launch_with_split_mode = [&](const int split_phase_mode, const char* kernel_name) { + const auto launch_with_phase = [&](const SM90FP8MegaMoERuntime::KernelPhase kernel_phase, + const char* kernel_name) { auto split_args = args; - split_args.split_phase_mode = split_phase_mode; + split_args.kernel_phase = kernel_phase; const auto code = SM90FP8MegaMoERuntime::generate(split_args); const auto runtime = compiler->build(kernel_name, code); SM90FP8MegaMoERuntime::launch(runtime, split_args); @@ -285,10 +287,10 @@ static void sm90_fp8_mega_moe( const bool split_l1_l2 = get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; if (split_l1_l2) { - launch_with_split_mode(1, "sm90_fp8_mega_moe_split_l1"); - launch_with_split_mode(2, "sm90_fp8_mega_moe_split_l2"); + launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Linear1, "sm90_fp8_mega_moe_l1_impl"); + launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Linear2, "sm90_fp8_mega_moe_l2_impl"); } else { - launch_with_split_mode(0, "sm90_fp8_mega_moe"); + launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Fused, "sm90_fp8_mega_moe_impl"); } } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index c98c326700..fd9310ed98 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -50,57 +50,152 @@ namespace deep_gemm { // reduction in BF16) — ported verbatim from the SM100 kernel. // ============================================================================ -template < - uint32_t kNumMaxTokensPerRank, - uint32_t kHidden, uint32_t kIntermediateHidden, - uint32_t kNumExperts, uint32_t kNumTopk, - uint32_t kNumExpertsPerWave, - uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, - uint32_t kNumMaxPoolTokens, - uint32_t kNumPaddedSFPoolTokens, - uint32_t kNumStages, - uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, - uint32_t kNumEpilogueThreads, - uint32_t kClusterSize, - uint32_t kNumSMs, uint32_t kNumRanks, - float kActivationClamp, - bool kFastMath, - bool kAsyncL1TMAStoreRequested = true, - bool kSplitSFATMARequested = false, - bool kDirectL2ScatterRequested = false, - bool kL2DualAccumRequested = false, - bool kPhaseProfileRequested = false, - bool kL1DualKAccumRequested = false, - bool kL2NMajorScheduleRequested = false, - bool kL1NMajorScheduleRequested = false, - bool kOneWarpCleanupRequested = false, - uint32_t kSplitPhaseMode = 0, - uint32_t L1_SHAPE_N = kIntermediateHidden * 2, - uint32_t L1_SHAPE_K = kHidden, - uint32_t L2_SHAPE_N = kHidden, - uint32_t L2_SHAPE_K = kIntermediateHidden, - uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, - uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, - uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, - uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, - uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, - uint32_t kNumTokensPerWarp = 32 / kNumTopk, +enum class MegaMoEPhaseKind { + Fused, + Linear1, + Linear2 +}; + +template +struct MegaMoEPhasePolicy { + static constexpr bool is_fused = kKind == MegaMoEPhaseKind::Fused; + static constexpr bool is_linear1_only = kKind == MegaMoEPhaseKind::Linear1; + static constexpr bool is_linear2_only = kKind == MegaMoEPhaseKind::Linear2; + static constexpr bool runs_linear1 = is_fused or is_linear1_only; + static constexpr bool runs_linear2 = is_fused or is_linear2_only; + static constexpr bool needs_dispatch_pull = runs_linear1; + static constexpr bool needs_combine = runs_linear2; + static constexpr bool signals_l1_ready = is_fused; + static constexpr bool waits_for_l2_arrival_mask = is_fused; + static constexpr bool cleans_l1_dispatch_state = is_fused; + + CUTLASS_DEVICE static bool is_linear1_phase(const sched::BlockPhase& block_phase) { + if constexpr (is_linear1_only) { + (void)block_phase; + return true; + } else if constexpr (is_linear2_only) { + (void)block_phase; + return false; + } else { + return block_phase == sched::BlockPhase::Linear1; + } + } + + template + CUTLASS_DEVICE static void for_each_selected_block(Scheduler& scheduler, Func&& func) { + if constexpr (is_linear1_only) { + scheduler.template for_each_phase_block( + [&](const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + func(sched::BlockPhase::Linear1, local_expert_idx, + num_k_blocks, m_block_idx, n_block_idx); + }); + } else if constexpr (is_linear2_only) { + scheduler.template for_each_phase_block( + [&](const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + func(sched::BlockPhase::Linear2, local_expert_idx, + num_k_blocks, m_block_idx, n_block_idx); + }); + } else { + scheduler.for_each_block(func); + } + } +}; + +using MegaMoEFusedPhase = MegaMoEPhasePolicy; +using MegaMoELinear1Phase = MegaMoEPhasePolicy; +using MegaMoELinear2Phase = MegaMoEPhasePolicy; + +#define DG_SM90_FP8_MOE_TEMPLATE_PARAMS \ + uint32_t kNumMaxTokensPerRank, \ + uint32_t kHidden, uint32_t kIntermediateHidden, \ + uint32_t kNumExperts, uint32_t kNumTopk, \ + uint32_t kNumExpertsPerWave, \ + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, \ + uint32_t kNumMaxPoolTokens, \ + uint32_t kNumPaddedSFPoolTokens, \ + uint32_t kNumStages, \ + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, \ + uint32_t kNumEpilogueThreads, \ + uint32_t kClusterSize, \ + uint32_t kNumSMs, uint32_t kNumRanks, \ + float kActivationClamp, \ + bool kFastMath, \ + bool kAsyncL1TMAStoreRequested = true, \ + bool kSplitSFATMARequested = false, \ + bool kDirectL2ScatterRequested = false, \ + bool kL2DualAccumRequested = false, \ + bool kPhaseProfileRequested = false, \ + bool kL1DualKAccumRequested = false, \ + bool kL2NMajorScheduleRequested = false, \ + bool kL1NMajorScheduleRequested = false, \ + bool kOneWarpCleanupRequested = false, \ + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, \ + uint32_t L1_SHAPE_K = kHidden, \ + uint32_t L2_SHAPE_N = kHidden, \ + uint32_t L2_SHAPE_K = kIntermediateHidden, \ + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, \ + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, \ + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, \ + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, \ + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, \ + uint32_t kNumTokensPerWarp = 32 / kNumTopk, \ uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks -> -CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void -sm90_fp8_mega_moe_impl(void* y, - int* cumulative_local_expert_recv_stats, - const uint32_t num_tokens, - const __grid_constant__ layout::SymBuffer sym_buffer, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, - const float* __restrict__ l1_weights_sf, - const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, - const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, - const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, - const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, - const float* __restrict__ l2_weights_sf) { + +#define DG_SM90_FP8_MOE_KERNEL_ARGS_DECL \ + void* y, \ + int* cumulative_local_expert_recv_stats, \ + const uint32_t num_tokens, \ + const __grid_constant__ layout::SymBuffer sym_buffer, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, \ + const float* __restrict__ l1_weights_sf, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, \ + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, \ + const float* __restrict__ l2_weights_sf + +#define DG_SM90_FP8_MOE_CORE_ARGS_DECL \ + void* y, \ + int* cumulative_local_expert_recv_stats, \ + const uint32_t num_tokens, \ + const layout::SymBuffer& sym_buffer, \ + const cute::TmaDescriptor& tensor_map_l1_acts, \ + const cute::TmaDescriptor& tensor_map_l1_acts_sf, \ + const cute::TmaDescriptor& tensor_map_l1_weights, \ + const float* __restrict__ l1_weights_sf, \ + const cute::TmaDescriptor& tensor_map_l1_output, \ + const cute::TmaDescriptor& tensor_map_l2_acts, \ + const cute::TmaDescriptor& tensor_map_l2_acts_sf, \ + const cute::TmaDescriptor& tensor_map_l2_weights, \ + const float* __restrict__ l2_weights_sf + +#define DG_SM90_FP8_MOE_KERNEL_ARGS \ + y, cumulative_local_expert_recv_stats, num_tokens, sym_buffer, \ + tensor_map_l1_acts, tensor_map_l1_acts_sf, tensor_map_l1_weights, \ + l1_weights_sf, tensor_map_l1_output, tensor_map_l2_acts, \ + tensor_map_l2_acts_sf, tensor_map_l2_weights, l2_weights_sf + +#define DG_SM90_FP8_MOE_CORE_TEMPLATE_ARGS(PhasePolicy) \ + PhasePolicy, \ + kNumMaxTokensPerRank, kHidden, kIntermediateHidden, kNumExperts, kNumTopk, \ + kNumExpertsPerWave, BLOCK_M, BLOCK_N, BLOCK_K, kNumMaxPoolTokens, \ + kNumPaddedSFPoolTokens, kNumStages, kNumDispatchThreads, \ + kNumNonEpilogueThreads, kNumEpilogueThreads, kClusterSize, kNumSMs, \ + kNumRanks, kActivationClamp, kFastMath, kAsyncL1TMAStoreRequested, \ + kSplitSFATMARequested, kDirectL2ScatterRequested, kL2DualAccumRequested, \ + kPhaseProfileRequested, kL1DualKAccumRequested, kL2NMajorScheduleRequested, \ + kL1NMajorScheduleRequested, kOneWarpCleanupRequested, L1_SHAPE_N, \ + L1_SHAPE_K, L2_SHAPE_N, L2_SHAPE_K, kNumDispatchWarps, \ + kNumMMANonEpilogueWarps, kNumEpilogueWarps, kNumEpilogueWarpgroups, \ + kNumThreads, kNumTokensPerWarp, kNumExpertsPerRank + +template +CUTLASS_DEVICE void +sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900) and (__CUDA_ARCH__ < 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; @@ -119,7 +214,6 @@ sm90_fp8_mega_moe_impl(void* y, "BLOCK_M must be 16/32 for mma.sync decode or a multiple of WGMMA::M (64)"); DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); - DG_STATIC_ASSERT(kSplitPhaseMode <= 2, "Invalid SM90 MegaMoE split phase mode"); // ===================================================================== // Thread / warp identification @@ -131,13 +225,13 @@ sm90_fp8_mega_moe_impl(void* y, // Prefetch the TMA descriptors used by this split phase. if (warp_idx == 0 and cute::elect_one_sync()) { - if constexpr (kSplitPhaseMode != 2) { + if constexpr (MegaMoEPhase::runs_linear1) { cute::prefetch_tma_descriptor(&tensor_map_l1_acts); cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); cute::prefetch_tma_descriptor(&tensor_map_l1_weights); cute::prefetch_tma_descriptor(&tensor_map_l1_output); } - if constexpr (kSplitPhaseMode != 1) { + if constexpr (MegaMoEPhase::runs_linear2) { cute::prefetch_tma_descriptor(&tensor_map_l2_acts); cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); cute::prefetch_tma_descriptor(&tensor_map_l2_weights); @@ -195,17 +289,15 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N - constexpr bool kRunOnlyLinear1 = kSplitPhaseMode == 1; - constexpr bool kRunOnlyLinear2 = kSplitPhaseMode == 2; constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && (!kUseMMASync) && + kAsyncL1TMAStoreRequested && MegaMoEPhase::runs_linear1 && (!kUseMMASync) && (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); - constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && (!kUseMMASync) && + constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && MegaMoEPhase::runs_linear2 && (!kUseMMASync) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL2DualAccum = kL2DualAccumRequested && (!kUseMMASync) && + constexpr bool kL2DualAccum = kL2DualAccumRequested && MegaMoEPhase::runs_linear2 && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL1DualKAccum = kL1DualKAccumRequested && (!kUseMMASync) && + constexpr bool kL1DualKAccum = kL1DualKAccumRequested && MegaMoEPhase::runs_linear1 && (!kUseMMASync) && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && (kHidden / BLOCK_K) % 2 == 0; using L1WGMMA = typename mma::sm90::FP8MMASelector::type; @@ -231,10 +323,10 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t kSharedMemoryAlignment = 1024; extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; - constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = - math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); - constexpr uint32_t SMEM_SEND_BUFFER_SIZE = - math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = MegaMoEPhase::needs_dispatch_pull ? + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment) : 0u; + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = MegaMoEPhase::needs_dispatch_pull ? + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment) : 0u; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 @@ -253,9 +345,9 @@ sm90_fp8_mega_moe_impl(void* y, constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) : 0u; - constexpr uint32_t SMEM_CD_L1_SIZE = - kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); - constexpr uint32_t SMEM_CD_L2_SIZE = kDirectL2Scatter ? 0u : + constexpr uint32_t SMEM_CD_L1_SIZE = MegaMoEPhase::runs_linear1 ? + kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) : 0u; + constexpr uint32_t SMEM_CD_L2_SIZE = (!MegaMoEPhase::runs_linear2 || kDirectL2Scatter) ? 0u : kNumEpilogueWarpgroups * WG_BLOCK_M * WG_BLOCK_N * sizeof(nv_bfloat16); constexpr uint32_t SMEM_CD_L1_ASYNC_ELEMS = kNumEpilogueWarpgroups * WG_BLOCK_M * L1_OUT_BLOCK_N; @@ -336,7 +428,7 @@ sm90_fp8_mega_moe_impl(void* y, // empty barriers too, so every math warp releases both CTAs. empty_barriers[i]->init(kClusterSize * kNumEpilogueWarps); } - if constexpr (!kRunOnlyLinear1) { + if constexpr (MegaMoEPhase::needs_combine) { #pragma unroll for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) combine_barriers[i]->init(1); @@ -424,33 +516,8 @@ sm90_fp8_mega_moe_impl(void* y, } }; - const auto should_run_block_phase = [](const sched::BlockPhase& block_phase) { - if constexpr (kRunOnlyLinear1) { - return block_phase == sched::BlockPhase::Linear1; - } else if constexpr (kRunOnlyLinear2) { - return block_phase == sched::BlockPhase::Linear2; - } else { - (void)block_phase; - return true; - } - }; - const auto for_each_selected_block = [&](auto&& func) { - if constexpr (kRunOnlyLinear1) { - scheduler.for_each_linear1_block([&](const uint32_t& local_expert_idx, - const uint32_t& num_k_blocks, - const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - func(sched::BlockPhase::Linear1, local_expert_idx, num_k_blocks, m_block_idx, n_block_idx); - }); - } else if constexpr (kRunOnlyLinear2) { - scheduler.for_each_linear2_block([&](const uint32_t& local_expert_idx, - const uint32_t& num_k_blocks, - const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - func(sched::BlockPhase::Linear2, local_expert_idx, num_k_blocks, m_block_idx, n_block_idx); - }); - } else { - scheduler.for_each_block(func); - } + MegaMoEPhase::for_each_selected_block(scheduler, func); }; const auto cleanup_workspace = [&]() { @@ -488,7 +555,7 @@ sm90_fp8_mega_moe_impl(void* y, } } - if constexpr (!kRunOnlyLinear2) { + if constexpr (MegaMoEPhase::cleans_l1_dispatch_state) { if constexpr (kOneWarpCleanupRequested) { if (warp_idx == 0) { for (uint32_t j = lane_idx; j < kNumRanks; j += 32) @@ -506,7 +573,7 @@ sm90_fp8_mega_moe_impl(void* y, if (warp_idx == 0) { for (uint32_t j = lane_idx; j < num_recv_m_blocks; j += 32) { *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; - if constexpr (!kRunOnlyLinear2) + if constexpr (MegaMoEPhase::cleans_l1_dispatch_state) *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; } __syncwarp(); @@ -514,7 +581,7 @@ sm90_fp8_mega_moe_impl(void* y, } else { for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; - if constexpr (!kRunOnlyLinear2) + if constexpr (MegaMoEPhase::cleans_l1_dispatch_state) *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; } __syncwarp(); @@ -536,7 +603,7 @@ sm90_fp8_mega_moe_impl(void* y, cutlass::arch::warpgroup_reg_dealloc(); const unsigned long long dispatch_total_start = phase_profile_clock(); - if constexpr (kRunOnlyLinear2) { + if constexpr (MegaMoEPhase::is_linear2_only) { scheduler.fetch_expert_recv_count(); ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); @@ -767,7 +834,7 @@ sm90_fp8_mega_moe_impl(void* y, phase_profile_record(kProfileDispatchPull, dispatch_pull_end - dispatch_pull_start); phase_profile_record(kProfileDispatchTotal, dispatch_pull_end - dispatch_total_start); } - if constexpr (kRunOnlyLinear1) + if constexpr (MegaMoEPhase::is_linear1_only) return; ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); @@ -792,10 +859,7 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - if (!should_run_block_phase(block_phase)) - return; - const bool is_linear1_phase = - kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); + const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); const auto tensor_map_a_ptr = !is_linear1_phase ? &tensor_map_l2_acts : &tensor_map_l1_acts; const auto tensor_map_sfa_ptr = !is_linear1_phase @@ -812,7 +876,7 @@ sm90_fp8_mega_moe_impl(void* y, const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = valid_m; while (ptx::ld_acq(ptr) != expected); - } else if constexpr (!kRunOnlyLinear2) { + } else if constexpr (MegaMoEPhase::waits_for_l2_arrival_mask) { const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; @@ -875,10 +939,7 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - if (!should_run_block_phase(block_phase)) - return; - const bool is_linear1_phase = - kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); + const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); const auto tensor_map_b_ptr = !is_linear1_phase ? &tensor_map_l2_weights : &tensor_map_l1_weights; @@ -909,10 +970,7 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - if (!should_run_block_phase(block_phase)) - return; - const bool is_linear1_phase = - kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); + const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); (void)local_expert_idx; (void)n_block_idx; const auto tensor_map_sfa_ptr = !is_linear1_phase @@ -927,7 +985,7 @@ sm90_fp8_mega_moe_impl(void* y, const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = valid_m; while (ptx::ld_acq(ptr) != expected); - } else if constexpr (!kRunOnlyLinear2) { + } else if constexpr (MegaMoEPhase::waits_for_l2_arrival_mask) { const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; const uint64_t expected = (kNumL1BlockNs >= 64) @@ -1000,8 +1058,7 @@ sm90_fp8_mega_moe_impl(void* y, const auto notify_l1_ready = [&](const uint32_t& ready_pool_block_idx, const uint32_t& ready_n_block_idx) { - // True split L2 runs in a later kernel launch and skips arrival-mask waits. - if constexpr (!kRunOnlyLinear1) { + if constexpr (MegaMoEPhase::signals_l1_ready) { if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { ptx::red_or_rel_gpu( workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), @@ -1073,10 +1130,7 @@ sm90_fp8_mega_moe_impl(void* y, const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - if (!should_run_block_phase(block_phase)) - return; - const bool is_linear1_phase = - kRunOnlyLinear1 ? true : (kRunOnlyLinear2 ? false : block_phase == sched::BlockPhase::Linear1); + const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); const uint32_t valid_m = scheduler.template get_valid_m(); const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; const uint32_t m_idx = pool_block_idx * BLOCK_M; @@ -2177,7 +2231,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl async_l1_store_stage ^= 1u; } else { ptx::tma_store_wait<0>(); - if constexpr (!kRunOnlyLinear1) { + if constexpr (MegaMoEPhase::signals_l1_ready) { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); notify_l1_ready(pool_block_idx, n_block_idx); } @@ -2203,7 +2257,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl async_l1_store_stage ^= 1u; } else { ptx::tma_store_wait<0>(); - if constexpr (!kRunOnlyLinear1) { + if constexpr (MegaMoEPhase::signals_l1_ready) { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); notify_l1_ready(pool_block_idx, n_block_idx); } @@ -2365,7 +2419,7 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if (epilogue_warp_idx == 0 and lane_idx == 0) phase_profile_record(kProfileMathLoop, math_loop_end - math_loop_start); - if constexpr (kRunOnlyLinear1) { + if constexpr (!MegaMoEPhase::needs_combine) { if constexpr (kAsyncL1TMAStore) drain_all_async_l1_stores(); return; @@ -2503,6 +2557,33 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl #endif } +template +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm90_fp8_mega_moe_impl(DG_SM90_FP8_MOE_KERNEL_ARGS_DECL) { + sm90_fp8_mega_moe_core( + DG_SM90_FP8_MOE_KERNEL_ARGS); +} + +template +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm90_fp8_mega_moe_l1_impl(DG_SM90_FP8_MOE_KERNEL_ARGS_DECL) { + sm90_fp8_mega_moe_core( + DG_SM90_FP8_MOE_KERNEL_ARGS); +} + +template +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm90_fp8_mega_moe_l2_impl(DG_SM90_FP8_MOE_KERNEL_ARGS_DECL) { + sm90_fp8_mega_moe_core( + DG_SM90_FP8_MOE_KERNEL_ARGS); +} + +#undef DG_SM90_FP8_MOE_TEMPLATE_PARAMS +#undef DG_SM90_FP8_MOE_KERNEL_ARGS_DECL +#undef DG_SM90_FP8_MOE_CORE_ARGS_DECL +#undef DG_SM90_FP8_MOE_KERNEL_ARGS +#undef DG_SM90_FP8_MOE_CORE_TEMPLATE_ARGS + } // namespace deep_gemm #pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh index dd3b077df6..2c2fa82782 100644 --- a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -243,33 +243,26 @@ struct MegaMoEScheduler { } } - template - CUTLASS_DEVICE void for_each_linear1_block(Func&& func) { - // Split-kernel mode: K1 owns only dispatch + Linear1. Unlike - // for_each_block(), do not burn scheduler iterations on Linear2 blocks. - fetch_expert_recv_count(); - set_expert_idx(0); - while (current_local_expert_idx < kNumExpertsPerRank) { - if (fetch_next_l1_block()) { - block_idx += kNumSMs / kClusterSize; - func(current_local_expert_idx, kNumL1BlockKs, m_block_idx, n_block_idx); - } else if (current_local_expert_idx >= kNumExpertsPerRank) { - break; - } - } - } - - template - CUTLASS_DEVICE void for_each_linear2_block(Func&& func) { - // Split-kernel mode: K2 starts after K1 has completed globally, so all - // L2-ready masks are already final. Schedule Linear2 blocks directly - // from the phase start, like a standalone grouped GEMM. + template + CUTLASS_DEVICE void for_each_phase_block(Func&& func) { + DG_STATIC_ASSERT(kPhase == BlockPhase::Linear1 or kPhase == BlockPhase::Linear2, + "Invalid MegaMoE scheduler phase"); + // Split-kernel mode schedules one phase directly instead of burning + // scheduler iterations on the other phase. fetch_expert_recv_count(); set_expert_idx(0); while (current_local_expert_idx < kNumExpertsPerRank) { - if (fetch_next_l2_block()) { + const bool found_block = [&]() { + if constexpr (kPhase == BlockPhase::Linear1) + return fetch_next_l1_block(); + else + return fetch_next_l2_block(); + }(); + if (found_block) { block_idx += kNumSMs / kClusterSize; - func(current_local_expert_idx, kNumL2BlockKs, m_block_idx, n_block_idx); + constexpr uint32_t kNumPhaseBlockKs = + kPhase == BlockPhase::Linear1 ? kNumL1BlockKs : kNumL2BlockKs; + func(current_local_expert_idx, kNumPhaseBlockKs, m_block_idx, n_block_idx); } else if (current_local_expert_idx >= kNumExpertsPerRank) { break; } diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py index 5db2b4adc8..ca6d38c539 100644 --- a/tests/test_mega_moe_hopper.py +++ b/tests/test_mega_moe_hopper.py @@ -36,9 +36,11 @@ _deep_ep_import_error = ex -# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名, -# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段 -SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl" +# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口前缀同名, +# bench_kineto 用它从 trace 里汇总 fused 或 split mega-MoE 的 GPU 段 +SM90_KERNEL_NAME_PREFIX = "sm90_fp8_mega_moe" +SPLIT_VS_ONE_MAX_ABS_REL_TOL = 1e-4 +FUSED_VS_LEGACY_MEAN_ABS_REL_TOL = 7e-2 # FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准 @@ -1141,16 +1143,26 @@ def run_baseline(): os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = old_split_env diff = (y_split.float() - y_one_kernel.float()).abs() denom = y_one_kernel.float().abs().mean().clamp_min(1e-12) + max_abs = diff.max() + mean_abs_rel = diff.mean().div(denom) + max_abs_threshold = denom * SPLIT_VS_ONE_MAX_ABS_REL_TOL dist_print( "Output diff (split two-kernel vs one-kernel):", once_in_node=True ) dist_print( - f" > max_abs={diff.max().item():.6e}, " + f" > max_abs={max_abs.item():.6e}, " f"mean_abs={diff.mean().item():.6e}, " - f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}", + f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " + f"max_abs_threshold={max_abs_threshold.item():.6e}", once_in_node=True, ) dist_print(once_in_node=True) + assert max_abs <= max_abs_threshold, ( + "split two-kernel output differs from one-kernel output: " + f"max_abs={max_abs.item():.6e}, " + f"threshold={max_abs_threshold.item():.6e}, " + f"mean_ref={denom.item():.6e}" + ) if ep_buffer is not None: out_b = run_baseline() assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, ( @@ -1159,16 +1171,23 @@ def run_baseline(): if args.check_output_diff: diff = (y.float() - out_b.float()).abs() denom = out_b.float().abs().mean().clamp_min(1e-12) + mean_abs_rel = diff.mean().div(denom) dist_print( "Output diff (fused vs legacy-per128 baseline):", once_in_node=True ) dist_print( f" > max_abs={diff.max().item():.6e}, " f"mean_abs={diff.mean().item():.6e}, " - f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}", + f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " + f"threshold={FUSED_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}", once_in_node=True, ) dist_print(once_in_node=True) + assert mean_abs_rel <= FUSED_VS_LEGACY_MEAN_ABS_REL_TOL, ( + "fused output differs from legacy-per128 baseline: " + f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " + f"threshold={FUSED_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}" + ) # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ---- # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目 @@ -1207,12 +1226,12 @@ def run_baseline(): max_expert_tokens = 0 # ---- benchmark ---- - # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead) + # fused:bench_kineto 抓 sm90_fp8_mega_moe* 的 GPU 段(不含 host overhead) if phase_profile_enabled: cum_stats_fused.zero_() t_fused = bench_kineto( run_fused, - SM90_KERNEL_NAME, + SM90_KERNEL_NAME_PREFIX, num_tests=args.num_bench_tests, barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer is not None From e35a3e3b8f95e9568207c41befc9ed22435d8d52 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 2 Jun 2026 15:11:13 +0800 Subject: [PATCH 09/15] Keep SM90 split L2 combine smem capacity --- .../include/deep_gemm/impls/sm90_fp8_mega_moe.cuh | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index fd9310ed98..995db3ca4b 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -323,10 +323,12 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { constexpr uint32_t kSharedMemoryAlignment = 1024; extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; - constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = MegaMoEPhase::needs_dispatch_pull ? - math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment) : 0u; - constexpr uint32_t SMEM_SEND_BUFFER_SIZE = MegaMoEPhase::needs_dispatch_pull ? - math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment) : 0u; + // Combine reuses the pre-barrier SMEM region, so split L2 keeps this + // dispatch scratch capacity even though it does not run dispatch pull. + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 From facc921020903d4c2e11a95edc6f18f8d7c5dc90 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 2 Jun 2026 16:23:39 +0800 Subject: [PATCH 10/15] Remove SM90 MegaMoE fused fallback --- csrc/jit_kernels/heuristics/mega_moe.hpp | 23 +-- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 14 +- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 98 +--------- deep_gemm/testing/bench.py | 10 +- tests/bench_mega_moe_sm90.py | 23 ++- tests/bench_mega_moe_sm90_skew.py | 17 +- tests/bench_mega_moe_sm90_skewed.py | 21 +-- tests/test_mega_moe_hopper.py | 173 +++++++----------- 8 files changed, 113 insertions(+), 266 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 46dc07b630..2e33d4aa9e 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -302,10 +302,6 @@ struct MegaMoESM90Config { } }; -static bool get_sm90_moe_split_l1_l2_default() { - return get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; -} - enum class Sm90MoeDeviceProfile { Generic, H20, @@ -347,7 +343,6 @@ struct Sm90MoeProfileConfig { struct Sm90MoeHeuristicPolicy { Sm90MoeDeviceProfile device_profile; - bool split_l1_l2; int num_experts_per_rank, num_topk, intermediate_hidden; int block_m, block_n; float expected_tokens_per_expert; @@ -362,7 +357,7 @@ struct Sm90MoeHeuristicPolicy { } bool uses_split_bn256() const { - return split_l1_l2 and block_m == 64 and block_n == 256; + return block_m == 64 and block_n == 256; } bool is_main_topk8() const { @@ -379,8 +374,7 @@ struct Sm90MoeHeuristicPolicy { const bool& skew_hint, const bool& masked_hint) const { int wave_override = 0; - if (expected_tokens_per_expert == 8.0f or - expected_tokens_per_expert == 128.0f or + if (expected_tokens_per_expert == 128.0f or (expected_tokens_per_expert >= 256.0f and expected_tokens_per_expert < 512.0f)) { wave_override = 16; } @@ -480,8 +474,6 @@ struct Sm90MoeHeuristicPolicy { return profile_config.num_experts_per_wave; if (is_hopper_topk6() and expected_tokens_per_expert >= 8.0f and expected_tokens_per_expert <= 32.0f) return 16; - if (is_main_topk8() and expected_tokens_per_expert == 8.0f) - return 16; if (is_main_topk8() and expected_tokens_per_expert == 128.0f) return 16; if (is_main_topk8() and expected_tokens_per_expert >= 256.0f and expected_tokens_per_expert < 512.0f) @@ -565,7 +557,6 @@ static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( const int& intermediate_hidden, const int& block_m, const int& block_n) { return { get_sm90_moe_device_profile(), - get_sm90_moe_split_l1_l2_default(), num_experts_per_rank, num_topk, intermediate_hidden, @@ -596,7 +587,7 @@ static std::tuple get_block_config_for_mega_moe_sm90( const bool use_mma_sync_decode = requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG", get_sm90_moe_split_l1_l2_default() ? 1 : 0) != 0 and + get_env("DG_SM90_MOE_BN256_2WG", 1) != 0 and forced_block_m != 128 and not use_mma_sync_decode; const bool use_bn256_seq_n = get_env("DG_SM90_MOE_BN256_SEQ") != 0 and @@ -737,8 +728,7 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( // Select the retained stage count for the current shape. const int max_num_stages = (smem_capacity - smem_fixed) / (smem_per_stage + smem_barriers_per_stage); - const bool split_l1_l2 = get_sm90_moe_split_l1_l2_default(); - const bool prefer_bn256_split = split_l1_l2 and block_n == 256; + const bool prefer_bn256_split = block_n == 256; const int preferred_num_stages = default_num_stages > 0 ? std::min(default_num_stages, max_num_stages) : (prefer_bn256_split ? std::min(4, max_num_stages) : 0); @@ -849,7 +839,6 @@ static std::vector get_mega_moe_config_candidates_sm90( const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, const int& hidden, const int& intermediate_hidden, const int& num_padded_sf_pool_tokens) { - const bool split_l1_l2 = get_sm90_moe_split_l1_l2_default(); const bool extra_modes = get_env("DG_SM90_MOE_SEARCH_EXTRA_MODES", 1) != 0; const bool extra_block_shapes = get_env("DG_SM90_MOE_SEARCH_BLOCK_SHAPES", 0) != 0; const float expected_tokens_per_expert = @@ -871,7 +860,7 @@ static std::vector get_mega_moe_config_candidates_sm90( const bool use_b_stationary_2wg = get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; const bool use_bn256_split_n_env = - get_env("DG_SM90_MOE_BN256_2WG", split_l1_l2 ? 1 : 0) != 0 and + get_env("DG_SM90_MOE_BN256_2WG", 1) != 0 and forced_block_m != 128 and not use_mma_sync_decode; const bool use_bn256_seq_n_env = get_env("DG_SM90_MOE_BN256_SEQ") != 0 and @@ -947,7 +936,7 @@ static std::vector get_mega_moe_config_candidates_sm90( const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; const bool prefer_compact_frontend = - split_l1_l2 and block_n == 256 and not split_sfa_tma; + block_n == 256 and not split_sfa_tma; const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND", prefer_compact_frontend ? 1 : 0) != 0; const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS", -1); diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 78963bd046..15fcbdbc28 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -32,7 +32,6 @@ namespace deep_gemm { class SM90FP8MegaMoERuntime final : public LaunchRuntime { public: enum class KernelPhase { - Fused, Linear1, Linear2 }; @@ -82,7 +81,7 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime static std::string generate_impl(const Args& args) { const char* kernel_symbol = args.kernel_phase == KernelPhase::Linear1 ? "sm90_fp8_mega_moe_l1_impl" : - (args.kernel_phase == KernelPhase::Linear2 ? "sm90_fp8_mega_moe_l2_impl" : "sm90_fp8_mega_moe_impl"); + "sm90_fp8_mega_moe_l2_impl"; return fmt::format(R"( #include @@ -258,7 +257,7 @@ static void sm90_fp8_mega_moe( .l2_nmajor_schedule = config.l2_nmajor_schedule, .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, .one_warp_cleanup = config.one_warp_cleanup, - .kernel_phase = SM90FP8MegaMoERuntime::KernelPhase::Fused, + .kernel_phase = SM90FP8MegaMoERuntime::KernelPhase::Linear1, .config = config, .y = y.data_ptr(), .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, @@ -285,13 +284,8 @@ static void sm90_fp8_mega_moe( SM90FP8MegaMoERuntime::launch(runtime, split_args); }; - const bool split_l1_l2 = get_env("DG_SM90_MOE_SPLIT_L1_L2", 1) != 0; - if (split_l1_l2) { - launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Linear1, "sm90_fp8_mega_moe_l1_impl"); - launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Linear2, "sm90_fp8_mega_moe_l2_impl"); - } else { - launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Fused, "sm90_fp8_mega_moe_impl"); - } + launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Linear1, "sm90_fp8_mega_moe_l1_impl"); + launch_with_phase(SM90FP8MegaMoERuntime::KernelPhase::Linear2, "sm90_fp8_mega_moe_l2_impl"); } } // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index 995db3ca4b..f8522709c7 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -51,33 +51,26 @@ namespace deep_gemm { // ============================================================================ enum class MegaMoEPhaseKind { - Fused, Linear1, Linear2 }; template struct MegaMoEPhasePolicy { - static constexpr bool is_fused = kKind == MegaMoEPhaseKind::Fused; static constexpr bool is_linear1_only = kKind == MegaMoEPhaseKind::Linear1; static constexpr bool is_linear2_only = kKind == MegaMoEPhaseKind::Linear2; - static constexpr bool runs_linear1 = is_fused or is_linear1_only; - static constexpr bool runs_linear2 = is_fused or is_linear2_only; + static constexpr bool runs_linear1 = is_linear1_only; + static constexpr bool runs_linear2 = is_linear2_only; static constexpr bool needs_dispatch_pull = runs_linear1; static constexpr bool needs_combine = runs_linear2; - static constexpr bool signals_l1_ready = is_fused; - static constexpr bool waits_for_l2_arrival_mask = is_fused; - static constexpr bool cleans_l1_dispatch_state = is_fused; CUTLASS_DEVICE static bool is_linear1_phase(const sched::BlockPhase& block_phase) { if constexpr (is_linear1_only) { (void)block_phase; return true; - } else if constexpr (is_linear2_only) { + } else { (void)block_phase; return false; - } else { - return block_phase == sched::BlockPhase::Linear1; } } @@ -90,20 +83,17 @@ struct MegaMoEPhasePolicy { func(sched::BlockPhase::Linear1, local_expert_idx, num_k_blocks, m_block_idx, n_block_idx); }); - } else if constexpr (is_linear2_only) { + } else { scheduler.template for_each_phase_block( [&](const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { func(sched::BlockPhase::Linear2, local_expert_idx, num_k_blocks, m_block_idx, n_block_idx); }); - } else { - scheduler.for_each_block(func); } } }; -using MegaMoEFusedPhase = MegaMoEPhasePolicy; using MegaMoELinear1Phase = MegaMoEPhasePolicy; using MegaMoELinear2Phase = MegaMoEPhasePolicy; @@ -557,34 +547,16 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { } } - if constexpr (MegaMoEPhase::cleans_l1_dispatch_state) { - if constexpr (kOneWarpCleanupRequested) { - if (warp_idx == 0) { - for (uint32_t j = lane_idx; j < kNumRanks; j += 32) - *workspace.get_expert_recv_count_ptr(j, i) = 0; - __syncwarp(); - } - } else { - for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) - *workspace.get_expert_recv_count_ptr(j, i) = 0; - __syncwarp(); - } - } - if constexpr (kOneWarpCleanupRequested) { if (warp_idx == 0) { for (uint32_t j = lane_idx; j < num_recv_m_blocks; j += 32) { *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; - if constexpr (MegaMoEPhase::cleans_l1_dispatch_state) - *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; } __syncwarp(); } } else { for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { *workspace.get_l1_arrival_count_ptr(cleanup_pool_block_offset + j) = 0; - if constexpr (MegaMoEPhase::cleans_l1_dispatch_state) - *workspace.get_l2_arrival_mask_ptr(cleanup_pool_block_offset + j) = 0; } __syncwarp(); } @@ -878,13 +850,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = valid_m; while (ptx::ld_acq(ptr) != expected); - } else if constexpr (MegaMoEPhase::waits_for_l2_arrival_mask) { - const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); - // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. - constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; - const uint64_t expected = (kNumL1BlockNs >= 64) - ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); - while (ptx::ld_acq_gpu(ptr) != expected); } } for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { @@ -987,12 +952,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); const auto expected = valid_m; while (ptx::ld_acq(ptr) != expected); - } else if constexpr (MegaMoEPhase::waits_for_l2_arrival_mask) { - const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); - constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; - const uint64_t expected = (kNumL1BlockNs >= 64) - ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); - while (ptx::ld_acq_gpu(ptr) != expected); } } @@ -1045,8 +1004,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { uint32_t async_l1_store_stage = 0; bool async_l1_store_pending[2] = {false, false}; - uint32_t async_l1_store_pool[2] = {0, 0}; - uint32_t async_l1_store_n[2] = {0, 0}; const auto arrive_empty_barrier = [&](const uint32_t& s) { if constexpr (kClusterSize == 1) { @@ -1058,21 +1015,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { } }; - const auto notify_l1_ready = [&](const uint32_t& ready_pool_block_idx, - const uint32_t& ready_n_block_idx) { - if constexpr (MegaMoEPhase::signals_l1_ready) { - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - ptx::red_or_rel_gpu( - workspace.get_l2_arrival_mask_ptr(ready_pool_block_idx), - 1ull << ready_n_block_idx); - } - __syncwarp(); - } else { - (void)ready_pool_block_idx; - (void)ready_n_block_idx; - } - }; - const auto drain_async_l1_store_stage = [&](const uint32_t& store_stage) { if constexpr (kAsyncL1TMAStore) { if (async_l1_store_pending[store_stage]) { @@ -1080,8 +1022,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { // for <=1 outstanding store makes the older buffer reusable. ptx::tma_store_wait<1>(); ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(async_l1_store_pool[store_stage], - async_l1_store_n[store_stage]); async_l1_store_pending[store_stage] = false; } } @@ -1092,14 +1032,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { if (async_l1_store_pending[0] or async_l1_store_pending[1]) { ptx::tma_store_wait<0>(); ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (async_l1_store_pending[0]) { - notify_l1_ready(async_l1_store_pool[0], async_l1_store_n[0]); - async_l1_store_pending[0] = false; - } - if (async_l1_store_pending[1]) { - notify_l1_ready(async_l1_store_pool[1], async_l1_store_n[1]); - async_l1_store_pending[1] = false; - } + async_l1_store_pending[0] = false; + async_l1_store_pending[1] = false; } } }; @@ -1639,7 +1573,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { __syncwarp(); ptx::tma_store_wait<0>(); ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); } else { constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; #pragma unroll @@ -2228,15 +2161,9 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if constexpr (kAsyncL1TMAStore) { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); async_l1_store_pending[l1_store_stage] = true; - async_l1_store_pool[l1_store_stage] = pool_block_idx; - async_l1_store_n[l1_store_stage] = n_block_idx; async_l1_store_stage ^= 1u; } else { ptx::tma_store_wait<0>(); - if constexpr (MegaMoEPhase::signals_l1_ready) { - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); - } } } else { ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); @@ -2254,15 +2181,9 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl if constexpr (kAsyncL1TMAStore) { ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); async_l1_store_pending[l1_store_stage] = true; - async_l1_store_pool[l1_store_stage] = pool_block_idx; - async_l1_store_n[l1_store_stage] = n_block_idx; async_l1_store_stage ^= 1u; } else { ptx::tma_store_wait<0>(); - if constexpr (MegaMoEPhase::signals_l1_ready) { - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - notify_l1_ready(pool_block_idx, n_block_idx); - } } } const unsigned long long block_epilogue_end = phase_profile_clock(); @@ -2559,13 +2480,6 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl #endif } -template -CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void -sm90_fp8_mega_moe_impl(DG_SM90_FP8_MOE_KERNEL_ARGS_DECL) { - sm90_fp8_mega_moe_core( - DG_SM90_FP8_MOE_KERNEL_ARGS); -} - template CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void sm90_fp8_mega_moe_l1_impl(DG_SM90_FP8_MOE_KERNEL_ARGS_DECL) { diff --git a/deep_gemm/testing/bench.py b/deep_gemm/testing/bench.py index 0e75bbabe4..0c38468e7f 100644 --- a/deep_gemm/testing/bench.py +++ b/deep_gemm/testing/bench.py @@ -136,12 +136,10 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, for line in prof_lines: if any(name in line for name in kernel_names): phase = '?' - if 'sm90_fp8_mega_moe_impl<' in line: - template_args = line.split('sm90_fp8_mega_moe_impl<', 1)[1].split('>(void', 1)[0] - args = [arg.strip() for arg in template_args.split(',')] - if len(args) > 29: - phase_arg = args[29] - phase = {'1u': 'l1', '2u': 'l2', '0u': 'one'}.get(phase_arg, phase_arg) + if 'sm90_fp8_mega_moe_l1_impl<' in line: + phase = 'l1' + elif 'sm90_fp8_mega_moe_l2_impl<' in line: + phase = 'l2' fields = line.split() cuda_time = fields[-2] if len(fields) >= 2 else 'unknown' count = fields[-1] if fields else 'unknown' diff --git a/tests/bench_mega_moe_sm90.py b/tests/bench_mega_moe_sm90.py index ae104f4fd4..b470126b84 100644 --- a/tests/bench_mega_moe_sm90.py +++ b/tests/bench_mega_moe_sm90.py @@ -89,7 +89,7 @@ def _run_one_config(args, num_tokens, num_max_tokens_per_rank, # Stage inputs once; bench-loop re-copies them each call (bench helper expects # an idempotent ``fn``). - def run_fused(): + def run_sm90(): buffer.x[:num_tokens].copy_(x_fp8) buffer.x_sf[:num_tokens].copy_(x_sf) buffer.topk_idx[:num_tokens].copy_(topk_idx) @@ -118,25 +118,24 @@ def run_fused(): if args.ncu_profile_only: dist_print(f'[NCU] tokens={num_tokens} hidden={hidden} ih={intermediate_hidden}', once_in_node=True) - run_fused() + run_sm90() torch.cuda.synchronize() dist.barrier() buffer.destroy() return # Warm up + benchmark - run_fused() + run_sm90() dist.barrier() if phase_profile_enabled: cum_stats.zero_() torch.cuda.synchronize() dist.barrier() - t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', - barrier=lambda: dist.barrier(), - num_tests=args.num_tests, - suppress_kineto_output=True, - with_multiple_kernels=os.environ.get( - 'DG_SM90_MOE_SPLIT_L1_L2', '1') != '0') + t_sm90 = bench_kineto(run_sm90, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True, + with_multiple_kernels=True) # Count tokens that landed on this rank for stats gathered_topk_idx = uneven_all_gather(topk_idx, group=group) @@ -145,7 +144,7 @@ def run_fused(): num_recv_tokens = (gathered_topk_idx != -1).sum().item() safe_div = lambda a, b: float('nan') if b == 0 else a / b - tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_sm90) num_touched_experts = max(0, torch.unique(gathered_topk_idx.flatten()).numel() - 1) # FP8 weights = 1 byte, FP8 acts = 1 byte, BF16 output = 2 bytes num_hbm_bytes = ( @@ -156,12 +155,12 @@ def run_fused(): num_recv_tokens * intermediate_hidden + # L2 acts read num_recv_tokens * hidden * 2 # L2 out write ) - hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_sm90) if print_perf: dist_print( f' tokens={num_tokens:4d} recv={num_recv_tokens:5d} experts={num_touched_experts:4d} ' - f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', + f'{t_sm90 * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', once_in_node=True, ) if phase_profile_enabled: diff --git a/tests/bench_mega_moe_sm90_skew.py b/tests/bench_mega_moe_sm90_skew.py index 1d8b40a728..d99d99e7e4 100644 --- a/tests/bench_mega_moe_sm90_skew.py +++ b/tests/bench_mega_moe_sm90_skew.py @@ -166,7 +166,7 @@ def _run_one_config(args, num_tokens, num_max_tokens_per_rank, use_skew_hint = global_bias is not None use_masked_hint = args.masked_ratio > 0 - def run_fused(): + def run_sm90(): buffer.x[:num_tokens].copy_(x_fp8) buffer.x_sf[:num_tokens].copy_(x_sf) buffer.topk_idx[:num_tokens].copy_(topk_idx) @@ -208,18 +208,17 @@ def run_fused(): os.environ['DG_SM90_MOE_MASKED_HINT'] = old_masked_hint return y - run_fused() + run_sm90() dist.barrier() if phase_profile_enabled: cum_stats.zero_() torch.cuda.synchronize() dist.barrier() - t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', - barrier=lambda: dist.barrier(), - num_tests=args.num_tests, - suppress_kineto_output=True, - with_multiple_kernels=os.environ.get( - 'DG_SM90_MOE_SPLIT_L1_L2', '1') != '0') + t_sm90 = bench_kineto(run_sm90, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True, + with_multiple_kernels=True) # Per-rank token receive counts + per-local-expert distribution gathered_topk_idx = uneven_all_gather(topk_idx, group=group) @@ -243,7 +242,7 @@ def run_fused(): local_max_mean = 0.0 # Cross-rank aggregation - info = torch.tensor([t_fused, float(num_recv_tokens), float(num_touched_experts), + info = torch.tensor([t_sm90, float(num_recv_tokens), float(num_touched_experts), local_max_mean], device='cuda', dtype=torch.float64) gather_buf = [torch.zeros_like(info) for _ in range(num_ranks)] dist.all_gather(gather_buf, info, group=group) diff --git a/tests/bench_mega_moe_sm90_skewed.py b/tests/bench_mega_moe_sm90_skewed.py index b875c7ebb2..74cca611bd 100644 --- a/tests/bench_mega_moe_sm90_skewed.py +++ b/tests/bench_mega_moe_sm90_skewed.py @@ -112,7 +112,7 @@ def _run_one_config(args, num_tokens, num_max_tokens_per_rank, cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') use_skew_hint = args.skew_alpha > 0.0 - def run_fused(): + def run_sm90(): buffer.x[:num_tokens].copy_(x_fp8) buffer.x_sf[:num_tokens].copy_(x_sf) buffer.topk_idx[:num_tokens].copy_(topk_idx) @@ -138,14 +138,13 @@ def run_fused(): os.environ['DG_SM90_MOE_SKEW_HINT'] = old_skew_hint return y - run_fused() + run_sm90() dist.barrier() - t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', - barrier=lambda: dist.barrier(), - num_tests=args.num_tests, - suppress_kineto_output=True, - with_multiple_kernels=os.environ.get( - 'DG_SM90_MOE_SPLIT_L1_L2', '1') != '0') + t_sm90 = bench_kineto(run_sm90, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True, + with_multiple_kernels=True) # Local expert count distribution gathered_topk_idx = uneven_all_gather(topk_idx, group=group) @@ -174,7 +173,7 @@ def run_fused(): nonzero = 0 safe_div = lambda a, b: float('nan') if b == 0 else a / b - tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_sm90) num_hbm_bytes = ( num_touched_experts * intermediate_hidden * 2 * hidden + num_touched_experts * hidden * intermediate_hidden + @@ -183,12 +182,12 @@ def run_fused(): num_recv_tokens * intermediate_hidden + num_recv_tokens * hidden * 2 ) - hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_sm90) dist_print( f' tokens={num_tokens:5d} recv={num_recv_tokens:6d} nz_exp={nonzero:3d}/{num_experts_per_rank} ' f'max/mean={skew_max_mean:.2f} ' - f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', + f'{t_sm90 * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', once_in_node=True, ) diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py index ca6d38c539..f799553dd9 100644 --- a/tests/test_mega_moe_hopper.py +++ b/tests/test_mega_moe_hopper.py @@ -1,16 +1,17 @@ """ -H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。 +H200 (SM90 / Hopper) mega-MoE: split L1/L2 kernel + 同管线 baseline 性能对比。 结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8: - * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`), + * split SM90:调用 `deep_gemm.fp8_mega_moe` + (kernel symbols `sm90_fp8_mega_moe_l1_impl` + `sm90_fp8_mega_moe_l2_impl`), 使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。 * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine, 使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation - per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA + per-128-K SFA,而 SM90 split mega-MoE 的 L1 epilogue 为避免跨 CTA 同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照, 不是 bitwise apples-to-apples correctness oracle。 - * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us / - reduction us / `t_baseline / t_fused` legacy 比。 + * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / SM90 us / + reduction us / `t_baseline / t_sm90` legacy 比。 """ import argparse @@ -37,10 +38,9 @@ # 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口前缀同名, -# bench_kineto 用它从 trace 里汇总 fused 或 split mega-MoE 的 GPU 段 +# bench_kineto 用它从 trace 里汇总 split mega-MoE 的 L1/L2 GPU 段 SM90_KERNEL_NAME_PREFIX = "sm90_fp8_mega_moe" -SPLIT_VS_ONE_MAX_ABS_REL_TOL = 1e-4 -FUSED_VS_LEGACY_MEAN_ABS_REL_TOL = 7e-2 +SM90_VS_LEGACY_MEAN_ABS_REL_TOL = 7e-2 # FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准 @@ -49,7 +49,7 @@ # 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。 _FP8_E4M3_MAX_TL = tl.constexpr(448.0) L1_ACT_SF_GRAN = 128 -FUSED_L2_ACT_SF_GRAN = 64 +SM90_L2_ACT_SF_GRAN = 64 BASELINE_L2_ACT_SF_GRAN = 128 WEIGHT_SF_GRAN_MN = 128 WEIGHT_SF_GRAN_K = 128 @@ -59,7 +59,7 @@ # 模块 1:Triton SwiGLU + FP8 量化内核 # ---------------------------------------------------------------------------- # baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按 -# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则, +# per-128-K 输入;但 scale 数值采用 SM90 epilogue 同款 UE8M0/power-of-two 规则, # 避免再额外引入 exact-FP32-scale 差异。 # 输入 x : (M, 2*H) bf16,内层是 [gate_part | up_part] # 输入 topk_w : (M,) fp32,可选 @@ -377,13 +377,13 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除" ) - # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe): + # SM90 mega-MoE kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe): # * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF) # * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列) assert hidden % 128 == 0 assert intermediate_hidden % 128 == 0 assert intermediate_hidden // 64 <= 64, ( - f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}" + f"SM90 mega-MoE kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}" ) # ---- 创建 BF16 输入:token 与两层 weight ---- @@ -918,13 +918,13 @@ def try_assign(token: int, seen_experts: set[int]) -> bool: topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) topk_weights.masked_fill_(topk_idx < 0, 0) - # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖 + # 累计接收统计:SM90 与 baseline 各持一份避免相互覆盖 phase_profile_enabled = os.environ.get("DG_SM90_MOE_PHASE_PROFILE", "0") not in ("", "0") phase_profile_extra = 64 if phase_profile_enabled else 0 - cum_stats_fused = torch.zeros( + cum_stats_sm90 = torch.zeros( (num_experts_per_rank + phase_profile_extra,), dtype=torch.int, device="cuda" ) - cum_stats_baseline = cum_stats_fused.clone() + cum_stats_baseline = cum_stats_sm90.clone() # ---- BF16 → FP8 量化 ---- # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序) @@ -938,12 +938,12 @@ def try_assign(token: int, seen_experts: set[int]) -> bool: l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16) l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16) - # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变 + # SM90 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变 transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( l1_weights, l2_weights ) - # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致) + # SwiGLU clamp:finite -> 传给 SM90/triton;inf -> None(关闭 clamp,与 SM90 一致) clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None run_baseline_enabled = args.run_baseline or bool(args.check_output_diff) @@ -951,7 +951,7 @@ def try_assign(token: int, seen_experts: set[int]) -> bool: alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) - # ---- 分配 fused 的 SymmBuffer 与输出 buffer ---- + # ---- 分配 SM90 的 SymmBuffer 与输出 buffer ---- sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe( group, num_experts, @@ -960,12 +960,12 @@ def try_assign(token: int, seen_experts: set[int]) -> bool: hidden, intermediate_hidden, ) - y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + y_sm90 = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") use_eplb_hint = bool(eplb_replica_for) use_skew_hint = args.score_powerlaw_alpha > 0.0 use_masked_hint = args.masked_ratio > 0.0 - def run_fused(): + def run_sm90(): # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时 # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入 sym_buffer.x[:num_tokens].copy_(x_fp8[0]) @@ -984,11 +984,11 @@ def run_fused(): os.environ["DG_SM90_MOE_MASKED_HINT"] = "1" try: deep_gemm.fp8_mega_moe( - y_fused, + y_sm90, transformed_l1, transformed_l2, sym_buffer, - cumulative_local_expert_recv_stats=cum_stats_fused, + cumulative_local_expert_recv_stats=cum_stats_sm90, recipe=(128, 128, 128), activation="swiglu", activation_clamp=clamp_arg, @@ -1010,10 +1010,10 @@ def run_fused(): os.environ.pop("DG_SM90_MOE_MASKED_HINT", None) else: os.environ["DG_SM90_MOE_MASKED_HINT"] = old_masked_hint - return y_fused + return y_sm90 # ---- 打印 config ---- - dist_print("Config (H200 fused mega-MoE):", once_in_node=True) + dist_print("Config (H200 split SM90 mega-MoE):", once_in_node=True) dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True) dist_print( f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True @@ -1024,7 +1024,7 @@ def run_fused(): ) dist_print(f" > Masked ratio: {args.masked_ratio}", once_in_node=True) dist_print( - f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, " + f" > Activation SF: SM90 L2 per-{SM90_L2_ACT_SF_GRAN} UE8M0, " f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 " f"(SM90 grouped GEMM constraint)", once_in_node=True, @@ -1038,10 +1038,10 @@ def run_fused(): ) dist_print(once_in_node=True) - # 与社区版 test_mega_moe.py 对齐:NCU 模式只跑 fused kernel,避免 baseline 噪声。 + # 与社区版 test_mega_moe.py 对齐:NCU 模式只跑 SM90 mega-MoE kernel,避免 baseline 噪声。 if args.ncu_profile_only: - dist_print("Run fused SM90 mega-MoE kernel:", once_in_node=True) - y = run_fused() + dist_print("Run split SM90 mega-MoE kernel:", once_in_node=True) + y = run_sm90() torch.cuda.synchronize() assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16 dist_print(" > Done, exiting", once_in_node=True) @@ -1065,7 +1065,7 @@ def run_fused(): # ---------------------------------------------------------------- # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine - # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换** + # 与 SM90 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换** # 的版本(baseline grouped GEMM 不需要 gate/up interleave) # ---------------------------------------------------------------- def run_baseline(): @@ -1097,9 +1097,9 @@ def run_baseline(): ) # Triton SwiGLU + FP8 量化(含 topk 权重乘法) - # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K; + # 注意:SM90 split mega-MoE 的 L2 activation SFA 是 per-64-K; # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline - # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。 + # 只能用 per-128-K,但 scale 数值采用 SM90 同款 UE8M0/power-of-two。 l1_y = swiglu_apply_weight_to_fp8_triton( x=l1_y, topk_weights=recv_topk_weights, @@ -1122,47 +1122,11 @@ def run_baseline(): # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank return ep_buffer.combine(l2_y, handle=handle)[0] - # ---- 跑一次确保不报错(fused + 可选 baseline)---- - y = run_fused() + # ---- 跑一次确保不报错(SM90 split + 可选 baseline)---- + y = run_sm90() assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, ( - f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}" + f"SM90 split 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}" ) - if args.check_split_vs_one_kernel: - old_split_env = os.environ.get("DG_SM90_MOE_SPLIT_L1_L2") - try: - os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = "1" - y_split = run_fused().detach().clone() - torch.cuda.synchronize() - os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = "0" - y_one_kernel = run_fused().detach().clone() - torch.cuda.synchronize() - finally: - if old_split_env is None: - os.environ.pop("DG_SM90_MOE_SPLIT_L1_L2", None) - else: - os.environ["DG_SM90_MOE_SPLIT_L1_L2"] = old_split_env - diff = (y_split.float() - y_one_kernel.float()).abs() - denom = y_one_kernel.float().abs().mean().clamp_min(1e-12) - max_abs = diff.max() - mean_abs_rel = diff.mean().div(denom) - max_abs_threshold = denom * SPLIT_VS_ONE_MAX_ABS_REL_TOL - dist_print( - "Output diff (split two-kernel vs one-kernel):", once_in_node=True - ) - dist_print( - f" > max_abs={max_abs.item():.6e}, " - f"mean_abs={diff.mean().item():.6e}, " - f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " - f"max_abs_threshold={max_abs_threshold.item():.6e}", - once_in_node=True, - ) - dist_print(once_in_node=True) - assert max_abs <= max_abs_threshold, ( - "split two-kernel output differs from one-kernel output: " - f"max_abs={max_abs.item():.6e}, " - f"threshold={max_abs_threshold.item():.6e}, " - f"mean_ref={denom.item():.6e}" - ) if ep_buffer is not None: out_b = run_baseline() assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, ( @@ -1173,20 +1137,20 @@ def run_baseline(): denom = out_b.float().abs().mean().clamp_min(1e-12) mean_abs_rel = diff.mean().div(denom) dist_print( - "Output diff (fused vs legacy-per128 baseline):", once_in_node=True + "Output diff (SM90 split vs legacy-per128 baseline):", once_in_node=True ) dist_print( f" > max_abs={diff.max().item():.6e}, " f"mean_abs={diff.mean().item():.6e}, " f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " - f"threshold={FUSED_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}", + f"threshold={SM90_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}", once_in_node=True, ) dist_print(once_in_node=True) - assert mean_abs_rel <= FUSED_VS_LEGACY_MEAN_ABS_REL_TOL, ( - "fused output differs from legacy-per128 baseline: " + assert mean_abs_rel <= SM90_VS_LEGACY_MEAN_ABS_REL_TOL, ( + "SM90 split output differs from legacy-per128 baseline: " f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " - f"threshold={FUSED_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}" + f"threshold={SM90_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}" ) # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ---- @@ -1226,11 +1190,11 @@ def run_baseline(): max_expert_tokens = 0 # ---- benchmark ---- - # fused:bench_kineto 抓 sm90_fp8_mega_moe* 的 GPU 段(不含 host overhead) + # split SM90:bench_kineto 抓 sm90_fp8_mega_moe* 的 GPU 段(不含 host overhead) if phase_profile_enabled: - cum_stats_fused.zero_() - t_fused = bench_kineto( - run_fused, + cum_stats_sm90.zero_() + t_sm90 = bench_kineto( + run_sm90, SM90_KERNEL_NAME_PREFIX, num_tests=args.num_bench_tests, barrier=lambda: ep_buffer.barrier(use_comm_stream=False) @@ -1241,13 +1205,10 @@ def run_baseline(): if args.dump_profile_traces else None ), - with_multiple_kernels=os.environ.get( - "DG_SM90_MOE_SPLIT_L1_L2", - "1", - ) != "0", + with_multiple_kernels=True, ) if phase_profile_enabled: - cum_stats_fused.zero_() + cum_stats_sm90.zero_() torch.cuda.synchronize() if ep_buffer is not None: ep_buffer.barrier(use_comm_stream=False) @@ -1256,11 +1217,11 @@ def run_baseline(): phase_start = torch.cuda.Event(enable_timing=True) phase_end = torch.cuda.Event(enable_timing=True) phase_start.record() - run_fused() + run_sm90() phase_end.record() torch.cuda.synchronize() phase_event_us = phase_start.elapsed_time(phase_end) * 1000.0 - raw_i32 = cum_stats_fused[num_experts_per_rank:num_experts_per_rank + phase_profile_extra].detach().cpu().tolist() + raw_i32 = cum_stats_sm90[num_experts_per_rank:num_experts_per_rank + phase_profile_extra].detach().cpu().tolist() def _u64(slot: int) -> int: lo = raw_i32[slot * 2] & 0xffffffff hi = raw_i32[slot * 2 + 1] & 0xffffffff @@ -1302,7 +1263,7 @@ def safe_div(a, b): # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens tflops = safe_div( - 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused + 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_sm90 ) # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同) @@ -1322,7 +1283,7 @@ def safe_div(a, b): ) l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4 l2_act_sf_bytes = ( - num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4 + num_recv_tokens * (intermediate_hidden // SM90_L2_ACT_SF_GRAN) * 4 ) num_hbm_bytes = ( l1_weight_bytes @@ -1337,17 +1298,17 @@ def safe_div(a, b): + l2_act_sf_bytes # L2 输入读 (FP8 + SF) + num_recv_tokens * hidden * 2 # L2 输出写 (BF16) ) - hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_sm90) # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16 num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2) - nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused) + nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_sm90) # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s) t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12 - # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐 - approx_factor = t_fused / max(t_fused - t_reduction, 1e-12) + # overlap 校正:扣掉 SM90 中无法重叠的串行 reduction 段后估计稳态吞吐 + approx_factor = t_sm90 / max(t_sm90 - t_reduction, 1e-12) # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline tflops_baseline = safe_div( @@ -1367,7 +1328,7 @@ def fmt_perf_line( ) -> str: reduction = f"{reduction_us:13.1f}" if reduction_us is not None else f"{'-':>13}" speedup_text = ( - f"{speedup:6.2f}x {'fused faster' if speedup > 1 else 'baseline faster'}" + f"{speedup:6.2f}x {'SM90 faster' if speedup > 1 else 'baseline faster'}" if speedup is not None else f"{'-':>21}" ) @@ -1391,8 +1352,8 @@ def fmt_perf_line( ) dist_print( fmt_perf_line( - "[fused]", - t_fused, + "[sm90]", + t_sm90, tflops * approx_factor, hbm_gbs * approx_factor, nvlink_gbs * approx_factor, @@ -1400,7 +1361,7 @@ def fmt_perf_line( ) ) if ep_buffer is not None: - speedup = safe_div(t_baseline, t_fused) + speedup = safe_div(t_baseline, t_sm90) dist_print( fmt_perf_line( "[baseline]", @@ -1433,14 +1394,14 @@ def fmt_perf_line( if __name__ == "__main__": parser = argparse.ArgumentParser( - description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline" + description="H200 mega-MoE: split SM90 deep_gemm.fp8_mega_moe vs DeepEP+grouped-FP8 baseline" ) # 资源 parser.add_argument( "--ncu-profile-only", action="store_true", - help="只运行一次 fused SM90 kernel,便于 NCU/Nsight 采样", + help="只运行一次 split SM90 kernel,便于 NCU/Nsight 采样", ) parser.add_argument( "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)" @@ -1453,7 +1414,7 @@ def fmt_perf_line( ) # 模型形状 - # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096 + # 注:SM90 mega-MoE kernel 要求 intermediate_hidden <= 4096 parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192) parser.add_argument( "--num-tokens", @@ -1535,7 +1496,7 @@ def fmt_perf_line( "--fast-math", type=int, default=1, - help="fused 内 SwiGLU 是否启用 fast-math(0/1)", + help="SM90 kernel 内 SwiGLU 是否启用 fast-math(0/1)", ) # 测时 @@ -1543,7 +1504,7 @@ def fmt_perf_line( "--num-bench-tests", type=int, default=30, - help="bench_kineto 抓 fused 时的迭代数", + help="bench_kineto 抓 SM90 split kernel 时的迭代数", ) parser.add_argument( "--num-warmup", type=int, default=5, help="baseline cuda events warmup" @@ -1566,19 +1527,13 @@ def fmt_perf_line( "--check-output-diff", type=int, default=0, - help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)", - ) - parser.add_argument( - "--check-split-vs-one-kernel", - type=int, - default=0, - help="非 0 时打印 split two-kernel 与 one-kernel fused 的输出差异", + help="非 0 时打印 SM90 split 与 legacy-per128 baseline 的输出差异(预期非 bitwise)", ) parser.add_argument( "--dump-profile-traces", type=str, default="", - help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)", + help="非空时把 SM90 split kernel 的 Chrome trace 写到该目录(每 rank 一份)", ) args = parser.parse_args() From ad68c2d0ad0b195a4ed6a52d8cc058253140c55b Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 2 Jun 2026 17:59:34 +0800 Subject: [PATCH 11/15] Clean up SM90 MegaMoE split config path --- csrc/jit_kernels/heuristics/mega_moe.hpp | 100 ++---------------- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 48 +++++---- tests/test_mega_moe_sm90.py | 26 ++--- 3 files changed, 49 insertions(+), 125 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 2e33d4aa9e..443dd4f915 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -356,7 +356,7 @@ struct Sm90MoeHeuristicPolicy { return expected_tokens_per_expert >= low and expected_tokens_per_expert <= high; } - bool uses_split_bn256() const { + bool uses_bn256_main_tile() const { return block_m == 64 and block_n == 256; } @@ -454,7 +454,7 @@ struct Sm90MoeHeuristicPolicy { const bool& eplb_hint = false, const bool& skew_hint = false, const bool& masked_hint = false) const { - if (not uses_split_bn256() or not is_main_topk8()) + if (not uses_bn256_main_tile() or not is_main_topk8()) return false; if (device_profile == Sm90MoeDeviceProfile::H20) { @@ -482,7 +482,7 @@ struct Sm90MoeHeuristicPolicy { } bool direct_l2_scatter() const { - if (not uses_split_bn256()) + if (not uses_bn256_main_tile()) return false; Sm90MoeProfileConfig profile_config; if (device_profile_config(profile_config)) @@ -501,7 +501,7 @@ struct Sm90MoeHeuristicPolicy { } bool l2_nmajor_schedule(const bool& eplb_hint, const bool& skew_hint) const { - if (not uses_split_bn256() or not is_main_topk8()) + if (not uses_bn256_main_tile() or not is_main_topk8()) return false; Sm90MoeProfileConfig profile_config; if (device_profile_config(profile_config, false, eplb_hint, skew_hint)) @@ -514,7 +514,7 @@ struct Sm90MoeHeuristicPolicy { } bool one_warp_cleanup(const bool& masked_hint) const { - if (not uses_split_bn256()) + if (not uses_bn256_main_tile()) return false; Sm90MoeProfileConfig profile_config; if (device_profile_config(profile_config, false, false, false, masked_hint)) @@ -566,53 +566,6 @@ static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( }; } -static std::tuple get_block_config_for_mega_moe_sm90( - const int& num_ranks, const int& num_experts, - const int& num_max_tokens_per_rank, const int& num_topk, - const int& num_tokens) { - // Keep mma.sync decode variants opt-in; the default path uses M64 CTAs. - const int num_experts_per_rank = num_experts / num_ranks; - const float expected_tokens_per_expert = - static_cast(num_tokens) * num_topk / num_experts_per_rank; - const int requested_mma_m = get_env("DG_SM90_MOE_MMA_SYNC_M") > 0 - ? get_env("DG_SM90_MOE_MMA_SYNC_M") - : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); - DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); - const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); - const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); - DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); - DG_HOST_ASSERT(forced_epilogue_warpgroups == 0 or - forced_epilogue_warpgroups == 1 or - forced_epilogue_warpgroups == 2); - const bool use_mma_sync_decode = - requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); - const bool use_bn256_split_n = - get_env("DG_SM90_MOE_BN256_2WG", 1) != 0 and - forced_block_m != 128 and not use_mma_sync_decode; - const bool use_bn256_seq_n = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and - forced_block_m != 128 and not use_mma_sync_decode; - const bool use_b_stationary_2wg = - get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; - DG_HOST_ASSERT(not (use_bn256_split_n and use_bn256_seq_n)); - DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n or use_bn256_seq_n))); - const int block_m = forced_block_m > 0 - ? forced_block_m - : (use_b_stationary_2wg ? 128 : (use_mma_sync_decode ? requested_mma_m : 64)); - const int num_epilogue_warpgroups = forced_epilogue_warpgroups > 0 - ? forced_epilogue_warpgroups - : ((block_m == 128 or use_b_stationary_2wg or use_bn256_split_n) ? 2 : 1); - DG_HOST_ASSERT(block_m % num_epilogue_warpgroups == 0); - DG_HOST_ASSERT((block_m != 16 and block_m != 32) or num_epilogue_warpgroups == 1); - DG_HOST_ASSERT(block_m != 128 or num_epilogue_warpgroups == 2); - - DG_HOST_ASSERT(std::any_of( - layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, - [=](const auto& candidate) { return candidate == block_m; }) - ); - return {block_m, num_epilogue_warpgroups * 128}; -} - static int get_num_experts_per_wave_for_mega_moe_sm90( const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) { @@ -635,33 +588,6 @@ static int get_num_experts_per_wave_for_mega_moe_sm90( intermediate_hidden, block_m, block_n, num_sms); } -static bool get_sm90_moe_direct_l2_scatter_default( - const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, - const int& intermediate_hidden, const int& block_m, const int& block_n) { - return get_sm90_moe_heuristic_policy( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, block_m, block_n).direct_l2_scatter(); -} - -static bool get_sm90_moe_l2_nmajor_schedule_default( - const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, - const int& intermediate_hidden, const int& block_m, const int& block_n) { - return get_sm90_moe_heuristic_policy( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, block_m, block_n).l2_nmajor_schedule( - get_env("DG_SM90_MOE_EPLB_HINT", 0) != 0, - get_env("DG_SM90_MOE_SKEW_HINT", 0) != 0); -} - -static bool get_sm90_moe_one_warp_cleanup_default( - const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, - const int& intermediate_hidden, const int& block_m, const int& block_n) { - return get_sm90_moe_heuristic_policy( - num_experts_per_rank, num_tokens, num_topk, - intermediate_hidden, block_m, block_n).one_warp_cleanup( - get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0); -} - static std::pair get_pipeline_config_for_mega_moe_sm90( const int& smem_capacity, const int& num_experts, const int& hidden, @@ -728,10 +654,10 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( // Select the retained stage count for the current shape. const int max_num_stages = (smem_capacity - smem_fixed) / (smem_per_stage + smem_barriers_per_stage); - const bool prefer_bn256_split = block_n == 256; + const bool prefer_bn256_n_tile = block_n == 256; const int preferred_num_stages = default_num_stages > 0 ? std::min(default_num_stages, max_num_stages) - : (prefer_bn256_split ? std::min(4, max_num_stages) : 0); + : (prefer_bn256_n_tile ? std::min(4, max_num_stages) : 0); const int forced_num_stages = get_env("DG_SM90_MOE_NUM_STAGES"); const int num_stages = forced_num_stages > 0 ? std::min(forced_num_stages, max_num_stages) @@ -862,11 +788,7 @@ static std::vector get_mega_moe_config_candidates_sm90( const bool use_bn256_split_n_env = get_env("DG_SM90_MOE_BN256_2WG", 1) != 0 and forced_block_m != 128 and not use_mma_sync_decode; - const bool use_bn256_seq_n_env = - get_env("DG_SM90_MOE_BN256_SEQ") != 0 and - forced_block_m != 128 and not use_mma_sync_decode; - DG_HOST_ASSERT(not (use_bn256_split_n_env and use_bn256_seq_n_env)); - DG_HOST_ASSERT(not (use_b_stationary_2wg and (use_bn256_split_n_env or use_bn256_seq_n_env))); + DG_HOST_ASSERT(not (use_b_stationary_2wg and use_bn256_split_n_env)); std::vector block_m_candidates; if (forced_block_m > 0) { @@ -897,7 +819,7 @@ static std::vector get_mega_moe_config_candidates_sm90( std::vector block_n_candidates; if (block_m == 64 and not use_mma_sync_decode and - (use_bn256_split_n_env or use_bn256_seq_n_env)) { + use_bn256_split_n_env) { append_unique_moe_candidate(block_n_candidates, 256); if (extra_block_shapes) append_unique_moe_candidate(block_n_candidates, 128); @@ -983,11 +905,11 @@ static std::vector get_mega_moe_config_candidates_sm90( auto l2_nmajor_candidates = get_sm90_moe_bool_candidates( "DG_SM90_MOE_L2_NMAJOR", l2_nmajor_schedule_default, - extra_modes and policy.uses_split_bn256() and policy.is_main_topk8()); + extra_modes and policy.uses_bn256_main_tile() and policy.is_main_topk8()); auto cleanup_candidates = get_sm90_moe_bool_candidates( "DG_SM90_MOE_ONE_WARP_CLEANUP", one_warp_cleanup_default, - extra_modes and policy.uses_split_bn256()); + extra_modes and policy.uses_bn256_main_tile()); const int default_epw = get_num_experts_per_wave_for_mega_moe_sm90( num_experts_per_rank, num_tokens, num_topk, diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index f8522709c7..89c7713588 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -64,16 +64,6 @@ struct MegaMoEPhasePolicy { static constexpr bool needs_dispatch_pull = runs_linear1; static constexpr bool needs_combine = runs_linear2; - CUTLASS_DEVICE static bool is_linear1_phase(const sched::BlockPhase& block_phase) { - if constexpr (is_linear1_only) { - (void)block_phase; - return true; - } else { - (void)block_phase; - return false; - } - } - template CUTLASS_DEVICE static void for_each_selected_block(Scheduler& scheduler, Func&& func) { if constexpr (is_linear1_only) { @@ -357,6 +347,19 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + constexpr uint32_t kCombineHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kCombineChunkSlots = 3; + constexpr uint32_t kCombineMaxRegistersForBuffer = 128; + constexpr uint32_t kCombineNumChunks = + (kCombineChunkSlots * kNumEpilogueWarps * kCombineHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE + and kHidden <= 32 * kCombineMaxRegistersForBuffer) ? 1 : 2; + constexpr uint32_t kCombineChunkBytes = kCombineHiddenBytes / kCombineNumChunks; + constexpr uint32_t SMEM_COMBINE_ALIAS_SIZE = MegaMoEPhase::needs_combine + ? kCombineChunkSlots * kNumEpilogueWarps * kCombineChunkBytes : 0u; + DG_STATIC_ASSERT(kHidden % kCombineNumChunks == 0, "Hidden must be divisible by number of combine chunks"); + DG_STATIC_ASSERT(SMEM_COMBINE_ALIAS_SIZE <= SMEM_BEFORE_BARRIER_SIZE, + "Combine SMEM alias exceeds the pre-barrier scratch region"); + // SMEM pointers auto smem_expert_count = reinterpret_cast(smem_buffer); const auto smem_send_buffers = layout::Buffer( @@ -833,7 +836,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); + (void)block_phase; + constexpr bool is_linear1_phase = MegaMoEPhase::runs_linear1; const auto tensor_map_a_ptr = !is_linear1_phase ? &tensor_map_l2_acts : &tensor_map_l1_acts; const auto tensor_map_sfa_ptr = !is_linear1_phase @@ -906,7 +910,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); + (void)block_phase; + constexpr bool is_linear1_phase = MegaMoEPhase::runs_linear1; const auto tensor_map_b_ptr = !is_linear1_phase ? &tensor_map_l2_weights : &tensor_map_l1_weights; @@ -937,7 +942,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); + (void)block_phase; + constexpr bool is_linear1_phase = MegaMoEPhase::runs_linear1; (void)local_expert_idx; (void)n_block_idx; const auto tensor_map_sfa_ptr = !is_linear1_phase @@ -1066,7 +1072,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const uint32_t& local_expert_idx, const uint32_t& num_k_blocks, const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - const bool is_linear1_phase = MegaMoEPhase::is_linear1_phase(block_phase); + (void)block_phase; + constexpr bool is_linear1_phase = MegaMoEPhase::runs_linear1; const uint32_t valid_m = scheduler.template get_valid_m(); const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; const uint32_t m_idx = pool_block_idx * BLOCK_M; @@ -2366,19 +2373,14 @@ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_bl ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); const unsigned long long combine_reduce_start = phase_profile_clock(); - constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumHiddenBytes = kCombineHiddenBytes; constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); - constexpr uint32_t kNumChunkSlots = 3; - constexpr uint32_t kNumMaxRegistersForBuffer = 128; - constexpr uint32_t kNumChunks = - (kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE - and kHidden <= 32 * kNumMaxRegistersForBuffer) ? 1 : 2; - constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkSlots = kCombineChunkSlots; + constexpr uint32_t kNumChunks = kCombineNumChunks; + constexpr uint32_t kNumChunkBytes = kCombineChunkBytes; constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; - DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); - DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements"); diff --git a/tests/test_mega_moe_sm90.py b/tests/test_mega_moe_sm90.py index f38a5be5c8..33e5a1c3ef 100644 --- a/tests/test_mega_moe_sm90.py +++ b/tests/test_mega_moe_sm90.py @@ -1,6 +1,6 @@ """Layered tests for the SM90 (Hopper) MegaMoE kernel. -The fused FP8 SM90 MegaMoE kernel is exercised across a hierarchy of +The split FP8 SM90 MegaMoE kernel is exercised across a hierarchy of scenarios so that each kernel path / heuristic branch / edge case is covered with at least one configuration. @@ -9,8 +9,8 @@ L1 Smoke : single tiny config; only verifies the kernel runs and produces an output close to a PyTorch reference. L2 Heuristic : sweeps tokens-per-expert across the bands of - ``get_block_config_for_mega_moe_sm90`` so each - ``{block_m, num_epilogue_warpgroups}`` case is hit. + the SM90 config selector so the main branch buckets + are covered. L3 Shape sweep : sweeps ``hidden``, ``intermediate_hidden`` and ``num_topk`` over divisible-by-128 values. L4 Edge cases : masking ratio, activation clamp (finite vs inf), @@ -19,10 +19,10 @@ Notes ----- -* The reference is a pure PyTorch BF16/FP32 simulation of the fused path +* The reference is a pure PyTorch BF16/FP32 simulation of the split path (dequantize -> matmul -> SwiGLU + clamp + per-row quantize -> matmul -> cross-rank scatter -> BF16 reduce). It is *not* bitwise-identical to - the kernel; correctness is checked with ``calc_diff < 0.07``. + the kernel; correctness is checked with ``calc_diff < 0.01`` by default. * Because every scenario allocates its own symmetric memory buffer we re-`init_dist`/`destroy` once per process at the outer level only, and re-create ``SymmBuffer`` per scenario. @@ -301,7 +301,7 @@ def _trace(stage: str): ) cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') - # ---- Run fused ----------------------------------------------------------- + # ---- Run SM90 MegaMoE ---------------------------------------------------- _trace('copy_inputs') buffer.x[:num_tokens].copy_(x_fp8) buffer.x_sf[:num_tokens].copy_(x_sf) @@ -309,7 +309,7 @@ def _trace(stage: str): buffer.topk_weights[:num_tokens].copy_(topk_w) y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - _trace('launch_fused (may JIT-compile, can take minutes)') + _trace('launch_sm90 (may JIT-compile, can take minutes)') deep_gemm.fp8_mega_moe( y_fused, transformed_l1, transformed_l2, buffer, cumulative_local_expert_recv_stats=cum_stats, @@ -318,9 +318,9 @@ def _trace(stage: str): activation_clamp=activation_clamp if math.isfinite(activation_clamp) else None, fast_math=fast_math, ) - _trace('sync_fused') + _trace('sync_sm90') torch.cuda.synchronize() - _trace('fused_done') + _trace('sm90_done') # ---- Reference & check --------------------------------------------------- # Use the FP8 weights and their block-(128, 128) SF directly — the dequant @@ -368,8 +368,8 @@ def _layer1_smoke() -> List[Tuple[str, Dict[str, Any]]]: def _layer2_heuristic_branches(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: - """Vary tokens / (num_experts * num_topk / num_ranks) so each - ``get_block_config_for_mega_moe_sm90`` band fires at least once. + """Vary tokens / (num_experts * num_topk / num_ranks) so the selector's + main tokens-per-expert buckets fire at least once. The heuristic decides on ``avg_tokens_per_expert``; we approximate by setting ``num_max_tokens_per_rank`` and ``num_topk`` while keeping @@ -518,8 +518,8 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): help='Layer 5 stress test count') parser.add_argument('--filter', type=str, default='', help='Substring filter on scenario names') - parser.add_argument('--diff-tol', type=float, default=0.07, - help='calc_diff tolerance (default: 0.07)') + parser.add_argument('--diff-tol', type=float, default=0.01, + help='calc_diff tolerance (default: 0.01)') parser.add_argument('--fail-fast', action='store_true', help='Stop on first failing scenario') args = parser.parse_args() From 6e66e1d9eff252b1a8e0bb42ba05f23274c12a52 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 2 Jun 2026 18:18:03 +0800 Subject: [PATCH 12/15] Drop unused SM90 MegaMoE helper files --- docs/sm90_megamoe_config_search.md | 122 --- scripts/run_ncu_mega_moe_sm90.sh | 89 -- tests/bench_mega_moe_sm90_skew.py | 385 ------- tests/bench_mega_moe_sm90_skewed.py | 249 ----- tests/test_mega_moe_hopper.py | 1551 --------------------------- 5 files changed, 2396 deletions(-) delete mode 100644 docs/sm90_megamoe_config_search.md delete mode 100755 scripts/run_ncu_mega_moe_sm90.sh delete mode 100644 tests/bench_mega_moe_sm90_skew.py delete mode 100644 tests/bench_mega_moe_sm90_skewed.py delete mode 100644 tests/test_mega_moe_hopper.py diff --git a/docs/sm90_megamoe_config_search.md b/docs/sm90_megamoe_config_search.md deleted file mode 100644 index d94dcf8949..0000000000 --- a/docs/sm90_megamoe_config_search.md +++ /dev/null @@ -1,122 +0,0 @@ -# SM90 MegaMoE Config Search - -This note records the H20 SM90 MegaMoE split-kernel defaults before moving the -selector to the DeepGEMM-style candidate-search path. - -## Selector Model - -Regular DeepGEMM GEMM selects parameters by: - -1. generating legal config candidates, -2. estimating a cheap `LayoutInfo`-like score from the shape, -3. choosing the best candidate deterministically, -4. JIT-compiling only the selected kernel variant. - -SM90 MegaMoE now follows the same shape. The selector builds complete -`MegaMoESM90Config` candidates containing tile sizes, wave grouping, stage -count, thread layout, and scheduler/epilogue mode flags. It then ranks them -with a block/wave score plus an empirical calibration layer for MoE-specific -dispatch/combine choices. - -This is not online autotuning. Runtime launch does not benchmark candidate -kernels. The default search is deterministic and keeps the H20 empirical -choices as calibration targets. - -## Search Space - -Default candidate dimensions: - -- `block_m`: default `64`; `128` can be included with - `DG_SM90_MOE_SEARCH_BLOCK_SHAPES=1`. -- `block_n`: default `256` for split L1/L2, otherwise `128`; alternate block-N - candidates can be included with `DG_SM90_MOE_SEARCH_BLOCK_SHAPES=1`. -- `num_epilogue_threads`: derived from block shape; `64x256` uses two epilogue - warpgroups. -- `num_dispatch_threads`: compact split frontend uses `64`; otherwise `128`. -- `direct_l2_scatter`: candidate default from empirical policy, alternate - included unless forced by `DG_SM90_MOE_DIRECT_L2_SCATTER`. -- `l2_nmajor_schedule`: candidate default from empirical policy, alternate - included unless forced by `DG_SM90_MOE_L2_NMAJOR`. -- `one_warp_cleanup`: candidate default from empirical policy, alternate - included unless forced by `DG_SM90_MOE_ONE_WARP_CLEANUP`. -- `num_stages`: empirical 4/5-stage target plus alternate 4/5 candidates unless - forced by `DG_SM90_MOE_NUM_STAGES`; impossible forced values are clamped to - the shared-memory-limited maximum stage count. -- `num_experts_per_wave`: empirical target plus `16` and full-rank candidates - when legal unless forced by `DG_SM90_MOE_EXPERTS_PER_WAVE`. - -Debugging: - -```bash -DG_PRINT_CONFIGS=1 python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 8 16 32 --num-tests 1 -DG_SM90_MOE_PRINT_SEARCH=2 python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 128 --num-tests 1 -``` - -## H20 Empirical Reference - -Configuration: `num_ranks=8`, `hidden=7168`, `intermediate_hidden=2048`, -`num_experts=256`, `num_experts_per_rank=32`, `topk=8`. - -`expected_tokens_per_expert = M * topk / num_experts_per_rank = M / 4`. - -| M | expected/expert | block_m | block_n | stages | dispatch/non-epi/epi threads | experts/wave | direct L2 scatter | L2 N-major | one-warp cleanup | -|---:|---:|---:|---:|---:|---:|---:|---|---|---| -| 8 | 2 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | off | on | -| 16 | 4 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | off | on | -| 32 | 8 | 64 | 256 | 4 | 64 / 64 / 256 | 32 | on | off | on | -| 64 | 16 | 64 | 256 | 5 | 64 / 64 / 256 | 32 | on | off | on | -| 128 | 32 | 64 | 256 | 5 | 64 / 64 / 256 | 32 | on | off | on | -| 256 | 64 | 64 | 256 | 4 | 64 / 64 / 256 | 32 | on | off | on | -| 260 | 65 | 64 | 256 | 4 | 64 / 64 / 256 | 32 | off | off | off | -| 512 | 128 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | off | off | -| 1024 | 256 | 64 | 256 | 5 | 64 / 64 / 256 | 16 | on | on | off | - -These are the pre-search empirical defaults. Focused H20 sweeps after adding -candidate search found two better defaults for single-size runs: - -- `M=32`: `num_experts_per_wave=16` was faster than the previous `32`. -- `M=260`: enabling both direct L2 scatter and one-warp cleanup was faster than - the previous boundary behavior that turned both off after - `expected_tokens_per_expert > 64`. - -The selector calibration has been updated for those two cases. Other tested -differences, such as 4-stage vs 5-stage at `M=512` and L2 N-major at `M=1024`, -were not stable enough to change the default. - -## Validation Log - -Commands: - -```bash -DG_PRINT_CONFIGS=1 python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 8 16 32 64 128 256 260 512 1024 --num-tests 1 -python3 tests/test_mega_moe_sm90.py --num-processes 1 --layers 1 2 --fail-fast -python3 tests/test_mega_moe_sm90.py --num-processes 2 --layers 1 2 --fail-fast -python3 tests/bench_mega_moe_sm90.py --num-processes 8 --batches 8 16 32 64 128 256 260 512 --num-tests 20 -``` - -Correctness after candidate selector: - -- `--num-processes 1 --layers 1 2 --fail-fast`: passed all 5 scenarios, - max printed diff `0.0006`. -- `--num-processes 2 --layers 1 2 --fail-fast`: passed all 5 scenarios, - max printed diff `0.0006`. - -Focused H20 benchmark after candidate selector: - -```text -M latency us -8 849.8 -16 806.4 -32 748.7 -64 799.9 -128 864.0 -256 1160.3 -260 1174.0 -512 1894.0 -1024 3226.0 -``` - -The strongest improvements versus the pre-search defaults are at `M=32` -(`experts_per_wave=16`) and `M=260` (direct L2 scatter plus one-warp cleanup). -Other search candidates remain available through env overrides, but the default -calibration was not changed when repeated measurements were noisy or neutral. diff --git a/scripts/run_ncu_mega_moe_sm90.sh b/scripts/run_ncu_mega_moe_sm90.sh deleted file mode 100755 index e8c2b0ef61..0000000000 --- a/scripts/run_ncu_mega_moe_sm90.sh +++ /dev/null @@ -1,89 +0,0 @@ -#!/bin/bash - -# SM90 (Hopper) variant of run_ncu_mega_moe.sh -# Drives `tests/bench_mega_moe_sm90.py` with NCU, profiling the -# `sm90_fp8_mega_moe_impl` kernel for a single batch size. - -set -e - -num_processes=8 -output_dir=work_sm90 -python_args=() -for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do - arg="${!arg_idx}" - case "$arg" in - --num-processes) - python_args+=("$arg") - if ((arg_idx < $#)); then - ((arg_idx++)) - num_processes="${!arg_idx}" - python_args+=("$num_processes") - fi - ;; - -h|--help) - echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]" - exit 0 - ;; - --num-processes=*) - num_processes="${arg#*=}" - python_args+=("$arg") - ;; - -o|--output) - if ((arg_idx < $#)); then - ((arg_idx++)) - output_dir="${!arg_idx}" - fi - ;; - --output=*) - output_dir="${arg#*=}" - ;; - *) - python_args+=("$arg") - ;; - esac -done - -echo "Python Args: ${python_args[*]}" -echo "Num Processes: $num_processes" -echo "Output Dir: $output_dir" -mkdir -p "$output_dir" - -export DG_JIT_WITH_LINEINFO=1 - -echo "Warm up JIT cache" -python tests/bench_mega_moe_sm90.py --ncu-profile-only "${python_args[@]}" - -sleep 2 - -ncu_args=( - --config-file off - --force-overwrite - --kernel-name sm90_fp8_mega_moe_impl - --import-source yes - --replay-mode application - --section SpeedOfLight - --section LaunchStats - --section SchedulerStats - --section WarpStateStats - --section MemoryWorkloadAnalysis - --section InstructionStats - --launch-skip 0 - --launch-count 1 - --clock-control none - --kill yes - --app-replay-buffer memory -) - -echo "Run Job" - -for ((i = 0; i < num_processes; ++i)); do - ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe-sm90.$i" \ - python tests/bench_mega_moe_sm90.py \ - --local-rank-idx=$i \ - --ncu-profile-only \ - "${python_args[@]}" & -done - -echo "Waiting" -wait -echo "Done" diff --git a/tests/bench_mega_moe_sm90_skew.py b/tests/bench_mega_moe_sm90_skew.py deleted file mode 100644 index d99d99e7e4..0000000000 --- a/tests/bench_mega_moe_sm90_skew.py +++ /dev/null @@ -1,385 +0,0 @@ -"""SM90 (Hopper) MegaMoE benchmark — patched to support power-law routing -and aggregate per-rank timing. - -Adds vs upstream `tests/bench_mega_moe_sm90.py`: - --skew-alpha Zipf exponent for expert popularity (0 = uniform, upstream default) - --skew-gain Scale of log-bias added to scores (default 1.0) - --skew-seed Seed for the per-expert bias permutation. Constant across - ranks so all ranks see the SAME hot/cold expert assignment - (matches production: gate is shared globally). - -Output is aggregated across all ranks: - - per-rank t_us, recv_tokens (max/mean/min) - - per-rank max/mean ratio (intra-rank imbalance) - - inter-rank max/mean t_us ratio (cross-rank imbalance) -""" -import argparse -import os -import random -import sys -import torch -import torch.distributed as dist -from typing import Tuple - -REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if REPO_ROOT not in sys.path: - sys.path.insert(0, REPO_ROOT) - -import deep_gemm -from deep_gemm.utils import per_token_cast_to_fp8 -from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather -from deep_gemm.testing import bench_kineto, get_arch_major - - -def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g, n, k = w.shape - assert n % 128 == 0 and k % 128 == 0 - chunk_g = 4 - w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) - sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) - for start in range(0, g, chunk_g): - end = min(start + chunk_g, g) - w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() - sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / 448.0 - w_fp8[start:end].copy_( - (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k)) - sf[start:end].copy_(sf_chunk) - return w_fp8, sf.contiguous() - - -def _make_global_bias(num_experts: int, alpha: float, gain: float, seed: int): - """Power-law (Zipf) bias per expert, identical across all ranks. - - Generated with a deterministic seed so all ranks see the same hot/cold - expert assignment. Returns None if alpha == 0 (uniform). - """ - if alpha <= 0: - return None - gen = torch.Generator(device='cuda') - gen.manual_seed(seed) - ranks = torch.arange(1, num_experts + 1, device='cuda', dtype=torch.float) - probs = 1.0 / ranks.pow(alpha) - probs = probs / probs.sum() - bias = torch.log(probs * num_experts) * gain - perm = torch.randperm(num_experts, device='cuda', generator=gen) - return bias[perm] - - -def _run_one_config(args, num_tokens, num_max_tokens_per_rank, - hidden, intermediate_hidden, - num_experts, num_topk, num_ranks, rank_idx, group, - activation_clamp, fast_math, - print_perf=True): - num_experts_per_rank = num_experts // num_ranks - assert num_tokens <= num_max_tokens_per_rank - - buffer = deep_gemm.get_symm_buffer_for_mega_moe( - group, num_experts, - num_max_tokens_per_rank, num_topk, - hidden, intermediate_hidden, - ) - - x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - l1_bf = torch.randn( - (num_experts_per_rank, intermediate_hidden * 2, hidden), - dtype=torch.bfloat16, device='cuda') * 0.05 - l2_bf = torch.randn( - (num_experts_per_rank, hidden, intermediate_hidden), - dtype=torch.bfloat16, device='cuda') * 0.05 - - # Per-rank random scores + global skew bias (added to all rows) - scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') - global_bias = _make_global_bias(num_experts, args.skew_alpha, args.skew_gain, - args.skew_seed) - if global_bias is not None: - scores = scores + global_bias.unsqueeze(0) - - replica_for = {} - if args.num_redundant_experts > 0: - assert args.num_redundant_experts % num_ranks == 0, 'redundant experts must divide ranks' - num_replicas_per_rank = args.num_redundant_experts // num_ranks - assert 0 < num_replicas_per_rank < num_experts_per_rank, 'invalid redundant expert count' - replica_slots = [] - for r in range(num_ranks): - base = r * num_experts_per_rank - replica_slots += list(range(base + num_experts_per_rank - num_replicas_per_rank, - base + num_experts_per_rank)) - logical_mask = torch.ones(num_experts, dtype=torch.bool, device='cuda') - logical_mask[torch.tensor(replica_slots, dtype=torch.long, device='cuda')] = False - scores[:, ~logical_mask] = -float('inf') - if global_bias is not None: - hot_order = torch.argsort(global_bias.masked_fill(~logical_mask, -float('inf')), descending=True).tolist() - else: - hot_order = torch.arange(num_experts, device='cuda')[logical_mask].tolist() - hot_experts = hot_order[:args.num_redundant_experts] - replica_for = {int(h): int(s) for h, s in zip(hot_experts, replica_slots)} - if rank_idx == 0 and print_perf: - print( - f'eplb_sim redundant={args.num_redundant_experts} ' - f'replicas_per_rank={num_replicas_per_rank} ' - f'dispatch={args.replica_dispatch}', - flush=True, - ) - - topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) - if replica_for and args.replica_dispatch == 'hash': - token_ids = torch.arange(num_tokens, device='cuda')[:, None] - slot_ids = torch.arange(num_topk, device='cuda')[None, :] - choose_replica = ((token_ids * num_topk + slot_ids + rank_idx) & 1).bool() - mapped = topk_idx.clone() - for logical_expert, replica_slot in replica_for.items(): - mapped = torch.where((topk_idx == logical_expert) & choose_replica, - torch.full_like(mapped, replica_slot), mapped) - topk_idx = mapped - elif replica_for and args.replica_dispatch == 'static': - mapped = topk_idx.clone() - for logical_expert, replica_slot in replica_for.items(): - logical_rank = logical_expert // num_experts_per_rank - replica_rank = replica_slot // num_experts_per_rank - if rank_idx == logical_rank: - chosen = logical_expert - elif rank_idx == replica_rank: - chosen = replica_slot - else: - chosen = replica_slot if ((rank_idx + logical_expert) & 1) else logical_expert - if chosen != logical_expert: - mapped = torch.where(topk_idx == logical_expert, - torch.full_like(mapped, chosen), mapped) - topk_idx = mapped - if args.masked_ratio > 0: - rand_mask = torch.rand_like(topk_idx, dtype=torch.float) - topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) - topk_w.masked_fill_(topk_idx < 0, 0) - - x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, - use_packed_ue8m0=False) - l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) - l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) - transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( - (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), - ) - - phase_profile_enabled = os.environ.get('DG_SM90_MOE_PHASE_PROFILE', '0') != '0' - phase_profile_ints = 64 if phase_profile_enabled else 0 - cum_stats = torch.zeros(num_experts_per_rank + phase_profile_ints, dtype=torch.int, device='cuda') - use_eplb_hint = bool(replica_for) - use_skew_hint = global_bias is not None - use_masked_hint = args.masked_ratio > 0 - - def run_sm90(): - buffer.x[:num_tokens].copy_(x_fp8) - buffer.x_sf[:num_tokens].copy_(x_sf) - buffer.topk_idx[:num_tokens].copy_(topk_idx) - buffer.topk_weights[:num_tokens].copy_(topk_w) - y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - old_eplb_hint = os.environ.get('DG_SM90_MOE_EPLB_HINT') - old_skew_hint = os.environ.get('DG_SM90_MOE_SKEW_HINT') - old_masked_hint = os.environ.get('DG_SM90_MOE_MASKED_HINT') - if use_eplb_hint: - os.environ['DG_SM90_MOE_EPLB_HINT'] = '1' - if use_skew_hint: - os.environ['DG_SM90_MOE_SKEW_HINT'] = '1' - if use_masked_hint: - os.environ['DG_SM90_MOE_MASKED_HINT'] = '1' - try: - deep_gemm.fp8_mega_moe( - y, transformed_l1, transformed_l2, buffer, - cumulative_local_expert_recv_stats=cum_stats, - recipe=(128, 128, 128), - activation='swiglu', - activation_clamp=activation_clamp, - fast_math=fast_math, - ) - finally: - if use_eplb_hint: - if old_eplb_hint is None: - os.environ.pop('DG_SM90_MOE_EPLB_HINT', None) - else: - os.environ['DG_SM90_MOE_EPLB_HINT'] = old_eplb_hint - if use_skew_hint: - if old_skew_hint is None: - os.environ.pop('DG_SM90_MOE_SKEW_HINT', None) - else: - os.environ['DG_SM90_MOE_SKEW_HINT'] = old_skew_hint - if use_masked_hint: - if old_masked_hint is None: - os.environ.pop('DG_SM90_MOE_MASKED_HINT', None) - else: - os.environ['DG_SM90_MOE_MASKED_HINT'] = old_masked_hint - return y - - run_sm90() - dist.barrier() - if phase_profile_enabled: - cum_stats.zero_() - torch.cuda.synchronize() - dist.barrier() - t_sm90 = bench_kineto(run_sm90, 'sm90_fp8_mega_moe', - barrier=lambda: dist.barrier(), - num_tests=args.num_tests, - suppress_kineto_output=True, - with_multiple_kernels=True) - - # Per-rank token receive counts + per-local-expert distribution - gathered_topk_idx = uneven_all_gather(topk_idx, group=group) - local_mask = ((gathered_topk_idx >= rank_idx * num_experts_per_rank) & - (gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)) - gathered_local = gathered_topk_idx.clone() - gathered_local[~local_mask] = -1 - num_recv_tokens = int(local_mask.sum().item()) - num_touched_experts = max(0, int(torch.unique(gathered_local.flatten()).numel()) - 1) - - if num_recv_tokens > 0: - per_expert_count = torch.zeros(num_experts_per_rank, dtype=torch.int64, device='cuda') - local_ids = gathered_local[local_mask] - rank_idx * num_experts_per_rank - per_expert_count.scatter_add_(0, local_ids.to(torch.int64), - torch.ones_like(local_ids, dtype=torch.int64)) - counts = per_expert_count.cpu().tolist() - local_max = max(counts) - local_mean = sum(counts) / num_experts_per_rank - local_max_mean = local_max / local_mean if local_mean > 0 else 0.0 - else: - local_max_mean = 0.0 - - # Cross-rank aggregation - info = torch.tensor([t_sm90, float(num_recv_tokens), float(num_touched_experts), - local_max_mean], device='cuda', dtype=torch.float64) - gather_buf = [torch.zeros_like(info) for _ in range(num_ranks)] - dist.all_gather(gather_buf, info, group=group) - - phase_gather_buf = None - phase_names = [ - 'dispatch_total', 'dispatch_pull', 'math_loop', 'combine_barrier', - 'combine_reduce', 'gemm_core', 'l1_epilogue', 'l2_epilogue', - ] - if phase_profile_enabled: - torch.cuda.synchronize() - num_profile_metrics = len(phase_names) - profile = cum_stats[ - num_experts_per_rank:num_experts_per_rank + phase_profile_ints - ].view(torch.int64) - phase_values = [] - for i in range(num_profile_metrics): - total = float(profile[i].item()) - max_v = float(profile[num_profile_metrics + i].item()) - count = float(profile[2 * num_profile_metrics + i].item()) - avg = total / count if count else 0.0 - phase_values.extend([avg, max_v, count]) - phase_info = torch.tensor(phase_values, device='cuda', dtype=torch.float64) - phase_gather_buf = [torch.zeros_like(phase_info) for _ in range(num_ranks)] - dist.all_gather(phase_gather_buf, phase_info, group=group) - - if rank_idx == 0 and print_perf: - all_t = [g[0].item() for g in gather_buf] - all_recv = [int(g[1].item()) for g in gather_buf] - all_touch = [int(g[2].item()) for g in gather_buf] - all_lmm = [g[3].item() for g in gather_buf] - t_mean = sum(all_t) / num_ranks - t_max = max(all_t) - t_min = min(all_t) - recv_mean = sum(all_recv) / num_ranks - recv_max = max(all_recv) - recv_min = min(all_recv) - inter_max_mean = t_max / t_mean if t_mean > 0 else 0.0 - # Aggregate TFLOPS based on max wall time (real serving sees max) - total_flops = 2 * sum(all_recv) * (hidden * intermediate_hidden * 3) - tflops_agg = total_flops / 1e12 / t_max if t_max > 0 else 0.0 - # Per-rank avg view: mean recv across ranks per mean time - tflops_mean = (2 * recv_mean * hidden * intermediate_hidden * 3) / 1e12 / t_mean - - print(f'tokens={num_tokens:5d} ' - f'recv[mean/max/min]={recv_mean:6.0f}/{recv_max}/{recv_min} ' - f'local_max/mean[mean]={sum(all_lmm)/num_ranks:.2f} ' - f'inter_max/mean={inter_max_mean:.3f} ' - f't[mean/max/min]us={t_mean*1e6:7.1f}/{t_max*1e6:7.1f}/{t_min*1e6:7.1f} ' - f'TFLOPS_agg={tflops_agg:6.1f} TFLOPS_mean={tflops_mean:6.1f}', - flush=True) - if phase_gather_buf is not None: - phase_by_rank = [g.cpu().tolist() for g in phase_gather_buf] - for i, name in enumerate(phase_names): - avg_values = [rank_values[3 * i] for rank_values in phase_by_rank] - max_values = [rank_values[3 * i + 1] for rank_values in phase_by_rank] - count_values = [rank_values[3 * i + 2] for rank_values in phase_by_rank] - avg_mean = sum(avg_values) / len(avg_values) - avg_max = max(avg_values) - max_max = max(max_values) - count_max = max(count_values) - print( - f' phase {name:16s} avg_mean={avg_mean:10.0f} ' - f'avg_max={avg_max:10.0f} max={max_max:10.0f} ' - f'count_max={count_max:8.0f}', - flush=True, - ) - - dist.barrier() - buffer.destroy() - - -def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) - torch.manual_seed(rank_idx) - random.seed(rank_idx) - - if get_arch_major() != 9: - dist_print(f'[SKIP] requires SM90', once_in_node=True) - dist.destroy_process_group() - return - - batches = args.batches if args.batches else [1, 2, 4, 8, 16, 32] - - if rank_idx == 0: - print(f'SM90 MegaMoE bench: ranks={num_ranks} hidden={args.hidden} ' - f'ih={args.intermediate_hidden} experts={args.num_experts} ' - f'topk={args.num_topk} skew_alpha={args.skew_alpha} ' - f'skew_gain={args.skew_gain} masked_ratio={args.masked_ratio} ' - f'fast_math={bool(args.fast_math)}', flush=True) - - num_max_tokens_per_rank = max(batches) - for num_tokens in batches: - _run_one_config( - args, num_tokens, num_max_tokens_per_rank, - args.hidden, args.intermediate_hidden, - args.num_experts, args.num_topk, - num_ranks, rank_idx, group, - activation_clamp=args.activation_clamp, - fast_math=bool(args.fast_math), - ) - - dist.barrier() - dist.destroy_process_group() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='SM90 MegaMoE bench (skew-aware)') - - parser.add_argument('--ncu-profile-only', action='store_true') - parser.add_argument('--num-processes', type=int, default=8) - parser.add_argument('--local-rank-idx', type=int, default=None) - - parser.add_argument('--batches', type=int, nargs='+', default=None) - parser.add_argument('--hidden', type=int, default=4096) - parser.add_argument('--intermediate-hidden', type=int, default=2048) - parser.add_argument('--num-experts', type=int, default=256) - parser.add_argument('--num-topk', type=int, default=8) - parser.add_argument('--activation-clamp', type=float, default=10.0) - parser.add_argument('--masked-ratio', type=float, default=0.0) - parser.add_argument('--fast-math', type=int, default=1) - parser.add_argument('--num-tests', type=int, default=20) - parser.add_argument('--skew-alpha', type=float, default=0.0, - help='Zipf exponent for expert popularity; 0=uniform') - parser.add_argument('--skew-gain', type=float, default=1.0, - help='Multiplier on log-bias added to scores') - parser.add_argument('--skew-seed', type=int, default=0, - help='Seed for the global hot/cold expert permutation') - parser.add_argument('--num-redundant-experts', type=int, default=0, - help='EPLB replica simulation: reserve physical expert slots as hot-expert replicas') - parser.add_argument('--replica-dispatch', choices=('hash', 'static'), default='hash', - help='Replica remap model: token-level hash or SGLang static source-rank approximation') - - args = parser.parse_args() - - if args.local_rank_idx is not None: - test(args.local_rank_idx, args.num_processes, args) - else: - np = args.num_processes - torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/bench_mega_moe_sm90_skewed.py b/tests/bench_mega_moe_sm90_skewed.py deleted file mode 100644 index 74cca611bd..0000000000 --- a/tests/bench_mega_moe_sm90_skewed.py +++ /dev/null @@ -1,249 +0,0 @@ -"""SM90 MegaMoE bench with power-law expert load skew. - -Adds --skew-alpha (Zipf exponent). The score generator biases the topk to -favor a power-law subset of experts: - - bias[i] = log( 1 / (rank[i]+1)^alpha * num_experts ) * gain - scores = N(0, 1) + bias - -alpha=0 => uniform (matches upstream bench). -alpha=1 => Zipfian harmonic — hot experts get ~ln(N) more tokens than cold. -alpha=2 => strong skew — top expert can get 5-10x mean. - -Per-rank expert permutation is randomized so hot experts spread across ranks -(matches production behavior where placement is shuffled). - -Reports observed skew via num_recv/num_touched and per-expert recv distribution. -""" -import argparse -import os -import random -import sys -import torch -import torch.distributed as dist -from typing import Tuple - -REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if REPO_ROOT not in sys.path: - sys.path.insert(0, REPO_ROOT) - -import deep_gemm -from deep_gemm.utils import per_token_cast_to_fp8 -from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather -from deep_gemm.testing import bench_kineto, get_arch_major - - -def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g, n, k = w.shape - assert n % 128 == 0 and k % 128 == 0 - chunk_g = 4 - w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) - sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) - for start in range(0, g, chunk_g): - end = min(start + chunk_g, g) - w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() - sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / 448.0 - w_fp8[start:end].copy_( - (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k)) - sf[start:end].copy_(sf_chunk) - return w_fp8, sf.contiguous() - - -def _generate_skewed_scores(num_tokens: int, num_experts: int, num_topk: int, - alpha: float, gain: float, generator): - """Generate scores that produce a Zipfian topk distribution. - - Adds a per-expert bias derived from Zipfian rank to N(0,1) scores. The - final topk[i] is therefore drawn from a softer-than-Zipf but still - power-law-tailed distribution. - """ - scores = torch.randn((num_tokens, num_experts), dtype=torch.float, - device='cuda', generator=generator) - if alpha > 0: - # Zipf weights: probs[k] = 1 / (k+1)^alpha, normalized - ranks = torch.arange(1, num_experts + 1, device='cuda', dtype=torch.float) - probs = 1.0 / ranks.pow(alpha) - probs = probs / probs.sum() - # Bias scale relative to uniform expert prob - bias = torch.log(probs * num_experts) * gain - # Shuffle so hot experts are not concentrated at low ids (mimics - # production where init_expert_location randomizes placement). - perm = torch.randperm(num_experts, device='cuda', generator=generator) - bias = bias[perm] - scores = scores + bias.unsqueeze(0) - return scores - - -def _run_one_config(args, num_tokens, num_max_tokens_per_rank, - hidden, intermediate_hidden, - num_experts, num_topk, num_ranks, rank_idx, group): - num_experts_per_rank = num_experts // num_ranks - assert num_tokens <= num_max_tokens_per_rank - - buffer = deep_gemm.get_symm_buffer_for_mega_moe( - group, num_experts, - num_max_tokens_per_rank, num_topk, - hidden, intermediate_hidden, - ) - - gen = torch.Generator(device='cuda') - gen.manual_seed(rank_idx * 1009 + int(num_tokens)) - - x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda', generator=gen) - l1_bf = torch.randn( - (num_experts_per_rank, intermediate_hidden * 2, hidden), - dtype=torch.bfloat16, device='cuda', generator=gen) * 0.05 - l2_bf = torch.randn( - (num_experts_per_rank, hidden, intermediate_hidden), - dtype=torch.bfloat16, device='cuda', generator=gen) * 0.05 - - scores = _generate_skewed_scores(num_tokens, num_experts, num_topk, - args.skew_alpha, args.skew_gain, gen) - topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) - - x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, - use_packed_ue8m0=False) - l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) - l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) - transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( - (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), - ) - - cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') - use_skew_hint = args.skew_alpha > 0.0 - - def run_sm90(): - buffer.x[:num_tokens].copy_(x_fp8) - buffer.x_sf[:num_tokens].copy_(x_sf) - buffer.topk_idx[:num_tokens].copy_(topk_idx) - buffer.topk_weights[:num_tokens].copy_(topk_w) - y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - old_skew_hint = os.environ.get('DG_SM90_MOE_SKEW_HINT') - if use_skew_hint: - os.environ['DG_SM90_MOE_SKEW_HINT'] = '1' - try: - deep_gemm.fp8_mega_moe( - y, transformed_l1, transformed_l2, buffer, - cumulative_local_expert_recv_stats=cum_stats, - recipe=(128, 128, 128), - activation='swiglu', - activation_clamp=10.0, - fast_math=True, - ) - finally: - if use_skew_hint: - if old_skew_hint is None: - os.environ.pop('DG_SM90_MOE_SKEW_HINT', None) - else: - os.environ['DG_SM90_MOE_SKEW_HINT'] = old_skew_hint - return y - - run_sm90() - dist.barrier() - t_sm90 = bench_kineto(run_sm90, 'sm90_fp8_mega_moe', - barrier=lambda: dist.barrier(), - num_tests=args.num_tests, - suppress_kineto_output=True, - with_multiple_kernels=True) - - # Local expert count distribution - gathered_topk_idx = uneven_all_gather(topk_idx, group=group) - local_mask = ((gathered_topk_idx >= rank_idx * num_experts_per_rank) & - (gathered_topk_idx < (rank_idx + 1) * num_experts_per_rank)) - gathered_local = gathered_topk_idx.clone() - gathered_local[~local_mask] = -1 - num_recv_tokens = int(local_mask.sum().item()) - num_touched_experts = max(0, int(torch.unique(gathered_local.flatten()).numel()) - 1) - - # Per-local-expert recv count (for skew diagnostics) - if num_recv_tokens > 0: - per_expert_count = torch.zeros(num_experts_per_rank, dtype=torch.int64, device='cuda') - local_ids = gathered_local[local_mask] - rank_idx * num_experts_per_rank - per_expert_count.scatter_add_(0, local_ids.to(torch.int64), - torch.ones_like(local_ids, dtype=torch.int64)) - counts = per_expert_count.cpu().tolist() - mean_c = sum(counts) / num_experts_per_rank - max_c = max(counts) - min_c = min(counts) - nonzero = sum(1 for c in counts if c > 0) - skew_max_mean = max_c / mean_c if mean_c > 0 else 0.0 - else: - max_c = min_c = mean_c = 0 - skew_max_mean = 0.0 - nonzero = 0 - - safe_div = lambda a, b: float('nan') if b == 0 else a / b - tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_sm90) - num_hbm_bytes = ( - num_touched_experts * intermediate_hidden * 2 * hidden + - num_touched_experts * hidden * intermediate_hidden + - num_recv_tokens * hidden + - num_recv_tokens * intermediate_hidden + - num_recv_tokens * intermediate_hidden + - num_recv_tokens * hidden * 2 - ) - hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_sm90) - - dist_print( - f' tokens={num_tokens:5d} recv={num_recv_tokens:6d} nz_exp={nonzero:3d}/{num_experts_per_rank} ' - f'max/mean={skew_max_mean:.2f} ' - f'{t_sm90 * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', - once_in_node=True, - ) - - dist.barrier() - buffer.destroy() - - -def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) - torch.manual_seed(rank_idx) - random.seed(rank_idx) - - if get_arch_major() != 9: - dist_print(f'[SKIP] requires SM90', once_in_node=True) - dist.destroy_process_group() - return - - batches = args.batches if args.batches else [1, 2, 4, 8, 16, 32] - - dist_print( - f'SM90 MegaMoE bench (skew_alpha={args.skew_alpha} gain={args.skew_gain}): ' - f'ranks={num_ranks} hidden={args.hidden} ih={args.intermediate_hidden} ' - f'experts={args.num_experts} topk={args.num_topk} fast_math=True', - once_in_node=True, - ) - - num_max_tokens_per_rank = max(batches) - for num_tokens in batches: - _run_one_config( - args, num_tokens, num_max_tokens_per_rank, - args.hidden, args.intermediate_hidden, - args.num_experts, args.num_topk, - num_ranks, rank_idx, group, - ) - - dist.barrier() - dist.destroy_process_group() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='SM90 MegaMoE bench (skewed routing)') - parser.add_argument('--num-processes', type=int, default=8) - parser.add_argument('--local-rank-idx', type=int, default=None) - parser.add_argument('--batches', type=int, nargs='+', default=None) - parser.add_argument('--hidden', type=int, default=4096) - parser.add_argument('--intermediate-hidden', type=int, default=2048) - parser.add_argument('--num-experts', type=int, default=256) - parser.add_argument('--num-topk', type=int, default=8) - parser.add_argument('--num-tests', type=int, default=20) - parser.add_argument('--skew-alpha', type=float, default=0.0, - help='Zipf exponent for expert bias (0=uniform)') - parser.add_argument('--skew-gain', type=float, default=1.0, - help='Scale of log-bias added to scores (default 1.0)') - args = parser.parse_args() - if args.local_rank_idx is not None: - test(args.local_rank_idx, args.num_processes, args) - else: - np = args.num_processes - torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py deleted file mode 100644 index f799553dd9..0000000000 --- a/tests/test_mega_moe_hopper.py +++ /dev/null @@ -1,1551 +0,0 @@ -""" -H200 (SM90 / Hopper) mega-MoE: split L1/L2 kernel + 同管线 baseline 性能对比。 - -结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8: - * split SM90:调用 `deep_gemm.fp8_mega_moe` - (kernel symbols `sm90_fp8_mega_moe_l1_impl` + `sm90_fp8_mega_moe_l2_impl`), - 使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。 - * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine, - 使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation - per-128-K SFA,而 SM90 split mega-MoE 的 L1 epilogue 为避免跨 CTA - 同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照, - 不是 bitwise apples-to-apples correctness oracle。 - * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / SM90 us / - reduction us / `t_baseline / t_sm90` legacy 比。 -""" - -import argparse -import math -import os -import random -import torch -import torch.distributed as dist -import triton -import triton.language as tl -from typing import Tuple - -import deep_gemm -from deep_gemm.utils import per_token_cast_to_fp8 -from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather -from deep_gemm.testing import bench_kineto, get_arch_major - -try: - import deep_ep as _deep_ep - _deep_ep_import_error = None -except Exception as ex: - _deep_ep = None - _deep_ep_import_error = ex - - -# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口前缀同名, -# bench_kineto 用它从 trace 里汇总 split mega-MoE 的 L1/L2 GPU 段 -SM90_KERNEL_NAME_PREFIX = "sm90_fp8_mega_moe" -SM90_VS_LEGACY_MEAN_ABS_REL_TOL = 7e-2 - - -# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准 -FP8_E4M3_MAX = 448.0 -# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例, -# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。 -_FP8_E4M3_MAX_TL = tl.constexpr(448.0) -L1_ACT_SF_GRAN = 128 -SM90_L2_ACT_SF_GRAN = 64 -BASELINE_L2_ACT_SF_GRAN = 128 -WEIGHT_SF_GRAN_MN = 128 -WEIGHT_SF_GRAN_K = 128 - - -# ============================================================================ -# 模块 1:Triton SwiGLU + FP8 量化内核 -# ---------------------------------------------------------------------------- -# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按 -# per-128-K 输入;但 scale 数值采用 SM90 epilogue 同款 UE8M0/power-of-two 规则, -# 避免再额外引入 exact-FP32-scale 差异。 -# 输入 x : (M, 2*H) bf16,内层是 [gate_part | up_part] -# 输入 topk_w : (M,) fp32,可选 -# 输出 y : (M, H) fp8_e4m3fn -# 输出 y_sf : (M, H/BLOCK_K) fp32 行主序 -# ============================================================================ - - -@triton.jit -def _swiglu_apply_weight_to_fp8_kernel( - x_ptr, - topk_w_ptr, - y_ptr, - y_sf_ptr, - M, - H, # 运行时形状 - stride_xm, - stride_xn, # x: (M, 2H) 的 stride - stride_ym, - stride_yn, # y: (M, H) 的 stride - stride_sfm, - stride_sfk, # y_sf: (M, H/BLOCK_K) 的 stride - clamp_value, # 当 HAS_CLAMP=False 时这个参数无意义 - HAS_TOPK: tl.constexpr, - HAS_CLAMP: tl.constexpr, - USE_UE8M0_SCALE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_K: tl.constexpr, # = num_per_channels -): - # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列) - pid_m = tl.program_id(0) - pid_k = tl.program_id(1) - - # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M) - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - # 当前 K-block 内的列索引(在 H 维度,不是 2H) - offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) - mask_m = offs_m < M - - # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))---- - # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的 - gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn - up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn - gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32) - up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32) - - # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)---- - if HAS_CLAMP: - gate = tl.minimum(gate, clamp_value) - up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value) - - # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)---- - y = gate * tl.sigmoid(gate) * up - - # ---- 4) 可选 MoE 权重缩放(per-token 标量)---- - if HAS_TOPK: - w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0) - y = y * w[:, None] - - # ---- 5) 当前 K-block 内每行 absmax → scale ---- - amax = tl.max(tl.abs(y), axis=1) # (BLOCK_M,) - sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30) - if USE_UE8M0_SCALE: - # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv: - # scale = 2 ** ceil(log2(amax / 448)). - sf = tl.exp2(tl.ceil(tl.log2(sf))) - - # ---- 6) 量化为 FP8 e4m3fn ---- - y_fp8 = (y / sf[:, None]).to(tl.float8e4nv) - - # ---- 7) 写回 y 和 sf ---- - y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn - tl.store(y_ptrs, y_fp8, mask=mask_m[:, None]) - - sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk - tl.store(sf_ptrs, sf, mask=mask_m) - - -def swiglu_apply_weight_to_fp8_triton( - x: torch.Tensor, - topk_weights: torch.Tensor | None, - clamp_value: float | None = None, - num_per_channels: int = BASELINE_L2_ACT_SF_GRAN, - use_ue8m0_scale: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - """SwiGLU + FP8 量化。语义等价于 PyTorch reference: - gate, up = x[:, :H], x[:, H:] - y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w - y_sf = y.view(M, H/np, np).abs().amax(-1) / 448 - if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf) - y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8) - """ - assert x.is_cuda and x.dtype == torch.bfloat16 - assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位" - M, two_H = x.shape - H = two_H // 2 - assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍" - - y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device) - y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device) - - # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调 - BLOCK_M = 16 - grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels) - - # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位 - topk_ptr = topk_weights if topk_weights is not None else x - - _swiglu_apply_weight_to_fp8_kernel[grid]( - x, - topk_ptr, - y, - y_sf, - M, - H, - x.stride(0), - x.stride(1), - y.stride(0), - y.stride(1), - y_sf.stride(0), - y_sf.stride(1), - float(clamp_value) if clamp_value is not None else 0.0, - HAS_TOPK=topk_weights is not None, - HAS_CLAMP=clamp_value is not None, - USE_UE8M0_SCALE=use_ue8m0_scale, - BLOCK_M=BLOCK_M, - BLOCK_K=num_per_channels, - ) - return y, y_sf - - -# ============================================================================ -# 模块 2:grouped weight 的 (128, 128) FP8 块量化 -# ---------------------------------------------------------------------------- -# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定: -# 每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。 -# 与 SM100 FP4 路径的差异: -# * 不需要 deep_gemm.transform_sf_into_required_layout -# * SF 是 FP32,不是 UE8M0 packed -# ============================================================================ - - -def _quantize_grouped_fp8_block_128_128( - w: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。""" - g, n, k = w.shape - assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数" - - chunk_g = 4 - w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) - sf = torch.empty((g, n // 128, k // 128), dtype=torch.float, device=w.device) - for start in range(0, g, chunk_g): - end = min(start + chunk_g, g) - w_view = w[start:end].view(end - start, n // 128, 128, k // 128, 128).float() - sf_chunk = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) / FP8_E4M3_MAX - w_fp8[start:end].copy_( - (w_view / sf_chunk.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn).view(end - start, n, k) - ) - sf[start:end].copy_(sf_chunk) - return w_fp8, sf.contiguous() - - -# ============================================================================ -# 模块 3:尝试导入 deep_ep(用于 dispatch / combine) -# ============================================================================ - - -def _import_deep_ep(): - if _deep_ep is None: - dist_print(f"Failed to import deep_ep: {_deep_ep_import_error}", once_in_node=True) - return None - return _deep_ep - - -class _DeepEPHandle: - def __init__(self, raw_handle, psum_num_recv_tokens_per_expert: torch.Tensor): - self.raw_handle = raw_handle - self.psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert - - -class _DeepEPBufferCompat: - """Compatibility shim for newer DeepEP versions that expose Buffer, not ElasticBuffer.""" - - def __init__(self, deep_ep, group, num_nvl_bytes: int): - self.buffer = deep_ep.Buffer( - group, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=0, - explicitly_destroy=True, - ) - - def dispatch( - self, - x, - *, - topk_idx, - topk_weights, - num_experts: int, - expert_alignment: int, - **_, - ): - num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = ( - self.buffer.get_dispatch_layout(topk_idx, num_experts) - ) - recv_x, _, recv_topk_weights, num_recv_tokens_per_expert, raw_handle, event = self.buffer.dispatch( - x, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - topk_idx=topk_idx, - topk_weights=topk_weights, - expert_alignment=expert_alignment, - ) - psum = torch.tensor( - num_recv_tokens_per_expert, dtype=torch.int, device=topk_idx.device - ).cumsum(dim=0, dtype=torch.int) - return recv_x, None, recv_topk_weights, _DeepEPHandle(raw_handle, psum), event - - def combine(self, x, *, handle): - raw_handle = handle.raw_handle if isinstance(handle, _DeepEPHandle) else handle - return self.buffer.combine(x, handle=raw_handle) - - def barrier(self, use_comm_stream: bool = False): - torch.cuda.synchronize() - dist.barrier() - - def destroy(self): - self.buffer.destroy() - - -def _make_deep_ep_buffer(deep_ep, group, num_max_tokens_per_rank, hidden, num_topk, sym_buffer_bytes): - if hasattr(deep_ep, "ElasticBuffer"): - return deep_ep.ElasticBuffer( - group, - num_max_tokens_per_rank=num_max_tokens_per_rank, - hidden=hidden, - num_topk=num_topk, - use_fp8_dispatch=True, - explicitly_destroy=True, - allow_multiple_reduction=False, - ) - nvl_alignment = 2 * 1024 * 1024 - num_nvl_bytes = ((int(sym_buffer_bytes) + nvl_alignment - 1) // nvl_alignment) * nvl_alignment - return _DeepEPBufferCompat(deep_ep, group, num_nvl_bytes=num_nvl_bytes) - - -# ============================================================================ -# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖) -# ============================================================================ - - -def _bench_cuda_events( - fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0 -) -> float: - """返回 fn 的中位数耗时(秒)。""" - for _ in range(num_warmup): - fn() - torch.cuda.synchronize() - times_ms = [] - for _ in range(num_repeat): - # L2 flush,避免重复访问命中 cache 让测时偏低 - if l2_flush_gb > 0: - free_bytes, _ = torch.cuda.mem_get_info() - flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5)) - if flush_bytes >= 4: - torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_() - s = torch.cuda.Event(enable_timing=True) - e = torch.cuda.Event(enable_timing=True) - s.record() - fn() - e.record() - e.synchronize() - times_ms.append(s.elapsed_time(e)) - times_ms.sort() - return times_ms[len(times_ms) // 2] / 1e3 - - -# ============================================================================ -# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline -# ============================================================================ - - -def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group - rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) - torch.manual_seed(rank_idx) - random.seed(rank_idx) - - if get_arch_major() != 9: - dist_print( - f"[SKIP] test_mega_moe_hopper requires SM90; got SM{get_arch_major()}0", - once_in_node=True, - ) - dist.destroy_process_group() - return - - # 形状参数(与 test_mega_moe.py 同名同义) - num_max_tokens_per_rank = args.num_max_tokens_per_rank - num_tokens = ( - max( - 0, - args.num_max_tokens_per_rank - - random.randint(0, args.num_max_removed_tokens), - ) - if args.num_tokens == 0 - else args.num_tokens - ) - hidden, intermediate_hidden = args.hidden, args.intermediate_hidden - num_experts, num_topk = args.num_experts, args.num_topk - num_experts_per_rank = num_experts // num_ranks - assert num_tokens <= num_max_tokens_per_rank - assert num_experts % num_ranks == 0, ( - f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除" - ) - - # SM90 mega-MoE kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe): - # * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF) - # * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列) - assert hidden % 128 == 0 - assert intermediate_hidden % 128 == 0 - assert intermediate_hidden // 64 <= 64, ( - f"SM90 mega-MoE kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}" - ) - - # ---- 创建 BF16 输入:token 与两层 weight ---- - # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维 - x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") - # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起) - l1_weights_bf16 = torch.randn( - (num_experts_per_rank, intermediate_hidden * 2, hidden), - dtype=torch.bfloat16, - device="cuda", - ) - # L2 weight: 每个 expert 把 intermediate_hidden → hidden - l2_weights_bf16 = torch.randn( - (num_experts_per_rank, hidden, intermediate_hidden), - dtype=torch.bfloat16, - device="cuda", - ) - - # 路由:scores → topk_idx (M, K) + topk_weights (M, K) - eplb_replica_for = {} - eplb_replica_slots: list[int] = [] - if args.num_redundant_experts > 0: - assert args.num_redundant_experts % num_ranks == 0, "num_redundant_experts must divide num_ranks" - num_replicas_per_rank = args.num_redundant_experts // num_ranks - assert 0 < num_replicas_per_rank < num_experts_per_rank, "invalid redundant expert count" - for r in range(num_ranks): - base = r * num_experts_per_rank - eplb_replica_slots += list(range(base + num_experts_per_rank - num_replicas_per_rank, - base + num_experts_per_rank)) - logical_mask = torch.ones(num_experts, dtype=torch.bool, device="cuda") - logical_mask[torch.tensor(eplb_replica_slots, dtype=torch.long, device="cuda")] = False - if args.score_powerlaw_alpha > 0: - expert_rank = torch.arange(1, num_experts + 1, dtype=torch.float, device="cuda") - bias_for_hot = torch.pow(expert_rank, -args.score_powerlaw_alpha) - bias_for_hot = (bias_for_hot - bias_for_hot.mean()) / (bias_for_hot.std() + 1e-6) - hot_order = torch.argsort(bias_for_hot.masked_fill(~logical_mask, -float("inf")), descending=True).cpu().tolist() - else: - hot_order = torch.arange(num_experts, device="cuda")[logical_mask].cpu().tolist() - hot_experts = hot_order[:args.num_redundant_experts] - eplb_replica_for = {int(h): int(s) for h, s in zip(hot_experts, eplb_replica_slots)} - if rank_idx == 0: - print( - f" > eplb_sim redundant={args.num_redundant_experts} " - f"replicas_per_rank={num_replicas_per_rank} " - f"dispatch={args.replica_dispatch}", - flush=True, - ) - - def make_scores(): - scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda") - if args.score_powerlaw_alpha > 0: - expert_rank = torch.arange(1, num_experts + 1, dtype=torch.float, device="cuda") - bias = torch.pow(expert_rank, -args.score_powerlaw_alpha) - bias = (bias - bias.mean()) / (bias.std() + 1e-6) - scores = scores + args.score_powerlaw_scale * bias[None, :] - if eplb_replica_slots: - scores[:, torch.tensor(eplb_replica_slots, dtype=torch.long, device="cuda")] = -float("inf") - return scores - - def apply_eplb_replicas(idx: torch.Tensor) -> torch.Tensor: - if not eplb_replica_for: - return idx - mapped = idx.clone() - if args.replica_dispatch == "hash": - token_ids = torch.arange(num_tokens, device="cuda")[:, None] - slot_ids = torch.arange(num_topk, device="cuda")[None, :] - choose_replica = ((token_ids * num_topk + slot_ids + rank_idx) & 1).bool() - for logical_expert, replica_slot in eplb_replica_for.items(): - mapped = torch.where((idx == logical_expert) & choose_replica, - torch.full_like(mapped, replica_slot), mapped) - elif args.replica_dispatch == "static": - for logical_expert, replica_slot in eplb_replica_for.items(): - logical_rank = logical_expert // num_experts_per_rank - replica_rank = replica_slot // num_experts_per_rank - if rank_idx == logical_rank: - chosen = logical_expert - elif rank_idx == replica_rank: - chosen = replica_slot - else: - chosen = replica_slot if ((rank_idx + logical_expert) & 1) else logical_expert - if chosen != logical_expert: - mapped = torch.where(idx == logical_expert, - torch.full_like(mapped, chosen), mapped) - else: - raise ValueError(f"unknown replica_dispatch={args.replica_dispatch}") - return mapped - - if args.routing_mode == "balanced": - assert args.masked_ratio == 0.0, "balanced routing does not support masked_ratio" - assert (num_tokens * num_topk) % num_experts == 0, "balanced routing requires M*topk divisible by num_experts" - token_ids = torch.arange(num_tokens, device="cuda", dtype=torch.long)[:, None] - topk_offsets = torch.arange(num_topk, device="cuda", dtype=torch.long)[None, :] - topk_idx = (token_ids * num_topk + topk_offsets) % num_experts - topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") - elif args.routing_mode in ("balanced-shuffled", "balanced-shuffled-score"): - assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" - assert num_tokens % num_experts == 0, f"{args.routing_mode} requires M divisible by num_experts" - assert num_experts % num_topk == 0, f"{args.routing_mode} requires experts divisible by topk" - token_perm = torch.randperm(num_tokens, device="cuda") - expert_perm = torch.randperm(num_experts, device="cuda") - positions = torch.arange(num_tokens, device="cuda", dtype=torch.long) - slot_stride = num_experts // num_topk - topk_idx = torch.empty((num_tokens, num_topk), dtype=torch.long, device="cuda") - for slot in range(num_topk): - expert_ids = expert_perm[(positions + slot * slot_stride) % num_experts] - topk_idx[token_perm, slot] = expert_ids - if args.routing_mode == "balanced-shuffled-score": - scores = make_scores() - topk_weights = torch.gather(scores, 1, topk_idx) - else: - topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") - elif args.routing_mode == "topk-repair-budget-softmax": - assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" - assert (num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact local expert capacity" - scores = make_scores() - topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) - expert_capacity = (num_tokens * num_topk) // num_experts - scores_cpu = scores.cpu() - probs_cpu = torch.softmax(scores_cpu, dim=1) - selected = topk_idx.cpu().tolist() - original_selected = [row[:] for row in selected] - selected_sets = [set(row) for row in selected] - counts = [0] * num_experts - for row in selected: - for expert in row: - counts[expert] += 1 - - selected_mass = 0.0 - for token, row in enumerate(selected): - selected_mass += sum(float(probs_cpu[token, expert].item()) for expert in row) - original_mass = selected_mass - changed_slots = 0 - touched_tokens = set() - budget = args.repair_mass_drop_budget - - while True: - overflow = {e for e, c in enumerate(counts) if c > expert_capacity} - underfull = [e for e, c in enumerate(counts) if c < expert_capacity] - if not overflow or not underfull: - break - candidates = [] - for token, row in enumerate(selected): - token_set = selected_sets[token] - for slot, old_expert in enumerate(row): - if old_expert not in overflow: - continue - old_score = float(scores_cpu[token, old_expert].item()) - old_mass = float(probs_cpu[token, old_expert].item()) - for new_expert in underfull: - if new_expert in token_set: - continue - new_mass = float(probs_cpu[token, new_expert].item()) - loss = old_score - float(scores_cpu[token, new_expert].item()) - mass_loss = old_mass - new_mass - candidates.append((loss, mass_loss, token, slot, old_expert, new_expert)) - if not candidates: - break - candidates.sort(key=lambda x: x[0]) - changed = False - for _, mass_loss, token, slot, old_expert, new_expert in candidates: - if counts[old_expert] <= expert_capacity or counts[new_expert] >= expert_capacity: - continue - if selected[token][slot] != old_expert or new_expert in selected_sets[token]: - continue - next_mass = selected_mass - mass_loss - next_drop = (original_mass - next_mass) / max(original_mass, 1e-12) - if next_drop > budget: - continue - selected[token][slot] = new_expert - selected_sets[token].remove(old_expert) - selected_sets[token].add(new_expert) - counts[old_expert] -= 1 - counts[new_expert] += 1 - selected_mass = next_mass - changed_slots += 1 - touched_tokens.add(token) - changed = True - if not changed: - break - - topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") - topk_weights = torch.softmax(torch.gather(scores, 1, topk_idx), dim=-1) - mass_drop = (original_mass - selected_mass) / max(original_mass, 1e-12) - over_slots = sum(max(0, c - expert_capacity) for c in counts) - max_count = max(counts) - dist_print( - f" > bounded_repair rank={rank_idx}: budget={budget:.3f} " - f"changed={changed_slots / max(num_tokens * num_topk, 1) * 100:.1f}% " - f"touched={len(touched_tokens) / max(num_tokens, 1) * 100:.1f}% " - f"mass_drop={mass_drop * 100:.1f}% over_slots={over_slots} max_count={max_count}", - once_in_node=False, - ) - elif args.routing_mode == "global-repair-budget-softmax": - assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" - scores = make_scores() - all_scores = uneven_all_gather(scores, group=group) - local_num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") - all_num_tokens_t = [torch.zeros_like(local_num_tokens) for _ in range(num_ranks)] - dist.all_gather(all_num_tokens_t, local_num_tokens, group=group) - all_num_tokens = [int(x.item()) for x in all_num_tokens_t] - local_offset = sum(all_num_tokens[:rank_idx]) - total_num_tokens = sum(all_num_tokens) - assert (total_num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact global expert capacity" - - all_topk_weights, all_topk_idx = torch.topk(all_scores, num_topk, dim=-1, largest=True, sorted=False) - expert_capacity = (total_num_tokens * num_topk) // num_experts - scores_cpu = all_scores.cpu() - probs_cpu = torch.softmax(scores_cpu, dim=1) - score_order = torch.argsort(all_scores, dim=1, descending=True).cpu().tolist() - selected = all_topk_idx.cpu().tolist() - selected_sets = [set(row) for row in selected] - counts = [0] * num_experts - for row in selected: - for expert in row: - counts[expert] += 1 - - selected_mass = 0.0 - for token, row in enumerate(selected): - selected_mass += sum(float(probs_cpu[token, expert].item()) for expert in row) - original_mass = selected_mass - changed_slots = 0 - touched_tokens = set() - budget = args.repair_mass_drop_budget - max_rounds = max(1, args.repair_max_rounds) - - for _round in range(max_rounds): - overflow = {e for e, c in enumerate(counts) if c > expert_capacity} - if not overflow: - break - underfull = {e for e, c in enumerate(counts) if c < expert_capacity} - if not underfull: - break - candidates = [] - for token, row in enumerate(selected): - token_set = selected_sets[token] - for slot, old_expert in enumerate(row): - if old_expert not in overflow: - continue - new_expert = -1 - for cand in score_order[token]: - if cand in underfull and cand not in token_set: - new_expert = cand - break - if new_expert < 0: - continue - old_score = float(scores_cpu[token, old_expert].item()) - old_mass = float(probs_cpu[token, old_expert].item()) - new_mass = float(probs_cpu[token, new_expert].item()) - loss = old_score - float(scores_cpu[token, new_expert].item()) - mass_loss = old_mass - new_mass - candidates.append((loss, mass_loss, token, slot, old_expert, new_expert)) - if not candidates: - break - candidates.sort(key=lambda x: x[0]) - changed = False - for _, mass_loss, token, slot, old_expert, new_expert in candidates: - if counts[old_expert] <= expert_capacity or counts[new_expert] >= expert_capacity: - continue - if selected[token][slot] != old_expert or new_expert in selected_sets[token]: - continue - next_mass = selected_mass - mass_loss - next_drop = (original_mass - next_mass) / max(original_mass, 1e-12) - if next_drop > budget: - continue - selected[token][slot] = new_expert - selected_sets[token].remove(old_expert) - selected_sets[token].add(new_expert) - counts[old_expert] -= 1 - counts[new_expert] += 1 - selected_mass = next_mass - changed_slots += 1 - touched_tokens.add(token) - changed = True - if not changed: - break - - all_topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") - topk_idx = all_topk_idx[local_offset:local_offset + num_tokens].contiguous() - topk_weights = torch.softmax(torch.gather(scores, 1, topk_idx), dim=-1) - mass_drop = (original_mass - selected_mass) / max(original_mass, 1e-12) - over_slots = sum(max(0, c - expert_capacity) for c in counts) - max_count = max(counts) - if rank_idx == 0: - dist_print( - f" > global_bounded_repair: budget={budget:.3f} " - f"changed={changed_slots / max(total_num_tokens * num_topk, 1) * 100:.1f}% " - f"touched={len(touched_tokens) / max(total_num_tokens, 1) * 100:.1f}% " - f"mass_drop={mass_drop * 100:.1f}% over_slots={over_slots} max_count={max_count}", - once_in_node=False, - ) - elif args.routing_mode in ("topk-repair", "topk-repair-one", "topk-repair-softmax"): - assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" - assert (num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact local expert capacity" - scores = make_scores() - topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) - expert_capacity = (num_tokens * num_topk) // num_experts - scores_cpu = scores.cpu() - selected = topk_idx.cpu().tolist() - selected_sets = [set(row) for row in selected] - counts = [0] * num_experts - for row in selected: - for expert in row: - counts[expert] += 1 - - while True: - overflow = {e for e, c in enumerate(counts) if c > expert_capacity} - underfull = [e for e, c in enumerate(counts) if c < expert_capacity] - if not overflow: - break - candidates = [] - underfull_set = set(underfull) - for token, row in enumerate(selected): - token_set = selected_sets[token] - for slot, old_expert in enumerate(row): - if old_expert not in overflow: - continue - old_score = float(scores_cpu[token, old_expert].item()) - for new_expert in underfull: - if new_expert in token_set: - continue - loss = old_score - float(scores_cpu[token, new_expert].item()) - candidates.append((loss, token, slot, old_expert, new_expert)) - assert candidates, "topk-repair could not find a repair candidate" - candidates.sort(key=lambda x: x[0]) - changed = False - for _, token, slot, old_expert, new_expert in candidates: - if counts[old_expert] <= expert_capacity or counts[new_expert] >= expert_capacity: - continue - if selected[token][slot] != old_expert or new_expert in selected_sets[token]: - continue - selected[token][slot] = new_expert - selected_sets[token].remove(old_expert) - selected_sets[token].add(new_expert) - counts[old_expert] -= 1 - counts[new_expert] += 1 - changed = True - assert changed, "topk-repair made no progress" - assert all(c == expert_capacity for c in counts), "topk-repair failed to reach exact capacity" - topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") - if args.routing_mode == "topk-repair-one": - topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") - elif args.routing_mode == "topk-repair-softmax": - topk_weights = torch.softmax(torch.gather(scores, 1, topk_idx), dim=-1) - else: - topk_weights = torch.gather(scores, 1, topk_idx) - elif args.routing_mode in ("local-exact-score", "local-exact-score-one"): - assert args.masked_ratio == 0.0, f"{args.routing_mode} routing does not support masked_ratio" - assert (num_tokens * num_topk) % num_experts == 0, f"{args.routing_mode} requires exact local expert capacity" - scores = make_scores() - expert_capacity = (num_tokens * num_topk) // num_experts - assert expert_capacity % num_topk == 0, f"{args.routing_mode} requires per-slot expert capacity" - per_slot_capacity = expert_capacity // num_topk - score_order = torch.argsort(scores, dim=1, descending=True).cpu().tolist() - selected = [[-1] * num_topk for _ in range(num_tokens)] - selected_sets = [set() for _ in range(num_tokens)] - - import sys - sys.setrecursionlimit(max(10000, num_tokens * 4)) - for slot in range(num_topk): - assignment = [-1] * num_tokens - matched_tokens = [[] for _ in range(num_experts)] - token_order = sorted( - range(num_tokens), - key=lambda t: scores[t, score_order[t][0]].item(), - reverse=True, - ) - - def try_assign(token: int, seen_experts: set[int]) -> bool: - for expert in score_order[token]: - if expert in selected_sets[token] or expert in seen_experts: - continue - seen_experts.add(expert) - if len(matched_tokens[expert]) < per_slot_capacity: - matched_tokens[expert].append(token) - assignment[token] = expert - return True - for idx, other_token in enumerate(list(matched_tokens[expert])): - if try_assign(other_token, seen_experts): - matched_tokens[expert][idx] = token - assignment[token] = expert - return True - return False - - for token in token_order: - assert try_assign(token, set()), "local-exact-score matching failed" - assert all(expert >= 0 for expert in assignment), "local-exact-score left unassigned tokens" - assert all(len(tokens) == per_slot_capacity for tokens in matched_tokens), "local-exact-score left capacity imbalance" - for token, expert in enumerate(assignment): - selected[token][slot] = expert - selected_sets[token].add(expert) - - topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") - if args.routing_mode == "local-exact-score-one": - topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float, device="cuda") - else: - topk_weights = torch.gather(scores, 1, topk_idx) - elif args.routing_mode == "local-capacity": - assert args.masked_ratio == 0.0, "local-capacity routing does not support masked_ratio" - scores = make_scores() - expert_capacity = math.ceil(num_tokens * num_topk / num_experts) - candidate_k = num_experts - cand_vals, cand_idx = torch.topk(scores, candidate_k, dim=-1, largest=True, sorted=True) - order = torch.argsort(cand_vals[:, 0], descending=True).cpu().tolist() - cand_idx_cpu = cand_idx.cpu().tolist() - remaining = [expert_capacity] * num_experts - selected = [[-1] * num_topk for _ in range(num_tokens)] - pending: list[tuple[int, int]] = [] - for token in order: - used = 0 - for expert in cand_idx_cpu[token]: - if remaining[expert] > 0: - selected[token][used] = expert - remaining[expert] -= 1 - used += 1 - if used == num_topk: - break - if used < num_topk: - pending.append((token, used)) - fallback_cursor = 0 - for token, used in pending: - already = set(selected[token][:used]) - while used < num_topk: - found = False - for _ in range(num_experts): - expert = fallback_cursor % num_experts - fallback_cursor += 1 - if remaining[expert] > 0 and expert not in already: - selected[token][used] = expert - remaining[expert] -= 1 - already.add(expert) - used += 1 - found = True - break - if not found: - for expert in cand_idx_cpu[token]: - if expert not in already: - selected[token][used] = expert - already.add(expert) - used += 1 - found = True - break - assert found, "local-capacity routing could not fill all topk slots" - topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") - topk_weights = torch.gather(scores, 1, topk_idx) - elif args.routing_mode == "capacity": - assert args.masked_ratio == 0.0, "capacity routing does not support masked_ratio" - scores = make_scores() - all_scores = uneven_all_gather(scores, group=group) - local_num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") - all_num_tokens_t = [torch.zeros_like(local_num_tokens) for _ in range(num_ranks)] - dist.all_gather(all_num_tokens_t, local_num_tokens, group=group) - all_num_tokens = [int(x.item()) for x in all_num_tokens_t] - local_offset = sum(all_num_tokens[:rank_idx]) - total_num_tokens = sum(all_num_tokens) - expert_capacity = math.ceil(total_num_tokens * num_topk / num_experts) - candidate_k = num_experts - cand_vals, cand_idx = torch.topk(all_scores, candidate_k, dim=-1, largest=True, sorted=True) - order = torch.argsort(cand_vals[:, 0], descending=True).cpu().tolist() - cand_idx_cpu = cand_idx.cpu().tolist() - remaining = [expert_capacity] * num_experts - selected = [[-1] * num_topk for _ in range(total_num_tokens)] - pending: list[tuple[int, int]] = [] - for token in order: - used = 0 - for expert in cand_idx_cpu[token]: - if remaining[expert] > 0: - selected[token][used] = expert - remaining[expert] -= 1 - used += 1 - if used == num_topk: - break - if used < num_topk: - pending.append((token, used)) - fallback_cursor = 0 - for token, used in pending: - already = set(selected[token][:used]) - while used < num_topk: - found = False - for _ in range(num_experts): - expert = fallback_cursor % num_experts - fallback_cursor += 1 - if remaining[expert] > 0 and expert not in already: - selected[token][used] = expert - remaining[expert] -= 1 - already.add(expert) - used += 1 - found = True - break - if not found: - for expert in cand_idx_cpu[token]: - if expert not in already: - selected[token][used] = expert - already.add(expert) - used += 1 - found = True - break - assert found, "capacity routing could not fill all topk slots" - if rank_idx == 0: - probs_cpu = torch.softmax(all_scores.cpu(), dim=1) - natural_mass = 0.0 - selected_mass = 0.0 - changed_slots = 0 - touched_tokens = 0 - for token, row in enumerate(selected): - natural = cand_idx_cpu[token][:num_topk] - natural_set = set(natural) - row_set = set(row) - overlap = len(natural_set & row_set) - changed_slots += num_topk - overlap - touched_tokens += overlap != num_topk - natural_mass += sum(float(probs_cpu[token, expert].item()) for expert in natural) - selected_mass += sum(float(probs_cpu[token, expert].item()) for expert in row) - mass_drop = (natural_mass - selected_mass) / max(natural_mass, 1e-12) - used_counts = [expert_capacity - r for r in remaining] - dist_print( - f" > capacity_quality: changed={changed_slots / max(total_num_tokens * num_topk, 1) * 100:.1f}% " - f"touched={touched_tokens / max(total_num_tokens, 1) * 100:.1f}% " - f"mass_drop={mass_drop * 100:.1f}% max_count={max(used_counts)}", - once_in_node=False, - ) - all_topk_idx = torch.tensor(selected, dtype=torch.long, device="cuda") - topk_idx = all_topk_idx[local_offset:local_offset + num_tokens].contiguous() - topk_weights = torch.gather(scores, 1, topk_idx) - else: - scores = make_scores() - topk_weights, topk_idx = torch.topk( - scores, num_topk, dim=-1, largest=True, sorted=False - ) - topk_idx = apply_eplb_replicas(topk_idx) - if args.masked_ratio > 0: - rand_mask = torch.rand_like(topk_idx, dtype=torch.float) - topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) - topk_weights.masked_fill_(topk_idx < 0, 0) - - # 累计接收统计:SM90 与 baseline 各持一份避免相互覆盖 - phase_profile_enabled = os.environ.get("DG_SM90_MOE_PHASE_PROFILE", "0") not in ("", "0") - phase_profile_extra = 64 if phase_profile_enabled else 0 - cum_stats_sm90 = torch.zeros( - (num_experts_per_rank + phase_profile_extra,), dtype=torch.int, device="cuda" - ) - cum_stats_baseline = cum_stats_sm90.clone() - - # ---- BF16 → FP8 量化 ---- - # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序) - # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF - x_fp8 = per_token_cast_to_fp8( - x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False - ) - - # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF) - # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组 - l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16) - l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16) - - # SM90 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变 - transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( - l1_weights, l2_weights - ) - - # SwiGLU clamp:finite -> 传给 SM90/triton;inf -> None(关闭 clamp,与 SM90 一致) - clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None - run_baseline_enabled = args.run_baseline or bool(args.check_output_diff) - - # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)---- - alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() - deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) - - # ---- 分配 SM90 的 SymmBuffer 与输出 buffer ---- - sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe( - group, - num_experts, - num_max_tokens_per_rank, - num_topk, - hidden, - intermediate_hidden, - ) - y_sm90 = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") - use_eplb_hint = bool(eplb_replica_for) - use_skew_hint = args.score_powerlaw_alpha > 0.0 - use_masked_hint = args.masked_ratio > 0.0 - - def run_sm90(): - # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时 - # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入 - sym_buffer.x[:num_tokens].copy_(x_fp8[0]) - sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1]) - sym_buffer.topk_idx[:num_tokens].copy_(topk_idx) - sym_buffer.topk_weights[:num_tokens].copy_(topk_weights) - - old_eplb_hint = os.environ.get("DG_SM90_MOE_EPLB_HINT") - old_skew_hint = os.environ.get("DG_SM90_MOE_SKEW_HINT") - old_masked_hint = os.environ.get("DG_SM90_MOE_MASKED_HINT") - if use_eplb_hint: - os.environ["DG_SM90_MOE_EPLB_HINT"] = "1" - if use_skew_hint: - os.environ["DG_SM90_MOE_SKEW_HINT"] = "1" - if use_masked_hint: - os.environ["DG_SM90_MOE_MASKED_HINT"] = "1" - try: - deep_gemm.fp8_mega_moe( - y_sm90, - transformed_l1, - transformed_l2, - sym_buffer, - cumulative_local_expert_recv_stats=cum_stats_sm90, - recipe=(128, 128, 128), - activation="swiglu", - activation_clamp=clamp_arg, - fast_math=bool(args.fast_math), - ) - finally: - if use_eplb_hint: - if old_eplb_hint is None: - os.environ.pop("DG_SM90_MOE_EPLB_HINT", None) - else: - os.environ["DG_SM90_MOE_EPLB_HINT"] = old_eplb_hint - if use_skew_hint: - if old_skew_hint is None: - os.environ.pop("DG_SM90_MOE_SKEW_HINT", None) - else: - os.environ["DG_SM90_MOE_SKEW_HINT"] = old_skew_hint - if use_masked_hint: - if old_masked_hint is None: - os.environ.pop("DG_SM90_MOE_MASKED_HINT", None) - else: - os.environ["DG_SM90_MOE_MASKED_HINT"] = old_masked_hint - return y_sm90 - - # ---- 打印 config ---- - dist_print("Config (H200 split SM90 mega-MoE):", once_in_node=True) - dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True) - dist_print( - f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True - ) - dist_print( - f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})", - once_in_node=True, - ) - dist_print(f" > Masked ratio: {args.masked_ratio}", once_in_node=True) - dist_print( - f" > Activation SF: SM90 L2 per-{SM90_L2_ACT_SF_GRAN} UE8M0, " - f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 " - f"(SM90 grouped GEMM constraint)", - once_in_node=True, - ) - dist_print( - f" > Baseline: {'enabled' if run_baseline_enabled else 'disabled'}", - once_in_node=True, - ) - dist_print( - f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True - ) - dist_print(once_in_node=True) - - # 与社区版 test_mega_moe.py 对齐:NCU 模式只跑 SM90 mega-MoE kernel,避免 baseline 噪声。 - if args.ncu_profile_only: - dist_print("Run split SM90 mega-MoE kernel:", once_in_node=True) - y = run_sm90() - torch.cuda.synchronize() - assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16 - dist_print(" > Done, exiting", once_in_node=True) - dist.barrier() - sym_buffer.destroy() - dist.destroy_process_group() - return - - # ---- 分配 DeepEP buffer(baseline 用)---- - deep_ep = _import_deep_ep() if run_baseline_enabled else None - ep_buffer = None - if deep_ep is not None: - ep_buffer = _make_deep_ep_buffer( - deep_ep, - group, - num_max_tokens_per_rank, - hidden, - num_topk, - sym_buffer.buffer.nbytes, - ) - - # ---------------------------------------------------------------- - # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine - # 与 SM90 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换** - # 的版本(baseline grouped GEMM 不需要 gate/up interleave) - # ---------------------------------------------------------------- - def run_baseline(): - recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch( - x_fp8, - topk_idx=topk_idx, - topk_weights=topk_weights, - cumulative_local_expert_recv_stats=cum_stats_baseline, - num_experts=num_experts, - expert_alignment=alignment, - do_cpu_sync=False, - do_handle_copy=False, - do_expand=True, - use_tma_aligned_col_major_sf=False, # SM90: row-major float SF - ) - n = recv_x[0].size(0) - - # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接) - l1_y = torch.empty( - (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda" - ) - deep_gemm.m_grouped_fp8_gemm_nt_contiguous( - recv_x, - l1_weights, - l1_y, - handle.psum_num_recv_tokens_per_expert, - use_psum_layout=True, - disable_ue8m0_cast=True, - ) - - # Triton SwiGLU + FP8 量化(含 topk 权重乘法) - # 注意:SM90 split mega-MoE 的 L2 activation SFA 是 per-64-K; - # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline - # 只能用 per-128-K,但 scale 数值采用 SM90 同款 UE8M0/power-of-two。 - l1_y = swiglu_apply_weight_to_fp8_triton( - x=l1_y, - topk_weights=recv_topk_weights, - clamp_value=clamp_arg, - num_per_channels=BASELINE_L2_ACT_SF_GRAN, - use_ue8m0_scale=True, - ) - - # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16 - l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda") - deep_gemm.m_grouped_fp8_gemm_nt_contiguous( - l1_y, - l2_weights, - l2_y, - handle.psum_num_recv_tokens_per_expert, - use_psum_layout=True, - disable_ue8m0_cast=True, - ) - - # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank - return ep_buffer.combine(l2_y, handle=handle)[0] - - # ---- 跑一次确保不报错(SM90 split + 可选 baseline)---- - y = run_sm90() - assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, ( - f"SM90 split 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}" - ) - if ep_buffer is not None: - out_b = run_baseline() - assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, ( - f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}" - ) - if args.check_output_diff: - diff = (y.float() - out_b.float()).abs() - denom = out_b.float().abs().mean().clamp_min(1e-12) - mean_abs_rel = diff.mean().div(denom) - dist_print( - "Output diff (SM90 split vs legacy-per128 baseline):", once_in_node=True - ) - dist_print( - f" > max_abs={diff.max().item():.6e}, " - f"mean_abs={diff.mean().item():.6e}, " - f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " - f"threshold={SM90_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}", - once_in_node=True, - ) - dist_print(once_in_node=True) - assert mean_abs_rel <= SM90_VS_LEGACY_MEAN_ABS_REL_TOL, ( - "SM90 split output differs from legacy-per128 baseline: " - f"mean_abs/mean_ref={mean_abs_rel.item():.6e}, " - f"threshold={SM90_VS_LEGACY_MEAN_ABS_REL_TOL:.6e}" - ) - - # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ---- - # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目 - # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。 - gathered_topk_idx = uneven_all_gather(topk_idx, group=group) - all_routed_topk_idx = gathered_topk_idx - local_num_tokens_t = torch.tensor([num_tokens], dtype=torch.long, device="cuda") - all_num_tokens_t = [torch.zeros_like(local_num_tokens_t) for _ in range(num_ranks)] - dist.all_gather(all_num_tokens_t, local_num_tokens_t, group=group) - all_num_tokens = [int(x.item()) for x in all_num_tokens_t] - peer_recv_counts = [] - row_start = 0 - for src_tokens in all_num_tokens: - src_topk = all_routed_topk_idx[row_start:row_start + src_tokens] - peer_recv_counts.append(int(((src_topk >= rank_idx * num_experts_per_rank) & - (src_topk < (rank_idx + 1) * num_experts_per_rank)).sum().item())) - row_start += src_tokens - max_peer_recv = max(peer_recv_counts) if peer_recv_counts else 0 - gathered_topk_idx = all_routed_topk_idx.clone() - gathered_topk_idx[ - (gathered_topk_idx < rank_idx * num_experts_per_rank) - | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank) - ] = -1 - local_expert_ids = gathered_topk_idx[gathered_topk_idx != -1] - num_recv_tokens = int(local_expert_ids.numel()) - num_touched_experts = int(torch.unique(local_expert_ids).numel()) - if num_recv_tokens > 0: - local_counts = torch.bincount( - local_expert_ids - rank_idx * num_experts_per_rank, - minlength=num_experts_per_rank, - ) - num_m_tiles = int(((local_counts + 63) // 64).sum().item()) - max_expert_tokens = int(local_counts.max().item()) - else: - num_m_tiles = 0 - max_expert_tokens = 0 - - # ---- benchmark ---- - # split SM90:bench_kineto 抓 sm90_fp8_mega_moe* 的 GPU 段(不含 host overhead) - if phase_profile_enabled: - cum_stats_sm90.zero_() - t_sm90 = bench_kineto( - run_sm90, - SM90_KERNEL_NAME_PREFIX, - num_tests=args.num_bench_tests, - barrier=lambda: ep_buffer.barrier(use_comm_stream=False) - if ep_buffer is not None - else dist.barrier(), - trace_path=( - f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json" - if args.dump_profile_traces - else None - ), - with_multiple_kernels=True, - ) - if phase_profile_enabled: - cum_stats_sm90.zero_() - torch.cuda.synchronize() - if ep_buffer is not None: - ep_buffer.barrier(use_comm_stream=False) - else: - dist.barrier() - phase_start = torch.cuda.Event(enable_timing=True) - phase_end = torch.cuda.Event(enable_timing=True) - phase_start.record() - run_sm90() - phase_end.record() - torch.cuda.synchronize() - phase_event_us = phase_start.elapsed_time(phase_end) * 1000.0 - raw_i32 = cum_stats_sm90[num_experts_per_rank:num_experts_per_rank + phase_profile_extra].detach().cpu().tolist() - def _u64(slot: int) -> int: - lo = raw_i32[slot * 2] & 0xffffffff - hi = raw_i32[slot * 2 + 1] & 0xffffffff - return lo | (hi << 32) - names = ( - "dispatch_total", - "dispatch_pull", - "math_loop", - "combine_barrier", - "combine_reduce", - "gemm_core", - "l1_epilogue", - "l2_epilogue", - ) - num_profile_metrics = len(names) - pieces = [] - for idx, name in enumerate(names): - total = _u64(idx) - max_cycles = _u64(num_profile_metrics + idx) - count = _u64(2 * num_profile_metrics + idx) - avg_us = (total / count / 1000.0) if count else 0.0 - max_us = max_cycles / 1000.0 - pieces.append(f"{name}:avg={avg_us:.1f}us,max={max_us:.1f}us,n={count},ns={total}/{max_cycles}") - dist_print(f" > phase_profile rank={rank_idx}: event={phase_event_us:.1f}us; " + "; ".join(pieces)) - # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events) - t_baseline = ( - _bench_cuda_events( - run_baseline, - num_warmup=args.num_warmup, - num_repeat=args.num_repeat, - l2_flush_gb=args.l2_flush_gb, - ) - if ep_buffer is not None - else 0.0 - ) - - def safe_div(a, b): - return float("nan") if b == 0 else a / b - - # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens - tflops = safe_div( - 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_sm90 - ) - - # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同) - l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden - l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden - l1_weight_sf_bytes = ( - num_touched_experts - * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN) - * (hidden // WEIGHT_SF_GRAN_K) - * 4 - ) - l2_weight_sf_bytes = ( - num_touched_experts - * (hidden // WEIGHT_SF_GRAN_MN) - * (intermediate_hidden // WEIGHT_SF_GRAN_K) - * 4 - ) - l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4 - l2_act_sf_bytes = ( - num_recv_tokens * (intermediate_hidden // SM90_L2_ACT_SF_GRAN) * 4 - ) - num_hbm_bytes = ( - l1_weight_bytes - + l2_weight_bytes # weights (FP8) - + l1_weight_sf_bytes - + l2_weight_sf_bytes # weight SF (FP32) - + num_recv_tokens * hidden - + l1_input_sf_bytes # L1 输入读 (FP8 + SF) - + num_recv_tokens * intermediate_hidden - + l2_act_sf_bytes # L1 输出写 (FP8 + SF) - + num_recv_tokens * intermediate_hidden - + l2_act_sf_bytes # L2 输入读 (FP8 + SF) - + num_recv_tokens * hidden * 2 # L2 输出写 (BF16) - ) - hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_sm90) - - # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16 - num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2) - nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_sm90) - - # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s) - t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12 - - # overlap 校正:扣掉 SM90 中无法重叠的串行 reduction 段后估计稳态吞吐 - approx_factor = t_sm90 / max(t_sm90 - t_reduction, 1e-12) - - # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline - tflops_baseline = safe_div( - 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline - ) - hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline) - nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline) - - def fmt_perf_line( - name: str, - t: float, - compute_tflops: float, - hbm_gbs_: float, - nvlink_gbs_: float, - reduction_us: float | None = None, - speedup: float | None = None, - ) -> str: - reduction = f"{reduction_us:13.1f}" if reduction_us is not None else f"{'-':>13}" - speedup_text = ( - f"{speedup:6.2f}x {'SM90 faster' if speedup > 1 else 'baseline faster'}" - if speedup is not None else - f"{'-':>21}" - ) - return ( - f" > {name:<10} {rank_idx:2d}/{num_ranks:<2d} " - f"{num_recv_tokens:12d} " - f"{num_touched_experts:14d} {num_m_tiles:7d} {max_expert_tokens:8d} {max_peer_recv:8d} | " - f"{compute_tflops:15.0f} " - f"{hbm_gbs_:9.0f} " - f"{nvlink_gbs_:9.0f} " - f"{t * 1e6:9.0f} " - f"{reduction} " - f"{speedup_text}" - ) - - dist_print("Performance:", once_in_node=True) - dist_print( - " > kind EP recv_tokens active_experts m_tiles max_exp max_peer | " - "compute(TFLOPS) HBM(GB/s) NVL(GB/s) time(us) reduction(us) speedup", - once_in_node=True, - ) - dist_print( - fmt_perf_line( - "[sm90]", - t_sm90, - tflops * approx_factor, - hbm_gbs * approx_factor, - nvlink_gbs * approx_factor, - reduction_us=t_reduction * 1e6, - ) - ) - if ep_buffer is not None: - speedup = safe_div(t_baseline, t_sm90) - dist_print( - fmt_perf_line( - "[baseline]", - t_baseline, - tflops_baseline, - hbm_gbs_baseline, - nvlink_gbs_baseline, - speedup=speedup, - ) - ) - else: - reason = ( - "disabled; pass --run-baseline or --check-output-diff to compare" - if not run_baseline_enabled - else "deep_ep unavailable" - ) - dist_print(f" > [baseline] ({reason})", once_in_node=True) - - # ---- 清理 ---- - dist.barrier() - sym_buffer.destroy() - if ep_buffer is not None: - ep_buffer.destroy() - dist.destroy_process_group() - - -# ============================================================================ -# 模块 6:argparse + spawn -# ============================================================================ - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="H200 mega-MoE: split SM90 deep_gemm.fp8_mega_moe vs DeepEP+grouped-FP8 baseline" - ) - - # 资源 - parser.add_argument( - "--ncu-profile-only", - action="store_true", - help="只运行一次 split SM90 kernel,便于 NCU/Nsight 采样", - ) - parser.add_argument( - "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)" - ) - parser.add_argument( - "--local-rank-idx", - type=int, - default=None, - help="单进程模式的 local rank;用于外部 launcher/NCU 分别启动每个 rank", - ) - - # 模型形状 - # 注:SM90 mega-MoE kernel 要求 intermediate_hidden <= 4096 - parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192) - parser.add_argument( - "--num-tokens", - type=int, - default=0, - help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank", - ) - parser.add_argument( - "--num-max-removed-tokens", - type=int, - default=0, - help="num-tokens 为 0 时,每个 rank 随机移除的最大 token 数", - ) - parser.add_argument("--hidden", type=int, default=7168) - parser.add_argument( - "--intermediate-hidden", - type=int, - default=3072, - help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)", - ) - parser.add_argument( - "--activation-clamp", - type=float, - default=10.0, - help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭", - ) - parser.add_argument("--num-experts", type=int, default=384) - parser.add_argument("--num-topk", type=int, default=6) - parser.add_argument( - "--routing-mode", - type=str, - default="random", - choices=("random", "balanced", "balanced-shuffled", "balanced-shuffled-score", "topk-repair", "topk-repair-one", "topk-repair-softmax", "topk-repair-budget-softmax", "global-repair-budget-softmax", "local-exact-score", "local-exact-score-one", "local-capacity", "capacity"), - help="routing 构造方式;balanced/balanced-shuffled/topk-repair/local-exact-score/local-capacity/capacity 控制每 expert assignment 数", - ) - parser.add_argument( - "--masked-ratio", - type=float, - default=0.0, - help="随机 mask 掉部分 topk expert selection,用于验证稀疏路由边界", - ) - parser.add_argument( - "--score-powerlaw-alpha", - type=float, - default=0.0, - help="给 routing score 加 Zipf/power-law expert bias;0 表示关闭", - ) - parser.add_argument( - "--score-powerlaw-scale", - type=float, - default=1.0, - help="power-law bias 的标准差尺度", - ) - parser.add_argument( - "--repair-mass-drop-budget", - type=float, - default=0.0, - help="topk-repair-budget-softmax 的 aggregate selected softmax-mass drop 上限,例如 0.10", - ) - parser.add_argument( - "--repair-max-rounds", - type=int, - default=4, - help="global-repair-budget-softmax 的 greedy repair 最大轮数,避免诊断代码超时", - ) - parser.add_argument( - "--num-redundant-experts", - type=int, - default=0, - help="EPLB replica simulation: reserve physical expert slots as hot-expert replicas", - ) - parser.add_argument( - "--replica-dispatch", - choices=("hash", "static"), - default="hash", - help="replica remap model: token-level hash or SGLang static source-rank approximation", - ) - parser.add_argument( - "--fast-math", - type=int, - default=1, - help="SM90 kernel 内 SwiGLU 是否启用 fast-math(0/1)", - ) - - # 测时 - parser.add_argument( - "--num-bench-tests", - type=int, - default=30, - help="bench_kineto 抓 SM90 split kernel 时的迭代数", - ) - parser.add_argument( - "--num-warmup", type=int, default=5, help="baseline cuda events warmup" - ) - parser.add_argument( - "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代" - ) - parser.add_argument( - "--l2-flush-gb", - type=float, - default=8.0, - help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭", - ) - parser.add_argument( - "--run-baseline", - action="store_true", - help="启用 DeepEP+grouped-FP8 legacy baseline;默认关闭以避免 full-size 默认配置触发 baseline kernel 非法访问", - ) - parser.add_argument( - "--check-output-diff", - type=int, - default=0, - help="非 0 时打印 SM90 split 与 legacy-per128 baseline 的输出差异(预期非 bitwise)", - ) - parser.add_argument( - "--dump-profile-traces", - type=str, - default="", - help="非空时把 SM90 split kernel 的 Chrome trace 写到该目录(每 rank 一份)", - ) - - args = parser.parse_args() - - if args.dump_profile_traces: - os.makedirs(args.dump_profile_traces, exist_ok=True) - - if args.local_rank_idx is not None: - # 单进程模式:由外部 launcher 分别设置 MASTER_ADDR/PORT/WORLD_SIZE/RANK。 - test(args.local_rank_idx, args.num_processes, args) - else: - # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group。 - torch.multiprocessing.spawn( - test, args=(args.num_processes, args), nprocs=args.num_processes - ) From e83d0fa35b60c2a21462e7e4c3b2f7ec7d1c903d Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Wed, 3 Jun 2026 12:30:32 +0800 Subject: [PATCH 13/15] Optimize MegaMoE split path --- csrc/jit_kernels/heuristics/mega_moe.hpp | 117 +++++++---------------- deep_gemm/mega/__init__.py | 2 - tests/test_mega_moe_sm90.py | 37 +++---- 3 files changed, 53 insertions(+), 103 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 443dd4f915..95f1700e9d 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -302,10 +302,10 @@ struct MegaMoESM90Config { } }; -enum class Sm90MoeDeviceProfile { +enum class Sm90MoeRuntimeProfile { Generic, - H20, - H200 + LowSm, + HighSm }; static std::string get_sm90_moe_lowercase(std::string value) { @@ -315,24 +315,24 @@ static std::string get_sm90_moe_lowercase(std::string value) { return value; } -static Sm90MoeDeviceProfile get_sm90_moe_device_profile() { +static Sm90MoeRuntimeProfile get_sm90_moe_runtime_profile() { const auto forced = get_sm90_moe_lowercase( get_env("DG_SM90_MOE_DEVICE_PROFILE", "")); if (not forced.empty() and forced != "auto") { - DG_HOST_ASSERT(forced == "generic" or forced == "h20" or forced == "h200"); - if (forced == "h20") - return Sm90MoeDeviceProfile::H20; - if (forced == "h200") - return Sm90MoeDeviceProfile::H200; - return Sm90MoeDeviceProfile::Generic; + DG_HOST_ASSERT(forced == "generic" or forced == "low_sm" or forced == "high_sm"); + if (forced == "low_sm") + return Sm90MoeRuntimeProfile::LowSm; + if (forced == "high_sm") + return Sm90MoeRuntimeProfile::HighSm; + return Sm90MoeRuntimeProfile::Generic; } - const auto device_name = get_sm90_moe_lowercase(device_runtime->get_prop()->name); - if (device_name.find("h200") != std::string::npos) - return Sm90MoeDeviceProfile::H200; - if (device_name.find("h20") != std::string::npos) - return Sm90MoeDeviceProfile::H20; - return Sm90MoeDeviceProfile::Generic; + const int num_sms = device_runtime->get_num_sms(); + if (num_sms <= 80) + return Sm90MoeRuntimeProfile::LowSm; + if (num_sms >= 100) + return Sm90MoeRuntimeProfile::HighSm; + return Sm90MoeRuntimeProfile::Generic; } struct Sm90MoeProfileConfig { @@ -342,7 +342,7 @@ struct Sm90MoeProfileConfig { }; struct Sm90MoeHeuristicPolicy { - Sm90MoeDeviceProfile device_profile; + Sm90MoeRuntimeProfile runtime_profile; int num_experts_per_rank, num_topk, intermediate_hidden; int block_m, block_n; float expected_tokens_per_expert; @@ -368,11 +368,11 @@ struct Sm90MoeHeuristicPolicy { return num_experts_per_rank == 48 and num_topk == 6 and intermediate_hidden == 3072; } - bool h20_main_topk8_profile_config(Sm90MoeProfileConfig& config, - const bool& direct_l2_scatter_enabled, - const bool& eplb_hint, - const bool& skew_hint, - const bool& masked_hint) const { + bool low_sm_main_topk8_profile_config(Sm90MoeProfileConfig& config, + const bool& direct_l2_scatter_enabled, + const bool& eplb_hint, + const bool& skew_hint, + const bool& masked_hint) const { int wave_override = 0; if (expected_tokens_per_expert == 128.0f or (expected_tokens_per_expert >= 256.0f and expected_tokens_per_expert < 512.0f)) { @@ -414,11 +414,8 @@ struct Sm90MoeHeuristicPolicy { return true; } - bool h200_main_topk8_profile_config(Sm90MoeProfileConfig& config) const { - // Calibrated from the H200 0601 sweep. Buckets are keyed by - // expected_tokens_per_expert = M * topk / experts_per_rank. - // The sweep labels include some requested `d0_s5` rows, but BN256 with - // direct L2 scatter disabled is capped at 4 stages by SM90 SMEM limits. + bool high_sm_main_topk8_profile_config(Sm90MoeProfileConfig& config) const { + // Profile buckets keyed by expected_tokens_per_expert. if (expected_tokens_per_expert <= 3.0f) { config = {32, 4, true, true, false}; } else if (expected_tokens_per_expert <= 6.0f) { @@ -457,12 +454,12 @@ struct Sm90MoeHeuristicPolicy { if (not uses_bn256_main_tile() or not is_main_topk8()) return false; - if (device_profile == Sm90MoeDeviceProfile::H20) { - return h20_main_topk8_profile_config( + if (runtime_profile == Sm90MoeRuntimeProfile::LowSm) { + return low_sm_main_topk8_profile_config( config, direct_l2_scatter_enabled, eplb_hint, skew_hint, masked_hint); } - if (device_profile == Sm90MoeDeviceProfile::H200) - return h200_main_topk8_profile_config(config); + if (runtime_profile == Sm90MoeRuntimeProfile::HighSm) + return high_sm_main_topk8_profile_config(config); return false; } @@ -556,7 +553,7 @@ static Sm90MoeHeuristicPolicy get_sm90_moe_heuristic_policy( const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, const int& intermediate_hidden, const int& block_m, const int& block_n) { return { - get_sm90_moe_device_profile(), + get_sm90_moe_runtime_profile(), num_experts_per_rank, num_topk, intermediate_hidden, @@ -675,8 +672,7 @@ static void append_unique_moe_candidate(std::vector& values, const T& value) static std::vector get_sm90_moe_bool_candidates( const std::string& env_name, - const bool& default_value, - const bool& allow_alternative) { + const bool& default_value) { const int forced = get_env(env_name, -1); DG_HOST_ASSERT(forced == -1 or forced == 0 or forced == 1); std::vector values; @@ -685,8 +681,6 @@ static std::vector get_sm90_moe_bool_candidates( return values; } append_unique_moe_candidate(values, default_value ? 1 : 0); - if (allow_alternative) - append_unique_moe_candidate(values, default_value ? 0 : 1); return values; } @@ -729,10 +723,7 @@ static Sm90MoeConfigInfo get_sm90_moe_config_info( const int num_last_blocks = num_blocks % num_sms; const int last_wave_util = num_last_blocks == 0 ? num_sms : num_last_blocks; - // This cost model intentionally mirrors DeepGEMM's style: enumerate legal - // candidates, rank them with cheap shape-derived estimates, and use a small - // empirical calibration layer for modes whose cost is dominated by dispatch / - // combine details not captured by the block-count model. + // Rank legal selector candidates with cheap shape-derived estimates. int empirical_penalty = 0; if (config.direct_l2_scatter != empirical_direct_l2_scatter) empirical_penalty += 1000000; @@ -765,8 +756,6 @@ static std::vector get_mega_moe_config_candidates_sm90( const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, const int& hidden, const int& intermediate_hidden, const int& num_padded_sf_pool_tokens) { - const bool extra_modes = get_env("DG_SM90_MOE_SEARCH_EXTRA_MODES", 1) != 0; - const bool extra_block_shapes = get_env("DG_SM90_MOE_SEARCH_BLOCK_SHAPES", 0) != 0; const float expected_tokens_per_expert = static_cast(num_tokens) * num_topk / num_experts_per_rank; @@ -799,8 +788,6 @@ static std::vector get_mega_moe_config_candidates_sm90( append_unique_moe_candidate(block_m_candidates, requested_mma_m); } else { append_unique_moe_candidate(block_m_candidates, 64); - if (extra_block_shapes) - append_unique_moe_candidate(block_m_candidates, 128); } const int num_max_pool_tokens = layout::get_num_max_pool_tokens( @@ -821,12 +808,8 @@ static std::vector get_mega_moe_config_candidates_sm90( if (block_m == 64 and not use_mma_sync_decode and use_bn256_split_n_env) { append_unique_moe_candidate(block_n_candidates, 256); - if (extra_block_shapes) - append_unique_moe_candidate(block_n_candidates, 128); } else { append_unique_moe_candidate(block_n_candidates, 128); - if (extra_block_shapes and block_m == 64 and not use_mma_sync_decode) - append_unique_moe_candidate(block_n_candidates, 256); } for (const int& block_n: block_n_candidates) { @@ -870,8 +853,6 @@ static std::vector get_mega_moe_config_candidates_sm90( append_unique_moe_candidate(dispatch_warp_candidates, forced_dispatch_warps); } else { append_unique_moe_candidate(dispatch_warp_candidates, compact_frontend ? 2 : 4); - if (extra_block_shapes and not compact_frontend) - append_unique_moe_candidate(dispatch_warp_candidates, 2); } for (const int& num_dispatch_warps: dispatch_warp_candidates) { @@ -900,27 +881,19 @@ static std::vector get_mega_moe_config_candidates_sm90( auto direct_candidates = get_sm90_moe_bool_candidates( "DG_SM90_MOE_DIRECT_L2_SCATTER", - direct_l2_scatter_default and direct_l2_scatter_legal, - extra_modes and direct_l2_scatter_legal); + direct_l2_scatter_default and direct_l2_scatter_legal); auto l2_nmajor_candidates = get_sm90_moe_bool_candidates( "DG_SM90_MOE_L2_NMAJOR", - l2_nmajor_schedule_default, - extra_modes and policy.uses_bn256_main_tile() and policy.is_main_topk8()); + l2_nmajor_schedule_default); auto cleanup_candidates = get_sm90_moe_bool_candidates( "DG_SM90_MOE_ONE_WARP_CLEANUP", - one_warp_cleanup_default, - extra_modes and policy.uses_bn256_main_tile()); + one_warp_cleanup_default); const int default_epw = get_num_experts_per_wave_for_mega_moe_sm90( num_experts_per_rank, num_tokens, num_topk, intermediate_hidden, block_m, block_n, num_sms); std::vector experts_per_wave_candidates; append_unique_moe_candidate(experts_per_wave_candidates, default_epw); - if (get_env("DG_SM90_MOE_EXPERTS_PER_WAVE") <= 0 and extra_modes) { - if (16 <= num_experts_per_rank and num_experts_per_rank % 16 == 0) - append_unique_moe_candidate(experts_per_wave_candidates, 16); - append_unique_moe_candidate(experts_per_wave_candidates, num_experts_per_rank); - } for (const int& direct_value: direct_candidates) { const bool direct_l2_scatter = direct_value != 0; @@ -937,10 +910,6 @@ static std::vector get_mega_moe_config_candidates_sm90( append_unique_moe_candidate(stage_candidates, forced_num_stages); } else { append_unique_moe_candidate(stage_candidates, empirical_stage); - if (extra_modes) { - append_unique_moe_candidate(stage_candidates, 4); - append_unique_moe_candidate(stage_candidates, 5); - } } for (const int& requested_num_stages: stage_candidates) { @@ -996,8 +965,6 @@ static Sm90MoeConfigInfo get_best_mega_moe_config_info_sm90( Sm90MoeConfigInfo best { std::numeric_limits::max(), 0, 0, 0, 0, candidates[0] }; - std::vector infos; - infos.reserve(candidates.size()); for (const auto& candidate: candidates) { const auto policy = get_sm90_moe_heuristic_policy( num_experts_per_rank, num_tokens, num_topk, @@ -1025,25 +992,9 @@ static Sm90MoeConfigInfo get_best_mega_moe_config_info_sm90( empirical_one_warp_cleanup, empirical_num_stages, empirical_num_experts_per_wave); - infos.emplace_back(info); if (info.score < best.score) best = info; } - - if (get_env("DG_SM90_MOE_PRINT_SEARCH", 0) > 0) { - const auto key = fmt::format( - "Sm90MoeConfigSearch(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", - num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); - static std::unordered_set printed; - if (printed.count(key) == 0) { - std::cout << key << ": selected " << best << std::endl; - if (get_env("DG_SM90_MOE_PRINT_SEARCH", 0) > 1) { - for (const auto& info: infos) - std::cout << " candidate " << info << std::endl; - } - printed.insert(key); - } - } return best; } diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index e2c5e026b2..3775220ed1 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -153,7 +153,6 @@ def fp8_fp4_mega_moe(y: torch.Tensor, fast_math ) - def fp8_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], @@ -182,4 +181,3 @@ def fp8_mega_moe(y: torch.Tensor, activation, activation_clamp, fast_math ) - diff --git a/tests/test_mega_moe_sm90.py b/tests/test_mega_moe_sm90.py index 33e5a1c3ef..e5356d1f3c 100644 --- a/tests/test_mega_moe_sm90.py +++ b/tests/test_mega_moe_sm90.py @@ -8,11 +8,9 @@ ------ L1 Smoke : single tiny config; only verifies the kernel runs and produces an output close to a PyTorch reference. - L2 Heuristic : sweeps tokens-per-expert across the bands of - the SM90 config selector so the main branch buckets - are covered. - L3 Shape sweep : sweeps ``hidden``, ``intermediate_hidden`` and - ``num_topk`` over divisible-by-128 values. + L2 Heuristic : covers tokens-per-expert bands of the SM90 selector. + L3 Shape coverage : covers divisible-by-128 ``hidden``, + ``intermediate_hidden`` and ``num_topk`` values. L4 Edge cases : masking ratio, activation clamp (finite vs inf), ``fast_math`` 0/1, ``num_tokens`` boundaries. L5 Stress : ``--num-correctness-tests`` repeated random configs. @@ -91,6 +89,10 @@ def _dequant_per_token_per_128_k(x_fp8: torch.Tensor, sf: torch.Tensor) -> torch return (w_view * sf.unsqueeze(-1)).view(m, k) +def _stable_name_seed(name: str) -> int: + return sum((i + 1) * ord(ch) for i, ch in enumerate(name)) % 1000 + + # ---------------------------------------------------------------------------- # PyTorch reference # ---------------------------------------------------------------------------- @@ -258,8 +260,9 @@ def _trace(stage: str): print(f'[rank{rank_idx}] {name} :: {stage}', flush=True) _trace('begin') - torch.manual_seed(rank_idx * 1000 + abs(hash(name)) % 1000) - random.seed(rank_idx * 1000 + abs(hash(name)) % 1000) + seed = rank_idx * 1000 + _stable_name_seed(name) + torch.manual_seed(seed) + random.seed(seed) # ---- Inputs (bf16) ------------------------------------------------------- x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') @@ -276,7 +279,6 @@ def _trace(stage: str): topk_idx.masked_fill_(rand_mask < masked_ratio, -1) topk_w.masked_fill_(topk_idx < 0, 0) - # Quantize x to FP8 with per-128 K float SF (SM90 format) # Quantize x to FP8 with per-128 K float SF (SM90 format) x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False) @@ -368,25 +370,24 @@ def _layer1_smoke() -> List[Tuple[str, Dict[str, Any]]]: def _layer2_heuristic_branches(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: - """Vary tokens / (num_experts * num_topk / num_ranks) so the selector's - main tokens-per-expert buckets fire at least once. - - The heuristic decides on ``avg_tokens_per_expert``; we approximate by - setting ``num_max_tokens_per_rank`` and ``num_topk`` while keeping - ``num_experts`` fixed. The bands are at 64.5 / 96.5 / 192.5. - """ + """Cover generic heuristic bands and the main topk8 profile selector.""" base = dict(hidden=1024, intermediate_hidden=1024, num_experts=8 * num_ranks, num_topk=2) out: List[Tuple[str, Dict[str, Any]]] = [] - # tokens-per-rank settings chosen to hit (small / mid / large) bands for tokens, label in [(64, 'small'), (256, 'midA'), (512, 'midB'), (2048, 'large')]: cfg = dict(base) cfg.update(num_max_tokens_per_rank=tokens, num_tokens=tokens) out.append((f'L2.heur.{label}.t{tokens}', cfg)) + profile_base = dict(hidden=512, intermediate_hidden=2048, + num_experts=32 * num_ranks, num_topk=8) + for tokens in (16, 64, 260, 1024): + cfg = dict(profile_base) + cfg.update(num_max_tokens_per_rank=tokens, num_tokens=tokens) + out.append((f'L2.profile_topk8.t{tokens}', cfg)) return out -def _layer3_shape_sweep(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: +def _layer3_shape_cases(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: out: List[Tuple[str, Dict[str, Any]]] = [] base_experts = 8 * num_ranks for hidden in (512, 2048): @@ -471,7 +472,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if 2 in args.layers: layers += _layer2_heuristic_branches(num_ranks) if 3 in args.layers: - layers += _layer3_shape_sweep(num_ranks) + layers += _layer3_shape_cases(num_ranks) if 4 in args.layers: layers += _layer4_edges(num_ranks) if 5 in args.layers: From 52e34b2d0f72e89823e319f8be61ffcd6fd1c18a Mon Sep 17 00:00:00 2001 From: aichenf Date: Mon, 8 Jun 2026 17:36:16 +0800 Subject: [PATCH 14/15] Remove SM90 MegaMoE MMA sync path --- csrc/jit_kernels/heuristics/mega_moe.hpp | 36 +-- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 256 +----------------- 2 files changed, 23 insertions(+), 269 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 95f1700e9d..0d195874e9 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -604,26 +604,21 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( // C/D output region: max of L1 FP8 (single-buffered, BLOCK_N/2 post-SwiGLU) // and L2 BF16, then 1024-byte aligned (matches kernel's SMEM_CD_SIZE). - // The mma.sync decode path additionally stages one BLOCK_M x BLOCK_N FP32 - // accumulator tile in SMEM for logical-row epilogue mapping. const auto num_epilogue_warpgroups = num_epilogue_warps / 4; const bool split_n_warpgroups = block_m == 64 and block_n == 256 and num_epilogue_warpgroups == 2; const bool serial_n_warpgroups = false; const int wg_block_m = split_n_warpgroups ? block_m : block_m / num_epilogue_warpgroups; const int wg_block_n = (split_n_warpgroups or serial_n_warpgroups) ? block_n / 2 : block_n; - const int smem_cd_accum = (block_m == 16 or block_m == 32) ? align(block_m * block_n * static_cast(sizeof(float)), kSmemAlignment) : 0; const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) const bool direct_l2_scatter = direct_l2_scatter_enabled and - block_m != 16 and block_m != 32 and not serial_n_warpgroups and wg_block_n == 128; const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and - block_m != 16 and block_m != 32 and not split_n_warpgroups and num_epilogue_warpgroups == 1; const int smem_cd_l2 = direct_l2_scatter ? 0 : num_epilogue_warpgroups * wg_block_m * wg_block_n * static_cast(sizeof(nv_bfloat16)); const int smem_cd_l1_async = async_l1_tma_store ? 2 * num_epilogue_warpgroups * wg_block_m * (block_n / 2) : 0; - const int smem_cd = smem_cd_accum + align(std::max(std::max(smem_cd_l1, smem_cd_l2), smem_cd_l1_async), kSmemAlignment); + const int smem_cd = align(std::max(std::max(smem_cd_l1, smem_cd_l2), smem_cd_l1_async), kSmemAlignment); // SF on SM90: // * SFA per stage must hold the larger of L1 (BLOCK_M floats, per-128 K) @@ -756,13 +751,6 @@ static std::vector get_mega_moe_config_candidates_sm90( const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, const int& hidden, const int& intermediate_hidden, const int& num_padded_sf_pool_tokens) { - const float expected_tokens_per_expert = - static_cast(num_tokens) * num_topk / num_experts_per_rank; - - const int requested_mma_m = get_env("DG_SM90_MOE_MMA_SYNC_M") > 0 - ? get_env("DG_SM90_MOE_MMA_SYNC_M") - : (get_env("DG_SM90_MOE_MMA_SYNC") != 0 ? 16 : 0); - DG_HOST_ASSERT(requested_mma_m == 0 or requested_mma_m == 16 or requested_mma_m == 32); const int forced_block_m = get_env("DG_SM90_MOE_FORCE_BLOCK_M"); const int forced_epilogue_warpgroups = get_env("DG_SM90_MOE_FORCE_EPILOGUE_WG"); DG_HOST_ASSERT(forced_block_m == 0 or forced_block_m == 64 or forced_block_m == 128); @@ -770,13 +758,11 @@ static std::vector get_mega_moe_config_candidates_sm90( forced_epilogue_warpgroups == 1 or forced_epilogue_warpgroups == 2); - const bool use_mma_sync_decode = - requested_mma_m > 0 and expected_tokens_per_expert <= static_cast(requested_mma_m); const bool use_b_stationary_2wg = - get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0 and not use_mma_sync_decode; + get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; const bool use_bn256_split_n_env = get_env("DG_SM90_MOE_BN256_2WG", 1) != 0 and - forced_block_m != 128 and not use_mma_sync_decode; + forced_block_m != 128; DG_HOST_ASSERT(not (use_b_stationary_2wg and use_bn256_split_n_env)); std::vector block_m_candidates; @@ -784,8 +770,6 @@ static std::vector get_mega_moe_config_candidates_sm90( append_unique_moe_candidate(block_m_candidates, forced_block_m); } else if (use_b_stationary_2wg) { append_unique_moe_candidate(block_m_candidates, 128); - } else if (use_mma_sync_decode) { - append_unique_moe_candidate(block_m_candidates, requested_mma_m); } else { append_unique_moe_candidate(block_m_candidates, 64); } @@ -805,8 +789,7 @@ static std::vector get_mega_moe_config_candidates_sm90( ); std::vector block_n_candidates; - if (block_m == 64 and not use_mma_sync_decode and - use_bn256_split_n_env) { + if (block_m == 64 and use_bn256_split_n_env) { append_unique_moe_candidate(block_n_candidates, 256); } else { append_unique_moe_candidate(block_n_candidates, 128); @@ -824,8 +807,6 @@ static std::vector get_mega_moe_config_candidates_sm90( for (const int& num_epilogue_warpgroups: epilogue_wg_candidates) { if (block_m % num_epilogue_warpgroups != 0) continue; - if ((block_m == 16 or block_m == 32) and num_epilogue_warpgroups != 1) - continue; if (block_m == 128 and num_epilogue_warpgroups != 2) continue; if (block_m == 64 and block_n == 256 and num_epilogue_warpgroups != 2) @@ -837,8 +818,8 @@ static std::vector get_mega_moe_config_candidates_sm90( (block_m == 128 and block_n == 128 and num_epilogue_threads == 256))) continue; const int cluster_size = use_cluster_bcast_b ? 2 : 1; - const int swizzle_acts_mode = (block_m == 16 or block_m == 32) ? 0 : 128; - const int swizzle_weights_mode = (block_m == 16 or block_m == 32) ? 0 : 128; + const int swizzle_acts_mode = 128; + const int swizzle_weights_mode = 128; const bool prefer_compact_frontend = block_n == 256 and not split_sfa_tma; @@ -875,9 +856,8 @@ static std::vector get_mega_moe_config_candidates_sm90( const bool one_warp_cleanup_default = policy.one_warp_cleanup( get_env("DG_SM90_MOE_MASKED_HINT", 0) != 0); const bool direct_l2_scatter_legal = - block_m != 16 and block_m != 32 and - ((block_m == 64 and block_n == 256 and num_epilogue_warpgroups == 2) or - block_n == 128); + (block_m == 64 and block_n == 256 and num_epilogue_warpgroups == 2) or + block_n == 128; auto direct_candidates = get_sm90_moe_bool_candidates( "DG_SM90_MOE_DIRECT_L2_SCATTER", diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index 89c7713588..34bf88a569 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -190,8 +190,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); DG_STATIC_ASSERT(kClusterSize == 1 or kClusterSize == 2, "Invalid cluster size"); DG_STATIC_ASSERT(kNumSMs % kClusterSize == 0, "SM count must be divisible by cluster size"); - DG_STATIC_ASSERT(BLOCK_M == 16 or BLOCK_M == 32 or BLOCK_M % 64 == 0, - "BLOCK_M must be 16/32 for mma.sync decode or a multiple of WGMMA::M (64)"); + DG_STATIC_ASSERT(BLOCK_M % 64 == 0, + "BLOCK_M must be a multiple of WGMMA::M (64)"); DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256, "BLOCK_N must be 64/128/256 for this SM90 path"); DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); @@ -259,25 +259,24 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { // ===================================================================== using a_dtype_t = cutlass::float_e4m3_t; using b_dtype_t = cutlass::float_e4m3_t; - constexpr bool kUseMMASync = (BLOCK_M == 16 or BLOCK_M == 32); constexpr bool kSplitNWarpgroups = - (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 2; + BLOCK_N == 256 && kNumEpilogueWarpgroups == 2; constexpr bool kSerialNWarpgroups = false; constexpr bool kWideNWarpgroups = - (!kUseMMASync) && BLOCK_N == 256 && kNumEpilogueWarpgroups == 1; + BLOCK_N == 256 && kNumEpilogueWarpgroups == 1; constexpr uint32_t WG_BLOCK_M = kSplitNWarpgroups ? BLOCK_M : BLOCK_M / kNumEpilogueWarpgroups; constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && MegaMoEPhase::runs_linear1 && (!kUseMMASync) && + kAsyncL1TMAStoreRequested && MegaMoEPhase::runs_linear1 && (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; - constexpr bool kSplitSFATMA = kSplitSFATMARequested && (!kUseMMASync); - constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && MegaMoEPhase::runs_linear2 && (!kUseMMASync) && + constexpr bool kSplitSFATMA = kSplitSFATMARequested; + constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && MegaMoEPhase::runs_linear2 && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL2DualAccum = kL2DualAccumRequested && MegaMoEPhase::runs_linear2 && (!kUseMMASync) && + constexpr bool kL2DualAccum = kL2DualAccumRequested && MegaMoEPhase::runs_linear2 && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL1DualKAccum = kL1DualKAccumRequested && MegaMoEPhase::runs_linear1 && (!kUseMMASync) && + constexpr bool kL1DualKAccum = kL1DualKAccumRequested && MegaMoEPhase::runs_linear1 && (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && (kHidden / BLOCK_K) % 2 == 0; using L1WGMMA = typename mma::sm90::FP8MMASelector::type; @@ -291,8 +290,8 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { // M blocks with identical expert/N/K coordinates so the B TMA can multicast. constexpr uint32_t LOAD_BLOCK_M = BLOCK_M; constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; - constexpr uint32_t kSwizzleAMode = kUseMMASync ? 0 : BLOCK_K * sizeof(a_dtype_t); // 0 or 128 - constexpr uint32_t kSwizzleBMode = kUseMMASync ? 0 : BLOCK_K * sizeof(b_dtype_t); // 0 or 128 + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128 + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128 constexpr uint32_t kSwizzleCDMode = 128; constexpr uint32_t kGranK = 128; // L1 acts SF, weights SF constexpr uint32_t kL2ActsSFGranK = 64; // L2 acts SF (per-64 K, SM90 only) @@ -312,8 +311,7 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 - // (two per-64-K halves). Each TMA destination must be 128B aligned, so - // the second L2 half cannot start immediately after 16 floats in M16 decode. + // (two per-64-K halves). Each TMA destination must be 128B aligned. constexpr uint32_t kL2SFAHalfStride = math::constexpr_align(BLOCK_M * sizeof(float), 128u) / sizeof(float); constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = 2 * kL2SFAHalfStride * sizeof(float); @@ -324,9 +322,7 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { // CD output: max of L1 FP8 (BLOCK_M * (BLOCK_N/2) * 1 byte * num_wg) and // L2 BF16 (BLOCK_M * BLOCK_N * 2 bytes * num_wg). - constexpr uint32_t SMEM_CD_ACCUM_SIZE = kUseMMASync - ? math::constexpr_align(BLOCK_M * BLOCK_N * sizeof(float), kSharedMemoryAlignment) - : 0u; + constexpr uint32_t SMEM_CD_ACCUM_SIZE = 0u; constexpr uint32_t SMEM_CD_L1_SIZE = MegaMoEPhase::runs_linear1 ? kNumEpilogueWarpgroups * WG_BLOCK_M * WG_L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) : 0u; constexpr uint32_t SMEM_CD_L2_SIZE = (!MegaMoEPhase::runs_linear2 || kDirectL2Scatter) ? 0u : @@ -369,9 +365,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { auto smem_gemm_base = math::advance_ptr( smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE); - // mma.sync decode stages FP32 accumulators through SMEM so the epilogue can - // use logical row/column indices instead of WGMMA accumulator layout. - auto smem_accum_f32 = reinterpret_cast(smem_gemm_base); auto smem_cd_base = math::advance_ptr(smem_gemm_base, SMEM_CD_ACCUM_SIZE); // CD output is shared by L1 (FP8) and L2 (BF16); reinterpret-cast as needed. auto smem_cd_l1 = reinterpret_cast(smem_cd_base); @@ -472,7 +465,7 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { constexpr uint32_t kNumDispatchRegisters = 48; constexpr bool kCompactFrontendWarpgroup = (kNumDispatchWarps == 2 and kNumMMANonEpilogueWarps == 2); constexpr uint32_t kNumNonEpilogueRegisters = kCompactFrontendWarpgroup ? kNumDispatchRegisters : 40; - constexpr uint32_t kNumEpilogueRegisters = (kSerialNWarpgroups or kWideNWarpgroups) ? 256 : ((kUseMMASync and BLOCK_M == 32) ? 240 : 208); + constexpr uint32_t kNumEpilogueRegisters = (kSerialNWarpgroups or kWideNWarpgroups) ? 256 : 208; DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, @@ -1051,10 +1044,7 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { const uint32_t r_1 = r_0 + 8; DG_STATIC_ASSERT(kSplitNWarpgroups || (BLOCK_M % kNumEpilogueWarpgroups == 0), "Invalid block M"); - if constexpr (kUseMMASync) { - DG_STATIC_ASSERT(WG_BLOCK_M == BLOCK_M, "mma.sync decode path uses one M tile per warpgroup"); - DG_STATIC_ASSERT(kNumEpilogueWarpgroups == 1, "mma.sync decode path currently uses one math warpgroup"); - } else if constexpr (kSplitNWarpgroups) { + if constexpr (kSplitNWarpgroups) { DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M and WG_BLOCK_N == L1WGMMA::N, "Split-N WGs must each run one M64N128 WGMMA per K-block"); } else if constexpr (kSerialNWarpgroups) { @@ -1092,222 +1082,6 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { drain_all_async_l1_stores(); } - if constexpr (kUseMMASync) { - using MMASyncTiled = cute::TiledMMA< - cute::MMA_Atom, - cute::Layout, cute::_4, cute::_1>>>; - auto mma_sync_tiled = MMASyncTiled{}; - auto thr_mma = mma_sync_tiled.get_thread_slice(epilogue_thread_idx); - auto sC = cute::make_tensor( - cute::make_smem_ptr(smem_accum_f32), - cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); - auto tCsC = thr_mma.partition_C(sC); - auto tCrC = thr_mma.make_fragment_C(tCsC); - auto tCrFinal = thr_mma.make_fragment_C(tCsC); - cute::clear(tCrFinal); - auto cC = cute::make_identity_tensor(cute::shape(sC)); - auto tCcC = thr_mma.partition_C(cC); - - auto add_scaled_fragment = [&](const float& scale_b, const uint32_t& sfa_offset) { - CUTE_UNROLL - for (int i = 0; i < cute::size(tCrC); ++i) { - const auto coord = tCcC(i); - if (cute::elem_less(coord, cute::shape(sC))) { - const uint32_t row = static_cast(cute::get<0>(coord)); - const uint32_t col = static_cast(cute::get<1>(coord)); - const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + sfa_offset + row); - tCrFinal(i) += scale_a * scale_b * tCrC(i); - } - } - }; - - for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { - full_barriers[stage_idx]->wait(phase); - - constexpr uint32_t kL1SFKBlocks = kHidden / 128; - constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; - constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; - constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; - - if (is_linear1_phase) { - const uint32_t gate_n = (n_block_idx * BLOCK_N + wg_n_idx) / 256u; - const uint32_t up_n = kL1SFGateBlks + gate_n; - const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; - const float gate_sf = __ldg(base + gate_n * kL1SFKBlocks); - const float up_sf = __ldg(base + up_n * kL1SFKBlocks); - - auto sA = cute::make_tensor( - cute::make_smem_ptr(smem_a[stage_idx]), - cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<128>{}), cute::GenRowMajor{})); - auto sB = cute::make_tensor( - cute::make_smem_ptr(smem_b[stage_idx]), - cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<128>{}), cute::GenRowMajor{})); - cute::clear(tCrC); - cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); - - CUTE_UNROLL - for (int i = 0; i < cute::size(tCrC); ++i) { - const auto coord = tCcC(i); - if (cute::elem_less(coord, cute::shape(sC))) { - const uint32_t row = static_cast(cute::get<0>(coord)); - const uint32_t col = static_cast(cute::get<1>(coord)); - const float scale_a = ptx::ld_shared(smem_sfa[stage_idx] + row); - const float scale_b = ((col / 8u) & 1u) ? up_sf : gate_sf; - tCrFinal(i) += scale_a * scale_b * tCrC(i); - } - } - } else { - const float l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert - + ((n_block_idx * BLOCK_N + wg_n_idx) / 128u) * kL2SFKBlocks + k_block_idx); - CUTE_UNROLL - for (uint32_t half = 0; half < 2; ++half) { - const uint32_t k_off = half * (BLOCK_K / 2); - auto sA = cute::make_tensor( - cute::make_smem_ptr(smem_a[stage_idx] + k_off), - cute::make_layout(cute::make_shape(cute::Int{}, cute::Int<64>{}), - cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); - auto sB = cute::make_tensor( - cute::make_smem_ptr(smem_b[stage_idx] + wg_n_idx * BLOCK_K + k_off), - cute::make_layout(cute::make_shape(cute::Int<128>{}, cute::Int<64>{}), - cute::make_stride(cute::Int<128>{}, cute::Int<1>{}))); - cute::clear(tCrC); - cute::cooperative_gemm(epilogue_thread_idx, mma_sync_tiled, sA, sB, tCrC); - add_scaled_fragment(l2_sf, half * kL2SFAHalfStride); - } - } - - arrive_empty_barrier(stage_idx); - __syncwarp(); - } - - CUTE_UNROLL - for (int i = 0; i < cute::size(tCrFinal); ++i) { - const auto coord = tCcC(i); - if (cute::elem_less(coord, cute::shape(sC))) { - const uint32_t row = static_cast(cute::get<0>(coord)); - const uint32_t col = static_cast(cute::get<1>(coord)); - smem_accum_f32[row * BLOCK_N + col] = tCrFinal(i); - } - } - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - - constexpr uint32_t kMMASyncRowsPerPass = kNumEpilogueThreads / 8; - DG_STATIC_ASSERT(kMMASyncRowsPerPass == 16, "mma.sync epilogue maps 8 lanes per row"); - - if (is_linear1_phase) { - #pragma unroll - for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { - const uint32_t row = row_base + epilogue_thread_idx / 8; - const uint32_t lane_in_row = epilogue_thread_idx % 8; - const bool valid_row = row < valid_m; - float swiglu_values[8]; - float amax = 0.0f; - - #pragma unroll - for (uint32_t p = 0; p < 8; ++p) { - const uint32_t gate_col = p * 16 + lane_in_row; - const uint32_t up_col = gate_col + 8; - float gate = valid_row ? smem_accum_f32[row * BLOCK_N + gate_col] : 0.0f; - float up = valid_row ? smem_accum_f32[row * BLOCK_N + up_col] : 0.0f; - if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { - gate = cute::min(gate, kActivationClamp); - up = cute::min(cute::max(up, -kActivationClamp), kActivationClamp); - } - const float e = kFastMath ? __expf(-gate) : expf(-gate); - const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); - const float weight = valid_row ? *l1_topk_weights_buffer - .get_data_buffer(m_idx + row) - .get_base_ptr() : 0.0f; - const float value = gate * sig * up * weight; - swiglu_values[p] = value; - amax = cute::max(amax, cute::abs(value)); - } - - amax = math::warp_reduce<8, false>(amax, math::ReduceMax()); - float2 amax_pair = {amax, amax}; - float2 sf_pair, sf_inv_pair; - math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); - const float sf = sf_pair.x; - const float sf_inv = sf_inv_pair.x; - - if (valid_row) { - #pragma unroll - for (uint32_t p = 0; p < 8; ++p) { - const uint32_t out_col = p * 8 + lane_in_row; - smem_cd_l1[row * L1_OUT_BLOCK_N + out_col] = - cutlass::float_e4m3_t(swiglu_values[p] * sf_inv); - } - if (lane_in_row == 0) { - auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); - const uint32_t token_idx = pool_block_idx * BLOCK_M + row; - const uint32_t k_sf_idx = n_block_idx; - sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_idx] = sf; - } - } - } - - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; - cute::tma_store_fence(); - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_l1_output, - smem_cd_l1, - out_n_idx, - m_idx); - cute::tma_store_arrive(); - } - __syncwarp(); - ptx::tma_store_wait<0>(); - - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - ptx::red_or_rel_gpu( - workspace.get_l2_arrival_mask_ptr(pool_block_idx), - 1ull << n_block_idx); - } - __syncwarp(); - } else { - #pragma unroll - for (uint32_t row_base = 0; row_base < BLOCK_M; row_base += kMMASyncRowsPerPass) { - const uint32_t row = row_base + epilogue_thread_idx / 8; - const uint32_t lane_in_row = epilogue_thread_idx % 8; - if (row < valid_m) { - const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + row); - const uint32_t dst_rank_idx = src_metadata.rank_idx; - const uint32_t dst_token_idx = src_metadata.token_idx; - const uint32_t dst_topk_idx = src_metadata.topk_idx; - const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) - .get_data_buffer(dst_token_idx); - #pragma unroll - for (uint32_t v = 0; v < 2; ++v) { - const uint32_t col = lane_in_row * 16 + v * 8; - uint4 packed; - float f0 = smem_accum_f32[row * BLOCK_N + col + 0]; - float f1 = smem_accum_f32[row * BLOCK_N + col + 1]; - float f2 = smem_accum_f32[row * BLOCK_N + col + 2]; - float f3 = smem_accum_f32[row * BLOCK_N + col + 3]; - float f4 = smem_accum_f32[row * BLOCK_N + col + 4]; - float f5 = smem_accum_f32[row * BLOCK_N + col + 5]; - float f6 = smem_accum_f32[row * BLOCK_N + col + 6]; - float f7 = smem_accum_f32[row * BLOCK_N + col + 7]; - packed.x = math::cast_into_bf16_and_pack(f0, f1); - packed.y = math::cast_into_bf16_and_pack(f2, f3); - packed.z = math::cast_into_bf16_and_pack(f4, f5); - packed.w = math::cast_into_bf16_and_pack(f6, f7); - auto dst_ptr = math::advance_ptr( - dst_token.get_base_ptr(), - n_idx * sizeof(nv_bfloat16) + col * sizeof(nv_bfloat16)); - *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; - } - } - } - ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); - } - return; - } - if constexpr (kSerialNWarpgroups) { using WGMMA = L1WGMMA; constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; From bdb43956aef93eff7884a9f1ee4f1732b0ea0ba0 Mon Sep 17 00:00:00 2001 From: AichenF Date: Tue, 9 Jun 2026 12:07:02 +0800 Subject: [PATCH 15/15] Prune unused SM90 MegaMoE env knobs --- csrc/jit_kernels/heuristics/mega_moe.hpp | 40 ++++--------------- csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp | 20 ---------- .../deep_gemm/impls/sm90_fp8_mega_moe.cuh | 27 ++++--------- 3 files changed, 16 insertions(+), 71 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index 0d195874e9..2fb02077c3 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -612,8 +612,7 @@ static std::pair get_pipeline_config_for_mega_moe_sm90( const int smem_cd_l1 = num_epilogue_warpgroups * wg_block_m * (wg_block_n / 2); // 1 byte/elem (FP8) const bool direct_l2_scatter = direct_l2_scatter_enabled and not serial_n_warpgroups and wg_block_n == 128; - const bool async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0 and - not split_n_warpgroups and num_epilogue_warpgroups == 1; + const bool async_l1_tma_store = false; const int smem_cd_l2 = direct_l2_scatter ? 0 : num_epilogue_warpgroups * wg_block_m * wg_block_n * static_cast(sizeof(nv_bfloat16)); const int smem_cd_l1_async = async_l1_tma_store ? @@ -758,27 +757,16 @@ static std::vector get_mega_moe_config_candidates_sm90( forced_epilogue_warpgroups == 1 or forced_epilogue_warpgroups == 2); - const bool use_b_stationary_2wg = - get_env("DG_SM90_MOE_B_STATIONARY_2WG") != 0; const bool use_bn256_split_n_env = get_env("DG_SM90_MOE_BN256_2WG", 1) != 0 and forced_block_m != 128; - DG_HOST_ASSERT(not (use_b_stationary_2wg and use_bn256_split_n_env)); std::vector block_m_candidates; - if (forced_block_m > 0) { - append_unique_moe_candidate(block_m_candidates, forced_block_m); - } else if (use_b_stationary_2wg) { - append_unique_moe_candidate(block_m_candidates, 128); - } else { - append_unique_moe_candidate(block_m_candidates, 64); - } + append_unique_moe_candidate(block_m_candidates, forced_block_m > 0 ? forced_block_m : 64); const int num_max_pool_tokens = layout::get_num_max_pool_tokens( num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); const int block_k = 128; - const bool split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0; - const bool use_cluster_bcast_b = get_env("DG_SM90_MOE_CLUSTER_BCAST_B") != 0 or use_b_stationary_2wg; const int num_sms = device_runtime->get_num_sms(); std::vector candidates; @@ -800,8 +788,7 @@ static std::vector get_mega_moe_config_candidates_sm90( if (forced_epilogue_warpgroups > 0) { append_unique_moe_candidate(epilogue_wg_candidates, forced_epilogue_warpgroups); } else { - append_unique_moe_candidate(epilogue_wg_candidates, - (block_m == 128 or use_b_stationary_2wg or block_n == 256) ? 2 : 1); + append_unique_moe_candidate(epilogue_wg_candidates, block_n == 256 ? 2 : 1); } for (const int& num_epilogue_warpgroups: epilogue_wg_candidates) { @@ -813,34 +800,23 @@ static std::vector get_mega_moe_config_candidates_sm90( continue; const int num_epilogue_threads = num_epilogue_warpgroups * 128; - if (use_cluster_bcast_b and - not ((block_m == 64 and block_n == 128 and num_epilogue_threads == 128) or - (block_m == 128 and block_n == 128 and num_epilogue_threads == 256))) - continue; - const int cluster_size = use_cluster_bcast_b ? 2 : 1; + const int cluster_size = 1; const int swizzle_acts_mode = 128; const int swizzle_weights_mode = 128; - const bool prefer_compact_frontend = - block_n == 256 and not split_sfa_tma; - const bool compact_frontend = get_env("DG_SM90_MOE_COMPACT_FRONTEND", - prefer_compact_frontend ? 1 : 0) != 0; + const bool compact_frontend = block_n == 256; const int forced_dispatch_warps = get_env("DG_SM90_MOE_DISPATCH_WARPS", -1); DG_HOST_ASSERT(forced_dispatch_warps == -1 or forced_dispatch_warps == 0 or forced_dispatch_warps == 2 or forced_dispatch_warps == 4 or forced_dispatch_warps == 8); std::vector dispatch_warp_candidates; - if (forced_dispatch_warps > 0) { - append_unique_moe_candidate(dispatch_warp_candidates, forced_dispatch_warps); - } else { - append_unique_moe_candidate(dispatch_warp_candidates, compact_frontend ? 2 : 4); - } + append_unique_moe_candidate(dispatch_warp_candidates, + forced_dispatch_warps > 0 ? forced_dispatch_warps : + (compact_frontend ? 2 : 4)); for (const int& num_dispatch_warps: dispatch_warp_candidates) { if (compact_frontend and num_dispatch_warps != 2) continue; - if (split_sfa_tma and compact_frontend) - continue; const int num_dispatch_threads = num_dispatch_warps * 32; const int num_non_epilogue_threads = compact_frontend ? 64 : 128; if ((num_dispatch_threads + num_non_epilogue_threads) % 128 != 0) diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp index 15fcbdbc28..df63142e93 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -44,14 +44,9 @@ class SM90FP8MegaMoERuntime final : public LaunchRuntime int num_ranks; float activation_clamp; bool fast_math; - bool async_l1_tma_store; - bool split_sfa_tma; bool direct_l2_scatter; - bool l2_dual_accum; bool phase_profile; - bool l1_dual_k_accum; bool l2_nmajor_schedule; - bool l1_nmajor_schedule; bool one_warp_cleanup; KernelPhase kernel_phase; MegaMoESM90Config config; @@ -105,11 +100,6 @@ static void __instantiate_kernel() {{ {}, {}, {}, - {}, - {}, - {}, - {}, - {}, {} >); }}; @@ -128,14 +118,9 @@ static void __instantiate_kernel() {{ args.launch_args.grid_dim.first, args.num_ranks, to_string(args.activation_clamp), args.fast_math ? "true" : "false", - args.async_l1_tma_store ? "true" : "false", - args.split_sfa_tma ? "true" : "false", args.direct_l2_scatter ? "true" : "false", - args.l2_dual_accum ? "true" : "false", args.phase_profile ? "true" : "false", - args.l1_dual_k_accum ? "true" : "false", args.l2_nmajor_schedule ? "true" : "false", - args.l1_nmajor_schedule ? "true" : "false", args.one_warp_cleanup ? "true" : "false"); } @@ -248,14 +233,9 @@ static void sm90_fp8_mega_moe( .num_ranks = num_ranks, .activation_clamp = activation_clamp, .fast_math = fast_math, - .async_l1_tma_store = get_env("DG_SM90_MOE_ASYNC_L1_STORE", 0) != 0, - .split_sfa_tma = get_env("DG_SM90_MOE_SPLIT_SFA_TMA", 0) != 0, .direct_l2_scatter = config.direct_l2_scatter, - .l2_dual_accum = get_env("DG_SM90_MOE_L2_DUAL_ACCUM", 0) != 0, .phase_profile = get_env("DG_SM90_MOE_PHASE_PROFILE", 0) != 0, - .l1_dual_k_accum = get_env("DG_SM90_MOE_L1_DUAL_K", 0) != 0, .l2_nmajor_schedule = config.l2_nmajor_schedule, - .l1_nmajor_schedule = get_env("DG_SM90_MOE_L1_NMAJOR", 0) != 0, .one_warp_cleanup = config.one_warp_cleanup, .kernel_phase = SM90FP8MegaMoERuntime::KernelPhase::Linear1, .config = config, diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh index 34bf88a569..865f1e670a 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -102,14 +102,9 @@ using MegaMoELinear2Phase = MegaMoEPhasePolicy; uint32_t kNumSMs, uint32_t kNumRanks, \ float kActivationClamp, \ bool kFastMath, \ - bool kAsyncL1TMAStoreRequested = true, \ - bool kSplitSFATMARequested = false, \ bool kDirectL2ScatterRequested = false, \ - bool kL2DualAccumRequested = false, \ bool kPhaseProfileRequested = false, \ - bool kL1DualKAccumRequested = false, \ bool kL2NMajorScheduleRequested = false, \ - bool kL1NMajorScheduleRequested = false, \ bool kOneWarpCleanupRequested = false, \ uint32_t L1_SHAPE_N = kIntermediateHidden * 2, \ uint32_t L1_SHAPE_K = kHidden, \ @@ -165,10 +160,9 @@ using MegaMoELinear2Phase = MegaMoEPhasePolicy; kNumExpertsPerWave, BLOCK_M, BLOCK_N, BLOCK_K, kNumMaxPoolTokens, \ kNumPaddedSFPoolTokens, kNumStages, kNumDispatchThreads, \ kNumNonEpilogueThreads, kNumEpilogueThreads, kClusterSize, kNumSMs, \ - kNumRanks, kActivationClamp, kFastMath, kAsyncL1TMAStoreRequested, \ - kSplitSFATMARequested, kDirectL2ScatterRequested, kL2DualAccumRequested, \ - kPhaseProfileRequested, kL1DualKAccumRequested, kL2NMajorScheduleRequested, \ - kL1NMajorScheduleRequested, kOneWarpCleanupRequested, L1_SHAPE_N, \ + kNumRanks, kActivationClamp, kFastMath, kDirectL2ScatterRequested, \ + kPhaseProfileRequested, kL2NMajorScheduleRequested, kOneWarpCleanupRequested, \ + L1_SHAPE_N, \ L1_SHAPE_K, L2_SHAPE_N, L2_SHAPE_K, kNumDispatchWarps, \ kNumMMANonEpilogueWarps, kNumEpilogueWarps, kNumEpilogueWarpgroups, \ kNumThreads, kNumTokensPerWarp, kNumExpertsPerRank @@ -268,17 +262,12 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { constexpr uint32_t WG_BLOCK_N = (kSplitNWarpgroups || kSerialNWarpgroups) ? BLOCK_N / 2 : BLOCK_N; constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU tile N constexpr uint32_t WG_L1_OUT_BLOCK_N = WG_BLOCK_N / 2; // post-SwiGLU per-WG N - constexpr bool kAsyncL1TMAStore = - kAsyncL1TMAStoreRequested && MegaMoEPhase::runs_linear1 && - (!kSplitNWarpgroups) && kNumEpilogueWarpgroups == 1; - constexpr bool kSplitSFATMA = kSplitSFATMARequested; + constexpr bool kAsyncL1TMAStore = false; + constexpr bool kSplitSFATMA = false; constexpr bool kDirectL2Scatter = kDirectL2ScatterRequested && MegaMoEPhase::runs_linear2 && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL2DualAccum = kL2DualAccumRequested && MegaMoEPhase::runs_linear2 && - (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128; - constexpr bool kL1DualKAccum = kL1DualKAccumRequested && MegaMoEPhase::runs_linear1 && - (!kSplitNWarpgroups) && (!kSerialNWarpgroups) && WG_BLOCK_N == 128 && - (kHidden / BLOCK_K) % 2 == 0; + constexpr bool kL2DualAccum = false; + constexpr bool kL1DualKAccum = false; using L1WGMMA = typename mma::sm90::FP8MMASelector::type; using L2WGMMA = typename mma::sm90::FP8MMASelector::type; static_assert(L1WGMMA::M == 64 and L1WGMMA::N == WG_BLOCK_N and L1WGMMA::K == 32, @@ -438,7 +427,7 @@ sm90_fp8_mega_moe_core(DG_SM90_FP8_MOE_CORE_ARGS_DECL) { L1_SHAPE_N, L1_SHAPE_K, L2_SHAPE_N, L2_SHAPE_K, kNumExpertsPerRank, kNumExpertsPerWave, - kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, kL1NMajorScheduleRequested>(workspace); + kNumSMs, kNumRanks, kClusterSize, kL2NMajorScheduleRequested, false>(workspace); // Pipeline state shared by TMA loaders and math warpgroups uint32_t stage_idx = 0, phase = 0;