[gfx1250][gemm] Add PTPC FP8/A8W4, non-tile-aligned M, and strided A/C support#649
Open
aoli26 wants to merge 9 commits into
Open
[gfx1250][gemm] Add PTPC FP8/A8W4, non-tile-aligned M, and strided A/C support#649aoli26 wants to merge 9 commits into
aoli26 wants to merge 9 commits into
Conversation
5fea303 to
49a21d4
Compare
- kernel: m_oob_clip + m_oob_store {buffer, tdm_tail}. A/A-scale load clip via
TDM tensor_dim1, C-store clips via buffer num_records, split-K via per-lane
(row < M) predicate on the atomic path.
- tdm_ops: make_tensor_descriptor_2d gains oob_outer_bound. It sets only
tensor_dim1 (HW OOB field); tile_dim1 stays the full per-warp tile. Accepts
int|index|i32, raises otherwise. None keeps the original (byte-identical) path.
- tests: M-pad coverage (M=16..1000 x buffer/tdm_tail x bf16/f32 + split-K).
Remove the m_oob_store parameter from compile_fp8fp4_gemm / compile_ptpc_gemm
and pick the non-aligned-M output clip internally:
tdm_tail when use_tdm_store and split_k == 1 (full tiles keep the fast TDM
store; the <=1 partial last M-tile falls back to buffer num_records)
buffer otherwise (whole-output num_records clip; split_k>1 uses the
per-lane row < M atomic predicate)
A whole-output buffer clip regressed aligned production prefill by +15%..+82%,
while tdm_tail stays within ~2% of the no-clip path, so a static buffer default
was wrong. The choice is fully derivable from use_tdm_store/split_k, so cache_tag
drops m_oob_store too (no collision).
Tests: the mxscale mpad test now parametrizes use_tdm_store to cover both auto
branches (tdm_tail / buffer); the atomic branch stays covered by the split-k mpad
test.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR extends the gfx1250 FP8/FP4 GEMM implementation to support (1) a new PTPC scaling mode (per-token sa[M] and per-channel sb[N], applied in the epilogue) and (2) non-tile-aligned runtime M handling (avoiding host-side padding/copies). It also adds a new ROCDL WMMA wrapper and substantially expands correctness coverage for PTPC and M-tail behavior.
Changes:
- Add
scale_mode="ptpc"support to the unified gfx1250 GEMM kernel, with fp32sa/sbloads and epilogue scaling (including split-K atomic path). - Add non-tile-aligned M (“M-OOB”) support via TDM descriptor bounds and output clipping (TDM store for full tiles, buffer/atomic for tails).
- Expand tests/benchmark CLI to cover PTPC (FP8 + A8W4), split-K, and a wide range of M-tail cases; add a no-scale WMMA wrapper op.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
kernels/gemm_fp8fp4_gfx1250.py |
Implements PTPC scaling mode and M-OOB handling in the kernel (TDM descriptors, epilogue scaling, store/atomic clipping). |
python/flydsl/expr/rocdl/tdm_ops.py |
Extends make_tensor_descriptor_2d with oob_outer_bound to support runtime outer-dim OOB-safe TDM loads. |
python/flydsl/expr/rocdl/__init__.py |
Adds ROCDL builder wrapper for wmma_f32_16x16x128_fp8_fp8. |
python/flydsl/expr/rocdl.py |
Adds the same ROCDL wrapper and exports it via __all__. |
tests/kernels/test_gemm_fp8fp4_gfx1250.py |
Adds PTPC correctness tests, M-tail test matrix, and benchmark/CLI support for PTPC scale mode. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add per-token per-channel (PTPC) scaling to the gfx1250 GEMM kernel: per-token
sa[M]and per-channelsb[N]scales, constant along K, stored as fp32 and applied once in the epilogue instead of per K-block. Also add non-tile-aligned M and strided A/C so the host can pass an unpadded runtime M and arbitrary A/C leading-dim strides directly, eliminating per-call padding allocation and memcpy.Technical Details
sa*sbis applied in fp32 in the epilogue (split-K supported via per-chunk scale + atomic add).make_tensor_descriptor_2dgains anoob_outer_boundparameter.Mis no longer a compile-time arg;launch_fnnow takes runtimelda/ldcleading-dim strides (dense callers passlda == K,ldc == N).make_tensor_descriptor_2daccepts a runtime i32/index outer stride. Aligned/dense callers are byte-identical to before.Test Plan
pytest tests/kernels/test_gemm_fp8fp4_gfx1250.py -k 'ptpc or mpad or strided', plus ISA inspection of the PTPC kernels.Test Result
All PTPC, M-pad, and strided A/C tests pass. ISA confirms scale TDM removal and epilogue prefetch with lower VGPR count and 0 spill.
Submission Checklist