Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ static bool early_return(const int& m, const int &n, const int& k,
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
if (c.has_value()) {
check_major_type_cd(c.value());
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat or d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type());
}

// No accumulation
Expand Down
14 changes: 7 additions & 7 deletions csrc/jit_kernels/heuristics/sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ struct SM90ArchSpec {
static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) {
int start = 16;

// Avoid bank conflicts for 1D1D kernel FP32 output
// Avoid bank conflicts for 1D1D kernel (applies to both FP32 and BF16 since no D-swizzling)
std::vector<int> candidates;
if (kernel_type == KernelType::Kernel1D1D and cd_dtype == torch::kFloat) {
if (kernel_type == KernelType::Kernel1D1D) {
candidates.push_back(16);
start = 24;
}
Expand Down Expand Up @@ -63,15 +63,15 @@ struct SM90ArchSpec {
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
// SM90 FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
// SM90 FP32 output does not support `block_m == 256`; 1D1D kernel only supports block_m <= 128
if (block_m == 256 and (cd_dtype == at::kFloat or kernel_type == KernelType::Kernel1D1D))
return false;

// Avoid large C/D shared memory for FP32 output
// Avoid large C/D shared memory
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
return false;
if (block_n > 128 and cd_dtype == torch::kFloat) {
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
return false;
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
return false;
}
Expand Down
10 changes: 6 additions & 4 deletions csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{}, {}
{}, {}, {}
>);
}};
)",
Expand All @@ -60,6 +60,7 @@ static void __instantiate_kernel() {{
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
args.gemm_config.with_accumulation,
to_string(args.gemm_config.cd_dtype));
}

Expand All @@ -82,7 +83,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat or d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);

const auto& config = get_best_config<SM90ArchSpec>(
Expand Down Expand Up @@ -148,7 +149,8 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te
const torch::Tensor& tensor_map_buffer,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat or d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(c.has_value());
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);

// Get config using max K for better performance
Expand Down Expand Up @@ -188,7 +190,7 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
0);

// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
Expand Down
38 changes: 26 additions & 12 deletions deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
int* grouped_layout,
Expand All @@ -41,7 +41,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
// Scaling checks
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float> or cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid C/D data dtype");
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");

// Types
Expand All @@ -56,7 +56,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,

// Shared memory
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
Expand Down Expand Up @@ -93,7 +93,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });

// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
Expand Down Expand Up @@ -348,20 +348,34 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);

// Store to D shared memory
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
}
} else {
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
auto packed_0 = cast_into_bf16_and_pack(final_accum[i * 4 + 0], final_accum[i * 4 + 1]);
auto packed_1 = cast_into_bf16_and_pack(final_accum[i * 4 + 2], final_accum[i * 4 + 3]);
st_shared(reinterpret_cast<const uint32_t*>(smem_d + r_0 * BLOCK_N + col_idx * 2 + i * 8),
static_cast<uint32_t>(packed_0));
st_shared(reinterpret_cast<const uint32_t*>(smem_d + r_1 * BLOCK_N + col_idx * 2 + i * 8),
static_cast<uint32_t>(packed_1));
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);

// Use TMA store to write back to global memory
if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(
&tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N,
using tma_store_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
tma_store_t::copy(
&tensor_map_cd, smem_d + r_0 * BLOCK_N + col_idx * 2, n_block_idx * BLOCK_N,
current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
cute::tma_store_arrive();
}
Expand Down