Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "../utils/compatibility.hpp"

#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include <cstdint>

#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
Expand All @@ -16,6 +18,18 @@
namespace deep_gemm::attention {

#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static bool has_byte_overlap(const torch::Tensor& a, const torch::Tensor& b) {
if (a.numel() == 0 or b.numel() == 0)
return false;

const auto a_begin = reinterpret_cast<uintptr_t>(a.data_ptr());
const auto b_begin = reinterpret_cast<uintptr_t>(b.data_ptr());
const auto a_end = a_begin + static_cast<uintptr_t>(a.nbytes());
const auto b_end = b_begin + static_cast<uintptr_t>(b.nbytes());

return a_begin < b_end and b_begin < a_end;
}

static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
Expand Down Expand Up @@ -195,18 +209,25 @@ static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::opt
return logits;
}

static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional<torch::Tensor>& indices) {
static void get_paged_mqa_logits_metadata_out(const torch::Tensor& context_lens,
const torch::Tensor& schedule_metadata,
int block_kv, int num_sms,
const std::optional<torch::Tensor>& indices) {
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const int batch_size = context_lens.size(0);
const int next_n = context_lens.size(1);
const bool is_varlen = indices.has_value();
DG_HOST_ASSERT(num_sms > 0);
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
DG_HOST_ASSERT(context_lens.is_contiguous());

// Create metadata tensor
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
DG_HOST_ASSERT(schedule_metadata.dim() == 2);
DG_HOST_ASSERT(schedule_metadata.size(0) == num_sms + 1 and schedule_metadata.size(1) == 2);
DG_HOST_ASSERT(schedule_metadata.scalar_type() == torch::kInt);
DG_HOST_ASSERT(schedule_metadata.is_contiguous());
DG_HOST_ASSERT(schedule_metadata.device() == context_lens.device());
DG_HOST_ASSERT(not has_byte_overlap(schedule_metadata, context_lens));

// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
Expand All @@ -216,6 +237,8 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
DG_HOST_ASSERT(indices_tensor.device() == context_lens.device());
DG_HOST_ASSERT(not has_byte_overlap(schedule_metadata, indices_tensor));
sm100_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, batch_size * next_n, next_n, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr<int>());
} else if (arch_major == 10) {
DG_HOST_ASSERT(block_kv == 64 or block_kv == 32);
Expand All @@ -226,6 +249,13 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}

static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
int block_kv, int num_sms,
const std::optional<torch::Tensor>& indices) {
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
get_paged_mqa_logits_metadata_out(context_lens, schedule_metadata, block_kv, num_sms, indices);

return schedule_metadata;
}
Expand Down Expand Up @@ -352,6 +382,7 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
DG_HOST_ASSERT(indices_tensor.device() == context_lens.device());
}

// Check schedule metadata
Expand Down Expand Up @@ -448,6 +479,9 @@ static void register_apis(pybind11::module_& m) {
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"),
py::arg("indices") = std::nullopt);
m.def("get_paged_mqa_logits_metadata_out", &get_paged_mqa_logits_metadata_out,
py::arg("context_lens"), py::arg("schedule_metadata"), py::arg("block_kv"),
py::arg("num_sms"), py::arg("indices") = std::nullopt);
m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
Expand Down
18 changes: 18 additions & 0 deletions csrc/tvm_ffi_api.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cstdint>
#include <limits>
#include <optional>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/container/array.h>
Expand Down Expand Up @@ -573,6 +574,8 @@ Tensor dg_fp8_mqa_logits(TensorView q, TensorView kv_data, TensorView kv_sf,

Tensor dg_get_paged_mqa_logits_metadata(TensorView context_lens, int64_t block_kv,
int64_t num_sms, Optional<TensorView> indices) {
DG_HOST_ASSERT(num_sms > 0 and num_sms <= std::numeric_limits<int>::max());
DG_HOST_ASSERT(block_kv == 64 or block_kv == 32);
auto indices_val = indices.has_value()?
std::optional<torch::Tensor>(convert_to_torch_tensor(indices.value()))
: std::nullopt;
Expand All @@ -582,6 +585,20 @@ Tensor dg_get_paged_mqa_logits_metadata(TensorView context_lens, int64_t block_k
return Tensor::FromDLPack(at::toDLPack(result));
}

void dg_get_paged_mqa_logits_metadata_out(TensorView context_lens, TensorView schedule_metadata,
int64_t block_kv, int64_t num_sms,
Optional<TensorView> indices) {
DG_HOST_ASSERT(num_sms > 0 and num_sms <= std::numeric_limits<int>::max());
DG_HOST_ASSERT(block_kv == 64 or block_kv == 32);
auto indices_val = indices.has_value() ?
std::optional<torch::Tensor>(convert_to_torch_tensor(indices.value()))
: std::nullopt;
attention::get_paged_mqa_logits_metadata_out(
convert_to_torch_tensor(context_lens),
convert_to_torch_tensor(schedule_metadata), static_cast<int>(block_kv),
static_cast<int>(num_sms), indices_val);
}

Tensor dg_fp8_paged_mqa_logits(TensorView q, TensorView fused_kv_cache,
TensorView weights, TensorView context_lens,
TensorView block_table, TensorView schedule_meta,
Expand Down Expand Up @@ -630,6 +647,7 @@ Tensor dg_fp8_fp4_paged_mqa_logits(TensorView q, Optional<TensorView> q_sf, Tens
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_gemm_nt_skip_head_mid, dg_fp8_gemm_nt_skip_head_mid);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_mqa_logits, dg_fp8_mqa_logits);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(get_paged_mqa_logits_metadata, dg_get_paged_mqa_logits_metadata);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(get_paged_mqa_logits_metadata_out, dg_get_paged_mqa_logits_metadata_out);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_paged_mqa_logits, dg_fp8_paged_mqa_logits);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_mqa_logits, dg_fp8_fp4_mqa_logits);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_paged_mqa_logits, dg_fp8_fp4_paged_mqa_logits);
Expand Down
1 change: 1 addition & 0 deletions deep_gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
# Attention kernels
fp8_fp4_mqa_logits,
get_paged_mqa_logits_metadata,
get_paged_mqa_logits_metadata_out,
fp8_fp4_paged_mqa_logits,
# Attention kernels (legacy)
fp8_mqa_logits,
Expand Down
7 changes: 7 additions & 0 deletions sgl_deep_gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def fp8_fp4_mqa_logits(q, kv, weights, cu_seq_len_k_start, cu_seq_len_k_end, cle
def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None):
return _C.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices)

