Skip to content

[PERF] MXFP4 (a4w4) MoE backend for gfx950#3470

Open
ColorsWind wants to merge 4 commits into
ROCm:mainfrom
RadeonFlow:rf-mxfp4-moe
Open

[PERF] MXFP4 (a4w4) MoE backend for gfx950#3470
ColorsWind wants to merge 4 commits into
ROCm:mainfrom
RadeonFlow:rf-mxfp4-moe

Conversation

@ColorsWind
Copy link
Copy Markdown

[PERF] mxfp4_moe — a native MXFP4 (a4w4) MoE backend for gfx950

Motivation

Kimi-K2.5 / DeepSeek-R1 serve MoE FFNs with MXFP4 (a4w4) weights. The existing FlyDSL (MLIR) a4w4 MoE path leaves performance on the table across the whole token range — it is launch/latency-bound at decode-size M and not fully HBM-bound at prefill-size M. This PR adds mxfp4_moe, a native pure-HIP MXFP4 MoE backend that is faster than FlyDSL at every M bucket on MI350X.

Technical Details

The backend is two JIT modules plus a fused_moe integration:

  • module_moe_mxfp4_aux — token sort, mxfp4 quant (bf16 → fp4 + e8m0, group=32), scale sort/shuffle, and the weighted scatter-reduce.
  • module_moe_mxfp4_gemm — the two a4w4 GEMMs (gemm1 fused-SwiGLU, gemm2 down-proj) with codegen'd kernel-name dispatch (gen_instances.py) and MFMA f4f4 microkernels.
  • fused_moe wiringmxfp4_moe_* kernel names route to an _mxfp4_moe_run pipeline via a MOEMetadata hook; per-M kernel selection comes from a shuffle_kind-tagged tuned CSV, so the backend ships its own tuned rows next to the default tuning without dedup conflicts.

Several pipelines, dispatched per token-count M

The optimal fusion/accumulation strategy flips between decode-size and prefill-size M, so we designed several pipelines and pick one per M from the tuned CSV:

  • inline-quant + atomic — small M (BM=16): sort → quant-fused-gemm1 → gemm2 (atomic-add reduce). Quant is folded into gemm1 and gemm2 atomic-adds straight into the output, removing the per-sorted-row partial buffer and the separate quant / scatter passes. (The gemm2 B-load additionally has a per-M-tuned cached vs non-temporal choice.) Wins when M ≲ BM and the path is launch/latency-bound.
  • three-stage — mid-to-large M (BM=32 / 128): sort → quant → sort-scale → gemm1 → gemm2 [→ scatter-reduce]. Prologue runs as separate full-throughput kernels. At mid M (BM=32) gemm2 atomic-accumulates straight into the output (no separate reduce); at large M (BM=128) gemm2 writes per-sorted-row flat_out and a dedicated non-atomic scatter-reduce does the weighted sum — avoiding the atomic contention that grows with M. HBM-bound regime.

The core tradeoff: fusion + atomics win when operator shapes are complementary and launch overhead dominates (small M); at large M they break down — the same token is read by multiple experts (duplicated quant work) and atomic contention grows — so the separate, non-atomic path wins.

On top of the large-M non-atomic path, an optional mxfp4-intermediate variant (CSV-gated, enabled on the largest prefill buckets) stages gemm2's flat_out as packed fp4 + e8m0 (mxfp4-out epilog) and reads it back with scatter_reduce_q, cutting the reduce's HBM traffic ~3.8×.

Test Plan

Machine / shape

  • AMD Instinct MI350X (gfx950)
  • Kimi-K2.5, TP=4: NE=385, H=7168, D_INTER=512, TOPK=9
  • Routing: 1 shared expert (always selected) + 8 routed experts, softmax with a mild per-expert popularity skew (skew=0.5).
  • Both backends fed the same per-1x32 MXFP4 weights (quantized once).
  • FlyDSL baseline: ROCm/aiter @ ca781d81.

Cache / timing

  • Correctness: end-to-end pipeline vs a full-precision bf16 ground truth.
  • Performance: end-to-end MoE wall-clock (perf_counter + synchronize), median of 30 iters, warm cache (weights resident in L2; no per-iter flush). mxfp4_moe is driven through the production fused_moe entry — per-M kernel dispatch comes straight from the tuned CSV (including the mxfp4-intermediate path on the largest buckets), so the numbers carry real host-side dispatch overhead; FlyDSL is timed at its kernel stages (best block_m reported). Same routing and fp4 weights for both — if anything this is conservative for mxfp4_moe at small M.

Correctness

End-to-end cos(kernel, bf16_truth) is 0.975–0.978 across the full M range and never worse than the fp4 reference — all pass.

Performance (end-to-end, mxfp4_moe vs FlyDSL)

M mxfp4_moe (µs) flydsl (µs) speedup cos(mxfp4_moe, flydsl)
4 99.3 178.3 1.80× 0.985
8 116.2 201.7 1.74× 0.987
16 153.8 233.6 1.52× 0.987
32 204.5 300.3 1.47× 0.987
64 254.8 342.8 1.35× 0.987
128 304.2 412.4 1.36× 0.987
256 363.3 471.7 1.30× 0.987
4096 999.7 1115.2 1.12× 0.987
8192 1399.0 1785.9 1.28× 0.981
16384 2250.5 3176.7 1.41× 0.981

mxfp4_moe is faster than FlyDSL across the entire token range (1.12×–1.80×). At the largest buckets (M ≥ 8192) the CSV switches to the mxfp4-intermediate path; its cos vs FlyDSL drops to ~0.981 (the inherent loss of the before-sum fp4 quant) while still pulling ahead in latency.

Add module_moe_mxfp4_aux for the MXFP4 (a4w4) MoE backend: fused token
sort+quant, 3-stage sort, scale sort/shuffle, and weighted bf16
scatter-reduce (gfx950, Kimi-K2.5 / DSR shapes).
Add module_moe_mxfp4_gemm: the two MXFP4 (a4w4) MoE GEMMs (gemm1 fused
SwiGLU, gemm2 down-proj) with codegen'd kernel-name dispatch
(gen_instances.py), MFMA f4f4 microkernels, epilogs, non-temporal
B-load dispatch, and XCD-swizzle interface (gfx950, Kimi-K2.5 / DSR).
Route mxfp4_moe_* kernel names to the _mxfp4_moe_run pipeline
(sort/quant -> gemm1 -> gemm2 -> scatter_reduce) via a MOEMetadata
pipeline hook, with shuffle_kind-tagged CSV lookup so the backend ships
its own tuned rows alongside the default set. Add Kimi-K2.5 tuned CSV.
Add the optional mxfp4-intermediate path: gemm2 stages flat_out as
packed fp4 + e8m0 (mxfp4-out epilog) and scatter_reduce_q reads it back
as mxfp4, cutting the reduce's HBM traffic ~3.8x. Gated by a
_MXFP4OUT tag in the tuned CSV (large-M buckets only, where the reduce
win beats the gemm2 epilog overhead). bf16-intermediate stays default.
@ColorsWind ColorsWind requested a review from a team June 1, 2026 15:52
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3470 --add-label <label>

@benenzhu
Copy link
Copy Markdown

benenzhu commented Jun 2, 2026

@coderfeli
Copy link
Copy Markdown
Collaborator

@benenzhu it is launch/latency-bound' in your bench script you use s = time.perf_counter() and then time.perf_counter() - s. It counts all cpu overhead. The real run in e2e uses cudagraph that will not count.

@ftyghome
Copy link
Copy Markdown

ftyghome commented Jun 3, 2026

@benenzhu it is launch/latency-bound' in your bench script you use s = time.perf_counter() and then time.perf_counter() - s. It counts all cpu overhead. The real run in e2e uses cudagraph that will not count.

Hi Felix,

Thanks for the review!
We just updated the microbench script to capture 60 runs in a cudagraph. The new perf data is listed below:

M mxfp4 us flydsl us flydsl/mxfp4 cos
4 43.5 69.9 1.61x 0.9845
8 62.2 96.4 1.55x 0.9868
16 101.6 132.9 1.31x 0.9874
32 151.3 190.2 1.26x 0.9868
64 206.8 236.6 1.14x 0.9874
128 256.3 305.8 1.19x 0.9875
256 310.2 341.0 1.10x 0.9871
4096 977.4 1025.6 1.05x 0.9875
8192 1381.3 1718.0 1.24x 0.9812
16384 2254.4 3133.8 1.39x 0.9811

@lalala-sh
Copy link
Copy Markdown
Contributor

@benenzhu it is launch/latency-bound' in your bench script you use s = time.perf_counter() and then time.perf_counter() - s. It counts all cpu overhead. The real run in e2e uses cudagraph that will not count.

Hi Felix,

Thanks for the review! We just updated the microbench script to capture 60 runs in a cudagraph. The new perf data is listed below:

M mxfp4 us flydsl us flydsl/mxfp4 cos
4 43.5 69.9 1.61x 0.9845
8 62.2 96.4 1.55x 0.9868
16 101.6 132.9 1.31x 0.9874
32 151.3 190.2 1.26x 0.9868
64 206.8 236.6 1.14x 0.9874
128 256.3 305.8 1.19x 0.9875
256 310.2 341.0 1.10x 0.9871
4096 977.4 1025.6 1.05x 0.9875
8192 1381.3 1718.0 1.24x 0.9812
16384 2254.4 3133.8 1.39x 0.9811

Hello @ftyghome, Could you please update the perf compare again? Collect the flydsl date with fused_moe interface, so that flydsl will be called through the moe tuner too.

@benenzhu
Copy link
Copy Markdown

benenzhu commented Jun 3, 2026

@benenzhu it is launch/latency-bound' in your bench script you use s = time.perf_counter() and then time.perf_counter() - s. It counts all cpu overhead. The real run in e2e uses cudagraph that will not count.

Hi Felix,
Thanks for the review! We just updated the microbench script to capture 60 runs in a cudagraph. The new perf data is listed below:
M mxfp4 us flydsl us flydsl/mxfp4 cos
4 43.5 69.9 1.61x 0.9845
8 62.2 96.4 1.55x 0.9868
16 101.6 132.9 1.31x 0.9874
32 151.3 190.2 1.26x 0.9868
64 206.8 236.6 1.14x 0.9874
128 256.3 305.8 1.19x 0.9875
256 310.2 341.0 1.10x 0.9871
4096 977.4 1025.6 1.05x 0.9875
8192 1381.3 1718.0 1.24x 0.9812
16384 2254.4 3133.8 1.39x 0.9811

Hello @ftyghome, Could you please update the perf compare again? Collect the flydsl date with fused_moe interface, so that flydsl will be called through the moe tuner too.

Hi updated: @lalala-sh
Script: https://gist.github.com/benenzhu/ea38af8e5606ed5596abc1890e56a5a1
Change by ai, not verify the correctness.

M mxfp4 us bm flydsl us flydsl/mxfp4 mx·fly mx·ref fly·ref
4 41.9 16 59.2 1.41x 0.9845 0.9533 0.9482
8 61.2 16 74.6 1.22x 0.9868 0.9524 0.9485
16 100.3 16 122.1 1.22x 0.9876 0.9538 0.9514
32 147.8 16 165.0 1.12x 0.9868 0.9548 0.9520
64 203.5 16 213.9 1.05x 0.9874 0.9555 0.9529
128 254.2 16 263.8 1.04x 0.9875 0.9562 0.9536
256 304.1 32 306.4 1.01x 0.9871 0.9561 0.9538
8192 1137.3 128 1359.0 1.19x 0.9811 0.9505 0.9541
16384 1821.7 128 2325.2 1.28x 0.9810 0.9506 0.9541

}
if (NE == 257) {
if (MB == 32) { LAUNCH(257, 9, 32, 7168); return; }
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why so many constraints here? Could you make it more general ? I think the kernels don't have these limitation?

}
if (NE == 257) {
if (MB == 16) { LAUNCH_IQ_ZI(257, 9, 16, 7168); return; }
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (NE == 257) {
if (MB == 16) { LAUNCH_IQ(257, 9, 16, 7168); return; }
}
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

}
if (NE == 257) {
if (MB == 32) { LAUNCH(257, 9, 32, 7168); return; }
if (MB == 128) { LAUNCH(257, 9, 128, 7168); return; }
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also make quant a general usable kernel?

__launch_bounds__(256,
is_nonatomic_v<kEpilog> ? 1 :
((BM == 16) ? 4 : 2))
kernel(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm/FlyDSL#660 for gemm2, we try to impl a flydsl pipeline that's exactly the same as this one. Would you mind transfer the same style and infra for better long term maintaining?


if (D_HIDDEN == 7168 && D_INTER == 512) {
if (NE == 385) { LAUNCH_G2_MXFP4(385, 512, 7168); return; }
if (NE == 257) { LAUNCH_G2_MXFP4(257, 512, 7168); return; }
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard code shapes. also couldn't be extended.

sorted_weights.data_ptr<float>(), \
reinterpret_cast<__hip_bfloat16*>(out.data_ptr()))

if (D_HIDDEN == 7168 && TOPK == 9) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

const bool nt_hints = (MB >= 128);

#define LAUNCH_Q(D_HIDDEN_, TOPK_, NT_) \
aiter::mxfp4_moe::moe_scatter_reduce::launch_mxfp4< \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All input is quanted to fp4. It's ok for gsm8k but not in our many other real tests. could you also add flags to enable ?

const buffer_rsrc_t B_q_rsrc =
make_buffer_rsrc(B_q,
(uint32_t)((long long)NUM_EXPERTS * N_OUT * K_HALF * sizeof(__hip_fp4x2_storage_t)));
constexpr int kAS_bound_div = kAtomic ? BM_GRID : 32;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The A_scale buffer-resource bound is 2× too small for BM=64 (atomic path).

A_scale is laid out in fixed 32-row chunks, and the read side addresses it accordingly at line 333:
const int chunk_base = (BM == 16) ? (m_row / BM_GRID) : (m_row / 32);
a_scale_s_base[sub] = (chunk_base + sub) * kAS_per_chunk_dw * 4;
So the chunk-stride divisor is 16 for BM=16 and 32 for everything else — it is not BM. But here the bound divides by kAtomic ? BM_GRID :
32, which only matches the addressing for BM∈{16,32}:
For BM=64 the rsrc extent is (MAX_M/64)*kAS_per_chunk_dw, but addressing reaches (MAX_M/32)*kAS_per_chunk_dw. Once m_row ≥ MAX_M/2, the
hardware buffer-bound check fails the scale loads and returns 0 → e8m0 = 0 → scale 2^-127 → those output rows silently collapse to ~0.
Wrong results, no error.

(257, 7168, 512, 9), # DSR
]

MAX_M = 655360
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the hard limit of topk * tokens? Then all prefill inputs >72.8K will fail and produce wrong outputs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants