diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index ce4bd0a98..55aa1821e 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -304,7 +304,7 @@ def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str): BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW elem_bits = 32 if dtype_str == "f32" else 16 - @flyc.kernel + @flyc.kernel(known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1]) def rmsnorm_large_m_small_n_kernel( Input: fx.Tensor, Gamma: fx.Tensor, diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 04eae3c92..b3d7998e9 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -167,6 +167,11 @@ def test_all(): # (16, 512, "bf16"), # BF16 # (1024, 8192, "bf16"), # BF16 (32768, 8192, "bf16"), + # Covers the large-M small-N path in build_rmsnorm_module + # (M > 8192 and N <= 2048): it launches BLOCK_M * THREADS_PER_ROW + # = 512..1024 threads/block, which requires known_block_size. + # N=512 is a real DeepSeek-R1 shape and hits the 1024-thread case. + (16384, 512, "bf16"), ] do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1"