Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernels/rmsnorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str):
BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW
elem_bits = 32 if dtype_str == "f32" else 16

@flyc.kernel
@flyc.kernel(known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1])
def rmsnorm_large_m_small_n_kernel(
Input: fx.Tensor,
Gamma: fx.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions tests/kernels/test_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def test_all():
# (16, 512, "bf16"), # BF16
# (1024, 8192, "bf16"), # BF16
(32768, 8192, "bf16"),
# Covers the large-M small-N path in build_rmsnorm_module
# (M > 8192 and N <= 2048): it launches BLOCK_M * THREADS_PER_ROW
# = 512..1024 threads/block, which requires known_block_size.
# N=512 is a real DeepSeek-R1 shape and hits the 1024-thread case.
(16384, 512, "bf16"),
]

do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1"
Expand Down