diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index ffc3530ad..a0590ef23 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -17,6 +17,7 @@ import flydsl.expr as fx from flydsl.expr import arith, const_expr, gpu, range_constexpr from flydsl.expr import math as fmath +from flydsl.expr.typing import Vector as Vec from flydsl.expr.vector import ReductionOp, full from flydsl.runtime.device import get_rocm_arch as get_hip_arch from kernels.kernels_common import dtype_to_elem_type, get_warp_size @@ -28,8 +29,76 @@ BLOCK_THREADS = 256 WARP_SIZE = get_warp_size() VEC_WIDTH = 8 -USE_NONTEMPORAL = True -VEC_ALIGN = 16 + + +# ── Shared-memory allocation for block reductions ───────────────────── +def _make_reduction_storage(red_slots: int): + @fx.struct + class SharedStorage: + s_red: fx.Array[fx.Float32, red_slots, 16] + s_red2: fx.Array[fx.Float32, red_slots, 16] + + return SharedStorage + + +def _load_scalar(copy_atom, elem_dtype, divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.make_rmem_tensor(1, elem_dtype) + fx.copy_atom_call(copy_atom, view, r) + return fx.memref_load_vec(r)[0] + + +def _store_scalar(copy_atom, elem_dtype, store_dtype, divided_tensor, index, val): + r = fx.make_rmem_tensor(1, elem_dtype) + ts = full(1, store_dtype(val), store_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom, r, view) + + +def _load_vec(copy_atom, vec_width, elem_dtype, div_tensor, idx): + r = fx.make_rmem_tensor(vec_width, elem_dtype) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + +def _store_vec(copy_atom, vec_width, elem_dtype, val, div_tensor, idx): + r = fx.make_rmem_tensor(vec_width, elem_dtype) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + +def _to_elem_scalar(dtype_str: str, elem_dtype, y): + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + +def _to_elem_vec(dtype_str: str, elem_dtype, use_hw_cvt_bf16: bool, y): + if const_expr(dtype_str == "bf16"): + if const_expr(use_hw_cvt_bf16): + return y.to(elem_dtype) + u = y.bitcast(fx.Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(fx.Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + return packed.bitcast(elem_dtype) + if const_expr(dtype_str == "f32"): + return y + return y.to(elem_dtype) + + +def _store_yscale(scale_copy_atom, yscale_div, index, val): + r = fx.make_rmem_tensor(1, fx.Float32) + ts = full(1, fx.Float32(val), fx.Float32) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) def build_layernorm_module(M: int, N: int, dtype_str: str): @@ -37,14 +106,9 @@ def build_layernorm_module(M: int, N: int, dtype_str: str): USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) - elem_bits = 32 if dtype_str == "f32" else 16 - # ── Shared-memory allocation for block reductions ───────────────────── - @fx.struct - class SharedStorage: - s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] - s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] + SharedStorage = _make_reduction_storage(RED_SLOTS) # ── GPU kernel ──────────────────────────────────────────────────────── @flyc.kernel @@ -146,20 +210,10 @@ def compute_mean_rstd(sum_val, sumsq_val): copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - def _load_vec(div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def _store_vec(val, div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - # ── Pass 1: load input, accumulate sum / sumsq ─────────────── for tile_i in range_constexpr(num_tiles_py): idx = tid + tile_i * BLOCK_THREADS - vec = _load_vec(in_div, idx) + vec = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx) in_local.append(vec) x = vec.to(fx.Float32) @@ -172,8 +226,8 @@ def _store_vec(val, div_tensor, idx): sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) mean, rstd = compute_mean_rstd(sum_val, sumsq_val) - g_cur = _load_vec(gamma_div, tid).to(fx.Float32) - b_cur = _load_vec(beta_div, tid).to(fx.Float32) + g_cur = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, tid).to(fx.Float32) + b_cur = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, beta_div, tid).to(fx.Float32) # ── Pass 2: normalize + affine + store ─────────────────────── for tile_i in range_constexpr(num_tiles_py): @@ -181,8 +235,8 @@ def _store_vec(val, div_tensor, idx): b_next = b_cur if const_expr(tile_i + 1 < num_tiles_py): next_idx = tid + (tile_i + 1) * BLOCK_THREADS - g_next = _load_vec(gamma_div, next_idx).to(fx.Float32) - b_next = _load_vec(beta_div, next_idx).to(fx.Float32) + g_next = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, next_idx).to(fx.Float32) + b_next = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, beta_div, next_idx).to(fx.Float32) else: g_next = g_cur b_next = b_cur @@ -191,29 +245,9 @@ def _store_vec(val, div_tensor, idx): y = (x - mean) * rstd y = y * g_cur + b_cur - out_e = y.to(elem_dtype) - if const_expr(dtype_str == "bf16"): - if const_expr(USE_HW_CVT_PK_BF16_F32): - out_e = y.to(elem_dtype) - else: - u = y.bitcast(fx.Uint32) - upper = u >> 16 - lsb = upper & 1 - bias = lsb + 0x7FFF - u_round = y.bitcast(fx.Uint32) + bias - bf16_bits = u_round >> 16 - even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) - odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << 16 - packed = even | odd_sh - out_e = packed.bitcast(elem_dtype) - elif const_expr(dtype_str == "f32"): - out_e = y - else: - out_e = y.to(elem_dtype) - + out_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) out_idx = tid + tile_i * BLOCK_THREADS - _store_vec(out_e, out_div, out_idx) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, out_e, out_div, out_idx) g_cur = g_next b_cur = b_next @@ -244,25 +278,12 @@ def _store_vec(val, div_tensor, idx): beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.make_rmem_tensor(1, elem_dtype) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_scalar(divided_tensor, index, val): - r = fx.make_rmem_tensor(1, elem_dtype) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - # ── Pass 1: sum + sumsq ────────────────────────────────────── for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int is_valid = idx < N idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(row_div, idx_safe) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx_safe) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) x2 = x * x x_safe = is_valid.select(x, c_zero_f) @@ -277,9 +298,9 @@ def _store_scalar(divided_tensor, index, val): for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): idx = tid + base_idx_int if idx < N: - x_e = _load_scalar(row_div, idx) - g_e = _load_scalar(gamma_div, idx) - b_e = _load_scalar(beta_div, idx) + x_e = _load_scalar(copy_atom_s, elem_dtype, row_div, idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + b_e = _load_scalar(copy_atom_s, elem_dtype, beta_div, idx) x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) @@ -287,14 +308,8 @@ def _store_scalar(divided_tensor, index, val): norm = diff * rstd scaled = norm * g y = scaled + b - y_e = y - if const_expr(dtype_str == "bf16"): - y_e = y.to(elem_dtype) - elif const_expr(dtype_str == "f32"): - y_e = y - else: - y_e = y.to(elem_dtype) - _store_scalar(out_div, idx, y_e) + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div, idx, y_e) # ── JIT host launcher ───────────────────────────────────────────────── @flyc.jit @@ -329,13 +344,13 @@ def _quant_dtype_max(dtype_str: str) -> float: def build_fused_add_layernorm_module(M: int, N: int, dtype_str: str): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 - @fx.struct - class SharedStorage: - s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] - s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] + SharedStorage = _make_reduction_storage(RED_SLOTS) @flyc.kernel def fused_add_layernorm_kernel( @@ -404,82 +419,128 @@ def compute_mean_rstd(sum_val, sumsq_val): var = (var < 0.0).select(0.0, var) return mean, fmath.rsqrt(var + eps_c, fastmath=fm_fast) - Input_buf = fx.rocdl.make_buffer_tensor(Input) - ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) - Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) - Beta_buf = fx.rocdl.make_buffer_tensor(Beta) - Output_buf = fx.rocdl.make_buffer_tensor(Output) - ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) - - row_in = fx.slice(Input_buf, (bid, None)) - row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) - row_out = fx.slice(Output_buf, (bid, None)) - row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) - - copy_atom_s = fx.make_copy_atom( - fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), - elem_bits, - ) + # ================================================================== + # Fast path: N == BLOCK_THREADS * VEC_WIDTH * 4 + # ================================================================== + if const_expr(N == (BLOCK_THREADS * VEC_WIDTH * 4) and elem_bits <= 16): + num_tiles_py = 4 + c_zero_f = fx.Float32(0.0) + thread_sum = c_zero_f + thread_sumsq = c_zero_f + added_local = [] - in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) - residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) - gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) - beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) - out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) - - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.make_rmem_tensor(1, elem_dtype) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_scalar(divided_tensor, index, val): - r = fx.make_rmem_tensor(1, elem_dtype) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) - c_zero_f = fx.Float32(0.0) - thread_sum = c_zero_f - thread_sumsq = c_zero_f - - for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): - idx = tid + base_idx_int - is_valid = idx < N - idx_safe = is_valid.select(idx, 0) - x_e = _load_scalar(in_div, idx_safe) - r_e = _load_scalar(residual_in_div, idx_safe) - x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) - residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) - added_e = (x + residual) if dtype_str == "f32" else (x + residual).to(elem_dtype) - added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) - added_safe = is_valid.select(added, c_zero_f) - thread_sum = thread_sum + added_safe - thread_sumsq = thread_sumsq + is_valid.select(added * added, c_zero_f) - if idx < N: - _store_scalar(residual_out_div, idx, added_e) - - sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) - mean, rstd = compute_mean_rstd(sum_val, sumsq_val) - - for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): - idx = tid + base_idx_int - if idx < N: - added_e = _load_scalar(residual_out_div, idx) - g_e = _load_scalar(gamma_div, idx) - b_e = _load_scalar(beta_div, idx) - added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) - g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(VEC_WIDTH, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(VEC_WIDTH, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + + # Pass 1: add residual, cache/store it, and accumulate sum/sumsq. + for tile_i in range_constexpr(num_tiles_py): + idx = tid + tile_i * BLOCK_THREADS + x = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx).to(fx.Float32) + residual = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, x + residual) + added_local.append(added_e) + added = added_e.to(fx.Float32) + added2 = added * added + thread_sum = thread_sum + added.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + added2.reduce(ReductionOp.ADD, fastmath=fm_fast) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, added_e, residual_out_div, idx) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean, rstd = compute_mean_rstd(sum_val, sumsq_val) + + # Pass 2: normalize + affine + store, reusing cached added values. + for tile_i in range_constexpr(num_tiles_py): + idx = tid + tile_i * BLOCK_THREADS + added = added_local[tile_i].to(fx.Float32) + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) + b = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, beta_div, idx).to(fx.Float32) y = (added - mean) * rstd y = y * g + b - if const_expr(dtype_str == "f32"): - y_e = y - else: - y_e = y.to(elem_dtype) - _store_scalar(out_div, idx, y_e) + y_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, y) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, y_e, out_div, idx) + + else: + # ============================================================== + # Generic path: scalar 2-pass implementation for arbitrary N + # ============================================================== + Input_buf = fx.rocdl.make_buffer_tensor(Input) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + + row_in = fx.slice(Input_buf, (bid, None)) + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + + in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) + + c_zero_f = fx.Float32(0.0) + thread_sum = c_zero_f + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x_e = _load_scalar(copy_atom_s, elem_dtype, in_div, idx_safe) + r_e = _load_scalar(copy_atom_s, elem_dtype, residual_in_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.to(fx.Float32) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + added_safe = is_valid.select(added, c_zero_f) + thread_sum = thread_sum + added_safe + thread_sumsq = thread_sumsq + is_valid.select(added * added, c_zero_f) + if idx < N: + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, residual_out_div, idx, added_e) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean, rstd = compute_mean_rstd(sum_val, sumsq_val) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + added_e = _load_scalar(copy_atom_s, elem_dtype, residual_out_div, idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + b_e = _load_scalar(copy_atom_s, elem_dtype, beta_div, idx) + added = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = (added - mean) * rstd + y = y * g + b + y_e = _to_elem_scalar(dtype_str, elem_dtype, y) + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, out_div, idx, y_e) @flyc.jit def launch_fused_add_layernorm( @@ -511,14 +572,14 @@ def _build_layernorm_quant_module( is_fused_add: bool, quant_dtype_str: str = "i8", ): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 quant_dtype_max = _quant_dtype_max(quant_dtype_str) - @fx.struct - class SharedStorage: - s_sum: fx.Array[fx.Float32, RED_SLOTS, 16] - s_sumsq: fx.Array[fx.Float32, RED_SLOTS, 16] + SharedStorage = _make_reduction_storage(RED_SLOTS) @flyc.kernel def layernorm_quant_kernel( @@ -553,12 +614,6 @@ def layernorm_quant_kernel( yscale_div = fx.logical_divide(YScale_buf, fx.make_layout(1, 1)) scale_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - def _store_yscale(index, val): - r = fx.make_rmem_tensor(1, fx.Float32) - ts = full(1, fx.Float32(val), fx.Float32) - fx.memref_store_vec(ts, r) - fx.copy_atom_call(scale_copy_atom, r, fx.slice(yscale_div, (None, index))) - def wave_reduce_add(x): w = x for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): @@ -628,145 +683,239 @@ def block_reduce_max(val): return fx.memref_load(s_sum, 0) - Input_buf = fx.rocdl.make_buffer_tensor(Input) - Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) - Beta_buf = fx.rocdl.make_buffer_tensor(Beta) - Output_buf = fx.rocdl.make_buffer_tensor(Output) - if const_expr(is_fused_add): - ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) - ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) - if const_expr(is_smooth): - XScale_buf = fx.rocdl.make_buffer_tensor(XScale) + # ================================================================== + # Fast path: N == BLOCK_THREADS * VEC_WIDTH * 4 + # ================================================================== + if const_expr(N == (BLOCK_THREADS * VEC_WIDTH * 4) and elem_bits <= 16): + num_tiles_py = 4 + quant_half_width = VEC_WIDTH // 2 + xscale_vec_width = 4 + abs_mask = full(VEC_WIDTH, fx.Uint32(0x7FFFFFFF), fx.Uint32) - row_in = fx.slice(Input_buf, (bid, None)) - row_out = fx.slice(Output_buf, (bid, None)) - if const_expr(is_fused_add): - row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) - row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + if const_expr(is_fused_add): + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) - copy_atom_s = fx.make_copy_atom( - fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), - elem_bits, - ) - copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + if const_expr(is_fused_add): + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) - in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) - gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) - beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) - out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) - if const_expr(is_fused_add): - residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) - residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) - if const_expr(is_smooth): - xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) - - def _load_scalar(divided_tensor, index): - view = fx.slice(divided_tensor, (None, index)) - r = fx.make_rmem_tensor(1, elem_dtype) - fx.copy_atom_call(copy_atom_s, view, r) - return fx.memref_load_vec(r)[0] - - def _store_elem_scalar(divided_tensor, index, val): - r = fx.make_rmem_tensor(1, elem_dtype) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_s, r, view) - - def _store_quant_scalar(divided_tensor, index, val): - r = fx.make_rmem_tensor(1, quant_dtype) - ts = full(1, quant_dtype(val), quant_dtype) - fx.memref_store_vec(ts, r) - view = fx.slice(divided_tensor, (None, index)) - fx.copy_atom_call(copy_atom_qs, r, view) - - def _abs_scalar(val): - is_neg = val < c_zero_f - neg_val = c_zero_f - val - return is_neg.select(neg_val, val) - - def _load_base_input_value(index): - x_e = _load_scalar(in_div, index) - return x_e if dtype_str == "f32" else x_e.to(fx.Float32) - - def _load_norm_input_value(index): + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(VEC_WIDTH, 1)) + out_div_q = fx.logical_divide(row_out, fx.make_layout(quant_half_width, 1)) + if const_expr(is_fused_add): + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(VEC_WIDTH, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(VEC_WIDTH, 1)) + if const_expr(is_smooth): + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(xscale_vec_width, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + copy_atom_q = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 8) + if const_expr(is_smooth): + copy_atom_xs = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 32) + + def _load_xscale_vec(idx): + s_lo = _load_vec(copy_atom_xs, xscale_vec_width, fx.Float32, xscale_div, idx * 2) + s_hi = _load_vec(copy_atom_xs, xscale_vec_width, fx.Float32, xscale_div, idx * 2 + 1) + return Vec(s_lo).shuffle(Vec(s_hi), [0, 1, 2, 3, 4, 5, 6, 7]).ir_value() + + thread_sum = c_zero_f + thread_sumsq = c_zero_f + norm_input_local = [] + + # Pass 1: prepare normalization input and accumulate sum/sumsq. + for tile_i in range_constexpr(num_tiles_py): + idx = tid + tile_i * BLOCK_THREADS + if const_expr(is_fused_add): + x = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx).to(fx.Float32) + residual = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, residual_in_div, idx).to(fx.Float32) + added_e = _to_elem_vec(dtype_str, elem_dtype, USE_HW_CVT_PK_BF16_F32, x + residual) + norm_input_local.append(added_e) + x_norm = added_e.to(fx.Float32) + _store_vec(copy_atom, VEC_WIDTH, elem_dtype, added_e, residual_out_div, idx) + else: + x_e = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, in_div, idx) + norm_input_local.append(x_e) + x_norm = x_e.to(fx.Float32) + x2 = x_norm * x_norm + thread_sum = thread_sum + x_norm.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = thread_sumsq + x2.reduce(ReductionOp.ADD, fastmath=fm_fast) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean = sum_val / n_float + var = sumsq_val / n_float - mean * mean + var = (var < c_zero_f).select(c_zero_f, var) + rstd = (var + eps_c).rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + y_local = [] + + # Pass 2: affine (+ optional smooth scale), cache y, accumulate row max. + for tile_i in range_constexpr(num_tiles_py): + idx = tid + tile_i * BLOCK_THREADS + x = norm_input_local[tile_i].to(fx.Float32) + g = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, gamma_div, idx).to(fx.Float32) + b = _load_vec(copy_atom, VEC_WIDTH, elem_dtype, beta_div, idx).to(fx.Float32) + y = (x - mean) * rstd + y = y * g + b + if const_expr(is_smooth): + y = y * _load_xscale_vec(idx) + y_local.append(y) + y_abs = (y.bitcast(fx.Uint32) & abs_mask).bitcast(fx.Float32) + tile_max = y_abs.reduce(ReductionOp.MAX) + thread_row_max = thread_row_max.maximumf(tile_max) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(scale_copy_atom, yscale_div, bid, final_scale) + + inv_scale = c_one_f / final_scale + + # Pass 3: quantize + store using per-row scale. + for tile_i in range_constexpr(num_tiles_py): + q = y_local[tile_i] * inv_scale + q_i8 = q.to(quant_dtype) + q_lo = q_i8.shuffle(q_i8, [0, 1, 2, 3]) + q_hi = q_i8.shuffle(q_i8, [4, 5, 6, 7]) + out_idx = tid * 2 + tile_i * BLOCK_THREADS * 2 + _store_vec(copy_atom_q, quant_half_width, quant_dtype, q_lo, out_div_q, out_idx) + _store_vec(copy_atom_q, quant_half_width, quant_dtype, q_hi, out_div_q, out_idx + 1) + + else: + # ============================================================== + # Generic path: scalar 3-pass implementation for arbitrary N + # ============================================================== + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + Output_buf = fx.rocdl.make_buffer_tensor(Output) if const_expr(is_fused_add): - added_e = _load_scalar(residual_out_div, index) - return added_e if dtype_str == "f32" else added_e.to(fx.Float32) - return _load_base_input_value(index) + ResidualIn_buf = fx.rocdl.make_buffer_tensor(ResidualIn) + ResidualOut_buf = fx.rocdl.make_buffer_tensor(ResidualOut) + if const_expr(is_smooth): + XScale_buf = fx.rocdl.make_buffer_tensor(XScale) - thread_sum = c_zero_f - thread_sumsq = c_zero_f + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + if const_expr(is_fused_add): + row_residual_in = fx.slice(ResidualIn_buf, (bid, None)) + row_residual_out = fx.slice(ResidualOut_buf, (bid, None)) - for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): - idx = tid + base_idx_int - is_valid = idx < N - idx_safe = is_valid.select(idx, 0) + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + copy_atom_qs = fx.make_copy_atom(fx.rocdl.BufferCopy(8), 8) + + in_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) if const_expr(is_fused_add): - x = _load_base_input_value(idx_safe) - r_e = _load_scalar(residual_in_div, idx_safe) - residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) - added_e = (x + residual) if dtype_str == "f32" else (x + residual).to(elem_dtype) - if idx < N: - _store_elem_scalar(residual_out_div, idx, added_e) - x = added_e if dtype_str == "f32" else added_e.to(fx.Float32) - else: - x = _load_norm_input_value(idx_safe) - x2 = x * x - thread_sum = thread_sum + is_valid.select(x, c_zero_f) - thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) - - sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) - mean = sum_val / n_float - var = sumsq_val / n_float - mean * mean - var = (var < c_zero_f).select(c_zero_f, var) - rstd = (var + eps_c).rsqrt(fastmath=fm_fast) - - thread_row_max = c_zero_f - for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): - idx = tid + base_idx_int - is_valid = idx < N - idx_safe = is_valid.select(idx, 0) - x = _load_norm_input_value(idx_safe) - g_e = _load_scalar(gamma_div, idx_safe) - b_e = _load_scalar(beta_div, idx_safe) - g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) - b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) - y = (x - mean) * rstd - y = y * g + b + residual_in_div = fx.logical_divide(row_residual_in, fx.make_layout(1, 1)) + residual_out_div = fx.logical_divide(row_residual_out, fx.make_layout(1, 1)) if const_expr(is_smooth): - s_e = _load_scalar(xscale_div, idx_safe) - s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) - y = y * s - y_abs = _abs_scalar(y) - thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) - - row_max = block_reduce_max(thread_row_max) - scale = row_max / c_dtype_max - final_scale = (scale == c_zero_f).select(c_one_f, scale) - - if tid == 0: - _store_yscale(bid, final_scale) - - inv_scale = c_one_f / final_scale - - for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): - idx = tid + base_idx_int - if idx < N: - x = _load_norm_input_value(idx) - g_e = _load_scalar(gamma_div, idx) - b_e = _load_scalar(beta_div, idx) + xscale_div = fx.logical_divide(XScale_buf, fx.make_layout(1, 1)) + + def _abs_scalar(val): + is_neg = val < c_zero_f + neg_val = c_zero_f - val + return is_neg.select(neg_val, val) + + def _load_base_input_value(index): + x_e = _load_scalar(copy_atom_s, elem_dtype, in_div, index) + return x_e if dtype_str == "f32" else x_e.to(fx.Float32) + + def _load_norm_input_value(index): + if const_expr(is_fused_add): + added_e = _load_scalar(copy_atom_s, elem_dtype, residual_out_div, index) + return added_e if dtype_str == "f32" else added_e.to(fx.Float32) + return _load_base_input_value(index) + + thread_sum = c_zero_f + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + if const_expr(is_fused_add): + x = _load_base_input_value(idx_safe) + r_e = _load_scalar(copy_atom_s, elem_dtype, residual_in_div, idx_safe) + residual = r_e if dtype_str == "f32" else r_e.to(fx.Float32) + added_e = _to_elem_scalar(dtype_str, elem_dtype, x + residual) + if idx < N: + _store_scalar(copy_atom_s, elem_dtype, elem_dtype, residual_out_div, idx, added_e) + x = added_e if dtype_str == "f32" else added_e.to(fx.Float32) + else: + x = _load_norm_input_value(idx_safe) + x2 = x * x + thread_sum = thread_sum + is_valid.select(x, c_zero_f) + thread_sumsq = thread_sumsq + is_valid.select(x2, c_zero_f) + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean = sum_val / n_float + var = sumsq_val / n_float - mean * mean + var = (var < c_zero_f).select(c_zero_f, var) + rstd = (var + eps_c).rsqrt(fastmath=fm_fast) + + thread_row_max = c_zero_f + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + is_valid = idx < N + idx_safe = is_valid.select(idx, 0) + x = _load_norm_input_value(idx_safe) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx_safe) + b_e = _load_scalar(copy_atom_s, elem_dtype, beta_div, idx_safe) g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) y = (x - mean) * rstd y = y * g + b if const_expr(is_smooth): - s_e = _load_scalar(xscale_div, idx) + s_e = _load_scalar(copy_atom_s, elem_dtype, xscale_div, idx_safe) s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) y = y * s - q = y * inv_scale - q_i8 = q.to(quant_dtype) - _store_quant_scalar(out_div, idx, q_i8) + y_abs = _abs_scalar(y) + thread_row_max = thread_row_max.maximumf(is_valid.select(y_abs, c_zero_f)) + + row_max = block_reduce_max(thread_row_max) + scale = row_max / c_dtype_max + final_scale = (scale == c_zero_f).select(c_one_f, scale) + + if tid == 0: + _store_yscale(scale_copy_atom, yscale_div, bid, final_scale) + + inv_scale = c_one_f / final_scale + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + if idx < N: + x = _load_norm_input_value(idx) + g_e = _load_scalar(copy_atom_s, elem_dtype, gamma_div, idx) + b_e = _load_scalar(copy_atom_s, elem_dtype, beta_div, idx) + g = g_e if dtype_str == "f32" else g_e.to(fx.Float32) + b = b_e if dtype_str == "f32" else b_e.to(fx.Float32) + y = (x - mean) * rstd + y = y * g + b + if const_expr(is_smooth): + s_e = _load_scalar(copy_atom_s, elem_dtype, xscale_div, idx) + s = s_e if dtype_str == "f32" else s_e.to(fx.Float32) + y = y * s + q = y * inv_scale + q_i8 = q.to(quant_dtype) + _store_scalar(copy_atom_qs, quant_dtype, quant_dtype, out_div, idx, q_i8) if is_fused_add: if is_smooth: diff --git a/tests/kernels/benchmark_common.py b/tests/kernels/benchmark_common.py index 03ffc4b9e..13691b615 100644 --- a/tests/kernels/benchmark_common.py +++ b/tests/kernels/benchmark_common.py @@ -41,10 +41,10 @@ class PerfRow: aiter_gpu_us: Optional[float] @property - def speedup_aiter_vs_flydsl(self) -> Optional[float]: + def speedup_flydsl_vs_aiter(self) -> Optional[float]: if self.flydsl_gpu_us is None or self.aiter_gpu_us is None: return None - return self.flydsl_gpu_us / self.aiter_gpu_us + return self.aiter_gpu_us / self.flydsl_gpu_us def _fmt_us(x: Optional[float]) -> str: @@ -57,7 +57,7 @@ def print_perf_table(rows: List[PerfRow]) -> None: print("=" * 100) print(f"{'op':10s} {'shape':18s} {'dtype':6s} {'FlyDSL(gpu us)':>14s} {'AIter(gpu us)':>14s} {'speedup':>10s}") for r in rows: - sp = r.speedup_aiter_vs_flydsl + sp = r.speedup_flydsl_vs_aiter sp_s = "-" if sp is None else f"{sp:,.2f}x" print( f"{r.op:10s} {r.shape:18s} {r.dtype:6s} {_fmt_us(r.flydsl_gpu_us):>14s} {_fmt_us(r.aiter_gpu_us):>14s} {sp_s:>10s}" @@ -946,7 +946,7 @@ def main() -> None: print("=" * 100) print(f"{'op':10s} {'shape':18s} {'dtype':6s} {'FlyDSL(gpu us)':>14s} {'torch(gpu us)':>14s} {'speedup':>10s}") for r in wmma_rows: - sp = r.speedup_aiter_vs_flydsl + sp = r.speedup_flydsl_vs_aiter sp_s = "-" if sp is None else f"{sp:,.2f}x" print( f"{r.op:10s} {r.shape:18s} {r.dtype:6s} {_fmt_us(r.flydsl_gpu_us):>14s} {_fmt_us(r.aiter_gpu_us):>14s} {sp_s:>10s}" diff --git a/tests/kernels/test_rmsnorm.py b/tests/kernels/test_rmsnorm.py index 04eae3c92..e0a7f6407 100644 --- a/tests/kernels/test_rmsnorm.py +++ b/tests/kernels/test_rmsnorm.py @@ -53,6 +53,40 @@ BENCH_ITERS = 100 +def _torch_dtype(dtype: str): + if dtype == "f32": + return DTYPE_FP32 + if dtype == "f16": + return DTYPE_FP16 + if dtype == "bf16": + return DTYPE_BF16 + raise ValueError(f"unsupported dtype: {dtype}") + + +def _get_rmsnorm_configs(): + shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + m_s, n_s, dt = [x.strip() for x in p.split(",")] + configs.append((int(m_s), int(n_s), dt)) + else: + # Prefer N multiples of 2048 to exercise the fast path. + configs = [ + (64, 256, "f32"), # Aligned + (128, 1024, "f32"), # Aligned + (32, 128, "f16"), # Aligned + (64, 2000, "f32"), # Unaligned (tail handling) + (16, 512, "bf16"), # BF16 + (1024, 8192, "bf16"), # BF16 + (32768, 8192, "bf16"), + ] + return configs + + def run_test(M: int, N: int, dtype: str = "f32"): print(f"\nTesting RMSNorm (M={M}, N={N}, dtype={dtype})") @@ -61,6 +95,7 @@ def run_test(M: int, N: int, dtype: str = "f32"): except Exception as e: print(f"[FAIL] Compile failed for (M={M}, N={N}, dtype={dtype}): {type(e).__name__}: {e}") return False, None + torch.manual_seed(42) input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) @@ -143,166 +178,6 @@ def kernel_launch(): return ok, flydsl_gpu_us -def test_all(): - print("=" * 80) - print("Running RMSNorm Tests") - print("=" * 80) - - shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() - if shapes_env: - configs = [] - for part in shapes_env.split(";"): - p = part.strip() - if not p: - continue - m_s, n_s, dt = [x.strip() for x in p.split(",")] - configs.append((int(m_s), int(n_s), dt)) - else: - # Prefer N multiples of 2048 to exercise the fast path. - configs = [ - # (64, 256, "f32"), # Aligned - # (128, 1024, "f32"), # Aligned - # (32, 128, "f16"), # Aligned - # (64, 2000, "f32"), # Unaligned (tail handling) - # (16, 512, "bf16"), # BF16 - # (1024, 8192, "bf16"), # BF16 - (32768, 8192, "bf16"), - ] - - do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" - perf_rows = [] - - failures = 0 - for M, N, dtype in configs: - ok, flydsl_gpu_us = run_test(M, N, dtype) - if not ok: - failures += 1 - - if do_compare: - import torch - - aiter_us = None - if maybe_enable_aiter(): - try: - from aiter.ops.triton.rmsnorm import rms_norm as aiter_rms_norm - - x = torch.randn( - (M, N), - device="cuda", - dtype=DTYPE_BF16 if dtype == "bf16" else (DTYPE_FP16 if dtype == "f16" else DTYPE_FP32), - ) - w = torch.rand((N,), device="cuda", dtype=x.dtype) - - def run_aiter(): - aiter_rms_norm(x, w, EPS) - - aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) - print(f"[Perf] AIter rmsnorm gpu: {aiter_us:.1f} us") - except Exception as e: - print(f"[Perf] AIter rmsnorm skipped: {type(e).__name__}: {e!r}") - - perf_rows.append( - PerfRow(op="rmsnorm", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us) - ) - - print("\n" + "=" * 80) - if failures == 0: - print("ALL TESTS PASSED") - else: - print(f"{failures} TESTS FAILED") - print("=" * 80) - if do_compare and perf_rows: - print_perf_table(perf_rows) - # Ensure a non-zero exit code on failure for shell wrappers. - if failures != 0: - raise SystemExit(1) - - -def _torch_dtype(dtype: str): - if dtype == "f32": - return DTYPE_FP32 - if dtype == "f16": - return DTYPE_FP16 - if dtype == "bf16": - return DTYPE_BF16 - raise ValueError(f"unsupported dtype: {dtype}") - - -def _get_rmsnorm_configs(): - shapes_env = os.environ.get("ROCDSL_RMSNORM_SHAPES", "").strip() - if shapes_env: - configs = [] - for part in shapes_env.split(";"): - p = part.strip() - if not p: - continue - m_s, n_s, dt = [x.strip() for x in p.split(",")] - configs.append((int(m_s), int(n_s), dt)) - return configs - - # Prefer N multiples of 2048 to exercise the fast path. - return [ - # (64, 256, "f32"), # Aligned - # (128, 1024, "f32"), # Aligned - # (32, 128, "f16"), # Aligned - # (64, 2000, "f32"), # Unaligned (tail handling) - # (16, 512, "bf16"), # BF16 - # (1024, 8192, "bf16"), # BF16 - (32768, 8192, "bf16"), - ] - - -def _reference_rmsnorm_quant(input_dev, gamma_dev, *, xscale_dev=None): - x = input_dev.to(DTYPE_FP32) - gamma = gamma_dev.to(DTYPE_FP32) - expected = (x / torch.sqrt((x * x).mean(dim=1, keepdim=True) + EPS)) * gamma - if xscale_dev is not None: - expected = expected * xscale_dev.to(DTYPE_FP32) - - yscale = expected.abs().amax(dim=1) / 127.0 - yscale = torch.where(yscale == 0, torch.ones_like(yscale), yscale) - q = torch.clamp(torch.trunc(expected / yscale.unsqueeze(1)), -127, 127).to(torch.int8) - return expected, q, yscale - - -def _bench_aiter_rmsnorm_quant(M: int, N: int, dtype: str, *, is_smooth: bool): - mode = "smoothquant" if is_smooth else "dynamicquant" - torch_dtype = _torch_dtype(dtype) - - try: - if is_smooth: - from aiter.ops.triton.normalization.rmsnorm import ( - rmsnorm2d_fwd_with_smoothquant as aiter_rmsnorm_quant, - ) - else: - from aiter.ops.triton.normalization.rmsnorm import ( - rmsnorm2d_fwd_with_dynamicquant as aiter_rmsnorm_quant, - ) - except Exception as e: - print(f"[Perf] AIter rmsnorm {mode} skipped: {type(e).__name__}: {e!r}") - return None - - x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() - w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() - y = torch.empty((M, N), dtype=torch.int8, device="cuda") - yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") - - if is_smooth: - xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() - - def run_aiter(): - aiter_rmsnorm_quant(y, x, xscale, yscale, w, EPS) - - else: - - def run_aiter(): - aiter_rmsnorm_quant(y, x, yscale, w, EPS) - - aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) - print(f"[Perf] AIter rmsnorm {mode} gpu: {aiter_us:.1f} us") - return aiter_us - - def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): mode = "smoothquant" if is_smooth else "dynamicquant" print(f"\nTesting RMSNorm {mode} (M={M}, N={N}, dtype={dtype})") @@ -315,6 +190,7 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): except Exception as e: print(f"[FAIL] Compile failed for {mode} (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") return False, None + torch.manual_seed(42) input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) @@ -336,17 +212,7 @@ def run_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): xscale_dev = None if is_smooth: xscale_dev = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() - dequant_tol = 0.25 if is_smooth else 0.2 - scale_tol = 1e-2 if is_smooth else 5e-3 - - # PyTorch Reference: - # RMS(x) = sqrt(mean(x^2) + eps) ; RMSNorm(x) = x / RMS(x) * gamma - # Quant path additionally computes per-row yscale and int8 output from the fp32 reference. - expected, q_ref, yscale_ref = _reference_rmsnorm_quant( - input_dev, - gamma_dev, - xscale_dev=xscale_dev, - ) + scale_tol = 1e-3 print("Launching kernel...") stream = torch.cuda.current_stream() @@ -381,30 +247,30 @@ def kernel_launch(): if flydsl_gpu_us is not None: print(f"[Perf] FlyDSL rmsnorm {mode} gpu: {flydsl_gpu_us:.1f} us") + # PyTorch Reference: + # RMS(x) = sqrt(mean(x^2) + eps) ; RMSNorm(x) = x / RMS(x) * gamma + # Quant path additionally computes per-row yscale and int8 output from the fp32 reference. + q_ref, yscale_ref = _reference_rmsnorm_quant( + input_dev, + gamma_dev, + xscale_dev=xscale_dev, + ) q_out = output_dev.to(torch.int16) q_expected = q_ref.to(torch.int16) yscale_out = yscale_dev.cpu() yscale_expected = yscale_ref.cpu() - output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) - error = (output_ref - expected).abs().max().item() - scale_diff = (yscale_out - yscale_expected).abs().max().item() - quant_diff = (q_out - q_expected).abs().max().item() + quant_error = (q_out - q_expected).abs().max().item() + scale_error = (yscale_out - yscale_expected).abs().max().item() - print(f"Max dequant error: {error:.2e} (tol={dequant_tol})") - print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") - print(f"Max quant diff: {quant_diff}") + print(f"Max quant diff: {quant_error}") + print(f"Max scale diff: {scale_error:.2e} (tol={scale_tol})") - ok = error < dequant_tol and scale_diff < scale_tol and quant_diff <= 1 + ok = quant_error <= 1 and scale_error < scale_tol if ok: print("PASSED") - ok = True else: print("FAILED") - print("First row Expected:") - print(expected[0, :5]) - print("First row Actual:") - print(output_ref[0, :5]) print("First row Quant Expected:") print(q_expected[0, :8]) print("First row Quant Actual:") @@ -413,168 +279,49 @@ def kernel_launch(): print(yscale_expected[:5]) print("First few YScale Actual:") print(yscale_out[:5]) - ok = False return ok, flydsl_gpu_us -def test_rmsnorm_dynamicquant(): - print("=" * 80) - print("Running RMSNorm DynamicQuant Tests") - print("=" * 80) - - do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" - perf_rows = [] +def run_fused_add_test(M: int, N: int, dtype: str): + print(f"\nTesting FusedAdd RMSNorm (M={M}, N={N}, dtype={dtype})") - failures = 0 - for M, N, dtype in _get_rmsnorm_configs(): - ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=False) - if not ok: - failures += 1 + try: + launch_fn = build_fused_add_rmsnorm_module(M, N, dtype) + except Exception as e: + print(f"[FAIL] Compile failed for fused_add rmsnorm (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") + return False, None - if do_compare: - aiter_us = None - if maybe_enable_aiter(): - aiter_us = _bench_aiter_rmsnorm_quant(M, N, dtype, is_smooth=False) - perf_rows.append( - PerfRow( - op="rmsnorm_dq", - shape=f"{M}x{N}", - dtype=dtype, - flydsl_gpu_us=flydsl_gpu_us, - aiter_gpu_us=aiter_us, - ) - ) + torch.manual_seed(42) + input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) + gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) - print("\n" + "=" * 80) - if failures == 0: - print("ALL TESTS PASSED") + if dtype == "f32": + input_dev = input_t.contiguous() + residual_in_dev = residual_t.contiguous() + gamma_dev = gamma_t.contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) + atol = 1e-4 + elif dtype == "f16": + input_dev = input_t.to(DTYPE_FP16).contiguous() + residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() + gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) + atol = 1e-2 + elif dtype == "bf16": + input_dev = input_t.to(DTYPE_BF16).contiguous() + residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() + gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() + output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) + atol = 2e-2 else: - print(f"{failures} TESTS FAILED") - print("=" * 80) - if do_compare and perf_rows: - print_perf_table(perf_rows) - # Ensure a non-zero exit code on failure for shell wrappers. - if failures != 0: - raise SystemExit(1) - + raise ValueError(f"unsupported dtype: {dtype}") -def test_rmsnorm_smoothquant(): - print("=" * 80) - print("Running RMSNorm SmoothQuant Tests") - print("=" * 80) - - do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" - perf_rows = [] - failures = 0 - - for M, N, dtype in _get_rmsnorm_configs(): - ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=True) - if not ok: - failures += 1 - - if do_compare: - aiter_us = None - if maybe_enable_aiter(): - aiter_us = _bench_aiter_rmsnorm_quant(M, N, dtype, is_smooth=True) - perf_rows.append( - PerfRow( - op="rmsnorm_sq", - shape=f"{M}x{N}", - dtype=dtype, - flydsl_gpu_us=flydsl_gpu_us, - aiter_gpu_us=aiter_us, - ) - ) - - print("\n" + "=" * 80) - if failures == 0: - print("ALL TESTS PASSED") - else: - print(f"{failures} TESTS FAILED") - print("=" * 80) - if do_compare and perf_rows: - print_perf_table(perf_rows) - # Ensure a non-zero exit code on failure for shell wrappers. - if failures != 0: - raise SystemExit(1) - - -def _reference_fused_add_rmsnorm(input_dev, residual_in_dev, gamma_dev): - added = input_dev + residual_in_dev - added_fp32 = added.to(DTYPE_FP32) - gamma = gamma_dev.to(DTYPE_FP32) - expected = (added_fp32 / torch.sqrt((added_fp32 * added_fp32).mean(dim=1, keepdim=True) + EPS)) * gamma - return added_fp32, expected - - -def _bench_aiter_fused_add_rmsnorm(M: int, N: int, dtype: str): - torch_dtype = _torch_dtype(dtype) - - try: - from aiter.ops.triton.normalization.rmsnorm import ( - rmsnorm2d_fwd_with_add as aiter_fused_add_rmsnorm, - ) - except Exception as e: - print(f"[Perf] AIter fused_add rmsnorm skipped: {type(e).__name__}: {e!r}") - return None - - x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() - residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() - w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() - out = torch.empty((M, N), device="cuda", dtype=torch_dtype) - residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) - - def run_aiter(): - aiter_fused_add_rmsnorm(out, x, residual_in, residual_out, w, EPS) - - aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) - print(f"[Perf] AIter fused_add rmsnorm gpu: {aiter_us:.1f} us") - return aiter_us - - -def run_fused_add_test(M: int, N: int, dtype: str): - print(f"\nTesting FusedAdd RMSNorm (M={M}, N={N}, dtype={dtype})") - - try: - launch_fn = build_fused_add_rmsnorm_module(M, N, dtype) - except Exception as e: - print(f"[FAIL] Compile failed for fused_add rmsnorm (M={M}, N={N}, dtype={dtype}): " f"{type(e).__name__}: {e}") - return False, None - - torch.manual_seed(42) - input_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) - residual_t = torch.randn((M, N), device="cuda", dtype=DTYPE_FP32) - gamma_t = torch.rand((N,), device="cuda", dtype=DTYPE_FP32) - - if dtype == "f32": - input_dev = input_t.contiguous() - residual_in_dev = residual_t.contiguous() - gamma_dev = gamma_t.contiguous() - output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) - residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP32) - output_atol = 1e-4 - residual_atol = 1e-4 - elif dtype == "f16": - input_dev = input_t.to(DTYPE_FP16).contiguous() - residual_in_dev = residual_t.to(DTYPE_FP16).contiguous() - gamma_dev = gamma_t.to(DTYPE_FP16).contiguous() - output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) - residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_FP16) - output_atol = 1e-2 - residual_atol = 1e-2 - elif dtype == "bf16": - input_dev = input_t.to(DTYPE_BF16).contiguous() - residual_in_dev = residual_t.to(DTYPE_BF16).contiguous() - gamma_dev = gamma_t.to(DTYPE_BF16).contiguous() - output_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) - residual_out_dev = torch.empty((M, N), device="cuda", dtype=DTYPE_BF16) - output_atol = 2e-2 - residual_atol = 2e-2 - else: - raise ValueError(f"unsupported dtype: {dtype}") - - print("Launching kernel...") - stream = torch.cuda.current_stream() + print("Launching kernel...") + stream = torch.cuda.current_stream() def kernel_launch(): launch_fn( @@ -607,6 +354,8 @@ def kernel_launch(): if flydsl_gpu_us is not None: print(f"[Perf] FlyDSL fused_add rmsnorm gpu: {flydsl_gpu_us:.1f} us") + # PyTorch Reference: + # RMS(x) = sqrt(mean(x^2) + eps) ; RMSNorm(x) = x / RMS(x) * gamma residual_expected, output_expected = _reference_fused_add_rmsnorm( input_dev, residual_in_dev, @@ -618,121 +367,25 @@ def kernel_launch(): residual_error = (residual_out_ref - residual_expected).abs().max().item() output_error = (output_ref - output_expected).abs().max().item() - print(f"Max residual error: {residual_error:.2e} (atol={residual_atol})") - print(f"Max output error: {output_error:.2e} (atol={output_atol})") + print(f"Max residual error: {residual_error:.2e} (atol={atol})") + print(f"Max output error: {output_error:.2e} (atol={atol})") - ok = residual_error < residual_atol and output_error < output_atol + ok = residual_error < atol and output_error < atol if ok: print("PASSED") else: print("FAILED") + print("First row Residual Expected:") + print(residual_expected[0, :5]) + print("First row Residual Actual:") + print(residual_out_ref[0, :5]) + print("First row Output Expected:") + print(output_expected[0, :5]) + print("First row Output Actual:") + print(output_ref[0, :5]) return ok, flydsl_gpu_us -def test_rmsnorm_fused_add(): - print("=" * 80) - print("Running FusedAdd RMSNorm Tests") - print("=" * 80) - - do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" - perf_rows = [] - failures = 0 - - for M, N, dtype in _get_rmsnorm_configs(): - ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) - if not ok: - failures += 1 - - if do_compare: - aiter_us = None - if maybe_enable_aiter(): - aiter_us = _bench_aiter_fused_add_rmsnorm(M, N, dtype) - perf_rows.append( - PerfRow( - op="rmsnorm_add", - shape=f"{M}x{N}", - dtype=dtype, - flydsl_gpu_us=flydsl_gpu_us, - aiter_gpu_us=aiter_us, - ) - ) - - print("\n" + "=" * 80) - if failures == 0: - print("ALL TESTS PASSED") - else: - print(f"{failures} TESTS FAILED") - print("=" * 80) - if do_compare and perf_rows: - print_perf_table(perf_rows) - # Ensure a non-zero exit code on failure for shell wrappers. - if failures != 0: - raise SystemExit(1) - - -def _reference_fused_add_rmsnorm_quant( - input_dev, - residual_in_dev, - gamma_dev, - *, - xscale_dev=None, -): - added = input_dev + residual_in_dev - residual_expected = added.to(DTYPE_FP32) - expected, q, yscale = _reference_rmsnorm_quant( - added, - gamma_dev, - xscale_dev=xscale_dev, - ) - return residual_expected, expected, q, yscale - - -def _bench_aiter_fused_add_rmsnorm_quant( - M: int, - N: int, - dtype: str, - *, - is_smooth: bool, -): - mode = "smoothquant" if is_smooth else "dynamicquant" - torch_dtype = _torch_dtype(dtype) - - try: - if is_smooth: - from aiter.ops.triton.normalization.rmsnorm import ( - rmsnorm2d_fwd_with_add_smoothquant as aiter_fused_add_rmsnorm_quant, - ) - else: - from aiter.ops.triton.normalization.rmsnorm import ( - rmsnorm2d_fwd_with_add_dynamicquant as aiter_fused_add_rmsnorm_quant, - ) - except Exception as e: - print(f"[Perf] AIter fused_add rmsnorm {mode} skipped: {type(e).__name__}: {e!r}") - return None - - x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() - residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() - w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() - y = torch.empty((M, N), dtype=torch.int8, device="cuda") - residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) - yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") - - if is_smooth: - xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() - - def run_aiter(): - aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, xscale, yscale, w, EPS) - - else: - - def run_aiter(): - aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, yscale, w, EPS) - - aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) - print(f"[Perf] AIter fused_add rmsnorm {mode} gpu: {aiter_us:.1f} us") - return aiter_us - - def run_fused_add_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): mode = "smoothquant" if is_smooth else "dynamicquant" print(f"\nTesting FusedAdd RMSNorm {mode} (M={M}, N={N}, dtype={dtype})") @@ -780,15 +433,7 @@ def run_fused_add_quant_test(M: int, N: int, dtype: str, *, is_smooth: bool): xscale_dev = None if is_smooth: xscale_dev = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() - dequant_tol = 0.25 if is_smooth else 0.2 - scale_tol = 1e-2 if is_smooth else 5e-3 - - residual_expected, expected, q_ref, yscale_ref = _reference_fused_add_rmsnorm_quant( - input_dev, - residual_in_dev, - gamma_dev, - xscale_dev=xscale_dev, - ) + scale_tol = 1e-3 print("Launching kernel...") stream = torch.cuda.current_stream() @@ -840,24 +485,29 @@ def kernel_launch(): if flydsl_gpu_us is not None: print(f"[Perf] FlyDSL fused_add rmsnorm {mode} gpu: {flydsl_gpu_us:.1f} us") + # PyTorch Reference: + # RMS(x) = sqrt(mean(x^2) + eps) ; RMSNorm(x) = x / RMS(x) * gamma + residual_expected, q_ref, yscale_ref = _reference_fused_add_rmsnorm_quant( + input_dev, + residual_in_dev, + gamma_dev, + xscale_dev=xscale_dev, + ) residual_out_ref = residual_out_dev.to(DTYPE_FP32) - output_ref = output_dev.to(DTYPE_FP32) * yscale_dev.unsqueeze(1) q_out = output_dev.to(torch.int16) q_expected = q_ref.to(torch.int16) yscale_out = yscale_dev.cpu() yscale_expected = yscale_ref.cpu() residual_error = (residual_out_ref - residual_expected).abs().max().item() - dequant_error = (output_ref - expected).abs().max().item() - scale_diff = (yscale_out - yscale_expected).abs().max().item() - quant_diff = (q_out - q_expected).abs().max().item() + scale_error = (yscale_out - yscale_expected).abs().max().item() + quant_error = (q_out - q_expected).abs().max().item() print(f"Max residual error: {residual_error:.2e} (tol={residual_atol})") - print(f"Max dequant error: {dequant_error:.2e} (tol={dequant_tol})") - print(f"Max scale diff: {scale_diff:.2e} (tol={scale_tol})") - print(f"Max quant diff: {quant_diff}") + print(f"Max scale error: {scale_error:.2e} (tol={scale_tol})") + print(f"Max quant error: {quant_error}") - ok = residual_error < residual_atol and dequant_error < dequant_tol and scale_diff < scale_tol and quant_diff <= 1 + ok = residual_error < residual_atol and scale_error < scale_tol and quant_error <= 1 if ok: print("PASSED") else: @@ -866,10 +516,6 @@ def kernel_launch(): print(residual_expected[0, :5]) print("First row Residual Actual:") print(residual_out_ref[0, :5]) - print("First row Expected:") - print(expected[0, :5]) - print("First row Actual:") - print(output_ref[0, :5]) print("First row Quant Expected:") print(q_expected[0, :8]) print("First row Quant Actual:") @@ -881,56 +527,383 @@ def kernel_launch(): return ok, flydsl_gpu_us -def test_rmsnorm_fused_add_dynamicquant(): - print("=" * 80) - print("Running FusedAdd RMSNorm DynamicQuant Tests") - print("=" * 80) - - do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" - perf_rows = [] - failures = 0 +def _reference_rmsnorm_quant(input_dev, gamma_dev, *, xscale_dev=None): + x = input_dev.to(DTYPE_FP32) + gamma = gamma_dev.to(DTYPE_FP32) + normalized = (x / torch.sqrt((x * x).mean(dim=1, keepdim=True) + EPS)) * gamma + if xscale_dev is not None: + normalized = normalized * xscale_dev.to(DTYPE_FP32) - for M, N, dtype in _get_rmsnorm_configs(): - ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=False) - if not ok: - failures += 1 + yscale = normalized.abs().amax(dim=1) / 127.0 + yscale = torch.where(yscale == 0, torch.ones_like(yscale), yscale) + q = torch.clamp(torch.trunc(normalized / yscale.unsqueeze(1)), -127, 127).to(torch.int8) + return q, yscale - if do_compare: - aiter_us = None - if maybe_enable_aiter(): - aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=False) - perf_rows.append( - PerfRow( - op="rmsnorm_add_dq", - shape=f"{M}x{N}", - dtype=dtype, - flydsl_gpu_us=flydsl_gpu_us, - aiter_gpu_us=aiter_us, - ) - ) - print("\n" + "=" * 80) - if failures == 0: - print("ALL TESTS PASSED") - else: +def _reference_fused_add_rmsnorm(input_dev, residual_in_dev, gamma_dev): + added = input_dev + residual_in_dev + added_fp32 = added.to(DTYPE_FP32) + gamma = gamma_dev.to(DTYPE_FP32) + expected = (added_fp32 / torch.sqrt((added_fp32 * added_fp32).mean(dim=1, keepdim=True) + EPS)) * gamma + return added_fp32, expected + + +def _reference_fused_add_rmsnorm_quant( + input_dev, + residual_in_dev, + gamma_dev, + *, + xscale_dev=None, +): + added = input_dev + residual_in_dev + residual_expected = added.to(DTYPE_FP32) + q, yscale = _reference_rmsnorm_quant( + added, + gamma_dev, + xscale_dev=xscale_dev, + ) + return residual_expected, q, yscale + + +def _bench_aiter_rmsnorm_quant(M: int, N: int, dtype: str, *, is_smooth: bool): + mode = "smoothquant" if is_smooth else "dynamicquant" + torch_dtype = _torch_dtype(dtype) + + try: + if is_smooth: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_smoothquant as aiter_rmsnorm_quant, + ) + else: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_dynamicquant as aiter_rmsnorm_quant, + ) + except Exception as e: + print(f"[Perf] AIter rmsnorm {mode} skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + y = torch.empty((M, N), dtype=torch.int8, device="cuda") + yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") + + if is_smooth: + xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + + def run_aiter(): + aiter_rmsnorm_quant(y, x, xscale, yscale, w, EPS) + + else: + + def run_aiter(): + aiter_rmsnorm_quant(y, x, yscale, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter rmsnorm {mode} gpu: {aiter_us:.1f} us") + return aiter_us + + +def _bench_aiter_fused_add_rmsnorm(M: int, N: int, dtype: str): + torch_dtype = _torch_dtype(dtype) + + try: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add as aiter_fused_add_rmsnorm, + ) + except Exception as e: + print(f"[Perf] AIter fused_add rmsnorm skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + + def run_aiter(): + aiter_fused_add_rmsnorm(out, x, residual_in, residual_out, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add rmsnorm gpu: {aiter_us:.1f} us") + return aiter_us + + +def _bench_aiter_fused_add_rmsnorm_quant(M: int, N: int, dtype: str, *, is_smooth: bool): + mode = "smoothquant" if is_smooth else "dynamicquant" + torch_dtype = _torch_dtype(dtype) + + try: + if is_smooth: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add_smoothquant as aiter_fused_add_rmsnorm_quant, + ) + else: + from aiter.ops.triton.normalization.rmsnorm import ( + rmsnorm2d_fwd_with_add_dynamicquant as aiter_fused_add_rmsnorm_quant, + ) + except Exception as e: + print(f"[Perf] AIter fused_add rmsnorm {mode} skipped: {type(e).__name__}: {e!r}") + return None + + x = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + residual_in = torch.randn((M, N), device="cuda", dtype=torch_dtype).contiguous() + w = torch.rand((N,), device="cuda", dtype=torch_dtype).contiguous() + y = torch.empty((M, N), dtype=torch.int8, device="cuda") + residual_out = torch.empty((M, N), device="cuda", dtype=torch_dtype) + yscale = torch.empty((M, 1), dtype=torch.float32, device="cuda") + + if is_smooth: + xscale = (torch.rand((N,), device="cuda", dtype=DTYPE_FP32) + 0.5).contiguous() + + def run_aiter(): + aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, xscale, yscale, w, EPS) + + else: + + def run_aiter(): + aiter_fused_add_rmsnorm_quant(y, x, residual_in, residual_out, yscale, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter fused_add rmsnorm {mode} gpu: {aiter_us:.1f} us") + return aiter_us + + +def test_rmsnorm(): + print("=" * 80) + print("Running RMSNorm Tests") + print("=" * 80) + + configs = _get_rmsnorm_configs() + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + + failures = 0 + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + try: + from aiter.ops.triton.rmsnorm import rms_norm as aiter_rms_norm + + torch_dtype = _torch_dtype(dtype) + x = torch.randn((M, N), device="cuda", dtype=torch_dtype) + w = torch.rand((N,), device="cuda", dtype=torch_dtype) + + def run_aiter(): + aiter_rms_norm(x, w, EPS) + + aiter_us = bench_gpu_us_torch(run_aiter, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f"[Perf] AIter rmsnorm gpu: {aiter_us:.1f} us") + except Exception as e: + print(f"[Perf] AIter rmsnorm skipped: {type(e).__name__}: {e!r}") + + perf_rows.append( + PerfRow(op="rmsnorm", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, aiter_gpu_us=aiter_us) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: print(f"{failures} TESTS FAILED") print("=" * 80) if do_compare and perf_rows: print_perf_table(perf_rows) + # Ensure a non-zero exit code on failure for shell wrappers. if failures != 0: raise SystemExit(1) -def test_rmsnorm_fused_add_smoothquant(): +def test_rmsnorm_dynamicquant(): print("=" * 80) - print("Running FusedAdd RMSNorm SmoothQuant Tests") + print("Running RMSNorm DynamicQuant Tests") print("=" * 80) + configs = _get_rmsnorm_configs() + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" perf_rows = [] + failures = 0 + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=False) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_rmsnorm_quant(M, N, dtype, is_smooth=False) + + perf_rows.append( + PerfRow( + op="rmsnorm_dynamicquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + # Ensure a non-zero exit code on failure for shell wrappers. + if failures != 0: + raise SystemExit(1) + + +def test_rmsnorm_smoothquant(): + print("=" * 80) + print("Running RMSNorm SmoothQuant Tests") + print("=" * 80) + + configs = _get_rmsnorm_configs() + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + + failures = 0 + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_quant_test(M, N, dtype, is_smooth=True) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_rmsnorm_quant(M, N, dtype, is_smooth=True) + + perf_rows.append( + PerfRow( + op="rmsnorm_smoothquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + # Ensure a non-zero exit code on failure for shell wrappers. + if failures != 0: + raise SystemExit(1) + + +def test_fused_add_rmsnorm(): + print("=" * 80) + print("Running FusedAdd RMSNorm Tests") + print("=" * 80) + + configs = _get_rmsnorm_configs() + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] - for M, N, dtype in _get_rmsnorm_configs(): + failures = 0 + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_fused_add_test(M, N, dtype) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm(M, N, dtype) + perf_rows.append( + PerfRow( + op="rmsnorm_add", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + # Ensure a non-zero exit code on failure for shell wrappers. + if failures != 0: + raise SystemExit(1) + + +def test_fused_add_rmsnorm_dynamicquant(): + print("=" * 80) + print("Running FusedAdd RMSNorm DynamicQuant Tests") + print("=" * 80) + + configs = _get_rmsnorm_configs() + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + + failures = 0 + for M, N, dtype in configs: + ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=False) + if not ok: + failures += 1 + + if do_compare: + aiter_us = None + if maybe_enable_aiter(): + aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=False) + perf_rows.append( + PerfRow( + op="rmsnorm_add_dynamicquant", + shape=f"{M}x{N}", + dtype=dtype, + flydsl_gpu_us=flydsl_gpu_us, + aiter_gpu_us=aiter_us, + ) + ) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + if failures != 0: + raise SystemExit(1) + + +def test_fused_add_rmsnorm_smoothquant(): + print("=" * 80) + print("Running FusedAdd RMSNorm SmoothQuant Tests") + print("=" * 80) + + configs = _get_rmsnorm_configs() + + do_compare = os.environ.get("ROCDSL_COMPARE_AITER", "0") == "1" + perf_rows = [] + + failures = 0 + for M, N, dtype in configs: ok, flydsl_gpu_us = run_fused_add_quant_test(M, N, dtype, is_smooth=True) if not ok: failures += 1 @@ -941,7 +914,7 @@ def test_rmsnorm_fused_add_smoothquant(): aiter_us = _bench_aiter_fused_add_rmsnorm_quant(M, N, dtype, is_smooth=True) perf_rows.append( PerfRow( - op="rmsnorm_add_sq", + op="rmsnorm_add_smoothquant", shape=f"{M}x{N}", dtype=dtype, flydsl_gpu_us=flydsl_gpu_us, @@ -962,4 +935,9 @@ def test_rmsnorm_fused_add_smoothquant(): if __name__ == "__main__": - test_all() + test_rmsnorm() + test_rmsnorm_dynamicquant() + test_rmsnorm_smoothquant() + test_fused_add_rmsnorm() + test_fused_add_rmsnorm_dynamicquant() + test_fused_add_rmsnorm_smoothquant()