[attention] refactor hip kl#2778
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR refactors several attention kernel entrypoints to use AITER’s lightweight tensor/stream abstractions (instead of direct Torch/ATen types), and extends FP8 dtype enumeration support used across the C++/Python boundary.
Changes:
- Switched attention kernel APIs from
torch::Tensor/at::hip::*toaiter_tensor_t,HipDeviceGuard, andaiter::getCurrentHIPStream(). - Added explicit FP8 E4M3FN / E4M3FNUZ dtype variants to the shared dtype enum and Python dtype exports.
- Adjusted MLA metadata sizing logic and enabled
develop=Trueon severalcompile_opsdecorators.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/kernels/attention_v1.cu | Moves v1 paged-attention launcher to aiter_tensor_t + AITER stream/guard and updates checks. |
| csrc/kernels/attention_ragged.cu | Moves ragged paged-attention launcher to aiter_tensor_t + AITER stream/guard and updates checks. |
| csrc/kernels/attention.cu | Refactors core paged attention launcher to aiter_tensor_t and AITER stream handling. |
| csrc/include/attention_v1.h | Updates v1 public header API to aiter_tensor_t and pointer-based optionals. |
| csrc/include/attention_ragged.h | Updates ragged public header API to aiter_tensor_t and pointer-based optionals. |
| csrc/include/attention.h | Updates core attention header API to aiter_tensor_t and pointer-based optionals. |
| csrc/include/aiter_enum.h | Extends AiterDtype enum with explicit FP8 E4M3FN/E4M3FNUZ variants. |
| aiter/utility/dtypes.py | Exposes new FP8 dtype aliases in Python. |
| aiter/ops/attention.py | Sets develop=True for several ops and modifies MLA metadata sizing logic. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int* kv_last_page_lens_ptr = | ||
| BLOCK_SIZE > 1 ? kv_last_page_lens.value().data_ptr<int>() : nullptr; | ||
| BLOCK_SIZE > 1 ? reinterpret_cast<int*>(kv_last_page_lens->data_ptr()) : nullptr; |
There was a problem hiding this comment.
kv_last_page_lens is now a raw pointer, but it is dereferenced unconditionally when BLOCK_SIZE > 1. If a caller passes nullptr (which the new signature allows), this will segfault. Either make kv_last_page_lens a reference (required) for block sizes >1, or add an AITER_CHECK(kv_last_page_lens != nullptr, ...) before dereferencing and keep passing nullptr only for BLOCK_SIZE == 1.
| max_qo_tiles_per_batch = ( | ||
| int(math.ceil(effective_seqlen_qo * num_head_qo / 128)) | ||
| int(math.ceil(max_seqlen_qo * num_head_qo / 128)) | ||
| if num_head_qo == 16 | ||
| or ( | ||
| get_gfx() == "gfx942" | ||
| and num_head_qo == 128 | ||
| and kv_dtype == dtypes.fp8 | ||
| and q_dtype == dtypes.fp8 | ||
| ) | ||
| or ( | ||
| get_gfx() == "gfx950" | ||
| and (num_head_qo * effective_seqlen_qo) % 128 == 0 | ||
| and kv_dtype == dtypes.bf16 | ||
| and q_dtype == dtypes.bf16 | ||
| ) | ||
| or ( | ||
| get_gfx() == "gfx950" | ||
| and num_head_qo == 64 | ||
| and q_dtype == dtypes.fp8 | ||
| and kv_dtype == dtypes.fp8 | ||
| and effective_seqlen_qo == 1 | ||
| and max_seqlen_qo == 1 | ||
| ) | ||
| or use_qseqlen_fold | ||
| else int(math.ceil(effective_seqlen_qo * num_head_qo / 16)) | ||
| else int(math.ceil(max_seqlen_qo * num_head_qo / 16)) | ||
| ) | ||
| batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size | ||
| tile_cnt = batch_size * max_qo_tiles_per_batch |
There was a problem hiding this comment.
For is_sparse=True, batch_size is expanded by max_seqlen_qo, but max_qo_tiles_per_batch is also computed using max_seqlen_qo. This effectively double-counts the sequence length (tile count grows ~O(max_seqlen_qo^2)) and contradicts the sparse-mode assumption that each expanded batch has 1 Q token. Consider restoring an effective_seqlen_qo = 1 if is_sparse else max_seqlen_qo and using it consistently in the tile math.
|
|
||
|
|
||
| @compile_ops(MD_NAME, ffi_type="ctypes") | ||
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) |
There was a problem hiding this comment.
compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | |
| @compile_ops(MD_NAME, ffi_type="ctypes") |
|
|
||
|
|
||
| @compile_ops(MD_NAME, ffi_type="ctypes") | ||
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) |
There was a problem hiding this comment.
compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | |
| @compile_ops(MD_NAME, ffi_type="ctypes") |
| { | ||
| TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||
| AITER_CHECK(false, "Unsupported data type"); | ||
| } |
There was a problem hiding this comment.
This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | ||
| def mla_prefill_asm_fwd( |
There was a problem hiding this comment.
compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.
| { | ||
| TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||
| AITER_CHECK(false, "Unsupported data type"); | ||
| } |
There was a problem hiding this comment.
This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.
| { | ||
| TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||
| AITER_CHECK(false, "Unsupported data type"); | ||
| } |
There was a problem hiding this comment.
This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.
| { | ||
| TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||
| AITER_CHECK(false, "Unsupported data type"); | ||
| } |
There was a problem hiding this comment.
This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist