Skip to content

[attention] refactor hip kl#2778

Open
amd-ruitang3 wants to merge 2 commits intomainfrom
attention_hip_kl_refactor
Open

[attention] refactor hip kl#2778
amd-ruitang3 wants to merge 2 commits intomainfrom
attention_hip_kl_refactor

Conversation

@amd-ruitang3
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@amd-ruitang3 amd-ruitang3 requested review from a team and Copilot April 17, 2026 11:15
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2778 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors several attention kernel entrypoints to use AITER’s lightweight tensor/stream abstractions (instead of direct Torch/ATen types), and extends FP8 dtype enumeration support used across the C++/Python boundary.

Changes:

  • Switched attention kernel APIs from torch::Tensor / at::hip::* to aiter_tensor_t, HipDeviceGuard, and aiter::getCurrentHIPStream().
  • Added explicit FP8 E4M3FN / E4M3FNUZ dtype variants to the shared dtype enum and Python dtype exports.
  • Adjusted MLA metadata sizing logic and enabled develop=True on several compile_ops decorators.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
csrc/kernels/attention_v1.cu Moves v1 paged-attention launcher to aiter_tensor_t + AITER stream/guard and updates checks.
csrc/kernels/attention_ragged.cu Moves ragged paged-attention launcher to aiter_tensor_t + AITER stream/guard and updates checks.
csrc/kernels/attention.cu Refactors core paged attention launcher to aiter_tensor_t and AITER stream handling.
csrc/include/attention_v1.h Updates v1 public header API to aiter_tensor_t and pointer-based optionals.
csrc/include/attention_ragged.h Updates ragged public header API to aiter_tensor_t and pointer-based optionals.
csrc/include/attention.h Updates core attention header API to aiter_tensor_t and pointer-based optionals.
csrc/include/aiter_enum.h Extends AiterDtype enum with explicit FP8 E4M3FN/E4M3FNUZ variants.
aiter/utility/dtypes.py Exposes new FP8 dtype aliases in Python.
aiter/ops/attention.py Sets develop=True for several ops and modifies MLA metadata sizing logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 316 to +317
int* kv_last_page_lens_ptr =
BLOCK_SIZE > 1 ? kv_last_page_lens.value().data_ptr<int>() : nullptr;
BLOCK_SIZE > 1 ? reinterpret_cast<int*>(kv_last_page_lens->data_ptr()) : nullptr;
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_last_page_lens is now a raw pointer, but it is dereferenced unconditionally when BLOCK_SIZE > 1. If a caller passes nullptr (which the new signature allows), this will segfault. Either make kv_last_page_lens a reference (required) for block sizes >1, or add an AITER_CHECK(kv_last_page_lens != nullptr, ...) before dereferencing and keep passing nullptr only for BLOCK_SIZE == 1.

Copilot uses AI. Check for mistakes.
Comment thread aiter/ops/attention.py
Comment on lines 936 to 956
max_qo_tiles_per_batch = (
int(math.ceil(effective_seqlen_qo * num_head_qo / 128))
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16
or (
get_gfx() == "gfx942"
and num_head_qo == 128
and kv_dtype == dtypes.fp8
and q_dtype == dtypes.fp8
)
or (
get_gfx() == "gfx950"
and (num_head_qo * effective_seqlen_qo) % 128 == 0
and kv_dtype == dtypes.bf16
and q_dtype == dtypes.bf16
)
or (
get_gfx() == "gfx950"
and num_head_qo == 64
and q_dtype == dtypes.fp8
and kv_dtype == dtypes.fp8
and effective_seqlen_qo == 1
and max_seqlen_qo == 1
)
or use_qseqlen_fold
else int(math.ceil(effective_seqlen_qo * num_head_qo / 16))
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
tile_cnt = batch_size * max_qo_tiles_per_batch
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For is_sparse=True, batch_size is expanded by max_seqlen_qo, but max_qo_tiles_per_batch is also computed using max_seqlen_qo. This effectively double-counts the sequence length (tile count grows ~O(max_seqlen_qo^2)) and contradicts the sparse-mode assumption that each expanded batch has 1 Q token. Consider restoring an effective_seqlen_qo = 1 if is_sparse else max_seqlen_qo and using it consistently in the tile math.

Copilot uses AI. Check for mistakes.
Comment thread aiter/ops/attention.py


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.

Suggested change
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
@compile_ops(MD_NAME, ffi_type="ctypes")

Copilot uses AI. Check for mistakes.
Comment thread aiter/ops/attention.py


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.

Suggested change
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
@compile_ops(MD_NAME, ffi_type="ctypes")

Copilot uses AI. Check for mistakes.
Comment on lines 542 to 544
{
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
AITER_CHECK(false, "Unsupported data type");
}
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.

Copilot uses AI. Check for mistakes.
Comment thread aiter/ops/attention.py
Comment on lines +691 to 692
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
def mla_prefill_asm_fwd(
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.

Copilot uses AI. Check for mistakes.
Comment on lines 480 to 482
{
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
AITER_CHECK(false, "Unsupported data type");
}
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.

Copilot uses AI. Check for mistakes.
Comment on lines 496 to 498
{
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
AITER_CHECK(false, "Unsupported data type");
}
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.

Copilot uses AI. Check for mistakes.
Comment on lines 526 to 528
{
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
AITER_CHECK(false, "Unsupported data type");
}
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error path drops the actual dtype information (AITER_CHECK(false, "Unsupported data type")). Since query.dtype() is an AiterDtype, consider including AiterDtype_to_str(query.dtype()) in the message (as done in csrc/kernels/attention.cu) to make debugging unsupported inputs easier.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants