diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 0270893c..facfb4bb 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -1,8 +1,9 @@ """Unified MXFP4/MXFP8/A8W4 GEMM kernel for gfx1250. -Supports FP4 (E2M1), FP8 (E4M3) and A8W4 (FP8 activation + FP4 weight) -data with E8M0 block scales via V_WMMA_SCALE instructions. -Select precision with ``data_format="fp4"|"fp8"|"a8w4"``. +Supports FP4 (E2M1), FP8 (E4M3) and A8W4 (FP8 activation + FP4 weight), +selected via ``data_format="fp4"|"fp8"|"a8w4"``. Scales are either E8M0 +block scales applied in-MMA (``scale_mode="mxscale"``) or per-token/ +per-channel fp32 scales applied in the epilogue (``scale_mode="ptpc"``). """ import functools @@ -12,7 +13,7 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir -from flydsl._mlir.dialects import fly, llvm +from flydsl._mlir.dialects import fly, llvm, scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops from flydsl.expr.rocdl import cluster @@ -69,10 +70,10 @@ def _make_tdm_desc(*, early_timeout=False, **kwargs): @functools.lru_cache(maxsize=256) -def compile_mxscale_gemm( +def compile_fp8fp4_gemm( *, data_format: str = "fp4", - M: int = 0, + scale_mode: str = "mxscale", N: int = 0, K: int, tile_m: int = 128, @@ -97,25 +98,33 @@ def compile_mxscale_gemm( scale_load_path: str = "tdm", fp8_schedule: str = "auto", ): - """Compile an MXFP4 or MXFP8 GEMM kernel with TDM async copy. + """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. Args: - data_format: "fp4" for FP4/E2M1, "fp8" for FP8/E4M3. + data_format: "fp4" (E2M1), "fp8" (E4M3), or "a8w4" (FP8 act + FP4 weight). + scale_mode: "mxscale" (E8M0 block scale via V_WMMA_SCALE) or "ptpc" + (per-token sa[M] / per-channel sb[N] fp32, applied in the epilogue). - Data layout (both formats): + Data layout: A: [M, K_packed] uint8 (FP4: K_packed=K//2, FP8: K_packed=K) B: [N, K_packed] uint8, preshuffled (16x16 byte tiles) - scale_A: [M, K//32] uint8 E8M0 (preshuffled) - scale_B: [N, K//32] uint8 E8M0 (preshuffled) + mxscale: scale_A [M, K//32], scale_B [N, K//32] uint8 E8M0 (preshuffled) + ptpc: scale_A [M], scale_B [N] fp32 Returns a JitFunction: - launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, stream) + launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, lda, ldc, stream) + where lda/ldc are A/C runtime leading-dim strides in elements (dense: lda=K, ldc=N). """ if data_format not in ("fp4", "fp8", "a8w4"): raise ValueError(f"data_format must be 'fp4', 'fp8', or 'a8w4', got {data_format!r}") + if scale_mode not in ("mxscale", "ptpc"): + raise ValueError(f"scale_mode must be 'mxscale' or 'ptpc', got {scale_mode!r}") + if scale_mode == "ptpc" and data_format not in ("fp8", "a8w4"): + raise ValueError("scale_mode='ptpc' currently only supports data_format='fp8' or 'a8w4'") is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" + is_ptpc = scale_mode == "ptpc" if out_dtype not in ("f32", "bf16", "f16"): raise ValueError(f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}") @@ -151,8 +160,9 @@ def compile_mxscale_gemm( if block_threads > 1024: raise ValueError(f"block_threads must be <= 1024, got {block_threads}") - if wave_specialized_tdm and num_warps < 4: - raise ValueError(f"wave_specialized_tdm requires at least 4 waves, got {num_warps}") + _min_wave_spec_warps = 2 if is_ptpc else 4 + if wave_specialized_tdm and num_warps < _min_wave_spec_warps: + raise ValueError(f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}") # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) @@ -240,8 +250,8 @@ def compile_mxscale_gemm( ab_split_a_rows = tile_m // 2 ab_split_b_groups = tile_n // 32 _scale_guard_bytes = 16 - lds_a_scale_bytes = tile_m * scale_k_per_tile + _scale_guard_bytes - lds_b_scale_bytes = tile_n * scale_k_per_tile + _scale_guard_bytes + lds_a_scale_bytes = 0 if is_ptpc else tile_m * scale_k_per_tile + _scale_guard_bytes + lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile @@ -418,7 +428,7 @@ def _align_up(value: int, align: int) -> int: COMPUTE_SCHEDULE_B_STREAMING = "b_streaming" fp8_deep_pipeline_eligible = ( - data_format == "fp8" + data_format in ("fp8", "a8w4") and tile_m == 256 and tile_n == 256 and tile_k == 128 @@ -447,7 +457,9 @@ def _pick_compute_schedule_kind(): # accumulators and uses the split to increase LDS-load-to-WMMA distance. if is_fp4: return COMPUTE_SCHEDULE_FP4_COL_BAND - if data_format == "fp8": + # A8W4 (FP8 act + FP4 weight) shares FP8's accumulator layout and operand + # path, so it reuses the FP8 schedules. + if data_format in ("fp8", "a8w4"): if fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and fp8_deep_pipeline_eligible): return COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE return COMPUTE_SCHEDULE_FP8_QUADRANT @@ -495,7 +507,7 @@ def _pick_compute_schedule_kind(): _fp8_half_wm = wmma_m_rep // 2 _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn - _fp8_b_scale_loads = (b_scale_load_rep + 3) // 4 + _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 if use_fp8_deep_pipeline_schedule: _fp8_pair_wm = 2 _fp8_pair_wn = 2 @@ -503,7 +515,7 @@ def _pick_compute_schedule_kind(): _fp8_wn_pairs = wmma_n_rep // _fp8_pair_wn _fp8_pair_a_loads = _fp8_pair_wm * DS_LOADS_PER_A_FRAG _fp8_pair_b_loads = _fp8_pair_wn * _b_frag_loads_per_wn - _fp8_scale_loads = (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 + _fp8_scale_loads = 0 if is_ptpc else (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 @flyc.kernel(known_block_size=[block_threads, 1, 1]) def kernel_mxscale_gemm( @@ -514,6 +526,8 @@ def kernel_mxscale_gemm( arg_b_scale: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, + i32_lda: fx.Int32, + i32_ldc: fx.Int32, ): # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) rocdl.disable_xdl_arb_stall() @@ -585,7 +599,13 @@ def _bvs_prefetch(k_base): return a, b m_idx = fx.Index(i32_m) - n_stride = arith.index(N) + # Runtime leading-dim strides (strided A/C). Dense callers pass lda == K, + # ldc == N for byte-identical addressing. A's stride is in packed elements. + if const_expr(PACK_FACTOR_A == 1): + lda_packed = fx.Index(i32_lda) + else: + lda_packed = fx.Index(i32_lda) / arith.index(PACK_FACTOR_A) + n_stride = fx.Index(i32_ldc) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) c_global_ptr_type = ir.Type.parse("!llvm.ptr<1>") @@ -600,7 +620,7 @@ def make_desc_a(memref, k_base): lds_memref=memref, global_offset=(blk_m, k_packed_off), tensor_shape=(tile_m, packed_tile_k_a), - strides=(K_packed_a, 1), + strides=(lda_packed, 1), tile_shape=(tile_m, packed_tile_k_a), elem_bytes=1, pad_interval=packed_tile_k_a, @@ -609,6 +629,7 @@ def make_desc_a(memref, k_base): workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=i32_m, ) def make_desc_b(memref, k_base): @@ -637,7 +658,7 @@ def make_desc_a_half(memref, k_base, m_half: int): lds_memref=memref, global_offset=(blk_m + arith.index(row_start), k_packed_off), tensor_shape=(tile_m, packed_tile_k_a), - strides=(K_packed_a, 1), + strides=(lda_packed, 1), tile_shape=(ab_split_a_rows, packed_tile_k_a), elem_bytes=1, pad_interval=packed_tile_k_a, @@ -647,6 +668,7 @@ def make_desc_a_half(memref, k_base, m_half: int): lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=i32_m, ) def make_desc_b_half(memref, k_base, n_half: int): @@ -882,6 +904,8 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): FP4 BScale has no op_sel (scaleAType=0 fixed); only AScale halves. FP8/A8W4 16x16 supports op_sel on both. """ + if const_expr(is_ptpc): + return None, None a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) if const_expr(use_scale_opsel): @@ -904,6 +928,35 @@ def _load_a_and_scales(a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases, ks): def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): """Emit one WMMA instruction (format-specific).""" idx = wm * wmma_n_rep + wn + if const_expr(is_ptpc): + if const_expr(is_a8w4): + accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( + T.vec(8, T.f32), + b_frag, + a_frag, + accs[idx], + 0x7F7F7F7F, + 0x7F7F7F7F, + fmtA=4, + fmtB=0, + ) + else: + # PTPC-FP8 needs no per-K scaling. We emit the scaled f8f6f4 op + # with an identity E8M0 scale (0x7F = 2^0 = 1.0) for toolchain + # compatibility; it is numerically equivalent to the dedicated + # no-scale op. Future: switch to the equivalent no-scale wmma: + # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) + accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( + T.vec(8, T.f32), + b_frag, + a_frag, + accs[idx], + 0x7F7F7F7F, + 0x7F7F7F7F, + fmtA=0, + fmtB=0, + ) + return if const_expr(use_scale_opsel): a_scale_idx = wm // 2 a_opsel = wm % 2 @@ -1282,12 +1335,16 @@ def _load_b_half(wn_base, ks): ] def _load_a_scales(ks): + if const_expr(is_ptpc): + return None # PTPC: scale applied in epilogue, not in K-loop a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) if const_expr(use_scale_opsel): return a_scales[::2] return a_scales def _load_b_scales(ks): + if const_expr(is_ptpc): + return None # PTPC: scale applied in epilogue, not in K-loop b_scales = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) if const_expr(use_scale_opsel): return b_scales[::2] @@ -1464,6 +1521,8 @@ def load_b_pair(wn_pair, ks): ] def _load_a_scales(ks): + if const_expr(is_ptpc): + return None # PTPC: scale applied in epilogue, not in K-loop if const_expr(use_buffer_vgpr_scale): if const_expr(pf_a_scales is not None): return pf_a_scales # prefetched (issued in the prior compute tile) @@ -1471,6 +1530,8 @@ def _load_a_scales(ks): return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) def _load_b_scales(ks): + if const_expr(is_ptpc): + return None # PTPC: scale applied in epilogue, not in K-loop if const_expr(use_buffer_vgpr_scale): if const_expr(pf_b_scales is not None): return pf_b_scales @@ -1663,17 +1724,18 @@ def hot_loop_scheduler(): _half_wm = wmma_m_rep // 2 _half_wmma = _half_wm * wmma_n_rep _b_loads_per_frag = 2 if is_a8w4 else 4 + _scale_dsrd = 0 if is_ptpc else 2 for _ks in range_constexpr(k_wmma_steps): if const_expr(_ks == 0): - rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + 2 + _half_wm * DS_LOADS_PER_A_FRAG) + rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd + _half_wm * DS_LOADS_PER_A_FRAG) else: rocdl.sched_dsrd(_half_wm * DS_LOADS_PER_A_FRAG) rocdl.sched_mfma(_half_wmma) rocdl.sched_dsrd(_half_wm * DS_LOADS_PER_A_FRAG) rocdl.sched_mfma(_half_wmma) if const_expr(_ks < k_wmma_steps - 1): - rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + 2) + rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd) rocdl.sched_barrier(0) def hot_loop_scheduler_fp4_bank_friendly(): @@ -1699,7 +1761,7 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_barrier(0) def hot_loop_scheduler_fp8_quadrant(): - _a_scale_loads = (wmma_m_rep + 3) // 4 + _a_scale_loads = 0 if is_ptpc else (wmma_m_rep + 3) // 4 _a_top_loads = _fp8_half_wm * DS_LOADS_PER_A_FRAG _a_bottom_loads = _a_top_loads _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn @@ -1961,10 +2023,16 @@ def epilogue_atomic_adds(final_accs, addrs): addr_idx = 0 for acc_idx, vec_base, m_off, wn in _sub_tiles: sub8 = _get_acc_sub8(final_accs, acc_idx, vec_base) - if const_expr(_bf16_out): - addr_idx += _atomic_add_acc_vec8_to_buffer(sub8, addrs[addr_idx]) - else: - addr_idx += _atomic_add_acc_vec8_to_buffer(sub8, addrs[addr_idx : addr_idx + 2]) + n_slots = 1 if _bf16_out else 2 + addr_arg = addrs[addr_idx] if _bf16_out else addrs[addr_idx : addr_idx + 2] + # Atomics use a raw global ptr (no num_records clip), so predicate + # per-lane to skip rows >= M. + row = blk_m + warp_m_base + arith.index(m_off) + lane16 + if_op = scf.IfOp(row < m_idx, [], has_else=False) + with ir.InsertionPoint(if_op.then_block): + _atomic_add_acc_vec8_to_buffer(sub8, addr_arg) + scf.YieldOp([]) + addr_idx += n_slots def grouped_accs_to_row_major(accs_grouped): row_major = [None] * n_accs @@ -1977,6 +2045,43 @@ def finalize_acc_layout(accs_in): return grouped_accs_to_row_major(accs_in) return accs_in + def epilogue_load_ptpc_scales(): + # PTPC scales: sa[M] per-token (scalar per wm), sb[N] per-channel + # (8 contiguous N cols per wn). Both fp32, constant along K. + # The scale memrefs are dynamically shaped, so max_size=False would fall + # back to a max-sized descriptor and disable hardware OOB. Derive + # num_records from runtime M / compile-time N (fp32 = 4 bytes) so the + # partial last M-tile clips rows >= M (and cols >= N) to 0. + sa_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, num_records_bytes=m_idx * arith.index(4)) + sb_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, num_records_bytes=N * 4) + sa = [] + for wm in range_constexpr(wmma_m_rep): + row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 + sv = buffer_ops.buffer_load(sa_rsrc, arith.index_cast(T.i32, row), vec_width=1, dtype=T.f32) + sa.append(fx.Vector.from_elements([sv] * 8)) + sb = [] + for wn in range_constexpr(wmma_n_rep): + col_base = blk_n + warp_n_base + arith.index(wn * WMMA_N) + lane_kgrp * arith.index(8) + # buffer_load vec_width is capped at 4: read 8 cols as 2x vec4. + lo = fx.Vector( + buffer_ops.buffer_load(sb_rsrc, arith.index_cast(T.i32, col_base), vec_width=4, dtype=T.f32) + ) + hi = fx.Vector( + buffer_ops.buffer_load( + sb_rsrc, arith.index_cast(T.i32, col_base + arith.index(4)), vec_width=4, dtype=T.f32 + ) + ) + sb.append(fx.Vector.from_elements([lo[0], lo[1], lo[2], lo[3], hi[0], hi[1], hi[2], hi[3]])) + return sa, sb + + def epilogue_apply_ptpc_scale(accs_in, sa, sb): + out = list(accs_in) + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + out[idx] = (fx.Vector(out[idx]) * sb[wn] * sa[wm]).ir_value() + return out + _effective_l2_pf = l2_prefetch_distance if const_expr(use_cluster and l2_prefetch_distance > 0): _effective_l2_pf = max(1, l2_prefetch_distance - 1) @@ -2025,14 +2130,21 @@ def _l2_prefetch(k_base): SmemPtr(arena_base_ptr, stage_b_data_off[i], elem_ty_lds, shape=(lds_b_data_f16,)) for i in range_constexpr(num_buffers) ] - stages_as = [ - SmemPtr(arena_base_ptr, stage_a_scale_off[i], elem_ty_lds, shape=(lds_a_scale_f16,)) - for i in range_constexpr(num_buffers) - ] - stages_bs = [ - SmemPtr(arena_base_ptr, stage_b_scale_off[i], elem_ty_lds, shape=(lds_b_scale_f16,)) - for i in range_constexpr(num_buffers) - ] + if const_expr(is_ptpc): + # PTPC applies sa*sb in the epilogue from global memory: no scale LDS. + # Alias the scale stage handles to A/B so the shared plumbing stays + # valid; for PTPC they are never written (no scale TDM) or read. + stages_as = stages_a + stages_bs = stages_b + else: + stages_as = [ + SmemPtr(arena_base_ptr, stage_a_scale_off[i], elem_ty_lds, shape=(lds_a_scale_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_bs = [ + SmemPtr(arena_base_ptr, stage_b_scale_off[i], elem_ty_lds, shape=(lds_b_scale_f16,)) + for i in range_constexpr(num_buffers) + ] stages_a_mem = [stages_a[i].get() for i in range_constexpr(num_buffers)] stages_b_mem = [stages_b[i].get() for i in range_constexpr(num_buffers)] @@ -2070,7 +2182,7 @@ def _l2_prefetch(k_base): lds_memref=d_lds_base_ptr, global_offset=(blk_m + warp_m_off_sgpr, blk_n + warp_n_off_sgpr), tensor_shape=(warp_tile_m, warp_tile_n), - strides=(N, 1), + strides=(n_stride, 1), tile_shape=(warp_tile_m, warp_tile_n), elem_bytes=elem_bytes_d, pad_interval=warp_tile_n, @@ -2078,6 +2190,7 @@ def _l2_prefetch(k_base): num_warps=1, lds_byte_offset=d_warp_off_sgpr, for_store=True, + oob_outer_bound=i32_m, ) # TDM descriptor lane layout: dgroup0 = [predicate, lds_addr, addr_lo, addr_hi]. @@ -2095,13 +2208,22 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): for i in range_constexpr(num_buffers): stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) stages_b_lds_addr.append(_dg0_lane(make_desc_b(stages_b_mem[i], arith.index(0)), 1)) - stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) - stages_bs_lds_addr.append(_dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1)) + if const_expr(not is_ptpc): + stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) + stages_bs_lds_addr.append(_dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1)) desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) desc_b_init = make_desc_b(stages_b_mem[0], split_k_base) - desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) - desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) + if const_expr(is_ptpc): + # No scale TDM for PTPC: alias the scale descriptors/addresses to A/B. + # Scale waves are predicated off, so these selections are never issued. + stages_as_lds_addr = stages_a_lds_addr + stages_bs_lds_addr = stages_b_lds_addr + desc_as_init = desc_a_init + desc_bs_init = desc_b_init + else: + desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) + desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) if const_expr(use_ab_half_split): stages_a0_lds_addr = [] stages_b0_lds_addr = [] @@ -2125,9 +2247,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): - # With scale on the VGPR path, drop scale waves 2,3 from the active TDM - # path -- unless ab-half-split repurposes them as the second A/B halves. - _drop_scale_waves = use_buffer_vgpr_scale and not use_ab_half_split + _drop_scale_waves = is_ptpc or (use_buffer_vgpr_scale and not use_ab_half_split) _active_wave_limit = 2 if _drop_scale_waves else 4 active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) @@ -2407,6 +2527,12 @@ def _mid_tdm_nws( elif const_expr(use_cluster): cluster.cluster_barrier() epi_addrs_box = [None] + _ptpc_scale_box = [None] + + def _load_ptpc_scales_once(): + if const_expr(is_ptpc and _ptpc_scale_box[0] is None): + _ptpc_scale_box[0] = epilogue_load_ptpc_scales() + _tail_had_load = False # Tail K-tile index, so the VGPR-path scale buffer_load uses the right k_base. _bvs_tail_kt = [loop_iters * num_buffers] @@ -2431,6 +2557,7 @@ def _bvs_tail_kb(): stages_b_idx[_compute_stage], stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], + emit_filler=(_load_ptpc_scales_once if is_ptpc else None), a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, ) @@ -2438,6 +2565,7 @@ def _bvs_tail_kb(): def _emit_epi_addrs(): epi_addrs_box[0] = epilogue_prepare_addrs() + _load_ptpc_scales_once() a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) accs = compute_tile_scheduled( @@ -2512,7 +2640,12 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): accs = finalize_acc_layout(accs) - if const_expr(use_tdm_store): + if const_expr(is_ptpc): + _load_ptpc_scales_once() + _ptpc_sa, _ptpc_sb = _ptpc_scale_box[0] + accs = epilogue_apply_ptpc_scale(accs, _ptpc_sa, _ptpc_sb) + + def _emit_tdm_store(): if const_expr(d_need_epilogue_fence): _pipeline_fence(outstanding=0) rocdl.sched_barrier(0) @@ -2520,7 +2653,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): rocdl.s_wait_dscnt(0) tdm_ops.tensor_store_2d(d_desc) tdm_ops.tensor_wait(0) - else: + + def _emit_buffer_store(): rocdl.sched_barrier(0) if const_expr(epi_addrs_box[0] is None): epi_addrs_box[0] = epilogue_prepare_addrs() @@ -2529,8 +2663,21 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): else: epilogue_stores(accs, epi_addrs_box[0]) + if const_expr(use_tdm_store): + full_tile = (blk_m + arith.index(tile_m)) <= m_idx + if_op = scf.IfOp(full_tile, [], has_else=True) + with ir.InsertionPoint(if_op.then_block): + _emit_tdm_store() + scf.YieldOp([]) + with ir.InsertionPoint(if_op.else_block): + _emit_buffer_store() + scf.YieldOp([]) + else: + _emit_buffer_store() + cache_tag = ( data_format, + scale_mode, K, tile_m, tile_n, @@ -2565,6 +2712,8 @@ def launch_mxscale_gemm( arg_b_scale: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, + i32_lda: fx.Int32, + i32_ldc: fx.Int32, stream: fx.Stream, ): _ = cache_tag @@ -2577,6 +2726,10 @@ def launch_mxscale_gemm( gy = (i32_n + (tile_n - 1)) // tile_n gz = split_k + if const_expr(use_cluster): + # Cluster launch needs a cluster-divisible grid + gx = ((gx + (cluster_m - 1)) // cluster_m) * cluster_m + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None kernel_mxscale_gemm( arg_c, @@ -2586,6 +2739,8 @@ def launch_mxscale_gemm( arg_b_scale, i32_m, i32_n, + i32_lda, + i32_ldc, value_attrs={ "rocdl.waves_per_eu": effective_waves_per_eu, "rocdl.cluster_dims": f"{cluster_m},{cluster_n},1" if const_expr(use_cluster) else None, @@ -2605,16 +2760,88 @@ def launch_mxscale_gemm( return launch_mxscale_gemm +def compile_mxscale_gemm(**kw): + """Backward-compatible wrapper: MX block-scale (E8M0) GEMM.""" + return compile_fp8fp4_gemm(scale_mode="mxscale", **kw) + + def compile_mxfp4_gemm(**kw): - return compile_mxscale_gemm(data_format="fp4", **kw) + return compile_fp8fp4_gemm(data_format="fp4", scale_mode="mxscale", **kw) def compile_mxfp8_gemm(**kw): - return compile_mxscale_gemm(data_format="fp8", **kw) + return compile_fp8fp4_gemm(data_format="fp8", scale_mode="mxscale", **kw) def compile_a8w4_gemm(**kw): - return compile_mxscale_gemm(data_format="a8w4", **kw) + return compile_fp8fp4_gemm(data_format="a8w4", scale_mode="mxscale", **kw) + + +def compile_ptpc_gemm( + *, + N: int = 0, + K: int, + data_format: str = "fp8", + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + m_warp: int = 2, + n_warp: int = 2, + num_buffers: int = 4, + waves_per_eu: int = None, + l2_prefetch_distance: int = 0, + cluster_m: int = 1, + cluster_n: int = 1, + out_dtype: str = "bf16", + inst_prefetch: bool = False, + expert_sched_mode: bool = True, + atomic_barrier_enable: bool = False, + split_k: int = 1, +): + """Compile a PTPC (per-token per-channel) GEMM kernel. + + A scale is per-token (sa[M], fp32), B scale is per-channel (sb[N], fp32), + both constant along K. The K-loop runs the WMMA unscaled (FP8) or with an + identity E8M0 scale (A8W4, which has no non-scale op); sa*sb is applied in + the epilogue in fp32. split_k>1 is supported (atomic add path). + + data_format: "fp8" (FP8 act + FP8 weight) or "a8w4" (FP8 act + FP4 weight). + wave_specialized_tdm=True requires m_warp*n_warp >= 2. + """ + return compile_fp8fp4_gemm( + data_format=data_format, + scale_mode="ptpc", + b_streaming=False, + wave_specialized_tdm=True, + use_scale_opsel=False, + fp8_schedule="auto", + scale_load_path="tdm", + use_tdm_store=(split_k == 1), + N=N, + K=K, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + waves_per_eu=waves_per_eu, + l2_prefetch_distance=l2_prefetch_distance, + cluster_m=cluster_m, + cluster_n=cluster_n, + out_dtype=out_dtype, + inst_prefetch=inst_prefetch, + expert_sched_mode=expert_sched_mode, + atomic_barrier_enable=atomic_barrier_enable, + split_k=split_k, + ) -__all__ = ["compile_mxscale_gemm", "compile_mxfp4_gemm", "compile_mxfp8_gemm", "compile_a8w4_gemm"] +__all__ = [ + "compile_fp8fp4_gemm", + "compile_mxscale_gemm", + "compile_mxfp4_gemm", + "compile_mxfp8_gemm", + "compile_a8w4_gemm", + "compile_ptpc_gemm", +] diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index 99fa11c0..59b07925 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -121,6 +121,7 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): # Keep references to ODS-generated builders so we can wrap them without losing access. _ods_wmma_scale_f32_16x16x128_f8f6f4 = globals().get("wmma_scale_f32_16x16x128_f8f6f4", None) _ods_wmma_scale_f32_32x16x128_f4 = globals().get("wmma_scale_f32_32x16x128_f4", None) +_ods_wmma_f32_16x16x128_fp8_fp8 = globals().get("wmma_f32_16x16x128_fp8_fp8", None) _ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 _ods_cluster_workgroup_id_x = cluster_workgroup_id_x _ods_cluster_workgroup_id_y = cluster_workgroup_id_y @@ -483,6 +484,32 @@ def wmma_scale_f32_32x16x128_f4( ).result +def wmma_f32_16x16x128_fp8_fp8(result_type, a, b, c, *, modC=0, reuseA=False, reuseB=False, loc=None, ip=None): + """Non-scale V_WMMA_F32_16X16X128 (E4M3) for gfx1250 (wave32). + + Operand types (wave32): + a: vector<16xi32> (16x128 FP8/E4M3 data) + b: vector<16xi32> (128x16 FP8/E4M3 data) + c: vector<8xf32> (16x16 FP32 accumulator) + """ + if _ods_wmma_f32_16x16x128_fp8_fp8 is None: + raise AttributeError("ROCDL op not found: wmma_f32_16x16x128_fp8_fp8") + a_v = _unwrap_wmma_operand(a, loc=loc) + b_v = _unwrap_wmma_operand(b, loc=loc) + c_v = _unwrap_wmma_operand(c, loc=loc) + return _ods_wmma_f32_16x16x128_fp8_fp8( + result_type, + a_v, + b_v, + c_v, + modC=modC, + reuseA=reuseA, + reuseB=reuseB, + loc=loc, + ip=ip, + ).result + + def wave_id(): """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). @@ -677,6 +704,7 @@ def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): "wmma_i32_16x16x32_iu4", "wmma_scale_f32_16x16x128_f8f6f4", # gfx1250 WMMA_SCALE 16x16x128 (FP4/FP6/FP8) "wmma_scale_f32_32x16x128_f4", # gfx1250 WMMA_SCALE 32x16x128 (FP4 only) + "wmma_f32_16x16x128_fp8_fp8", # gfx1250 WMMA 16x16x128 non-scale (E4M3) # Matrix operations - SMFMAC (Sparse Matrix FMA) "smfmac_f32_32x32x16_f16", "smfmac_f32_32x32x16_bf16", diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 3bc5cc0a..e4085148 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -21,6 +21,7 @@ # Keep references to ODS-generated builders so we can wrap them without losing access. _ods_wmma_scale_f32_16x16x128_f8f6f4 = globals().get("wmma_scale_f32_16x16x128_f8f6f4", None) _ods_wmma_scale_f32_32x16x128_f4 = globals().get("wmma_scale_f32_32x16x128_f4", None) +_ods_wmma_f32_16x16x128_fp8_fp8 = globals().get("wmma_f32_16x16x128_fp8_fp8", None) _ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 _ods_cluster_workgroup_id_x = cluster_workgroup_id_x _ods_cluster_workgroup_id_y = cluster_workgroup_id_y @@ -308,6 +309,32 @@ def wmma_scale_f32_32x16x128_f4( ).result +def wmma_f32_16x16x128_fp8_fp8(result_type, a, b, c, *, modC=0, reuseA=False, reuseB=False, loc=None, ip=None): + """Non-scale V_WMMA_F32_16X16X128 (E4M3) for gfx1250 (wave32). + + Operand types (wave32): + a: vector<16xi32> (16x128 FP8/E4M3 data) + b: vector<16xi32> (128x16 FP8/E4M3 data) + c: vector<8xf32> (16x16 FP32 accumulator) + """ + if _ods_wmma_f32_16x16x128_fp8_fp8 is None: + raise AttributeError("ROCDL op not found: wmma_f32_16x16x128_fp8_fp8") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + return _ods_wmma_f32_16x16x128_fp8_fp8( + result_type, + a_v, + b_v, + c_v, + modC=modC, + reuseA=reuseA, + reuseB=reuseB, + loc=loc, + ip=ip, + ).result + + def wave_id(): """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). diff --git a/python/flydsl/expr/rocdl/tdm_ops.py b/python/flydsl/expr/rocdl/tdm_ops.py index f2644d31..56a24b19 100644 --- a/python/flydsl/expr/rocdl/tdm_ops.py +++ b/python/flydsl/expr/rocdl/tdm_ops.py @@ -216,6 +216,7 @@ def make_tensor_descriptor_2d( for_store: bool = False, atomic_barrier_enable: bool = False, early_timeout: bool = False, + oob_outer_bound=None, ) -> TDMDescriptor2D: """Build a 2D TDM descriptor for tensor_load_to_lds_d2. @@ -238,7 +239,8 @@ def make_tensor_descriptor_2d( lds_memref: The LDS memref value (already the correct buffer slot). global_offset: (outer_idx, inner_idx) as MLIR index values. tensor_shape: (outer_size, inner_size) as Python ints. - strides: (outer_stride, inner_stride) as Python ints. + strides: (outer_stride, inner_stride); inner is a Python int, outer + may be an int or a runtime i32/index Value (strided A/C). tile_shape: (outer_tile, inner_tile) as Python ints. elem_bytes: Element size in bytes (2 for f16/bf16, 4 for f32). pad_interval: Padding interval in elements (0 to disable). @@ -265,6 +267,20 @@ def make_tensor_descriptor_2d( multicast-load knob (1 = GL1 returns to the requesters present when GL2 data arrives, latecomers re-broadcast; default 0 = standard wider-merge timeout). + oob_outer_bound: Optional runtime outer-dim global extent (e.g. real M for + a row-major A/C) for non-tile-aligned outer dims. When given, + ``tensor_dim1`` is set to the tile-start-relative remaining + extent ``max(0, oob_outer_bound - (outer_off + warp_off_outer))`` + while ``tile_dim1`` is left at the full per-warp tile, so the + partial last tile exceeds the tensor bound and the HW + OOB-handles the overhang. On the validated eng-sample a + regular-D# load issues no global fetch for the OOB rows + (fault-safe) and zero-fills them in LDS. Store-side OOB via + this field is HW-context dependent and not relied upon by + callers (see flydsl_fp8_perf/m_pad_oob/FINDINGS.md). Accepts a + Python int or an i32/index ir.Value. None (default) keeps + tensor_dim1 == tile_dim1 (OOB off) — byte-identical to the + original path. Returns: TDMDescriptor2D with dgroup0 and dgroup1 ready for tensor_load_2d. @@ -276,6 +292,23 @@ def make_tensor_descriptor_2d( outer_tile, inner_tile = tile_shape outer_off, inner_off = global_offset + # outer_stride may be a compile-time int or a runtime i32/index Value (strided + # A/C). Normalise to an index value for address math and remember if runtime. + if isinstance(outer_stride, int): + outer_stride_idx = arith.index(outer_stride) + outer_stride_is_runtime = False + else: + os_val = outer_stride.ir_value() if hasattr(outer_stride, "ir_value") else outer_stride + if not isinstance(os_val, ir.Value): + raise TypeError(f"outer stride must be int or i32/index ir.Value, got {type(outer_stride).__name__}") + if isinstance(os_val.type, ir.IndexType): + outer_stride_idx = _ArithValue(os_val) + elif isinstance(os_val.type, ir.IntegerType) and os_val.type.width == 32: + outer_stride_idx = arith.index_cast(T.index, os_val) + else: + raise TypeError(f"outer stride ir.Value must be index or i32, got {os_val.type}") + outer_stride_is_runtime = True + # -- Warp distribution -- warps_per_dim, block_per_warp = compute_warp_distribution( [outer_tile, inner_tile], @@ -305,9 +338,9 @@ def make_tensor_descriptor_2d( a_raw = global_ptr.__extract_to_ir_values__()[0] glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) - glb_elem_off = (outer_off + warp_off_outer) * arith.index(outer_stride) + ( - inner_off + warp_off_inner - ) * arith.index(inner_stride) + glb_elem_off = (outer_off + warp_off_outer) * outer_stride_idx + (inner_off + warp_off_inner) * arith.index( + inner_stride + ) glb_byte_off = glb_elem_off * arith.index(elem_bytes) glb_byte_off_i64 = arith.index_cast(T.i64, glb_byte_off) glb_addr_i64 = glb_base_i64 + glb_byte_off_i64 @@ -395,23 +428,59 @@ def make_tensor_descriptor_2d( # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) - # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] - g1_s2 = arith.constant( - ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), - type=T.i32, - ) - - # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] - g1_s3 = arith.constant( - ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), - type=T.i32, - ) + if oob_outer_bound is None: + # Compile-time tensor_dim1 == tile extent: OOB checking off. + # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] + g1_s2 = arith.constant( + ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), + type=T.i32, + ) + # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] + g1_s3 = arith.constant( + ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), + type=T.i32, + ) + else: + # Runtime tensor_dim1 = max(0, oob_outer_bound - (outer_off + warp_off_outer)), + # tile-start-relative (the descriptor's global address already includes the + # tile/warp start). tile_dim1 (sgpr4) stays the full per-warp tile, so the + # partial last tile exceeds the tensor bound and the HW OOB-handles the + # overhang. tensor_dim0 (innermost) and the tile dims stay compile-time. + if isinstance(oob_outer_bound, int): + ob_i32 = arith.constant(oob_outer_bound, type=T.i32) + else: + ob_i32 = oob_outer_bound.ir_value() if hasattr(oob_outer_bound, "ir_value") else oob_outer_bound + if not isinstance(ob_i32, ir.Value): + raise TypeError( + f"oob_outer_bound must be int or i32/index ir.Value, got {type(oob_outer_bound).__name__}" + ) + if isinstance(ob_i32.type, ir.IndexType): + ob_i32 = arith.index_cast(T.i32, ob_i32) + elif not (isinstance(ob_i32.type, ir.IntegerType) and ob_i32.type.width == 32): + raise TypeError(f"oob_outer_bound ir.Value must be index or i32, got {ob_i32.type}") + start_i32 = arith.index_cast(T.i32, outer_off + warp_off_outer) + tdim1_rt = arith.maxsi(arith.subi(ob_i32, start_i32), arith.constant(0, type=T.i32)) + c16 = arith.constant(16, type=T.i32) + c_mask16 = arith.constant(0xFFFF, type=T.i32) + # sgpr2: tensor_dim0_hi[15:0] (const) | tensor_dim1_lo[31:16] (runtime) + g1_s2 = arith.ori( + arith.constant((tdim0 >> 16) & 0xFFFF, type=T.i32), + arith.shli(arith.andi(tdim1_rt, c_mask16), c16), + ) + # sgpr3: tensor_dim1_hi[15:0] (runtime) | tile_dim0[31:16] (const) + g1_s3 = arith.ori( + arith.andi(arith.shrui(tdim1_rt, c16), c_mask16), + arith.constant(tile_d0 << 16, type=T.i32), + ) - # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 + # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 (always the full per-warp tile) g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim - g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) + if outer_stride_is_runtime: + g1_s5 = arith.index_cast(T.i32, outer_stride_idx) + else: + g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) # sgpr6-7: for 2D, no higher-dim strides g1_s6 = arith.constant(0, type=T.i32) diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 382a73f0..ad1daf3e 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Unified MXFP4/MXFP8/A8W4 GEMM correctness tests for gfx1250. +"""MXFP4/MXFP8/A8W4 and PTPC-FP8 GEMM correctness tests for gfx1250. Kernel implementation: kernels/gemm_fp8fp4_gfx1250.py """ @@ -24,7 +24,7 @@ import flydsl.compiler as flyc # noqa: E402,I001 from flydsl.runtime.device import get_rocm_arch # noqa: E402 -from kernels.gemm_fp8fp4_gfx1250 import compile_mxscale_gemm # noqa: E402 +from kernels.gemm_fp8fp4_gfx1250 import compile_mxscale_gemm, compile_ptpc_gemm # noqa: E402 from tests.kernels.utils import fp4_utils # noqa: E402 if not torch.cuda.is_available(): @@ -55,6 +55,7 @@ def preshuffle_e8m0_scale( scale_k_per_tile: int = 4, WMMA_DIM: int = 16, coalesced: bool = False, + row_align: int = None, ) -> torch.Tensor: """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access. @@ -63,8 +64,16 @@ def preshuffle_e8m0_scale( """ if coalesced: return preshuffle_e8m0_scale_coalesced(scale, block=warp_tile) - _, K_scale = scale.shape + rows, K_scale = scale.shape assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" + # Accept an unpadded row count (M for a_scale / N for b_scale): pad rows to + # row_align (the GEMM reads tile_m-granular tiles, so callers pass row_align=tile_m) + # with E8M0 127 (=1.0). Padding rows feed only discarded output rows. No-op when + # already aligned. Defaults to warp_tile (the minimum the reshape needs). + align = row_align if row_align is not None else warp_tile + if rows % align != 0: + pad = _align_up(rows, align) - rows + scale = torch.cat([scale, torch.full((pad, K_scale), 127, dtype=scale.dtype, device=scale.device)], dim=0) SCALES_PER_WMMA = 4 wmma_rep = warp_tile // WMMA_DIM k_groups = K_scale // scale_k_per_tile @@ -103,8 +112,23 @@ def _parse_fill_mode(arg: str): return ("const", value) +_MXFP4_MAGS = (0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0) + + +def _nearest_mxfp4_value(value: float) -> float: + """Nearest E2M1-representable value to `value`, never zero unless value == 0.""" + if value == 0: + return 0.0 + sign = -1.0 if value < 0 else 1.0 + mag = abs(float(value)) + return sign * min(_MXFP4_MAGS, key=lambda m: abs(m - mag)) + + def _fp4_e2m1_packed_fill(rows: int, cols: int, value: float) -> torch.Tensor: - dense = torch.full((rows, cols), float(value), dtype=torch.float32) + # Snap to the nearest nonzero E2M1 value: a raw round of a small fill (0.1) + # would land on 0 and make the whole weight tensor vanish. + snapped = _nearest_mxfp4_value(value) + dense = torch.full((rows, cols), float(snapped), dtype=torch.float32) return fp4_utils.f32_to_mxfp4(dense).view(torch.uint8) @@ -158,6 +182,11 @@ def _fill_mode_label(fill_spec, data_format: str) -> str: label = f"const={fill_spec[1]:g}, E8M0 byte=127" if data_format in ("fp8", "a8w4"): label += f", FP8 byte=0x{_fp8_e4m3fn_byte(fill_spec[1]):02x}" + if data_format in ("fp4", "a8w4"): + eff = _nearest_mxfp4_value(fill_spec[1]) + label += f", FP4={eff:g}" + if eff != fill_spec[1]: + label += f" (snapped from {fill_spec[1]:g})" return label @@ -261,17 +290,24 @@ def _get_padded_problem_shape( tile_k: int, split_k: int, ) -> dict[str, int]: - """Pad runtime problem to tile-aligned kernel dimensions.""" + """Validate tile alignment and return the (unpadded) kernel dimensions. + + N/K must divide their tiles; M is ragged (hardware OOB). Fail loudly instead + of silently host-padding. + """ if K % SCALE_BLOCK != 0: raise ValueError(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") + if N % tile_n != 0: + raise ValueError(f"N={N} must be divisible by tile_n={tile_n} (no silent pad)") + if K % (tile_k * split_k) != 0: + raise ValueError(f"K={K} must be divisible by tile_k*split_k={tile_k * split_k} (no silent pad)") pack_a, pack_b = _mxscale_pack_factors(data_format) - padded_k = _align_up(K, tile_k * split_k) return { - "M": _align_up(M, tile_m), - "N": _align_up(N, tile_n), - "K": padded_k, - "K_scale": padded_k // SCALE_BLOCK, + "M": M, + "N": N, + "K": K, + "K_scale": K // SCALE_BLOCK, "pack_a": pack_a, "pack_b": pack_b, } @@ -364,8 +400,8 @@ def _run_mxscale_gemm_test( _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} torch_out_dtype = _dtype_map[out_dtype] - # Split-K accumulates across workgroups in fp32; half outputs are converted after. - kernel_out_dtype = "f32" if (split_k > 1 and out_dtype in ("bf16", "f16")) else out_dtype + # Split-K accumulates at the output precision. + kernel_out_dtype = out_dtype torch_kernel_dtype = _dtype_map[kernel_out_dtype] torch.manual_seed(0) @@ -429,7 +465,6 @@ def _run_mxscale_gemm_test( launch_fn = compile_mxscale_gemm( data_format=data_format, - M=padded_m, N=padded_n, K=padded_k, tile_m=tile_m, @@ -469,11 +504,12 @@ def _run_mxscale_gemm_test( bs_flat, padded_m, padded_n, + padded_k, + padded_n, torch.cuda.current_stream(), ) torch.cuda.synchronize() - # Convert the fp32 split-K accumulation back to the requested half dtype. c_out = c_gpu[:M, :N].to(torch_out_dtype).cpu() print( @@ -517,7 +553,13 @@ def _run_mxscale_gemm_test( else: # FP8: standard SCALE_BLOCK=32 reference if out_dtype in ("bf16", "f16"): - torch.testing.assert_close(c_out_f, ref_f, rtol=1e-2, atol=5e-2) + # split-k atomic-adds at output precision; peak-scale tolerance to + # absorb the compounded bf16/f16 rounding on large-magnitude outputs. + if split_k > 1: + peak = float(ref_f.abs().max()) + torch.testing.assert_close(c_out_f, ref_f, rtol=2e-2, atol=max(5e-2, 2e-2 * peak)) + else: + torch.testing.assert_close(c_out_f, ref_f, rtol=1e-2, atol=5e-2) else: atol = max(1e-2, K * 0.6) torch.testing.assert_close(c_out_f, ref_f, rtol=1e-3, atol=atol) @@ -672,7 +714,7 @@ def test_mxfp8_gemm( ) -@pytest.mark.parametrize("split_k", [2, 4, 6, 8]) +@pytest.mark.parametrize("split_k", [2, 4]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) def test_mxfp8_gemm_splitk(split_k, out_dtype): """FP8 split-K: split_k workgroups accumulate partial K-sums into C via atomic add. @@ -701,8 +743,8 @@ def test_mxfp8_gemm_splitk(split_k, out_dtype): @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ - (128, 5760, 2880, 128, 256, 256, 2, 2), - (128, 2880, 2880, 128, 256, 256, 2, 2), + (128, 5632, 2816, 128, 256, 256, 2, 2), + (128, 2816, 2816, 128, 256, 256, 2, 2), (1024, 1024, 1024, 128, 256, 128, 2, 4), ], ) @@ -734,12 +776,12 @@ def test_a8w4_gemm( @pytest.mark.parametrize( "M, N, K, use_tdm_store", [ - (13, 2880, 2880, True), - (33, 5760, 2880, False), + (13, 2816, 2816, True), + (33, 5632, 2816, False), ], ) def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): - # Small-M path: pad M to 16 and dedicate one wave to the M dimension. + # Small-M path: ragged M via OOB, one wave dedicated to the M dimension. _run_mxscale_gemm_test( "a8w4", M, @@ -995,7 +1037,6 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ launch_fn = compile_mxscale_gemm( data_format=data_format, - M=M, N=N, K=K, tile_m=tile_m, @@ -1024,13 +1065,15 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ bs_flat, M, N, + K, + N, torch.cuda.current_stream(), ) # Resolve stream lazily inside the launch closure so graph capture sees # the active capture stream rather than a stream bound before capture. def launch(): - compiled_exe(c_flat, a_flat, b_flat, as_flat, bs_flat, M, N, torch.cuda.current_stream()) + compiled_exe(c_flat, a_flat, b_flat, as_flat, bs_flat, M, N, K, N, torch.cuda.current_stream()) # ── Eager run (reference) ── c_gpu.zero_() @@ -1190,6 +1233,470 @@ def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): return latencies[len(latencies) // 2] +def reference_ptpc_gemm(data_format, a, b, sa, sb, M, N, K): + """PTPC reference: D = (A @ B^T) * sa[:,None] * sb[None,:]. + + data_format="fp8": FP8 activation + FP8 weight. + data_format="a8w4": FP8 activation + FP4 (E2M1) weight. + """ + a_f32 = fp4_utils.fp8_e4m3_to_f32(a.view(torch.uint8))[:M, :K] + convert_b = fp4_utils.mxfp4_to_f32 if data_format == "a8w4" else fp4_utils.fp8_e4m3_to_f32 + b_f32 = convert_b(b.view(torch.uint8))[:N, :K] + raw = torch.matmul(a_f32, b_f32.T) + return raw * sa[:M].view(M, 1) * sb[:N].view(1, N) + + +def _run_ptpc_gemm_test( + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers, + out_dtype, + *, + data_format="fp8", + l2_prefetch_distance=2, + cluster_m=1, + cluster_n=1, + split_k=1, + lda_pad=0, + ldc_pad=0, +): + """Correctness body for PTPC (per-token per-channel) GEMM. + + A scale sa[M] (per-token) and B scale sb[N] (per-channel) are fp32, constant + along K. The K-loop runs the WMMA unscaled (fp8) or with an identity scale + (a8w4); sa*sb is applied in the epilogue. data_format: "fp8" or "a8w4". + """ + arch = str(get_rocm_arch()) + if arch != "gfx1250": + pytest.skip(f"PTPC requires gfx1250, got {arch}") + + padded_shape = _get_padded_problem_shape(data_format, M, N, K, tile_m, tile_n, tile_k, split_k) + padded_m, padded_n, padded_k = padded_shape["M"], padded_shape["N"], padded_shape["K"] + local_k = padded_k // split_k + num_k_tiles = local_k // tile_k + if num_buffers > 1 and num_k_tiles < num_buffers: + pytest.skip(f"{num_buffers}-buf requires num_k_tiles >= {num_buffers}") + + _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} + torch_out_dtype = _dtype_map[out_dtype] + kernel_out_dtype = out_dtype # split-k atomic-adds at output precision + torch_kernel_dtype = _dtype_map[kernel_out_dtype] + + torch.manual_seed(0) + a = random_fp8_data(M, K) # FP8 activation for both fp8 and a8w4 + b = fp4_utils.random_fp4_packed(N, K) if data_format == "a8w4" else random_fp8_data(N, K) + # Per-token / per-channel fp32 scales in a benign range to avoid degeneracy. + sa = (0.5 + torch.rand(M, dtype=torch.float32)).contiguous() + sb = (0.5 + torch.rand(N, dtype=torch.float32)).contiguous() + + ref = reference_ptpc_gemm(data_format, a, b, sa, sb, M, N, K) + print( + f"\nRunning PTPC {data_format.upper()} GEMM: M={M}, N={N}, K={K}, tiles=({tile_m},{tile_n},{tile_k}), " + f"bufs={num_buffers}, split_k={split_k}, out={out_dtype}" + ) + print(f"Ref stats: min={ref.min():.2f}, max={ref.max():.2f}, mean={ref.mean():.2f}, std={ref.std():.2f}") + + # Pad data to tile-aligned shapes; B is preshuffled like the mxscale path. + # A8W4 packs the FP4 weight 2-per-byte, so B's column count is K/pack_b. + K_packed_b = padded_k // padded_shape["pack_b"] + a = _pad_2d_tensor(a, padded_m, padded_k, fill_value=0) + b = _pad_2d_tensor(b, padded_n, K_packed_b, fill_value=0) + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed_b) + # Pad scales (pad region is discarded in the [:M,:N] slice). + sa_p = torch.zeros(padded_m, dtype=torch.float32) + sa_p[:M] = sa + sb_p = torch.zeros(padded_n, dtype=torch.float32) + sb_p[:N] = sb + + # Optional strided A/C: back data with a wider leading dim (lda/ldc), exercising + # the runtime-stride descriptor path. lda/ldc are logical leading dims (elements). + pack_a = padded_shape["pack_a"] + lda = padded_k + lda_pad + ldc = padded_n + ldc_pad + if lda_pad: + a_full = torch.zeros(padded_m, lda // pack_a, dtype=a.dtype) + a_full[:, : padded_k // pack_a] = a + a = a_full + + a_gpu = a.cuda() + b_gpu = b.cuda() + sa_gpu = sa_p.cuda() + sb_gpu = sb_p.cuda() + c_gpu = torch.zeros(padded_m, ldc, dtype=torch_kernel_dtype, device="cuda") + + launch_fn = compile_ptpc_gemm( + N=padded_n, + K=padded_k, + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + l2_prefetch_distance=l2_prefetch_distance, + cluster_m=cluster_m, + cluster_n=cluster_n, + out_dtype=kernel_out_dtype, + split_k=split_k, + ) + + flyc.compile( + launch_fn, + c_gpu.contiguous(), + a_gpu.contiguous(), + b_gpu.contiguous(), + sa_gpu.contiguous(), + sb_gpu.contiguous(), + padded_m, + padded_n, + lda, + ldc, + torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + c_out = c_gpu[:M, :N].to(torch_out_dtype).cpu() + print( + f"Out stats: min={c_out.float().min():.2f}, max={c_out.float().max():.2f}, " + f"mean={c_out.float().mean():.2f}, std={c_out.float().std():.2f}" + ) + if c_out.float().abs().max() < 1e-10: + print("WARNING: kernel output is all zeros!") + + c_out_f = c_out.float() + ref_f = ref.to(torch_out_dtype).float() if out_dtype in ("bf16", "f16") else ref.float() + diff = (c_out_f - ref_f).abs() + print(f"Abs diff: max={diff.max():.4f}, mean={diff.mean():.4f}") + cos_sim = torch.nn.functional.cosine_similarity( + c_out_f.flatten().unsqueeze(0).double(), ref_f.flatten().unsqueeze(0).double() + ).item() + print(f"Cosine similarity: {cos_sim:.6f}") + + peak = float(ref_f.abs().max()) + if out_dtype in ("bf16", "f16"): + torch.testing.assert_close(c_out_f, ref_f, rtol=2e-2, atol=max(5e-2, 2e-2 * peak)) + else: + torch.testing.assert_close(c_out_f, ref_f, rtol=1e-3, atol=max(1e-2, K * 0.6)) + print("PASSED") + + +@pytest.mark.parametrize("out_dtype", ["bf16", "f32"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers", + [ + (256, 256, 512, 256, 256, 128, 2, 2, 4), # deep-pipeline eligible + (128, 256, 512, 128, 256, 128, 2, 2, 4), # quadrant fallback + ], +) +def test_ptpc_fp8_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, out_dtype): + _run_ptpc_gemm_test(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, out_dtype) + + +@pytest.mark.parametrize("lda_pad, ldc_pad", [(128, 0), (0, 256), (128, 256)]) +def test_ptpc_fp8_gemm_strided(lda_pad, ldc_pad): + """Strided A/C: data backed by a wider leading dim, passed via runtime lda/ldc.""" + _run_ptpc_gemm_test( + 128, 256, 512, 128, 256, 128, 2, 2, num_buffers=4, out_dtype="bf16", lda_pad=lda_pad, ldc_pad=ldc_pad + ) + + +@pytest.mark.parametrize("split_k", [2, 4]) +@pytest.mark.parametrize("out_dtype", ["bf16", "f32"]) +def test_ptpc_fp8_gemm_splitk(split_k, out_dtype): + """PTPC split-K: each chunk applies sa*sb then atomic-adds; sum stays correct.""" + _run_ptpc_gemm_test(128, 256, 2048, 128, 256, 128, 2, 4, num_buffers=2, out_dtype=out_dtype, split_k=split_k) + + +@pytest.mark.parametrize("out_dtype", ["bf16", "f32"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers", + [ + (128, 256, 512, 128, 256, 128, 2, 4, 2), # row-major (a8w4) + wave-spec TDM + (128, 256, 1024, 128, 256, 256, 2, 4, 3), + ], +) +def test_ptpc_a8w4_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, out_dtype): + """PTPC A8W4 (FP8 act + FP4 weight): K-loop uses identity-scale f8f6f4 WMMA; + real per-token/per-channel sa*sb is applied in the epilogue.""" + _run_ptpc_gemm_test(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, out_dtype, data_format="a8w4") + + +@pytest.mark.parametrize("split_k", [2, 4]) +def test_ptpc_a8w4_gemm_splitk(split_k): + """PTPC A8W4 split-K: identity-scale K-loop + epilogue sa*sb + atomic add.""" + _run_ptpc_gemm_test( + 128, 256, 2048, 128, 256, 128, 2, 4, num_buffers=2, out_dtype="bf16", split_k=split_k, data_format="a8w4" + ) + + +# --------------------------------------------------------------------------- +# Non-tile-aligned M (the default, no host M-padding): A/C (and ptpc sa) are +# allocated at the real M. A-load TDM skips rows>=M, sa buffer_load OOB->0, C +# buffer_store clips via num_records. N,K stay tile-aligned. +# --------------------------------------------------------------------------- +_DT = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} +_MPAD_MS = [1, 16, 31, 64, 65, 100, 127, 128, 129, 130, 192, 255, 256, 257, 384, 500, 1000, 2048] + + +def _assert_mpad(c_real, ref, out_dtype): + c = c_real.float() + ref_f = ref.to(_DT[out_dtype]).float() + peak = float(ref_f.abs().max()) + if out_dtype in ("bf16", "f16"): + torch.testing.assert_close(c, ref_f, rtol=2e-2, atol=max(5e-2, 2e-2 * peak)) + else: + torch.testing.assert_close(c, ref_f, rtol=1e-3, atol=max(1e-2, ref.shape[-1] * 0.6)) + + +def _run_ptpc_mpad( + M, + N, + K, + *, + data_format="fp8", + out_dtype="bf16", + split_k=1, + tile_m=128, + tile_n=128, + tile_k=128, + m_warp=2, + n_warp=2, + num_buffers=4, + cluster_m=1, + cluster_n=1, +): + arch = str(get_rocm_arch()) + if arch != "gfx1250": + pytest.skip(f"requires gfx1250, got {arch}") + assert N % tile_n == 0 and K % tile_k == 0, "M-pad test keeps N,K tile-aligned" + # split_k atomic-adds at output precision (per-lane predicate on row < M). + kernel_out_dtype = out_dtype + torch.manual_seed(0) + a = random_fp8_data(M, K) + b = fp4_utils.random_fp4_packed(N, K) if data_format == "a8w4" else random_fp8_data(N, K) + sa = (0.5 + torch.rand(M, dtype=torch.float32)).contiguous() + sb = (0.5 + torch.rand(N, dtype=torch.float32)).contiguous() + ref = reference_ptpc_gemm(data_format, a, b, sa, sb, M, N, K) + pack_b = 2 if data_format == "a8w4" else 1 + b_ps = fp4_utils.preshuffle_b_16x16(b, N, K // pack_b) + c_gpu = torch.zeros(M, N, dtype=_DT[kernel_out_dtype], device="cuda") # real M; zero for atomic + launch = compile_ptpc_gemm( + N=N, + K=K, + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + out_dtype=kernel_out_dtype, + split_k=split_k, + cluster_m=cluster_m, + cluster_n=cluster_n, + ) + launch(c_gpu, a.cuda(), b_ps.cuda(), sa.cuda(), sb.cuda(), M, N, K, N, torch.cuda.current_stream()) + torch.cuda.synchronize() + _assert_mpad(c_gpu[:M].cpu(), ref, kernel_out_dtype) + + +def _run_mxscale_mpad( + M, + N, + K, + *, + out_dtype="bf16", + use_tdm_store=True, + tile_m=128, + tile_n=128, + tile_k=128, + m_warp=2, + n_warp=2, + num_buffers=4, + cluster_m=1, + cluster_n=1, +): + arch = str(get_rocm_arch()) + if arch != "gfx1250": + pytest.skip(f"requires gfx1250, got {arch}") + assert N % tile_n == 0 and K % tile_k == 0, "M-pad test keeps N,K tile-aligned" + torch.manual_seed(0) + a = random_fp8_data(M, K) + b = random_fp8_data(N, K) + a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) # real M, unpadded + b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + ref = reference_mxfp8_gemm(a, b, a_scale, b_scale, M, N, K) + skt = tile_k // SCALE_BLOCK + # a_scale stays UNPADDED host-side; preshuffle pads rows to tile_m (the GEMM + # reads tile_m-granular scale tiles for the partial last M-tile). N is aligned. + as_ps = preshuffle_e8m0_scale(a_scale, tile_m // m_warp, scale_k_per_tile=skt, row_align=tile_m) + bs_ps = preshuffle_e8m0_scale(b_scale, tile_n // n_warp, scale_k_per_tile=skt) + b_ps = fp4_utils.preshuffle_b_16x16(b, N, K) + c_gpu = torch.zeros(M, N, dtype=_DT[out_dtype], device="cuda") # real M + launch = compile_mxscale_gemm( + data_format="fp8", + N=N, + K=K, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + out_dtype=out_dtype, + use_tdm_store=use_tdm_store, + cluster_m=cluster_m, + cluster_n=cluster_n, + ) + launch(c_gpu, a.cuda(), b_ps.cuda(), as_ps.cuda(), bs_ps.cuda(), M, N, K, N, torch.cuda.current_stream()) + torch.cuda.synchronize() + _assert_mpad(c_gpu[:M].cpu(), ref, out_dtype) + + +@pytest.mark.parametrize("out_dtype", ["bf16", "f32"]) +@pytest.mark.parametrize("M", _MPAD_MS) +def test_ptpc_fp8_gemm_mpad(M, out_dtype): + _run_ptpc_mpad(M, 256, 512, out_dtype=out_dtype) + + +@pytest.mark.parametrize("M", _MPAD_MS) +def test_ptpc_a8w4_gemm_mpad(M): + _run_ptpc_mpad(M, 256, 512, data_format="a8w4", m_warp=2, n_warp=4, num_buffers=2) + + +@pytest.mark.parametrize("use_tdm_store", [True, False]) +@pytest.mark.parametrize("out_dtype", ["bf16", "f32"]) +@pytest.mark.parametrize("M", _MPAD_MS) +def test_mxfp8_gemm_mpad(M, out_dtype, use_tdm_store): + _run_mxscale_mpad(M, 256, 512, out_dtype=out_dtype, use_tdm_store=use_tdm_store) + + +@pytest.mark.parametrize("split_k", [2, 4]) +@pytest.mark.parametrize("M", [1, 64, 129, 192, 257, 500]) +def test_ptpc_fp8_gemm_splitk_mpad(M, split_k): + # split_k atomic output predicated per-lane on row < M (auto buffer/atomic path). + _run_ptpc_mpad(M, 256, 2048, m_warp=2, n_warp=4, num_buffers=2, split_k=split_k) + + +# Tile/warp-config diversity: the per-warp partial-tile clip uses +# warp_tile_m = tile_m // m_warp, so M must be exercised against different warp +# boundaries. Existing mpad tests are all m_warp=2 (warp_tile_m=64); these add +# warp_tile_m in {128 (single M-warp / tile_m=256), 32 (fine 4-way split)}. +_MPAD_WARP_CFGS = [ + # (tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers) + (128, 128, 128, 1, 4, 4), # warp_tile_m=128: single M-warp, no M split + (128, 128, 128, 4, 2, 2), # warp_tile_m=32: fine-grained M warps + (256, 128, 128, 2, 2, 2), # tile_m=256, warp_tile_m=128 +] +# Boundary-diverse M for warp_tile_m in {32, 128}: partial/full/OOB warps + aligned. +_MPAD_WARP_MS = [1, 33, 64, 100, 129, 200, 256, 333] + + +@pytest.mark.parametrize("tile_m,tile_n,tile_k,m_warp,n_warp,num_buffers", _MPAD_WARP_CFGS) +@pytest.mark.parametrize("M", _MPAD_WARP_MS) +def test_ptpc_fp8_gemm_mpad_warps(M, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers): + _run_ptpc_mpad( + M, + 256, + 512, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + ) + + +# M=100 -> grid_m 1->2, tile1 fully OOB (rows>=100) under M-multicast +# M=129,200,450 -> partial last M-tile, grid divisible +# M=256,512 -> tile-aligned +# M=257,300 -> grid_m 3->4 (rounded); M=300 also makes tile3 fully OOB +_MPAD_CLUSTER_MS = [100, 129, 200, 256, 257, 300, 450, 512] +_MPAD_CLUSTERS = [(2, 2), (2, 4)] + + +@pytest.mark.parametrize("cluster_m,cluster_n", _MPAD_CLUSTERS) +@pytest.mark.parametrize("M", _MPAD_CLUSTER_MS) +def test_ptpc_fp8_gemm_mpad_cluster(M, cluster_m, cluster_n): + _run_ptpc_mpad(M, 512, 512, m_warp=2, n_warp=2, num_buffers=2, cluster_m=cluster_m, cluster_n=cluster_n) + + +@pytest.mark.parametrize("cluster_m,cluster_n", _MPAD_CLUSTERS) +@pytest.mark.parametrize("M", _MPAD_CLUSTER_MS) +def test_ptpc_a8w4_gemm_mpad_cluster(M, cluster_m, cluster_n): + _run_ptpc_mpad( + M, 512, 512, data_format="a8w4", m_warp=2, n_warp=4, num_buffers=2, cluster_m=cluster_m, cluster_n=cluster_n + ) + + +@pytest.mark.parametrize("use_tdm_store", [True, False]) +@pytest.mark.parametrize("cluster_m,cluster_n", _MPAD_CLUSTERS) +@pytest.mark.parametrize("M", _MPAD_CLUSTER_MS) +def test_mxfp8_gemm_mpad_cluster(M, cluster_m, cluster_n, use_tdm_store): + _run_mxscale_mpad( + M, + 512, + 512, + m_warp=2, + n_warp=2, + num_buffers=2, + cluster_m=cluster_m, + cluster_n=cluster_n, + use_tdm_store=use_tdm_store, + ) + + +@pytest.mark.parametrize("split_k", [2, 4]) +@pytest.mark.parametrize("M", [100, 129, 256, 300, 450]) +def test_ptpc_fp8_gemm_splitk_mpad_cluster(M, split_k): + # split_k atomic output (per-lane row1. + _run_ptpc_mpad(M, 512, 2048, m_warp=2, n_warp=2, num_buffers=2, split_k=split_k, cluster_m=2, cluster_n=2) + + +@pytest.mark.parametrize("cluster_m,cluster_n", [(2, 2), (2, 4)]) +@pytest.mark.parametrize("M", [100, 300, 512, 600, 700, 1024]) +def test_ptpc_fp8_gemm_mpad_cluster_tm256(M, cluster_m, cluster_n): + _run_ptpc_mpad( + M, + 1024, + 512, + tile_m=256, + tile_n=256, + m_warp=2, + n_warp=2, + num_buffers=2, + cluster_m=cluster_m, + cluster_n=cluster_n, + ) + + +@pytest.mark.parametrize("use_tdm_store", [True, False]) +@pytest.mark.parametrize("cluster_m,cluster_n", [(2, 2), (2, 4)]) +@pytest.mark.parametrize("M", [100, 300, 512, 600, 700, 1024]) +def test_mxfp8_gemm_mpad_cluster_tm256(M, cluster_m, cluster_n, use_tdm_store): + _run_mxscale_mpad( + M, + 1024, + 512, + tile_m=256, + tile_n=256, + m_warp=2, + n_warp=2, + num_buffers=2, + cluster_m=cluster_m, + cluster_n=cluster_n, + use_tdm_store=use_tdm_store, + ) + + def _run_benchmark(args): """Benchmark mode: compile once, time kernel execution with proper methodology.""" import time @@ -1211,13 +1718,18 @@ def _run_benchmark(args): is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" + is_ptpc = getattr(args, "scale_mode", "mxscale") == "ptpc" + if is_ptpc and data_format not in ("fp8", "a8w4"): + raise ValueError(f"scale_mode='ptpc' only supports data_format='fp8' or 'a8w4', got {data_format!r}") _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} - # split_k>1 accumulates partial K-sums in fp32 for precision; bf16/f16 atomics are - # supported but compound rounding error, so we run f32 and convert back on the host. - kernel_out_dtype = "f32" if (args.split_k > 1 and args.out_dtype in ("bf16", "f16")) else args.out_dtype + # split_k atomic-adds at output precision (bf16/f16). + kernel_out_dtype = args.out_dtype torch_kernel_dtype = _dtype_map[kernel_out_dtype] elem_bytes_d = 2 if kernel_out_dtype in ("bf16", "f16") else 4 - fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") + if is_ptpc: + fmt_name = "PTPC-A8W4" if is_a8w4 else "PTPC-FP8" + else: + fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") print("=" * 72) print(f" {fmt_name} GEMM Benchmark on gfx1250") @@ -1237,23 +1749,74 @@ def _run_benchmark(args): l2_flush_label = "OFF (graph)" if getattr(args, "use_graph", False) else ("OFF" if args.no_flush_l2 else "ON") print(f" Warmup={args.warmup}, Iters={args.iters}, L2 flush={l2_flush_label}") print(" Output init: zero before warmup") + if is_ptpc: + # compile_ptpc_gemm forces these internally; flag the ones the user set off-default. + _ptpc_ignored = [] + if args.no_tdm_store: + _ptpc_ignored.append("--no-tdm-store") + if not args.wave_spec_tdm: + _ptpc_ignored.append("--no-wave-spec-tdm") + if args.use_scale_opsel: + _ptpc_ignored.append("--use-scale-opsel") + if args.scale_load_path != "tdm": + _ptpc_ignored.append(f"--scale-load-path {args.scale_load_path}") + if args.b_streaming: + _ptpc_ignored.append("--b-streaming") + if _ptpc_ignored: + print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) torch.manual_seed(0) - a, b, a_scale, b_scale, fill_spec = _fill_mode_inputs(M, N, K, data_format, getattr(args, "fill_mode", "random")) - print(f" Fill mode: {_fill_mode_label(fill_spec, data_format)}") - - a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) - - skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp - _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) + if is_ptpc: + # PTPC: fp8 A with fp32 per-token (sa[M]) / per-channel (sb[N]) scales, no scale preshuffle. + # B is fp8 (data_format="fp8") or FP4-packed 2-per-byte (data_format="a8w4"). + K_packed_b = padded_k // PACK_B + b_kind = "fp4 (a8w4)" if is_a8w4 else "fp8" + fill_spec = _parse_fill_mode(getattr(args, "fill_mode", "random")) + if fill_spec[0] == "const": + value = fill_spec[1] + fp8_byte = _fp8_e4m3fn_byte(value) + a_raw = torch.full((M, K), fp8_byte, dtype=torch.uint8) + b_raw = _fp4_e2m1_packed_fill(N, K, value) if is_a8w4 else torch.full((N, K), fp8_byte, dtype=torch.uint8) + # Neutral per-token/per-channel scales so the const output stays predictable. + a_scale = torch.zeros(padded_m, dtype=torch.float32) + a_scale[:M] = 1.0 + b_scale = torch.zeros(padded_n, dtype=torch.float32) + b_scale[:N] = 1.0 + if is_a8w4: + eff_b = _nearest_mxfp4_value(value) + b_note = f"fp4 B={eff_b:g}" + (f" (snapped from {value:g})" if eff_b != value else "") + else: + b_note = "fp8 B" + print(f" Fill mode: const={value:g} (FP8 byte=0x{fp8_byte:02x}), {b_note}, sa=sb=1.0") + else: + a_raw = random_fp8_data(M, K) + b_raw = fp4_utils.random_fp4_packed(N, K) if is_a8w4 else random_fp8_data(N, K) + a_scale = torch.zeros(padded_m, dtype=torch.float32) + a_scale[:M] = 0.5 + torch.rand(M, dtype=torch.float32) + b_scale = torch.zeros(padded_n, dtype=torch.float32) + b_scale[:N] = 0.5 + torch.rand(N, dtype=torch.float32) + print(f" Fill mode: random fp8 A / {b_kind} B, fp32 per-token/per-channel scales") + a = _pad_2d_tensor(a_raw, padded_m, padded_k, fill_value=0) + b = _pad_2d_tensor(b_raw, padded_n, K_packed_b, fill_value=0) + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed_b) + else: + a, b, a_scale, b_scale, fill_spec = _fill_mode_inputs( + M, N, K, data_format, getattr(args, "fill_mode", "random") + ) + print(f" Fill mode: {_fill_mode_label(fill_spec, data_format)}") - K_packed = padded_k // PACK_B - b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) + a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) + + skt = tile_k // SCALE_BLOCK + _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) + + K_packed = padded_k // PACK_B + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) a_gpu = a.cuda() b_gpu = b.cuda() @@ -1267,32 +1830,54 @@ def _run_benchmark(args): if args.split_k > 1 and use_tdm_store: print(" Note: split-K forces buffer-store atomic epilogue; disabling TDM store.") use_tdm_store = False - launch_fn = compile_mxscale_gemm( - data_format=data_format, - M=padded_m, - N=padded_n, - K=padded_k, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=args.m_warp, - n_warp=args.n_warp, - num_buffers=args.num_buffers, - waves_per_eu=args.waves_per_eu, - l2_prefetch_distance=args.l2_prefetch_distance, - cluster_m=args.cluster_m, - cluster_n=args.cluster_n, - use_tdm_store=use_tdm_store, - out_dtype=kernel_out_dtype, - inst_prefetch=args.inst_prefetch, - wave_specialized_tdm=args.wave_spec_tdm, - split_k=args.split_k, - use_scale_opsel=args.use_scale_opsel, - expert_sched_mode=args.expert_sched_mode, - atomic_barrier_enable=args.atomic_barrier_enable, - b_streaming=args.b_streaming, - scale_load_path=args.scale_load_path, - ) + if is_ptpc: + # compile_ptpc_gemm fixes scale_mode/wave_spec/use_tdm_store internally. + launch_fn = compile_ptpc_gemm( + N=padded_n, + K=padded_k, + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=args.m_warp, + n_warp=args.n_warp, + num_buffers=args.num_buffers, + waves_per_eu=args.waves_per_eu, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + out_dtype=kernel_out_dtype, + inst_prefetch=args.inst_prefetch, + expert_sched_mode=args.expert_sched_mode, + atomic_barrier_enable=args.atomic_barrier_enable, + split_k=args.split_k, + ) + else: + launch_fn = compile_mxscale_gemm( + data_format=data_format, + N=padded_n, + K=padded_k, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=args.m_warp, + n_warp=args.n_warp, + num_buffers=args.num_buffers, + waves_per_eu=args.waves_per_eu, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + use_tdm_store=use_tdm_store, + out_dtype=kernel_out_dtype, + inst_prefetch=args.inst_prefetch, + wave_specialized_tdm=args.wave_spec_tdm, + split_k=args.split_k, + use_scale_opsel=args.use_scale_opsel, + expert_sched_mode=args.expert_sched_mode, + atomic_barrier_enable=args.atomic_barrier_enable, + b_streaming=args.b_streaming, + scale_load_path=args.scale_load_path, + ) compiled_exe = flyc.compile( launch_fn, @@ -1303,6 +1888,8 @@ def _run_benchmark(args): bs_gpu, padded_m, padded_n, + padded_k, + padded_n, torch.cuda.current_stream(), ) @@ -1318,6 +1905,8 @@ def run_kernel(): bs_gpu, padded_m, padded_n, + padded_k, + padded_n, torch.cuda.current_stream(), ) @@ -1345,7 +1934,7 @@ def run_kernel(): bytes_a = padded_m * padded_k // PACK_A bytes_b = padded_n * padded_k // PACK_B - bytes_scale = (padded_m + padded_n) * padded_shape["K_scale"] + bytes_scale = (padded_m + padded_n) * (4 if is_ptpc else padded_shape["K_scale"]) bytes_d = padded_m * padded_n * elem_bytes_d read_bytes = bytes_a + bytes_b + bytes_scale write_bytes = bytes_d @@ -1447,15 +2036,13 @@ def _run_graph_verify(args): as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} - # split_k>1 accumulates partial K-sums in fp32 for precision; bf16/f16 atomics are - # supported but compound rounding error, so we run f32 and convert back on the host. - kernel_out_dtype = "f32" if (args.split_k > 1 and args.out_dtype in ("bf16", "f16")) else args.out_dtype + # split_k atomic-adds at output precision (bf16/f16). + kernel_out_dtype = args.out_dtype c_gpu = torch.zeros(padded_m, padded_n, dtype=_dtype_map[kernel_out_dtype], device="cuda") use_tdm_store = not args.no_tdm_store and args.split_k == 1 launch_fn = compile_mxscale_gemm( data_format=data_format, - M=padded_m, N=padded_n, K=padded_k, tile_m=tile_m, @@ -1494,11 +2081,24 @@ def _run_graph_verify(args): bs_flat, padded_m, padded_n, + padded_k, + padded_n, torch.cuda.current_stream(), ) def launch(): - compiled_exe(c_flat, a_flat, b_flat, as_flat, bs_flat, padded_m, padded_n, torch.cuda.current_stream()) + compiled_exe( + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, + padded_m, + padded_n, + padded_k, + padded_n, + torch.cuda.current_stream(), + ) c_gpu.zero_() launch() @@ -1549,6 +2149,14 @@ def launch(): parser = argparse.ArgumentParser() parser.add_argument("--data-format", type=str, default="fp8", choices=["fp4", "fp8", "a8w4"]) + parser.add_argument( + "--scale-mode", + type=str, + default="mxscale", + choices=["mxscale", "ptpc"], + help="Scale organization: 'mxscale' (E8M0 block scale) or 'ptpc' " + "(per-token/per-channel fp32; supports --data-format fp8 or a8w4).", + ) parser.add_argument("-M", type=int, default=1024) parser.add_argument("-N", type=int, default=1024) parser.add_argument("-K", type=int, default=2048) @@ -1614,12 +2222,33 @@ def launch(): ) args = parser.parse_args() + if args.scale_mode == "ptpc" and args.verify_graph: + raise SystemExit("--scale-mode ptpc does not support --verify-graph") + if args.verify_graph: _run_graph_verify(args) if not args.benchmark: sys.exit(0) if args.benchmark: _run_benchmark(args) + elif args.scale_mode == "ptpc": + _run_ptpc_gemm_test( + args.M, + args.N, + args.K, + args.tile_m, + args.tile_n, + args.tile_k, + args.m_warp, + args.n_warp, + num_buffers=args.num_buffers, + out_dtype=args.out_dtype, + data_format=args.data_format, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + split_k=args.split_k, + ) else: use_tdm_store = not args.no_tdm_store and args.split_k == 1 _run_mxscale_gemm_test(