Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ sm100_store_cd(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma
uint32_t tmem_addr = tmem_base_addr + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
#ifdef __CUDACC_DEBUG__
tmem_addr |= ((epilogue_warp_idx % 4) * 32) << 16;
#endif
auto smem_ptr = smem_base_ptr + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ sm100_store_cd_swap_ab(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint3
uint32_t tmem_addr = tmem_base_addr +
s * STORE_BLOCK_M + // Store stage offset
i * kNumSwizzleAtomRows; // In-block offset
#ifdef __CUDACC_DEBUG__
tmem_addr |= ((epilogue_warp_idx % 4) * 32) << 16;
#endif
uint32_t values[kNumSwizzleAtomRows];

// Warps cooperatively write an atomic block to shared memory
Expand Down
3 changes: 3 additions & 0 deletions deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
#ifdef __CUDACC_DEBUG__
tmem_addr |= ((warp_idx % 4) * 32) << 16;
#endif
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
Expand Down
7 changes: 6 additions & 1 deletion deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,13 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
#ifdef __CUDACC_DEBUG__
uint32_t addr = tmem_addr | (((threadIdx.x / 32) % 4 * 32) << 16);
#else
const auto& addr = tmem_addr;
#endif
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
Loader::copy(addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,13 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
#ifdef __CUDACC_DEBUG__
uint32_t addr = tmem_addr | (((threadIdx.x / 32) % 4 * 32) << 16);
#else
const auto& addr = tmem_addr;
#endif
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
Loader::copy(addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
Expand Down
6 changes: 6 additions & 0 deletions deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Load from TMEM
uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M;
uint32_t values[ATOM_M];
#ifdef __CUDACC_DEBUG__
tmem_addr |= (warp_idx_in_wg * 32) << 16;
#endif
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
Expand Down Expand Up @@ -1141,6 +1144,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Load from TMEM using .16x256b shape to satisfy STSM layout requirements
// Start from lane index 0 and 16
uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
#ifdef __CUDACC_DEBUG__
tmem_addr |= (warp_idx_in_wg * 32) << 16;
#endif
uint32_t values[ATOM_M];
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
Expand Down
3 changes: 3 additions & 0 deletions deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,

// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
#ifdef __CUDACC_DEBUG__
tmem_addr |= ((epilogue_warp_idx % 4) * 32) << 16;
#endif
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
Expand Down
7 changes: 6 additions & 1 deletion deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,13 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
#ifdef __CUDACC_DEBUG__
uint32_t addr = tmem_addr | (((threadIdx.x / 32) % 4 * 32) << 16);
#else
const auto& addr = tmem_addr;
#endif
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
Loader::copy(addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
#ifdef __CUDACC_DEBUG__
uint32_t addr = tmem_addr | (((threadIdx.x / 32) % 4 * 32) << 16);
#else
const auto& addr = tmem_addr;
#endif
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
Loader::copy(addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
#ifdef __CUDACC_DEBUG__
tmem_addr |= ((warp_idx % 4) * 32) << 16;
#endif
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
Expand Down