diff --git a/kernels/moe_sorting_kernel.py b/kernels/moe_sorting_kernel.py index 6916fef6d..ca59b408b 100644 --- a/kernels/moe_sorting_kernel.py +++ b/kernels/moe_sorting_kernel.py @@ -37,6 +37,10 @@ from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from kernels.kernels_common import get_warp_size +from kernels.topk_gating_softmax_kernel import ( + _compute_topk_gating_layout, + _emit_topk_gating_softmax_body, +) BLOCK_SIZE = 256 UNIT_SIZE = 32 # GEMM tile-M, aka block_size in CK @@ -187,9 +191,21 @@ def _lds_store_raw(raw_mr, val, idx): # _make_cache_key, and dict lookup, reducing dispatch from ~70 us to ~5 us. # --------------------------------------------------------------------------- _oneshot_cf_cache = {} # (num_experts, topk, max_tokens, unit_size, has_mask, device) -> CompiledFunction +_oneshot_fused_cf_cache = {} # fused oneshot: same key + (dtype_str, renormalize) -> CompiledFunction _multiphase_cf_cache = {} # (num_experts, topk, unit_size, kernel_name, *constexpr_vals) -> CompiledFunction _dummy_mask_cache = {} # device -> torch.Tensor(1, dtype=i32, value=1) +# Caches for moe_softmax_sort_flydsl's unfused fallback path. +_topk_fallback_scratch_cache = {} # (device, M, topk) -> (topk_weights, topk_ids, tei) +_topk_fallback_builder_cache = {} # (num_experts, topk, dtype_str, renormalize) -> launch_fn + + +# `_compute_topk_gating_layout` and `_emit_topk_gating_softmax_body` are +# imported from `kernels.topk_gating_softmax_kernel` (see the top-level +# imports) so the fused oneshot kernel below and the standalone gating +# kernel share the same layout heuristic and MLIR emission code — they +# can never disagree about VPT / TPT / ATOM_BITS or about per-token math. + # --------------------------------------------------------------------------- # FlyDSL GPU kernel — oneshot path (single kernel, all phases in LDS) @@ -703,6 +719,510 @@ def launch_moe_sorting_oneshot( return launch_moe_sorting_oneshot +# --------------------------------------------------------------------------- +# FlyDSL GPU kernel — fused (gating + sort) oneshot path +# --------------------------------------------------------------------------- +@functools.lru_cache(maxsize=256) +def compile_moe_sorting_oneshot_fused( + *, + num_experts: int, + topk: int, + dtype_str: str = "bf16", + renormalize: bool = True, + max_tokens: int = 16, + unit_size: int = UNIT_SIZE, + has_mask: bool = False, +): + """Compile the fused (gating + sort) oneshot-path MoE kernel. + + Reads `gating_logits[M, E]` directly, performs softmax + top-K in block + 0, stages all per-token-K intermediates in on-chip LDS, then runs the + sort phases (count → prefix-sum → scatter + padding) without ever + round-tripping the routing intermediates through HBM. Phase 1 of the + original sort (mesh fill) is gone — gating's leader-write callback + fills the mesh directly. Phase 3's per-token weight load now reads from + a small `weights_lds[max_tokens, topk]` region instead of HBM. Blocks + > 0 zero the moe_buf as in the unfused oneshot kernel. + + Parameters identical to ``_compile_moe_sorting_oneshot`` plus: + dtype_str — input dtype for the gating logits (``f32`` / ``f16`` / ``bf16``) + renormalize — whether the top-K weights are renormalised to sum to 1 + """ + arch = get_hip_arch() + E = num_experts + smem_cols = E + 1 + + if arch in ("gfx942",) or str(arch).startswith("gfx94"): + lds_capacity_bytes = 65536 + elif str(arch).startswith("gfx95"): + lds_capacity_bytes = 163840 + else: + lds_capacity_bytes = 65536 + + lds_capacity_ints = lds_capacity_bytes // 4 + target_occupancy = 2 + r = lds_capacity_ints // target_occupancy // smem_cols + sub_unroll = 8 + cumsum_bufs = 2 + if r < (cumsum_bufs + sub_unroll): + raise ValueError( + f"LDS too small for E={E}: need at least " f"{(cumsum_bufs + sub_unroll) * smem_cols * 4} bytes" + ) + r_for_sub = ((r - cumsum_bufs) // sub_unroll) * sub_unroll + r_token_min = ((max_tokens + sub_unroll - 1) // sub_unroll) * sub_unroll + r_for_sub = min(r_for_sub, r_token_min) + sub_tokens = r_for_sub + + allocator = SmemAllocator(None, arch=arch) + cumsum_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = cumsum_offset + smem_cols * 4 + cumdup_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = cumdup_offset + smem_cols * 4 + mesh_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = mesh_offset + sub_tokens * smem_cols * 4 + + # V2 LDS region: weights_lds[max_tokens, topk] — gating winner weights + # staged on-chip so Phase 3's scatter reads them from LDS instead of + # HBM. ≤512 B at the oneshot-path bound (max_tokens=16, topk=8) so it + # has negligible occupancy impact. + weights_lds_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = weights_lds_offset + max_tokens * topk * 4 + + gating_layout = _compute_topk_gating_layout(E, topk, dtype_str) + + @flyc.kernel(known_block_size=[BLOCK_SIZE, 1, 1]) + def moe_sorting_oneshot_fused_kernel( + gating_logits: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_moe_buf_elems: fx.Int32, + ): + bid = gpu.block_idx.x + tid = gpu.thread_idx.x + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + tokens = i32_tokens + c_zero_i32 = fx.Int32(0) + c_one_i32 = fx.Int32(1) + c_oob_idx = fx.Int32(0x7FFFFFFF) + c4_i32 = fx.Int32(4) + + moe_buf_rsrc = buffer_ops.create_buffer_resource(moe_buf, max_size=True) + sorted_ids_rsrc = buffer_ops.create_buffer_resource(sorted_token_ids, max_size=True) + sorted_w_rsrc = buffer_ops.create_buffer_resource(sorted_weights_out, max_size=True) + sorted_e_rsrc = buffer_ops.create_buffer_resource(sorted_expert_ids, max_size=True) + nvalid_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + mask_rsrc = buffer_ops.create_buffer_resource(expert_mask_tensor, max_size=True) + + base_ptr = allocator.get_base() + cumsum_mr = SmemPtr(base_ptr, cumsum_offset, T.i32, shape=(smem_cols,)).get() + cumdup_mr = SmemPtr(base_ptr, cumdup_offset, T.i32, shape=(smem_cols,)).get() + mesh_mr = SmemPtr(base_ptr, mesh_offset, T.i32, shape=(sub_tokens * smem_cols,)).get() + weights_lds_mr = SmemPtr( + base_ptr, + weights_lds_offset, + T.i32, + shape=(max_tokens * topk,), + ).get() + + c_topk = fx.Int32(topk) + c_E = fx.Int32(E) + c_unit = fx.Int32(unit_size) + c_sub_tokens = fx.Int32(sub_tokens) + c_smem_cols = fx.Int32(smem_cols) + c_sentinel = fx.Int32((topk << 24)) + + # =================== MOE_BUF ZEROING (blocks > 0 only) =============== + is_zero_block = bid != c_zero_i32 + if is_zero_block: + zero_gid_v4 = (bid - c_one_i32) * fx.Int32(BLOCK_SIZE) + tid + num_zero_blocks = gpu.grid_dim.x - c_one_i32 + zero_stride_v4 = num_zero_blocks * fx.Int32(BLOCK_SIZE) + i32_moe_buf_v4 = i32_moe_buf_elems >> fx.Int32(2) + zero_niters = (i32_moe_buf_v4 + zero_stride_v4 - c_one_i32) // zero_stride_v4 + _zs = fx.Index(0) + _ze = ArithValue(zero_niters).index_cast(T.index) + _z1 = fx.Index(1) + c_zero_v4 = fx.Vector.filled(4, 0, fx.Int32) + c4_i32 = fx.Int32(4) + for _z in range(_zs, _ze, _z1): + z_idx_v4 = zero_gid_v4 + fx.Int32(_z) * zero_stride_v4 + z_valid = z_idx_v4 < i32_moe_buf_v4 + z_elem = z_valid.select(z_idx_v4 * c4_i32, c_oob_idx) + buffer_ops.buffer_store(c_zero_v4, moe_buf_rsrc, z_elem) + + # =================== SORTING (block 0 only) ========================== + is_sort_block = bid == c_zero_i32 + if is_sort_block: + # ========== PHASE 1 (mesh CLEAR ONLY in fused kernel) ============ + # Gating's on_winner_idx callback fills `mesh_LDS` directly, so + # the old "Phase 1 fill" loop is gone. The clear must still + # happen first (gating only writes winning cells; Phase 2 reads + # every cell and treats 0 as "no token"). + for i_clear in range_constexpr(0, sub_tokens * smem_cols, BLOCK_SIZE): + idx = fx.Int32(i_clear) + tid + is_valid = idx < fx.Int32(sub_tokens * smem_cols) + safe_idx = is_valid.select(idx, c_zero_i32) + safe_idx_ix = ArithValue(safe_idx).index_cast(T.index) + _lds_store_raw(mesh_mr, c_zero_i32, safe_idx_ix) + # Make the mesh clear visible to the gating callback writes + # (gating's leader lanes update individual cells immediately + # after this barrier). + gpu.barrier() + + # =========== PHASE 0 (fused): Gating + softmax + top-K =========== + # `local_token` is the per-block 0..TOKENS_PER_BLOCK-1 index. + # For block 0 the global token equals local_token, but we use + # local_token for the LDS indices (mesh / weights_lds rows). + # + # mesh_LDS layout: `[sub_tokens, smem_cols]` row-major. Each + # winning cell is set to `k_idx + 1` (so 0 still means "empty"). + # + # weights_LDS layout: `[max_tokens, topk]` row-major. Stored as + # i32 bits (we bitcast the f32 weight; Phase 3 stores those + # bits to sorted_w which is also accessed via the i32 path). + def on_winner_idx(local_token, global_token, k_int, expert_idx): + mesh_idx = local_token * c_smem_cols + expert_idx + _lds_store_raw(mesh_mr, fx.Int32(k_int + 1), mesh_idx) + + def on_winner_weight(local_token, global_token, k_int, weight): + w_bits = ArithValue(weight).bitcast(T.i32) + w_idx = local_token * c_topk + fx.Int32(k_int) + _lds_store_raw(weights_lds_mr, fx.Int32(w_bits), w_idx) + + _emit_topk_gating_softmax_body( + gating_logits, + None, # TopkWeights HBM not used in fused + None, # TopkIndices HBM not used in fused + None, # TokenExpertIndices HBM not used in fused + i32_tokens, + num_experts=num_experts, + topk=topk, + dtype_str=dtype_str, + renormalize=renormalize, + on_winner_idx=on_winner_idx, + on_winner_weight=on_winner_weight, + emit_tei=False, + **gating_layout, + ) + # Make the gating LDS writes visible to Phase 2 (different + # threads will read those mesh / weights_LDS cells). + gpu.barrier() + + # ===================== PHASE 2: Count + Prefix Sum ===================== + c_lane_group_sz = fx.Int32(8) + lane_group_id = tid // c_lane_group_sz + lane_group_os = tid % c_lane_group_sz + width8_i32 = fx.Int32(8) + + is_t0 = tid == c_zero_i32 + _lds_store_raw(cumsum_mr, c_zero_i32, c_zero_i32) + gpu.barrier() + + for i_e in range_constexpr(0, E, BLOCK_SIZE // 8): + eid_local = fx.Int32(i_e) + lane_group_id + eid_valid = eid_local < c_E + + cnt = c_zero_i32 + for i_sub in range_constexpr(0, sub_tokens, 8): + sub_idx = fx.Int32(i_sub) + lane_group_os + sub_valid = sub_idx < c_sub_tokens + combined_valid = eid_valid & sub_valid + + safe_sub = combined_valid.select(sub_idx, c_zero_i32) + safe_eid = combined_valid.select(eid_local, c_zero_i32) + mesh_rd_addr = safe_sub * c_smem_cols + safe_eid + mesh_rd_ix = ArithValue(mesh_rd_addr).index_cast(T.index) + mesh_val = _lds_load_raw(mesh_mr, mesh_rd_ix) + + has_token = combined_valid.select( + (mesh_val != c_zero_i32).select(c_one_i32, c_zero_i32), + c_zero_i32, + ) + + reduced = has_token + for sh in range_constexpr(3): + off = fx.Int32(1 << sh) + peer = reduced.shuffle_xor(off, width8_i32) + reduced = reduced + peer + cnt = cnt + reduced + + write_valid = eid_valid & (lane_group_os == c_zero_i32) + cs_idx = write_valid.select(eid_local + c_one_i32, c_zero_i32) + cs_ix = ArithValue(cs_idx).index_cast(T.index) + cs_val = write_valid.select(cnt, c_zero_i32) + _lds_store_raw(cumsum_mr, cs_val, cs_ix) + gpu.barrier() + + for i_cvt in range_constexpr(0, E, BLOCK_SIZE): + cvt_eid = fx.Int32(i_cvt) + tid + cvt_valid = cvt_eid < c_E + safe_cvt_idx = cvt_valid.select(cvt_eid + c_one_i32, c_zero_i32) + cvt_ix = ArithValue(safe_cvt_idx).index_cast(T.index) + raw_cnt_cvt = _lds_load_raw(cumsum_mr, cvt_ix) + blocks_cvt = (raw_cnt_cvt + c_unit - c_one_i32) // c_unit + padded_cvt = (raw_cnt_cvt == c_zero_i32).select(c_zero_i32, blocks_cvt * c_unit) + _lds_store_raw(cumsum_mr, cvt_valid.select(padded_cvt, c_zero_i32), cvt_ix) + gpu.barrier() + + if has_mask: + for i_ep in range_constexpr(0, E, BLOCK_SIZE): + ep_eid = fx.Int32(i_ep) + tid + ep_valid = ep_eid < c_E + ep_safe_eid = ep_valid.select(ep_eid, c_zero_i32) + ep_m = buffer_ops.buffer_load(mask_rsrc, ep_safe_eid, vec_width=1, dtype=T.i32) + should_zero = ep_valid & (ep_m == c_zero_i32) + ep_cs_ix = ArithValue(ep_valid.select(ep_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + _lds_store_raw( + cumsum_mr, should_zero.select(c_zero_i32, _lds_load_raw(cumsum_mr, ep_cs_ix)), ep_cs_ix + ) + gpu.barrier() + + is_wave0 = wave == c_zero_i32 + prev_chunk_total = c_zero_i32 + + for chunk_start in range_constexpr(0, E, WARP_SIZE): + eid_ps = fx.Int32(chunk_start) + lane + eid_ps_valid = is_wave0 & (eid_ps < c_E) + safe_eid_ps = eid_ps_valid.select(eid_ps + c_one_i32, c_zero_i32) + ps_ix = ArithValue(safe_eid_ps).index_cast(T.index) + val = eid_ps_valid.select(_lds_load_raw(cumsum_mr, ps_ix), c_zero_i32) + + val = _dpp_intra_wave_prefix_sum(val, lane, WARP_SIZE) + val = val + prev_chunk_total + + _lds_store_raw( + cumdup_mr, eid_ps_valid.select(val, c_zero_i32), eid_ps_valid.select(eid_ps + c_one_i32, c_zero_i32) + ) + + last_addr = fx.Int32((WARP_SIZE - 1) * 4) + prev_chunk_total = fly_rocdl.ds_bpermute(T.i32, last_addr, val) + prev_chunk_total = fx.Int32(prev_chunk_total) + + _lds_store_raw(cumdup_mr, is_t0.select(c_zero_i32, _lds_load_raw(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + + cs_E_ix_ps = ArithValue(c_E).index_cast(T.index) + total_padded = _lds_load_raw(cumdup_mr, cs_E_ix_ps) + buffer_ops.buffer_store(total_padded, nvalid_rsrc, c_zero_i32) + buffer_ops.buffer_store(tokens, nvalid_rsrc, c_one_i32) + gpu.barrier() + + for i_cp in range_constexpr(0, E + 1, BLOCK_SIZE): + cp_idx = fx.Int32(i_cp) + tid + cp_valid = cp_idx <= c_E + safe_cp_idx = cp_valid.select(cp_idx, c_zero_i32) + cp_ix = ArithValue(safe_cp_idx).index_cast(T.index) + cp_val = _lds_load_raw(cumdup_mr, cp_ix) + _lds_store_raw(cumsum_mr, cp_val, cp_ix) + gpu.barrier() + + if has_mask: + for i_ml in range_constexpr(0, E, BLOCK_SIZE): + ml_eid = fx.Int32(i_ml) + tid + ml_valid = ml_eid < c_E + safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) + ml_mask = buffer_ops.buffer_load(mask_rsrc, safe_ml_eid, vec_width=1, dtype=T.i32) + ml_val = ml_valid.select(ml_mask, c_zero_i32) + ml_ix = ArithValue(ml_valid.select(ml_eid + c_one_i32, c_zero_i32)).index_cast(T.index) + _lds_store_raw(cumdup_mr, ml_val, ml_ix) + _lds_store_raw(cumdup_mr, is_t0.select(c_zero_i32, _lds_load_raw(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + + prev_chunk_total_m = c_zero_i32 + for chunk_start_m in range_constexpr(0, E, WARP_SIZE): + eid_m = fx.Int32(chunk_start_m) + lane + eid_m_valid = is_wave0 & (eid_m < c_E) + safe_eid_m = eid_m_valid.select(eid_m + c_one_i32, c_zero_i32) + m_ix = ArithValue(safe_eid_m).index_cast(T.index) + mval = eid_m_valid.select(_lds_load_raw(cumdup_mr, m_ix), c_zero_i32) + + mval = _dpp_intra_wave_prefix_sum(mval, lane, WARP_SIZE) + mval = mval + prev_chunk_total_m + _lds_store_raw( + cumdup_mr, + eid_m_valid.select(mval, c_zero_i32), + eid_m_valid.select(eid_m + c_one_i32, c_zero_i32), + ) + + last_addr_m = fx.Int32((WARP_SIZE - 1) * 4) + prev_chunk_total_m = fly_rocdl.ds_bpermute(T.i32, last_addr_m, mval) + prev_chunk_total_m = fx.Int32(prev_chunk_total_m) + + _lds_store_raw(cumdup_mr, is_t0.select(c_zero_i32, _lds_load_raw(cumdup_mr, c_zero_i32)), c_zero_i32) + gpu.barrier() + else: + for i_ml in range_constexpr(0, E, BLOCK_SIZE): + ml_eid = fx.Int32(i_ml) + tid + ml_valid = ml_eid < c_E + safe_ml_eid = ml_valid.select(ml_eid, c_zero_i32) + ml_ix = ArithValue(safe_ml_eid).index_cast(T.index) + _lds_store_raw(cumdup_mr, ml_valid.select(safe_ml_eid, c_zero_i32), ml_ix) + gpu.barrier() + + for i_eid in range_constexpr(0, E, BLOCK_SIZE): + eid_wr = fx.Int32(i_eid) + tid + eid_wr_valid = eid_wr < c_E + safe_eid_wr = eid_wr_valid.select(eid_wr, c_zero_i32) + + cs_start_ix = ArithValue(safe_eid_wr).index_cast(T.index) + cs_end_ix = ArithValue(safe_eid_wr + c_one_i32).index_cast(T.index) + e_start = _lds_load_raw(cumsum_mr, cs_start_ix) + e_end = eid_wr_valid.select(_lds_load_raw(cumsum_mr, cs_end_ix), e_start) + local_eid = _lds_load_raw(cumdup_mr, cs_start_ix) + + _lds_store_raw(cumdup_mr, e_start, cs_start_ix) + + blk_start = e_start // c_unit + blk_end = e_end // c_unit + for j_blk in range_constexpr(max_tokens): + blk_idx = blk_start + fx.Int32(j_blk) + blk_valid = eid_wr_valid & (blk_idx < blk_end) + safe_blk = blk_valid.select(blk_idx, c_oob_idx) + buffer_ops.buffer_store(local_eid, sorted_e_rsrc, safe_blk) + gpu.barrier() + + cs_E_ix = ArithValue(c_E).index_cast(T.index) + cumE = _lds_load_raw(cumsum_mr, cs_E_ix) + _lds_store_raw(cumdup_mr, cumE, cs_E_ix) + gpu.barrier() + + # ====================== PHASE 3: Scatter ============================== + for i_e2 in range_constexpr(0, E, BLOCK_SIZE // 8): + eid_sc = fx.Int32(i_e2) + lane_group_id + eid_sc_valid = eid_sc < c_E + safe_eid_sc = eid_sc_valid.select(eid_sc, c_E) + + sc_expert_enabled = eid_sc_valid + if has_mask: + sc_mask_val = buffer_ops.buffer_load( + mask_rsrc, eid_sc_valid.select(eid_sc, c_zero_i32), vec_width=1, dtype=T.i32 + ) + sc_expert_enabled = eid_sc_valid & (sc_mask_val != c_zero_i32) + + cs_sc_ix = ArithValue(safe_eid_sc).index_cast(T.index) + position = _lds_load_raw(cumsum_mr, cs_sc_ix) + + for i_sub2 in range_constexpr(0, sub_tokens, 8): + my_sub = fx.Int32(i_sub2) + lane_group_os + my_sub_valid = sc_expert_enabled & (my_sub < c_sub_tokens) + safe_my_sub = my_sub_valid.select(my_sub, c_zero_i32) + my_mesh_addr = safe_my_sub * c_smem_cols + safe_eid_sc + my_mesh_ix = ArithValue(my_mesh_addr).index_cast(T.index) + my_x = _lds_load_raw(mesh_mr, my_mesh_ix) + my_has_token = my_sub_valid & (my_x != c_zero_i32) + local_cnt = my_has_token.select(c_one_i32, c_zero_i32) + + cnt_raw = _unwrap_val(local_cnt) + zero_raw = _unwrap_val(c_zero_i32) + + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_1, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= c_one_i32 + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + cnt_raw = _unwrap_val(local_cnt) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_2, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= fx.Int32(2) + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + cnt_raw = _unwrap_val(local_cnt) + remote = fly_rocdl.update_dpp( + T.i32, zero_raw, cnt_raw, DPP_ROW_SHR_4, DPP_ROW_MASK, DPP_BANK_MASK, True + ) + should_add = lane_group_os >= fx.Int32(4) + local_cnt = should_add.select(local_cnt + fx.Int32(remote), local_cnt) + + last_lane_of_group = tid | fx.Int32(7) + last_addr = last_lane_of_group * c4_i32 + batch_total = fly_rocdl.ds_bpermute(T.i32, last_addr, local_cnt) + batch_total = fx.Int32(batch_total) + + slot = position + local_cnt - c_one_i32 + safe_x = my_has_token.select(my_x, c_one_i32) + topk_slot_sc = safe_x - c_one_i32 + packed_id = (topk_slot_sc << fx.Int32(24)) | my_sub + safe_slot = my_has_token.select(slot, c_oob_idx) + buffer_ops.buffer_store(packed_id, sorted_ids_rsrc, safe_slot) + + # Fused: weight comes from LDS (gating staged it there) + # instead of HBM. `my_sub` is the per-token local index + # 0..max_tokens-1; topk_slot_sc identifies the K rank. + w_lds_idx = my_has_token.select(my_sub * c_topk + topk_slot_sc, c_zero_i32) + w_val_i32 = _lds_load_raw(weights_lds_mr, w_lds_idx) + buffer_ops.buffer_store(w_val_i32, sorted_w_rsrc, safe_slot) + + position = position + batch_total + + _lds_store_raw(cumsum_mr, position, cs_sc_ix) + gpu.barrier() + + sentinel_val = c_sentinel | tokens + c_zero_as_i32 = c_zero_i32 + for i_pad in range_constexpr(0, E, BLOCK_SIZE): + eid_pad = fx.Int32(i_pad) + tid + pad_valid = eid_pad < c_E + safe_eid_pad = pad_valid.select(eid_pad, c_zero_i32) + + cs_pad_ix = ArithValue(safe_eid_pad).index_cast(T.index) + cdp_ix = ArithValue(safe_eid_pad + c_one_i32).index_cast(T.index) + pad_start = _lds_load_raw(cumsum_mr, cs_pad_ix) + pad_end = pad_valid.select(_lds_load_raw(cumdup_mr, cdp_ix), pad_start) + + for j_pad in range_constexpr(unit_size): + pad_slot = pad_start + fx.Int32(j_pad) + pad_slot_valid = pad_valid & (pad_slot < pad_end) + safe_pad_slot = pad_slot_valid.select(pad_slot, c_oob_idx) + buffer_ops.buffer_store(sentinel_val, sorted_ids_rsrc, safe_pad_slot) + buffer_ops.buffer_store(c_zero_as_i32, sorted_w_rsrc, safe_pad_slot) + + @flyc.jit + def launch_moe_sorting_oneshot_fused( + gating_logits: fx.Tensor, + sorted_token_ids: fx.Tensor, + sorted_weights_out: fx.Tensor, + sorted_expert_ids: fx.Tensor, + num_valid_ids_out: fx.Tensor, + moe_buf: fx.Tensor, + expert_mask_tensor: fx.Tensor, + i32_tokens: fx.Int32, + i32_moe_buf_elems: fx.Int32, + n_grid_blocks: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = moe_sorting_oneshot_fused_kernel( + gating_logits, + sorted_token_ids, + sorted_weights_out, + sorted_expert_ids, + num_valid_ids_out, + moe_buf, + expert_mask_tensor, + i32_tokens, + i32_moe_buf_elems, + ) + launcher.launch( + grid=(n_grid_blocks, 1, 1), + block=(BLOCK_SIZE, 1, 1), + stream=stream, + ) + + return launch_moe_sorting_oneshot_fused + + # --------------------------------------------------------------------------- # FlyDSL GPU kernels — multiphase path (2 or 4 kernels, large T via HBM workspace) # --------------------------------------------------------------------------- @@ -1770,3 +2290,271 @@ def moe_sorting_flydsl( _launch_cached(_multiphase_cf_cache, base_key + ("4k_fused",), launch_4k_fused, k4_args, stream) return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + + +def _supports_fused_oneshot(num_experts: int, topk: int, dtype_str: str) -> bool: + """Whether the fused oneshot kernel supports this configuration. + + Matches `_compute_topk_gating_layout`'s constraints: returns False when + the gating layout heuristic cannot find a valid (VPT, THREADS_PER_TOKEN) + pair for `num_experts`, or when `topk > num_experts`. + """ + try: + _compute_topk_gating_layout(num_experts, topk, dtype_str) + return True + except ValueError: + return False + + +def _alloc_topk_fallback(device, M, topk, num_experts): + """Allocate / cache HBM buffers used by the unfused 2-kernel fallback. + + `topk_weights` is f32, `topk_ids` is i32, both shaped `[M, topk]`. + `tei` is `[M, topk]` i32 (consumed by gating but unused downstream). + The cache is keyed on shape so different M values get separate buffers. + """ + key = (device, M, topk) + cached = _topk_fallback_scratch_cache.get(key) + if cached is not None: + return cached + topk_weights = torch.empty((M, topk), dtype=torch.float32, device=device) + topk_ids = torch.empty((M, topk), dtype=torch.int32, device=device) + tei = torch.empty((M, topk), dtype=torch.int32, device=device) + cached = (topk_weights, topk_ids, tei) + _topk_fallback_scratch_cache[key] = cached + return cached + + +def moe_softmax_sort_flydsl( + gating_logits, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + topk, + dtype_str, + unit_size=UNIT_SIZE, + expert_mask=None, + renormalize=True, + num_local_tokens=None, +): + """Fused entry point: gating logits → softmax → top-K → sort. + + For small M (M ≤ FUSED_ONESHOT_MAX_T) with a supported (num_experts, topk, dtype) + layout, runs a single fused kernel that produces the sort outputs + directly from `gating_logits[M, num_experts]`. Otherwise falls back to + the original 2-kernel chain (`topk_gating_softmax` then + `moe_sorting_flydsl`). + + The output tensor contract is identical to `moe_sorting_flydsl` so this + function is a drop-in replacement at the next-higher layer that wants + to skip the explicit gating step. + + Parameters + ---------- + gating_logits : (M, num_experts) bf16/f16/f32 router logits. + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + : pre-allocated output tensors (same layout/contract as + `moe_sorting_flydsl`). + num_experts : E (e.g. 256 for DeepSeek R1). + topk : K (e.g. 8 for DeepSeek R1). + dtype_str : 'f32' | 'f16' | 'bf16' — must match `gating_logits.dtype`. + unit_size : GEMM tile-M alignment (default `UNIT_SIZE`). + expert_mask : optional EP mask (int32 [E], 1=enabled / 0=masked). + renormalize : if True, rescale top-K weights to sum to 1 (matches + DeepSeek V3 / vLLM convention). + num_local_tokens: optional override for the dynamic M (matches the + `moe_sorting_flydsl` convention). + """ + if num_local_tokens is not None: + M = num_local_tokens.item() if isinstance(num_local_tokens, torch.Tensor) else int(num_local_tokens) + else: + M = gating_logits.shape[0] + + device = gating_logits.device + moe_buf_i32 = moe_buf.view(torch.int32) + moe_buf_elems = moe_buf_i32.numel() + + has_mask = expert_mask is not None + if not has_mask: + mask_tensor = _dummy_mask_cache.get(device) + if mask_tensor is None: + mask_tensor = torch.ones(1, dtype=torch.int32, device=device) + _dummy_mask_cache[device] = mask_tensor + else: + mask_tensor = expert_mask + + sub_tokens = _compute_sub_tokens(num_experts) + FUSED_ONESHOT_MAX_T = 16 + + fusion_ok = ( + M <= min(sub_tokens, FUSED_ONESHOT_MAX_T) + and topk <= num_experts + and _supports_fused_oneshot(num_experts, topk, dtype_str) + ) + + if fusion_ok: + max_tokens = max(M, 8) + max_tokens = ((max_tokens + 7) // 8) * 8 + + target_occupancy = 2 + num_cu = torch.cuda.get_device_properties(device).multi_processor_count + n_zero_blocks = min( + (moe_buf_elems + BLOCK_SIZE - 1) // BLOCK_SIZE, + num_cu * target_occupancy, + ) + n_grid_blocks = 1 + n_zero_blocks + + launch_moe_sorting_oneshot_fused_path( + gating_logits, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + mask_tensor, + M, + moe_buf_elems, + n_grid_blocks, + num_experts=num_experts, + topk=topk, + dtype_str=dtype_str, + renormalize=renormalize, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + else: + # Fallback: run gating and sort as two separate kernels. + topk_weights, topk_ids, tei = _alloc_topk_fallback(device, M, topk, num_experts) + + builder_key = (num_experts, topk, dtype_str, renormalize) + launch_topk = _topk_fallback_builder_cache.get(builder_key) + if launch_topk is None: + # Local import avoids forcing every caller to know about the + # builder; the cache keeps the import cost out of the hot path. + from kernels.topk_gating_softmax_kernel import ( + build_topk_gating_softmax_module, + ) + + launch_topk = build_topk_gating_softmax_module( + num_experts=num_experts, + topk=topk, + dtype_str=dtype_str, + renormalize=renormalize, + ) + _topk_fallback_builder_cache[builder_key] = launch_topk + + stream = torch.cuda.current_stream() + launch_topk( + gating_logits, + topk_weights, + topk_ids, + tei, + M, + stream=stream, + ) + moe_sorting_flydsl( + topk_ids, + topk_weights, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf, + num_experts, + unit_size=unit_size, + expert_mask=expert_mask, + num_local_tokens=M, + ) + + return sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf + + +def launch_moe_sorting_oneshot_fused_path( + gating_logits, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + *, + num_experts, + topk, + dtype_str, + renormalize=True, + max_tokens=16, + unit_size=UNIT_SIZE, + has_mask=False, +): + """Low-level launcher for the fused (gating + sort) oneshot path. + + LDS-direct: reads `gating_logits[M, E]` and emits the sort outputs + in one kernel launch with no HBM intermediates. The launcher carries no + scratch tensors — all per-token-K data is staged in LDS inside the + kernel. + """ + cache_key = (num_experts, topk, max_tokens, unit_size, n_grid_blocks, dtype_str, renormalize, has_mask) + cf = _oneshot_fused_cf_cache.get(cache_key) + if cf is not None: + stream = torch.cuda.current_stream() + cf( + gating_logits, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + fx.Stream(stream), + ) + return + + launch_fn = compile_moe_sorting_oneshot_fused( + num_experts=num_experts, + topk=topk, + dtype_str=dtype_str, + renormalize=renormalize, + max_tokens=max_tokens, + unit_size=unit_size, + has_mask=has_mask, + ) + stream = torch.cuda.current_stream() + launch_fn( + gating_logits, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + stream=stream, + ) + + cf = flyc.compile( + launch_fn, + gating_logits, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + moe_buf_i32, + expert_mask, + i32_tokens, + i32_moe_buf_elems, + n_grid_blocks, + fx.Stream(stream), + ) + _oneshot_fused_cf_cache[cache_key] = cf diff --git a/kernels/topk_gating_softmax_kernel.py b/kernels/topk_gating_softmax_kernel.py index 9d9553001..f008fdfb0 100644 --- a/kernels/topk_gating_softmax_kernel.py +++ b/kernels/topk_gating_softmax_kernel.py @@ -10,12 +10,25 @@ 3. renormalize = rescale K selected weights to sum to 1.0 Outputs: topk_weights (f32), topk_indices (i32), token_expert_indices (i32). + +This module also exposes two shared helpers used by the fused oneshot path in +``kernels/moe_sorting_kernel.py``: + + - ``_compute_topk_gating_layout`` — resolves the full layout dict (VPT, + THREADS_PER_TOKEN, TOKENS_PER_BLOCK, ATOM_BITS, ...). + - ``_emit_topk_gating_softmax_body`` — emits the softmax + top-K MLIR + body into the current ``@flyc.kernel`` insertion point, with optional + per-winner callbacks so the fused kernel can sink the winning weights + and expert indices directly to LDS instead of HBM. """ import math +from contextlib import contextmanager import flydsl.compiler as flyc import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf from flydsl.expr import arith, range_constexpr, vector from flydsl.expr.arith import ArithValue from flydsl.expr.typing import Int32, T @@ -28,6 +41,18 @@ BLOCK_THREADS = WARPS_PER_BLOCK * WARP_SIZE # 256 on gfx95x +@contextmanager +def _if_then(if_op): + """Context manager for an explicit ``scf.IfOp`` then-region.""" + with ir.InsertionPoint(if_op.then_block): + try: + yield if_op.then_block + finally: + blk = if_op.then_block + if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): + scf.YieldOp([]) + + def _pick_layout(num_experts: int): """Pick (VPT, THREADS_PER_TOKEN) for the multi-token-per-block fast path. @@ -51,23 +76,14 @@ def _pick_layout(num_experts: int): return None, None -def build_topk_gating_softmax_module( - num_experts: int, - topk: int, - dtype_str: str = "bf16", - renormalize: bool = True, -): - """Build a fused TopK gating softmax kernel. +def _compute_topk_gating_layout(num_experts: int, topk: int, dtype_str: str): + """Resolve the full layout dict (VPT, THREADS_PER_TOKEN, TOKENS_PER_BLOCK, + ATOM_BITS, ELEMS_PER_ATOM, ATOMS_PER_THREAD, elem_bits) for the multi- + token-per-block gating softmax kernel. - Args: - num_experts: Number of MoE experts (columns in gating_output). - topk: Number of top experts to select per token. - dtype_str: Input data type ('f32', 'f16', 'bf16'). - renormalize: If True, rescale selected weights to sum to 1. - - Returns: - A @flyc.jit launcher function with signature - ``(gating, weights, indices, tei, num_tokens, *, stream)``. + Shared by the standalone kernel in this module and the fused oneshot + kernel in ``kernels/moe_sorting_kernel.py`` so the two paths can never + disagree on the layout. """ elem_bits = 32 if dtype_str == "f32" else 16 @@ -78,13 +94,12 @@ def build_topk_gating_softmax_module( f"layout: requires num_experts // VPT to be a power of 2 <= " f"WARP_SIZE={WARP_SIZE} for some VPT in [16, 8, 4, 2, 1]." ) + if topk > num_experts: + raise ValueError(f"topk={topk} > num_experts={num_experts}") TOKENS_PER_WARP = WARP_SIZE // THREADS_PER_TOKEN TOKENS_PER_BLOCK = WARPS_PER_BLOCK * TOKENS_PER_WARP - if topk > num_experts: - raise ValueError(f"topk={topk} > num_experts={num_experts}") - if elem_bits <= 16 and VPT % 8 == 0: ATOM_BITS = 128 # 8 bf16/f16 per atom call elif elem_bits <= 16 and VPT % 4 == 0: @@ -98,6 +113,324 @@ def build_topk_gating_softmax_module( ELEMS_PER_ATOM = ATOM_BITS // elem_bits ATOMS_PER_THREAD = VPT // ELEMS_PER_ATOM + return dict( + elem_bits=elem_bits, + VPT=VPT, + THREADS_PER_TOKEN=THREADS_PER_TOKEN, + TOKENS_PER_WARP=TOKENS_PER_WARP, + TOKENS_PER_BLOCK=TOKENS_PER_BLOCK, + ATOM_BITS=ATOM_BITS, + ELEMS_PER_ATOM=ELEMS_PER_ATOM, + ATOMS_PER_THREAD=ATOMS_PER_THREAD, + ) + + +@flyc.jit +def _emit_topk_gating_softmax_body( + GatingOutput, + TopkWeights, + TopkIndices, + TokenExpertIndices, + i32_num_tokens, + *, + num_experts: int, + topk: int, + dtype_str: str, + renormalize: bool, + VPT: int, + THREADS_PER_TOKEN: int, + TOKENS_PER_WARP: int, + TOKENS_PER_BLOCK: int, + ATOM_BITS: int, + ELEMS_PER_ATOM: int, + ATOMS_PER_THREAD: int, + elem_bits: int, + on_winner_idx=None, + on_winner_weight=None, + emit_tei: bool = True, +): + """Emit MLIR for gating logits → softmax → top-K into the current + ``@flyc.kernel`` insertion point. Used by the fused oneshot kernel + in ``kernels/moe_sorting_kernel.py``. + + Must be called from inside an ``@flyc.kernel`` so that ``fx.block_idx``, + ``fx.thread_idx``, buffer/copy-atom operations, etc. are valid in the + current tracing context. + + Output stores + ------------- + By default each leader lane writes the winning (weight, expert_idx, tei) + triples to the HBM tensors ``TopkWeights``, ``TopkIndices``, + ``TokenExpertIndices``. If ``on_winner_idx`` and/or ``on_winner_weight`` + callbacks are provided they replace the corresponding HBM store. Each + callback is invoked once per (token, k) inside the leader-active region + with signature:: + + on_winner_idx(local_token_i32, global_token_i32, k_int, expert_idx_i32) + on_winner_weight(local_token_i32, global_token_i32, k_int, weight_f32) + + ``local_token_i32`` is the token index within the block + (``0..TOKENS_PER_BLOCK-1``) and is suitable for indexing per-block LDS + staging buffers. + + If ``emit_tei=False`` the TEI HBM store is suppressed entirely and the + ``TokenExpertIndices`` tensor argument is unused (callers may pass + ``None``). Similarly, callers providing ``on_winner_idx`` / + ``on_winner_weight`` may pass ``None`` for the corresponding HBM + tensor — the buffer-resource + slice for that output is only + materialised when its HBM store is enabled. + """ + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_dtype = dtype_to_elem_type(dtype_str) + elem_type = elem_dtype.ir_type + compute_type = T.f32 + register_addr_space = int(fx.AddressSpace.Register) + + fm_fast = arith.FastMathFlags.fast + + c_zero_f = fx.Float32(0.0) + c_neg_inf = fx.Float32(float("-inf")) + c_log2e = fx.Float32(1.4426950408889634) + c_one_f = fx.Float32(1.0) + + c_warp = fx.Int32(WARP_SIZE) + c_tpt = fx.Int32(THREADS_PER_TOKEN) + c_tpw = fx.Int32(TOKENS_PER_WARP) + c_tpb = fx.Int32(TOKENS_PER_BLOCK) + c_vpt = fx.Int32(VPT) + + warp_id = tid // c_warp # 0..WARPS_PER_BLOCK-1 + lane = tid % c_warp # 0..WARP_SIZE-1 + token_in_warp = lane // c_tpt # 0..TOKENS_PER_WARP-1 + expert_lane = lane % c_tpt # 0..THREADS_PER_TOKEN-1 + local_token = warp_id * c_tpw + token_in_warp # 0..TOKENS_PER_BLOCK-1 + global_token = bid * c_tpb + local_token # token row + + in_range = global_token < i32_num_tokens + global_token_safe = in_range.select(global_token, fx.Int32(0)) + + def group_reduce(x, mode): + """Butterfly reduce within a THREADS_PER_TOKEN sub-warp group.""" + width_i32 = c_tpt + w = x + for _sh in range_constexpr(int(math.log2(THREADS_PER_TOKEN))): + off = fx.Int32(THREADS_PER_TOKEN // (2 << _sh)) + peer = w.shuffle_xor(off, width_i32) + if mode == "max": + w = w.maximumf(peer) + else: + w = w.addf(peer, fastmath=fm_fast) + return w + + def group_reduce_argmax(val, idx): + """Butterfly argmax within a THREADS_PER_TOKEN sub-warp group. + + All lanes in the group end with the same (max_val, max_idx). + Ties are broken by the lower expert index. + """ + width_i32 = c_tpt + wv, wi = val, idx + for _sh in range_constexpr(int(math.log2(THREADS_PER_TOKEN))): + off = fx.Int32(THREADS_PER_TOKEN // (2 << _sh)) + peer_v = wv.shuffle_xor(off, width_i32) + peer_i = wi.shuffle_xor(off, width_i32) + is_greater = peer_v > wv + is_equal = ArithValue(peer_v) == ArithValue(wv) + peer_lower_idx = peer_i < wi + take_peer = is_greater | (is_equal & peer_lower_idx) + wv = take_peer.select(peer_v, wv) + wi = take_peer.select(peer_i, wi) + return wv, wi + + GatingOutput_buf = fx.rocdl.make_buffer_tensor(GatingOutput) + row_gating = fx.slice(GatingOutput_buf, (global_token_safe, None)) + gating_div = fx.logical_divide(row_gating, fx.make_layout(ELEMS_PER_ATOM, 1)) + + # Only materialise the output views/buffer-resources for the stores we + # actually emit. Callers supplying callbacks (V2 on-chip-sink mode) + # may pass `None` for the corresponding HBM tensors. + weights_div = None + if on_winner_weight is None: + TopkWeights_buf = fx.rocdl.make_buffer_tensor(TopkWeights) + row_weights = fx.slice(TopkWeights_buf, (global_token_safe, None)) + weights_div = fx.logical_divide(row_weights, fx.make_layout(1, 1)) + + indices_div = None + if on_winner_idx is None: + TopkIndices_buf = fx.rocdl.make_buffer_tensor(TopkIndices) + row_indices = fx.slice(TopkIndices_buf, (global_token_safe, None)) + indices_div = fx.logical_divide(row_indices, fx.make_layout(1, 1)) + + tei_div = None + if emit_tei: + TokenExpertIndices_buf = fx.rocdl.make_buffer_tensor(TokenExpertIndices) + row_tei = fx.slice(TokenExpertIndices_buf, (global_token_safe, None)) + tei_div = fx.logical_divide(row_tei, fx.make_layout(1, 1)) + + copy_atom_in = fx.make_copy_atom(fx.rocdl.BufferCopy(ATOM_BITS), elem_bits) + # Use the older fx.memref_alloca + explicit MemRefType API rather than + # the newer fx.make_rmem_tensor helper: the latter is missing from the + # pre-built flydsl shipped under `build-fly/python_packages/` (which + # pytest's conftest puts ahead of the source `python/flydsl/` on + # sys.path), so make_rmem_tensor only works in script mode. memref_alloca + # is present in both versions and produces equivalent IR. + atom_reg_ty_in = fx.MemRefType.get( + elem_type, + fx.LayoutType.get(ELEMS_PER_ATOM, 1), + register_addr_space, + ) + atom_reg_lay_in = fx.make_layout(ELEMS_PER_ATOM, 1) + + copy_atom_f32 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + scalar_reg_ty_f32 = fx.MemRefType.get(T.f32, fx.LayoutType.get(1, 1), register_addr_space) + scalar_reg_lay = fx.make_layout(1, 1) + + def _load_atom_in(divided, atom_index): + """Load ELEMS_PER_ATOM contiguous elements starting at atom_index.""" + view = fx.slice(divided, (None, atom_index)) + r = fx.memref_alloca(atom_reg_ty_in, atom_reg_lay_in) + fx.copy_atom_call(copy_atom_in, view, r) + return fx.memref_load_vec(r) + + def _store_scalar_f32(divided, index, val): + r = fx.memref_alloca(scalar_reg_ty_f32, scalar_reg_lay) + v = fx.Vector.from_elements([val], fx.Float32) + fx.memref_store_vec(v, r) + view = fx.slice(divided, (None, index)) + fx.copy_atom_call(copy_atom_f32, r, view) + + def _store_scalar_i32(divided, index, val): + # `divided` is a logical_divide of a torch.float32-viewed buffer, + # so its element type is f32. Reinterpret the i32 bits as f32 and + # store via the f32 copy atom (avoids signed-vs-signless legalize + # failures when going through si32). + val_f32 = ArithValue(val).bitcast(T.f32) + r = fx.memref_alloca(scalar_reg_ty_f32, scalar_reg_lay) + v = fx.Vector.from_elements([val_f32], fx.Float32) + fx.memref_store_vec(v, r) + view = fx.slice(divided, (None, index)) + fx.copy_atom_call(copy_atom_f32, r, view) + + # Pass 1: load this thread's VPT experts + per-thread max + col_idx_list = [] + for v in range_constexpr(VPT): + col_idx_list.append(expert_lane * c_vpt + fx.Int32(v)) + + c_atoms_pt = fx.Int32(ATOMS_PER_THREAD) + x_list = [] + thread_max = c_neg_inf + for a in range_constexpr(ATOMS_PER_THREAD): + atom_idx = expert_lane * c_atoms_pt + fx.Int32(a) + atom_vec = _load_atom_in(gating_div, atom_idx) + for v in range_constexpr(ELEMS_PER_ATOM): + val_e = vector.extract(atom_vec, static_position=[v]) + xv = val_e if dtype_str == "f32" else val_e.extf(compute_type) + x_list.append(xv) + thread_max = thread_max.maximumf(xv) + + group_max = group_reduce(thread_max, "max") + + # Pass 2: exp(x - max) and per-token sum + thread_sum = c_zero_f + exp_list = [] + for v in range_constexpr(VPT): + sub = x_list[v] - group_max + scaled = sub * c_log2e + ev = scaled.exp2(fastmath=fm_fast) + exp_list.append(ev) + thread_sum = thread_sum + ev + + group_sum = group_reduce(thread_sum, "sum") + + # Pass 3: normalise -> softmax probabilities (kept in registers) + inv_sum = c_one_f / group_sum + prob_list = [] + for v in range_constexpr(VPT): + prob_list.append(exp_list[v] * inv_sum) + + # Pass 4: iterative top-K (sub-warp argmax → mask) + selected_weights = [] # one f32 per k iter (replicated across the group) + selected_indices = [] # one i32 per k iter (replicated across the group) + selected_sum = c_zero_f + + for k_idx in range_constexpr(topk): + thread_best_val = c_neg_inf + thread_best_idx = fx.Int32(-1) + for v in range_constexpr(VPT): + pv = prob_list[v] + ci = col_idx_list[v] + is_better = pv > thread_best_val + thread_best_val = is_better.select(pv, thread_best_val) + thread_best_idx = is_better.select(ci, thread_best_idx) + + global_best_val, global_best_idx = group_reduce_argmax(thread_best_val, thread_best_idx) + + selected_weights.append(global_best_val) + selected_indices.append(global_best_idx) + selected_sum = selected_sum + global_best_val + + for v in range_constexpr(VPT): + ci = col_idx_list[v] + is_winner = ArithValue(ci) == ArithValue(global_best_idx) + prob_list[v] = is_winner.select(c_neg_inf, prob_list[v]) + + # Pass 5: leader writes weights/indices/tei (with optional renorm). + c_eps = fx.Float32(1e-20) + denom = selected_sum.maximumf(c_eps) + inv_denom = c_one_f / denom + + if (expert_lane == fx.Int32(0)) & (global_token < i32_num_tokens): + num_tokens_v = ArithValue(i32_num_tokens) + for k_idx in range_constexpr(topk): + w_val = selected_weights[k_idx] + if renormalize: + w_val = w_val * inv_denom + if on_winner_weight is not None: + on_winner_weight(local_token, global_token, k_idx, w_val) + else: + _store_scalar_f32(weights_div, Int32(k_idx), w_val) + + if on_winner_idx is not None: + on_winner_idx(local_token, global_token, k_idx, selected_indices[k_idx]) + else: + _store_scalar_i32(indices_div, Int32(k_idx), selected_indices[k_idx]) + + if emit_tei: + # tei[t, k] = k * num_tokens + t (matches vLLM convention). + tei_val = Int32(k_idx) * num_tokens_v + global_token + _store_scalar_i32(tei_div, Int32(k_idx), tei_val) + + +def build_topk_gating_softmax_module( + num_experts: int, + topk: int, + dtype_str: str = "bf16", + renormalize: bool = True, +): + """Build a fused TopK gating softmax kernel. + + Args: + num_experts: Number of MoE experts (columns in gating_output). + topk: Number of top experts to select per token. + dtype_str: Input data type ('f32', 'f16', 'bf16'). + renormalize: If True, rescale selected weights to sum to 1. + + Returns: + A @flyc.jit launcher function with signature + ``(gating, weights, indices, tei, num_tokens, *, stream)``. + """ + layout = _compute_topk_gating_layout(num_experts, topk, dtype_str) + elem_bits = layout["elem_bits"] + VPT = layout["VPT"] + THREADS_PER_TOKEN = layout["THREADS_PER_TOKEN"] + TOKENS_PER_WARP = layout["TOKENS_PER_WARP"] + TOKENS_PER_BLOCK = layout["TOKENS_PER_BLOCK"] + ATOM_BITS = layout["ATOM_BITS"] + ELEMS_PER_ATOM = layout["ELEMS_PER_ATOM"] + ATOMS_PER_THREAD = layout["ATOMS_PER_THREAD"] + # No shared memory used — every reduction stays inside a sub-warp lane group. @flyc.kernel diff --git a/tests/kernels/test_moe_sorting.py b/tests/kernels/test_moe_sorting.py index 39286db02..caeaf3a81 100644 --- a/tests/kernels/test_moe_sorting.py +++ b/tests/kernels/test_moe_sorting.py @@ -35,8 +35,13 @@ from kernels.moe_sorting_kernel import ( # noqa: E402 UNIT_SIZE, + _supports_fused_oneshot, + moe_softmax_sort_flydsl, moe_sorting_flydsl, ) +from kernels.topk_gating_softmax_kernel import ( # noqa: E402 + build_topk_gating_softmax_module, +) WARMUP_ITERS = 3 RUN_BENCH = os.environ.get("MOE_SORTING_BENCH", "0") == "1" @@ -645,6 +650,234 @@ def test_moe_sorting_vs_aiter(T, E, topk): assert result, f"FlyDSL vs aiter mismatch for T={T}, E={E}, topk={topk}" +# --------------------------------------------------------------------------- +# Fused softmax+top-K+sort tests (moe_softmax_sort_flydsl) +# --------------------------------------------------------------------------- +_TORCH_DTYPE = {"f32": torch.float32, "f16": torch.float16, "bf16": torch.bfloat16} + + +def _call_softmax_sort_fused( + gating_logits, E, topk, dtype_str, *, model_dim=4096, unit_size=UNIT_SIZE, expert_mask=None, renormalize=True +): + """Allocate outputs and call moe_softmax_sort_flydsl. Mirrors + `_call_flydsl` but takes raw gating logits and dispatches through the + fused entry point.""" + M = gating_logits.shape[0] + max_padded = M * topk + E * unit_size - topk + max_blocks = (max_padded + unit_size - 1) // unit_size + device = gating_logits.device + s_ids = torch.empty(max_padded, dtype=torch.int32, device=device) + s_w = torch.empty(max_padded, dtype=torch.float32, device=device) + s_eids = torch.empty(max_blocks, dtype=torch.int32, device=device) + nv = torch.empty(2, dtype=torch.int32, device=device) + buf = torch.empty((M, model_dim), dtype=torch.bfloat16, device=device) + return moe_softmax_sort_flydsl( + gating_logits, + s_ids, + s_w, + s_eids, + nv, + buf, + E, + topk, + dtype_str, + unit_size=unit_size, + expert_mask=expert_mask, + renormalize=renormalize, + ) + + +def _two_kernel_reference( + gating_logits, E, topk, dtype_str, *, model_dim=4096, unit_size=UNIT_SIZE, expert_mask=None, renormalize=True +): + """Run gating + sort as two separate kernels; return the same output + tuple as the fused path. Used as the regression oracle for the fused + kernel: anything the fused kernel produces must match what these two + kernels produce on the same gating logits.""" + M = gating_logits.shape[0] + device = gating_logits.device + + topk_weights = torch.empty((M, topk), dtype=torch.float32, device=device) + topk_ids = torch.empty((M, topk), dtype=torch.int32, device=device) + tei = torch.empty((M, topk), dtype=torch.int32, device=device) + + launch_topk = build_topk_gating_softmax_module( + num_experts=E, + topk=topk, + dtype_str=dtype_str, + renormalize=renormalize, + ) + stream = torch.cuda.current_stream() + launch_topk(gating_logits, topk_weights, topk_ids, tei, M, stream=stream) + + max_padded = M * topk + E * unit_size - topk + max_blocks = (max_padded + unit_size - 1) // unit_size + s_ids = torch.empty(max_padded, dtype=torch.int32, device=device) + s_w = torch.empty(max_padded, dtype=torch.float32, device=device) + s_eids = torch.empty(max_blocks, dtype=torch.int32, device=device) + nv = torch.empty(2, dtype=torch.int32, device=device) + buf = torch.empty((M, model_dim), dtype=torch.bfloat16, device=device) + return moe_sorting_flydsl( + topk_ids, + topk_weights, + s_ids, + s_w, + s_eids, + nv, + buf, + E, + unit_size, + expert_mask=expert_mask, + ) + + +def _check_outputs_equal(ref_tuple, fused_tuple, *, topk, M, unit_size, label): + """Compare the 5-tuple outputs of the two paths. Returns True on success. + + `sorted_ids` may legitimately differ in order within each expert's + padded block (the sort is bag-of-tokens within an expert). We compare + set-equality per expert block and exact equality for everything else. + """ + ref_ids, ref_w, ref_eids, ref_nv, ref_buf = ref_tuple + fused_ids, fused_w, fused_eids, fused_nv, fused_buf = fused_tuple + + passed = True + + nv_ok = torch.equal(ref_nv, fused_nv) + print( + f" [{label}/num_valid_ids] ref={ref_nv.tolist()} fused={fused_nv.tolist()} " f"({'OK' if nv_ok else 'FAIL'})" + ) + passed &= nv_ok + + num_padded = ref_nv[0].item() + passed &= check_sorted_ids(ref_ids, fused_ids, num_padded, topk, M, f"{label}/sorted_ids") + passed &= check_sorted_weights( + ref_w, + fused_w, + ref_ids, + topk, + M, + gpu_ids=fused_ids, + num_padded=num_padded, + label=f"{label}/sorted_weights", + ) + # Both paths leave the trailing blocks of `sorted_expert_ids` + # uninitialised (they're allocated via torch.empty in both the fused + # and reference launchers), so compare only the in-range blocks. + num_valid_blocks = num_padded // unit_size + passed &= check_expert_ids( + ref_eids, + fused_eids, + f"{label}/sorted_expert_ids", + num_valid_blocks=num_valid_blocks, + ) + + buf_zero = (fused_buf.view(torch.int32) == 0).all().item() + print(f" [{label}/moe_buf_zeroed] {'OK' if buf_zero else 'FAIL'}") + passed &= buf_zero + + return passed + + +def _run_softmax_sort_fused_test(T, E, topk, dtype_str, *, renormalize=True, unit_size=UNIT_SIZE, model_dim=4096): + """Generate gating logits, run both paths, compare. Returns bool.""" + print(f"\n{'=' * 60}") + print(f"Fused softmax_sort test: T={T}, E={E}, topk={topk}, " f"dtype={dtype_str}, renorm={renormalize}") + print(f"{'=' * 60}") + + torch.manual_seed(42 + T * 1000 + E * 10 + topk + hash(dtype_str) % 100) + torch_dtype = _TORCH_DTYPE[dtype_str] + + # Generate logits in fp32, then quantise to the kernel dtype so the + # reference path sees identical bytes. Without this, bf16/f16 boundary + # ties at the top-K cutoff can swing differently in fp32 vs quantised + # arithmetic and produce spurious mismatches. + gating_fp32 = torch.rand((T, E), device="cuda", dtype=torch.float32) * 4.0 - 2.0 + gating_dev = gating_fp32.to(torch_dtype).contiguous() + + fused_out = _call_softmax_sort_fused( + gating_dev, + E, + topk, + dtype_str, + model_dim=model_dim, + unit_size=unit_size, + renormalize=renormalize, + ) + ref_out = _two_kernel_reference( + gating_dev, + E, + topk, + dtype_str, + model_dim=model_dim, + unit_size=unit_size, + renormalize=renormalize, + ) + torch.cuda.synchronize() + + passed = _check_outputs_equal( + ref_out, + fused_out, + topk=topk, + M=T, + unit_size=unit_size, + label=f"fused(T={T},E={E},k={topk},{dtype_str})", + ) + print(f" >>> {'PASSED' if passed else 'FAILED'}") + return passed + + +FUSED_ONESHOT_CONFIGS = [ + # (T, E, topk, dtype) + (1, 256, 8, "bf16"), + (4, 256, 8, "bf16"), + (8, 256, 8, "bf16"), + (16, 256, 8, "bf16"), + (1, 128, 4, "bf16"), + (8, 128, 4, "bf16"), + (1, 32, 5, "bf16"), + (4, 32, 5, "bf16"), + # dtype coverage + (8, 256, 8, "f16"), + (8, 256, 8, "f32"), + # Edge cases + (7, 256, 8, "bf16"), # M not a multiple of TOKENS_PER_BLOCK + (13, 256, 8, "bf16"), # arbitrary M < TOKENS_PER_BLOCK +] + + +@pytest.mark.parametrize("T,E,topk,dtype_str", FUSED_ONESHOT_CONFIGS) +def test_moe_softmax_sort_fused_oneshot(T, E, topk, dtype_str): + assert _supports_fused_oneshot(E, topk, dtype_str), ( + f"Test config {E=}/{topk=}/{dtype_str=} not supported by fused oneshot; " "check FUSED_ONESHOT_CONFIGS" + ) + passed = _run_softmax_sort_fused_test(T, E, topk, dtype_str) + assert passed, ( + f"moe_softmax_sort_flydsl fused oneshot mismatch for " f"T={T}, E={E}, topk={topk}, dtype={dtype_str}" + ) + + +@pytest.mark.parametrize( + "T,E,topk,dtype_str", + [ + # M > FUSED_ONESHOT_MAX_T forces the fused entry to take its 2-kernel fallback + # (separate gating + moe_sorting launches). The gating layout must + # still be supported. + (32, 256, 8, "bf16"), + (64, 256, 8, "bf16"), + (128, 256, 8, "bf16"), + (1024, 256, 8, "bf16"), + ], +) +def test_moe_softmax_sort_fallback(T, E, topk, dtype_str): + """The fused-path entry must remain correct when it falls back to the + 2-kernel chain because M exceeds the oneshot bound.""" + passed = _run_softmax_sort_fused_test(T, E, topk, dtype_str) + assert passed, ( + f"moe_softmax_sort_flydsl fallback path mismatch for " f"T={T}, E={E}, topk={topk}, dtype={dtype_str}" + ) + + # --------------------------------------------------------------------------- # Benchmark utilities # --------------------------------------------------------------------------- @@ -683,6 +916,56 @@ def bench_eager_us(fn, warmup=BENCH_WARMUP, iters=BENCH_MEASURE, flush_l2=True): return latencies[len(latencies) // 2] +def bench_kernel_us(fn, warmup=BENCH_WARMUP, iters=BENCH_MEASURE): + """Pure on-device kernel time (per invocation, microseconds). + + Uses ``torch.profiler`` (CUPTI on CUDA, roctracer on ROCm) to capture + per-kernel begin/end timestamps from the GPU command processor itself. + The returned figure sums every GPU kernel that ran during ``iters`` + invocations of ``fn`` and divides by ``iters``. + + Compared to ``bench_graph_us``: + - ``bench_graph_us`` measures end-to-end CUDA-graph replay latency, + which still includes graph-replay overhead and any inter-kernel + dispatch gaps on the GPU command processor. + - ``bench_kernel_us`` measures only the wall time the GPU is actually + executing kernels — i.e. the floor on kernel runtime, with launch + and dispatch effects removed. + + For multi-kernel paths (e.g. unfused gating + sort) this returns the + sum of all per-kernel durations, which is the right comparison point + for fusion: it isolates how much on-device compute fusion saved, + independent of dispatch / scheduler effects. + """ + try: + from torch.profiler import ProfilerActivity, profile + except ImportError: + return None + + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + try: + with profile(activities=[ProfilerActivity.CUDA]) as prof: + for _ in range(iters): + fn() + torch.cuda.synchronize() + except Exception: + return None + + total_us = 0.0 + for k in prof.key_averages(): + # Sum events with non-zero on-device dwell. This naturally excludes + # host-side stubs like hipModuleLaunchKernel / cudaLaunchKernel and + # hipDeviceSynchronize / cudaDeviceSynchronize, whose self_device_time + # is 0 because no work runs on the GPU under their name. + sd = getattr(k, "self_device_time_total", 0.0) + if sd > 0: + total_us += sd + return total_us / iters + + def bench_graph_us(fn, warmup=BENCH_WARMUP, iters=BENCH_MEASURE): """CUDA graph benchmark — amortizes kernel launch overhead.""" for _ in range(warmup): @@ -722,8 +1005,8 @@ def run_bench_comparison(token_sweep=None): """Benchmark FlyDSL vs CK (aiter) across T values in eager and graph modes.""" try: from aiter.fused_moe import moe_sorting as aiter_moe_sorting - except ImportError: - print(" aiter not available, skipping CK comparison") + except (ImportError, AttributeError) as e: + print(f" aiter not available ({type(e).__name__}: {e}), skipping CK comparison") aiter_moe_sorting = None E, topk, model_dim = 256, 8, 4096 @@ -739,13 +1022,18 @@ def run_bench_comparison(token_sweep=None): print(f" Device: {torch.cuda.get_device_name(0)}") props = torch.cuda.get_device_properties(0) print(f" CUs: {props.multi_processor_count}, oneshot threshold: T<={sub_tokens}") - print(f" Modes: eager (with L2 flush, median of {BENCH_MEASURE}), graph ({BENCH_MEASURE} replays)") - print(f"{'=' * 110}") print( - f"{'T':>6s} | {'Path':>7s} | {'FLY eager':>10s} | {'FLY graph':>10s} | " - f"{'CK eager':>10s} | {'CK graph':>10s} | {'Eager':>7s} | {'Graph':>7s}" + f" Modes: eager (with L2 flush, median of {BENCH_MEASURE}), graph ({BENCH_MEASURE} replays), " + f"kernel (on-device dwell)" + ) + print(f"{'=' * 140}") + print( + f"{'T':>6s} | {'Path':>7s} | " + f"{'FLY eager':>10s} | {'FLY graph':>10s} | {'FLY kern':>10s} | " + f"{'CK eager':>10s} | {'CK graph':>10s} | {'CK kern':>10s} | " + f"{'Eager':>7s} | {'Graph':>7s} | {'Kern':>7s}" ) - print("-" * 110) + print("-" * 140) for T in token_sweep: torch.manual_seed(42) @@ -779,8 +1067,9 @@ def fly_fn(): fly_eager = bench_eager_us(fly_fn) fly_graph = bench_graph_us(fly_fn) + fly_kernel = bench_kernel_us(fly_fn) - ck_eager, ck_graph = None, None + ck_eager, ck_graph, ck_kernel = None, None, None if aiter_moe_sorting is not None: def ck_fn(): @@ -790,6 +1079,7 @@ def ck_fn(): ck_eager = bench_eager_us(ck_fn) ck_graph = bench_graph_us(ck_fn) + ck_kernel = bench_kernel_us(ck_fn) def fmt(v): return f"{v:8.1f}us" if v is not None else " N/A" @@ -801,13 +1091,161 @@ def ratio(a, b): return f" {r:.2f}x" print( - f"{T:>6d} | {path:>7s} | {fmt(fly_eager)} | {fmt(fly_graph)} | " - f"{fmt(ck_eager)} | {fmt(ck_graph)} | " - f"{ratio(fly_eager, ck_eager)} | {ratio(fly_graph, ck_graph)}" + f"{T:>6d} | {path:>7s} | " + f"{fmt(fly_eager)} | {fmt(fly_graph)} | {fmt(fly_kernel)} | " + f"{fmt(ck_eager)} | {fmt(ck_graph)} | {fmt(ck_kernel)} | " + f"{ratio(fly_eager, ck_eager)} | {ratio(fly_graph, ck_graph)} | " + f"{ratio(fly_kernel, ck_kernel)}" ) - print("=" * 110) - print(" Ratio < 1.0 = FlyDSL faster. Eager includes launch overhead. Graph amortizes it.") + print("=" * 140) + print(" Ratio < 1.0 = FlyDSL faster. Eager includes host launch overhead;") + print(" Graph amortizes launch but includes dispatch gaps; Kern is pure on-GPU kernel time.") + print() + + +def run_fused_bench_comparison(token_sweep=None, dtype_str="bf16", num_experts=256, topk=8, model_dim=4096): + """Benchmark fused softmax+top-K+sort vs the unfused 2-kernel chain. + + Measures the routing stage (gating + sort) end-to-end across a range of + M values, in eager mode, CUDA-graph mode, and pure on-device kernel + time (profiler-based). + + Parameters + ---------- + token_sweep : list[int] | None + M values to sweep. Default: [1, 4, 8, 16, 32, 64]. + dtype_str : 'bf16' | 'f16' | 'f32'. + num_experts : E for the MoE router (e.g. 256 for DeepSeek R1). + topk : Experts per token (e.g. 8 for DeepSeek R1). + model_dim : Hidden size — sets `moe_buf` size, which controls the + blocks-1..N zero-pass cost portion of the kernel + (DeepSeek R1: 7168). + """ + E = num_experts + if token_sweep is None: + # Decode regime is where fusion applies. Include a few sizes above + # FUSED_ONESHOT_MAX_T=16 so the fallback path is also exercised. + token_sweep = [1, 4, 8, 16, 32, 64] + + torch_dtype = _TORCH_DTYPE[dtype_str] + + print(f"\n{'=' * 145}") + print( + f" MoE Fused Routing Benchmark: fused vs (gating + moe_sorting) " + f"(E={E}, topk={topk}, dtype={dtype_str}, " + f"model_dim={model_dim}, unit_size={UNIT_SIZE})" + ) + print(f" Device: {torch.cuda.get_device_name(0)}") + print( + f" Modes: eager (with L2 flush, median of {BENCH_MEASURE}), " + f"graph ({BENCH_MEASURE} replays), kernel (sum of on-device kernel time)" + ) + print(f"{'=' * 145}") + print( + f"{'T':>6s} | {'Path':>9s} | " + f"{'unfused eager':>14s} | {'fused eager':>13s} | " + f"{'unfused graph':>14s} | {'fused graph':>13s} | " + f"{'unfused kern':>13s} | {'fused kern':>12s} | " + f"{'Eager':>7s} | {'Graph':>7s} | {'Kern':>7s}" + ) + print("-" * 145) + + # Cache the standalone gating launcher so the unfused path doesn't pay + # compile time inside the measured region. + launch_topk = build_topk_gating_softmax_module( + num_experts=E, + topk=topk, + dtype_str=dtype_str, + renormalize=True, + ) + + for T in token_sweep: + torch.manual_seed(42) + gating_logits = ( + (torch.rand((T, E), device="cuda", dtype=torch.float32) * 4.0 - 2.0).to(torch_dtype).contiguous() + ) + + max_num_tokens_padded = T * topk + E * UNIT_SIZE - topk + max_num_m_blocks = (max_num_tokens_padded + UNIT_SIZE - 1) // UNIT_SIZE + sorted_ids = torch.empty(max_num_tokens_padded, dtype=torch.int32, device="cuda") + sorted_w = torch.empty(max_num_tokens_padded, dtype=torch.float32, device="cuda") + sorted_eids = torch.empty(max_num_m_blocks, dtype=torch.int32, device="cuda") + nvalid = torch.empty(2, dtype=torch.int32, device="cuda") + moe_buf_2d = torch.empty((T, model_dim), dtype=torch.bfloat16, device="cuda") + + # Unfused: separate gating + sort tensors + u_topk_w = torch.empty((T, topk), dtype=torch.float32, device="cuda") + u_topk_ids = torch.empty((T, topk), dtype=torch.int32, device="cuda") + u_tei = torch.empty((T, topk), dtype=torch.int32, device="cuda") + + def unfused_fn(): + stream = torch.cuda.current_stream() + launch_topk(gating_logits, u_topk_w, u_topk_ids, u_tei, T, stream=stream) + moe_sorting_flydsl( + u_topk_ids, + u_topk_w, + sorted_ids, + sorted_w, + sorted_eids, + nvalid, + moe_buf_2d, + E, + UNIT_SIZE, + ) + + def fused_fn(): + moe_softmax_sort_flydsl( + gating_logits, + sorted_ids, + sorted_w, + sorted_eids, + nvalid, + moe_buf_2d, + E, + topk, + dtype_str, + unit_size=UNIT_SIZE, + ) + + # Warm up both paths once before measurement (covers compile cache). + unfused_fn() + fused_fn() + torch.cuda.synchronize() + + unfused_eager = bench_eager_us(unfused_fn) + fused_eager = bench_eager_us(fused_fn) + unfused_graph = bench_graph_us(unfused_fn) + fused_graph = bench_graph_us(fused_fn) + unfused_kernel = bench_kernel_us(unfused_fn) + fused_kernel = bench_kernel_us(fused_fn) + + path = "fused" if T <= 16 else "fallback" + + def fmt(v, w=12): + return f"{v:{w}.1f}us" if v is not None else f"{'N/A':>{w + 2}s}" + + def ratio(unfused, fused): + if unfused is None or fused is None or fused == 0: + return " N/A" + # Speedup: how much faster is fused vs unfused? >1.0 = fused wins. + r = unfused / fused + return f" {r:.2f}x" + + print( + f"{T:>6d} | {path:>9s} | " + f"{fmt(unfused_eager)} | {fmt(fused_eager)} | " + f"{fmt(unfused_graph)} | {fmt(fused_graph)} | " + f"{fmt(unfused_kernel, 11)} | {fmt(fused_kernel, 10)} | " + f"{ratio(unfused_eager, fused_eager)} | " + f"{ratio(unfused_graph, fused_graph)} | " + f"{ratio(unfused_kernel, fused_kernel)}" + ) + + print("=" * 145) + print(" Ratio > 1.0 = fused faster. Eager includes host launch overhead (2 launches vs 1);") + print(" Graph amortizes launch but still includes inter-kernel dispatch gaps;") + print(" Kern is pure on-GPU kernel time (sum of per-kernel device dwell, via torch.profiler).") print() @@ -821,10 +1259,28 @@ def main(): parser.add_argument("-k", "--topk", type=int, default=None, help="Top-k") parser.add_argument("--all", action="store_true", help="Run all configs") parser.add_argument("--aiter", action="store_true", help="Compare with aiter") - parser.add_argument("--bench", action="store_true", help="Run benchmark sweep (eager + graph, FlyDSL vs CK)") + parser.add_argument("--bench", action="store_true", help="Run benchmark sweep (eager + graph + kern, FlyDSL vs CK)") + parser.add_argument( + "--bench-fused", action="store_true", help="Run fused-vs-unfused benchmark for moe_softmax_sort_flydsl" + ) parser.add_argument( "--bench-tokens", type=str, default=None, help="Comma-separated T values for bench (default: all)" ) + parser.add_argument("--bench-experts", type=int, default=256, help="Num experts E for --bench-fused (default 256)") + parser.add_argument("--bench-topk", type=int, default=8, help="topk for --bench-fused (default 8)") + parser.add_argument( + "--bench-dtype", + type=str, + default="bf16", + choices=["bf16", "f16", "f32"], + help="Gating-logits dtype for --bench-fused (default bf16)", + ) + parser.add_argument( + "--bench-model-dim", + type=int, + default=4096, + help="Hidden size (moe_buf width) for --bench-fused " "(DeepSeek R1: 7168; default 4096)", + ) args = parser.parse_args() if args.bench: @@ -834,6 +1290,19 @@ def main(): run_bench_comparison(token_sweep=token_sweep) return + if args.bench_fused: + token_sweep = None + if args.bench_tokens: + token_sweep = [int(t) for t in args.bench_tokens.split(",")] + run_fused_bench_comparison( + token_sweep=token_sweep, + dtype_str=args.bench_dtype, + num_experts=args.bench_experts, + topk=args.bench_topk, + model_dim=args.bench_model_dim, + ) + return + if args.T is not None: E = args.E or 256 topk = args.topk or 8