Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def gen_pa_fwd_asm(
return torch.empty_like(Q)


@compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake)
@compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake, develop=True)
def pa_fwd_naive(
# [num_seqs, num_heads, head_size]
query: torch.Tensor,
Expand Down Expand Up @@ -652,7 +652,7 @@ def paged_attention_ragged(
MD_NAME = "module_mla_asm"


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.

Suggested change
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
@compile_ops(MD_NAME, ffi_type="ctypes")

Copilot uses AI. Check for mistakes.
def mla_decode_stage1_asm_fwd(
# [num_seqs, num_heads, head_size]
Q: torch.Tensor,
Expand Down Expand Up @@ -688,7 +688,7 @@ def mla_decode_stage1_asm_fwd(
) -> None: ...


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
def mla_prefill_asm_fwd(
Comment on lines +691 to 692
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.

Copilot uses AI. Check for mistakes.
# [num_seqs, num_heads, head_size]
Q: torch.Tensor,
Expand Down Expand Up @@ -873,7 +873,7 @@ def get_ps_metadata_v1(
) -> None: ...


@compile_ops(MD_NAME, ffi_type="ctypes")
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_ops(..., ffi_type="ctypes") currently ignores the develop flag (it only affects the pybind wrapper). Setting develop=True here is misleading; either remove it or extend compile_ops to apply the intended behavior for ctypes as well.

Suggested change
@compile_ops(MD_NAME, ffi_type="ctypes", develop=True)
@compile_ops(MD_NAME, ffi_type="ctypes")

Copilot uses AI. Check for mistakes.
def mla_prefill_ps_asm_fwd(
Q: torch.Tensor,
K: torch.Tensor,
Expand Down Expand Up @@ -933,32 +933,24 @@ def get_mla_metadata_info_v1(
)
)

# In sparse mode, each expanded batch has 1 Q token
effective_seqlen_qo = 1 if is_sparse else max_seqlen_qo
max_qo_tiles_per_batch = (
int(math.ceil(effective_seqlen_qo * num_head_qo / 128))
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16
or (
get_gfx() == "gfx942"
and num_head_qo == 128
and kv_dtype == dtypes.fp8
and q_dtype == dtypes.fp8
)
or (
get_gfx() == "gfx950"
and (num_head_qo * effective_seqlen_qo) % 128 == 0
and kv_dtype == dtypes.bf16
and q_dtype == dtypes.bf16
)
or (
get_gfx() == "gfx950"
and num_head_qo == 64
and q_dtype == dtypes.fp8
and kv_dtype == dtypes.fp8
and effective_seqlen_qo == 1
and max_seqlen_qo == 1
)
or use_qseqlen_fold
else int(math.ceil(effective_seqlen_qo * num_head_qo / 16))
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
tile_cnt = batch_size * max_qo_tiles_per_batch
Comment on lines 936 to 956
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For is_sparse=True, batch_size is expanded by max_seqlen_qo, but max_qo_tiles_per_batch is also computed using max_seqlen_qo. This effectively double-counts the sequence length (tile count grows ~O(max_seqlen_qo^2)) and contradicts the sparse-mode assumption that each expanded batch has 1 Q token. Consider restoring an effective_seqlen_qo = 1 if is_sparse else max_seqlen_qo and using it consistently in the tile math.

Copilot uses AI. Check for mistakes.
Expand Down
2 changes: 2 additions & 0 deletions aiter/utility/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def get_dtype_fp8():
i4x2 = getattr(torch, "int4", _8bit_fallback)
fp4x2 = getattr(torch, "float4_e2m1fn_x2", _8bit_fallback)
fp8 = get_dtype_fp8()
fp8_e4m3fn = torch.float8_e4m3fn
fp8_e4m3fnuz = torch.float8_e4m3fnuz
fp8_e8m0 = getattr(torch, "float8_e8m0fnu", _8bit_fallback)
fp16 = torch.float16
bf16 = torch.bfloat16
Expand Down
8 changes: 7 additions & 1 deletion csrc/include/aiter_enum.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ enum class QuantType : int
};
typedef enum
{
AITER_DTYPE_fp8,
AITER_DTYPE_fp8, // Ambiguous dtype, use the corresponding type per device.
AITER_DTYPE_fp8_e4m3fn,
AITER_DTYPE_fp8_e4m3fnuz,
AITER_DTYPE_fp8_e8m0,
AITER_DTYPE_fp16,
AITER_DTYPE_bf16,
Expand All @@ -45,6 +47,8 @@ static inline size_t AiterDtype_element_size(AiterDtype dtype)
switch(dtype)
{
case AITER_DTYPE_fp8:
case AITER_DTYPE_fp8_e4m3fn:
case AITER_DTYPE_fp8_e4m3fnuz:
case AITER_DTYPE_fp8_e8m0:
case AITER_DTYPE_i4x2:
case AITER_DTYPE_fp4x2:
Expand All @@ -67,6 +71,8 @@ static inline std::string AiterDtype_to_str(int dtype)
switch(dtype)
{
case AITER_DTYPE_fp8: return "fp8";
case AITER_DTYPE_fp8_e4m3fn: return "fp8_e4m3fn";
case AITER_DTYPE_fp8_e4m3fnuz: return "fp8_e4m3fnuz";
case AITER_DTYPE_fp8_e8m0: return "fp8_e8m0";
case AITER_DTYPE_fp16: return "fp16";
case AITER_DTYPE_bf16: return "bf16";
Expand Down
15 changes: 8 additions & 7 deletions csrc/include/attention.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>
#include "aiter_tensor.h"
#include <string>

void paged_attention(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int64_t num_kv_heads, double scale,
torch::Tensor &block_tables, torch::Tensor &context_lens,
const aiter_tensor_t &out, const aiter_tensor_t &exp_sums, const aiter_tensor_t &max_logits,
const aiter_tensor_t &tmp_out, const aiter_tensor_t &query, const aiter_tensor_t &key_cache,
const aiter_tensor_t &value_cache, int64_t num_kv_heads, double scale,
const aiter_tensor_t &block_tables, const aiter_tensor_t &context_lens,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor> &alibi_slopes,
const aiter_tensor_t *alibi_slopes,
const std::string &kv_cache_dtype, double k_scale, double v_scale,
const std::optional<torch::Tensor> &fp8_out_scale, int64_t partition_size);
const aiter_tensor_t *fp8_out_scale, int64_t partition_size);
30 changes: 14 additions & 16 deletions csrc/include/attention_ragged.h
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>
#include "aiter_tensor.h"
#include <optional>
#include <string>

void paged_attention_ragged(
torch::Tensor &out, // [num_seqs, num_heads, head_size]
torch::Tensor &workspace_buffer,
torch::Tensor &query, // [num_seqs, num_heads, head_size]
torch::Tensor
&key_cache, // [num_blocks, num_heads, block_size, head_size] or
// [num_blocks, block_size, num_heads, head_size]
torch::Tensor
&value_cache, // [num_blocks, num_heads, block_size, head_size] or
// [num_blocks, block_size, num_heads, head_size]
const aiter_tensor_t &out, // [num_seqs, num_heads, head_size]
const aiter_tensor_t &workspace_buffer,
const aiter_tensor_t &query, // [num_seqs, num_heads, head_size]
const aiter_tensor_t &key_cache,
const aiter_tensor_t &value_cache,
double scale,
torch::Tensor &kv_indptr, // [num_seqs + 1]
torch::Tensor &kv_page_indices, // [max_num_blocks]
std::optional<torch::Tensor> &kv_last_page_lens, // [num_seqs]
const aiter_tensor_t &kv_indptr, // [num_seqs + 1]
const aiter_tensor_t &kv_page_indices, // [max_num_blocks]
const aiter_tensor_t *kv_last_page_lens, // [num_seqs]
int64_t block_size, int64_t max_num_partitions,
const std::optional<torch::Tensor> &alibi_slopes,
const aiter_tensor_t *alibi_slopes,
const std::string &kv_cache_dtype, const std::string &kv_cache_layout,
float logits_soft_cap, torch::Tensor &k_scale, torch::Tensor &v_scale,
const std::optional<torch::Tensor> &fp8_out_scale, int64_t partition_size);
float logits_soft_cap, const aiter_tensor_t &k_scale, const aiter_tensor_t &v_scale,
const aiter_tensor_t *fp8_out_scale, int64_t partition_size);
31 changes: 15 additions & 16 deletions csrc/include/attention_v1.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>
#include "aiter_tensor.h"
#include <optional>
#include <string>

void paged_attention_v1(
torch::Tensor &out, // [num_seqs, num_heads, head_size]
torch::Tensor &workspace_buffer,
torch::Tensor &query, // [num_seqs, num_heads, head_size]
torch::Tensor
&key_cache, // [num_blocks, num_heads, block_size, head_size] or
// [num_blocks, block_size, num_heads, head_size]
torch::Tensor
&value_cache, // [num_blocks, num_heads, block_size, head_size] or
// [num_blocks, block_size, num_heads, head_size]
const aiter_tensor_t &out, // [num_seqs, num_heads, head_size]
const aiter_tensor_t &workspace_buffer,
const aiter_tensor_t &query, // [num_seqs, num_heads, head_size]
const aiter_tensor_t &key_cache,
const aiter_tensor_t &value_cache,
double scale,
torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq]
const std::optional<torch::Tensor>& cu_query_lens, // [num_seqs+1]
torch::Tensor &context_lens, // [num_seqs]
int64_t max_context_len, const std::optional<torch::Tensor> &alibi_slopes,
const aiter_tensor_t &block_tables, // [num_seqs, max_num_blocks_per_seq]
const aiter_tensor_t *cu_query_lens, // [num_seqs+1]
const aiter_tensor_t &context_lens, // [num_seqs]
int64_t max_context_len,
const aiter_tensor_t *alibi_slopes,
const std::string &kv_cache_dtype, const std::string &kv_cache_layout,
float logits_soft_cap, torch::Tensor &k_scale, torch::Tensor &v_scale,
const std::optional<torch::Tensor> &fp8_out_scale, int64_t partition_size);
float logits_soft_cap, const aiter_tensor_t &k_scale, const aiter_tensor_t &v_scale,
const aiter_tensor_t *fp8_out_scale, int64_t partition_size);
Loading
Loading