Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernels/rmsnorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _build_rmsnorm_large_m_small_n_module(M: int, N: int, dtype_str: str):
BLOCK_THREADS_SPECIAL = BLOCK_M * THREADS_PER_ROW
elem_bits = 32 if dtype_str == "f32" else 16

@flyc.kernel
@flyc.kernel(known_block_size=[BLOCK_THREADS_SPECIAL, 1, 1])
def rmsnorm_large_m_small_n_kernel(
Input: fx.Tensor,
Gamma: fx.Tensor,
Expand Down
Loading