def get_paged_mqa_logits_metadata_out(
context_lens, schedule_metadata, block_kv, num_sms, indices=None
):
_C.get_paged_mqa_logits_metadata_out(
context_lens, schedule_metadata, block_kv, num_sms, indices
)

def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
_C.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits)

Expand Down
111 changes: 111 additions & 0 deletions sgl_deep_gemm/tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,116 @@ def enumerate_mqa_logits():
print()


def _expect_paged_mqa_logits_metadata_out_failure(*args, **kwargs):
try:
deep_gemm.get_paged_mqa_logits_metadata_out(*args, **kwargs)
except Exception:
return
raise AssertionError("Expected get_paged_mqa_logits_metadata_out to fail")


def test_paged_mqa_logits_metadata_out():
if get_arch_major() not in (9, 10):
return

context_lens = torch.tensor(
[[128, 130, 132], [65, 67, 69], [200, 202, 204]],
device='cuda',
dtype=torch.int,
)
num_sms = deep_gemm.get_num_sms()

for block_kv in ((64, 32) if get_arch_major() == 10 else (64, )):
expected = deep_gemm.get_paged_mqa_logits_metadata(
context_lens, block_kv, num_sms
)
actual = torch.full_like(expected, -1)
actual_ptr = actual.data_ptr()
ret = deep_gemm.get_paged_mqa_logits_metadata_out(
context_lens, actual, block_kv, num_sms
)
torch.cuda.synchronize()

assert ret is None
assert actual.data_ptr() == actual_ptr
assert torch.equal(actual, expected)

_expect_paged_mqa_logits_metadata_out_failure(
context_lens,
torch.empty((num_sms, 2), device='cuda', dtype=torch.int),
64,
num_sms,
)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens,
torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int64),
64,
num_sms,
)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens, torch.empty((1, 2), device='cuda', dtype=torch.int), 64, 0
)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens, torch.empty((1, 2), device='cuda', dtype=torch.int), 64, -1
)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens,
torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int),
48,
num_sms,
)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens,
torch.empty((num_sms + 1, 2), device='cuda', dtype=torch.int),
2**40 + 64,
num_sms,
)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens,
torch.empty((num_sms + 1, 2), dtype=torch.int),
64,
num_sms,
)
alias_context_lens = torch.empty((3, 2), device='cuda', dtype=torch.int)
_expect_paged_mqa_logits_metadata_out_failure(
alias_context_lens, alias_context_lens, 64, 2
)
partial_overlap = torch.empty(7, device='cuda', dtype=torch.int)
_expect_paged_mqa_logits_metadata_out_failure(
partial_overlap[:6].view(3, 2), partial_overlap[1:].view(3, 2), 64, 2
)

if get_arch_major() == 10:
indices = torch.tensor([2, 0, 1], device='cuda', dtype=torch.int)
context_lens = context_lens[:, :1].contiguous()
expected = deep_gemm.get_paged_mqa_logits_metadata(
context_lens, 64, num_sms, indices=indices
)
actual = torch.full_like(expected, -1)
actual_ptr = actual.data_ptr()
ret = deep_gemm.get_paged_mqa_logits_metadata_out(
context_lens, actual, 64, num_sms, indices=indices
)
torch.cuda.synchronize()

assert ret is None
assert actual.data_ptr() == actual_ptr
assert torch.equal(actual, expected)
indices_overlap = torch.empty(5, device='cuda', dtype=torch.int)
_expect_paged_mqa_logits_metadata_out_failure(
context_lens,
indices_overlap[1:].view(2, 2),
64,
1,
indices=indices_overlap[:3],
)
if torch.cuda.device_count() > 1:
wrong_device_indices = indices.to('cuda:1')
_expect_paged_mqa_logits_metadata_out_failure(
context_lens, actual, 64, num_sms, indices=wrong_device_indices
)


def ref_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor,
max_model_len: int, use_2d_context_lens: bool):
Expand Down Expand Up @@ -494,4 +604,5 @@ def enumerate_paged_mqa_logits():

test_gemm_skip_head_mid()
test_mqa_logits()
test_paged_mqa_logits_metadata_out()
test_paged_mqa_logits()
Loading