Skip to content

[Bugfix] rmsnorm: annotate known_block_size on large-M small-N path#639

Open
jhinpan wants to merge 3 commits into
ROCm:mainfrom
jhinpan:fix/rmsnorm-large-m-small-n-block-size
Open

[Bugfix] rmsnorm: annotate known_block_size on large-M small-N path#639
jhinpan wants to merge 3 commits into
ROCm:mainfrom
jhinpan:fix/rmsnorm-large-m-small-n-block-size

Conversation

@jhinpan
Copy link
Copy Markdown
Contributor

@jhinpan jhinpan commented Jun 3, 2026

Restored PR: this is a recreation of #615. The previous PR was closed accidentally when the old head repository was deleted, so GitHub lost the head repo association. This branch is restored at the same old head commit 5dbacfa5fa3df9218ba8a47959dfbba332cea8e6.


Fixes #614.

Problem

The large-M small-N rmsnorm path (M > 8192 and N <= 2048) launches blocks of BLOCK_M * THREADS_PER_ROW = 1024 threads (N<=512) or 512 threads (N=1536), but the kernel is decorated with a bare @flyc.kernel. Without known_block_size, the AMDGPU backend keeps its default max flat workgroup size of 256, so the launch aborts:

ValueError: launch block size 1024x1x1 = 1024 threads exceeds the AMDGPU
default max_flat_workgroup_size of 256.

This crashes rmsnorm on real shapes — DeepSeek-R1 (N=512, N=1536) and Qwen3 per-head q/k-norm (N=128) at large prefill token counts.

Fix

Add known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1] to the kernel decorator — the same idiom every other >256-thread kernel in kernels/ uses (fp8_gemm_8wave, flash_attn_func, hgemm_splitk, custom_all_reduce_kernel, ...). One line; preserves the path's intended block size.

-    @flyc.kernel
+    @flyc.kernel(known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1])
     def rmsnorm_large_m_small_n_kernel(

Verification (MI350X / gfx950, ROCm 7.2, FLYDSL_RUNTIME_ENABLE_CACHE=0)

M N path result max|err| vs fp32 ref (bf16)
16384 128 special compiles + runs 0.0147
16384 512 special compiles + runs 0.0149
16384 1536 special compiles + runs 0.0155
32768 128 special compiles + runs 0.0139
16384 2048 special compiles + runs 0.0156
16384 4096 fast (control) unchanged 0.0156
32768 8192 fast (control) unchanged 0.0156

All previously-crashing shapes now run and match the fp32 reference within bf16 tolerance; fast-vectorized and generic paths are untouched.

🤖 Generated with Claude Code

The large-M small-N rmsnorm path (M>8192 & N<=2048) launches blocks of
BLOCK_M*THREADS_PER_ROW threads, which is 1024 (N<=512) or 512 (N=1536).
Without a known_block_size annotation the AMDGPU backend keeps the default
max flat workgroup size of 256, so the launch aborts with:

  ValueError: launch block size 1024x1x1 = 1024 threads exceeds the AMDGPU
  default max_flat_workgroup_size of 256.

This crashes rmsnorm on real shapes -- e.g. DeepSeek-R1 (N=512/1536) and
Qwen3 q/k-norm (N=128) at large prefill token counts. Add
known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1] to the kernel decorator, the
same idiom every other >256-thread kernel uses (fp8_gemm_8wave, flash_attn_func,
hgemm_splitk, ...). Verified on MI350X/gfx950: M=16384/32768 x N=128/512/1536/2048
now compile, run, and match the fp32 reference (max|err| ~0.015 in bf16);
fast-vectorized and generic paths unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 3, 2026 08:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR updates the RMSNorm kernel configuration to provide an explicit compile-time block size to the flyc.kernel decorator when building the “large M / small N” module, likely to improve compilation and/or launch specialization.

Changes:

  • Passes known_block_size to @flyc.kernel for rmsnorm_large_m_small_n_kernel.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/rmsnorm_kernel.py
elem_bits = 32 if dtype_str == "f32" else 16

@flyc.kernel
@flyc.kernel(known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1])
jhinpan and others added 2 commits June 4, 2026 17:30
This path in build_rmsnorm_module launches BLOCK_M*THREADS_PER_ROW =
512..1024 threads/block and was never exercised by any test: test_all's
only enabled shape is (32768, 8192) (large M, but N=8192 > 2048 -> regular
path), so the missing known_block_size annotation fixed in this PR stayed
latent since the path landed.

Add (16384, 512, "bf16") -- a real DeepSeek-R1 shape that routes through
the large-M small-N branch and hits its 1024-thread worst case.

Verified on MI350X (gfx950): without the fix this shape raises
"launch block size 1024x1x1 = 1024 threads exceeds the AMDGPU default
max_flat_workgroup_size of 256"; with the fix it compiles and matches the
torch fp32 reference (max abs err 1.25e-2 < bf16 atol 2e-2).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jhinpan
Copy link
Copy Markdown
Contributor Author

jhinpan commented Jun 4, 2026

@coderfeli mind giving this a quick review when you get a chance?

Worth flagging why this stayed hidden: it's another test-coverage gap. The large-M/small-N path (M > 8192 and N <= 2048) was never selected by any test — test_all's only enabled shape is (32768, 8192) (large M, but N=8192 > 2048 → regular path). And the max_flat_workgroup_size guard (#329) predates this path (#481), so it has been broken since it landed; CI just never ran it. The verification table above was all manual.

I've added a committed regression test so it can't silently re-break: (16384, 512, "bf16") in test_all — a real DeepSeek-R1 shape that drives this branch at its 1024-thread worst case. Re-confirmed on MI350X (gfx950):

  • without the fixValueError: launch block size 1024x1x1 = 1024 threads exceeds the AMDGPU default max_flat_workgroup_size of 256
  • with the fix → compiles + matches the fp32 reference (max|err| 1.25e-2 < bf16 tol)

Diff stays minimal — 1 kernel line + 1 test shape. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

rmsnorm large-M small-N path crashes: block size exceeds AMDGPU max_flat_workgroup_size (missing known_block_size)

2 participants