[PERF] MXFP4 (a4w4) MoE backend for gfx950#3470
Conversation
Add module_moe_mxfp4_aux for the MXFP4 (a4w4) MoE backend: fused token sort+quant, 3-stage sort, scale sort/shuffle, and weighted bf16 scatter-reduce (gfx950, Kimi-K2.5 / DSR shapes).
Add module_moe_mxfp4_gemm: the two MXFP4 (a4w4) MoE GEMMs (gemm1 fused SwiGLU, gemm2 down-proj) with codegen'd kernel-name dispatch (gen_instances.py), MFMA f4f4 microkernels, epilogs, non-temporal B-load dispatch, and XCD-swizzle interface (gfx950, Kimi-K2.5 / DSR).
Route mxfp4_moe_* kernel names to the _mxfp4_moe_run pipeline (sort/quant -> gemm1 -> gemm2 -> scatter_reduce) via a MOEMetadata pipeline hook, with shuffle_kind-tagged CSV lookup so the backend ships its own tuned rows alongside the default set. Add Kimi-K2.5 tuned CSV.
Add the optional mxfp4-intermediate path: gemm2 stages flat_out as packed fp4 + e8m0 (mxfp4-out epilog) and scatter_reduce_q reads it back as mxfp4, cutting the reduce's HBM traffic ~3.8x. Gated by a _MXFP4OUT tag in the tuned CSV (large-M buckets only, where the reduce win beats the gemm2 epilog overhead). bf16-intermediate stays default.
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
|
Performance bench script: https://gist.github.com/ftyghome/088870c544693da57258d22968573620 |
|
@benenzhu it is launch/latency-bound' in your bench script you use s = time.perf_counter() and then time.perf_counter() - s. It counts all cpu overhead. The real run in e2e uses cudagraph that will not count. |
Hi Felix, Thanks for the review!
|
Hello @ftyghome, Could you please update the perf compare again? Collect the flydsl date with fused_moe interface, so that flydsl will be called through the moe tuner too. |
Hi updated: @lalala-sh
|
| } | ||
| if (NE == 257) { | ||
| if (MB == 32) { LAUNCH(257, 9, 32, 7168); return; } | ||
| } |
There was a problem hiding this comment.
why so many constraints here? Could you make it more general ? I think the kernels don't have these limitation?
| } | ||
| if (NE == 257) { | ||
| if (MB == 16) { LAUNCH_IQ_ZI(257, 9, 16, 7168); return; } | ||
| } |
| if (NE == 257) { | ||
| if (MB == 16) { LAUNCH_IQ(257, 9, 16, 7168); return; } | ||
| } | ||
| } |
| } | ||
| if (NE == 257) { | ||
| if (MB == 32) { LAUNCH(257, 9, 32, 7168); return; } | ||
| if (MB == 128) { LAUNCH(257, 9, 128, 7168); return; } |
There was a problem hiding this comment.
Also make quant a general usable kernel?
| __launch_bounds__(256, | ||
| is_nonatomic_v<kEpilog> ? 1 : | ||
| ((BM == 16) ? 4 : 2)) | ||
| kernel( |
There was a problem hiding this comment.
ROCm/FlyDSL#660 for gemm2, we try to impl a flydsl pipeline that's exactly the same as this one. Would you mind transfer the same style and infra for better long term maintaining?
|
|
||
| if (D_HIDDEN == 7168 && D_INTER == 512) { | ||
| if (NE == 385) { LAUNCH_G2_MXFP4(385, 512, 7168); return; } | ||
| if (NE == 257) { LAUNCH_G2_MXFP4(257, 512, 7168); return; } |
There was a problem hiding this comment.
hard code shapes. also couldn't be extended.
| sorted_weights.data_ptr<float>(), \ | ||
| reinterpret_cast<__hip_bfloat16*>(out.data_ptr())) | ||
|
|
||
| if (D_HIDDEN == 7168 && TOPK == 9) { |
| const bool nt_hints = (MB >= 128); | ||
|
|
||
| #define LAUNCH_Q(D_HIDDEN_, TOPK_, NT_) \ | ||
| aiter::mxfp4_moe::moe_scatter_reduce::launch_mxfp4< \ |
There was a problem hiding this comment.
All input is quanted to fp4. It's ok for gsm8k but not in our many other real tests. could you also add flags to enable ?
| const buffer_rsrc_t B_q_rsrc = | ||
| make_buffer_rsrc(B_q, | ||
| (uint32_t)((long long)NUM_EXPERTS * N_OUT * K_HALF * sizeof(__hip_fp4x2_storage_t))); | ||
| constexpr int kAS_bound_div = kAtomic ? BM_GRID : 32; |
There was a problem hiding this comment.
The A_scale buffer-resource bound is 2× too small for BM=64 (atomic path).
A_scale is laid out in fixed 32-row chunks, and the read side addresses it accordingly at line 333:
const int chunk_base = (BM == 16) ? (m_row / BM_GRID) : (m_row / 32);
a_scale_s_base[sub] = (chunk_base + sub) * kAS_per_chunk_dw * 4;
So the chunk-stride divisor is 16 for BM=16 and 32 for everything else — it is not BM. But here the bound divides by kAtomic ? BM_GRID :
32, which only matches the addressing for BM∈{16,32}:
For BM=64 the rsrc extent is (MAX_M/64)*kAS_per_chunk_dw, but addressing reaches (MAX_M/32)*kAS_per_chunk_dw. Once m_row ≥ MAX_M/2, the
hardware buffer-bound check fails the scale loads and returns 0 → e8m0 = 0 → scale 2^-127 → those output rows silently collapse to ~0.
Wrong results, no error.
| (257, 7168, 512, 9), # DSR | ||
| ] | ||
|
|
||
| MAX_M = 655360 |
There was a problem hiding this comment.
this is the hard limit of topk * tokens? Then all prefill inputs >72.8K will fail and produce wrong outputs?
[PERF] mxfp4_moe — a native MXFP4 (a4w4) MoE backend for gfx950
Motivation
Kimi-K2.5 / DeepSeek-R1 serve MoE FFNs with MXFP4 (a4w4) weights. The existing FlyDSL (MLIR) a4w4 MoE path leaves performance on the table across the whole token range — it is launch/latency-bound at decode-size M and not fully HBM-bound at prefill-size M. This PR adds
mxfp4_moe, a native pure-HIP MXFP4 MoE backend that is faster than FlyDSL at every M bucket on MI350X.Technical Details
The backend is two JIT modules plus a
fused_moeintegration:module_moe_mxfp4_aux— token sort, mxfp4 quant (bf16 → fp4 + e8m0, group=32), scale sort/shuffle, and the weighted scatter-reduce.module_moe_mxfp4_gemm— the two a4w4 GEMMs (gemm1 fused-SwiGLU, gemm2 down-proj) with codegen'd kernel-name dispatch (gen_instances.py) and MFMA f4f4 microkernels.fused_moewiring —mxfp4_moe_*kernel names route to an_mxfp4_moe_runpipeline via aMOEMetadatahook; per-M kernel selection comes from ashuffle_kind-tagged tuned CSV, so the backend ships its own tuned rows next to the default tuning without dedup conflicts.Several pipelines, dispatched per token-count M
The optimal fusion/accumulation strategy flips between decode-size and prefill-size M, so we designed several pipelines and pick one per M from the tuned CSV:
sort → quant-fused-gemm1 → gemm2 (atomic-add reduce). Quant is folded into gemm1 and gemm2 atomic-adds straight into the output, removing the per-sorted-row partial buffer and the separate quant / scatter passes. (The gemm2 B-load additionally has a per-M-tuned cached vs non-temporal choice.) Wins when M ≲ BM and the path is launch/latency-bound.sort → quant → sort-scale → gemm1 → gemm2 [→ scatter-reduce]. Prologue runs as separate full-throughput kernels. At mid M (BM=32) gemm2 atomic-accumulates straight into the output (no separate reduce); at large M (BM=128) gemm2 writes per-sorted-rowflat_outand a dedicated non-atomic scatter-reduce does the weighted sum — avoiding the atomic contention that grows with M. HBM-bound regime.The core tradeoff: fusion + atomics win when operator shapes are complementary and launch overhead dominates (small M); at large M they break down — the same token is read by multiple experts (duplicated quant work) and atomic contention grows — so the separate, non-atomic path wins.
On top of the large-M non-atomic path, an optional mxfp4-intermediate variant (CSV-gated, enabled on the largest prefill buckets) stages gemm2's
flat_outas packed fp4 + e8m0 (mxfp4-out epilog) and reads it back withscatter_reduce_q, cutting the reduce's HBM traffic ~3.8×.Test Plan
Machine / shape
ca781d81.Cache / timing
perf_counter+synchronize), median of 30 iters, warm cache (weights resident in L2; no per-iter flush).mxfp4_moeis driven through the productionfused_moeentry — per-M kernel dispatch comes straight from the tuned CSV (including the mxfp4-intermediate path on the largest buckets), so the numbers carry real host-side dispatch overhead; FlyDSL is timed at its kernel stages (best block_m reported). Same routing and fp4 weights for both — if anything this is conservative formxfp4_moeat small M.Correctness
End-to-end
cos(kernel, bf16_truth)is 0.975–0.978 across the full M range and never worse than the fp4 reference — all pass.Performance (end-to-end, mxfp4_moe vs FlyDSL)
mxfp4_moeis faster than FlyDSL across the entire token range (1.12×–1.80×). At the largest buckets (M ≥ 8192) the CSV switches to the mxfp4-intermediate path; its cos vs FlyDSL drops to ~0.981 (the inherent loss of the before-sum fp4 quant) while still pulling ahead in latency.