diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 8448a02e7..e987b26ad 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -3,6 +3,8 @@ #include "../utils/compatibility.hpp" #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include + #include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" @@ -16,6 +18,18 @@ namespace deep_gemm::attention { #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static bool has_byte_overlap(const torch::Tensor& a, const torch::Tensor& b) { + if (a.numel() == 0 or b.numel() == 0) + return false; + + const auto a_begin = reinterpret_cast(a.data_ptr()); + const auto b_begin = reinterpret_cast(b.data_ptr()); + const auto a_end = a_begin + static_cast(a.nbytes()); + const auto b_end = b_begin + static_cast(b.nbytes()); + + return a_begin < b_end and b_begin < a_end; +} + static void fp8_gemm_nt_skip_head_mid(const std::pair& a, const std::pair& b, const torch::Tensor& d, @@ -195,18 +209,25 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple& indices) { +static void get_paged_mqa_logits_metadata_out(const torch::Tensor& context_lens, + const torch::Tensor& schedule_metadata, + int block_kv, int num_sms, + const std::optional& indices) { // NOTES: Only 2D context lens is supported for now DG_HOST_ASSERT(context_lens.dim() == 2); const bool is_context_lens_2d = true; const int batch_size = context_lens.size(0); const int next_n = context_lens.size(1); const bool is_varlen = indices.has_value(); + DG_HOST_ASSERT(num_sms > 0); DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); DG_HOST_ASSERT(context_lens.is_contiguous()); - - // Create metadata tensor - auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options()); + DG_HOST_ASSERT(schedule_metadata.dim() == 2); + DG_HOST_ASSERT(schedule_metadata.size(0) == num_sms + 1 and schedule_metadata.size(1) == 2); + DG_HOST_ASSERT(schedule_metadata.scalar_type() == torch::kInt); + DG_HOST_ASSERT(schedule_metadata.is_contiguous()); + DG_HOST_ASSERT(schedule_metadata.device() == context_lens.device()); + DG_HOST_ASSERT(not has_byte_overlap(schedule_metadata, context_lens)); // Dispatch implementation const auto arch_major = device_runtime->get_arch_major(); @@ -216,6 +237,8 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_ DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); DG_HOST_ASSERT(indices_tensor.is_contiguous()); DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + DG_HOST_ASSERT(indices_tensor.device() == context_lens.device()); + DG_HOST_ASSERT(not has_byte_overlap(schedule_metadata, indices_tensor)); sm100_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, batch_size * next_n, next_n, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr()); } else if (arch_major == 10) { DG_HOST_ASSERT(block_kv == 64 or block_kv == 32); @@ -226,6 +249,13 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_ } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } +} + +static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, + int block_kv, int num_sms, + const std::optional& indices) { + auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options()); + get_paged_mqa_logits_metadata_out(context_lens, schedule_metadata, block_kv, num_sms, indices); return schedule_metadata; } @@ -352,6 +382,7 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple +#include #include #include #include @@ -573,6 +574,8 @@ Tensor dg_fp8_mqa_logits(TensorView q, TensorView kv_data, TensorView kv_sf, Tensor dg_get_paged_mqa_logits_metadata(TensorView context_lens, int64_t block_kv, int64_t num_sms, Optional indices) { + DG_HOST_ASSERT(num_sms > 0 and num_sms <= std::numeric_limits::max()); + DG_HOST_ASSERT(block_kv == 64 or block_kv == 32); auto indices_val = indices.has_value()? std::optional(convert_to_torch_tensor(indices.value())) : std::nullopt; @@ -582,6 +585,20 @@ Tensor dg_get_paged_mqa_logits_metadata(TensorView context_lens, int64_t block_k return Tensor::FromDLPack(at::toDLPack(result)); } +void dg_get_paged_mqa_logits_metadata_out(TensorView context_lens, TensorView schedule_metadata, + int64_t block_kv, int64_t num_sms, + Optional indices) { + DG_HOST_ASSERT(num_sms > 0 and num_sms <= std::numeric_limits::max()); + DG_HOST_ASSERT(block_kv == 64 or block_kv == 32); + auto indices_val = indices.has_value() ? + std::optional(convert_to_torch_tensor(indices.value())) + : std::nullopt; + attention::get_paged_mqa_logits_metadata_out( + convert_to_torch_tensor(context_lens), + convert_to_torch_tensor(schedule_metadata), static_cast(block_kv), + static_cast(num_sms), indices_val); +} + Tensor dg_fp8_paged_mqa_logits(TensorView q, TensorView fused_kv_cache, TensorView weights, TensorView context_lens, TensorView block_table, TensorView schedule_meta, @@ -630,6 +647,7 @@ Tensor dg_fp8_fp4_paged_mqa_logits(TensorView q, Optional q_sf, Tens TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_gemm_nt_skip_head_mid, dg_fp8_gemm_nt_skip_head_mid); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_mqa_logits, dg_fp8_mqa_logits); TVM_FFI_DLL_EXPORT_TYPED_FUNC(get_paged_mqa_logits_metadata, dg_get_paged_mqa_logits_metadata); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(get_paged_mqa_logits_metadata_out, dg_get_paged_mqa_logits_metadata_out); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_paged_mqa_logits, dg_fp8_paged_mqa_logits); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_mqa_logits, dg_fp8_fp4_mqa_logits); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_paged_mqa_logits, dg_fp8_fp4_paged_mqa_logits); diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index c782e09e1..c8bf8a595 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -62,6 +62,7 @@ # Attention kernels fp8_fp4_mqa_logits, get_paged_mqa_logits_metadata, + get_paged_mqa_logits_metadata_out, fp8_fp4_paged_mqa_logits, # Attention kernels (legacy) fp8_mqa_logits, diff --git a/sgl_deep_gemm/__init__.py b/sgl_deep_gemm/__init__.py index 4d94f6bdd..a01fae617 100644 --- a/sgl_deep_gemm/__init__.py +++ b/sgl_deep_gemm/__init__.py @@ -223,6 +223,13 @@ def fp8_fp4_mqa_logits(q, kv, weights, cu_seq_len_k_start, cu_seq_len_k_end, cle def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None): return _C.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices) + def get_paged_mqa_logits_metadata_out( + context_lens, schedule_metadata, block_kv, num_sms, indices=None + ): + _C.get_paged_mqa_logits_metadata_out( + context_lens, schedule_metadata, block_kv, num_sms, indices + ) + def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): _C.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits) diff --git a/sgl_deep_gemm/tests/test_attention.py b/sgl_deep_gemm/tests/test_attention.py index adb219b7d..14f466fc8 100644 --- a/sgl_deep_gemm/tests/test_attention.py +++ b/sgl_deep_gemm/tests/test_attention.py @@ -245,6 +245,116 @@ def enumerate_mqa_logits(): print() +def _expect_paged_mqa_logits_metadata_out_failure(*args, **kwargs): + try: + deep_gemm.get_paged_mqa_logits_metadata_out(*args, **kwargs) + except Exception: + return + raise AssertionError("Expected get_paged_mqa_logits_metadata_out to fail") + + +def test_paged_mqa_logits_metadata_out(): + if get_arch_major() not in (9, 10): + return + + context_lens = torch.tensor( + [[128, 130, 132], [65, 67, 69], [200, 202, 204]], + device='cuda', + dtype=torch.int, + ) + num_sms = deep_gemm.get_num_sms() + + for block_kv in ((64, 32) if get_arch_major() == 10 else (64, )): + expected = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, block_kv, num_sms + ) + actual = torch.full_like(expected, -1) + actual_ptr = actual.data_ptr() + ret = deep_gemm.get_paged_mqa_logits_metadata_out( + context_lens, actual, block_kv, num_sms + ) + torch.cuda.synchronize() + + assert ret is None + assert actual.data_ptr() == actual_ptr + assert torch.equal(actual, expected) + + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms, 2), device='cuda', dtype=torch.int), + 64, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int64), + 64, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, torch.empty((1, 2), device='cuda', dtype=torch.int), 64, 0 + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, torch.empty((1, 2), device='cuda', dtype=torch.int), 64, -1 + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int), + 48, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int), + 2**40 + 64, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), dtype=torch.int), + 64, + num_sms, + ) + alias_context_lens = torch.empty((3, 2), device='cuda', dtype=torch.int) + _expect_paged_mqa_logits_metadata_out_failure( + alias_context_lens, alias_context_lens, 64, 2 + ) + partial_overlap = torch.empty(7, device='cuda', dtype=torch.int) + _expect_paged_mqa_logits_metadata_out_failure( + partial_overlap[:6].view(3, 2), partial_overlap[1:].view(3, 2), 64, 2 + ) + + if get_arch_major() == 10: + indices = torch.tensor([2, 0, 1], device='cuda', dtype=torch.int) + context_lens = context_lens[:, :1].contiguous() + expected = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, 64, num_sms, indices=indices + ) + actual = torch.full_like(expected, -1) + actual_ptr = actual.data_ptr() + ret = deep_gemm.get_paged_mqa_logits_metadata_out( + context_lens, actual, 64, num_sms, indices=indices + ) + torch.cuda.synchronize() + + assert ret is None + assert actual.data_ptr() == actual_ptr + assert torch.equal(actual, expected) + indices_overlap = torch.empty(5, device='cuda', dtype=torch.int) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + indices_overlap[1:].view(2, 2), + 64, + 1, + indices=indices_overlap[:3], + ) + if torch.cuda.device_count() > 1: + wrong_device_indices = indices.to('cuda:1') + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, actual, 64, num_sms, indices=wrong_device_indices + ) + + def ref_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, max_model_len: int, use_2d_context_lens: bool): @@ -494,4 +604,5 @@ def enumerate_paged_mqa_logits(): test_gemm_skip_head_mid() test_mqa_logits() + test_paged_mqa_logits_metadata_out() test_paged_mqa_logits() diff --git a/tests/test_attention.py b/tests/test_attention.py index adb219b7d..14f466fc8 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -245,6 +245,116 @@ def enumerate_mqa_logits(): print() +def _expect_paged_mqa_logits_metadata_out_failure(*args, **kwargs): + try: + deep_gemm.get_paged_mqa_logits_metadata_out(*args, **kwargs) + except Exception: + return + raise AssertionError("Expected get_paged_mqa_logits_metadata_out to fail") + + +def test_paged_mqa_logits_metadata_out(): + if get_arch_major() not in (9, 10): + return + + context_lens = torch.tensor( + [[128, 130, 132], [65, 67, 69], [200, 202, 204]], + device='cuda', + dtype=torch.int, + ) + num_sms = deep_gemm.get_num_sms() + + for block_kv in ((64, 32) if get_arch_major() == 10 else (64, )): + expected = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, block_kv, num_sms + ) + actual = torch.full_like(expected, -1) + actual_ptr = actual.data_ptr() + ret = deep_gemm.get_paged_mqa_logits_metadata_out( + context_lens, actual, block_kv, num_sms + ) + torch.cuda.synchronize() + + assert ret is None + assert actual.data_ptr() == actual_ptr + assert torch.equal(actual, expected) + + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms, 2), device='cuda', dtype=torch.int), + 64, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int64), + 64, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, torch.empty((1, 2), device='cuda', dtype=torch.int), 64, 0 + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, torch.empty((1, 2), device='cuda', dtype=torch.int), 64, -1 + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int), + 48, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int), + 2**40 + 64, + num_sms, + ) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + torch.empty((num_sms + 1, 2), dtype=torch.int), + 64, + num_sms, + ) + alias_context_lens = torch.empty((3, 2), device='cuda', dtype=torch.int) + _expect_paged_mqa_logits_metadata_out_failure( + alias_context_lens, alias_context_lens, 64, 2 + ) + partial_overlap = torch.empty(7, device='cuda', dtype=torch.int) + _expect_paged_mqa_logits_metadata_out_failure( + partial_overlap[:6].view(3, 2), partial_overlap[1:].view(3, 2), 64, 2 + ) + + if get_arch_major() == 10: + indices = torch.tensor([2, 0, 1], device='cuda', dtype=torch.int) + context_lens = context_lens[:, :1].contiguous() + expected = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, 64, num_sms, indices=indices + ) + actual = torch.full_like(expected, -1) + actual_ptr = actual.data_ptr() + ret = deep_gemm.get_paged_mqa_logits_metadata_out( + context_lens, actual, 64, num_sms, indices=indices + ) + torch.cuda.synchronize() + + assert ret is None + assert actual.data_ptr() == actual_ptr + assert torch.equal(actual, expected) + indices_overlap = torch.empty(5, device='cuda', dtype=torch.int) + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, + indices_overlap[1:].view(2, 2), + 64, + 1, + indices=indices_overlap[:3], + ) + if torch.cuda.device_count() > 1: + wrong_device_indices = indices.to('cuda:1') + _expect_paged_mqa_logits_metadata_out_failure( + context_lens, actual, 64, num_sms, indices=wrong_device_indices + ) + + def ref_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, max_model_len: int, use_2d_context_lens: bool): @@ -494,4 +604,5 @@ def enumerate_paged_mqa_logits(): test_gemm_skip_head_mid() test_mqa_logits() + test_paged_mqa_logits_metadata_out() test_paged_mqa_logits()