diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 305b8f3f9b..d1c827b0bc 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -81,7 +81,7 @@ def gen_pa_fwd_asm( return torch.empty_like(Q) -@compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake) +@compile_ops("module_attention", gen_fake=gen_pa_fwd_native_fake, develop=True) def pa_fwd_naive( # [num_seqs, num_heads, head_size] query: torch.Tensor, @@ -652,7 +652,7 @@ def paged_attention_ragged( MD_NAME = "module_mla_asm" -@compile_ops(MD_NAME, ffi_type="ctypes") +@compile_ops(MD_NAME, ffi_type="ctypes", develop=True) def mla_decode_stage1_asm_fwd( # [num_seqs, num_heads, head_size] Q: torch.Tensor, @@ -688,7 +688,7 @@ def mla_decode_stage1_asm_fwd( ) -> None: ... -@compile_ops(MD_NAME, ffi_type="ctypes") +@compile_ops(MD_NAME, ffi_type="ctypes", develop=True) def mla_prefill_asm_fwd( # [num_seqs, num_heads, head_size] Q: torch.Tensor, @@ -873,7 +873,7 @@ def get_ps_metadata_v1( ) -> None: ... -@compile_ops(MD_NAME, ffi_type="ctypes") +@compile_ops(MD_NAME, ffi_type="ctypes", develop=True) def mla_prefill_ps_asm_fwd( Q: torch.Tensor, K: torch.Tensor, @@ -933,10 +933,8 @@ def get_mla_metadata_info_v1( ) ) - # In sparse mode, each expanded batch has 1 Q token - effective_seqlen_qo = 1 if is_sparse else max_seqlen_qo max_qo_tiles_per_batch = ( - int(math.ceil(effective_seqlen_qo * num_head_qo / 128)) + int(math.ceil(max_seqlen_qo * num_head_qo / 128)) if num_head_qo == 16 or ( get_gfx() == "gfx942" @@ -944,21 +942,15 @@ def get_mla_metadata_info_v1( and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8 ) - or ( - get_gfx() == "gfx950" - and (num_head_qo * effective_seqlen_qo) % 128 == 0 - and kv_dtype == dtypes.bf16 - and q_dtype == dtypes.bf16 - ) or ( get_gfx() == "gfx950" and num_head_qo == 64 and q_dtype == dtypes.fp8 and kv_dtype == dtypes.fp8 - and effective_seqlen_qo == 1 + and max_seqlen_qo == 1 ) or use_qseqlen_fold - else int(math.ceil(effective_seqlen_qo * num_head_qo / 16)) + else int(math.ceil(max_seqlen_qo * num_head_qo / 16)) ) batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size tile_cnt = batch_size * max_qo_tiles_per_batch diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 5fbe00f9be..6c595749e6 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -23,6 +23,8 @@ def get_dtype_fp8(): i4x2 = getattr(torch, "int4", _8bit_fallback) fp4x2 = getattr(torch, "float4_e2m1fn_x2", _8bit_fallback) fp8 = get_dtype_fp8() +fp8_e4m3fn = torch.float8_e4m3fn +fp8_e4m3fnuz = torch.float8_e4m3fnuz fp8_e8m0 = getattr(torch, "float8_e8m0fnu", _8bit_fallback) fp16 = torch.float16 bf16 = torch.bfloat16 diff --git a/csrc/include/aiter_enum.h b/csrc/include/aiter_enum.h index c36a725f95..9dd5341e57 100644 --- a/csrc/include/aiter_enum.h +++ b/csrc/include/aiter_enum.h @@ -24,7 +24,9 @@ enum class QuantType : int }; typedef enum { - AITER_DTYPE_fp8, + AITER_DTYPE_fp8, // Ambiguous dtype, use the corresponding type per device. + AITER_DTYPE_fp8_e4m3fn, + AITER_DTYPE_fp8_e4m3fnuz, AITER_DTYPE_fp8_e8m0, AITER_DTYPE_fp16, AITER_DTYPE_bf16, @@ -45,6 +47,8 @@ static inline size_t AiterDtype_element_size(AiterDtype dtype) switch(dtype) { case AITER_DTYPE_fp8: + case AITER_DTYPE_fp8_e4m3fn: + case AITER_DTYPE_fp8_e4m3fnuz: case AITER_DTYPE_fp8_e8m0: case AITER_DTYPE_i4x2: case AITER_DTYPE_fp4x2: @@ -67,6 +71,8 @@ static inline std::string AiterDtype_to_str(int dtype) switch(dtype) { case AITER_DTYPE_fp8: return "fp8"; + case AITER_DTYPE_fp8_e4m3fn: return "fp8_e4m3fn"; + case AITER_DTYPE_fp8_e4m3fnuz: return "fp8_e4m3fnuz"; case AITER_DTYPE_fp8_e8m0: return "fp8_e8m0"; case AITER_DTYPE_fp16: return "fp16"; case AITER_DTYPE_bf16: return "bf16"; diff --git a/csrc/include/attention.h b/csrc/include/attention.h index ecf2fd94ed..fece2825ce 100644 --- a/csrc/include/attention.h +++ b/csrc/include/attention.h @@ -1,14 +1,15 @@ #pragma once // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include +#include "aiter_tensor.h" +#include void paged_attention( - torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, - torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int64_t num_kv_heads, double scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, + const aiter_tensor_t &out, const aiter_tensor_t &exp_sums, const aiter_tensor_t &max_logits, + const aiter_tensor_t &tmp_out, const aiter_tensor_t &query, const aiter_tensor_t &key_cache, + const aiter_tensor_t &value_cache, int64_t num_kv_heads, double scale, + const aiter_tensor_t &block_tables, const aiter_tensor_t &context_lens, int64_t block_size, int64_t max_context_len, - const std::optional &alibi_slopes, + const aiter_tensor_t *alibi_slopes, const std::string &kv_cache_dtype, double k_scale, double v_scale, - const std::optional &fp8_out_scale, int64_t partition_size); + const aiter_tensor_t *fp8_out_scale, int64_t partition_size); diff --git a/csrc/include/attention_ragged.h b/csrc/include/attention_ragged.h index 4a5068ff07..f9e65c7e91 100644 --- a/csrc/include/attention_ragged.h +++ b/csrc/include/attention_ragged.h @@ -1,24 +1,22 @@ #pragma once // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include +#include "aiter_tensor.h" +#include +#include void paged_attention_ragged( - torch::Tensor &out, // [num_seqs, num_heads, head_size] - torch::Tensor &workspace_buffer, - torch::Tensor &query, // [num_seqs, num_heads, head_size] - torch::Tensor - &key_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] - torch::Tensor - &value_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] + const aiter_tensor_t &out, // [num_seqs, num_heads, head_size] + const aiter_tensor_t &workspace_buffer, + const aiter_tensor_t &query, // [num_seqs, num_heads, head_size] + const aiter_tensor_t &key_cache, + const aiter_tensor_t &value_cache, double scale, - torch::Tensor &kv_indptr, // [num_seqs + 1] - torch::Tensor &kv_page_indices, // [max_num_blocks] - std::optional &kv_last_page_lens, // [num_seqs] + const aiter_tensor_t &kv_indptr, // [num_seqs + 1] + const aiter_tensor_t &kv_page_indices, // [max_num_blocks] + const aiter_tensor_t *kv_last_page_lens, // [num_seqs] int64_t block_size, int64_t max_num_partitions, - const std::optional &alibi_slopes, + const aiter_tensor_t *alibi_slopes, const std::string &kv_cache_dtype, const std::string &kv_cache_layout, - float logits_soft_cap, torch::Tensor &k_scale, torch::Tensor &v_scale, - const std::optional &fp8_out_scale, int64_t partition_size); + float logits_soft_cap, const aiter_tensor_t &k_scale, const aiter_tensor_t &v_scale, + const aiter_tensor_t *fp8_out_scale, int64_t partition_size); diff --git a/csrc/include/attention_v1.h b/csrc/include/attention_v1.h index a6127b004c..cb0e810f8d 100644 --- a/csrc/include/attention_v1.h +++ b/csrc/include/attention_v1.h @@ -1,23 +1,22 @@ #pragma once // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include +#include "aiter_tensor.h" +#include +#include void paged_attention_v1( - torch::Tensor &out, // [num_seqs, num_heads, head_size] - torch::Tensor &workspace_buffer, - torch::Tensor &query, // [num_seqs, num_heads, head_size] - torch::Tensor - &key_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] - torch::Tensor - &value_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] + const aiter_tensor_t &out, // [num_seqs, num_heads, head_size] + const aiter_tensor_t &workspace_buffer, + const aiter_tensor_t &query, // [num_seqs, num_heads, head_size] + const aiter_tensor_t &key_cache, + const aiter_tensor_t &value_cache, double scale, - torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] - const std::optional& cu_query_lens, // [num_seqs+1] - torch::Tensor &context_lens, // [num_seqs] - int64_t max_context_len, const std::optional &alibi_slopes, + const aiter_tensor_t &block_tables, // [num_seqs, max_num_blocks_per_seq] + const aiter_tensor_t *cu_query_lens, // [num_seqs+1] + const aiter_tensor_t &context_lens, // [num_seqs] + int64_t max_context_len, + const aiter_tensor_t *alibi_slopes, const std::string &kv_cache_dtype, const std::string &kv_cache_layout, - float logits_soft_cap, torch::Tensor &k_scale, torch::Tensor &v_scale, - const std::optional &fp8_out_scale, int64_t partition_size); + float logits_soft_cap, const aiter_tensor_t &k_scale, const aiter_tensor_t &v_scale, + const aiter_tensor_t *fp8_out_scale, int64_t partition_size); diff --git a/csrc/kernels/attention.cu b/csrc/kernels/attention.cu index 515c7e3f01..edb1a22f0b 100644 --- a/csrc/kernels/attention.cu +++ b/csrc/kernels/attention.cu @@ -15,14 +15,14 @@ * limitations under the License. */ -#include -#include -#include #include #include "aiter_hip_common.h" +#include "aiter_tensor.h" +#include "aiter_stream.h" #include "attention.h" #include +#include #include "dtype_fp8.cuh" #include "quant_utils.cuh" @@ -2201,13 +2201,13 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template void paged_attention_custom_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, const int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len, const std::optional& alibi_slopes, + const aiter_tensor_t& out, const aiter_tensor_t& exp_sums, const aiter_tensor_t& max_logits, + const aiter_tensor_t& tmp_out, const aiter_tensor_t& query, const aiter_tensor_t& key_cache, + const aiter_tensor_t& value_cache, const int num_kv_heads, float scale, + const aiter_tensor_t& block_tables, const aiter_tensor_t& context_lens, + int max_context_len, const aiter_tensor_t* alibi_slopes, float k_scale, float v_scale, - const std::optional& fp8_out_scale) { + const aiter_tensor_t* fp8_out_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -2219,7 +2219,7 @@ void paged_attention_custom_launcher( // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes->data_ptr()) : nullptr; float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -2228,13 +2228,13 @@ void paged_attention_custom_launcher( T* query_ptr = reinterpret_cast(query.data_ptr()); KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* block_tables_ptr = reinterpret_cast(block_tables.data_ptr()); + int* context_lens_ptr = reinterpret_cast(context_lens.data_ptr()); // NOTE: fp8_out_scale is optional. const float* fp8_out_scale_ptr = fp8_out_scale - ? reinterpret_cast(fp8_out_scale.value().data_ptr()) + ? reinterpret_cast(fp8_out_scale->data_ptr()) : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); @@ -2249,8 +2249,8 @@ void paged_attention_custom_launcher( constexpr int NTHR = 256; //PARTITION_SIZE; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); - const hipStream_t stream = at::hip::getCurrentHIPStream(); + HipDeviceGuard device_guard(query.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); switch (gqa_ratio) { case 1: //LAUNCH_CUSTOM_ATTENTION(1); @@ -2317,7 +2317,7 @@ void paged_attention_custom_launcher( LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: - TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + AITER_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } @@ -2360,7 +2360,7 @@ void paged_attention_custom_launcher( LAUNCH_CUSTOM_REDUCTION(8); break; default: - TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + AITER_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } #endif @@ -2382,7 +2382,7 @@ void paged_attention_custom_launcher( CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ break; \ default: \ - TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ + AITER_CHECK(false, "Unsupported partition size: ", partition_size); \ break; \ } /* @@ -2390,7 +2390,7 @@ void paged_attention_custom_launcher( #if defined(__HIPCC__) && defined(__gfx90a__) #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if (fp8_out_scale) { \ - TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + AITER_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ } else { \ CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ } @@ -2413,7 +2413,7 @@ void paged_attention_custom_launcher( CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + AITER_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } /* @@ -2427,7 +2427,7 @@ void paged_attention_custom_launcher( CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ default: \ - TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + AITER_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } /* @@ -2436,46 +2436,46 @@ void paged_attention_custom_launcher( break; \ */ void paged_attention( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& + const aiter_tensor_t& out, // [num_seqs, num_heads, head_size] + const aiter_tensor_t& exp_sums, // [num_seqs, num_heads, max_num_partitions] + const aiter_tensor_t& max_logits, // [num_seqs, num_heads, max_num_partitions] + const aiter_tensor_t& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& + const aiter_tensor_t& query, // [num_seqs, num_heads, head_size] + const aiter_tensor_t& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& + const aiter_tensor_t& value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + const aiter_tensor_t& block_tables, // [num_seqs, max_num_blocks_per_seq] + const aiter_tensor_t& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, - const std::optional& alibi_slopes, + const aiter_tensor_t* alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, - const std::optional& fp8_out_scale, int64_t partition_size) { + const aiter_tensor_t* fp8_out_scale, int64_t partition_size) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Half) { + if (query.dtype() == AITER_DTYPE_fp16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto); - } else if (query.dtype() == at::ScalarType::BFloat16) { + } else if (query.dtype() == AITER_DTYPE_bf16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, vllm::Fp8KVCacheDataType::kAuto); } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + AITER_CHECK(false, "Unsupported data type: ", AiterDtype_to_str(query.dtype())); } } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { - if (query.dtype() == at::ScalarType::Half) { + if (query.dtype() == AITER_DTYPE_fp16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (query.dtype() == at::ScalarType::BFloat16) { + } else if (query.dtype() == AITER_DTYPE_bf16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + AITER_CHECK(false, "Unsupported data type: ", AiterDtype_to_str(query.dtype())); } } else { - TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); + AITER_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } } diff --git a/csrc/kernels/attention_ragged.cu b/csrc/kernels/attention_ragged.cu index 3bd8fa0711..0b225b5177 100644 --- a/csrc/kernels/attention_ragged.cu +++ b/csrc/kernels/attention_ragged.cu @@ -2,9 +2,9 @@ // Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "attention_common.cuh" #include "attention_ragged.h" -#include -#include -#include +#include "aiter_hip_common.h" +#include "aiter_tensor.h" +#include "aiter_stream.h" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ defined(__gfx942__) || defined(__gfx950__)) @@ -278,22 +278,22 @@ template -void paged_attention_custom_launcher(torch::Tensor& out, - torch::Tensor& workspace_buffer, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, +void paged_attention_custom_launcher(const aiter_tensor_t& out, + const aiter_tensor_t& workspace_buffer, + const aiter_tensor_t& query, + const aiter_tensor_t& key_cache, + const aiter_tensor_t& value_cache, float scale, - torch::Tensor& kv_indptr, - torch::Tensor& kv_page_indices, - std::optional& kv_last_page_lens, + const aiter_tensor_t& kv_indptr, + const aiter_tensor_t& kv_page_indices, + const aiter_tensor_t* kv_last_page_lens, int max_num_partitions, - const std::optional& alibi_slopes, + const aiter_tensor_t* alibi_slopes, const std::string& kv_cache_layout, float logits_soft_cap, - torch::Tensor& k_scale, - torch::Tensor& v_scale, - const std::optional& fp8_out_scale) + const aiter_tensor_t& k_scale, + const aiter_tensor_t& v_scale, + const aiter_tensor_t* fp8_out_scale) { const int num_kv_heads = kv_cache_layout == "HND" ? key_cache.size(1) : key_cache.size(2); int num_seqs = query.size(0); @@ -306,21 +306,21 @@ void paged_attention_custom_launcher(torch::Tensor& out, // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = - alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + alibi_slopes ? reinterpret_cast(alibi_slopes->data_ptr()) : nullptr; T* query_ptr = reinterpret_cast(query.data_ptr()); KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* kv_indptr_ptr = kv_indptr.data_ptr(); - int* kv_page_indices_ptr = kv_page_indices.data_ptr(); + int* kv_indptr_ptr = reinterpret_cast(kv_indptr.data_ptr()); + int* kv_page_indices_ptr = reinterpret_cast(kv_page_indices.data_ptr()); int* kv_last_page_lens_ptr = - BLOCK_SIZE > 1 ? kv_last_page_lens.value().data_ptr() : nullptr; + BLOCK_SIZE > 1 ? reinterpret_cast(kv_last_page_lens->data_ptr()) : nullptr; const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); // NOTE: fp8_out_scale is optional. const float* fp8_out_scale_ptr = - fp8_out_scale ? reinterpret_cast(fp8_out_scale.value().data_ptr()) : nullptr; + fp8_out_scale ? reinterpret_cast(fp8_out_scale->data_ptr()) : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const float logits_soft_cap_rcp = (LOGITS_SOFT_CAP_ENABLED ? 1.f / logits_soft_cap : 0.f); @@ -343,8 +343,8 @@ void paged_attention_custom_launcher(torch::Tensor& out, constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); - const hipStream_t stream = at::hip::getCurrentHIPStream(); + HipDeviceGuard device_guard(query.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch(gqa_ratio) @@ -365,7 +365,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, case 14: LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; - default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; + default: AITER_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } dim3 reduce_grid(num_heads, num_seqs); @@ -384,7 +384,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, case 6: LAUNCH_CUSTOM_REDUCTION(6); break; case 7: LAUNCH_CUSTOM_REDUCTION(7); break; case 8: LAUNCH_CUSTOM_REDUCTION(8); break; - default: TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; + default: AITER_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } } @@ -429,7 +429,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, } \ else \ { \ - TORCH_CHECK(false, "logits_soft_cap must be non-negative"); \ + AITER_CHECK(false, "logits_soft_cap must be non-negative"); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE) \ @@ -446,14 +446,14 @@ void paged_attention_custom_launcher(torch::Tensor& out, switch(partition_size) \ { \ case 256: CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); break; \ - default: TORCH_CHECK(false, "Unsupported partition size: ", partition_size); break; \ + default: AITER_CHECK(false, "Unsupported partition size: ", partition_size); break; \ } #if defined(__HIPCC__) && defined(__gfx90a__) #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if(fp8_out_scale) \ { \ - TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + AITER_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ } \ else \ { \ @@ -476,7 +476,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, case 1: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 1, HEAD_SIZE); break; \ case 16: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); break; \ case 32: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); break; \ - default: TORCH_CHECK(false, "Unsupported block size: ", block_size); break; \ + default: AITER_CHECK(false, "Unsupported block size: ", block_size); break; \ } #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ @@ -484,68 +484,68 @@ void paged_attention_custom_launcher(torch::Tensor& out, { \ case 64: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); break; \ case 128: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); break; \ - default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; \ + default: AITER_CHECK(false, "Unsupported head size: ", head_size); break; \ } void paged_attention_ragged( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& workspace_buffer, - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] + const aiter_tensor_t& out, // [num_seqs, num_heads, head_size] + const aiter_tensor_t& workspace_buffer, + const aiter_tensor_t& query, // [num_seqs, num_heads, head_size] + const aiter_tensor_t& key_cache, // [num_blocks, num_heads, block_size, head_size] or + // [num_blocks, block_size, num_heads, head_size] + const aiter_tensor_t& value_cache, // [num_blocks, num_heads, block_size, head_size] or + // [num_blocks, block_size, num_heads, head_size] double scale, - torch::Tensor& kv_indptr, // [num_seqs + 1] - torch::Tensor& kv_page_indices, // [max_num_blocks] - std::optional& kv_last_page_lens, // [num_seqs] + const aiter_tensor_t& kv_indptr, // [num_seqs + 1] + const aiter_tensor_t& kv_page_indices, // [max_num_blocks] + const aiter_tensor_t* kv_last_page_lens, // [num_seqs] int64_t block_size, int64_t max_num_partitions, - const std::optional& alibi_slopes, + const aiter_tensor_t* alibi_slopes, const std::string& kv_cache_dtype, const std::string& kv_cache_layout, float logits_soft_cap, - torch::Tensor& k_scale, - torch::Tensor& v_scale, - const std::optional& fp8_out_scale, + const aiter_tensor_t& k_scale, + const aiter_tensor_t& v_scale, + const aiter_tensor_t* fp8_out_scale, int64_t partition_size) { const int head_size = query.size(2); if(kv_cache_dtype == "auto") { - if(query.dtype() == at::ScalarType::Half) + if(query.dtype() == AITER_DTYPE_fp16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto); } - else if(query.dtype() == at::ScalarType::BFloat16) + else if(query.dtype() == AITER_DTYPE_bf16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD( __hip_bfloat16, __hip_bfloat16, vllm::Fp8KVCacheDataType::kAuto); } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + AITER_CHECK(false, "Unsupported data type"); } } else if(kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { - if(query.dtype() == at::ScalarType::Half) + if(query.dtype() == AITER_DTYPE_fp16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } - else if(query.dtype() == at::ScalarType::BFloat16) + else if(query.dtype() == AITER_DTYPE_bf16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD( __hip_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + AITER_CHECK(false, "Unsupported data type"); } } else { - TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); + AITER_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } } diff --git a/csrc/kernels/attention_v1.cu b/csrc/kernels/attention_v1.cu index d8a57fa905..156096b7e3 100644 --- a/csrc/kernels/attention_v1.cu +++ b/csrc/kernels/attention_v1.cu @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include +#include "aiter_hip_common.h" +#include "aiter_tensor.h" +#include "aiter_stream.h" #include "attention_v1.h" #include "attention_common.cuh" @@ -226,23 +226,23 @@ template -void paged_attention_custom_launcher(torch::Tensor& out, - torch::Tensor& workspace_buffer, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, +void paged_attention_custom_launcher(const aiter_tensor_t& out, + const aiter_tensor_t& workspace_buffer, + const aiter_tensor_t& query, + const aiter_tensor_t& key_cache, + const aiter_tensor_t& value_cache, float scale, - torch::Tensor& block_tables, - const std::optional& cu_query_lens, - torch::Tensor& context_lens, + const aiter_tensor_t& block_tables, + const aiter_tensor_t* cu_query_lens, + const aiter_tensor_t& context_lens, int max_num_blocks_per_seq, int max_num_partitions, - const std::optional& alibi_slopes, + const aiter_tensor_t* alibi_slopes, const std::string& kv_cache_layout, float logits_soft_cap, - torch::Tensor& k_scale, - torch::Tensor& v_scale, - const std::optional& fp8_out_scale) + const aiter_tensor_t& k_scale, + const aiter_tensor_t& v_scale, + const aiter_tensor_t* fp8_out_scale) { const int num_kv_heads = kv_cache_layout=="HND" ? key_cache.size(1) : key_cache.size(2); int num_seqs = context_lens.size(0); @@ -255,20 +255,20 @@ void paged_attention_custom_launcher(torch::Tensor& out, // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = - alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + alibi_slopes ? reinterpret_cast(alibi_slopes->data_ptr()) : nullptr; T* query_ptr = reinterpret_cast(query.data_ptr()); KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* context_lens_ptr = context_lens.data_ptr(); - int* block_tables_ptr = block_tables.data_ptr(); - int* cu_query_lens_ptr = cu_query_lens ? cu_query_lens.value().data_ptr() : nullptr; + int* context_lens_ptr = reinterpret_cast(context_lens.data_ptr()); + int* block_tables_ptr = reinterpret_cast(block_tables.data_ptr()); + int* cu_query_lens_ptr = cu_query_lens ? reinterpret_cast(cu_query_lens->data_ptr()) : nullptr; const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); // NOTE: fp8_out_scale is optional. const float* fp8_out_scale_ptr = - fp8_out_scale ? reinterpret_cast(fp8_out_scale.value().data_ptr()) : nullptr; + fp8_out_scale ? reinterpret_cast(fp8_out_scale->data_ptr()) : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const float logits_soft_cap_rcp = (LOGITS_SOFT_CAP_ENABLED ? 1.f / logits_soft_cap : 0.f); @@ -291,8 +291,8 @@ void paged_attention_custom_launcher(torch::Tensor& out, dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); - const hipStream_t stream = at::hip::getCurrentHIPStream(); + HipDeviceGuard device_guard(query.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch(gqa_ratio) @@ -313,7 +313,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, case 14: LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; - default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; + default: AITER_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } dim3 reduce_grid(num_heads, num_seqs); @@ -331,7 +331,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, case 6: LAUNCH_CUSTOM_REDUCTION(6); break; case 7: LAUNCH_CUSTOM_REDUCTION(7); break; case 8: LAUNCH_CUSTOM_REDUCTION(8); break; - default: TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; + default: AITER_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } } @@ -378,7 +378,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, } \ else \ { \ - TORCH_CHECK(false, "logits_soft_cap must be non-negative"); \ + AITER_CHECK(false, "logits_soft_cap must be non-negative"); \ } @@ -396,14 +396,14 @@ void paged_attention_custom_launcher(torch::Tensor& out, switch(partition_size) \ { \ case 256: CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); break; \ - default: TORCH_CHECK(false, "Unsupported partition size: ", partition_size); break; \ + default: AITER_CHECK(false, "Unsupported partition size: ", partition_size); break; \ } #if defined(__HIPCC__) && defined(__gfx90a__) #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if(fp8_out_scale) \ { \ - TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + AITER_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ } \ else \ { \ @@ -426,7 +426,7 @@ void paged_attention_custom_launcher(torch::Tensor& out, case 1: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 1, HEAD_SIZE); break; \ case 16: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); break; \ case 32: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); break; \ - default: TORCH_CHECK(false, "Unsupported block size: ", block_size); break; \ + default: AITER_CHECK(false, "Unsupported block size: ", block_size); break; \ } #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ @@ -434,30 +434,30 @@ void paged_attention_custom_launcher(torch::Tensor& out, { \ case 64: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); break; \ case 128: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); break; \ - default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; \ + default: AITER_CHECK(false, "Unsupported head size: ", head_size); break; \ } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& workspace_buffer, - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] or - // [num_blocks, block_size, num_heads, head_size] + const aiter_tensor_t& out, // [num_seqs, num_heads, head_size] + const aiter_tensor_t& workspace_buffer, + const aiter_tensor_t& query, // [num_seqs, num_heads, head_size] + const aiter_tensor_t& key_cache, // [num_blocks, num_heads, block_size, head_size] or + // [num_blocks, block_size, num_heads, head_size] + const aiter_tensor_t& value_cache, // [num_blocks, num_heads, block_size, head_size] or + // [num_blocks, block_size, num_heads, head_size] double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - const std::optional& cu_query_lens, // [num_seqs+1] - torch::Tensor& context_lens, // [num_seqs] + const aiter_tensor_t& block_tables, // [num_seqs, max_num_blocks_per_seq] + const aiter_tensor_t* cu_query_lens, // [num_seqs+1] + const aiter_tensor_t& context_lens, // [num_seqs] int64_t max_context_len, - const std::optional& alibi_slopes, + const aiter_tensor_t* alibi_slopes, const std::string& kv_cache_dtype, const std::string& kv_cache_layout, float logits_soft_cap, - torch::Tensor& k_scale, - torch::Tensor& v_scale, - const std::optional& fp8_out_scale, int64_t partition_size) + const aiter_tensor_t& k_scale, + const aiter_tensor_t& v_scale, + const aiter_tensor_t* fp8_out_scale, int64_t partition_size) { const int64_t block_size = kv_cache_layout=="HND" ? key_cache.size(2) : key_cache.size(1); const int head_size = query.size(2); @@ -467,39 +467,39 @@ void paged_attention_v1( if(kv_cache_dtype == "auto") { - if(query.dtype() == at::ScalarType::Half) + if(query.dtype() == AITER_DTYPE_fp16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto); } - else if(query.dtype() == at::ScalarType::BFloat16) + else if(query.dtype() == AITER_DTYPE_bf16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD( __hip_bfloat16, __hip_bfloat16, vllm::Fp8KVCacheDataType::kAuto); } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + AITER_CHECK(false, "Unsupported data type"); } } else if(kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { - if(query.dtype() == at::ScalarType::Half) + if(query.dtype() == AITER_DTYPE_fp16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } - else if(query.dtype() == at::ScalarType::BFloat16) + else if(query.dtype() == AITER_DTYPE_bf16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD( __hip_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + AITER_CHECK(false, "Unsupported data type"); } } else { - TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); + AITER_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } }