-
Notifications
You must be signed in to change notification settings - Fork 295
[attention] refactor hip kl #2778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -81,7 +81,7 @@ def gen_pa_fwd_asm( | |||||
| return torch.empty_like(Q) | ||||||
|
|
||||||
|
|
||||||
| @compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake) | ||||||
| @compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake, develop=True) | ||||||
| def pa_fwd_naive( | ||||||
| # [num_seqs, num_heads, head_size] | ||||||
| query: torch.Tensor, | ||||||
|
|
@@ -652,7 +652,7 @@ def paged_attention_ragged( | |||||
| MD_NAME = "module_mla_asm" | ||||||
|
|
||||||
|
|
||||||
| @compile_ops(MD_NAME, ffi_type="ctypes") | ||||||
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | ||||||
| def mla_decode_stage1_asm_fwd( | ||||||
| # [num_seqs, num_heads, head_size] | ||||||
| Q: torch.Tensor, | ||||||
|
|
@@ -688,7 +688,7 @@ def mla_decode_stage1_asm_fwd( | |||||
| ) -> None: ... | ||||||
|
|
||||||
|
|
||||||
| @compile_ops(MD_NAME, ffi_type="ctypes") | ||||||
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | ||||||
| def mla_prefill_asm_fwd( | ||||||
|
Comment on lines
+691
to
692
|
||||||
| # [num_seqs, num_heads, head_size] | ||||||
| Q: torch.Tensor, | ||||||
|
|
@@ -873,7 +873,7 @@ def get_ps_metadata_v1( | |||||
| ) -> None: ... | ||||||
|
|
||||||
|
|
||||||
| @compile_ops(MD_NAME, ffi_type="ctypes") | ||||||
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | ||||||
|
||||||
| @compile_ops(MD_NAME, ffi_type="ctypes", develop=True) | |
| @compile_ops(MD_NAME, ffi_type="ctypes") |
Copilot
AI
Apr 17, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For is_sparse=True, batch_size is expanded by max_seqlen_qo, but max_qo_tiles_per_batch is also computed using max_seqlen_qo. This effectively double-counts the sequence length (tile count grows ~O(max_seqlen_qo^2)) and contradicts the sparse-mode assumption that each expanded batch has 1 Q token. Consider restoring an effective_seqlen_qo = 1 if is_sparse else max_seqlen_qo and using it consistently in the tile math.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,15 @@ | ||
| #pragma once | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| #include <torch/extension.h> | ||
| #include "aiter_tensor.h" | ||
| #include <string> | ||
|
|
||
| void paged_attention( | ||
| torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, | ||
| torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, | ||
| torch::Tensor &value_cache, int64_t num_kv_heads, double scale, | ||
| torch::Tensor &block_tables, torch::Tensor &context_lens, | ||
| const aiter_tensor_t &out, const aiter_tensor_t &exp_sums, const aiter_tensor_t &max_logits, | ||
| const aiter_tensor_t &tmp_out, const aiter_tensor_t &query, const aiter_tensor_t &key_cache, | ||
| const aiter_tensor_t &value_cache, int64_t num_kv_heads, double scale, | ||
| const aiter_tensor_t &block_tables, const aiter_tensor_t &context_lens, | ||
| int64_t block_size, int64_t max_context_len, | ||
| const std::optional<torch::Tensor> &alibi_slopes, | ||
| const aiter_tensor_t *alibi_slopes, | ||
| const std::string &kv_cache_dtype, double k_scale, double v_scale, | ||
| const std::optional<torch::Tensor> &fp8_out_scale, int64_t partition_size); | ||
| const aiter_tensor_t *fp8_out_scale, int64_t partition_size); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,22 @@ | ||
| #pragma once | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| #include <torch/extension.h> | ||
| #include "aiter_tensor.h" | ||
| #include <optional> | ||
| #include <string> | ||
|
|
||
| void paged_attention_ragged( | ||
| torch::Tensor &out, // [num_seqs, num_heads, head_size] | ||
| torch::Tensor &workspace_buffer, | ||
| torch::Tensor &query, // [num_seqs, num_heads, head_size] | ||
| torch::Tensor | ||
| &key_cache, // [num_blocks, num_heads, block_size, head_size] or | ||
| // [num_blocks, block_size, num_heads, head_size] | ||
| torch::Tensor | ||
| &value_cache, // [num_blocks, num_heads, block_size, head_size] or | ||
| // [num_blocks, block_size, num_heads, head_size] | ||
| const aiter_tensor_t &out, // [num_seqs, num_heads, head_size] | ||
| const aiter_tensor_t &workspace_buffer, | ||
| const aiter_tensor_t &query, // [num_seqs, num_heads, head_size] | ||
| const aiter_tensor_t &key_cache, | ||
| const aiter_tensor_t &value_cache, | ||
| double scale, | ||
| torch::Tensor &kv_indptr, // [num_seqs + 1] | ||
| torch::Tensor &kv_page_indices, // [max_num_blocks] | ||
| std::optional<torch::Tensor> &kv_last_page_lens, // [num_seqs] | ||
| const aiter_tensor_t &kv_indptr, // [num_seqs + 1] | ||
| const aiter_tensor_t &kv_page_indices, // [max_num_blocks] | ||
| const aiter_tensor_t *kv_last_page_lens, // [num_seqs] | ||
| int64_t block_size, int64_t max_num_partitions, | ||
| const std::optional<torch::Tensor> &alibi_slopes, | ||
| const aiter_tensor_t *alibi_slopes, | ||
| const std::string &kv_cache_dtype, const std::string &kv_cache_layout, | ||
| float logits_soft_cap, torch::Tensor &k_scale, torch::Tensor &v_scale, | ||
| const std::optional<torch::Tensor> &fp8_out_scale, int64_t partition_size); | ||
| float logits_soft_cap, const aiter_tensor_t &k_scale, const aiter_tensor_t &v_scale, | ||
| const aiter_tensor_t *fp8_out_scale, int64_t partition_size); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,23 +1,22 @@ | ||
| #pragma once | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| #include <torch/extension.h> | ||
| #include "aiter_tensor.h" | ||
| #include <optional> | ||
| #include <string> | ||
|
|
||
| void paged_attention_v1( | ||
| torch::Tensor &out, // [num_seqs, num_heads, head_size] | ||
| torch::Tensor &workspace_buffer, | ||
| torch::Tensor &query, // [num_seqs, num_heads, head_size] | ||
| torch::Tensor | ||
| &key_cache, // [num_blocks, num_heads, block_size, head_size] or | ||
| // [num_blocks, block_size, num_heads, head_size] | ||
| torch::Tensor | ||
| &value_cache, // [num_blocks, num_heads, block_size, head_size] or | ||
| // [num_blocks, block_size, num_heads, head_size] | ||
| const aiter_tensor_t &out, // [num_seqs, num_heads, head_size] | ||
| const aiter_tensor_t &workspace_buffer, | ||
| const aiter_tensor_t &query, // [num_seqs, num_heads, head_size] | ||
| const aiter_tensor_t &key_cache, | ||
| const aiter_tensor_t &value_cache, | ||
| double scale, | ||
| torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| const std::optional<torch::Tensor>& cu_query_lens, // [num_seqs+1] | ||
| torch::Tensor &context_lens, // [num_seqs] | ||
| int64_t max_context_len, const std::optional<torch::Tensor> &alibi_slopes, | ||
| const aiter_tensor_t &block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| const aiter_tensor_t *cu_query_lens, // [num_seqs+1] | ||
| const aiter_tensor_t &context_lens, // [num_seqs] | ||
| int64_t max_context_len, | ||
| const aiter_tensor_t *alibi_slopes, | ||
| const std::string &kv_cache_dtype, const std::string &kv_cache_layout, | ||
| float logits_soft_cap, torch::Tensor &k_scale, torch::Tensor &v_scale, | ||
| const std::optional<torch::Tensor> &fp8_out_scale, int64_t partition_size); | ||
| float logits_soft_cap, const aiter_tensor_t &k_scale, const aiter_tensor_t &v_scale, | ||
| const aiter_tensor_t *fp8_out_scale, int64_t partition_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compile_ops(..., ffi_type="ctypes")currently ignores thedevelopflag (it only affects the pybind wrapper). Settingdevelop=Truehere is misleading; either remove it or extendcompile_opsto apply the intended behavior for ctypes as well.