[Bugfix] rmsnorm: annotate known_block_size on large-M small-N path#639
[Bugfix] rmsnorm: annotate known_block_size on large-M small-N path#639jhinpan wants to merge 3 commits into
Conversation
The large-M small-N rmsnorm path (M>8192 & N<=2048) launches blocks of BLOCK_M*THREADS_PER_ROW threads, which is 1024 (N<=512) or 512 (N=1536). Without a known_block_size annotation the AMDGPU backend keeps the default max flat workgroup size of 256, so the launch aborts with: ValueError: launch block size 1024x1x1 = 1024 threads exceeds the AMDGPU default max_flat_workgroup_size of 256. This crashes rmsnorm on real shapes -- e.g. DeepSeek-R1 (N=512/1536) and Qwen3 q/k-norm (N=128) at large prefill token counts. Add known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1] to the kernel decorator, the same idiom every other >256-thread kernel uses (fp8_gemm_8wave, flash_attn_func, hgemm_splitk, ...). Verified on MI350X/gfx950: M=16384/32768 x N=128/512/1536/2048 now compile, run, and match the fp32 reference (max|err| ~0.015 in bf16); fast-vectorized and generic paths unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR updates the RMSNorm kernel configuration to provide an explicit compile-time block size to the flyc.kernel decorator when building the “large M / small N” module, likely to improve compilation and/or launch specialization.
Changes:
- Passes
known_block_sizeto@flyc.kernelforrmsnorm_large_m_small_n_kernel.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| elem_bits = 32 if dtype_str == "f32" else 16 | ||
|
|
||
| @flyc.kernel | ||
| @flyc.kernel(known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1]) |
This path in build_rmsnorm_module launches BLOCK_M*THREADS_PER_ROW = 512..1024 threads/block and was never exercised by any test: test_all's only enabled shape is (32768, 8192) (large M, but N=8192 > 2048 -> regular path), so the missing known_block_size annotation fixed in this PR stayed latent since the path landed. Add (16384, 512, "bf16") -- a real DeepSeek-R1 shape that routes through the large-M small-N branch and hits its 1024-thread worst case. Verified on MI350X (gfx950): without the fix this shape raises "launch block size 1024x1x1 = 1024 threads exceeds the AMDGPU default max_flat_workgroup_size of 256"; with the fix it compiles and matches the torch fp32 reference (max abs err 1.25e-2 < bf16 atol 2e-2). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
@coderfeli mind giving this a quick review when you get a chance? Worth flagging why this stayed hidden: it's another test-coverage gap. The large-M/small-N path ( I've added a committed regression test so it can't silently re-break:
Diff stays minimal — 1 kernel line + 1 test shape. Thanks! |
Fixes #614.
Problem
The large-M small-N rmsnorm path (
M > 8192 and N <= 2048) launches blocks ofBLOCK_M * THREADS_PER_ROW= 1024 threads (N<=512) or 512 threads (N=1536), but the kernel is decorated with a bare@flyc.kernel. Withoutknown_block_size, the AMDGPU backend keeps its default max flat workgroup size of 256, so the launch aborts:This crashes rmsnorm on real shapes — DeepSeek-R1 (
N=512,N=1536) and Qwen3 per-head q/k-norm (N=128) at large prefill token counts.Fix
Add
known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1]to the kernel decorator — the same idiom every other >256-thread kernel inkernels/uses (fp8_gemm_8wave,flash_attn_func,hgemm_splitk,custom_all_reduce_kernel, ...). One line; preserves the path's intended block size.Verification (MI350X / gfx950, ROCm 7.2,
FLYDSL_RUNTIME_ENABLE_CACHE=0)All previously-crashing shapes now run and match the fp32 reference within bf16 tolerance; fast-vectorized and generic paths are untouched.
🤖 Generated with Claude Code