From 6b1fb7954d432d6b51e59e8af70ba6c17dd2e417 Mon Sep 17 00:00:00 2001 From: lcy-seso Date: Wed, 4 Jun 2025 10:47:44 +0000 Subject: [PATCH] adjust indent from 4 to 2 characters. --- .clang-format | 4 +- benchmarks/cpp/flashattention/convert.cuh | 60 +- benchmarks/cpp/flashattention/copy.cuh | 918 +++++++++-------- benchmarks/cpp/flashattention/cutlass_fa.cuh | 803 ++++++++------- benchmarks/cpp/flashattention/main.cu | 158 +-- benchmarks/cpp/flashattention/reduce.cuh | 99 +- benchmarks/cpp/flashattention/util.hpp | 29 +- benchmarks/cpp/fused_two_gemms/bench.cu | 366 ++++--- .../cutlass_fused_two_gemms.cuh | 365 ++++--- benchmarks/cpp/fused_two_gemms/util.cuh | 186 ++-- benchmarks/cpp/g2s_copy/cutlass_copy.cuh | 178 ++-- benchmarks/cpp/g2s_copy/main.cu | 208 ++-- benchmarks/cpp/g2s_copy/tilefusion_copy.cuh | 24 +- benchmarks/cpp/gemm/bench.cu | 335 +++--- benchmarks/cpp/gemm/cutlass_gemm.cuh | 194 ++-- benchmarks/cpp/gemm/tilefusion_gemm.cuh | 195 ++-- benchmarks/cpp/gemm/util.cuh | 146 +-- benchmarks/utils/cpp/cuda_utils.cuh | 78 +- benchmarks/utils/cpp/cutlass/convert.cuh | 60 +- benchmarks/utils/cpp/cutlass/copy.cuh | 204 ++-- benchmarks/utils/cpp/cutlass/traits_base.cuh | 8 +- examples/101_gemm/01_gemm_global_reg/gemm.hpp | 100 +- examples/101_gemm/01_gemm_global_reg/main.cu | 195 ++-- examples/101_gemm/02_gemm_all_mem/gemm.hpp | 218 ++-- examples/101_gemm/02_gemm_all_mem/main.cu | 219 ++-- examples/101_gemm/util.hpp | 124 +-- include/cell/compute/broadcast.hpp | 194 ++-- include/cell/compute/gemm.hpp | 284 +++--- include/cell/compute/map.hpp | 63 +- include/cell/compute/math_functor.hpp | 176 ++-- include/cell/compute/reduce.hpp | 198 ++-- include/cell/copy/constants.hpp | 12 +- include/cell/copy/copy_atom.hpp | 293 +++--- include/cell/copy/global_to_register.hpp | 340 +++---- include/cell/copy/global_to_shared.hpp | 822 ++++++++------- include/cell/copy/register.hpp | 22 +- include/cell/copy/shared_to_register.hpp | 566 +++++------ include/cell/copy/sync.hpp | 8 +- include/cell/copy/vectorize.hpp | 70 +- include/cell/copy/warp.hpp | 403 ++++---- include/cell/mask.hpp | 166 ++- include/cell/pipeline.hpp | 158 +-- include/cell/warp.hpp | 4 +- include/config.hpp | 20 +- include/cuda_utils.hpp | 64 +- include/jit/common.hpp | 30 +- include/jit/compiler.hpp | 100 +- include/jit/config.hpp | 32 +- include/kernel_registry.hpp | 101 +- include/kernels/common.hpp | 8 +- include/kernels/dispatch_macros.hpp | 33 +- include/kernels/flash_attention_device.cuh | 476 +++++---- include/kernels/fused_two_gemms_device.cuh | 368 ++++--- include/kernels/gemm_device.cuh | 646 ++++++------ include/types/base.hpp | 130 ++- include/types/base_tile.hpp | 138 +-- include/types/global.hpp | 70 +- include/types/global_tile_iterator.hpp | 254 ++--- include/types/layout.hpp | 176 ++-- include/types/packing.hpp | 12 +- include/types/register.hpp | 86 +- include/types/shared.hpp | 190 ++-- include/types/shared_tile_iterator.hpp | 379 ++++--- include/types/swizzle.hpp | 266 ++--- include/types/tile_shape.hpp | 14 +- include/util/cuda_info.hpp | 28 +- include/util/cuda_timer.hpp | 49 +- include/util/debug.hpp | 12 +- include/util/math_utils.hpp | 38 +- include/util/print.hpp | 210 ++-- src/cuda_info.cc | 111 +- src/cuda_utils.cc | 46 +- src/jit/compiler.cc | 354 ++++--- src/kernels/flash_attn.cu | 228 ++--- src/kernels/fused_two_gemms.cu | 181 ++-- src/kernels/gemm.cu | 203 ++-- src/kernels/scatter_nd.cu | 129 ++- src/torch_bind.cc | 4 +- tests/cpp/cell/test_broadcast.cu | 116 +-- tests/cpp/cell/test_flash_attn.cu | 235 +++-- tests/cpp/cell/test_g2r_copy.cu | 426 ++++---- tests/cpp/cell/test_g2s_load.cu | 436 ++++---- tests/cpp/cell/test_gemm.cu | 544 +++++----- tests/cpp/cell/test_reduce.cu | 260 ++--- tests/cpp/cell/test_s2r_copy.cu | 405 ++++---- tests/cpp/cell/test_single_wmma.cu | 272 +++-- tests/cpp/cell/test_swizzled_copy.cu | 962 +++++++++--------- tests/cpp/common/test_utils.cc | 24 +- tests/cpp/jit/test_jit.cc | 153 ++- tests/cpp/test_unit.cc | 8 +- tests/cpp/types/test_fp8.cu | 483 +++++---- tests/cpp/types/test_gtile_iterator.cu | 184 ++-- tests/cpp/types/test_layout.cu | 146 +-- tests/cpp/types/test_stile_iterator.cu | 276 ++--- tests/cpp/types/test_swizzle.cu | 116 +-- tests/cpp/types/test_warp_base_tile_shape.cu | 240 ++--- 96 files changed, 9772 insertions(+), 9953 deletions(-) diff --git a/.clang-format b/.clang-format index 1261a85a..5423c9bc 100644 --- a/.clang-format +++ b/.clang-format @@ -3,9 +3,9 @@ BasedOnStyle: Google UseTab: Never ColumnLimit: 80 -IndentWidth: 4 +IndentWidth: 2 -AccessModifierOffset: -2 +AccessModifierOffset: -1 DerivePointerAlignment: false PointerAlignment: Left diff --git a/benchmarks/cpp/flashattention/convert.cuh b/benchmarks/cpp/flashattention/convert.cuh index 703513db..1f7e5510 100644 --- a/benchmarks/cpp/flashattention/convert.cuh +++ b/benchmarks/cpp/flashattention/convert.cuh @@ -16,36 +16,36 @@ using namespace cute; template CUTE_DEVICE auto convert_type(cute::Tensor const& tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - auto frag = - convert_op(*reinterpret_cast*>( - tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } template DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { - using namespace cute; - static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); - static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); - auto l = logical_divide(rowcol_layout, - Shape>>{}); + using namespace cute; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + auto l = logical_divide(rowcol_layout, + Shape>>{}); - return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), - get<0>(get<1>(get<1>(l)))), - get<1>(get<0>(l)), get<1>(get<1>(get<1>(l)))); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), + get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), get<1>(get<1>(get<1>(l)))); } DEVICE auto convert_layout_C_Aregs() { - using namespace cute; - auto layout_s = Layout, _2, _16>>{}; - auto l = logical_divide(layout_s, Shape{}); + using namespace cute; + auto layout_s = Layout, _2, _16>>{}; + auto l = logical_divide(layout_s, Shape{}); - return make_layout( - make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))), - get<1>(l), get<1>(get<2>(l))); + return make_layout( + make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))), + get<1>(l), get<1>(get<2>(l))); } /** @@ -53,21 +53,21 @@ DEVICE auto convert_layout_C_Aregs() { */ template DEVICE auto convert_layout_scores(LayoutType layout_s) { - using namespace cute; - static_assert(decltype(size<0>(layout_s))::value == 4); - static_assert(decltype(rank(layout_s))::value == 3); + using namespace cute; + static_assert(decltype(size<0>(layout_s))::value == 4); + static_assert(decltype(rank(layout_s))::value == 3); - auto l = logical_divide(layout_s, Shape<_2>{}); - return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), - make_layout(get<0>(get<0>(l)), get<2>(l))); + auto l = logical_divide(layout_s, Shape<_2>{}); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), + make_layout(get<0>(get<0>(l)), get<2>(l))); } template DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) { - using namespace cute; + using namespace cute; - auto l = logical_divide(layout_s, Shape>{}); - return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l))); + auto l = logical_divide(layout_s, Shape>{}); + return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l))); } } // namespace cutlass_wrapper diff --git a/benchmarks/cpp/flashattention/copy.cuh b/benchmarks/cpp/flashattention/copy.cuh index d25de1d5..f5baac72 100644 --- a/benchmarks/cpp/flashattention/copy.cuh +++ b/benchmarks/cpp/flashattention/copy.cuh @@ -17,245 +17,245 @@ namespace detail { template class G2SCopyQK { - public: - DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK, - TiledCopy tiled_copy, int gQ_stride, int sQ_stride, - int gK_stride, int sK_stride, int num_stage = 2) - : gQ(gQ), - sQ(sQ), - gK(gK), - sK(sK), - gQ_stride(gQ_stride), - sQ_stride(sQ_stride), - gK_stride(gK_stride), - sK_stride(sK_stride), - cur_iter(0), - cur_iter_sk(0), - num_stage(num_stage) {} - - /** - * @brief Update the pointer of the global K tensor. - * - * Since the K matrix is split along both the n and k dimensions, the - * pointer offset for the K matrix needs to be updated to the next kTN * kK - * position during the next n dimension iteration. - * - * @param gK_slice The stride in N dimension. - * @param gK_stride The stride in K dimension. - */ - DEVICE void update_tile_K(int gK_slice, int gK_stride) { - gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; - } - - /** - * @brief Reset the pointer of the global K tensor. - * - * The current function is called when `load_q_once` is true, i.e., when - * kTK == kK. In this case, the pointer of Q needs to be restored to the - * starting position. - * - * @param stride The stride in K dimension. - */ - DEVICE void reset_tile_Q(int stride) { sQ.data() = sQ.data() + (-stride); } - - /** - * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix - * only needs to be loaded once and does not require repeated loading, while - * the K matrix needs to be updated and loaded. - */ - DEVICE void prologue_K() { + public: + DEVICE G2SCopyQK(GQTensor& gQ, SQTensor& sQ, GKTensor& gK, SKTensor& sK, + TiledCopy tiled_copy, int gQ_stride, int sQ_stride, + int gK_stride, int sK_stride, int num_stage = 2) + : gQ(gQ), + sQ(sQ), + gK(gK), + sK(sK), + gQ_stride(gQ_stride), + sQ_stride(sQ_stride), + gK_stride(gK_stride), + sK_stride(sK_stride), + cur_iter(0), + cur_iter_sk(0), + num_stage(num_stage) {} + + /** + * @brief Update the pointer of the global K tensor. + * + * Since the K matrix is split along both the n and k dimensions, the + * pointer offset for the K matrix needs to be updated to the next kTN * kK + * position during the next n dimension iteration. + * + * @param gK_slice The stride in N dimension. + * @param gK_stride The stride in K dimension. + */ + DEVICE void update_tile_K(int gK_slice, int gK_stride) { + gK.data() = gK.data() + (-gK_stride) + gK_slice * gK_stride; + } + + /** + * @brief Reset the pointer of the global K tensor. + * + * The current function is called when `load_q_once` is true, i.e., when + * kTK == kK. In this case, the pointer of Q needs to be restored to the + * starting position. + * + * @param stride The stride in K dimension. + */ + DEVICE void reset_tile_Q(int stride) { sQ.data() = sQ.data() + (-stride); } + + /** + * @brief Preload the K matrix. When `load_q_once` is true, the Q matrix + * only needs to be loaded once and does not require repeated loading, while + * the K matrix needs to be updated and loaded. + */ + DEVICE void prologue_K() { #pragma unroll - for (int m = 0; m < size<1>(gK); ++m) { + for (int m = 0; m < size<1>(gK); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gK); ++k) { - cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); - } - } - - cute::cp_async_fence(); + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } - gK.data() = gK.data() + gK_stride; - sK.data() = sK.data() + sK_stride; + cute::cp_async_fence(); - if ((cur_iter_sk + 1) % num_stage == 0) { - sK.data() = sK.data() + (-sK_stride * num_stage); - } + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; - cur_iter_sk++; + if ((cur_iter_sk + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); } - DEVICE void prologue() { + cur_iter_sk++; + } + + DEVICE void prologue() { #pragma unroll - for (int m = 0; m < size<1>(gQ); ++m) { + for (int m = 0; m < size<1>(gQ); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gQ); ++k) { - cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); - } - } + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } #pragma unroll - for (int m = 0; m < size<1>(gK); ++m) { + for (int m = 0; m < size<1>(gK); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gK); ++k) { - cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); - } - } - - cute::cp_async_fence(); + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } - gQ.data() = gQ.data() + gQ_stride; - sQ.data() = sQ.data() + sQ_stride; - gK.data() = gK.data() + gK_stride; - sK.data() = sK.data() + sK_stride; + cute::cp_async_fence(); - // Circlically read SMEM Buffer - if ((cur_iter + 1) % num_stage == 0) { - sQ.data() = sQ.data() + (-sQ_stride * num_stage); - sK.data() = sK.data() + (-sK_stride * num_stage); - } + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; - cur_iter++; + // Circlically read SMEM Buffer + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); } - DEVICE void body() { + cur_iter++; + } + + DEVICE void body() { #pragma unroll - for (int m = 0; m < size<1>(gQ); ++m) { + for (int m = 0; m < size<1>(gQ); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gQ); ++k) { - cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); - } - } + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } #pragma unroll - for (int m = 0; m < size<1>(gK); ++m) { + for (int m = 0; m < size<1>(gK); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gK); ++k) { - cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); - } - } - - cute::cp_async_fence(); + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } + } - gQ.data() = gQ.data() + gQ_stride; - sQ.data() = sQ.data() + sQ_stride; - gK.data() = gK.data() + gK_stride; - sK.data() = sK.data() + sK_stride; + cute::cp_async_fence(); - if ((cur_iter + 1) % num_stage == 0) { - sQ.data() = sQ.data() + (-sQ_stride * num_stage); - sK.data() = sK.data() + (-sK_stride * num_stage); - } + gQ.data() = gQ.data() + gQ_stride; + sQ.data() = sQ.data() + sQ_stride; + gK.data() = gK.data() + gK_stride; + sK.data() = sK.data() + sK_stride; - cur_iter++; + if ((cur_iter + 1) % num_stage == 0) { + sQ.data() = sQ.data() + (-sQ_stride * num_stage); + sK.data() = sK.data() + (-sK_stride * num_stage); } - DEVICE void epilogue() { + cur_iter++; + } + + DEVICE void epilogue() { #pragma unroll - for (int m = 0; m < size<1>(gQ); ++m) { + for (int m = 0; m < size<1>(gQ); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gQ); ++k) { - cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); - } - } + for (int k = 0; k < size<2>(gQ); ++k) { + cute::copy(tiled_copy, gQ(_, m, k), sQ(_, m, k)); + } + } #pragma unroll - for (int m = 0; m < size<1>(gK); ++m) { + for (int m = 0; m < size<1>(gK); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gK); ++k) { - cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); - } - } - - cute::cp_async_fence(); + for (int k = 0; k < size<2>(gK); ++k) { + cute::copy(tiled_copy, gK(_, m, k), sK(_, m, k)); + } } - private: - GQTensor& gQ; - SQTensor& sQ; - GKTensor& gK; - SKTensor& sK; - TiledCopy tiled_copy; - int gQ_stride; - int sQ_stride; - int gK_stride; - int sK_stride; - int cur_iter; - int cur_iter_sk; - int num_stage; + cute::cp_async_fence(); + } + + private: + GQTensor& gQ; + SQTensor& sQ; + GKTensor& gK; + SKTensor& sK; + TiledCopy tiled_copy; + int gQ_stride; + int sQ_stride; + int gK_stride; + int sK_stride; + int cur_iter; + int cur_iter_sk; + int num_stage; }; template class G2SCopyV { - public: - DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy, - int gV_stride, int sV_stride, int num_stage = 2) - : gV(gV), - sV(sV), - gV_stride(gV_stride), - sV_stride(sV_stride), - cur_iter(0), - num_stage(num_stage) {} - - DEVICE void prologue() { + public: + DEVICE G2SCopyV(GVTensor& gV, SVTensor& sV, TiledCopy tiled_copy, + int gV_stride, int sV_stride, int num_stage = 2) + : gV(gV), + sV(sV), + gV_stride(gV_stride), + sV_stride(sV_stride), + cur_iter(0), + num_stage(num_stage) {} + + DEVICE void prologue() { #pragma unroll - for (int m = 0; m < size<1>(gV); ++m) { + for (int m = 0; m < size<1>(gV); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gV); ++k) { - cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); - } - } - - cute::cp_async_fence(); - gV.data() = gV.data() + gV_stride; - sV.data() = sV.data() + sV_stride; + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } - if ((cur_iter + 1) % num_stage == 0) { - sV.data() = sV.data() + (-sV_stride * num_stage); - } + cute::cp_async_fence(); + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; - cur_iter++; + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); } - DEVICE void body() { + cur_iter++; + } + + DEVICE void body() { #pragma unroll - for (int m = 0; m < size<1>(gV); ++m) { + for (int m = 0; m < size<1>(gV); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gV); ++k) { - cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); - } - } - - cute::cp_async_fence(); + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } + } - gV.data() = gV.data() + gV_stride; - sV.data() = sV.data() + sV_stride; + cute::cp_async_fence(); - if ((cur_iter + 1) % num_stage == 0) { - sV.data() = sV.data() + (-sV_stride * num_stage); - } + gV.data() = gV.data() + gV_stride; + sV.data() = sV.data() + sV_stride; - cur_iter++; + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); } - DEVICE void epilogue() { + cur_iter++; + } + + DEVICE void epilogue() { #pragma unroll - for (int m = 0; m < size<1>(gV); ++m) { + for (int m = 0; m < size<1>(gV); ++m) { #pragma unroll - for (int k = 0; k < size<2>(gV); ++k) { - cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); - } - } - cute::cp_async_fence(); + for (int k = 0; k < size<2>(gV); ++k) { + cute::copy(tiled_copy, gV(_, m, k), sV(_, m, k)); + } } - - private: - GVTensor& gV; - SVTensor& sV; - TiledCopy tiled_copy; - int gV_stride; - int sV_stride; - int cur_iter; - int num_stage; + cute::cp_async_fence(); + } + + private: + GVTensor& gV; + SVTensor& sV; + TiledCopy tiled_copy; + int gV_stride; + int sV_stride; + int cur_iter; + int num_stage; }; template class S2RPipelineQK { - public: - DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view, - RQCopyView& rQ_copy_view, SKTensor& sK, - RKMmaView& rK_mma_view, RKCopyView& rK_copy_view, - RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k, - TiledMma tiled_mma, int sQ_stride, int sK_stride, - int num_stage = 2) - : sQ(sQ), - rQ_mma_view(rQ_mma_view), - rQ_copy_view(rQ_copy_view), - sK(sK), - rK_mma_view(rK_mma_view), - rK_copy_view(rK_copy_view), - acc(acc), - copy_q(copy_q), - copy_k(copy_k), - tiled_mma(tiled_mma), - sQ_stride(sQ_stride), - sK_stride(sK_stride), - num_stage(num_stage), - cur_iter(0), - cur_iter_sq(0) {} - - DEVICE void prologue() { - cur_iter = 0; + public: + DEVICE S2RPipelineQK(SQTensor& sQ, RQMmaView& rQ_mma_view, + RQCopyView& rQ_copy_view, SKTensor& sK, + RKMmaView& rK_mma_view, RKCopyView& rK_copy_view, + RAccTensor& acc, TiledCopyQ copy_q, TiledCopyK copy_k, + TiledMma tiled_mma, int sQ_stride, int sK_stride, + int num_stage = 2) + : sQ(sQ), + rQ_mma_view(rQ_mma_view), + rQ_copy_view(rQ_copy_view), + sK(sK), + rK_mma_view(rK_mma_view), + rK_copy_view(rK_copy_view), + acc(acc), + copy_q(copy_q), + copy_k(copy_k), + tiled_mma(tiled_mma), + sQ_stride(sQ_stride), + sK_stride(sK_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sq(0) {} + + DEVICE void prologue() { + cur_iter = 0; + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + +#pragma unroll + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), acc); + } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; + + cur_iter++; + } + + DEVICE void body() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); #pragma unroll - for (int i = 0; i < size<2>(rK_mma_view); ++i) { - if (i < size<2>(rK_mma_view) - 1) { - cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); - cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); - } - cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), - acc); - } - sQ.data() = sQ.data() + sQ_stride; - sK.data() = sK.data() + sK_stride; - - cur_iter++; + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), acc); } + sQ.data() = sQ.data() + sQ_stride; + sK.data() = sK.data() + sK_stride; - DEVICE void body() { - cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); - cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); + } + + cur_iter++; + cur_iter_sq++; + } + + DEVICE void epilogue() { + cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); + cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); #pragma unroll - for (int i = 0; i < size<2>(rK_mma_view); ++i) { - if (i < size<2>(rK_mma_view) - 1) { - cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); - cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), - acc); - } - sQ.data() = sQ.data() + sQ_stride; - sK.data() = sK.data() + sK_stride; - - if ((cur_iter + 1) % num_stage == 0) { - sK.data() = sK.data() + (-sK_stride * num_stage); - } - - cur_iter++; - cur_iter_sq++; + for (int i = 0; i < size<2>(rK_mma_view); ++i) { + if (i < size<2>(rK_mma_view) - 1) { + cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); + cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), acc); } - DEVICE void epilogue() { - cute::copy(copy_q, sQ(_, _, _0{}), rQ_copy_view(_, _, _0{})); - cute::copy(copy_k, sK(_, _, _0{}), rK_copy_view(_, _, _0{})); + sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq); + sK.data() = sK.data() + sK_stride; -#pragma unroll - for (int i = 0; i < size<2>(rK_mma_view); ++i) { - if (i < size<2>(rK_mma_view) - 1) { - cute::copy(copy_q, sQ(_, _, i + 1), rQ_copy_view(_, _, i + 1)); - cute::copy(copy_k, sK(_, _, i + 1), rK_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, rQ_mma_view(_, _, i), rK_mma_view(_, _, i), - acc); - } - - sQ.data() = sQ.data() + (-sQ_stride * cur_iter_sq); - sK.data() = sK.data() + sK_stride; - - if ((cur_iter + 1) % num_stage == 0) { - sK.data() = sK.data() + (-sK_stride * num_stage); - } - - cur_iter++; - cur_iter_sq = 0; + if ((cur_iter + 1) % num_stage == 0) { + sK.data() = sK.data() + (-sK_stride * num_stage); } - private: - SQTensor& sQ; - RQMmaView& rQ_mma_view; - RQCopyView& rQ_copy_view; - SKTensor& sK; - RKMmaView& rK_mma_view; - RKCopyView& rK_copy_view; - RAccTensor& acc; - TiledCopyQ copy_q; - TiledCopyK copy_k; - TiledMma tiled_mma; - int sQ_stride; - int sK_stride; - int num_stage; - int cur_iter; - int cur_iter_sq; + cur_iter++; + cur_iter_sq = 0; + } + + private: + SQTensor& sQ; + RQMmaView& rQ_mma_view; + RQCopyView& rQ_copy_view; + SKTensor& sK; + RKMmaView& rK_mma_view; + RKCopyView& rK_copy_view; + RAccTensor& acc; + TiledCopyQ copy_q; + TiledCopyK copy_k; + TiledMma tiled_mma; + int sQ_stride; + int sK_stride; + int num_stage; + int cur_iter; + int cur_iter_sq; }; template class S2RPipelineV { - public: - DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view, - RVCopyView& rV_copy_view, RegAcc& acc, - TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride, - int num_stage = 2) - : sV(sV), - rV_mma_view(rV_mma_view), - rV_copy_view(rV_copy_view), - acc(acc), - tiled_copy(tiled_copy), - sV_stride(sV_stride), - num_stage(num_stage), - cur_iter(0), - cur_iter_sv(0) {} - - template - DEVICE void prologue(RegValue& value) { - cur_iter = 0; - cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + public: + DEVICE S2RPipelineV(SVTensor& sV, RVMmaView& rV_mma_view, + RVCopyView& rV_copy_view, RegAcc& acc, + TiledCopy tiled_copy, TiledMma tiled_mma, int sV_stride, + int num_stage = 2) + : sV(sV), + rV_mma_view(rV_mma_view), + rV_copy_view(rV_copy_view), + acc(acc), + tiled_copy(tiled_copy), + sV_stride(sV_stride), + num_stage(num_stage), + cur_iter(0), + cur_iter_sv(0) {} + + template + DEVICE void prologue(RegValue& value) { + cur_iter = 0; + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); #pragma unroll - for (int i = 0; i < size<2>(rV_mma_view); ++i) { - if (i < size<2>(rV_mma_view) - 1) { - cute::copy(tiled_copy, sV(_, _, i + 1), - rV_copy_view(_, _, i + 1)); - } - // TODO(KuangjuX): Understand this code. Why do we need to use - // `value(_, _, cur_iter * size<2>(rV_mma_view) + i)`? - cute::gemm(tiled_mma, - value(_, _, cur_iter * size<2>(rV_mma_view) + i), - rV_mma_view(_, _, i), acc); - } - - sV.data() = sV.data() + sV_stride; - cur_iter++; + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), rV_copy_view(_, _, i + 1)); + } + // TODO(KuangjuX): Understand this code. Why do we need to use + // `value(_, _, cur_iter * size<2>(rV_mma_view) + i)`? + cute::gemm(tiled_mma, value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); } - template - DEVICE void body(RegValue& value) { - cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + sV.data() = sV.data() + sV_stride; + cur_iter++; + } + + template + DEVICE void body(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); #pragma unroll - for (int i = 0; i < size<2>(rV_mma_view); ++i) { - if (i < size<2>(rV_mma_view) - 1) { - cute::copy(tiled_copy, sV(_, _, i + 1), - rV_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, - value(_, _, cur_iter * size<2>(rV_mma_view) + i), - rV_mma_view(_, _, i), acc); - } - - sV.data() = sV.data() + sV_stride; - if ((cur_iter + 1) % num_stage == 0) { - sV.data() = sV.data() + (-sV_stride * num_stage); - } - - cur_iter++; - cur_iter_sv++; + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); + } + + sV.data() = sV.data() + sV_stride; + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); } - template - DEVICE void epilogue(RegValue& value) { - cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); + cur_iter++; + cur_iter_sv++; + } + + template + DEVICE void epilogue(RegValue& value) { + cute::copy(tiled_copy, sV(_, _, _0{}), rV_copy_view(_, _, _0{})); #pragma unroll - for (int i = 0; i < size<2>(rV_mma_view); ++i) { - if (i < size<2>(rV_mma_view) - 1) { - cute::copy(tiled_copy, sV(_, _, i + 1), - rV_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, - value(_, _, cur_iter * size<2>(rV_mma_view) + i), - rV_mma_view(_, _, i), acc); - } - - sV.data() = sV.data() + (-sV_stride * cur_iter_sv); - - if ((cur_iter + 1) % num_stage == 0) { - sV.data() = sV.data() + (-sV_stride * num_stage); - } - - cur_iter++; - cur_iter_sv = 0; + for (int i = 0; i < size<2>(rV_mma_view); ++i) { + if (i < size<2>(rV_mma_view) - 1) { + cute::copy(tiled_copy, sV(_, _, i + 1), rV_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, value(_, _, cur_iter * size<2>(rV_mma_view) + i), + rV_mma_view(_, _, i), acc); } - private: - SVTensor& sV; - RVMmaView& rV_mma_view; - RVCopyView& rV_copy_view; - RegAcc& acc; - TiledCopy tiled_copy; - TiledMma tiled_mma; - int sV_stride; - int num_stage; - int cur_iter; - int cur_iter_sv; + sV.data() = sV.data() + (-sV_stride * cur_iter_sv); + + if ((cur_iter + 1) % num_stage == 0) { + sV.data() = sV.data() + (-sV_stride * num_stage); + } + + cur_iter++; + cur_iter_sv = 0; + } + + private: + SVTensor& sV; + RVMmaView& rV_mma_view; + RVCopyView& rV_copy_view; + RegAcc& acc; + TiledCopy tiled_copy; + TiledMma tiled_mma; + int sV_stride; + int num_stage; + int cur_iter; + int cur_iter_sv; }; } // namespace detail @@ -481,65 +472,65 @@ template DEVICE auto make_g2s_v(const Element* gV_ptr, Element* sV_ptr, int gV_stride) { - int tid = threadIdx.x; + int tid = threadIdx.x; - auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{}); - auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{}); + auto gV = make_tensor(make_gmem_ptr(gV_ptr), GlobalVLayout{}); + auto sV = make_tensor(make_smem_ptr(sV_ptr), SharedVLayout{}); - TiledCopy tiled_copy; + TiledCopy tiled_copy; - auto loader = tiled_copy.get_thread_slice(tid); + auto loader = tiled_copy.get_thread_slice(tid); - auto gVs = loader.partition_S(gV); - auto sVs = loader.partition_D(sV); + auto gVs = loader.partition_S(gV); + auto sVs = loader.partition_D(sV); - int sV_stride = size(sV); + int sV_stride = size(sV); #ifdef DEBUG - if (thread0()) { - printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); - } + if (thread0()) { + printf("gV_stride: %d, sV_stride: %d\n", gV_stride, sV_stride); + } #endif - detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); + detail::G2SCopyV copy_v(gVs, sVs, tiled_copy, gV_stride, sV_stride); - return copy_v; + return copy_v; } template DEVICE auto make_s2r_v(const Element* sV_ptr, SVLayout sV_layout, RegAcc& acc, SmemCopyAtom copy_atom, TiledMma tiled_mma) { - int tid = threadIdx.x; + int tid = threadIdx.x; - auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); + auto sV_ = make_tensor(make_smem_ptr(sV_ptr), sV_layout); - auto thr_mma = tiled_mma.get_thread_slice(tid); + auto thr_mma = tiled_mma.get_thread_slice(tid); - auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma); - auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid); + auto s2r_copy_v = make_tiled_copy_B(copy_atom, tiled_mma); + auto s2r_thr_copy_v = s2r_copy_v.get_thread_slice(tid); - auto sV = s2r_thr_copy_v.partition_S(sV_); + auto sV = s2r_thr_copy_v.partition_S(sV_); - auto rV_mma = thr_mma.partition_fragment_B(sV_); - auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); + auto rV_mma = thr_mma.partition_fragment_B(sV_); + auto rV_copy = s2r_thr_copy_v.retile_D(rV_mma); - int sV_stride = size(sV_); + int sV_stride = size(sV_); - detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v, - tiled_mma, sV_stride); + detail::S2RPipelineV s2r_pipeline_v(sV, rV_mma, rV_copy, acc, s2r_copy_v, + tiled_mma, sV_stride); - return s2r_pipeline_v; + return s2r_pipeline_v; } template DEVICE auto store_r2s_o(Element* sO_ptr, SOLayout sO_layout, RegO& o, SmemCopyAtom copy_atom, TiledMma tiled_mma) { - auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); + auto sO = make_tensor(make_smem_ptr(sO_ptr), sO_layout); - auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma); - auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x); + auto r2s_copy_o = make_tiled_copy_C(copy_atom, tiled_mma); + auto r2s_thr_copy_o = r2s_copy_o.get_thread_slice(threadIdx.x); - auto src = r2s_thr_copy_o.retile_S(o); - auto dst = r2s_thr_copy_o.partition_D(sO); + auto src = r2s_thr_copy_o.retile_S(o); + auto dst = r2s_thr_copy_o.partition_D(sO); - cute::copy(r2s_copy_o, src, dst); + cute::copy(r2s_copy_o, src, dst); } template (gO_partition); ++m) { + for (int m = 0; m < size<1>(gO_partition); ++m) { #pragma unroll - for (int n = 0; n < size<2>(gO_partition); ++n) { - cute::copy(tiled_copy, sO_partition(_, m, n), - gO_partition(_, m, n)); - } + for (int n = 0; n < size<2>(gO_partition); ++n) { + cute::copy(tiled_copy, sO_partition(_, m, n), gO_partition(_, m, n)); } + } } } // namespace cutlass_wrapper diff --git a/benchmarks/cpp/flashattention/cutlass_fa.cuh b/benchmarks/cpp/flashattention/cutlass_fa.cuh index a8487ede..a8b3e145 100644 --- a/benchmarks/cpp/flashattention/cutlass_fa.cuh +++ b/benchmarks/cpp/flashattention/cutlass_fa.cuh @@ -21,78 +21,78 @@ template > struct FATraits : public Base { - using Element = Element_; - - static_assert(kTP == kP, "The current implementation requires kTP == P."); - static_assert(kSecondaryTN == kTN, - "The current implementation requires kSecondaryTN == kTN."); - static_assert(kM % kTM == 0, "kM must be a multiple of kTM."); - static_assert(kN % kTN == 0, "kN must be a multiple of kTN."); - static_assert(kK % kTK == 0, "kK must be a multiple of kTK."); - static_assert(kWarpPerCol == 1, - "The current implementation requires kWarpPerCol == 1."); - - // Declare global to shared memory copy layout. - using GmemLayoutQ = Layout, Int>, Stride, _1>>; - using GmemLayoutK = Layout, Int>, Stride, _1>>; - using GmemLayoutV = Layout, Int>, Stride, _1>>; - using GmemLayoutO = Layout, Int>, Stride, _1>>; - - static constexpr int kThreads = kWarpPerRow * kWarpPerCol * 32; - - /** - * Define the atomic layout of shared memory, which is the smallest - * configuration unit of shared memory. Larger shapes are tiled based on the - * atomic layout. - */ - using SmemLayoutAtom = decltype(composition( - Swizzle{}, - Layout>, Stride, _1>>{})); - - using SmemLayoutQ = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - using SmemLayoutO = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - - /** - * In the Ampere architecture, loading from shared memory to register memory - * requires the use of the `ldmatrix` instruction, while storing from - * register memory to shared memory does not have hardware support and uses - * a default copy instead.” - */ - using LoadS2RCopyAtom = Copy_Atom; - using StoreR2SCopyAtom = Copy_Atom; - - static constexpr int kWarps = kThreads / 32; - - using TiledMma = - TiledMMA, - Layout, Int, _1>>, - Tile, Int<16 * kWarpPerCol>, _16>>; + using Element = Element_; + + static_assert(kTP == kP, "The current implementation requires kTP == P."); + static_assert(kSecondaryTN == kTN, + "The current implementation requires kSecondaryTN == kTN."); + static_assert(kM % kTM == 0, "kM must be a multiple of kTM."); + static_assert(kN % kTN == 0, "kN must be a multiple of kTN."); + static_assert(kK % kTK == 0, "kK must be a multiple of kTK."); + static_assert(kWarpPerCol == 1, + "The current implementation requires kWarpPerCol == 1."); + + // Declare global to shared memory copy layout. + using GmemLayoutQ = Layout, Int>, Stride, _1>>; + using GmemLayoutK = Layout, Int>, Stride, _1>>; + using GmemLayoutV = Layout, Int>, Stride, _1>>; + using GmemLayoutO = Layout, Int>, Stride, _1>>; + + static constexpr int kThreads = kWarpPerRow * kWarpPerCol * 32; + + /** + * Define the atomic layout of shared memory, which is the smallest + * configuration unit of shared memory. Larger shapes are tiled based on the + * atomic layout. + */ + using SmemLayoutAtom = decltype(composition( + Swizzle{}, + Layout>, Stride, _1>>{})); + + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + + /** + * In the Ampere architecture, loading from shared memory to register memory + * requires the use of the `ldmatrix` instruction, while storing from + * register memory to shared memory does not have hardware support and uses + * a default copy instead.” + */ + using LoadS2RCopyAtom = Copy_Atom; + using StoreR2SCopyAtom = Copy_Atom; + + static constexpr int kWarps = kThreads / 32; + + using TiledMma = + TiledMMA, + Layout, Int, _1>>, + Tile, Int<16 * kWarpPerCol>, _16>>; #ifdef CP_ASYNC_SM80_ENABLED - // for Ampere - using CopyInstG2S = - Copy_Atom, Element>; + // for Ampere + using CopyInstG2S = + Copy_Atom, Element>; #else - using CopyInstG2S = Copy_Atom; + using CopyInstG2S = Copy_Atom; #endif - // TODO(KuangjuX): Understand this configuration. - using GmemCopyLayoutAtom = - Layout, Int>, - Stride, _1>>; + // TODO(KuangjuX): Understand this configuration. + using GmemCopyLayoutAtom = + Layout, Int>, + Stride, _1>>; - using TiledCopyG2S = decltype(make_tiled_copy( - CopyInstG2S{}, GmemCopyLayoutAtom{}, Layout>{})); + using TiledCopyG2S = decltype(make_tiled_copy( + CopyInstG2S{}, GmemCopyLayoutAtom{}, Layout>{})); - using TiledCopyS2G = decltype(make_tiled_copy( - Copy_Atom{}, GmemCopyLayoutAtom{}, - Layout>{})); + using TiledCopyS2G = + decltype(make_tiled_copy(Copy_Atom{}, + GmemCopyLayoutAtom{}, Layout>{})); }; template (buf_); - - const Element* Q = dQ + blockIdx.z * kTM * kN + blockIdx.x * kTM * kK; - const Element* K = dK + blockIdx.z * kK * kN; - const Element* V = dV + blockIdx.z * kP * kN + blockIdx.y * kTP * kN; - Element* O = - dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP; - - Element* sQ_ptr = reinterpret_cast(buf); - Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK; - Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK; - Element* sO_ptr = sQ_ptr; - - typename KeTraits::TiledMma mma; - typename KeTraits::TiledCopyG2S tiled_copy_g2s; - - // Build the copy plan for QK from global memory to shared memory. - auto g2s_copy_qk = make_g2s_qk< - Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ, - typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK, - typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK); + // constexpr float softmax_scale = 1.250000e-01f; + // TODO(KuangjuX): Use a fixed value for easy comparison. + constexpr float softmax_scale = 1.0f; + const bool load_q_once = (kTK == kK); + + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + + const Element* Q = dQ + blockIdx.z * kTM * kN + blockIdx.x * kTM * kK; + const Element* K = dK + blockIdx.z * kK * kN; + const Element* V = dV + blockIdx.z * kP * kN + blockIdx.y * kTP * kN; + Element* O = + dO + blockIdx.z * kM * kP + blockIdx.x * (kTM * kP) + blockIdx.y * kTP; + + Element* sQ_ptr = reinterpret_cast(buf); + Element* sK_ptr = sQ_ptr + kTM * kTK * kStagesQK; + Element* sV_ptr = sK_ptr + kTN * kTK * kStagesQK; + Element* sO_ptr = sQ_ptr; + + typename KeTraits::TiledMma mma; + typename KeTraits::TiledCopyG2S tiled_copy_g2s; + + // Build the copy plan for QK from global memory to shared memory. + auto g2s_copy_qk = make_g2s_qk< + Element, typename KeTraits::GmemLayoutQ, typename KeTraits::SmemLayoutQ, + typename KeTraits::GmemLayoutK, typename KeTraits::SmemLayoutK, + typename KeTraits::TiledCopyG2S>(Q, sQ_ptr, K, sK_ptr, kTK, kTK); + + /** + * In FractalTensor, The size of the V matrix is [kN, kP], and the size + * processed in a single SM Block is [kN, kTP]. When split along the N + * dimension, the size is [kTN, kTP]. Therefore, the stride for global + * memory should be set to kTN * kP. + * + * In the current implementation, the shape of the V matrix is [kP, kN], and + * the block size processed by a single Block is [kTP, kN]. Therefore, the + * stride only needs to be set to kTN each time. + */ + auto g2s_copy_v = make_g2s_v(V, sV_ptr, kTN); + + auto acc0 = get_acc(mma); + auto acco = get_acc(mma); - /** - * In FractalTensor, The size of the V matrix is [kN, kP], and the size - * processed in a single SM Block is [kN, kTP]. When split along the N - * dimension, the size is [kTN, kTP]. Therefore, the stride for global - * memory should be set to kTN * kP. - * - * In the current implementation, the shape of the V matrix is [kP, kN], and - * the block size processed by a single Block is [kTP, kN]. Therefore, the - * stride only needs to be set to kTN each time. - */ - auto g2s_copy_v = - make_g2s_v(V, sV_ptr, kTN); +#ifdef DEBUG + if (thread0()) { + printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", (int)size<0>(acc0), + (int)size<1>(acc0), (int)size<2>(acc0)); + printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n", (int)size<0>(acco), + (int)size<1>(acco), (int)size<2>(acco)); + } +#endif + + /** + * In TileFusion, we use + * ```cpp + * using RegVec = RegTile>; + * ``` + * We need to store the reduce results for both the top row and the bottom + * row simultaneously. + */ + + auto m_new = make_tensor(Shape(acc0)>>{}); + auto lse_new = make_fragment_like(m_new); + + auto s2r_pipeline_qk = + make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{}, + typename KeTraits::SmemLayoutK{}, acc0, + typename KeTraits::LoadS2RCopyAtom{}, mma); + + auto s2r_pipeline_v = + make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, + typename KeTraits::LoadS2RCopyAtom{}, mma); + + // Issue global to shared memory copy before the main loop. + g2s_copy_qk.prologue(); + + fill(lse_new, 0.0f); + fill(m_new, -INFINITY); + clear(acco); + + /** + * Flash Attention performs two-level tiling for each SM Block, splitting + * along the N dimension and the K dimension. The Q matrix is split along + * the K dimension, the V matrix is split along the N dimension, and the K + * matrix is split along both dimensions simultaneously. + */ + // TODO(KuangjuX): Add unroll for last iteration. + int split_n = kN / kTN; + for (int n = 0; n < split_n - (kUnrollLastIteration ? 1 : 0); ++n) { + clear(acc0); + + // When `load_q_once` is true, the following code is not executed. + int slice_k = kK / kTK - 1; + for (int k = 0; k < slice_k; ++k) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_qk.body(); + // Load data from shared memory into register and issue MMA. + s2r_pipeline_qk.body(); + } + + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.prologue(); + // When `load_q_once` is true, `g2s_copy_qk.prologue()` is executed only + // once, and `s2r_pipeline_qk.epilogue()` is executed once as well. + s2r_pipeline_qk.epilogue(); + + // scores = dot(q, k) + auto scores = + make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); + + auto m_old = make_fragment_like(m_new); + copy(m_new, m_old); + + auto scores_max = make_fragment_like(m_new); + + // scores_max = reduce_max(scores, axis=1) + reduce_max<4, true>(scores, scores_max); + + // Compute new partial max value. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = max(m_new(ax0), scores_max(ax0)); + } - auto acc0 = get_acc(mma); - auto acco = get_acc(mma); + // Currently, `acco` stores the results from the previous iteration's + // computation. + auto previous_attn_block = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); #ifdef DEBUG if (thread0()) { - printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", - (int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); - printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n", - (int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco)); + printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores), + (int)size<1>(scores)); + printf("previous_attn_block size<0>: %d, size<1>: %d\n", + (int)size<0>(previous_attn_block), + (int)size<1>(previous_attn_block)); } #endif - /** - * In TileFusion, we use - * ```cpp - * using RegVec = RegTile>; - * ``` - * We need to store the reduce results for both the top row and the bottom - * row simultaneously. - */ - - auto m_new = make_tensor(Shape(acc0)>>{}); - auto lse_new = make_fragment_like(m_new); + // Renormalization for the previous block. + for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { + // Compute `acc_o_scale = exp(m_i - m_ij)` + float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * scale; + // Compute `acc_o = acc_o_scale * acc_o` + for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { + previous_attn_block(ax0, ax1) *= scale; + } + } - auto s2r_pipeline_qk = - make_s2r_qk(sQ_ptr, sK_ptr, typename KeTraits::SmemLayoutQ{}, - typename KeTraits::SmemLayoutK{}, acc0, - typename KeTraits::LoadS2RCopyAtom{}, mma); + for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { + // Compute `p = exp(qk - m_ij)` + float m_scaled = m_new(ax0) * softmax_scale; + for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { + scores(ax0, ax1) = exp(scores(ax0, ax1) * softmax_scale - m_scaled); + } + } - auto s2r_pipeline_v = - make_s2r_v(sV_ptr, typename KeTraits::SmemLayoutV{}, acco, - typename KeTraits::LoadS2RCopyAtom{}, mma); + // Compute `l_ij = sum(p)`. + auto scores_sum = make_fragment_like(lse_new); + reduce_sum<4>(scores, scores_sum); - // Issue global to shared memory copy before the main loop. - g2s_copy_qk.prologue(); + // Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. + for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { + lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); + } - fill(lse_new, 0.0f); - fill(m_new, -INFINITY); - clear(acco); + // TODO(KuangjuX): Understand the following code. + auto frag = convert_type(scores); + auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); + auto rP_Aregs = + make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); /** - * Flash Attention performs two-level tiling for each SM Block, splitting - * along the N dimension and the K dimension. The Q matrix is split along - * the K dimension, the V matrix is split along the N dimension, and the K - * matrix is split along both dimensions simultaneously. + * In FractalTensor, the `kTN` dimension is split again. To simplify the + * current implementation of rhe pipeline flashattention, the `tile_n` + * is hardcoded to 0 at this point. */ - // TODO(KuangjuX): Add unroll for last iteration. - int split_n = kN / kTN; - for (int n = 0; n < split_n - (kUnrollLastIteration ? 1 : 0); ++n) { - clear(acc0); - - // When `load_q_once` is true, the following code is not executed. - int slice_k = kK / kTK - 1; - for (int k = 0; k < slice_k; ++k) { - // Barrier to ensure all data are loaded into shared memory. - cp_async_wait_flash<0>(); - __syncthreads(); - g2s_copy_qk.body(); - // Load data from shared memory into register and issue MMA. - s2r_pipeline_qk.body(); - } - - cp_async_wait_flash<0>(); - __syncthreads(); - g2s_copy_v.prologue(); - // When `load_q_once` is true, `g2s_copy_qk.prologue()` is executed only - // once, and `s2r_pipeline_qk.epilogue()` is executed once as well. - s2r_pipeline_qk.epilogue(); - - // scores = dot(q, k) - auto scores = - make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); - - auto m_old = make_fragment_like(m_new); - copy(m_new, m_old); - - auto scores_max = make_fragment_like(m_new); - - // scores_max = reduce_max(scores, axis=1) - reduce_max<4, true>(scores, scores_max); - - // Compute new partial max value. - for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { - m_new(ax0) = max(m_new(ax0), scores_max(ax0)); - } - - // Currently, `acco` stores the results from the previous iteration's - // computation. - auto previous_attn_block = - make_tensor(acco.data(), convert_layout_scores(acco.layout())); - -#ifdef DEBUG - if (thread0()) { - printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores), - (int)size<1>(scores)); - printf("previous_attn_block size<0>: %d, size<1>: %d\n", - (int)size<0>(previous_attn_block), - (int)size<1>(previous_attn_block)); - } -#endif + int secondary_tile_n = kSecondaryTN / kTN - 1; + for (int tile_ = 0; tile_ < secondary_tile_n; ++tile_) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.body(); + s2r_pipeline_v.body(rP_Aregs); + } - // Renormalization for the previous block. - for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { - // Compute `acc_o_scale = exp(m_i - m_ij)` - float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); - lse_new(ax0) = lse_new(ax0) * scale; - // Compute `acc_o = acc_o_scale * acc_o` - for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { - previous_attn_block(ax0, ax1) *= scale; - } - } - - for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { - // Compute `p = exp(qk - m_ij)` - float m_scaled = m_new(ax0) * softmax_scale; - for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { - scores(ax0, ax1) = - exp(scores(ax0, ax1) * softmax_scale - m_scaled); - } - } - - // Compute `l_ij = sum(p)`. - auto scores_sum = make_fragment_like(lse_new); - reduce_sum<4>(scores, scores_sum); - - // Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. - for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { - lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); - } - - // TODO(KuangjuX): Understand the following code. - auto frag = convert_type(scores); - auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); - auto rP_Aregs = - make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); + cp_async_wait_flash<0>(); + __syncthreads(); + if (n < split_n - 1) { + /** + * Update K tile because the entire K Block will be processed in a + * single SM Block. + * + * For example, In `TileFusion`: + * ```cpp + * for (int n = 0; n < GIteratorV::sc0; ++n) { + * load_sv(gVs(n), sV); + * for (int k = 0; k < GIteratorQ::sc1; ++k) { + * load_sq(gQs(k), sQ); + * load_sk(gKs(k, n), sK); + * } + * } + * ``` + */ + g2s_copy_qk.update_tile_K(kTN, kK); + /** + * `load_q_once` means that at this point `kK == kTK`, and the Q is + * loaded into shared memory in blocks only once. In this case, we + * only need to update the pointer of K and do not need to update + * the pointer for Q, because the blocking along the k dimension + * will not be executed, thus the Q is always reloaded. + */ + if (load_q_once) { + g2s_copy_qk.prologue_K(); + } else { /** - * In FractalTensor, the `kTN` dimension is split again. To simplify the - * current implementation of rhe pipeline flashattention, the `tile_n` - * is hardcoded to 0 at this point. + * In this case, we need to reset thr pointer of Q to the + * starting position and simultaneously preload the Q and K. */ - int secondary_tile_n = kSecondaryTN / kTN - 1; - for (int tile_ = 0; tile_ < secondary_tile_n; ++tile_) { - // Barrier to ensure all data are loaded into shared memory. - cp_async_wait_flash<0>(); - __syncthreads(); - g2s_copy_v.body(); - s2r_pipeline_v.body(rP_Aregs); - } - - cp_async_wait_flash<0>(); - __syncthreads(); - - if (n < split_n - 1) { - /** - * Update K tile because the entire K Block will be processed in a - * single SM Block. - * - * For example, In `TileFusion`: - * ```cpp - * for (int n = 0; n < GIteratorV::sc0; ++n) { - * load_sv(gVs(n), sV); - * for (int k = 0; k < GIteratorQ::sc1; ++k) { - * load_sq(gQs(k), sQ); - * load_sk(gKs(k, n), sK); - * } - * } - * ``` - */ - g2s_copy_qk.update_tile_K(kTN, kK); - /** - * `load_q_once` means that at this point `kK == kTK`, and the Q is - * loaded into shared memory in blocks only once. In this case, we - * only need to update the pointer of K and do not need to update - * the pointer for Q, because the blocking along the k dimension - * will not be executed, thus the Q is always reloaded. - */ - if (load_q_once) { - g2s_copy_qk.prologue_K(); - } else { - /** - * In this case, we need to reset thr pointer of Q to the - * starting position and simultaneously preload the Q and K. - */ - g2s_copy_qk.reset_tile_Q(kK); - g2s_copy_qk.prologue(); - } - } - - // Compute `acc_o = acc_o + dot(p, v)` - s2r_pipeline_v.epilogue(rP_Aregs); - - // Compute `lse_i = m_ij + log(l_i_new)`. - for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { - m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); - } + g2s_copy_qk.reset_tile_Q(kK); + g2s_copy_qk.prologue(); + } } - if (kUnrollLastIteration) { - clear(acc0); - - // When `load_q_once` is true, the following code is not executed. - int slice_k = kK / kTK - 1; - for (int k = 0; k < slice_k; ++k) { - // Barrier to ensure all data are loaded into shared memory. - cp_async_wait_flash<0>(); - __syncthreads(); - g2s_copy_qk.body(); - // Load data from shared memory into register and issue MMA. - s2r_pipeline_qk.body(); - } - - cp_async_wait_flash<0>(); - __syncthreads(); - g2s_copy_v.prologue(); - s2r_pipeline_qk.epilogue(); - - // scores = dot(q, k) - auto scores = - make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); - - auto m_old = make_fragment_like(m_new); - copy(m_new, m_old); - - auto scores_max = make_fragment_like(m_new); - - // scores_max = reduce_max(scores, axis=1) - reduce_max<4, true>(scores, scores_max); - - // Compute new partial max value. - for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { - m_new(ax0) = max(m_new(ax0), scores_max(ax0)); - } - - // Currently, `acco` stores the results from the previous iteration's - // computation. - auto previous_attn_block = - make_tensor(acco.data(), convert_layout_scores(acco.layout())); - - // Renormalization for the previous block. - for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { - // Compute `acc_o_scale = exp(m_i - m_ij)` - float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); - lse_new(ax0) = lse_new(ax0) * scale; - // Compute `acc_o = acc_o_scale * acc_o` - for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { - previous_attn_block(ax0, ax1) *= scale; - } - } - - for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { - // Compute `p = exp(qk - m_ij)` - float m_scaled = m_new(ax0) * softmax_scale; - for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { - scores(ax0, ax1) = - exp(scores(ax0, ax1) * softmax_scale - m_scaled); - } - } - - // Compute `l_ij = sum(p)`. - auto scores_sum = make_fragment_like(lse_new); - reduce_sum<4>(scores, scores_sum); - - // Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. - for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { - lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); - } - - // TODO(KuangjuX): Understand the following code. - auto frag = convert_type(scores); - auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); - auto rP_Aregs = - make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); + // Compute `acc_o = acc_o + dot(p, v)` + s2r_pipeline_v.epilogue(rP_Aregs); - /** - * In FractalTensor, the `kTN` dimension is split again. To simplify the - * current implementation of the pipeline flashattention, the `tile_n` - * is hardcoded to 0 at this point. - */ - int secondary_tile_n = kTN / kSecondaryTN - 1; - for (int tile_ = 0; tile_ < secondary_tile_n; ++tile_) { - // Barrier to ensure all data are loaded into shared memory. - cp_async_wait_flash<0>(); - __syncthreads(); - g2s_copy_v.body(); - s2r_pipeline_v.body(rP_Aregs); - } - - cp_async_wait_flash<0>(); - __syncthreads(); - - // Compute `acc_o = acc_o + dot(p, v)` - s2r_pipeline_v.epilogue(rP_Aregs); - - // Compute `lse_i = m_ij + log(l_i_new)`. - for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { - m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); - } + // Compute `lse_i = m_ij + log(l_i_new)`. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + } + } + + if (kUnrollLastIteration) { + clear(acc0); + + // When `load_q_once` is true, the following code is not executed. + int slice_k = kK / kTK - 1; + for (int k = 0; k < slice_k; ++k) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_qk.body(); + // Load data from shared memory into register and issue MMA. + s2r_pipeline_qk.body(); + } + + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.prologue(); + s2r_pipeline_qk.epilogue(); + + // scores = dot(q, k) + auto scores = + make_tensor(acc0.data(), convert_layout_scores(acc0.layout())); + + auto m_old = make_fragment_like(m_new); + copy(m_new, m_old); + + auto scores_max = make_fragment_like(m_new); + + // scores_max = reduce_max(scores, axis=1) + reduce_max<4, true>(scores, scores_max); + + // Compute new partial max value. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = max(m_new(ax0), scores_max(ax0)); } - // Normalize the attention block. - auto attn_block = + // Currently, `acco` stores the results from the previous iteration's + // computation. + auto previous_attn_block = make_tensor(acco.data(), convert_layout_scores(acco.layout())); - for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) { - // TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i - - // lse_i)`. - - // float scale = 1 / lse_new(ax0); - float o_scale = exp(m_new(ax0) - lse_new(ax0)); - // TODO(KuangjuX): Move this code into loop? - // lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); - for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) { - attn_block(ax0, ax1) *= o_scale; - } + + // Renormalization for the previous block. + for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { + // Compute `acc_o_scale = exp(m_i - m_ij)` + float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); + lse_new(ax0) = lse_new(ax0) * scale; + // Compute `acc_o = acc_o_scale * acc_o` + for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { + previous_attn_block(ax0, ax1) *= scale; + } + } + + for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { + // Compute `p = exp(qk - m_ij)` + float m_scaled = m_new(ax0) * softmax_scale; + for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { + scores(ax0, ax1) = exp(scores(ax0, ax1) * softmax_scale - m_scaled); + } + } + + // Compute `l_ij = sum(p)`. + auto scores_sum = make_fragment_like(lse_new); + reduce_sum<4>(scores, scores_sum); + + // Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. + for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { + lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); + } + + // TODO(KuangjuX): Understand the following code. + auto frag = convert_type(scores); + auto rP = make_tensor(make_rmem_ptr(&frag), scores.layout()); + auto rP_Aregs = + make_tensor(rP.data(), convert_layout_rowcol_Aregs(rP.layout())); + + /** + * In FractalTensor, the `kTN` dimension is split again. To simplify the + * current implementation of the pipeline flashattention, the `tile_n` + * is hardcoded to 0 at this point. + */ + int secondary_tile_n = kTN / kSecondaryTN - 1; + for (int tile_ = 0; tile_ < secondary_tile_n; ++tile_) { + // Barrier to ensure all data are loaded into shared memory. + cp_async_wait_flash<0>(); + __syncthreads(); + g2s_copy_v.body(); + s2r_pipeline_v.body(rP_Aregs); } - // Store O from registers to shared memory and then to global memory. - store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco, - typename KeTraits::StoreR2SCopyAtom{}, mma); + cp_async_wait_flash<0>(); __syncthreads(); - store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{}, - typename KeTraits::SmemLayoutO{}, - typename KeTraits::TiledCopyS2G{}); + // Compute `acc_o = acc_o + dot(p, v)` + s2r_pipeline_v.epilogue(rP_Aregs); + + // Compute `lse_i = m_ij + log(l_i_new)`. + for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { + m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + } + } + + // Normalize the attention block. + auto attn_block = + make_tensor(acco.data(), convert_layout_scores(acco.layout())); + for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) { + // TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i - + // lse_i)`. + + // float scale = 1 / lse_new(ax0); + float o_scale = exp(m_new(ax0) - lse_new(ax0)); + // TODO(KuangjuX): Move this code into loop? + // lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); + for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) { + attn_block(ax0, ax1) *= o_scale; + } + } + + // Store O from registers to shared memory and then to global memory. + store_r2s_o(sO_ptr, typename KeTraits::SmemLayoutO{}, acco, + typename KeTraits::StoreR2SCopyAtom{}, mma); + __syncthreads(); + + store_s2g_o(O, sO_ptr, typename KeTraits::GmemLayoutO{}, + typename KeTraits::SmemLayoutO{}, + typename KeTraits::TiledCopyS2G{}); } } // namespace cutlass_wrapper diff --git a/benchmarks/cpp/flashattention/main.cu b/benchmarks/cpp/flashattention/main.cu index b2a71e35..ccba7816 100644 --- a/benchmarks/cpp/flashattention/main.cu +++ b/benchmarks/cpp/flashattention/main.cu @@ -8,114 +8,114 @@ template void run(bool check = true) { - using InType = cutlass::half_t; - using AccType = cutlass::half_t; - using OutType = cutlass::half_t; + using InType = cutlass::half_t; + using AccType = cutlass::half_t; + using OutType = cutlass::half_t; - // Currently `kBatch` is fixed to 1. - static constexpr int kBatch = 1; - static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; + // Currently `kBatch` is fixed to 1. + static constexpr int kBatch = 1; + static constexpr int kThreads = kWarpPerCol * kWarpPerRow * 32; - static_assert(kP == kTP, - "The current implementation requires kTP == P for now."); + static_assert(kP == kTP, + "The current implementation requires kTP == P for now."); - // initialize data - thrust::host_vector h_a(kM * kK * kBatch); + // initialize data + thrust::host_vector h_a(kM * kK * kBatch); - for (int i = 0; i < h_a.size(); ++i) - h_a[i] = static_cast(rand_float()); + for (int i = 0; i < h_a.size(); ++i) + h_a[i] = static_cast(rand_float()); - thrust::host_vector h_b(kK * kN * kBatch); - for (int i = 0; i < h_b.size(); ++i) - h_b[i] = static_cast(rand_float()); + thrust::host_vector h_b(kK * kN * kBatch); + for (int i = 0; i < h_b.size(); ++i) + h_b[i] = static_cast(rand_float()); - thrust::host_vector h_c(kN * kP * kBatch); - for (int i = 0; i < h_c.size(); ++i) - h_c[i] = static_cast(rand_float()); + thrust::host_vector h_c(kN * kP * kBatch); + for (int i = 0; i < h_c.size(); ++i) + h_c[i] = static_cast(rand_float()); - thrust::host_vector h_d(kM * kP * kBatch); - thrust::fill(h_d.begin(), h_d.end(), 0.); + thrust::host_vector h_d(kM * kP * kBatch); + thrust::fill(h_d.begin(), h_d.end(), 0.); - // Host side memory initialization. - thrust::host_vector acc(kM * kN * kBatch); - thrust::fill(acc.begin(), acc.end(), 0.); + // Host side memory initialization. + thrust::host_vector acc(kM * kN * kBatch); + thrust::fill(acc.begin(), acc.end(), 0.); - thrust::host_vector exp_values(kM * kP * kBatch); - thrust::fill(exp_values.begin(), exp_values.end(), 0.); + thrust::host_vector exp_values(kM * kP * kBatch); + thrust::fill(exp_values.begin(), exp_values.end(), 0.); - thrust::host_vector h_o(kM * kP * kBatch); - thrust::fill(h_o.begin(), h_o.end(), 0.); + thrust::host_vector h_o(kM * kP * kBatch); + thrust::fill(h_o.begin(), h_o.end(), 0.); - thrust::host_vector cur_row_max(kM * kBatch); - thrust::fill(cur_row_max.begin(), cur_row_max.end(), 0.); + thrust::host_vector cur_row_max(kM * kBatch); + thrust::fill(cur_row_max.begin(), cur_row_max.end(), 0.); - thrust::host_vector prev_row_max(kM * kBatch); - thrust::fill(prev_row_max.begin(), prev_row_max.end(), 0.); + thrust::host_vector prev_row_max(kM * kBatch); + thrust::fill(prev_row_max.begin(), prev_row_max.end(), 0.); - thrust::host_vector new_row_max(kM * kBatch); - thrust::fill(new_row_max.begin(), new_row_max.end(), 0.); + thrust::host_vector new_row_max(kM * kBatch); + thrust::fill(new_row_max.begin(), new_row_max.end(), 0.); - thrust::host_vector prev_norm_vec(kM * kBatch); - thrust::fill(prev_norm_vec.begin(), prev_norm_vec.end(), 0.); + thrust::host_vector prev_norm_vec(kM * kBatch); + thrust::fill(prev_norm_vec.begin(), prev_norm_vec.end(), 0.); - thrust::host_vector new_norm_vec(kM * kBatch); - thrust::fill(new_norm_vec.begin(), new_norm_vec.end(), 0.); + thrust::host_vector new_norm_vec(kM * kBatch); + thrust::fill(new_norm_vec.begin(), new_norm_vec.end(), 0.); - thrust::host_vector prev_sum_vec(kM * kBatch); - thrust::fill(prev_sum_vec.begin(), prev_sum_vec.end(), 0.); + thrust::host_vector prev_sum_vec(kM * kBatch); + thrust::fill(prev_sum_vec.begin(), prev_sum_vec.end(), 0.); - thrust::host_vector cur_sum_vec(kM * kBatch); - thrust::fill(cur_sum_vec.begin(), cur_sum_vec.end(), 0.); + thrust::host_vector cur_sum_vec(kM * kBatch); + thrust::fill(cur_sum_vec.begin(), cur_sum_vec.end(), 0.); - thrust::host_vector new_sum_vec(kM * kBatch); - thrust::fill(new_sum_vec.begin(), new_sum_vec.end(), 0.); + thrust::host_vector new_sum_vec(kM * kBatch); + thrust::fill(new_sum_vec.begin(), new_sum_vec.end(), 0.); - thrust::device_vector d_a = h_a; - thrust::device_vector d_b = h_b; - thrust::device_vector d_c = h_c; - thrust::device_vector d_d = h_d; + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_c = h_c; + thrust::device_vector d_d = h_d; - const InType* A = thrust::raw_pointer_cast(d_a.data()); - const InType* B = thrust::raw_pointer_cast(d_b.data()); - const InType* C = thrust::raw_pointer_cast(d_c.data()); - InType* D = thrust::raw_pointer_cast(d_d.data()); + const InType* A = thrust::raw_pointer_cast(d_a.data()); + const InType* B = thrust::raw_pointer_cast(d_b.data()); + const InType* C = thrust::raw_pointer_cast(d_c.data()); + InType* D = thrust::raw_pointer_cast(d_d.data()); - int block_x = (kM + kTM - 1) / kTM; - int block_y = (kP + kTP - 1) / kTP; - int block_z = kBatch; + int block_x = (kM + kTM - 1) / kTM; + int block_y = (kP + kTP - 1) / kTP; + int block_z = kBatch; - dim3 grid(block_x, block_y, block_z); - dim3 block(kThreads, 1, 1); + dim3 grid(block_x, block_y, block_z); + dim3 block(kThreads, 1, 1); - int shm_input = - (kTM * kTK * kStagesQK + kTK * kTN * kStagesQK + kTN * kTP * kStagesV); - int shm_output = kTM * kTP; - int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) - : shm_input * sizeof(InType); + int shm_input = + (kTM * kTK * kStagesQK + kTK * kTN * kStagesQK + kTN * kTP * kStagesV); + int shm_output = kTM * kTP; + int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) + : shm_input * sizeof(InType); - using Traits = - benchmarks::cutlass_wrapper::FATraits; + using Traits = + benchmarks::cutlass_wrapper::FATraits; - auto fa_kernel = - benchmarks::cutlass_wrapper::fa_kernel; + auto fa_kernel = + benchmarks::cutlass_wrapper::fa_kernel; - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - fa_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(fa_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } - fa_kernel<<>>(A, B, C, D); + fa_kernel<<>>(A, B, C, D); - cudaDeviceSynchronize(); + cudaDeviceSynchronize(); } int main() { - // - run<64, 64, 128, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); - // run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); + // + run<64, 64, 128, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); + // run<64, 64, 256, 128, 64, 64, 128, 128, 1, 1, 1, 1>(); } diff --git a/benchmarks/cpp/flashattention/reduce.cuh b/benchmarks/cpp/flashattention/reduce.cuh index f85d34d6..301c9eab 100644 --- a/benchmarks/cpp/flashattention/reduce.cuh +++ b/benchmarks/cpp/flashattention/reduce.cuh @@ -13,40 +13,37 @@ namespace cutlass_wrapper { using namespace cute; struct MaxOp_float { - DEVICE float operator()(float const& x, float const& y) { - return max(x, y); - } + DEVICE float operator()(float const& x, float const& y) { return max(x, y); } }; template struct SumOp { - DEVICE T operator()(T const& x, T const& y) { return x + y; } + DEVICE T operator()(T const& x, T const& y) { return x + y; } }; template struct SumAbsOp { - DEVICE T operator()(T const& x, T const& y) { return x + abs(y); } + DEVICE T operator()(T const& x, T const& y) { return x + abs(y); } }; template struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || - THREADS == 4); - template - static DEVICE T run(T x, Operator& op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static DEVICE T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } }; template <> struct Allreduce<2> { - template - static DEVICE T run(T x, Operator& op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; - } + template + static DEVICE T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } }; template const& tensor, cute::Tensor& summary, Operator& op) { - using namespace cute; - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + using namespace cute; + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = - zero_init ? op(0, tensor(mi, 0)) : op(summary(mi), tensor(mi, 0)); + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = + zero_init ? op(0, tensor(mi, 0)) : op(summary(mi), tensor(mi, 0)); #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); } + } } template DEVICE void quad_allreduce_(cute::Tensor& dst, cute::Tensor& src, Operator& op) { - using namespace cute; - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll - for (int i = 0; i < size(dst); i++) { - dst(i) = Allreduce<4>::run(src(i), op); - } + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } } template & dst, cute::Tensor& src, Operator& op) { - using namespace cute; - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll - for (int i = 0; i < size(dst); i++) { - dst(i) = Allreduce<8>::run(src(i), op); - } + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<8>::run(src(i), op); + } } template DEVICE void allreduce_(cute::Tensor& dst, cute::Tensor& src, Operator& op) { - using namespace cute; - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + using namespace cute; + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll - for (int i = 0; i < size(dst); i++) { - dst(i) = Allreduce::run(src(i), op); - } + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce::run(src(i), op); + } } template DEVICE void reduce_(cute::Tensor const& tensor, cute::Tensor& summary, Operator& op) { - thread_reduce_(tensor, summary, op); - allreduce_(summary, summary, op); + thread_reduce_(tensor, summary, op); + allreduce_(summary, summary, op); } template DEVICE void reduce_max(cute::Tensor const& tensor, cute::Tensor& max) { - MaxOp_float max_op; - reduce_(tensor, max, max_op); + MaxOp_float max_op; + reduce_(tensor, max, max_op); } template DEVICE void reduce_sum(cute::Tensor const& tensor, cute::Tensor& sum) { - SumOp sum_op; - reduce_(tensor, sum, sum_op); + SumOp sum_op; + reduce_(tensor, sum, sum_op); } template DEVICE void reduce_sumabs(cute::Tensor const& tensor, cute::Tensor& sum) { - SumAbsOp sumabs_op; - reduce_(tensor, sum, sumabs_op); + SumAbsOp sumabs_op; + reduce_(tensor, sum, sumabs_op); } } // namespace cutlass_wrapper diff --git a/benchmarks/cpp/flashattention/util.hpp b/benchmarks/cpp/flashattention/util.hpp index 1cc00eb4..1d96c1c8 100644 --- a/benchmarks/cpp/flashattention/util.hpp +++ b/benchmarks/cpp/flashattention/util.hpp @@ -9,24 +9,23 @@ #include float rand_float(float a = 1e-1, float b = 5e-2) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; } bool check_results(const __half* values1, const __half* values2, int numel) { - bool passed = true; - const float epsilon = 1e-1; + bool passed = true; + const float epsilon = 1e-1; - for (int i = 0; i < numel; ++i) { - if (fabs(__half2float(values1[i]) - __half2float(values2[i])) > - epsilon) { - printf("%d-th value differs: %.3f vs. %.3f\n", i, - __half2float(values1[i]), __half2float(values2[i])); - passed = false; - break; - } + for (int i = 0; i < numel; ++i) { + if (fabs(__half2float(values1[i]) - __half2float(values2[i])) > epsilon) { + printf("%d-th value differs: %.3f vs. %.3f\n", i, + __half2float(values1[i]), __half2float(values2[i])); + passed = false; + break; } - return passed; + } + return passed; } diff --git a/benchmarks/cpp/fused_two_gemms/bench.cu b/benchmarks/cpp/fused_two_gemms/bench.cu index 2fd36d58..fcceb5ce 100644 --- a/benchmarks/cpp/fused_two_gemms/bench.cu +++ b/benchmarks/cpp/fused_two_gemms/bench.cu @@ -11,238 +11,236 @@ using namespace tilefusion::kernels; template __attribute__((global)) void kernel_wrapper(const InType* A, const InType* B, const InType* C, InType* D) { - ke_fused_two_gemms(A, B, C, D); + ke_fused_two_gemms(A, B, C, D); } template void run(float epsilon = 1e-3) { - using InType = __half; - using AccType = float; + using InType = __half; + using AccType = float; - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; - static constexpr int kP = dim_size<3, WholeShape>; + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; + static constexpr int kP = dim_size<3, WholeShape>; + + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + static constexpr int kTK = dim_size<2, CtaTileShape>; + static constexpr int kTP = dim_size<3, CtaTileShape>; - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - static constexpr int kTK = dim_size<2, CtaTileShape>; - static constexpr int kTP = dim_size<3, CtaTileShape>; + static_assert(kK == kTK, "The current implementation requires kTK == K."); + static_assert(kP == kTP, "The current implementation requires kTP == P."); - static_assert(kK == kTK, "The current implementation requires kTK == K."); - static_assert(kP == kTP, "The current implementation requires kTP == P."); + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; + thrust::host_vector h_a(kM * kK * kBatch); - thrust::host_vector h_a(kM * kK * kBatch); + for (int i = 0; i < h_a.size(); ++i) { + h_a[i] = static_cast(rand_float()); + } - for (int i = 0; i < h_a.size(); ++i) { - h_a[i] = static_cast(rand_float()); - } + thrust::host_vector h_b(kK * kN * kBatch); + for (int i = 0; i < h_b.size(); ++i) { + h_b[i] = static_cast(rand_float()); + } - thrust::host_vector h_b(kK * kN * kBatch); - for (int i = 0; i < h_b.size(); ++i) { - h_b[i] = static_cast(rand_float()); - } + thrust::host_vector h_c(kN * kP * kBatch); + for (int i = 0; i < h_c.size(); ++i) { + h_c[i] = static_cast(rand_float()); + } - thrust::host_vector h_c(kN * kP * kBatch); - for (int i = 0; i < h_c.size(); ++i) { - h_c[i] = static_cast(rand_float()); - } + thrust::host_vector h_d(kM * kP * kBatch); + thrust::fill(h_d.begin(), h_d.end(), 0.); - thrust::host_vector h_d(kM * kP * kBatch); - thrust::fill(h_d.begin(), h_d.end(), 0.); + thrust::host_vector h_d2(kM * kP * kBatch); + thrust::fill(h_d2.begin(), h_d2.end(), 0.); - thrust::host_vector h_d2(kM * kP * kBatch); - thrust::fill(h_d2.begin(), h_d2.end(), 0.); + thrust::host_vector<__half> h_d3(kM * kP * kBatch); + thrust::fill(h_d3.begin(), h_d3.end(), 0.); - thrust::host_vector<__half> h_d3(kM * kP * kBatch); - thrust::fill(h_d3.begin(), h_d3.end(), 0.); + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_c = h_c; + thrust::device_vector d_d = h_d; + thrust::device_vector d_d2 = h_d2; + thrust::device_vector<__half> d_d3 = h_d3; - thrust::device_vector d_a = h_a; - thrust::device_vector d_b = h_b; - thrust::device_vector d_c = h_c; - thrust::device_vector d_d = h_d; - thrust::device_vector d_d2 = h_d2; - thrust::device_vector<__half> d_d3 = h_d3; + const cutlass::half_t* CA = thrust::raw_pointer_cast(d_a.data()); + const cutlass::half_t* CB = thrust::raw_pointer_cast(d_b.data()); + const cutlass::half_t* CC = thrust::raw_pointer_cast(d_c.data()); + cutlass::half_t* CD = thrust::raw_pointer_cast(d_d2.data()); - const cutlass::half_t* CA = thrust::raw_pointer_cast(d_a.data()); - const cutlass::half_t* CB = thrust::raw_pointer_cast(d_b.data()); - const cutlass::half_t* CC = thrust::raw_pointer_cast(d_c.data()); - cutlass::half_t* CD = thrust::raw_pointer_cast(d_d2.data()); + const InType* A = reinterpret_cast(CA); + const InType* B = reinterpret_cast(CB); + const InType* C = reinterpret_cast(CC); + InType* D = thrust::raw_pointer_cast(d_d.data()); - const InType* A = reinterpret_cast(CA); - const InType* B = reinterpret_cast(CB); - const InType* C = reinterpret_cast(CC); - InType* D = thrust::raw_pointer_cast(d_d.data()); + using Config = FusedTwoGemmsTraits; - using Config = FusedTwoGemmsTraits; + int block_x = CeilDiv; + int block_y = CeilDiv; + int block_z = kBatch; + static constexpr int kThreads = tl::get_numel * 32; - int block_x = CeilDiv; - int block_y = CeilDiv; - int block_z = kBatch; - static constexpr int kThreads = tl::get_numel * 32; + dim3 grid(block_x, block_y, block_z); + dim3 block(kThreads, 1, 1); - dim3 grid(block_x, block_y, block_z); - dim3 block(kThreads, 1, 1); + static constexpr int kShmInput = (kTM * kTK + kTK * kTN + kTN * kTP); + static constexpr int kShmOutput = kTM * kTP; + static constexpr int kSharedSize = kShmInput < kShmOutput + ? kShmOutput * sizeof(InType) + : kShmInput * sizeof(InType); - static constexpr int kShmInput = (kTM * kTK + kTK * kTN + kTN * kTP); - static constexpr int kShmOutput = kTM * kTP; - static constexpr int kSharedSize = kShmInput < kShmOutput - ? kShmOutput * sizeof(InType) - : kShmInput * sizeof(InType); + auto ke_tilefusion = &kernel_wrapper; - auto ke_tilefusion = &kernel_wrapper; + auto cutlass_fused_gemm = + &cute_fused_gemm; - auto cutlass_fused_gemm = - &cute_fused_gemm; + if (kSharedSize > 48 * 1024) { + cudaFuncSetAttribute(ke_tilefusion, + cudaFuncAttributeMaxDynamicSharedMemorySize, + kSharedSize); + } - if (kSharedSize > 48 * 1024) { - cudaFuncSetAttribute(ke_tilefusion, - cudaFuncAttributeMaxDynamicSharedMemorySize, - kSharedSize); - } + ke_tilefusion<<>>(A, B, C, D); + cudaDeviceSynchronize(); - ke_tilefusion<<>>(A, B, C, D); - cudaDeviceSynchronize(); - - h_d = d_d; + h_d = d_d; - cutlass_fused_gemm(CA, CB, CC, CD, false, 0, 0); - h_d2 = d_d2; + cutlass_fused_gemm(CA, CB, CC, CD, false, 0, 0); + h_d2 = d_d2; - thrust::host_vector h_acc(kM * kN * kBatch); - thrust::fill(h_acc.begin(), h_acc.end(), 0.); - thrust::device_vector d_acc = h_acc; + thrust::host_vector h_acc(kM * kN * kBatch); + thrust::fill(h_acc.begin(), h_acc.end(), 0.); + thrust::device_vector d_acc = h_acc; - cublas_two_gemms(kM, kN, kK, kP, kBatch, A, B, C, - thrust::raw_pointer_cast(d_d3.data()), - thrust::raw_pointer_cast(d_acc.data()), false); - cudaDeviceSynchronize(); - h_acc = d_acc; - h_d3 = d_d3; + cublas_two_gemms(kM, kN, kK, kP, kBatch, A, B, C, + thrust::raw_pointer_cast(d_d3.data()), + thrust::raw_pointer_cast(d_acc.data()), false); + cudaDeviceSynchronize(); + h_acc = d_acc; + h_d3 = d_d3; #ifdef DEBUG - InType* data = thrust::raw_pointer_cast(h_d.data()); - cutlass::half_t* cutlass_data = thrust::raw_pointer_cast(h_d2.data()); - __half* cutlass_data_half = reinterpret_cast<__half*>(cutlass_data); - __half* ground_truth = thrust::raw_pointer_cast(h_d3.data()); - - const int numel = 256; - printf("ours:\n"); - for (int i = 0; i < numel; ++i) { - printf("%.3f, ", __half2float(data[i])); - if (i && (i + 1) % 16 == 0) printf("\n"); - } - printf("cutlass:\n"); - for (int i = 0; i < numel; ++i) { - printf("%.3f, ", __half2float(cutlass_data_half[i])); - if (i && (i + 1) % 16 == 0) printf("\n"); - } - printf("\nground_truth:\n"); - for (int i = 0; i < numel; ++i) { - printf("%.3f, ", __half2float(ground_truth[i])); - if (i && (i + 1) % 16 == 0) printf("\n"); - } - - bool passed1 = check_results(data, ground_truth, kM * kP, epsilon); - bool passed2 = - check_results(cutlass_data_half, ground_truth, kM * kP, epsilon); - std::cout << "passed1: " << passed1 << ", passed2: " << passed2 - << std::endl; - - if (passed1 && passed2) { - std::cout << "[" << kM << ", " << kN << ", " << kK << ", " << kP - << "], batch = " << kBatch << ", passed." << std::endl; - } else { - std::cout << "[" << kM << ", " << kN << ", " << kK << ", " << kP - << "], batch = " << kBatch << ", failed." << std::endl; - } + InType* data = thrust::raw_pointer_cast(h_d.data()); + cutlass::half_t* cutlass_data = thrust::raw_pointer_cast(h_d2.data()); + __half* cutlass_data_half = reinterpret_cast<__half*>(cutlass_data); + __half* ground_truth = thrust::raw_pointer_cast(h_d3.data()); + + const int numel = 256; + printf("ours:\n"); + for (int i = 0; i < numel; ++i) { + printf("%.3f, ", __half2float(data[i])); + if (i && (i + 1) % 16 == 0) printf("\n"); + } + printf("cutlass:\n"); + for (int i = 0; i < numel; ++i) { + printf("%.3f, ", __half2float(cutlass_data_half[i])); + if (i && (i + 1) % 16 == 0) printf("\n"); + } + printf("\nground_truth:\n"); + for (int i = 0; i < numel; ++i) { + printf("%.3f, ", __half2float(ground_truth[i])); + if (i && (i + 1) % 16 == 0) printf("\n"); + } + + bool passed1 = check_results(data, ground_truth, kM * kP, epsilon); + bool passed2 = + check_results(cutlass_data_half, ground_truth, kM * kP, epsilon); + std::cout << "passed1: " << passed1 << ", passed2: " << passed2 << std::endl; + + if (passed1 && passed2) { + std::cout << "[" << kM << ", " << kN << ", " << kK << ", " << kP + << "], batch = " << kBatch << ", passed." << std::endl; + } else { + std::cout << "[" << kM << ", " << kN << ", " << kK << ", " << kP + << "], batch = " << kBatch << ", failed." << std::endl; + } #endif - CudaTimer timer; - const int warm_up = 10; - const int iters = 50; - - for (int i = 0; i < warm_up; ++i) { - ke_tilefusion<<>>(A, B, C, D); - } - cudaDeviceSynchronize(); - - timer.start(); - for (int i = 0; i < iters; ++i) { - ke_tilefusion<<>>(A, B, C, D); - } - cudaDeviceSynchronize(); - float tilefusion_time = timer.stop() / iters; - - float cutlass_time = - cutlass_fused_gemm(CA, CB, CC, CD, true, warm_up, iters); - - float cublas_time = cublas_two_gemms( - kM, kN, kK, kP, kBatch, A, B, C, thrust::raw_pointer_cast(d_d3.data()), - thrust::raw_pointer_cast(d_acc.data()), true); - - std::cout << "[" << kM << ", " << kN << ", " << kK << ", " << kP << "]\t[" - << kTM << ", " << kTN << ", " << kTK << ", " << kTP << "]\t" - << std::fixed << std::setprecision(4) << cublas_time << "\t" - << cutlass_time << "(" << cutlass_time / cublas_time << ")" - << "\t" << tilefusion_time << "(" << tilefusion_time / cublas_time - << ")" << std::endl; + CudaTimer timer; + const int warm_up = 10; + const int iters = 50; + + for (int i = 0; i < warm_up; ++i) { + ke_tilefusion<<>>(A, B, C, D); + } + cudaDeviceSynchronize(); + + timer.start(); + for (int i = 0; i < iters; ++i) { + ke_tilefusion<<>>(A, B, C, D); + } + cudaDeviceSynchronize(); + float tilefusion_time = timer.stop() / iters; + + float cutlass_time = cutlass_fused_gemm(CA, CB, CC, CD, true, warm_up, iters); + + float cublas_time = cublas_two_gemms( + kM, kN, kK, kP, kBatch, A, B, C, thrust::raw_pointer_cast(d_d3.data()), + thrust::raw_pointer_cast(d_acc.data()), true); + + std::cout << "[" << kM << ", " << kN << ", " << kK << ", " << kP << "]\t[" + << kTM << ", " << kTN << ", " << kTK << ", " << kTP << "]\t" + << std::fixed << std::setprecision(4) << cublas_time << "\t" + << cutlass_time << "(" << cutlass_time / cublas_time << ")" + << "\t" << tilefusion_time << "(" << tilefusion_time / cublas_time + << ")" << std::endl; } int main() { - using WarpLayout = tl::RowMajor<4, 1>; - static constexpr int kSharedAccess = 128; + using WarpLayout = tl::RowMajor<4, 1>; + static constexpr int kSharedAccess = 128; - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - run, - B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, - WarpLayout, 1, kSharedAccess>(5e-3); + run, + B2BGemmShape<64 /*kTM*/, 128 /*kTN*/, 128 /*kTK*/, 128 /*kTP*/>, + WarpLayout, 1, kSharedAccess>(5e-3); - return 0; + return 0; } diff --git a/benchmarks/cpp/fused_two_gemms/cutlass_fused_two_gemms.cuh b/benchmarks/cpp/fused_two_gemms/cutlass_fused_two_gemms.cuh index b4818cae..5060cc52 100644 --- a/benchmarks/cpp/fused_two_gemms/cutlass_fused_two_gemms.cuh +++ b/benchmarks/cpp/fused_two_gemms/cutlass_fused_two_gemms.cuh @@ -18,72 +18,71 @@ template > struct FusedGemmTraits : public Base { - using Element = Element_; - - static_assert(kTK == kTN && kTN == kTP, - "Fused GEMM requires kTK == kTN == kTP."); - static_assert(kWarpPerCol == 1, - "The Fused GEMM requires a single warp along CTA tile."); - - using GmemLayoutA = Layout, Int>, Stride, _1>>; - using GmemLayoutB = Layout, Int>, Stride, _1>>; - using GmemLayoutC = Layout, Int>, Stride, _1>>; - using GmemLayoutD = Layout, Int>, Stride, _1>>; - - // TODO(haruhi): The current implementation uses ldmatrix.x4 - // instruction which requires the TileMMA configuration to be - // fixed as follows. Make it able to be tuned by policy in - // future implementation. - using TiledMma = - TiledMMA, // for ampere - Layout, Int, _1>>, - Tile, Int<16 * kWarpPerCol>, _16>>; - static constexpr int kThreads = size(TiledMma{}); - static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); - - static constexpr int kNumPerAccess = Base::kNumPerAccess; - static constexpr int kThreadsPerCol = CeilDiv; - static constexpr int kThreadsPerRow = CeilDiv; - - // static constexpr int kSwizzle = (kTK == 32 ? 2 : 3); - using SmemLayoutAtom = decltype(composition( - Swizzle<2, 3, 3>{}, - Layout>, Stride, _1>>{})); - - using SmemLayoutA = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - - // The current implementation requires B are laid out in column - // major. a [kTK, kTN] matrix in column major can be interpreted - // as a [kTN, kTK] matrix in row major. - using SmemLayoutB = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - // a [kTN, kTP] matrix in column major fashion, - // can be interpreted as a [kTP, kTN] matrix in row major fashion. - using SmemLayoutC = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using Element = Element_; + + static_assert(kTK == kTN && kTN == kTP, + "Fused GEMM requires kTK == kTN == kTP."); + static_assert(kWarpPerCol == 1, + "The Fused GEMM requires a single warp along CTA tile."); + + using GmemLayoutA = Layout, Int>, Stride, _1>>; + using GmemLayoutB = Layout, Int>, Stride, _1>>; + using GmemLayoutC = Layout, Int>, Stride, _1>>; + using GmemLayoutD = Layout, Int>, Stride, _1>>; + + // TODO(haruhi): The current implementation uses ldmatrix.x4 + // instruction which requires the TileMMA configuration to be + // fixed as follows. Make it able to be tuned by policy in + // future implementation. + using TiledMma = + TiledMMA, // for ampere + Layout, Int, _1>>, + Tile, Int<16 * kWarpPerCol>, _16>>; + static constexpr int kThreads = size(TiledMma{}); + static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); + + static constexpr int kNumPerAccess = Base::kNumPerAccess; + static constexpr int kThreadsPerCol = CeilDiv; + static constexpr int kThreadsPerRow = CeilDiv; + + // static constexpr int kSwizzle = (kTK == 32 ? 2 : 3); + using SmemLayoutAtom = decltype(composition( + Swizzle<2, 3, 3>{}, Layout>, Stride, _1>>{})); + + using SmemLayoutA = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + + // The current implementation requires B are laid out in column + // major. a [kTK, kTN] matrix in column major can be interpreted + // as a [kTN, kTK] matrix in row major. + using SmemLayoutB = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + // a [kTN, kTP] matrix in column major fashion, + // can be interpreted as a [kTP, kTN] matrix in row major fashion. + using SmemLayoutC = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); #ifdef CP_ASYNC_SM80_ENABLED - using CopyInstG2S = - Copy_Atom, Element>; + using CopyInstG2S = + Copy_Atom, Element>; #else - using CopyInstG2S = Copy_Atom; + using CopyInstG2S = Copy_Atom; #endif - using TiledCopyG2S = decltype(make_tiled_copy( - CopyInstG2S{}, - Layout, Int>, - Stride, _1>>{}, - Layout>>{})); - - using TiledCopyS2G = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Int>, - Stride, _1>>{}, - Layout>>{})); - using SmemLayoutD = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - - using StoreD_R2S = R2SCopy2D; + using TiledCopyG2S = decltype(make_tiled_copy( + CopyInstG2S{}, + Layout, Int>, + Stride, _1>>{}, + Layout>>{})); + + using TiledCopyS2G = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Int>, + Stride, _1>>{}, + Layout>>{})); + using SmemLayoutD = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + + using StoreD_R2S = R2SCopy2D; }; template __global__ void fused_gemm_kernel(const Element* dA, const Element* dB, const Element* dC, Element* dD) { - // Advance to the global data tile to the current CTA. - Element* A = const_cast(dA) + blockIdx.x * (kTM * kK); - Element* B = const_cast(dB); - Element* gC_ptr = const_cast(dC) + blockIdx.y * (kTP * kN); - Element* gD_ptr = dD + blockIdx.x * (kTM * kP) + (blockIdx.y * kTP); - - Element* gA_ptr; - Element* gB_ptr; - - extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; - auto* shm = reinterpret_cast(shared_buf); - // pointers to shared memory tiles - Element* sA_ptr = shm; - Element* sB_ptr = shm + kTM * kTK; - Element* sC_ptr = shm + kTM * kTK + kTK * kTN; - Element* sD_ptr = shm; - - typename KeTraits::TiledMma mma; // for shared memory to register copy - typename KeTraits::TiledCopyG2S tiled_copy; - - auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); - auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); - auto acc1 = get_acc(mma); // accumulator for the 1st gemm - - auto rC = make_s2rB(sC_ptr, typename KeTraits::SmemLayoutC{}, mma); - auto acc2 = get_acc(mma); // accumulator for the 2nd gemm - - typename KeTraits::StoreD_R2S sD; // declare register to shared store plan - - for (int n = 0; n < kN; n += kTN) { // iterate over N - gA_ptr = A; // A tile is repeated loaded - gB_ptr = B + n * kK; - for (int k = 0; k < kK; k += kTK) { // iterate over K - copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, - typename KeTraits::SmemLayoutA{}, tiled_copy); - copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, - typename KeTraits::SmemLayoutB{}, tiled_copy); - __copy_async(); - __syncthreads(); - - // iterate over the register tiles along the kTK dimension - for (int i = 0; i < rA.get_iters(); ++i) { - rA.copy(i); // load A register tile from shared memory - rB.copy(i); // load B register tile from shared memory - gemm(mma, rA[i], rB[i], acc1); // compute - } - __syncthreads(); - - gA_ptr += kTK; - gB_ptr += kTK; - } - - // The output type of the first tensor core matrix multiplication is - // float32. However, before the second GEMM operation, the output - // needs to be converted to half precision. - auto acc_half = convert_type(acc1); - auto rA2 = convert_layout(acc_half); - - // load C tile from global to shared memory - copy_tile_g2s(gC_ptr, sC_ptr, typename KeTraits::GmemLayoutC{}, - typename KeTraits::SmemLayoutC{}, tiled_copy); - __copy_async(); - __syncthreads(); - - // iterate over register tiles along the kTN dimension - for (int i = 0; i < rC.get_iters(); ++i) { - rC.copy(i); // load C tile from shared memory to register - gemm(mma, rA2[i], rC[i], acc2); // compute - } - __syncthreads(); - - clear(acc1); - gC_ptr += kTN; + // Advance to the global data tile to the current CTA. + Element* A = const_cast(dA) + blockIdx.x * (kTM * kK); + Element* B = const_cast(dB); + Element* gC_ptr = const_cast(dC) + blockIdx.y * (kTP * kN); + Element* gD_ptr = dD + blockIdx.x * (kTM * kP) + (blockIdx.y * kTP); + + Element* gA_ptr; + Element* gB_ptr; + + extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; + auto* shm = reinterpret_cast(shared_buf); + // pointers to shared memory tiles + Element* sA_ptr = shm; + Element* sB_ptr = shm + kTM * kTK; + Element* sC_ptr = shm + kTM * kTK + kTK * kTN; + Element* sD_ptr = shm; + + typename KeTraits::TiledMma mma; // for shared memory to register copy + typename KeTraits::TiledCopyG2S tiled_copy; + + auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); + auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); + auto acc1 = get_acc(mma); // accumulator for the 1st gemm + + auto rC = make_s2rB(sC_ptr, typename KeTraits::SmemLayoutC{}, mma); + auto acc2 = get_acc(mma); // accumulator for the 2nd gemm + + typename KeTraits::StoreD_R2S sD; // declare register to shared store plan + + for (int n = 0; n < kN; n += kTN) { // iterate over N + gA_ptr = A; // A tile is repeated loaded + gB_ptr = B + n * kK; + for (int k = 0; k < kK; k += kTK) { // iterate over K + copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, + typename KeTraits::SmemLayoutA{}, tiled_copy); + copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, + typename KeTraits::SmemLayoutB{}, tiled_copy); + __copy_async(); + __syncthreads(); + + // iterate over the register tiles along the kTK dimension + for (int i = 0; i < rA.get_iters(); ++i) { + rA.copy(i); // load A register tile from shared memory + rB.copy(i); // load B register tile from shared memory + gemm(mma, rA[i], rB[i], acc1); // compute + } + __syncthreads(); + + gA_ptr += kTK; + gB_ptr += kTK; } - // store register tile to shared memory - sD.copy(acc2, shm); + // The output type of the first tensor core matrix multiplication is + // float32. However, before the second GEMM operation, the output + // needs to be converted to half precision. + auto acc_half = convert_type(acc1); + auto rA2 = convert_layout(acc_half); + + // load C tile from global to shared memory + copy_tile_g2s(gC_ptr, sC_ptr, typename KeTraits::GmemLayoutC{}, + typename KeTraits::SmemLayoutC{}, tiled_copy); + __copy_async(); __syncthreads(); - copy_tile_s2g(sD_ptr, gD_ptr, typename KeTraits::SmemLayoutD{}, - typename KeTraits::GmemLayoutD{}, - typename KeTraits::TiledCopyS2G{}); + // iterate over register tiles along the kTN dimension + for (int i = 0; i < rC.get_iters(); ++i) { + rC.copy(i); // load C tile from shared memory to register + gemm(mma, rA2[i], rC[i], acc2); // compute + } + __syncthreads(); + + clear(acc1); + gC_ptr += kTN; + } + + // store register tile to shared memory + sD.copy(acc2, shm); + __syncthreads(); + + copy_tile_s2g(sD_ptr, gD_ptr, typename KeTraits::SmemLayoutD{}, + typename KeTraits::GmemLayoutD{}, + typename KeTraits::TiledCopyS2G{}); } } // namespace cutlass_wrapper } // namespace benchmarks @@ -186,50 +185,50 @@ template ; - - auto kernel = &benchmarks::cutlass_wrapper::fused_gemm_kernel< - Element, kM, kN, kK, kP, kTM, kTN, kTK, kTP, KeTraits>; - - int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); - int shm_output = kTM * kTP; - int shm_size = shm_input < shm_output ? shm_output * sizeof(Element) - : shm_input * sizeof(Element); - - // maximal statically allocated smem per block - const int kMaxSmemPerBlock = 48 * 1024; - if (shm_size > kMaxSmemPerBlock) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + using namespace benchmarks::cutlass_wrapper; + + using KeTraits = benchmarks::cutlass_wrapper::FusedGemmTraits< + Element, kWarpPerRow, kWarpPerCol, kM, kN, kK, kP, kTM, kTN, kTK, kTP>; + + auto kernel = &benchmarks::cutlass_wrapper::fused_gemm_kernel< + Element, kM, kN, kK, kP, kTM, kTN, kTK, kTP, KeTraits>; + + int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); + int shm_output = kTM * kTP; + int shm_size = shm_input < shm_output ? shm_output * sizeof(Element) + : shm_input * sizeof(Element); + + // maximal statically allocated smem per block + const int kMaxSmemPerBlock = 48 * 1024; + if (shm_size > kMaxSmemPerBlock) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } + + // blocks are launched along the M and P dimensions. + int block_x = (kM + kTM - 1) / kTM; + int block_y = (kP + kTP - 1) / kTP; + const int kThreads = KeTraits::kThreads; + + dim3 gridDim(block_x, block_y, 1); + dim3 blockDim(kThreads, 1, 1); + + float elapsed = 0.; + if (timeit) { + for (int i = 0; i < warp_up; ++i) { + kernel<<>>(dA, dB, dC, dD); } + cudaDeviceSynchronize(); - // blocks are launched along the M and P dimensions. - int block_x = (kM + kTM - 1) / kTM; - int block_y = (kP + kTP - 1) / kTP; - const int kThreads = KeTraits::kThreads; - - dim3 gridDim(block_x, block_y, 1); - dim3 blockDim(kThreads, 1, 1); - - float elapsed = 0.; - if (timeit) { - for (int i = 0; i < warp_up; ++i) { - kernel<<>>(dA, dB, dC, dD); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) { - kernel<<>>(dA, dB, dC, dD); - } - cudaDeviceSynchronize(); - elapsed = timer.stop() / iters; - } else { - kernel<<>>(dA, dB, dC, dD); + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) { + kernel<<>>(dA, dB, dC, dD); } - return elapsed; + cudaDeviceSynchronize(); + elapsed = timer.stop() / iters; + } else { + kernel<<>>(dA, dB, dC, dD); + } + return elapsed; } diff --git a/benchmarks/cpp/fused_two_gemms/util.cuh b/benchmarks/cpp/fused_two_gemms/util.cuh index 2a0f5e65..b5a663f4 100644 --- a/benchmarks/cpp/fused_two_gemms/util.cuh +++ b/benchmarks/cpp/fused_two_gemms/util.cuh @@ -14,37 +14,37 @@ template using B2BGemmShape = TileShape; float rand_float(float a = 1e-1, float b = 5e-2) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; } void cublas_two_gemms_impl(cublasHandle_t handle, int kM, int kN, int kK, int kP, int kBatch, const __half* A, const __half* B, const __half* C, __half* D, __half* acc) { - __half alf = static_cast<__half>(1.); - __half bet = static_cast<__half>(0.); - for (int b = 0; b < kBatch; ++b) { - A += b * kM * kK; - B += b * kK * kN; - C += b * kM * kN; - acc += b * kM * kN; - D += b * kM * kP; - // acc = A @ B - // acc^T = B^T @ A^T - // [n, m] = [n, k] @ [k, m] - cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kN, kM, kK, - &alf, B, kK, A, kK, &bet, acc, kN); - - // D and acc are laid out in row-major fashion, while C is in column - // major fashion. Operands of cuBLAS is by default in column - // fashion. D = acc @ C D^T = C^T @ acc^T; [p, m] = [p, n] @ [n, m] - cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kP, kM, kN, - &alf, C, kN, acc, kN, &bet, D, kP); - - // cudaDeviceSynchronize(); - } + __half alf = static_cast<__half>(1.); + __half bet = static_cast<__half>(0.); + for (int b = 0; b < kBatch; ++b) { + A += b * kM * kK; + B += b * kK * kN; + C += b * kM * kN; + acc += b * kM * kN; + D += b * kM * kP; + // acc = A @ B + // acc^T = B^T @ A^T + // [n, m] = [n, k] @ [k, m] + cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kN, kM, kK, &alf, + B, kK, A, kK, &bet, acc, kN); + + // D and acc are laid out in row-major fashion, while C is in column + // major fashion. Operands of cuBLAS is by default in column + // fashion. D = acc @ C D^T = C^T @ acc^T; [p, m] = [p, n] @ [n, m] + cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kP, kM, kN, &alf, + C, kN, acc, kN, &bet, D, kP); + + // cudaDeviceSynchronize(); + } } /* In this implementation, A and D are interpreted as being laid out in @@ -60,99 +60,97 @@ float cublas_two_gemms(int kM, int kN, int kK, int kP, int kBatch, const __half* As, const __half* Bs, const __half* Cs, __half* Ds, __half* accs, bool timeit = false, int warm_up = 5, int iters = 20) { - cublasHandle_t handle; - cublasCreate(&handle); - - const __half* A = As; - const __half* B = Bs; - const __half* C = Cs; - __half* acc = accs; - __half* D = Ds; - - float elapsed = 0.; - - if (timeit) { - for (int i = 0; i < warm_up; ++i) { - cublas_two_gemms_impl(handle, kM, kN, kK, kP, kBatch, A, B, C, D, - acc); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) { - cublas_two_gemms_impl(handle, kM, kN, kK, kP, kBatch, A, B, C, D, - acc); - } - cudaDeviceSynchronize(); - elapsed = timer.stop() / iters; - } else { - cublas_two_gemms_impl(handle, kM, kN, kK, kP, kBatch, A, B, C, D, acc); + cublasHandle_t handle; + cublasCreate(&handle); + + const __half* A = As; + const __half* B = Bs; + const __half* C = Cs; + __half* acc = accs; + __half* D = Ds; + + float elapsed = 0.; + + if (timeit) { + for (int i = 0; i < warm_up; ++i) { + cublas_two_gemms_impl(handle, kM, kN, kK, kP, kBatch, A, B, C, D, acc); } + cudaDeviceSynchronize(); + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) { + cublas_two_gemms_impl(handle, kM, kN, kK, kP, kBatch, A, B, C, D, acc); + } cudaDeviceSynchronize(); - cublasDestroy(handle); - return elapsed; + elapsed = timer.stop() / iters; + } else { + cublas_two_gemms_impl(handle, kM, kN, kK, kP, kBatch, A, B, C, D, acc); + } + + cudaDeviceSynchronize(); + cublasDestroy(handle); + return elapsed; } bool check_results(const float* values1, const __half* values2, int numel, float epsilon) { - bool passed = true; + bool passed = true; - float v2 = 0.; + float v2 = 0.; - double total_diff = 0.; - double max_abs_diff = FLT_MIN; - double diff = 0.; + double total_diff = 0.; + double max_abs_diff = FLT_MIN; + double diff = 0.; - for (int i = 0; i < numel; ++i) { - v2 = __half2float(values2[i]); - diff = abs(values1[i] - v2); - max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; - total_diff += diff; + for (int i = 0; i < numel; ++i) { + v2 = __half2float(values2[i]); + diff = abs(values1[i] - v2); + max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; + total_diff += diff; #ifdef DEBUG - if (diff > epsilon) { - printf("%d-th value has large differences: %.3f vs. %.3f\n", i, - values1[i], v2); - } -#endif + if (diff > epsilon) { + printf("%d-th value has large differences: %.3f vs. %.3f\n", i, + values1[i], v2); } +#endif + } - double avg_diff = total_diff / numel; - if (avg_diff > epsilon) passed = false; + double avg_diff = total_diff / numel; + if (avg_diff > epsilon) passed = false; - return passed; + return passed; } bool check_results(const __half* values1, const __half* values2, int numel, float epsilon) { - bool passed = true; + bool passed = true; - float v1 = 0.; - float v2 = 0.; + float v1 = 0.; + float v2 = 0.; - double total_diff = 0.; - double max_abs_diff = FLT_MIN; - double diff = 0.; + double total_diff = 0.; + double max_abs_diff = FLT_MIN; + double diff = 0.; - for (int i = 0; i < numel; ++i) { - v1 = __half2float(values1[i]); - v2 = __half2float(values2[i]); - diff = abs(v1 - v2); - max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; - total_diff += diff; + for (int i = 0; i < numel; ++i) { + v1 = __half2float(values1[i]); + v2 = __half2float(values2[i]); + diff = abs(v1 - v2); + max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; + total_diff += diff; #ifdef DEBUG - if (diff > epsilon) { - printf("%d-th value has large differences: %.3f vs. %.3f\n", i, - values1[i], v2); - } -#endif + if (diff > epsilon) { + printf("%d-th value has large differences: %.3f vs. %.3f\n", i, + values1[i], v2); } +#endif + } - double avg_diff = total_diff / numel; - if (avg_diff > epsilon) passed = false; + double avg_diff = total_diff / numel; + if (avg_diff > epsilon) passed = false; - return passed; + return passed; } diff --git a/benchmarks/cpp/g2s_copy/cutlass_copy.cuh b/benchmarks/cpp/g2s_copy/cutlass_copy.cuh index 091d7e9e..09109e84 100644 --- a/benchmarks/cpp/g2s_copy/cutlass_copy.cuh +++ b/benchmarks/cpp/g2s_copy/cutlass_copy.cuh @@ -19,122 +19,122 @@ template struct Loader { - DEVICE void operator()(const Element* src_, Element* dst_) { - int tid = threadIdx.x; + DEVICE void operator()(const Element* src_, Element* dst_) { + int tid = threadIdx.x; - auto gtile = make_tensor(make_gmem_ptr(src_), src_layout_); - auto stile = make_tensor(make_smem_ptr(dst_), dst_layout_); + auto gtile = make_tensor(make_gmem_ptr(src_), src_layout_); + auto stile = make_tensor(make_smem_ptr(dst_), dst_layout_); - auto loader = tiled_copy_.get_thread_slice(tid); + auto loader = tiled_copy_.get_thread_slice(tid); - auto src = loader.partition_S(gtile); - auto dst = loader.partition_D(stile); + auto src = loader.partition_S(gtile); + auto dst = loader.partition_D(stile); #pragma unroll - for (int i = 0; i < int(size<1>(src)); ++i) + for (int i = 0; i < int(size<1>(src)); ++i) #pragma unroll - for (int j = 0; j < int(size<2>(src)); ++j) - cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j)); - } - - private: - // source - using GlobalLayout = - cute::Layout, Int>, Stride, _1>>; - GlobalLayout src_layout_; - - // destination - using LayoutAtom = - decltype(composition(cute::Swizzle<2, 3, 3>{}, - cute::Layout, Stride<_64, _1>>{})); - using SharedLayout = decltype(tile_to_shape( - LayoutAtom{}, Shape, Int>{}, cute::Step<_2, _1>{})); - SharedLayout dst_layout_; - - // tiled copy - static constexpr int kThreadCols = kCols * 16 / 128; - static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols; - - using ThreadLayout = cute::Layout, Int>, - Stride, _1>>; - using ValueLayout = cute::Layout>; - - using CopyInst = - Copy_Atom, Element>; - using TiledCopy = - decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{})); - TiledCopy tiled_copy_; + for (int j = 0; j < int(size<2>(src)); ++j) + cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j)); + } + + private: + // source + using GlobalLayout = + cute::Layout, Int>, Stride, _1>>; + GlobalLayout src_layout_; + + // destination + using LayoutAtom = + decltype(composition(cute::Swizzle<2, 3, 3>{}, + cute::Layout, Stride<_64, _1>>{})); + using SharedLayout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, cute::Step<_2, _1>{})); + SharedLayout dst_layout_; + + // tiled copy + static constexpr int kThreadCols = kCols * 16 / 128; + static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols; + + using ThreadLayout = cute::Layout, Int>, + Stride, _1>>; + using ValueLayout = cute::Layout>; + + using CopyInst = + Copy_Atom, Element>; + using TiledCopy = + decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{})); + TiledCopy tiled_copy_; }; template struct Storer { - DEVICE void operator()(const Element* src_, Element* dst_) { - int tid = threadIdx.x; + DEVICE void operator()(const Element* src_, Element* dst_) { + int tid = threadIdx.x; - auto stile = make_tensor(make_smem_ptr(src_), src_layout_); // shared - auto gtile = make_tensor(make_gmem_ptr(dst_), dst_layout_); // global + auto stile = make_tensor(make_smem_ptr(src_), src_layout_); // shared + auto gtile = make_tensor(make_gmem_ptr(dst_), dst_layout_); // global - auto loader = tiled_copy_.get_thread_slice(tid); + auto loader = tiled_copy_.get_thread_slice(tid); - auto src = loader.partition_S(stile); - auto dst = loader.partition_D(gtile); + auto src = loader.partition_S(stile); + auto dst = loader.partition_D(gtile); #pragma unroll - for (int i = 0; i < int(size<1>(src)); ++i) + for (int i = 0; i < int(size<1>(src)); ++i) #pragma unroll - for (int j = 0; j < int(size<2>(src)); ++j) - cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j)); - } - - private: - // declare the source layout - using LayoutAtom = - decltype(composition(cute::Swizzle<2, 3, 3>{}, - cute::Layout, Stride<_64, _1>>{})); - using SharedLayout = decltype(tile_to_shape( - LayoutAtom{}, Shape, Int>{}, cute::Step<_2, _1>{})); - SharedLayout src_layout_; - - // declare the destination layout - using GlobalLayout = - cute::Layout, Int>, Stride, _1>>; - GlobalLayout dst_layout_; - - // declare the tiled copy - static constexpr int kThreadCols = kCols * 16 / 128; - static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols; - using ThreadLayout = cute::Layout, Int>, - Stride, _1>>; - using ValueLayout = cute::Layout>; - - using CopyInst = Copy_Atom; - using TiledCopy = - decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{})); - TiledCopy tiled_copy_; + for (int j = 0; j < int(size<2>(src)); ++j) + cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j)); + } + + private: + // declare the source layout + using LayoutAtom = + decltype(composition(cute::Swizzle<2, 3, 3>{}, + cute::Layout, Stride<_64, _1>>{})); + using SharedLayout = decltype(tile_to_shape( + LayoutAtom{}, Shape, Int>{}, cute::Step<_2, _1>{})); + SharedLayout src_layout_; + + // declare the destination layout + using GlobalLayout = + cute::Layout, Int>, Stride, _1>>; + GlobalLayout dst_layout_; + + // declare the tiled copy + static constexpr int kThreadCols = kCols * 16 / 128; + static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols; + using ThreadLayout = cute::Layout, Int>, + Stride, _1>>; + using ValueLayout = cute::Layout>; + + using CopyInst = Copy_Atom; + using TiledCopy = + decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{})); + TiledCopy tiled_copy_; }; } // namespace template __global__ void cutlass_g2s_data_transfer(const Element* src, Element* dst) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); - using G2S = Loader; - G2S loader; + using G2S = Loader; + G2S loader; - using S2G = Storer; - S2G storer; + using S2G = Storer; + S2G storer; - for (int k = 0; k < kRepeat; ++k) { - loader(src, buf); + for (int k = 0; k < kRepeat; ++k) { + loader(src, buf); - cutlass_wrapper::__copy_async(); - __syncthreads(); + cutlass_wrapper::__copy_async(); + __syncthreads(); - storer(buf, dst); - __syncthreads(); - } + storer(buf, dst); + __syncthreads(); + } } diff --git a/benchmarks/cpp/g2s_copy/main.cu b/benchmarks/cpp/g2s_copy/main.cu index 2c1dfd87..805084d8 100644 --- a/benchmarks/cpp/g2s_copy/main.cu +++ b/benchmarks/cpp/g2s_copy/main.cu @@ -20,147 +20,147 @@ const int kRepeat = 100; template bool check_results(const Element* dst1, const Element* dst2, int64_t numel) { - float epsilon = 1e-3; - for (int i = 0; i < numel; ++i) { - float v1 = abs(static_cast(dst1[i])); - float v2 = abs(static_cast(dst2[i])); - if (v1 - v2 > epsilon) { - std::cerr << "Mismatch at " << i << ": " << v1 << " vs " << v2 - << std::endl; - return false; - } + float epsilon = 1e-3; + for (int i = 0; i < numel; ++i) { + float v1 = abs(static_cast(dst1[i])); + float v2 = abs(static_cast(dst2[i])); + if (v1 - v2 > epsilon) { + std::cerr << "Mismatch at " << i << ": " << v1 << " vs " << v2 + << std::endl; + return false; } - return true; + } + return true; } template float test_tilefusion(const Element* src, Element* dst) { - using Global = GlobalTile; - using Shared = SharedTile; + using Global = GlobalTile; + using Shared = SharedTile; - using Loader = GlobalToSharedLoader; - Loader loader; + using Loader = GlobalToSharedLoader; + Loader loader; - using Storer = SharedToGlobalStorer; - Storer storer; + using Storer = SharedToGlobalStorer; + Storer storer; - auto kernel = - &g2s_data_transfer; + auto kernel = + &g2s_data_transfer; - static const int kThreads = WarpLayout::kNumel * 32; - int shm_size = Shared::kNumel * sizeof(Element); + static const int kThreads = WarpLayout::kNumel * 32; + int shm_size = Shared::kNumel * sizeof(Element); - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } - dim3 grids(1, 1, 1); - dim3 blocks(kThreads); + dim3 grids(1, 1, 1); + dim3 blocks(kThreads); - for (int i = 0; i < warmup; ++i) // warm up - kernel<<>>(src, dst, loader, storer); - cudaDeviceSynchronize(); + for (int i = 0; i < warmup; ++i) // warm up + kernel<<>>(src, dst, loader, storer); + cudaDeviceSynchronize(); - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) - kernel<<>>(src, dst, loader, storer); - cudaDeviceSynchronize(); - return timer.stop() / iters; + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) + kernel<<>>(src, dst, loader, storer); + cudaDeviceSynchronize(); + return timer.stop() / iters; } template float test_cutlass(const Element* src, Element* dst) { - auto kernel = &cutlass_g2s_data_transfer; - - int shm_size = Layout::kNumel * sizeof(Element); - int kThreads = WarpLayout::kNumel * 32; - - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } - - dim3 grids(1, 1, 1); - dim3 blocks(kThreads); - - for (int i = 0; i < warmup; ++i) { - kernel<<>>(src, dst); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) { - kernel<<>>(src, dst); - } - cudaDeviceSynchronize(); - return timer.stop() / iters; + auto kernel = + &cutlass_g2s_data_transfer; + + int shm_size = Layout::kNumel * sizeof(Element); + int kThreads = WarpLayout::kNumel * 32; + + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } + + dim3 grids(1, 1, 1); + dim3 blocks(kThreads); + + for (int i = 0; i < warmup; ++i) { + kernel<<>>(src, dst); + } + cudaDeviceSynchronize(); + + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) { + kernel<<>>(src, dst); + } + cudaDeviceSynchronize(); + return timer.stop() / iters; } template void run_test_rowmajor() { - int numel = Layout::kNumel; + int numel = Layout::kNumel; - thrust::host_vector h_src(numel); - for (int i = 0; i < h_src.size(); ++i) - h_src[i] = static_cast(i % 2048); + thrust::host_vector h_src(numel); + for (int i = 0; i < h_src.size(); ++i) + h_src[i] = static_cast(i % 2048); - thrust::device_vector d_src = h_src; - const Element* src = thrust::raw_pointer_cast(d_src.data()); + thrust::device_vector d_src = h_src; + const Element* src = thrust::raw_pointer_cast(d_src.data()); - thrust::device_vector d_dst1(numel); - thrust::fill(d_dst1.begin(), d_dst1.end(), static_cast(0.)); - Element* dst1 = thrust::raw_pointer_cast(d_dst1.data()); + thrust::device_vector d_dst1(numel); + thrust::fill(d_dst1.begin(), d_dst1.end(), static_cast(0.)); + Element* dst1 = thrust::raw_pointer_cast(d_dst1.data()); - thrust::device_vector d_dst2(numel); - thrust::fill(d_dst2.begin(), d_dst2.end(), static_cast(0.)); - Element* dst2 = thrust::raw_pointer_cast(d_dst2.data()); + thrust::device_vector d_dst2(numel); + thrust::fill(d_dst2.begin(), d_dst2.end(), static_cast(0.)); + Element* dst2 = thrust::raw_pointer_cast(d_dst2.data()); - float t1 = test_tilefusion(src, dst1); - thrust::host_vector h_dst1 = d_dst1; + float t1 = test_tilefusion(src, dst1); + thrust::host_vector h_dst1 = d_dst1; - float t2 = test_cutlass(src, dst2); - thrust::host_vector h_dst2 = d_dst2; + float t2 = test_cutlass(src, dst2); + thrust::host_vector h_dst2 = d_dst2; - bool passed = check_results(thrust::raw_pointer_cast(h_dst1.data()), - thrust::raw_pointer_cast(h_dst2.data()), numel); - if (!passed) { - std::cerr << "Test failed" << std::endl; - return; - } + bool passed = check_results(thrust::raw_pointer_cast(h_dst1.data()), + thrust::raw_pointer_cast(h_dst2.data()), numel); + if (!passed) { + std::cerr << "Test failed" << std::endl; + return; + } - std::cout << "|RowMajor(" << Layout::kRows << ", " << Layout::kCols << ")|(" - << WarpLayout::kRows << ", " << WarpLayout::kCols << ")|" << t1 - << "|" << t2 << "|" << t1 / t2 << "|" << std::endl; + std::cout << "|RowMajor(" << Layout::kRows << ", " << Layout::kCols << ")|(" + << WarpLayout::kRows << ", " << WarpLayout::kCols << ")|" << t1 + << "|" << t2 << "|" << t1 / t2 << "|" << std::endl; } int main() { - std::cout << std::setprecision(4) - << "|Shape|Warp Layout|tilefusion(ms)|cutlass(ms)|Ratio|" - << std::endl - << "|:---|:---:|:---:|:---:|:---:|" << std::endl; + std::cout << std::setprecision(4) + << "|Shape|Warp Layout|tilefusion(ms)|cutlass(ms)|Ratio|" + << std::endl + << "|:---|:---:|:---:|:---:|:---:|" << std::endl; - using DType = __half; + using DType = __half; - run_test_rowmajor, tl::RowMajor<1, 1>>(); - run_test_rowmajor, tl::RowMajor<1, 1>>(); - run_test_rowmajor, tl::RowMajor<2, 1>>(); - run_test_rowmajor, tl::RowMajor<4, 1>>(); + run_test_rowmajor, tl::RowMajor<1, 1>>(); + run_test_rowmajor, tl::RowMajor<1, 1>>(); + run_test_rowmajor, tl::RowMajor<2, 1>>(); + run_test_rowmajor, tl::RowMajor<4, 1>>(); - run_test_rowmajor, tl::RowMajor<1, 1>>(); - run_test_rowmajor, tl::RowMajor<2, 2>>(); - run_test_rowmajor, tl::RowMajor<4, 2>>(); + run_test_rowmajor, tl::RowMajor<1, 1>>(); + run_test_rowmajor, tl::RowMajor<2, 2>>(); + run_test_rowmajor, tl::RowMajor<4, 2>>(); - run_test_rowmajor, tl::RowMajor<1, 1>>(); - run_test_rowmajor, tl::RowMajor<2, 2>>(); - run_test_rowmajor, tl::RowMajor<2, 4>>(); - run_test_rowmajor, tl::RowMajor<4, 4>>(); + run_test_rowmajor, tl::RowMajor<1, 1>>(); + run_test_rowmajor, tl::RowMajor<2, 2>>(); + run_test_rowmajor, tl::RowMajor<2, 4>>(); + run_test_rowmajor, tl::RowMajor<4, 4>>(); - return 0; + return 0; } diff --git a/benchmarks/cpp/g2s_copy/tilefusion_copy.cuh b/benchmarks/cpp/g2s_copy/tilefusion_copy.cuh index 0fa31acd..014e1759 100644 --- a/benchmarks/cpp/g2s_copy/tilefusion_copy.cuh +++ b/benchmarks/cpp/g2s_copy/tilefusion_copy.cuh @@ -11,19 +11,19 @@ template __global__ void g2s_data_transfer(const Element* src_ptr, Element* dst_ptr, Loader& loader, Storer& storer) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); - Global src(src_ptr); - Shared inter(buf); - Global dst(dst_ptr); // global memory tile + Global src(src_ptr); + Shared inter(buf); + Global dst(dst_ptr); // global memory tile - for (int i = 0; i < kRepeat; ++i) { - loader(src, inter); - copy::__copy_async(); - __syncthreads(); + for (int i = 0; i < kRepeat; ++i) { + loader(src, inter); + copy::__copy_async(); + __syncthreads(); - storer(inter, dst); - __syncthreads(); - } + storer(inter, dst); + __syncthreads(); + } } diff --git a/benchmarks/cpp/gemm/bench.cu b/benchmarks/cpp/gemm/bench.cu index 6c3b1307..57fe31e8 100644 --- a/benchmarks/cpp/gemm/bench.cu +++ b/benchmarks/cpp/gemm/bench.cu @@ -21,185 +21,184 @@ static constexpr int kRK = 64; static constexpr int kSharedAccess = 64; void run_test(std::ofstream& fout) { - //// =============== Declaration =============== //// - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; - - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - static constexpr int kTK = dim_size<2, CtaTileShape>; - - using InType = __half; - using AccType = float; - - using Config = KeGemmTraits; - auto tilefusion_gemm = - &gemm; - - using KeTraits = benchmarks::cutlass_wrapper::GemmTraits< - cutlass::half_t, kWarpPerRow, kWarpPerCol, kM, kN, kK, kTM, kTN, kTK>; - auto cutlass_gemm = - &benchmarks::cutlass_wrapper::gemm_kernel; - - static constexpr int inputs = kTK * (kTN + kTM) * sizeof(InType); - static constexpr int acc = kTM * kTN * sizeof(InType); - static constexpr int smem_size = inputs > acc ? inputs : acc; - - const int kMaxSmemPerBlock = 48 * 1024; - if (smem_size > kMaxSmemPerBlock) { - cudaFuncSetAttribute(tilefusion_gemm, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - cudaFuncSetAttribute(cutlass_gemm, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - int block_x = benchmarks::CeilDiv; - int block_y = benchmarks::CeilDiv; - dim3 dim_grid(block_x, block_y, 1); - dim3 dim_block(Config::kThreads, 1, 1); - - std::cout << "Running test:" << std::endl - << "[M, N, K] = " << kM << ", " << kN << ", " << kK - << ", [TM, TN, TK] = " << kTM << ", " << kTN << ", " << kTK - << ", RK = " << kRK << ", WarpLayout = [" << kWarpPerRow << ", " - << kWarpPerCol << "]" << std::endl - << "blocks = [" << block_x << ", " << block_y << "]" << std::endl - << std::endl; - - //// =============== Prepare data =============== //// - // input matrix A - thrust::host_vector h_a(kM * kK); - for (int i = 0; i < h_a.size(); ++i) - h_a[i] = static_cast(rand_float()); - thrust::device_vector d_a = h_a; - const cutlass::half_t* dA = thrust::raw_pointer_cast(d_a.data()); - const __half* dA2 = reinterpret_cast(dA); - - // input matrix B - thrust::host_vector h_b(kK * kN); - for (int i = 0; i < h_b.size(); ++i) - h_b[i] = static_cast(rand_float()); - thrust::device_vector d_b = h_b; - const cutlass::half_t* dB = thrust::raw_pointer_cast(d_b.data()); - const __half* dB2 = reinterpret_cast(dB); - - // output matrix C for cutlass GEMM kernel - thrust::device_vector d_c(kM * kN); - cutlass::half_t* dC = thrust::raw_pointer_cast(d_c.data()); - thrust::device_vector d_c2(kM * kN); - InType* dC2 = thrust::raw_pointer_cast(d_c2.data()); - - // output matrix C for cublas gemm - thrust::device_vector<__half> d_c3(kM * kN); - __half* dC3 = thrust::raw_pointer_cast(d_c3.data()); - - thrust::host_vector h_c; - thrust::host_vector h_c2; - thrust::host_vector<__half> h_c3; + //// =============== Declaration =============== //// + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; + + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + static constexpr int kTK = dim_size<2, CtaTileShape>; + + using InType = __half; + using AccType = float; + + using Config = KeGemmTraits; + auto tilefusion_gemm = + &gemm; + + using KeTraits = benchmarks::cutlass_wrapper::GemmTraits< + cutlass::half_t, kWarpPerRow, kWarpPerCol, kM, kN, kK, kTM, kTN, kTK>; + auto cutlass_gemm = + &benchmarks::cutlass_wrapper::gemm_kernel; + + static constexpr int inputs = kTK * (kTN + kTM) * sizeof(InType); + static constexpr int acc = kTM * kTN * sizeof(InType); + static constexpr int smem_size = inputs > acc ? inputs : acc; + + const int kMaxSmemPerBlock = 48 * 1024; + if (smem_size > kMaxSmemPerBlock) { + cudaFuncSetAttribute(tilefusion_gemm, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + cudaFuncSetAttribute( + cutlass_gemm, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + + int block_x = benchmarks::CeilDiv; + int block_y = benchmarks::CeilDiv; + dim3 dim_grid(block_x, block_y, 1); + dim3 dim_block(Config::kThreads, 1, 1); + + std::cout << "Running test:" << std::endl + << "[M, N, K] = " << kM << ", " << kN << ", " << kK + << ", [TM, TN, TK] = " << kTM << ", " << kTN << ", " << kTK + << ", RK = " << kRK << ", WarpLayout = [" << kWarpPerRow << ", " + << kWarpPerCol << "]" << std::endl + << "blocks = [" << block_x << ", " << block_y << "]" << std::endl + << std::endl; + + //// =============== Prepare data =============== //// + // input matrix A + thrust::host_vector h_a(kM * kK); + for (int i = 0; i < h_a.size(); ++i) + h_a[i] = static_cast(rand_float()); + thrust::device_vector d_a = h_a; + const cutlass::half_t* dA = thrust::raw_pointer_cast(d_a.data()); + const __half* dA2 = reinterpret_cast(dA); + + // input matrix B + thrust::host_vector h_b(kK * kN); + for (int i = 0; i < h_b.size(); ++i) + h_b[i] = static_cast(rand_float()); + thrust::device_vector d_b = h_b; + const cutlass::half_t* dB = thrust::raw_pointer_cast(d_b.data()); + const __half* dB2 = reinterpret_cast(dB); + + // output matrix C for cutlass GEMM kernel + thrust::device_vector d_c(kM * kN); + cutlass::half_t* dC = thrust::raw_pointer_cast(d_c.data()); + thrust::device_vector d_c2(kM * kN); + InType* dC2 = thrust::raw_pointer_cast(d_c2.data()); + + // output matrix C for cublas gemm + thrust::device_vector<__half> d_c3(kM * kN); + __half* dC3 = thrust::raw_pointer_cast(d_c3.data()); + + thrust::host_vector h_c; + thrust::host_vector h_c2; + thrust::host_vector<__half> h_c3; //// =============== check correctness =============== //// #ifdef CHECK_CORRECTNESS - thrust::fill(d_c.begin(), d_c.end(), static_cast(0.)); - thrust::fill(d_c2.begin(), d_c2.end(), static_cast(0.)); - thrust::fill(d_c3.begin(), d_c3.end(), static_cast<__half>(0.)); + thrust::fill(d_c.begin(), d_c.end(), static_cast(0.)); + thrust::fill(d_c2.begin(), d_c2.end(), static_cast(0.)); + thrust::fill(d_c3.begin(), d_c3.end(), static_cast<__half>(0.)); + + cutlass_gemm<<>>(dA, dB, dC); + cudaDeviceSynchronize(); + h_c = d_c; + + tilefusion_gemm<<>>(dA2, dB2, dC2); + cudaDeviceSynchronize(); + h_c2 = d_c2; + + // cublas + cublas_hgemm(kM, kN, kK, dA2, dB2, dC3, false /*timeit*/); + h_c3 = d_c3; + + bool passed1 = + check_results(thrust::raw_pointer_cast(h_c2.data()), /*tilefusion */ + thrust::raw_pointer_cast(h_c.data()), /*cutlass */ kM * kN); + + bool passed2 = + check_results(thrust::raw_pointer_cast(h_c3.data()), /*cublas */ + thrust::raw_pointer_cast(h_c.data()), /*cutlass */ kM * kN); + + if (!(passed1 && passed2)) { + std::cerr << "Test failed" << std::endl; + return; + } + std::cout << "Test passed" << std::endl; +#endif - cutlass_gemm<<>>(dA, dB, dC); - cudaDeviceSynchronize(); - h_c = d_c; + //// =============== Timing =============== //// + thrust::fill(d_c.begin(), d_c.end(), static_cast(0.)); + thrust::fill(d_c2.begin(), d_c2.end(), static_cast(0.)); + thrust::fill(d_c3.begin(), d_c3.end(), static_cast<__half>(0.)); + float cublas_time = cublas_hgemm(kM, kN, kK, dA2, dB2, dC3, true); + h_c3 = d_c3; + + const int warm_up = 10; + const int iters = 50; + for (int i = 0; i < warm_up; ++i) { + cutlass_gemm<<>>(dA, dB, dC); tilefusion_gemm<<>>(dA2, dB2, dC2); - cudaDeviceSynchronize(); - h_c2 = d_c2; - - // cublas - cublas_hgemm(kM, kN, kK, dA2, dB2, dC3, false /*timeit*/); - h_c3 = d_c3; - - bool passed1 = check_results( - thrust::raw_pointer_cast(h_c2.data()), /*tilefusion */ - thrust::raw_pointer_cast(h_c.data()), /*cutlass */ kM * kN); - - bool passed2 = check_results( - thrust::raw_pointer_cast(h_c3.data()), /*cublas */ - thrust::raw_pointer_cast(h_c.data()), /*cutlass */ kM * kN); - - if (!(passed1 && passed2)) { - std::cerr << "Test failed" << std::endl; - return; - } - std::cout << "Test passed" << std::endl; -#endif + } + cudaDeviceSynchronize(); + + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) { + cutlass_gemm<<>>(dA, dB, dC); + } + cudaDeviceSynchronize(); + float cutlass_time = timer.stop() / iters; - //// =============== Timing =============== //// - thrust::fill(d_c.begin(), d_c.end(), static_cast(0.)); - thrust::fill(d_c2.begin(), d_c2.end(), static_cast(0.)); - thrust::fill(d_c3.begin(), d_c3.end(), static_cast<__half>(0.)); - - float cublas_time = cublas_hgemm(kM, kN, kK, dA2, dB2, dC3, true); - h_c3 = d_c3; - - const int warm_up = 10; - const int iters = 50; - for (int i = 0; i < warm_up; ++i) { - cutlass_gemm<<>>(dA, dB, dC); - tilefusion_gemm<<>>(dA2, dB2, dC2); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) { - cutlass_gemm<<>>(dA, dB, dC); - } - cudaDeviceSynchronize(); - float cutlass_time = timer.stop() / iters; - - timer.start(); - for (int i = 0; i < iters; ++i) { - tilefusion_gemm<<>>(dA2, dB2, dC2); - } - cudaDeviceSynchronize(); - float tilefusion_time = timer.stop() / iters; - - float base = cublas_time; - - fout << "[" << kM << ", " << kN << ", " << kK << "]\t[" << kTM << ", " - << kTN << ", " << kTK << "]\t" << kRK << "\t[" << kWarpPerRow << ", " - << kWarpPerCol << "]\t" << cublas_time << "\t" << cutlass_time << "(" - << std::setprecision(2) << cutlass_time / base << ")" - << "\t" << std::setprecision(6) << tilefusion_time << " (" - << std::setprecision(2) << tilefusion_time / base << ")" << std::endl; + timer.start(); + for (int i = 0; i < iters; ++i) { + tilefusion_gemm<<>>(dA2, dB2, dC2); + } + cudaDeviceSynchronize(); + float tilefusion_time = timer.stop() / iters; + + float base = cublas_time; + + fout << "[" << kM << ", " << kN << ", " << kK << "]\t[" << kTM << ", " << kTN + << ", " << kTK << "]\t" << kRK << "\t[" << kWarpPerRow << ", " + << kWarpPerCol << "]\t" << cublas_time << "\t" << cutlass_time << "(" + << std::setprecision(2) << cutlass_time / base << ")" + << "\t" << std::setprecision(6) << tilefusion_time << " (" + << std::setprecision(2) << tilefusion_time / base << ")" << std::endl; } int main() { - std::ofstream fout; - fout.setf(std::ios::fixed); - fout.precision(6); + std::ofstream fout; + fout.setf(std::ios::fixed); + fout.precision(6); - auto dev_name = tilefusion::get_device_name(); - std::stringstream file_name; - file_name << "figures/bench_" << dev_name << "_gemm.tsv"; - fout.open(file_name.str(), std::ios::out); + auto dev_name = tilefusion::get_device_name(); + std::stringstream file_name; + file_name << "figures/bench_" << dev_name << "_gemm.tsv"; + fout.open(file_name.str(), std::ios::out); - fout << "[M, N, K]\t[kTM, kTN, kTK]\tkRK\tWarp Layout\t" - "cuBLAS(ms)\tcutlass(ms)\ttilefusion(ms)" - << std::endl; + fout << "[M, N, K]\t[kTM, kTN, kTK]\tkRK\tWarp Layout\t" + "cuBLAS(ms)\tcutlass(ms)\ttilefusion(ms)" + << std::endl; - run_test(fout); - return 0; + run_test(fout); + return 0; } diff --git a/benchmarks/cpp/gemm/cutlass_gemm.cuh b/benchmarks/cpp/gemm/cutlass_gemm.cuh index be54c850..3f37eda9 100644 --- a/benchmarks/cpp/gemm/cutlass_gemm.cuh +++ b/benchmarks/cpp/gemm/cutlass_gemm.cuh @@ -19,115 +19,115 @@ template > struct GemmTraits : public Base { - using Element = Element_; - - static_assert(kTM % kWarpPerRow == 0, - "the M dimension of the CTA tile should be divisible by the " - "number of warps along that that dimension."); - static_assert(kTN % kWarpPerCol == 0, - "the N dimension of the CTA tile should be divisible by the " - "number of warps along that that dimension."); - - // declare global to shared memory copy layout. - using GmemLayoutA = Layout, Int>, Stride, _1>>; - using GmemLayoutB = Layout, Int>, Stride, _1>>; - using GmemLayoutC = Layout, Int>, Stride, _1>>; - - using TiledMma = - TiledMMA, // for ampere - Layout, Int, _1>>, - Tile, Int<16 * kWarpPerCol>, _16>>; - - static constexpr int kThreads = size(TiledMma{}); - static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); - - static constexpr int kNumPerAccess = Base::kNumPerAccess; - static constexpr int kThreadsPerCol = CeilDiv; - static constexpr int kThreadsPerRow = CeilDiv; - - using SmemLayoutAtom = decltype(composition( - Swizzle<2, 3, 3>{}, Layout>, - Stride, _1>>{})); - - using SmemLayoutA = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - using SmemLayoutB = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); - using SmemLayoutC = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using Element = Element_; + + static_assert(kTM % kWarpPerRow == 0, + "the M dimension of the CTA tile should be divisible by the " + "number of warps along that that dimension."); + static_assert(kTN % kWarpPerCol == 0, + "the N dimension of the CTA tile should be divisible by the " + "number of warps along that that dimension."); + + // declare global to shared memory copy layout. + using GmemLayoutA = Layout, Int>, Stride, _1>>; + using GmemLayoutB = Layout, Int>, Stride, _1>>; + using GmemLayoutC = Layout, Int>, Stride, _1>>; + + using TiledMma = + TiledMMA, // for ampere + Layout, Int, _1>>, + Tile, Int<16 * kWarpPerCol>, _16>>; + + static constexpr int kThreads = size(TiledMma{}); + static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); + + static constexpr int kNumPerAccess = Base::kNumPerAccess; + static constexpr int kThreadsPerCol = CeilDiv; + static constexpr int kThreadsPerRow = CeilDiv; + + using SmemLayoutAtom = decltype(composition( + Swizzle<2, 3, 3>{}, Layout>, + Stride, _1>>{})); + + using SmemLayoutA = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutB = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); + using SmemLayoutC = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); #ifdef CP_ASYNC_SM80_ENABLED - using CopyInstG2S = - Copy_Atom, Element>; + using CopyInstG2S = + Copy_Atom, Element>; #else - using CopyInstG2S = Copy_Atom; + using CopyInstG2S = Copy_Atom; #endif - using TiledCopyG2S = decltype(make_tiled_copy( - CopyInstG2S{}, - Layout, Int>, - Stride, _1>>{}, - Layout>>{})); - - using TiledCopyS2G = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Int>, - Stride, _1>>{}, - Layout>>{})); - using StoreC_R2S = R2SCopy2D; + using TiledCopyG2S = decltype(make_tiled_copy( + CopyInstG2S{}, + Layout, Int>, + Stride, _1>>{}, + Layout>>{})); + + using TiledCopyS2G = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Int>, + Stride, _1>>{}, + Layout>>{})); + using StoreC_R2S = R2SCopy2D; }; template __global__ void gemm_kernel(const Element* dA, const Element* dB, Element* dC) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); - - // Advance to the global data tile to the current CTA. - Element* gA_ptr = const_cast(dA) + blockIdx.x * kK * kTM; - Element* gB_ptr = const_cast(dB) + blockIdx.y * kK * kTN; - Element* gC_ptr = dC + blockIdx.x * kTM * kN + blockIdx.y * kTN; - - // pointers to shared memory tiles - Element* sA_ptr = buf; - Element* sB_ptr = buf + kTM * kTK; - Element* sC_ptr = buf; - - typename KeTraits::TiledMma mma; - typename KeTraits::TiledCopyG2S tiled_copy; - - auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); - auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); - auto acc = get_acc(mma); - - for (int k = 0; k < kK; k += kTK) { - copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, - typename KeTraits::SmemLayoutA{}, tiled_copy); - copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, - typename KeTraits::SmemLayoutB{}, tiled_copy); - __copy_async(); - __syncthreads(); - - for (int i = 0; i < rA.get_iters(); ++i) { - rA.copy(i); // load A register tile from shared memory - rB.copy(i); // load B register tile from shared memory - - gemm(mma, rA[i], rB[i], acc); - } - gA_ptr += kTK; - gB_ptr += kTK; - } - - // declare register to shared store plan - typename KeTraits::StoreC_R2S sC; - // store register tile to shared memory - sC.copy(acc, buf); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + + // Advance to the global data tile to the current CTA. + Element* gA_ptr = const_cast(dA) + blockIdx.x * kK * kTM; + Element* gB_ptr = const_cast(dB) + blockIdx.y * kK * kTN; + Element* gC_ptr = dC + blockIdx.x * kTM * kN + blockIdx.y * kTN; + + // pointers to shared memory tiles + Element* sA_ptr = buf; + Element* sB_ptr = buf + kTM * kTK; + Element* sC_ptr = buf; + + typename KeTraits::TiledMma mma; + typename KeTraits::TiledCopyG2S tiled_copy; + + auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); + auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); + auto acc = get_acc(mma); + + for (int k = 0; k < kK; k += kTK) { + copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, + typename KeTraits::SmemLayoutA{}, tiled_copy); + copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, + typename KeTraits::SmemLayoutB{}, tiled_copy); + __copy_async(); __syncthreads(); - // store shared memory tile to global memory - copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, - typename KeTraits::GmemLayoutC{}, - typename KeTraits::TiledCopyS2G{}); + for (int i = 0; i < rA.get_iters(); ++i) { + rA.copy(i); // load A register tile from shared memory + rB.copy(i); // load B register tile from shared memory + + gemm(mma, rA[i], rB[i], acc); + } + gA_ptr += kTK; + gB_ptr += kTK; + } + + // declare register to shared store plan + typename KeTraits::StoreC_R2S sC; + // store register tile to shared memory + sC.copy(acc, buf); + __syncthreads(); + + // store shared memory tile to global memory + copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, + typename KeTraits::GmemLayoutC{}, + typename KeTraits::TiledCopyS2G{}); } } // namespace cutlass_wrapper } // namespace benchmarks diff --git a/benchmarks/cpp/gemm/tilefusion_gemm.cuh b/benchmarks/cpp/gemm/tilefusion_gemm.cuh index d4b627ef..2011c392 100644 --- a/benchmarks/cpp/gemm/tilefusion_gemm.cuh +++ b/benchmarks/cpp/gemm/tilefusion_gemm.cuh @@ -23,83 +23,82 @@ template struct KeGemmTraits { - using BaseShape = traits::BaseTileShape; + using BaseShape = traits::BaseTileShape; - static constexpr int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; + static constexpr int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - static constexpr int kTK = dim_size<2, CtaTileShape>; + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + static constexpr int kTK = dim_size<2, CtaTileShape>; - static const bool kSwizzled = true; + static const bool kSwizzled = true; - // Total data access for operand A in global memory - using GlobalA = GlobalTile>; - using GIteratorA = GTileIterator>; + // Total data access for operand A in global memory + using GlobalA = GlobalTile>; + using GIteratorA = GTileIterator>; - // Shared Tile for operand A - using SharedA = - SharedTile, kSwizzled, kSharedAccess>; - using LoadSharedA = - tilefusion::cell::copy::GlobalToSharedLoader; + // Shared Tile for operand A + using SharedA = + SharedTile, kSwizzled, kSharedAccess>; + using LoadSharedA = + tilefusion::cell::copy::GlobalToSharedLoader; - // Access a single register tile for operand A - using SIteratorA = STileIterator>; + // Access a single register tile for operand A + using SIteratorA = STileIterator>; - // Register tile for a single thread of operand A - static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kRK / BaseShape::kCols; - using RegA = RegTile, tl::RowMajor>; + // Register tile for a single thread of operand A + static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kRK / BaseShape::kCols; + using RegA = RegTile, tl::RowMajor>; - using LoadRegA = - SharedToRegLoader; + using LoadRegA = + SharedToRegLoader; - // Total data access for operand B in global memory - // using GlobalB = GlobalTile>; - using GlobalB = GlobalTile>; - using GIteratorB = GTileIterator>; + // Total data access for operand B in global memory + // using GlobalB = GlobalTile>; + using GlobalB = GlobalTile>; + using GIteratorB = GTileIterator>; - // Shared Tile for operand B - using SharedB = - SharedTile, kSwizzled, kSharedAccess>; - using LoadSharedB = - tilefusion::cell::copy::GlobalToSharedLoader; + // Shared Tile for operand B + using SharedB = + SharedTile, kSwizzled, kSharedAccess>; + using LoadSharedB = + tilefusion::cell::copy::GlobalToSharedLoader; - // Access a single register tile for operand B - using SIteratorB = STileIterator>; + // Access a single register tile for operand B + using SIteratorB = STileIterator>; - static_assert(SIteratorA::sc1 == SIteratorB::sc0, - "mismatched K dimension!"); + static_assert(SIteratorA::sc1 == SIteratorB::sc0, "mismatched K dimension!"); - // Register tile for a single thread of operand A - static constexpr int kBKs = kRK / BaseShape::kRows; - static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegB = RegTile, tl::ColMajor>; + // Register tile for a single thread of operand A + static constexpr int kBKs = kRK / BaseShape::kRows; + static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegB = RegTile, tl::ColMajor>; - using LoadRegB = - SharedToRegLoader; + using LoadRegB = + SharedToRegLoader; - // Global Tile for output C - using GlobalC = GlobalTile>; - // Shared Tile for output C - using SharedC = SharedTile, kSwizzled>; + // Global Tile for output C + using GlobalC = GlobalTile>; + // Shared Tile for output C + using SharedC = SharedTile, kSwizzled>; - // Register Tile for output C - static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; - using Acc = RegTile, tl::RowMajor>; - using AccHalf = RegTile, tl::RowMajor>; + // Register Tile for output C + static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; + using Acc = RegTile, tl::RowMajor>; + using AccHalf = RegTile, tl::RowMajor>; - using ConvertAcc = compute::RegTileConvert; + using ConvertAcc = compute::RegTileConvert; - using StoreRegC = RegToSharedStorer; - using StoreSharedC = SharedToGlobalStorer; + using StoreRegC = RegToSharedStorer; + using StoreSharedC = SharedToGlobalStorer; }; template __global__ void gemm(const InType* dA_, const InType* dB_, InType* dC_) { - InType* dA = const_cast(dA_) + blockIdx.x * kTM * kK; - InType* dB = const_cast(dB_) + blockIdx.y * kTN * kK; - InType* dC = dC_ + blockIdx.x * kTM * kN + blockIdx.y * kTN; + InType* dA = const_cast(dA_) + blockIdx.x * kTM * kK; + InType* dB = const_cast(dB_) + blockIdx.y * kTN * kK; + InType* dC = dC_ + blockIdx.x * kTM * kN + blockIdx.y * kTN; - extern __shared__ __align__(sizeof(double)) unsigned char buf[]; - InType* sA_ptr = reinterpret_cast(buf); - InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; - InType* sC_ptr = reinterpret_cast(buf); + extern __shared__ __align__(sizeof(double)) unsigned char buf[]; + InType* sA_ptr = reinterpret_cast(buf); + InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; + InType* sC_ptr = reinterpret_cast(buf); - GIteratorA gAs(dA); - GIteratorB gBs(dB); + GIteratorA gAs(dA); + GIteratorB gBs(dB); - SharedA sA(sA_ptr); - SIteratorA sAs(sA_ptr); - RegA rA; + SharedA sA(sA_ptr); + SIteratorA sAs(sA_ptr); + RegA rA; - SharedB sB(sB_ptr); - SIteratorB sBs(sB_ptr); - RegB rB; + SharedB sB(sB_ptr); + SIteratorB sBs(sB_ptr); + RegB rB; - Acc acc; - AccHalf acc_half; - ConvertAcc convert_acc; + Acc acc; + AccHalf acc_half; + ConvertAcc convert_acc; - SharedC sC(sC_ptr); - GlobalC gC(dC); + SharedC sC(sC_ptr); + GlobalC gC(dC); - LoadSharedA load_sA; - LoadRegA load_rA; + LoadSharedA load_sA; + LoadRegA load_rA; - LoadSharedB load_sB; - LoadRegB load_rB; + LoadSharedB load_sB; + LoadRegB load_rB; - StoreRegC store_rC; - StoreSharedC store_sC; + StoreRegC store_rC; + StoreSharedC store_sC; - for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { - load_sA(gAs(k1), sA); - load_sB(gBs(k1), sB); - __copy_async(); - __syncthreads(); + for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { + load_sA(gAs(k1), sA); + load_sB(gBs(k1), sB); + __copy_async(); + __syncthreads(); - for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { - load_rA(sAs(k2), rA); - load_rB(sBs(k2), rB); + for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { + load_rA(sAs(k2), rA); + load_rB(sBs(k2), rB); - gemm(rA, rB, acc); - } + gemm(rA, rB, acc); } + } - convert_acc(acc, acc_half); + convert_acc(acc, acc_half); - store_rC(acc_half, sC); - __syncthreads(); - store_sC(sC, gC); + store_rC(acc_half, sC); + __syncthreads(); + store_sC(sC, gC); } diff --git a/benchmarks/cpp/gemm/util.cuh b/benchmarks/cpp/gemm/util.cuh index 56e7474a..78f815e3 100644 --- a/benchmarks/cpp/gemm/util.cuh +++ b/benchmarks/cpp/gemm/util.cuh @@ -16,56 +16,56 @@ using namespace benchmarks; using namespace tilefusion; float rand_float(float a = 1e-4, float b = 5e-3) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; } namespace { bool check_results_impl(const __half* values1, const __half* values2, int numel) { - bool passed = true; - const float epsilon = 1e-3; + bool passed = true; + const float epsilon = 1e-3; - double total_diff = 0.; - double max_abs_diff = FLT_MIN; - double diff = 0.; + double total_diff = 0.; + double max_abs_diff = FLT_MIN; + double diff = 0.; #ifdef DEBUG - int cut_off = 128; - printf("\nground truth:\n"); - for (int i = 0; i < cut_off; ++i) { - printf("%.5f, ", __half2float(values1[i])); - if (i && (i + 1) % 16 == 0) printf("\n"); - } - printf("\ncomputed values:\n"); - for (int i = 0; i < cut_off; ++i) { - printf("%.5f, ", __half2float(values2[i])); - if (i && (i + 1) % 16 == 0) printf("\n"); - } + int cut_off = 128; + printf("\nground truth:\n"); + for (int i = 0; i < cut_off; ++i) { + printf("%.5f, ", __half2float(values1[i])); + if (i && (i + 1) % 16 == 0) printf("\n"); + } + printf("\ncomputed values:\n"); + for (int i = 0; i < cut_off; ++i) { + printf("%.5f, ", __half2float(values2[i])); + if (i && (i + 1) % 16 == 0) printf("\n"); + } #endif - for (int i = 0; i < numel; ++i) { - float v1 = __half2float(values1[i]); - float v2 = __half2float(values2[i]); + for (int i = 0; i < numel; ++i) { + float v1 = __half2float(values1[i]); + float v2 = __half2float(values2[i]); - diff = fabs(v1 - v2); - max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; - total_diff += diff; + diff = fabs(v1 - v2); + max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; + total_diff += diff; #ifdef DEBUG - if (diff > epsilon) { - printf("the %d-th value differs (%.4f): %.4f vs. %.4f\n", i, diff, - v1, v2); - } -#endif + if (diff > epsilon) { + printf("the %d-th value differs (%.4f): %.4f vs. %.4f\n", i, diff, v1, + v2); } +#endif + } - double avg_diff = total_diff / numel; - if (avg_diff > epsilon) passed = false; + double avg_diff = total_diff / numel; + if (avg_diff > epsilon) passed = false; - return passed; + return passed; } } // namespace @@ -76,61 +76,61 @@ bool check_results(const T* values1_, const cutlass::half_t* values2, template <> bool check_results(const cutlass::half_t* values1_, const cutlass::half_t* values2_, int numel) { - const __half* values1 = reinterpret_cast(values1_); - const __half* values2 = reinterpret_cast(values2_); - return check_results_impl(values1, values2, numel); + const __half* values1 = reinterpret_cast(values1_); + const __half* values2 = reinterpret_cast(values2_); + return check_results_impl(values1, values2, numel); } template <> bool check_results(const __half* values1, const cutlass::half_t* values2_, int numel) { - const __half* values2 = reinterpret_cast(values2_); - return check_results_impl(values1, values2, numel); + const __half* values2 = reinterpret_cast(values2_); + return check_results_impl(values1, values2, numel); } template <> bool check_results(const float* values1, const cutlass::half_t* values2_, int numel) { - const __half* values2 = reinterpret_cast(values2_); - __half* hvalues1 = (__half*)malloc(numel * sizeof(__half)); - for (int i = 0; i < numel; ++i) { - hvalues1[i] = __float2half(values1[i]); - } - return check_results_impl(hvalues1, values2, numel); + const __half* values2 = reinterpret_cast(values2_); + __half* hvalues1 = (__half*)malloc(numel * sizeof(__half)); + for (int i = 0; i < numel; ++i) { + hvalues1[i] = __float2half(values1[i]); + } + return check_results_impl(hvalues1, values2, numel); } float cublas_hgemm(int64_t kM, int64_t kN, int64_t kK, const __half* A, const __half* B, __half* C, bool timeit = false, int warm_up = 5, int iters = 20) { - cublasHandle_t handle; - cublasCreate(&handle); - - __half alf = static_cast<__half>(1.); - __half bet = static_cast<__half>(0.); - - float elapsed = 0.; - - if (timeit) { - for (int i = 0; i < warm_up; ++i) { - cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, - kK, &alf, B, kK, A, kK, &bet, C, kN); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) { - cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, - kK, &alf, B, kK, A, kK, &bet, C, kN); - } - cudaDeviceSynchronize(); - elapsed = timer.stop() / iters; - } else { - cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, - &alf, B, kK, A, kK, &bet, C, kN); + cublasHandle_t handle; + cublasCreate(&handle); + + __half alf = static_cast<__half>(1.); + __half bet = static_cast<__half>(0.); + + float elapsed = 0.; + + if (timeit) { + for (int i = 0; i < warm_up; ++i) { + cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, + &alf, B, kK, A, kK, &bet, C, kN); } cudaDeviceSynchronize(); - cublasDestroy(handle); - return elapsed; + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) { + cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, + &alf, B, kK, A, kK, &bet, C, kN); + } + cudaDeviceSynchronize(); + elapsed = timer.stop() / iters; + } else { + cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, &alf, + B, kK, A, kK, &bet, C, kN); + } + cudaDeviceSynchronize(); + + cublasDestroy(handle); + return elapsed; } diff --git a/benchmarks/utils/cpp/cuda_utils.cuh b/benchmarks/utils/cpp/cuda_utils.cuh index 81800a18..f2dc51bd 100644 --- a/benchmarks/utils/cpp/cuda_utils.cuh +++ b/benchmarks/utils/cpp/cuda_utils.cuh @@ -15,57 +15,57 @@ template inline constexpr int CeilDiv = (a + b - 1) / b; // for compile-time values #if defined(__CUDA_ARCH__) - #define HOST_DEVICE __forceinline__ __host__ __device__ - #define DEVICE __forceinline__ __device__ - #define HOST __forceinline__ __host__ + #define HOST_DEVICE __forceinline__ __host__ __device__ + #define DEVICE __forceinline__ __device__ + #define HOST __forceinline__ __host__ #else - #define HOST_DEVICE inline - #define DEVICE inline - #define HOST inline + #define HOST_DEVICE inline + #define DEVICE inline + #define HOST inline #endif const char* cublasGetErrorString(cublasStatus_t status) { - switch (status) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return "unknown error"; + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "unknown error"; } inline void __cublasCheck(const cublasStatus_t err, const char* file, int line) { - if (err != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "%s(%d): Cublas error: %s.\n", file, line, - cublasGetErrorString(err)); - exit(EXIT_FAILURE); - } + if (err != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "%s(%d): Cublas error: %s.\n", file, line, + cublasGetErrorString(err)); + exit(EXIT_FAILURE); + } } #define CublasCheck(call) __cublasCheck(call, __FILE__, __LINE__) inline void __cudaCheck(const cudaError err, const char* file, int line) { - if (err != cudaSuccess) { - fprintf(stderr, "%s(%d): CUDA error: %s.\n", file, line, - cudaGetErrorString(err)); - exit(EXIT_FAILURE); - } + if (err != cudaSuccess) { + fprintf(stderr, "%s(%d): CUDA error: %s.\n", file, line, + cudaGetErrorString(err)); + exit(EXIT_FAILURE); + } } #define CudaCheck(call) __cudaCheck(call, __FILE__, __LINE__) diff --git a/benchmarks/utils/cpp/cutlass/convert.cuh b/benchmarks/utils/cpp/cutlass/convert.cuh index ab833dca..b0ee3520 100644 --- a/benchmarks/utils/cpp/cutlass/convert.cuh +++ b/benchmarks/utils/cpp/cutlass/convert.cuh @@ -13,25 +13,25 @@ using namespace cute; namespace { template DEVICE auto convert_type(cute::Tensor const& tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = - convert_op(*reinterpret_cast*>( - tensor.data())); + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } template struct IndexedTensor_ { - DEVICE IndexedTensor_(Tensor& tensor) : tensor_(tensor) {} + DEVICE IndexedTensor_(Tensor& tensor) : tensor_(tensor) {} - DEVICE const auto operator[](int idx) { return tensor_(_, _, idx); } + DEVICE const auto operator[](int idx) { return tensor_(_, _, idx); } - private: - Tensor& tensor_; + private: + Tensor& tensor_; }; } // namespace @@ -40,28 +40,28 @@ struct IndexedTensor_ { // using m16n8k8. template DEVICE auto convert_layout(const Tensor& acc) { - auto acc_layout = acc.layout(); + auto acc_layout = acc.layout(); - using X = Underscore; - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(cute::rank(acc_layout))::value == 3); + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(cute::rank(acc_layout))::value == 3); - constexpr int mma_shape_K = cute::get<2>(typename MMA::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int mma_shape_K = cute::get<2>(typename MMA::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - IndexedTensor_ indexed_tensor(acc); - return indexed_tensor; - } else { - // (4, MMA_M, (2, MMA_N / 2))) - auto l = cute::logical_divide(acc_layout, Shape{}); - auto new_layout = make_layout(make_layout(get<0>(l), get<2, 0>(l)), - get<1>(l), get<2, 1>(l)); - auto new_tensor = make_tensor(acc.data(), new_layout); + if constexpr (mma_shape_K == 8) { + IndexedTensor_ indexed_tensor(acc); + return indexed_tensor; + } else { + // (4, MMA_M, (2, MMA_N / 2))) + auto l = cute::logical_divide(acc_layout, Shape{}); + auto new_layout = make_layout(make_layout(get<0>(l), get<2, 0>(l)), + get<1>(l), get<2, 1>(l)); + auto new_tensor = make_tensor(acc.data(), new_layout); - IndexedTensor_ indexed_tensor(new_tensor); - return indexed_tensor; - } + IndexedTensor_ indexed_tensor(new_tensor); + return indexed_tensor; + } }; } // namespace cutlass_wrapper } // namespace benchmarks diff --git a/benchmarks/utils/cpp/cutlass/copy.cuh b/benchmarks/utils/cpp/cutlass/copy.cuh index 24c96a10..9987b938 100644 --- a/benchmarks/utils/cpp/cutlass/copy.cuh +++ b/benchmarks/utils/cpp/cutlass/copy.cuh @@ -11,31 +11,31 @@ namespace cutlass_wrapper { using namespace cute; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - #define CP_ASYNC_SM80_ENABLED + #define CP_ASYNC_SM80_ENABLED #endif template DEVICE void wait_group() { #if defined(CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); #endif } DEVICE void commit_copy_group() { #if defined(CP_ASYNC_SM80_ENABLED) - cute::cp_async_fence(); + cute::cp_async_fence(); #endif } DEVICE void __copy_async() { - commit_copy_group(); - wait_group<0>(); + commit_copy_group(); + wait_group<0>(); } template DEVICE void cp_async_wait_flash() { #if defined(CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); #endif } @@ -45,21 +45,21 @@ template (src)); ++i) + for (int i = 0; i < int(size<1>(src)); ++i) #pragma unroll - for (int j = 0; j < int(size<2>(src)); ++j) - cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); + for (int j = 0; j < int(size<2>(src)); ++j) + cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); } // Copy a tensor from shared memory to global memory @@ -68,115 +68,115 @@ template (src)); ++i) + for (int i = 0; i < int(size<1>(src)); ++i) #pragma unroll - for (int j = 0; j < int(size<2>(src)); ++j) - cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); + for (int j = 0; j < int(size<2>(src)); ++j) + cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); } template struct R2SCopy2D { - using TiledMma = TiledMma_; - using Dstlayout_ = DstLayout; - using CopyAtom = Copy_Atom; - - public: - template - __device__ void copy(cute::Tensor const& acc, - Element* dst_data) { - int tid = threadIdx.x; - - // FIXME(haruhi): This implementation is specifically designed - // for tcu WMMA and assumes that the ACC value has a - // floating-point precision. The code converts the ACC value - // to half-precision. - auto src_tensor = convert_type(acc); - auto dst_tensor = make_tensor(make_smem_ptr(dst_data), DstLayout{}); - - auto tiled_copy = make_tiled_copy_C(CopyAtom{}, TiledMma{}); - auto thrd_copy = tiled_copy.get_thread_slice(tid); - - auto src = thrd_copy.retile_S(src_tensor); - auto dst = thrd_copy.partition_D(dst_tensor); - cute::copy(tiled_copy, src, dst); - } - - private: - template - DEVICE auto convert_type(cute::Tensor const& tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op( - *reinterpret_cast*>( - tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); - } + using TiledMma = TiledMma_; + using Dstlayout_ = DstLayout; + using CopyAtom = Copy_Atom; + + public: + template + __device__ void copy(cute::Tensor const& acc, + Element* dst_data) { + int tid = threadIdx.x; + + // FIXME(haruhi): This implementation is specifically designed + // for tcu WMMA and assumes that the ACC value has a + // floating-point precision. The code converts the ACC value + // to half-precision. + auto src_tensor = convert_type(acc); + auto dst_tensor = make_tensor(make_smem_ptr(dst_data), DstLayout{}); + + auto tiled_copy = make_tiled_copy_C(CopyAtom{}, TiledMma{}); + auto thrd_copy = tiled_copy.get_thread_slice(tid); + + auto src = thrd_copy.retile_S(src_tensor); + auto dst = thrd_copy.partition_D(dst_tensor); + cute::copy(tiled_copy, src, dst); + } + + private: + template + DEVICE auto convert_type(cute::Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } }; template struct Shm2RegLoad { - public: - DEVICE Shm2RegLoad(TiledCopy& copy, const STensor& src, DTensor& dst, - DTensorView& dst_view) - : tiled_copy_(copy), src_(src), dst_(dst), dst_view_(dst_view) {} + public: + DEVICE Shm2RegLoad(TiledCopy& copy, const STensor& src, DTensor& dst, + DTensorView& dst_view) + : tiled_copy_(copy), src_(src), dst_(dst), dst_view_(dst_view) {} - DEVICE void copy(int pos) { - cute::copy(tiled_copy_, src_(_, _, pos), dst_view_(_, _, pos)); - } + DEVICE void copy(int pos) { + cute::copy(tiled_copy_, src_(_, _, pos), dst_view_(_, _, pos)); + } - DEVICE int get_iters() { return size<2>(dst_); } + DEVICE int get_iters() { return size<2>(dst_); } - DEVICE const auto operator[](int idx) { return dst_(_, _, idx); } + DEVICE const auto operator[](int idx) { return dst_(_, _, idx); } - private: - TiledCopy& tiled_copy_; - const STensor& src_; - DTensor& dst_; - DTensorView& dst_view_; + private: + TiledCopy& tiled_copy_; + const STensor& src_; + DTensor& dst_; + DTensorView& dst_view_; }; template DEVICE auto get_acc(const TiledMma& tiled_mma) { - auto acc = partition_fragment_C(tiled_mma, Shape, Int>{}); - clear(acc); + auto acc = partition_fragment_C(tiled_mma, Shape, Int>{}); + clear(acc); - return acc; + return acc; } template DEVICE auto make_s2rA(const Element* data, const Layout& layout, const TiledMma& tiled_mma) { - int tid = threadIdx.x; + int tid = threadIdx.x; - auto tensor = cute::make_tensor(make_smem_ptr(data), layout); + auto tensor = cute::make_tensor(make_smem_ptr(data), layout); - using SmemLoadAtom = Copy_Atom; - auto tiled_copy = make_tiled_copy_A(SmemLoadAtom{}, tiled_mma); + using SmemLoadAtom = Copy_Atom; + auto tiled_copy = make_tiled_copy_A(SmemLoadAtom{}, tiled_mma); - auto thrd_copy = tiled_copy.get_thread_slice(tid); - auto src = thrd_copy.partition_S(tensor); + auto thrd_copy = tiled_copy.get_thread_slice(tid); + auto src = thrd_copy.partition_S(tensor); - // partition register - auto thr_mma = tiled_mma.get_thread_slice(tid); - auto dst = thr_mma.partition_fragment_A(tensor); - auto dst_view = thrd_copy.retile_D(dst); + // partition register + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto dst = thr_mma.partition_fragment_A(tensor); + auto dst_view = thrd_copy.retile_D(dst); - Shm2RegLoad loader(tiled_copy, src, dst, dst_view); - return loader; + Shm2RegLoad loader(tiled_copy, src, dst, dst_view); + return loader; } // FIXIME(haruhi): the current implementation is for fast experiment, @@ -184,22 +184,22 @@ DEVICE auto make_s2rA(const Element* data, const Layout& layout, template DEVICE auto make_s2rB(const Element* data, const Layout& layout, const TiledMma& tiled_mma) { - int tid = threadIdx.x; + int tid = threadIdx.x; - using SmemLoadAtom = Copy_Atom; - auto tiled_copy = make_tiled_copy_B(SmemLoadAtom{}, tiled_mma); - auto thrd_copy = tiled_copy.get_thread_slice(tid); + using SmemLoadAtom = Copy_Atom; + auto tiled_copy = make_tiled_copy_B(SmemLoadAtom{}, tiled_mma); + auto thrd_copy = tiled_copy.get_thread_slice(tid); - auto tensor = make_tensor(make_smem_ptr(data), layout); - auto src = thrd_copy.partition_S(tensor); + auto tensor = make_tensor(make_smem_ptr(data), layout); + auto src = thrd_copy.partition_S(tensor); - // partition register - auto thr_mma = tiled_mma.get_thread_slice(tid); - auto dst = thr_mma.partition_fragment_B(tensor); - auto dst_view = thrd_copy.retile_D(dst); + // partition register + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto dst = thr_mma.partition_fragment_B(tensor); + auto dst_view = thrd_copy.retile_D(dst); - Shm2RegLoad loader(tiled_copy, src, dst, dst_view); - return loader; + Shm2RegLoad loader(tiled_copy, src, dst, dst_view); + return loader; } } // namespace cutlass_wrapper diff --git a/benchmarks/utils/cpp/cutlass/traits_base.cuh b/benchmarks/utils/cpp/cutlass/traits_base.cuh index ce3e4f2d..140487c4 100644 --- a/benchmarks/utils/cpp/cutlass/traits_base.cuh +++ b/benchmarks/utils/cpp/cutlass/traits_base.cuh @@ -8,10 +8,10 @@ namespace cutlass_wrapper { template struct AccessBase { - // the maximal width of vectorized access. - static constexpr int kAccessInBits = 128; - static constexpr int kElmentBits = cutlass::sizeof_bits::value; - static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; + // the maximal width of vectorized access. + static constexpr int kAccessInBits = 128; + static constexpr int kElmentBits = cutlass::sizeof_bits::value; + static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; }; } // namespace cutlass_wrapper diff --git a/examples/101_gemm/01_gemm_global_reg/gemm.hpp b/examples/101_gemm/01_gemm_global_reg/gemm.hpp index 02054ea9..a519e855 100644 --- a/examples/101_gemm/01_gemm_global_reg/gemm.hpp +++ b/examples/101_gemm/01_gemm_global_reg/gemm.hpp @@ -16,47 +16,47 @@ using GemmShape = TileShape; template struct GemmTraits { - using BaseShape = traits::BaseTileShape; - static constexpr int kChunkK = 64; + using BaseShape = traits::BaseTileShape; + static constexpr int kChunkK = 64; - static constexpr int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; + static constexpr int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; - // operand A - using GlobalA = GlobalTile>; - using IteratorA = GTileIterator>; + // operand A + using GlobalA = GlobalTile>; + using IteratorA = GTileIterator>; - static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kChunkK / BaseShape::kCols; - using RegA = RegTile, tl::RowMajor>; + static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kChunkK / BaseShape::kCols; + using RegA = RegTile, tl::RowMajor>; - using ALoader = copy::GlobalToRegLoader; + using ALoader = + copy::GlobalToRegLoader; - // operand B - using GlobalB = GlobalTile>; - using IteratorB = GTileIterator>; + // operand B + using GlobalB = GlobalTile>; + using IteratorB = GTileIterator>; - static constexpr int kBKs = kChunkK / BaseShape::kRows; - static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegB = RegTile, tl::ColMajor>; + static constexpr int kBKs = kChunkK / BaseShape::kRows; + static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegB = RegTile, tl::ColMajor>; - using BLoader = copy::GlobalToRegLoader; + using BLoader = + copy::GlobalToRegLoader; - // output C - using GlobalC = GlobalTile>; + // output C + using GlobalC = GlobalTile>; - static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegC = RegTile, tl::RowMajor>; + static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegC = RegTile, tl::RowMajor>; - using CStorer = copy::RegToGlobalStorer; + using CStorer = copy::RegToGlobalStorer; }; template __global__ void simple_gemm(const InType* dA, const InType* dB, AccType* dC) { - int offset_a = blockIdx.x * kTM * kK; - int offset_b = blockIdx.y * kTN * kK; - int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; + int offset_a = blockIdx.x * kTM * kK; + int offset_b = blockIdx.y * kTN * kK; + int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; - IteratorA gAs(dA + offset_a); - RegA rA; - ALoader loader_a; + IteratorA gAs(dA + offset_a); + RegA rA; + ALoader loader_a; - IteratorB gBs(dB + offset_b); - RegB rB; - BLoader loader_b; + IteratorB gBs(dB + offset_b); + RegB rB; + BLoader loader_b; - RegC acc; - GlobalC gC(dC + offset_c); - CStorer storer_c; + RegC acc; + GlobalC gC(dC + offset_c); + CStorer storer_c; - for (int k = 0; k < IteratorA::sc1; ++k) { - loader_a(gAs(k), rA); - loader_b(gBs(k), rB); - __syncthreads(); - - compute::gemm(rA, rB, acc); - } + for (int k = 0; k < IteratorA::sc1; ++k) { + loader_a(gAs(k), rA); + loader_b(gBs(k), rB); __syncthreads(); - storer_c(acc, gC); + compute::gemm(rA, rB, acc); + } + __syncthreads(); + + storer_c(acc, gC); } diff --git a/examples/101_gemm/01_gemm_global_reg/main.cu b/examples/101_gemm/01_gemm_global_reg/main.cu index 8eac50d8..41b58d8d 100644 --- a/examples/101_gemm/01_gemm_global_reg/main.cu +++ b/examples/101_gemm/01_gemm_global_reg/main.cu @@ -6,107 +6,106 @@ template int run_test() { - using InType = __half; - using AccType = float; - - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; - - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - - thrust::host_vector h_a(kM * kK); - for (int i = 0; i < h_a.size(); ++i) - h_a[i] = static_cast(rand_float()); - - thrust::host_vector h_b(kK * kN); - for (int i = 0; i < h_b.size(); ++i) - h_b[i] = static_cast(rand_float()); - - thrust::host_vector h_c(kM * kN); - thrust::fill(h_c.begin(), h_c.end(), 0.); - - thrust::device_vector d_a = h_a; - thrust::device_vector d_b = h_b; - thrust::device_vector d_c = h_c; - - const InType* A = thrust::raw_pointer_cast(d_a.data()); - const InType* B = thrust::raw_pointer_cast(d_b.data()); - AccType* C = thrust::raw_pointer_cast(d_c.data()); - - using Config = - GemmTraits; - - using RegA = typename Config::RegA; - using RegB = typename Config::RegB; - using RegC = typename Config::RegC; - - using IteratorA = typename Config::IteratorA; - using IteratorB = typename Config::IteratorB; - - int block_x = CeilDiv; - int block_y = CeilDiv; - - std::cout << "kThreads: " << Config::kThreads << std::endl - << "RegA: " << RegA{} << std::endl - << "RegB: " << RegB{} << std::endl - << "RegC: " << RegC{} << std::endl - << "IteratorA: " << IteratorA{} << std::endl - << "IteratorB: " << IteratorB{} << std::endl - << "blocks: [" << block_x << ", " << block_y << "]" << std::endl - << std::endl; - - dim3 dim_grid(block_x, block_y, 1); - dim3 dim_block(Config::kThreads, 1, 1); - simple_gemm - <<>>(A, B, C); + using InType = __half; + using AccType = float; + + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; + + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + + thrust::host_vector h_a(kM * kK); + for (int i = 0; i < h_a.size(); ++i) + h_a[i] = static_cast(rand_float()); + + thrust::host_vector h_b(kK * kN); + for (int i = 0; i < h_b.size(); ++i) + h_b[i] = static_cast(rand_float()); + + thrust::host_vector h_c(kM * kN); + thrust::fill(h_c.begin(), h_c.end(), 0.); + + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_c = h_c; + + const InType* A = thrust::raw_pointer_cast(d_a.data()); + const InType* B = thrust::raw_pointer_cast(d_b.data()); + AccType* C = thrust::raw_pointer_cast(d_c.data()); + + using Config = GemmTraits; + + using RegA = typename Config::RegA; + using RegB = typename Config::RegB; + using RegC = typename Config::RegC; + + using IteratorA = typename Config::IteratorA; + using IteratorB = typename Config::IteratorB; + + int block_x = CeilDiv; + int block_y = CeilDiv; + + std::cout << "kThreads: " << Config::kThreads << std::endl + << "RegA: " << RegA{} << std::endl + << "RegB: " << RegB{} << std::endl + << "RegC: " << RegC{} << std::endl + << "IteratorA: " << IteratorA{} << std::endl + << "IteratorB: " << IteratorB{} << std::endl + << "blocks: [" << block_x << ", " << block_y << "]" << std::endl + << std::endl; + + dim3 dim_grid(block_x, block_y, 1); + dim3 dim_block(Config::kThreads, 1, 1); + simple_gemm + <<>>(A, B, C); + cudaDeviceSynchronize(); + h_c = d_c; + + // check correctness + thrust::device_vector d_c2(kM * kN); + thrust::fill(d_c2.begin(), d_c2.end(), 0.); + + cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c2.data()), false /*timeit*/); + thrust::host_vector h_c2 = d_c2; + + bool passed = check_results(thrust::raw_pointer_cast(h_c.data()), + thrust::raw_pointer_cast(h_c2.data()), kM * kN); + + if (passed) { + std::cout << "Test passed." << std::endl; + + CudaTimer timer; + timer.start(); + int iters = 20; + for (int i = 0; i < iters; ++i) { + simple_gemm + <<>>(A, B, C); + } cudaDeviceSynchronize(); - h_c = d_c; - - // check correctness - thrust::device_vector d_c2(kM * kN); - thrust::fill(d_c2.begin(), d_c2.end(), 0.); - - cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c2.data()), false /*timeit*/); - thrust::host_vector h_c2 = d_c2; - - bool passed = check_results(thrust::raw_pointer_cast(h_c.data()), - thrust::raw_pointer_cast(h_c2.data()), kM * kN); - - if (passed) { - std::cout << "Test passed." << std::endl; - - CudaTimer timer; - timer.start(); - int iters = 20; - for (int i = 0; i < iters; ++i) { - simple_gemm - <<>>(A, B, C); - } - cudaDeviceSynchronize(); - - float time = timer.stop(); - std::cout << std::setprecision(4) << "elapsed time: " << time / iters - << " ms" << std::endl; - - } else - std::cerr << "Test failed." << std::endl; - - return 0; + + float time = timer.stop(); + std::cout << std::setprecision(4) << "elapsed time: " << time / iters + << " ms" << std::endl; + + } else + std::cerr << "Test failed." << std::endl; + + return 0; } int main(int argc, char* argv[]) { - run_test, GemmShape<256, 128, 64>, - tl::RowMajor<2, 2>>(); + run_test, GemmShape<256, 128, 64>, + tl::RowMajor<2, 2>>(); - return 0; + return 0; } diff --git a/examples/101_gemm/02_gemm_all_mem/gemm.hpp b/examples/101_gemm/02_gemm_all_mem/gemm.hpp index a5e36cf4..0249537d 100644 --- a/examples/101_gemm/02_gemm_all_mem/gemm.hpp +++ b/examples/101_gemm/02_gemm_all_mem/gemm.hpp @@ -18,78 +18,76 @@ using GemmShape = TileShape; template struct KeGemmTraits { - using BaseShape = traits::BaseTileShape; - - static constexpr int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; - - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; - - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - static constexpr int kTK = dim_size<2, CtaTileShape>; - - static const bool kSwizzled = true; - - // Total data access for operand A in global memory - using GlobalA = GlobalTile>; - // Access a single global tile for operand A - using GIteratorA = GTileIterator>; - - // Shared Tile for operand A - using SharedA = SharedTile, kSwizzled>; - // Access a single register tile for operand A - using SIteratorA = STileIterator>; - - // Register tile for a single thread of operand A - static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kRK / BaseShape::kCols; - using RegA = RegTile, tl::RowMajor>; - - // Loaders for operand A - using G2SLoaderA = GlobalToSharedLoader; - using S2RLoaderA = - SharedToRegLoader; - - // Total data access for operand B in global memory - using GlobalB = GlobalTile>; - // Access a single global tile for operand B - using GIteratorB = GTileIterator>; - - // Shared Tile for operand B - using SharedB = SharedTile, kSwizzled>; - // Access a single register tile for operand B - using SIteratorB = STileIterator>; - - static_assert(GIteratorA::sc1 == GIteratorB::sc0, - "mismatched K dimension!"); - static_assert(SIteratorA::sc1 == SIteratorB::sc0, - "mismatched K dimension!"); - - // Register tile for a single thread of operand A - static constexpr int kBKs = kRK / BaseShape::kRows; - static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegB = RegTile, tl::ColMajor>; - - using G2SLoaderB = GlobalToSharedLoader; - using S2RLoaderB = - SharedToRegLoader; - - // Global Tile for output C - using GlobalC = GlobalTile>; - // Shared Tile for output C - using SharedC = SharedTile, kSwizzled>; - - // Register Tile for output C - static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegC = RegTile, tl::RowMajor>; - - using R2SStorerC = RegToSharedStorer; - using S2GStorerC = SharedToGlobalStorer; + using BaseShape = traits::BaseTileShape; + + static constexpr int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; + + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; + + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + static constexpr int kTK = dim_size<2, CtaTileShape>; + + static const bool kSwizzled = true; + + // Total data access for operand A in global memory + using GlobalA = GlobalTile>; + // Access a single global tile for operand A + using GIteratorA = GTileIterator>; + + // Shared Tile for operand A + using SharedA = SharedTile, kSwizzled>; + // Access a single register tile for operand A + using SIteratorA = STileIterator>; + + // Register tile for a single thread of operand A + static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kRK / BaseShape::kCols; + using RegA = RegTile, tl::RowMajor>; + + // Loaders for operand A + using G2SLoaderA = GlobalToSharedLoader; + using S2RLoaderA = + SharedToRegLoader; + + // Total data access for operand B in global memory + using GlobalB = GlobalTile>; + // Access a single global tile for operand B + using GIteratorB = GTileIterator>; + + // Shared Tile for operand B + using SharedB = SharedTile, kSwizzled>; + // Access a single register tile for operand B + using SIteratorB = STileIterator>; + + static_assert(GIteratorA::sc1 == GIteratorB::sc0, "mismatched K dimension!"); + static_assert(SIteratorA::sc1 == SIteratorB::sc0, "mismatched K dimension!"); + + // Register tile for a single thread of operand A + static constexpr int kBKs = kRK / BaseShape::kRows; + static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegB = RegTile, tl::ColMajor>; + + using G2SLoaderB = GlobalToSharedLoader; + using S2RLoaderB = + SharedToRegLoader; + + // Global Tile for output C + using GlobalC = GlobalTile>; + // Shared Tile for output C + using SharedC = SharedTile, kSwizzled>; + + // Register Tile for output C + static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegC = RegTile, tl::RowMajor>; + + using R2SStorerC = RegToSharedStorer; + using S2GStorerC = SharedToGlobalStorer; }; template __global__ void gemm(const InType* dA, const InType* dB, AccType* dC) { - int offset_a = blockIdx.x * kTM * kK; - int offset_b = blockIdx.y * kTN * kK; - int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; + int offset_a = blockIdx.x * kTM * kK; + int offset_b = blockIdx.y * kTN * kK; + int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; - extern __shared__ __align__(sizeof(double)) unsigned char buf[]; - InType* sA_ptr = reinterpret_cast(buf); - InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; - AccType* sC_ptr = reinterpret_cast(buf); + extern __shared__ __align__(sizeof(double)) unsigned char buf[]; + InType* sA_ptr = reinterpret_cast(buf); + InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; + AccType* sC_ptr = reinterpret_cast(buf); - // declare tiles, iterators and loaders - GIteratorA gAs(dA + offset_a); - SIteratorA sAs(sA_ptr); + // declare tiles, iterators and loaders + GIteratorA gAs(dA + offset_a); + SIteratorA sAs(sA_ptr); - GIteratorB gBs(dB + offset_b); - SIteratorB sBs(sB_ptr); + GIteratorB gBs(dB + offset_b); + SIteratorB sBs(sB_ptr); - SharedA sA(sA_ptr); - RegA rA; + SharedA sA(sA_ptr); + RegA rA; - SharedB sB(sB_ptr); - RegB rB; + SharedB sB(sB_ptr); + RegB rB; - RegC acc; - SharedC sC(sC_ptr); - GlobalC gC(dC + offset_c); + RegC acc; + SharedC sC(sC_ptr); + GlobalC gC(dC + offset_c); - G2SLoaderA g2s_a; - S2RLoaderA s2r_a; + G2SLoaderA g2s_a; + S2RLoaderA s2r_a; - G2SLoaderB g2s_b; - S2RLoaderB s2r_b; + G2SLoaderB g2s_b; + S2RLoaderB s2r_b; - R2SStorerC r2s_c; - S2GStorerC s2g_c; + R2SStorerC r2s_c; + S2GStorerC s2g_c; - for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { - g2s_a(gAs(k1), sA); - g2s_b(gBs(k1), sB); - __copy_async(); - __syncthreads(); + for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { + g2s_a(gAs(k1), sA); + g2s_b(gBs(k1), sB); + __copy_async(); + __syncthreads(); - for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { - s2r_a(sAs(k2), rA); - s2r_b(sBs(k2), rB); + for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { + s2r_a(sAs(k2), rA); + s2r_b(sBs(k2), rB); - compute::gemm(rA, rB, acc); - } + compute::gemm(rA, rB, acc); } - r2s_c(acc, sC); - __syncthreads(); - s2g_c(sC, gC); + } + r2s_c(acc, sC); + __syncthreads(); + s2g_c(sC, gC); } diff --git a/examples/101_gemm/02_gemm_all_mem/main.cu b/examples/101_gemm/02_gemm_all_mem/main.cu index 8edadb91..e2ccba90 100644 --- a/examples/101_gemm/02_gemm_all_mem/main.cu +++ b/examples/101_gemm/02_gemm_all_mem/main.cu @@ -5,120 +5,119 @@ #include "util.hpp" void run_test() { - using WholeShape = GemmShape<4096, 4096, 4096>; - using CtaTileShape = GemmShape<64, 128, 128>; - using WarpLayout = tl::RowMajor<2, 2>; - static constexpr int kRK = 32; - - using InType = __half; - using AccType = float; - - static constexpr int kM = dim_size<0, WholeShape>; - static constexpr int kN = dim_size<1, WholeShape>; - static constexpr int kK = dim_size<2, WholeShape>; - - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - static constexpr int kTK = dim_size<2, CtaTileShape>; - - thrust::host_vector h_a(kM * kK); - for (int i = 0; i < h_a.size(); ++i) - h_a[i] = static_cast(rand_float()); - - thrust::host_vector h_b(kK * kN); - for (int i = 0; i < h_b.size(); ++i) - h_b[i] = static_cast(rand_float()); - - thrust::host_vector h_c(kM * kN); - thrust::fill(h_c.begin(), h_c.end(), 0.); - - thrust::device_vector d_a = h_a; - thrust::device_vector d_b = h_b; - thrust::device_vector d_c = h_c; - - const InType* A = thrust::raw_pointer_cast(d_a.data()); - const InType* B = thrust::raw_pointer_cast(d_b.data()); - AccType* C = thrust::raw_pointer_cast(d_c.data()); - - using Config = KeGemmTraits; - auto kernel = - &gemm; - - static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); - static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); - static constexpr int smem_size = smem_size_inputs > smem_size_accumulators - ? smem_size_inputs - : smem_size_accumulators; - - const int kMaxSmemPerBlock = 48 * 1024; - if (smem_size > kMaxSmemPerBlock) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + using WholeShape = GemmShape<4096, 4096, 4096>; + using CtaTileShape = GemmShape<64, 128, 128>; + using WarpLayout = tl::RowMajor<2, 2>; + static constexpr int kRK = 32; + + using InType = __half; + using AccType = float; + + static constexpr int kM = dim_size<0, WholeShape>; + static constexpr int kN = dim_size<1, WholeShape>; + static constexpr int kK = dim_size<2, WholeShape>; + + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + static constexpr int kTK = dim_size<2, CtaTileShape>; + + thrust::host_vector h_a(kM * kK); + for (int i = 0; i < h_a.size(); ++i) + h_a[i] = static_cast(rand_float()); + + thrust::host_vector h_b(kK * kN); + for (int i = 0; i < h_b.size(); ++i) + h_b[i] = static_cast(rand_float()); + + thrust::host_vector h_c(kM * kN); + thrust::fill(h_c.begin(), h_c.end(), 0.); + + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_c = h_c; + + const InType* A = thrust::raw_pointer_cast(d_a.data()); + const InType* B = thrust::raw_pointer_cast(d_b.data()); + AccType* C = thrust::raw_pointer_cast(d_c.data()); + + using Config = + KeGemmTraits; + auto kernel = &gemm; + + static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); + static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); + static constexpr int smem_size = smem_size_inputs > smem_size_accumulators + ? smem_size_inputs + : smem_size_accumulators; + + const int kMaxSmemPerBlock = 48 * 1024; + if (smem_size > kMaxSmemPerBlock) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + + int block_x = CeilDiv; + int block_y = CeilDiv; + + dim3 dim_grid(block_x, block_y, 1); + dim3 dim_block(Config::kThreads, 1, 1); + + kernel<<>>(A, B, C); + cudaDeviceSynchronize(); + h_c = d_c; + + // check correctness + thrust::device_vector d_c2(kM * kN); + thrust::fill(d_c2.begin(), d_c2.end(), 0.); + + cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c2.data()), false /*timeit*/); + thrust::host_vector h_c2 = d_c2; + + bool passed = check_results(thrust::raw_pointer_cast(h_c.data()), + thrust::raw_pointer_cast(h_c2.data()), kM * kN); + + if (passed) { + std::cout << "Test passed." << std::endl; + + for (int i = 0; i < 10; ++i) { // warm up + kernel<<>>(A, B, C); } - - int block_x = CeilDiv; - int block_y = CeilDiv; - - dim3 dim_grid(block_x, block_y, 1); - dim3 dim_block(Config::kThreads, 1, 1); - - kernel<<>>(A, B, C); cudaDeviceSynchronize(); - h_c = d_c; - - // check correctness - thrust::device_vector d_c2(kM * kN); - thrust::fill(d_c2.begin(), d_c2.end(), 0.); - - cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c2.data()), false /*timeit*/); - thrust::host_vector h_c2 = d_c2; - - bool passed = check_results(thrust::raw_pointer_cast(h_c.data()), - thrust::raw_pointer_cast(h_c2.data()), kM * kN); - - if (passed) { - std::cout << "Test passed." << std::endl; - - for (int i = 0; i < 10; ++i) { // warm up - kernel<<>>(A, B, C); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - int iters = 20; - for (int i = 0; i < iters; ++i) { - kernel<<>>(A, B, C); - } - cudaDeviceSynchronize(); - float time2 = timer.stop() / iters; - - float time1 = cublas_hgemm( - kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c2.data()), true /*timeit*/); - - std::cout << "cuBLAS\ttilefusion\tRatio" << std::endl; - std::cout << std::setprecision(4) << time1 << "\t" << time2 << "\t" - << time2 / time1 << std::endl; - } else { - std::cerr << "Test failed." << std::endl; + + CudaTimer timer; + timer.start(); + int iters = 20; + for (int i = 0; i < iters; ++i) { + kernel<<>>(A, B, C); } + cudaDeviceSynchronize(); + float time2 = timer.stop() / iters; + + float time1 = + cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c2.data()), true /*timeit*/); + + std::cout << "cuBLAS\ttilefusion\tRatio" << std::endl; + std::cout << std::setprecision(4) << time1 << "\t" << time2 << "\t" + << time2 / time1 << std::endl; + } else { + std::cerr << "Test failed." << std::endl; + } } int main(int argc, char* argv[]) { - run_test(); - return 0; + run_test(); + return 0; } diff --git a/examples/101_gemm/util.hpp b/examples/101_gemm/util.hpp index b8f2a948..e9b43c5b 100644 --- a/examples/101_gemm/util.hpp +++ b/examples/101_gemm/util.hpp @@ -11,87 +11,87 @@ #include float rand_float(float a = 5e-4, float b = 1e-2) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; } float cublas_hgemm(int64_t kM, int64_t kN, int64_t kK, // problem shape const __half* A, const __half* B, __half* C, bool timeit = false, int warm_up = 5, int iters = 20) { - cublasHandle_t handle; - cublasCreate(&handle); - - __half alf = static_cast<__half>(1.); - __half bet = static_cast<__half>(0.); - - float elapsed = 0.; - - if (timeit) { - for (int i = 0; i < warm_up; ++i) { - cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, - kK, &alf, B, kK, A, kK, &bet, C, kN); - } - cudaDeviceSynchronize(); - - CudaTimer timer; - timer.start(); - for (int i = 0; i < iters; ++i) { - cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, - kK, &alf, B, kK, A, kK, &bet, C, kN); - } - cudaDeviceSynchronize(); - elapsed = timer.stop() / iters; - } else { - cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, - &alf, B, kK, A, kK, &bet, C, kN); + cublasHandle_t handle; + cublasCreate(&handle); + + __half alf = static_cast<__half>(1.); + __half bet = static_cast<__half>(0.); + + float elapsed = 0.; + + if (timeit) { + for (int i = 0; i < warm_up; ++i) { + cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, + &alf, B, kK, A, kK, &bet, C, kN); } cudaDeviceSynchronize(); - cublasDestroy(handle); - return elapsed; + CudaTimer timer; + timer.start(); + for (int i = 0; i < iters; ++i) { + cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, + &alf, B, kK, A, kK, &bet, C, kN); + } + cudaDeviceSynchronize(); + elapsed = timer.stop() / iters; + } else { + cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, &alf, + B, kK, A, kK, &bet, C, kN); + } + cudaDeviceSynchronize(); + + cublasDestroy(handle); + return elapsed; } bool check_results(const float* values1, const __half* values2, int numel) { - bool passed = true; - const float epsilon = 1e-3; + bool passed = true; + const float epsilon = 1e-3; - double total_diff = 0.; - double max_abs_diff = FLT_MIN; - double diff = 0.; + double total_diff = 0.; + double max_abs_diff = FLT_MIN; + double diff = 0.; #ifdef DEBUG - int cut_off = 128; - printf("ground truth:\n"); - for (int i = 0; i < cut_off; ++i) { - printf("%.4f, ", __half2float(values2[i])); - if (i && (i + 1) % 16 == 0) printf("\n"); - } - printf("\ncomputed values:\n"); - for (int i = 0; i < cut_off; ++i) { - printf("%.4f, ", values1[i]); - if (i && (i + 1) % 16 == 0) printf("\n"); - } + int cut_off = 128; + printf("ground truth:\n"); + for (int i = 0; i < cut_off; ++i) { + printf("%.4f, ", __half2float(values2[i])); + if (i && (i + 1) % 16 == 0) printf("\n"); + } + printf("\ncomputed values:\n"); + for (int i = 0; i < cut_off; ++i) { + printf("%.4f, ", values1[i]); + if (i && (i + 1) % 16 == 0) printf("\n"); + } #endif - for (int i = 0; i < numel; ++i) { - float v1 = values1[i]; - float v2 = __half2float(values2[i]); - diff = fabs(v1 - v2); - max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; - total_diff += diff; + for (int i = 0; i < numel; ++i) { + float v1 = values1[i]; + float v2 = __half2float(values2[i]); + diff = fabs(v1 - v2); + max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; + total_diff += diff; #ifdef DEBUG - if (diff > epsilon) { - printf("the %d-th value differs (%.4f): %.4f vs. %.4f\n", i, diff, - v1, v2); - } -#endif + if (diff > epsilon) { + printf("the %d-th value differs (%.4f): %.4f vs. %.4f\n", i, diff, v1, + v2); } +#endif + } - double avg_diff = total_diff / numel; - if (avg_diff > epsilon) passed = false; + double avg_diff = total_diff / numel; + if (avg_diff > epsilon) passed = false; - return passed; + return passed; } diff --git a/include/cell/compute/broadcast.hpp b/include/cell/compute/broadcast.hpp index fd509b06..37de93cb 100644 --- a/include/cell/compute/broadcast.hpp +++ b/include/cell/compute/broadcast.hpp @@ -13,153 +13,153 @@ namespace tl = tile_layout; template struct Broadcast { - using DType = typename DstTile::DType::DType; + using DType = typename DstTile::DType::DType; - static constexpr int kRows = DstTile::kRows; - static constexpr int kCols = DstTile::kCols; + static constexpr int kRows = DstTile::kRows; + static constexpr int kCols = DstTile::kCols; - DEVICE void operator()(const SrcTile& src, DstTile& dst) {} + DEVICE void operator()(const SrcTile& src, DstTile& dst) {} }; template struct Broadcast { - using DType = typename DstTile::DType::DType; + using DType = typename DstTile::DType::DType; - static constexpr int kRows = DstTile::kRows; - static constexpr int kCols = DstTile::kCols; + static constexpr int kRows = DstTile::kRows; + static constexpr int kCols = DstTile::kCols; - DEVICE void operator()(const SrcTile& src, DstTile& dst) { + DEVICE void operator()(const SrcTile& src, DstTile& dst) { #pragma unroll - for (int i = 0; i < kRows; ++i) { - DType top_row = src(i, 0); - DType bottom_row = src(i, 1); + for (int i = 0; i < kRows; ++i) { + DType top_row = src(i, 0); + DType bottom_row = src(i, 1); #pragma unroll - for (int j = 0; j < kCols; ++j) { - dst(i, j)(0, 0) = top_row; - dst(i, j)(0, 1) = top_row; - dst(i, j)(1, 0) = top_row; - dst(i, j)(1, 1) = top_row; - - dst(i, j)(0, 2) = bottom_row; - dst(i, j)(0, 3) = bottom_row; - dst(i, j)(1, 2) = bottom_row; - dst(i, j)(1, 3) = bottom_row; - } - } + for (int j = 0; j < kCols; ++j) { + dst(i, j)(0, 0) = top_row; + dst(i, j)(0, 1) = top_row; + dst(i, j)(1, 0) = top_row; + dst(i, j)(1, 1) = top_row; + + dst(i, j)(0, 2) = bottom_row; + dst(i, j)(0, 3) = bottom_row; + dst(i, j)(1, 2) = bottom_row; + dst(i, j)(1, 3) = bottom_row; + } } + } }; template struct Broadcast { - using DType = typename DstTile::DType::DType; + using DType = typename DstTile::DType::DType; - static constexpr int kRows = DstTile::kRows; - static constexpr int kCols = DstTile::kCols; + static constexpr int kRows = DstTile::kRows; + static constexpr int kCols = DstTile::kCols; - DEVICE void operator()(const SrcTile& src, DstTile& dst) { + DEVICE void operator()(const SrcTile& src, DstTile& dst) { #pragma unroll - for (int j = 0; j < kCols; ++j) { - DType top_col = src(0, j); - DType bottom_col = src(1, j); + for (int j = 0; j < kCols; ++j) { + DType top_col = src(0, j); + DType bottom_col = src(1, j); #pragma unroll - for (int i = 0; i < kRows; ++i) { - dst(i, j)(0, 0) = top_col; - dst(i, j)(1, 0) = top_col; - dst(i, j)(0, 1) = top_col; - dst(i, j)(1, 1) = top_col; - - dst(i, j)(2, 0) = bottom_col; - dst(i, j)(3, 0) = bottom_col; - dst(i, j)(2, 1) = bottom_col; - dst(i, j)(3, 1) = bottom_col; - } - } + for (int i = 0; i < kRows; ++i) { + dst(i, j)(0, 0) = top_col; + dst(i, j)(1, 0) = top_col; + dst(i, j)(0, 1) = top_col; + dst(i, j)(1, 1) = top_col; + + dst(i, j)(2, 0) = bottom_col; + dst(i, j)(3, 0) = bottom_col; + dst(i, j)(2, 1) = bottom_col; + dst(i, j)(3, 1) = bottom_col; + } } + } }; template struct BroadcastFuse { - using DType = typename DstTile::DType::DType; + using DType = typename DstTile::DType::DType; - static constexpr int kRows = DstTile::kRows; - static constexpr int kCols = DstTile::kCols; + static constexpr int kRows = DstTile::kRows; + static constexpr int kCols = DstTile::kCols; - DEVICE void operator()(const SrcTile& src, DstTile& dst) {} + DEVICE void operator()(const SrcTile& src, DstTile& dst) {} }; template struct BroadcastFuse { - using DType = typename DstTile::DType::DType; + using DType = typename DstTile::DType::DType; - static constexpr int kRows = DstTile::kRows; - static constexpr int kCols = DstTile::kCols; + static constexpr int kRows = DstTile::kRows; + static constexpr int kCols = DstTile::kCols; - DEVICE void operator()(const SrcTile& src, DstTile& dst) { - Functor f; + DEVICE void operator()(const SrcTile& src, DstTile& dst) { + Functor f; #pragma unroll - for (int i = 0; i < kRows; ++i) { - DType top_row = src(i, 0); - DType bottom_row = src(i, 1); + for (int i = 0; i < kRows; ++i) { + DType top_row = src(i, 0); + DType bottom_row = src(i, 1); #pragma unroll - for (int j = 0; j < kCols; ++j) { - f(dst(i, j)(0, 0), top_row, dst(i, j)(0, 0)); - f(dst(i, j)(0, 1), top_row, dst(i, j)(0, 1)); - f(dst(i, j)(1, 0), top_row, dst(i, j)(1, 0)); - f(dst(i, j)(1, 1), top_row, dst(i, j)(1, 1)); - - f(dst(i, j)(0, 2), bottom_row, dst(i, j)(0, 2)); - f(dst(i, j)(0, 3), bottom_row, dst(i, j)(0, 3)); - f(dst(i, j)(1, 2), bottom_row, dst(i, j)(1, 2)); - f(dst(i, j)(1, 3), bottom_row, dst(i, j)(1, 3)); - } - } + for (int j = 0; j < kCols; ++j) { + f(dst(i, j)(0, 0), top_row, dst(i, j)(0, 0)); + f(dst(i, j)(0, 1), top_row, dst(i, j)(0, 1)); + f(dst(i, j)(1, 0), top_row, dst(i, j)(1, 0)); + f(dst(i, j)(1, 1), top_row, dst(i, j)(1, 1)); + + f(dst(i, j)(0, 2), bottom_row, dst(i, j)(0, 2)); + f(dst(i, j)(0, 3), bottom_row, dst(i, j)(0, 3)); + f(dst(i, j)(1, 2), bottom_row, dst(i, j)(1, 2)); + f(dst(i, j)(1, 3), bottom_row, dst(i, j)(1, 3)); + } } + } }; template struct BroadcastFuse { - using DType = typename DstTile::DType::DType; + using DType = typename DstTile::DType::DType; - static constexpr int kRows = DstTile::kRows; - static constexpr int kCols = DstTile::kCols; + static constexpr int kRows = DstTile::kRows; + static constexpr int kCols = DstTile::kCols; - DEVICE void operator()(const SrcTile& src, DstTile& dst) { - Functor f; + DEVICE void operator()(const SrcTile& src, DstTile& dst) { + Functor f; #pragma unroll - for (int j = 0; j < kCols; ++j) { - DType top_col = src(0, j); - DType bottom_col = src(1, j); + for (int j = 0; j < kCols; ++j) { + DType top_col = src(0, j); + DType bottom_col = src(1, j); #pragma unroll - for (int i = 0; i < kRows; ++i) { - f(dst(i, j)(0, 0), top_col, dst(i, j)(0, 0)); - f(dst(i, j)(1, 0), top_col, dst(i, j)(1, 0)); - f(dst(i, j)(0, 1), top_col, dst(i, j)(0, 1)); - f(dst(i, j)(1, 1), top_col, dst(i, j)(1, 1)); - - f(dst(i, j)(2, 0), bottom_col, dst(i, j)(2, 0)); - f(dst(i, j)(3, 0), bottom_col, dst(i, j)(3, 0)); - f(dst(i, j)(2, 1), bottom_col, dst(i, j)(2, 1)); - f(dst(i, j)(3, 1), bottom_col, dst(i, j)(3, 1)); - } - } + for (int i = 0; i < kRows; ++i) { + f(dst(i, j)(0, 0), top_col, dst(i, j)(0, 0)); + f(dst(i, j)(1, 0), top_col, dst(i, j)(1, 0)); + f(dst(i, j)(0, 1), top_col, dst(i, j)(0, 1)); + f(dst(i, j)(1, 1), top_col, dst(i, j)(1, 1)); + + f(dst(i, j)(2, 0), bottom_col, dst(i, j)(2, 0)); + f(dst(i, j)(3, 0), bottom_col, dst(i, j)(3, 0)); + f(dst(i, j)(2, 1), bottom_col, dst(i, j)(2, 1)); + f(dst(i, j)(3, 1), bottom_col, dst(i, j)(3, 1)); + } } + } }; template struct BroadcastScalar { - static constexpr int kRows = RegTile::kRows; - static constexpr int kCols = RegTile::kCols; - - template - DEVICE void operator()(const RegTile& src, Element scalar, RegTile& dst) { - Functor f; - for (int i = 0; i < kRows; ++i) { - for (int j = 0; j < kCols; ++j) { - f(src(i, j), scalar, dst(i, j)); - } - } + static constexpr int kRows = RegTile::kRows; + static constexpr int kCols = RegTile::kCols; + + template + DEVICE void operator()(const RegTile& src, Element scalar, RegTile& dst) { + Functor f; + for (int i = 0; i < kRows; ++i) { + for (int j = 0; j < kCols; ++j) { + f(src(i, j), scalar, dst(i, j)); + } } + } }; template diff --git a/include/cell/compute/gemm.hpp b/include/cell/compute/gemm.hpp index b8505782..fdefdfe4 100644 --- a/include/cell/compute/gemm.hpp +++ b/include/cell/compute/gemm.hpp @@ -17,136 +17,136 @@ struct MmaAtom; template <> struct MmaAtom<__half, __half, float, MMA_ATOM_16x16x16> { - struct BaseTile { - static constexpr int kRows = 16; - static constexpr int kCols = 16; - static constexpr int kNumel = 256; - }; - using BaseTileA = BaseTile; - using BaseTileB = BaseTile; - using BaseTileC = BaseTile; - - DEVICE void operator()(const __half* ra, const __half* rb, float* rc) { - const uint32_t* A = reinterpret_cast(ra); - const uint32_t* B = reinterpret_cast(rb); - float* C = static_cast(rc); + struct BaseTile { + static constexpr int kRows = 16; + static constexpr int kCols = 16; + static constexpr int kNumel = 256; + }; + using BaseTileA = BaseTile; + using BaseTileB = BaseTile; + using BaseTileC = BaseTile; + + DEVICE void operator()(const __half* ra, const __half* rb, float* rc) { + const uint32_t* A = reinterpret_cast(ra); + const uint32_t* B = reinterpret_cast(rb); + float* C = static_cast(rc); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(B[3]), - "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(B[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); #else - assert(false && - "This GEMM implementation requires SM80 (Ampere) or later " - "architecture"); + assert(false && + "This GEMM implementation requires SM80 (Ampere) or later " + "architecture"); #endif - } + } }; template <> struct MmaAtom<__half, __half, __half, MMA_ATOM_16x16x16> { - struct BaseTile { - static constexpr int kRows = 16; - static constexpr int kCols = 16; - static constexpr int kNumel = 256; - }; - using BaseTileA = BaseTile; - using BaseTileB = BaseTile; - using BaseTileC = BaseTile; - - DEVICE void operator()(const __half* ra, const __half* rb, __half* rc) { - const uint32_t* A = reinterpret_cast(ra); - const uint32_t* B = reinterpret_cast(rb); - uint32_t* C = reinterpret_cast(rc); + struct BaseTile { + static constexpr int kRows = 16; + static constexpr int kCols = 16; + static constexpr int kNumel = 256; + }; + using BaseTileA = BaseTile; + using BaseTileB = BaseTile; + using BaseTileC = BaseTile; + + DEVICE void operator()(const __half* ra, const __half* rb, __half* rc) { + const uint32_t* A = reinterpret_cast(ra); + const uint32_t* B = reinterpret_cast(rb); + uint32_t* C = reinterpret_cast(rc); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0, %1}, " - "{%2, %3, %4, %5}, " - "{%6, %7}, " - "{%8, %9};" - : "=r"(C[0]), "=r"(C[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), - "r"(C[0]), "r"(C[1])); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0, %1}, " - "{%2, %3, %4, %5}, " - "{%6, %7}, " - "{%8, %9};" - : "=r"(C[2]), "=r"(C[3]) - : "r"(A[4]), "r"(A[5]), "r"(A[6]), "r"(A[7]), "r"(B[1]), "r"(B[3]), - "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3, %4, %5}, " + "{%6, %7}, " + "{%8, %9};" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), + "r"(C[0]), "r"(C[1])); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3, %4, %5}, " + "{%6, %7}, " + "{%8, %9};" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[4]), "r"(A[5]), "r"(A[6]), "r"(A[7]), "r"(B[1]), "r"(B[3]), + "r"(C[2]), "r"(C[3])); #else - assert(false && - "This GEMM implementation requires SM80 (Ampere) or later " - "architecture"); + assert(false && + "This GEMM implementation requires SM80 (Ampere) or later " + "architecture"); #endif - } + } }; template <> struct MmaAtom<__bfloat16, __bfloat16, float, MMA_ATOM_16x16x16> { - struct BaseTile { - static constexpr int kRows = 16; - static constexpr int kCols = 16; - static constexpr int kNumel = 256; - }; - using BaseTileA = BaseTile; - using BaseTileB = BaseTile; - using BaseTileC = BaseTile; - - DEVICE void operator()(const __bfloat16* ra, const __bfloat16* rb, - float* rc) { - const uint32_t* A = reinterpret_cast(ra); - const uint32_t* B = reinterpret_cast(rb); - float* C = static_cast(rc); + struct BaseTile { + static constexpr int kRows = 16; + static constexpr int kCols = 16; + static constexpr int kNumel = 256; + }; + using BaseTileA = BaseTile; + using BaseTileB = BaseTile; + using BaseTileC = BaseTile; + + DEVICE void operator()(const __bfloat16* ra, const __bfloat16* rb, + float* rc) { + const uint32_t* A = reinterpret_cast(ra); + const uint32_t* B = reinterpret_cast(rb); + float* C = static_cast(rc); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "+f"(C[0]), "+f"(C[1]), "+f"(C[2]), "+f"(C[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "+f"(C[4]), "+f"(C[5]), "+f"(C[6]), "+f"(C[7]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(B[3]), - "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "+f"(C[0]), "+f"(C[1]), "+f"(C[2]), "+f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "+f"(C[4]), "+f"(C[5]), "+f"(C[6]), "+f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(B[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); #else - assert(false && - "This GEMM implementation requires SM80 (Ampere) or later " - "architecture"); + assert(false && + "This GEMM implementation requires SM80 (Ampere) or later " + "architecture"); #endif - } + } }; /// @brief: Functor to warp wmma PTX instruction. See the below document for @@ -155,49 +155,49 @@ struct MmaAtom<__bfloat16, __bfloat16, float, MMA_ATOM_16x16x16> { template struct Gemm { - using InTypeA = typename RegTileA::DType::DType; - using InTypeB = typename RegTileB::DType::DType; - using OutType = typename RegTileC::DType::DType; - - static_assert(std::is_same_v || - std::is_same_v, - "This GEMM implementation supports only half-precision as " - "the input element type."); - static_assert(std::is_same_v || - std::is_same_v, - "The output type must be float or half."); - static_assert(std::is_same_v, - "Mismatched data type for operand A and B."); - static_assert(RegTileB::kRows == RegTileA::kCols, - "Mismatched k-dimension for operand A and B."); - - static constexpr int kMs = RegTileA::kRows; - static constexpr int kNs = RegTileB::kCols; - static constexpr int kKs = RegTileA::kCols; - static_assert(kMs && kNs && kKs, "Invalid tile shapes for GEMM."); - - DEVICE void operator()(const RegTileA& a, const RegTileB& b, RegTileC& c) { - for (int i = 0; i < kMs; ++i) { - for (int j = 0; j < kNs; ++j) { + using InTypeA = typename RegTileA::DType::DType; + using InTypeB = typename RegTileB::DType::DType; + using OutType = typename RegTileC::DType::DType; + + static_assert(std::is_same_v || + std::is_same_v, + "This GEMM implementation supports only half-precision as " + "the input element type."); + static_assert(std::is_same_v || + std::is_same_v, + "The output type must be float or half."); + static_assert(std::is_same_v, + "Mismatched data type for operand A and B."); + static_assert(RegTileB::kRows == RegTileA::kCols, + "Mismatched k-dimension for operand A and B."); + + static constexpr int kMs = RegTileA::kRows; + static constexpr int kNs = RegTileB::kCols; + static constexpr int kKs = RegTileA::kCols; + static_assert(kMs && kNs && kKs, "Invalid tile shapes for GEMM."); + + DEVICE void operator()(const RegTileA& a, const RegTileB& b, RegTileC& c) { + for (int i = 0; i < kMs; ++i) { + for (int j = 0; j < kNs; ++j) { #pragma unroll - for (int k = 0; k < kKs; ++k) { - mma(a(i, k).data(), b(k, j).data(), c(i, j).mutable_data()); - } - } + for (int k = 0; k < kKs; ++k) { + mma(a(i, k).data(), b(k, j).data(), c(i, j).mutable_data()); } + } } + } - private: - using MmaAtom = MmaAtom; - MmaAtom mma; + private: + using MmaAtom = MmaAtom; + MmaAtom mma; }; template > DEVICE void gemm(const RegTileA& a, const RegTileB& b, RegTileC& c) { - GemmOp gemm; - gemm(a, b, c); + GemmOp gemm; + gemm(a, b, c); } } // namespace tilefusion::cell::compute diff --git a/include/cell/compute/map.hpp b/include/cell/compute/map.hpp index e6547f0f..2346d125 100644 --- a/include/cell/compute/map.hpp +++ b/include/cell/compute/map.hpp @@ -19,21 +19,21 @@ namespace detail { // which will affect the memory access performance. template struct ElementWise { - using DType = typename RegTile::DType; + using DType = typename RegTile::DType; - static constexpr int kRows = RegTile::kRows; - static constexpr int kCols = RegTile::kCols; + static constexpr int kRows = RegTile::kRows; + static constexpr int kCols = RegTile::kCols; - DEVICE void operator()(const RegTile& src, RegTile& dst) { - Functor f; + DEVICE void operator()(const RegTile& src, RegTile& dst) { + Functor f; #pragma unroll - for (int i = 0; i < kRows; ++i) { + for (int i = 0; i < kRows; ++i) { #pragma unroll - for (int j = 0; j < kCols; ++j) { - f(src(i, j), dst(i, j)); - } - } + for (int j = 0; j < kCols; ++j) { + f(src(i, j), dst(i, j)); + } } + } }; // TODO(KuangjuX): Distinguish whether the `Layout` is Row Major or Column @@ -41,22 +41,22 @@ struct ElementWise { // which will affect the memory access performance. template struct ElementWise2 { - static constexpr int kRows = SrcRegTile::kRows; - static constexpr int kCols = SrcRegTile::kCols; + static constexpr int kRows = SrcRegTile::kRows; + static constexpr int kCols = SrcRegTile::kCols; - static_assert(kRows == DstRegTile::kRows, "kRows must be equal"); - static_assert(kCols == DstRegTile::kCols, "kCols must be equal"); + static_assert(kRows == DstRegTile::kRows, "kRows must be equal"); + static_assert(kCols == DstRegTile::kCols, "kCols must be equal"); - DEVICE void operator()(const SrcRegTile& src, DstRegTile& dst) { - Functor f; + DEVICE void operator()(const SrcRegTile& src, DstRegTile& dst) { + Functor f; #pragma unroll - for (int i = 0; i < kRows; ++i) { + for (int i = 0; i < kRows; ++i) { #pragma unroll - for (int j = 0; j < kCols; ++j) { - f(src(i, j), dst(i, j)); - } - } + for (int j = 0; j < kCols; ++j) { + f(src(i, j), dst(i, j)); + } } + } }; // TODO(KuangjuX): Distinguish whether the `Layout` is Row Major or Column @@ -64,22 +64,21 @@ struct ElementWise2 { // which will affect the memory access performance. template struct Binary { - using DType = typename RegTile::DType; + using DType = typename RegTile::DType; - static constexpr int kRows = RegTile::kRows; - static constexpr int kCols = RegTile::kCols; + static constexpr int kRows = RegTile::kRows; + static constexpr int kCols = RegTile::kCols; - DEVICE void operator()(const RegTile& lhs, const RegTile& rhs, - RegTile& dst) { - Functor f; + DEVICE void operator()(const RegTile& lhs, const RegTile& rhs, RegTile& dst) { + Functor f; #pragma unroll - for (int i = 0; i < kRows; ++i) { + for (int i = 0; i < kRows; ++i) { #pragma unroll - for (int j = 0; j < kCols; ++j) { - f(lhs(i, j), rhs(i, j), dst(i, j)); - } - } + for (int j = 0; j < kCols; ++j) { + f(lhs(i, j), rhs(i, j), dst(i, j)); + } } + } }; } // namespace detail diff --git a/include/cell/compute/math_functor.hpp b/include/cell/compute/math_functor.hpp index 1f48d36a..91c710cc 100644 --- a/include/cell/compute/math_functor.hpp +++ b/include/cell/compute/math_functor.hpp @@ -9,215 +9,207 @@ namespace tilefusion::cell::compute { template struct Add { - DEVICE Element operator()(Element a, Element b) const { return a + b; } + DEVICE Element operator()(Element a, Element b) const { return a + b; } - DEVICE void operator()(const Element& lhs, const Element& rhs, - Element& dst) { - dst = lhs + rhs; - } + DEVICE void operator()(const Element& lhs, const Element& rhs, Element& dst) { + dst = lhs + rhs; + } }; template struct Sub { - DEVICE Element operator()(Element a, Element b) const { return a - b; } + DEVICE Element operator()(Element a, Element b) const { return a - b; } - DEVICE void operator()(const Element& lhs, const Element& rhs, - Element& dst) { - dst = lhs - rhs; - } + DEVICE void operator()(const Element& lhs, const Element& rhs, Element& dst) { + dst = lhs - rhs; + } }; template struct Mul { - DEVICE Element operator()(Element a, Element b) const { return a * b; } + DEVICE Element operator()(Element a, Element b) const { return a * b; } - DEVICE void operator()(const Element& lhs, const Element& rhs, - Element& dst) { - dst = lhs * rhs; - } + DEVICE void operator()(const Element& lhs, const Element& rhs, Element& dst) { + dst = lhs * rhs; + } }; template struct Div { - DEVICE Element operator()(Element a, Element b) const { return a / b; } + DEVICE Element operator()(Element a, Element b) const { return a / b; } - DEVICE void operator()(const Element& lhs, const Element& rhs, - Element& dst) { - dst = lhs / rhs; - } + DEVICE void operator()(const Element& lhs, const Element& rhs, Element& dst) { + dst = lhs / rhs; + } }; template struct Max { - DEVICE Element operator()(Element a, Element b) const { - return a > b ? a : b; - } + DEVICE Element operator()(Element a, Element b) const { + return a > b ? a : b; + } - DEVICE void operator()(const Element& lhs, const Element& rhs, - Element& dst) { - dst = lhs > rhs ? lhs : rhs; - } + DEVICE void operator()(const Element& lhs, const Element& rhs, Element& dst) { + dst = lhs > rhs ? lhs : rhs; + } }; template struct Min { - DEVICE Element operator()(Element a, Element b) const { - return a < b ? a : b; - } + DEVICE Element operator()(Element a, Element b) const { + return a < b ? a : b; + } - DEVICE void operator()(const Element& lhs, const Element& rhs, - Element& dst) { - dst = lhs < rhs ? lhs : rhs; - } + DEVICE void operator()(const Element& lhs, const Element& rhs, Element& dst) { + dst = lhs < rhs ? lhs : rhs; + } }; template struct Exp { - DEVICE Element operator()(Element a) const { return exp(a); } + DEVICE Element operator()(Element a) const { return exp(a); } - DEVICE void operator()(const Element& src, Element& dst) { dst = exp(src); } + DEVICE void operator()(const Element& src, Element& dst) { dst = exp(src); } }; #if defined(__CUDA_ARCH__) template <> struct Exp { - DEVICE float operator()(float a) const { return __expf(a); } + DEVICE float operator()(float a) const { return __expf(a); } - DEVICE void operator()(const float& src, float& dst) { dst = __expf(src); } + DEVICE void operator()(const float& src, float& dst) { dst = __expf(src); } }; template <> struct Exp<__half> { - DEVICE __half operator()(__half a) const { return hexp(a); } + DEVICE __half operator()(__half a) const { return hexp(a); } - DEVICE void operator()(const __half& src, __half& dst) { dst = hexp(src); } + DEVICE void operator()(const __half& src, __half& dst) { dst = hexp(src); } }; #endif template struct Log { - DEVICE Element operator()(Element a) const { return log(a); } + DEVICE Element operator()(Element a) const { return log(a); } - DEVICE void operator()(const Element& src, Element& dst) { dst = log(src); } + DEVICE void operator()(const Element& src, Element& dst) { dst = log(src); } }; #if defined(__CUDA_ARCH__) template <> struct Log { - DEVICE float operator()(float a) const { return __logf(a); } + DEVICE float operator()(float a) const { return __logf(a); } - DEVICE void operator()(const float& src, float& dst) { dst = __logf(src); } + DEVICE void operator()(const float& src, float& dst) { dst = __logf(src); } }; template <> struct Log<__half> { - DEVICE __half operator()(__half a) const { return hlog(a); } + DEVICE __half operator()(__half a) const { return hlog(a); } - DEVICE void operator()(const __half& src, __half& dst) { dst = hlog(src); } + DEVICE void operator()(const __half& src, __half& dst) { dst = hlog(src); } }; #endif template struct Relu { - DEVICE Element operator()(Element a) const { return a > 0 ? a : 0; } + DEVICE Element operator()(Element a) const { return a > 0 ? a : 0; } - DEVICE void operator()(const Element& src, Element& dst) { - dst = src > 0 ? src : 0; - } + DEVICE void operator()(const Element& src, Element& dst) { + dst = src > 0 ? src : 0; + } }; #if defined(__CUDA_ARCH__) template <> struct Relu { - DEVICE float operator()(float a) const { return max(a, 0.f); } + DEVICE float operator()(float a) const { return max(a, 0.f); } - DEVICE void operator()(const float& src, float& dst) { - dst = max(src, 0.f); - } + DEVICE void operator()(const float& src, float& dst) { dst = max(src, 0.f); } }; template <> struct Relu<__half> { - DEVICE __half operator()(__half a) const { return __hmax(a, 0); } + DEVICE __half operator()(__half a) const { return __hmax(a, 0); } - DEVICE void operator()(const __half& src, __half& dst) { - dst = __hmax(src, 0); - } + DEVICE void operator()(const __half& src, __half& dst) { + dst = __hmax(src, 0); + } }; #endif template struct Convert { - DEVICE OutType operator()(InType a) const { return OutType(a); } + DEVICE OutType operator()(InType a) const { return OutType(a); } - DEVICE void operator()(const InType& src, OutType& dst) { - dst = OutType(src); - } + DEVICE void operator()(const InType& src, OutType& dst) { + dst = OutType(src); + } }; #if defined(__CUDA_ARCH__) template <> struct Convert { - DEVICE __half operator()(float a) const { return __float2half(a); } + DEVICE __half operator()(float a) const { return __float2half(a); } - DEVICE void operator()(const float& src, __half& dst) { - dst = __float2half(src); - } + DEVICE void operator()(const float& src, __half& dst) { + dst = __float2half(src); + } }; template <> struct Convert<__half, float> { - DEVICE float operator()(__half a) const { return __half2float(a); } + DEVICE float operator()(__half a) const { return __half2float(a); } - DEVICE void operator()(const __half& src, float& dst) { - dst = __half2float(src); - } + DEVICE void operator()(const __half& src, float& dst) { + dst = __half2float(src); + } }; - #ifdef CUDA_FP8_AVAILABLE + #ifdef CUDA_FP8_AVAILABLE // FP8 E4M3 conversions template <> struct Convert { - DEVICE __fp8_e4m3 operator()(float a) const { - return from_float<__fp8_e4m3>(a); - } + DEVICE __fp8_e4m3 operator()(float a) const { + return from_float<__fp8_e4m3>(a); + } - DEVICE void operator()(const float& src, __fp8_e4m3& dst) { - dst = from_float<__fp8_e4m3>(src); - } + DEVICE void operator()(const float& src, __fp8_e4m3& dst) { + dst = from_float<__fp8_e4m3>(src); + } }; template <> struct Convert<__fp8_e4m3, float> { - DEVICE float operator()(__fp8_e4m3 a) const { return to_float(a); } + DEVICE float operator()(__fp8_e4m3 a) const { return to_float(a); } - DEVICE void operator()(const __fp8_e4m3& src, float& dst) { - dst = to_float(src); - } + DEVICE void operator()(const __fp8_e4m3& src, float& dst) { + dst = to_float(src); + } }; // FP8 E5M2 conversions template <> struct Convert { - DEVICE __fp8_e5m2 operator()(float a) const { - return from_float<__fp8_e5m2>(a); - } + DEVICE __fp8_e5m2 operator()(float a) const { + return from_float<__fp8_e5m2>(a); + } - DEVICE void operator()(const float& src, __fp8_e5m2& dst) { - dst = from_float<__fp8_e5m2>(src); - } + DEVICE void operator()(const float& src, __fp8_e5m2& dst) { + dst = from_float<__fp8_e5m2>(src); + } }; template <> struct Convert<__fp8_e5m2, float> { - DEVICE float operator()(__fp8_e5m2 a) const { return to_float(a); } + DEVICE float operator()(__fp8_e5m2 a) const { return to_float(a); } - DEVICE void operator()(const __fp8_e5m2& src, float& dst) { - dst = to_float(src); - } + DEVICE void operator()(const __fp8_e5m2& src, float& dst) { + dst = to_float(src); + } }; - #endif // CUDA_FP8_AVAILABLE + #endif // CUDA_FP8_AVAILABLE #endif // defined(__CUDA_ARCH__) diff --git a/include/cell/compute/reduce.hpp b/include/cell/compute/reduce.hpp index f1eebf1e..4c2dca58 100644 --- a/include/cell/compute/reduce.hpp +++ b/include/cell/compute/reduce.hpp @@ -15,140 +15,140 @@ struct Reduce; template struct Reduce { - using DType = typename RegTile::DType::DType; + using DType = typename RegTile::DType::DType; - static constexpr int kRows = RegTile::kRows; - static constexpr int kCols = RegTile::kCols; + static constexpr int kRows = RegTile::kRows; + static constexpr int kCols = RegTile::kCols; - template - DEVICE void operator()(const RegTile& src, DstTile& dst, Reduce reduce) { - const int leader = threadIdx.x & 0x1C; + template + DEVICE void operator()(const RegTile& src, DstTile& dst, Reduce reduce) { + const int leader = threadIdx.x & 0x1C; #pragma unroll - for (int i = 0; i < kRows; ++i) { - DType top_rows[kCols]; - DType bottom_rows[kCols]; + for (int i = 0; i < kRows; ++i) { + DType top_rows[kCols]; + DType bottom_rows[kCols]; #pragma unroll - for (int j = 0; j < kCols; ++j) { - auto base_tile = src(i, j); - DType top_row_0 = reduce(base_tile(0, 0), base_tile(0, 1)); - DType top_row_1 = reduce(base_tile(1, 0), base_tile(1, 1)); - top_rows[j] = reduce(top_row_0, top_row_1); + for (int j = 0; j < kCols; ++j) { + auto base_tile = src(i, j); + DType top_row_0 = reduce(base_tile(0, 0), base_tile(0, 1)); + DType top_row_1 = reduce(base_tile(1, 0), base_tile(1, 1)); + top_rows[j] = reduce(top_row_0, top_row_1); - DType bottom_row_0 = reduce(base_tile(0, 2), base_tile(0, 3)); - DType bottom_row_1 = reduce(base_tile(1, 2), base_tile(1, 3)); - bottom_rows[j] = reduce(bottom_row_0, bottom_row_1); - } + DType bottom_row_0 = reduce(base_tile(0, 2), base_tile(0, 3)); + DType bottom_row_1 = reduce(base_tile(1, 2), base_tile(1, 3)); + bottom_rows[j] = reduce(bottom_row_0, bottom_row_1); + } - DType top_row = top_rows[0]; - DType bottom_row = bottom_rows[0]; + DType top_row = top_rows[0]; + DType bottom_row = bottom_rows[0]; - // Compute the reduction of the top and bottom rows. + // Compute the reduction of the top and bottom rows. #pragma unroll - for (int j = 1; j < kCols; ++j) { - top_row = reduce(top_row, top_rows[j]); - bottom_row = reduce(bottom_row, bottom_rows[j]); - } - - // Shuffle the results to the leader thread. - top_row = reduce(top_row, shuffle_down_sync(MASK_ALL, top_row, 2)); - top_row = reduce(top_row, shuffle_down_sync(MASK_ALL, top_row, 1)); - - bottom_row = - reduce(bottom_row, shuffle_down_sync(MASK_ALL, bottom_row, 2)); - bottom_row = - reduce(bottom_row, shuffle_down_sync(MASK_ALL, bottom_row, 1)); - - // Group the threads into groups of four, and broadcast the data - // from the first thread in each group to the other three threads. - top_row = shuffle_sync(MASK_ALL, top_row, leader); - bottom_row = shuffle_sync(MASK_ALL, bottom_row, leader); - - // Store the results to the destination tile. - dst(i, 0) = top_row; - dst(i, 1) = bottom_row; - } + for (int j = 1; j < kCols; ++j) { + top_row = reduce(top_row, top_rows[j]); + bottom_row = reduce(bottom_row, bottom_rows[j]); + } + + // Shuffle the results to the leader thread. + top_row = reduce(top_row, shuffle_down_sync(MASK_ALL, top_row, 2)); + top_row = reduce(top_row, shuffle_down_sync(MASK_ALL, top_row, 1)); + + bottom_row = + reduce(bottom_row, shuffle_down_sync(MASK_ALL, bottom_row, 2)); + bottom_row = + reduce(bottom_row, shuffle_down_sync(MASK_ALL, bottom_row, 1)); + + // Group the threads into groups of four, and broadcast the data + // from the first thread in each group to the other three threads. + top_row = shuffle_sync(MASK_ALL, top_row, leader); + bottom_row = shuffle_sync(MASK_ALL, bottom_row, leader); + + // Store the results to the destination tile. + dst(i, 0) = top_row; + dst(i, 1) = bottom_row; } + } }; template struct Reduce { - using DType = typename RegTile::DType::DType; + using DType = typename RegTile::DType::DType; - static constexpr int kRows = RegTile::kRows; - static constexpr int kCols = RegTile::kCols; + static constexpr int kRows = RegTile::kRows; + static constexpr int kCols = RegTile::kCols; - template - DEVICE void operator()(const RegTile& tile, DstTile& dst, Reduce reduce) { - const int leader = threadIdx.x & 0x1C; + template + DEVICE void operator()(const RegTile& tile, DstTile& dst, Reduce reduce) { + const int leader = threadIdx.x & 0x1C; #pragma unroll - for (int i = 0; i < kCols; ++i) { - DType top_cols[kRows]; - DType bottom_cols[kRows]; + for (int i = 0; i < kCols; ++i) { + DType top_cols[kRows]; + DType bottom_cols[kRows]; #pragma unroll - for (int j = 0; j < kRows; ++j) { - auto base_tile = tile(j, i); - DType top_col_0 = reduce(base_tile(0, 0), base_tile(1, 0)); - DType top_col_1 = reduce(base_tile(0, 1), base_tile(1, 1)); - top_cols[j] = reduce(top_col_0, top_col_1); + for (int j = 0; j < kRows; ++j) { + auto base_tile = tile(j, i); + DType top_col_0 = reduce(base_tile(0, 0), base_tile(1, 0)); + DType top_col_1 = reduce(base_tile(0, 1), base_tile(1, 1)); + top_cols[j] = reduce(top_col_0, top_col_1); - DType bottom_col_0 = reduce(base_tile(2, 0), base_tile(3, 0)); - DType bottom_col_1 = reduce(base_tile(2, 1), base_tile(3, 1)); - bottom_cols[j] = reduce(bottom_col_0, bottom_col_1); - } + DType bottom_col_0 = reduce(base_tile(2, 0), base_tile(3, 0)); + DType bottom_col_1 = reduce(base_tile(2, 1), base_tile(3, 1)); + bottom_cols[j] = reduce(bottom_col_0, bottom_col_1); + } - DType top_col = top_cols[0]; - DType bottom_col = bottom_cols[0]; + DType top_col = top_cols[0]; + DType bottom_col = bottom_cols[0]; - // Compute the reduction of the top and bottom columns. + // Compute the reduction of the top and bottom columns. #pragma unroll - for (int j = 1; j < kRows; ++j) { - top_col = reduce(top_col, top_cols[j]); - bottom_col = reduce(bottom_col, bottom_cols[j]); - } - - // Shuffle the results to the leader thread. - top_col = reduce(top_col, shuffle_down_sync(MASK_ALL, top_col, 2)); - top_col = reduce(top_col, shuffle_down_sync(MASK_ALL, top_col, 1)); - bottom_col = - reduce(bottom_col, shuffle_down_sync(MASK_ALL, bottom_col, 2)); - bottom_col = - reduce(bottom_col, shuffle_down_sync(MASK_ALL, bottom_col, 1)); - - // Group the threads into groups of four, and broadcast the data - // from the first thread in each group to the other three threads. - top_col = shuffle_sync(MASK_ALL, top_col, leader); - bottom_col = shuffle_sync(MASK_ALL, bottom_col, leader); - - // Store the results to the destination tile. - dst(0, i) = top_col; - dst(1, i) = bottom_col; - } + for (int j = 1; j < kRows; ++j) { + top_col = reduce(top_col, top_cols[j]); + bottom_col = reduce(bottom_col, bottom_cols[j]); + } + + // Shuffle the results to the leader thread. + top_col = reduce(top_col, shuffle_down_sync(MASK_ALL, top_col, 2)); + top_col = reduce(top_col, shuffle_down_sync(MASK_ALL, top_col, 1)); + bottom_col = + reduce(bottom_col, shuffle_down_sync(MASK_ALL, bottom_col, 2)); + bottom_col = + reduce(bottom_col, shuffle_down_sync(MASK_ALL, bottom_col, 1)); + + // Group the threads into groups of four, and broadcast the data + // from the first thread in each group to the other three threads. + top_col = shuffle_sync(MASK_ALL, top_col, leader); + bottom_col = shuffle_sync(MASK_ALL, bottom_col, leader); + + // Store the results to the destination tile. + dst(0, i) = top_col; + dst(1, i) = bottom_col; } + } }; } // namespace detail template struct SumReduce { - using DType = typename RegTile::DType::DType; + using DType = typename RegTile::DType::DType; - template - DEVICE void operator()(const RegTile& src, DstTile& dst) { - detail::Reduce row_sum; - row_sum(src, dst, Add{}); - } + template + DEVICE void operator()(const RegTile& src, DstTile& dst) { + detail::Reduce row_sum; + row_sum(src, dst, Add{}); + } }; template struct MaxReduce { - using DType = typename RegTile::DType::DType; + using DType = typename RegTile::DType::DType; - template - DEVICE void operator()(const RegTile& src, DstTile& dst) { - detail::Reduce row_max; - row_max(src, dst, Max{}); - } + template + DEVICE void operator()(const RegTile& src, DstTile& dst) { + detail::Reduce row_max; + row_max(src, dst, Max{}); + } }; } // namespace tilefusion::cell::compute diff --git a/include/cell/copy/constants.hpp b/include/cell/copy/constants.hpp index 3c36a1bc..8e089034 100644 --- a/include/cell/copy/constants.hpp +++ b/include/cell/copy/constants.hpp @@ -5,11 +5,11 @@ namespace tilefusion::cell::copy { enum class WarpReuse { - // data are evenly partitioned to be loaded by warps. - kCont = 0, // all warps continuously load data, no reuse - kRowReuseCont = 1, // Row-wise even reuse, warps in the same row - // repeatedly load the same data - kColReuseCont = 2 // Column-wise even reuse, warps in the same column - // repeatedly load the same data + // data are evenly partitioned to be loaded by warps. + kCont = 0, // all warps continuously load data, no reuse + kRowReuseCont = 1, // Row-wise even reuse, warps in the same row + // repeatedly load the same data + kColReuseCont = 2 // Column-wise even reuse, warps in the same column + // repeatedly load the same data }; } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/copy_atom.hpp b/include/cell/copy/copy_atom.hpp index 7650c62e..7cc51c9b 100644 --- a/include/cell/copy/copy_atom.hpp +++ b/include/cell/copy/copy_atom.hpp @@ -17,33 +17,33 @@ namespace tl = tile_layout; namespace { template DEVICE void ld_global_st_shared(uint32_t dst, void const* src) { - static_assert(kBytes == 4 || kBytes == 8 || kBytes == 16); + static_assert(kBytes == 4 || kBytes == 8 || kBytes == 16); #if (__CUDA_ARCH__ >= 800) - // SM90, hopper, SM80, SM86, ampere - // TODO(ying): add a wrapper to allow choosing between different caching - // policies (e.g. "cache all levels"). - asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(dst), - "l"(src), "n"(kBytes)); + // SM90, hopper, SM80, SM86, ampere + // TODO(ying): add a wrapper to allow choosing between different caching + // policies (e.g. "cache all levels"). + asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(dst), + "l"(src), "n"(kBytes)); #else - // SM75, turing - unsigned tmp[kBytes / 4]; - if constexpr (kBytes == 16) { - asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n" - : "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3]) - : "l"(src)); - asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(dst), - "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), "r"(tmp[3])); - } else if constexpr (kBytes == 8) { - asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" - : "=r"(tmp[0]), "=r"(tmp[1]) - : "l"(src)); - asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(dst), - "r"(tmp[0]), "r"(tmp[1])); - } else if constexpr (kBytes == 4) { - asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(src)); - asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(dst), "r"(tmp[0])); - } + // SM75, turing + unsigned tmp[kBytes / 4]; + if constexpr (kBytes == 16) { + asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3]) + : "l"(src)); + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(dst), + "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), "r"(tmp[3])); + } else if constexpr (kBytes == 8) { + asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" + : "=r"(tmp[0]), "=r"(tmp[1]) + : "l"(src)); + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(dst), "r"(tmp[0]), + "r"(tmp[1])); + } else if constexpr (kBytes == 4) { + asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(src)); + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(dst), "r"(tmp[0])); + } #endif } @@ -54,36 +54,36 @@ DEVICE void ld_shared(void* dst, uint32_t src); /// ld.shared - 16b template <> DEVICE void ld_shared<2>(void* dst, uint32_t src) { - asm volatile("ld.shared.u16 %0, [%1];\n" - : "=h"(*reinterpret_cast(dst)) - : "r"(src)); + asm volatile("ld.shared.u16 %0, [%1];\n" + : "=h"(*reinterpret_cast(dst)) + : "r"(src)); } /// ld.shared - 32b template <> DEVICE void ld_shared<4>(void* dst, uint32_t src) { - asm volatile("ld.shared.u32 %0, [%1];\n" - : "=r"(*reinterpret_cast(dst)) - : "r"(src)); + asm volatile("ld.shared.u32 %0, [%1];\n" + : "=r"(*reinterpret_cast(dst)) + : "r"(src)); } /// ld.shared - 64b template <> DEVICE void ld_shared<8>(void* dst, uint32_t src) { - uint2* dst_u64 = reinterpret_cast(dst); - asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" - : "=r"(dst_u64->x), "=r"(dst_u64->y) - : "r"(src)); + uint2* dst_u64 = reinterpret_cast(dst); + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : "=r"(dst_u64->x), "=r"(dst_u64->y) + : "r"(src)); } /// ld.shared - 128b template <> DEVICE void ld_shared<16>(void* dst, uint32_t src) { - uint4* dst_u128 = reinterpret_cast(dst); - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst_u128->x), "=r"(dst_u128->y), "=r"(dst_u128->z), - "=r"(dst_u128->w) - : "r"(src)); + uint4* dst_u128 = reinterpret_cast(dst); + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_u128->x), "=r"(dst_u128->y), "=r"(dst_u128->z), + "=r"(dst_u128->w) + : "r"(src)); } /// st.shared @@ -93,36 +93,36 @@ DEVICE void st_shared(uint32_t dst, void const* src); /// st.shared - 16b template <> DEVICE void st_shared<2>(uint32_t dst, void const* src) { - asm volatile("st.shared.u16 [%0], %1;\n" - : - : "r"(dst), "h"(*reinterpret_cast(src))); + asm volatile("st.shared.u16 [%0], %1;\n" + : + : "r"(dst), "h"(*reinterpret_cast(src))); } /// st.shared - 32b template <> DEVICE void st_shared<4>(uint32_t dst, void const* src) { - asm volatile("st.shared.u32 [%0], %1;\n" - : - : "r"(dst), "r"(*reinterpret_cast(src))); + asm volatile("st.shared.u32 [%0], %1;\n" + : + : "r"(dst), "r"(*reinterpret_cast(src))); } /// st.shared - 64b template <> DEVICE void st_shared<8>(uint32_t dst, void const* src) { - uint2 const* dst_u64 = reinterpret_cast(src); - asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" - : - : "r"(dst), "r"(dst_u64->x), "r"(dst_u64->y)); + uint2 const* dst_u64 = reinterpret_cast(src); + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : + : "r"(dst), "r"(dst_u64->x), "r"(dst_u64->y)); } /// st.shared - 128b template <> DEVICE void st_shared<16>(uint32_t dst, void const* src) { - uint4 const* dst_u128 = reinterpret_cast(src); - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" - : - : "r"(dst), "r"(dst_u128->x), "r"(dst_u128->y), - "r"(dst_u128->z), "r"(dst_u128->w)); + uint4 const* dst_u128 = reinterpret_cast(src); + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z), + "r"(dst_u128->w)); } /// st.global @@ -131,11 +131,11 @@ DEVICE void st_global(void* dst, const void* src); template <> DEVICE void st_global<16>(void* dst, const void* src) { - uint4 const* dst_u128 = reinterpret_cast(src); - asm volatile("st.global.v4.b32 [%0], {%1, %2, %3, %4};\n" - : - : "l"(dst), "r"(dst_u128->x), "r"(dst_u128->y), - "r"(dst_u128->z), "r"(dst_u128->w)); + uint4 const* dst_u128 = reinterpret_cast(src); + asm volatile("st.global.v4.b32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z), + "r"(dst_u128->w)); } template @@ -143,59 +143,58 @@ DEVICE void ld_shared_st_global(void* dst, uint32_t src); template <> DEVICE void ld_shared_st_global<16>(void* dst, uint32_t src) { - unsigned tmp[4]; - ld_shared<16>(tmp, src); - st_global<16>(dst, tmp); + unsigned tmp[4]; + ld_shared<16>(tmp, src); + st_global<16>(dst, tmp); } } // namespace template - requires HalfType + requires HalfType struct LoadMatBase { - using DType = Element; - using ThreadLayout = tile_layout::ColMajor<16, 2>; - - static constexpr int kAccessInBits = 128; // 128 bits - static constexpr int kElmentBits = sizeof(DType) * 8; - static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; - - /// @brief Returns the lane row of the current thread within a warp. - // For ldmatrix, threads in a warp are arranged in a 16x2 - // column-major layout: - // - // | | 0 | 1| - // |--|---|---| - // |0 | 0 | 16| - // |1 | 2 | 17| - // |2 | 4 | 18| - // | |...|...| - // |15| 15| 31| - /// For example, if threadIdx.x is 43, its lane_row is 8 and lane_col is 0. - - /// @brief Returns the lane row of the current thread within a warp. - DEVICE int lane_row_id() { - int lane_id = threadIdx.x % WARP_SIZE; - return lane_id % tl::num_rows; - } - - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_col_id() { - int lane_id = threadIdx.x % WARP_SIZE; - return lane_id / tl::num_rows; - } - - /// @brief a thin wrapper for executing ldmatrix instruction to load a - /// `16x16` tile to register. - DEVICE void ldmatrix(const DType* src, DType* dst) { - uint32_t* reg = reinterpret_cast(dst); - uint32_t smem_addr = - static_cast(__cvta_generic_to_shared(src)); - - asm volatile( - "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(reg[0]), "=r"(reg[1]), "=r"(reg[2]), "=r"(reg[3]) - : "r"(smem_addr)); - } + using DType = Element; + using ThreadLayout = tile_layout::ColMajor<16, 2>; + + static constexpr int kAccessInBits = 128; // 128 bits + static constexpr int kElmentBits = sizeof(DType) * 8; + static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; + + /// @brief Returns the lane row of the current thread within a warp. + // For ldmatrix, threads in a warp are arranged in a 16x2 + // column-major layout: + // + // | | 0 | 1| + // |--|---|---| + // |0 | 0 | 16| + // |1 | 2 | 17| + // |2 | 4 | 18| + // | |...|...| + // |15| 15| 31| + /// For example, if threadIdx.x is 43, its lane_row is 8 and lane_col is 0. + + /// @brief Returns the lane row of the current thread within a warp. + DEVICE int lane_row_id() { + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id % tl::num_rows; + } + + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_col_id() { + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id / tl::num_rows; + } + + /// @brief a thin wrapper for executing ldmatrix instruction to load a + /// `16x16` tile to register. + DEVICE void ldmatrix(const DType* src, DType* dst) { + uint32_t* reg = reinterpret_cast(dst); + uint32_t smem_addr = static_cast(__cvta_generic_to_shared(src)); + + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reg[0]), "=r"(reg[1]), "=r"(reg[2]), "=r"(reg[3]) + : "r"(smem_addr)); + } }; template @@ -204,64 +203,64 @@ struct StoreMatBase; /// TODO(haruhi): try to reduece reusable codes. template struct StoreMatBase { - using DType = Shared::DType; + using DType = Shared::DType; - // the thread layout for wmma's output tile. - using ThreadLayout = tile_layout::RowMajor<8, 4>; + // the thread layout for wmma's output tile. + using ThreadLayout = tile_layout::RowMajor<8, 4>; - static constexpr int kThreadRows = tl::num_rows; - static constexpr int kThreadCols = tl::num_cols; + static constexpr int kThreadRows = tl::num_rows; + static constexpr int kThreadCols = tl::num_cols; - // in the output of a wmma tile, each thread stores four segments in 2x2 - // layout, and each fragment contains 2 elements regardless of the data - // type - static constexpr int kSegRows = 2; - static constexpr int kSegCols = 2; + // in the output of a wmma tile, each thread stores four segments in 2x2 + // layout, and each fragment contains 2 elements regardless of the data + // type + static constexpr int kSegRows = 2; + static constexpr int kSegCols = 2; - // the number of elements per segment, vectorized instruction are used to - // access `kElemPerSeg` elements. - static constexpr int kElemPerSeg = 2; + // the number of elements per segment, vectorized instruction are used to + // access `kElemPerSeg` elements. + static constexpr int kElemPerSeg = 2; - static constexpr int kAccessInBits = kElemPerSeg * int(sizeof(DType) * 8); + static constexpr int kAccessInBits = kElemPerSeg * int(sizeof(DType) * 8); - DEVICE int lane_row_id() { - return (threadIdx.x % WARP_SIZE) / tl::num_cols; - } + DEVICE int lane_row_id() { + return (threadIdx.x % WARP_SIZE) / tl::num_cols; + } - DEVICE int lane_col_id() { - return (threadIdx.x % WARP_SIZE) % tl::num_cols; - } + DEVICE int lane_col_id() { + return (threadIdx.x % WARP_SIZE) % tl::num_cols; + } }; /// TODO(haruhi): try to reduece reusable codes. template struct StoreMatBase { - using DType = Shared::DType; + using DType = Shared::DType; - // the thread layout for wmma's output tile. - using ThreadLayout = tile_layout::ColMajor<4, 8>; + // the thread layout for wmma's output tile. + using ThreadLayout = tile_layout::ColMajor<4, 8>; - static constexpr int kThreadRows = tl::num_rows; - static constexpr int kThreadCols = tl::num_cols; + static constexpr int kThreadRows = tl::num_rows; + static constexpr int kThreadCols = tl::num_cols; - // in the output of a wmma tile, each thread stores four segments in 2x2 - // layout, and each fragment contains 2 elements regardless of the data - // type - static constexpr int kSegRows = 2; - static constexpr int kSegCols = 2; + // in the output of a wmma tile, each thread stores four segments in 2x2 + // layout, and each fragment contains 2 elements regardless of the data + // type + static constexpr int kSegRows = 2; + static constexpr int kSegCols = 2; - // the number of elements per segment, vectorized instruction are used to - // access `kElemPerSeg` elements. - static constexpr int kElemPerSeg = 2; - static constexpr int kAccessInBits = kElemPerSeg * int(sizeof(DType) * 8); + // the number of elements per segment, vectorized instruction are used to + // access `kElemPerSeg` elements. + static constexpr int kElemPerSeg = 2; + static constexpr int kAccessInBits = kElemPerSeg * int(sizeof(DType) * 8); - DEVICE int lane_row_id() { - return (threadIdx.x % WARP_SIZE) % tl::num_rows; - } + DEVICE int lane_row_id() { + return (threadIdx.x % WARP_SIZE) % tl::num_rows; + } - DEVICE int lane_col_id() { - return (threadIdx.x % WARP_SIZE) / tl::num_rows; - } + DEVICE int lane_col_id() { + return (threadIdx.x % WARP_SIZE) / tl::num_rows; + } }; } // namespace tilefusion::cell::copy::atom diff --git a/include/cell/copy/global_to_register.hpp b/include/cell/copy/global_to_register.hpp index 5957b7e0..a3e2f1fb 100644 --- a/include/cell/copy/global_to_register.hpp +++ b/include/cell/copy/global_to_register.hpp @@ -12,10 +12,10 @@ namespace tilefusion::cell::copy { namespace { struct BaseTileConfig { - static constexpr int kRows = 16; - static constexpr int kCols = 16; - static constexpr int kThreadRows = 8; - static constexpr int kThreadCols = 4; + static constexpr int kRows = 16; + static constexpr int kCols = 16; + static constexpr int kThreadRows = 8; + static constexpr int kThreadCols = 4; }; } // namespace @@ -30,87 +30,87 @@ struct GlobalToRegLoaderImpl; template struct GlobalToRegLoaderImpl { - using Global = Global_; - using Reg = Reg_; - using DType = typename Global::DType; + using Global = Global_; + using Reg = Reg_; + using DType = typename Global::DType; - DEVICE void operator()(const DType* src, Reg& dst) { - int lane_id = threadIdx.x % WARP_SIZE; - const DType* data; - int land_row = lane_id / kThreadCol; - int land_col = lane_id % kThreadCol * 2; + DEVICE void operator()(const DType* src, Reg& dst) { + int lane_id = threadIdx.x % WARP_SIZE; + const DType* data; + int land_row = lane_id / kThreadCol; + int land_col = lane_id % kThreadCol * 2; - Vectorize copy; + Vectorize copy; #pragma unroll - for (int i = 0; i < kRowExec; ++i) { - int row = i * kCols + land_row; + for (int i = 0; i < kRowExec; ++i) { + int row = i * kCols + land_row; #pragma unroll - for (int j = 0; j < kColExec; ++j) { - int col = j * kRows + land_col; - data = src + row * kStride + col; - - copy(data, &dst(i, j)(0, 0)); - copy(data + 8, &dst(i, j)(1, 0)); - copy(data + 8 * kStride, &dst(i, j)(0, 2)); - copy(data + 8 * kStride + 8, &dst(i, j)(1, 2)); - } - } + for (int j = 0; j < kColExec; ++j) { + int col = j * kRows + land_col; + data = src + row * kStride + col; + + copy(data, &dst(i, j)(0, 0)); + copy(data + 8, &dst(i, j)(1, 0)); + copy(data + 8 * kStride, &dst(i, j)(0, 2)); + copy(data + 8 * kStride + 8, &dst(i, j)(1, 2)); + } } - - private: - // pre-computed values - static constexpr int kThreadCol = BaseTileConfig::kThreadCols; - static constexpr int kRows = BaseTileConfig::kRows; - static constexpr int kCols = BaseTileConfig::kCols; - static constexpr int kStride = Global::kRowStride; - - // how many times a `BaseTile` is executed along the row and column - // direction. - static constexpr int kRowExec = Reg::kRows; - static constexpr int kColExec = Reg::kCols; + } + + private: + // pre-computed values + static constexpr int kThreadCol = BaseTileConfig::kThreadCols; + static constexpr int kRows = BaseTileConfig::kRows; + static constexpr int kCols = BaseTileConfig::kCols; + static constexpr int kStride = Global::kRowStride; + + // how many times a `BaseTile` is executed along the row and column + // direction. + static constexpr int kRowExec = Reg::kRows; + static constexpr int kColExec = Reg::kCols; }; template struct GlobalToRegLoaderImpl { - using Global = Global_; - using Reg = Reg_; - using DType = typename Global::DType; + using Global = Global_; + using Reg = Reg_; + using DType = typename Global::DType; - DEVICE void operator()(const DType* src, Reg& dst) { - int lane_id = threadIdx.x % WARP_SIZE; + DEVICE void operator()(const DType* src, Reg& dst) { + int lane_id = threadIdx.x % WARP_SIZE; - int land_row = lane_id / kThreadCol; - int land_col = lane_id % kThreadCol * 2; + int land_row = lane_id / kThreadCol; + int land_col = lane_id % kThreadCol * 2; - const DType* data; - Vectorize copy; + const DType* data; + Vectorize copy; #pragma unroll - for (int i = 0; i < kColExec; ++i) { - int col = i * BaseTileConfig::kRows + land_row; - for (int j = 0; j < kRowExec; ++j) { - int row = j * BaseTileConfig::kCols + land_col; - data = src + col * kStride + row; - - copy(data, &dst(j, i)(0, 0)); - copy(data + 8, &dst(j, i)(0, 1)); - copy(data + 8 * kStride, &dst(j, i)(2, 0)); - copy(data + 8 * kStride + 8, &dst(j, i)(2, 1)); - } - } + for (int i = 0; i < kColExec; ++i) { + int col = i * BaseTileConfig::kRows + land_row; + for (int j = 0; j < kRowExec; ++j) { + int row = j * BaseTileConfig::kCols + land_col; + data = src + col * kStride + row; + + copy(data, &dst(j, i)(0, 0)); + copy(data + 8, &dst(j, i)(0, 1)); + copy(data + 8 * kStride, &dst(j, i)(2, 0)); + copy(data + 8 * kStride + 8, &dst(j, i)(2, 1)); + } } - - private: - // pre-computed values - static constexpr int kThreadCol = BaseTileConfig::kThreadCols; - static constexpr int kRows = BaseTileConfig::kRows; - static constexpr int kCols = BaseTileConfig::kCols; - static constexpr int kStride = Global::kColStride; - - // how many times a `BaseTile` is executed along the row and column - // direction. - static constexpr int kRowExec = Reg::kRows; - static constexpr int kColExec = Reg::kCols; + } + + private: + // pre-computed values + static constexpr int kThreadCol = BaseTileConfig::kThreadCols; + static constexpr int kRows = BaseTileConfig::kRows; + static constexpr int kCols = BaseTileConfig::kCols; + static constexpr int kStride = Global::kColStride; + + // how many times a `BaseTile` is executed along the row and column + // direction. + static constexpr int kRowExec = Reg::kRows; + static constexpr int kColExec = Reg::kCols; }; /** @@ -124,87 +124,87 @@ struct RegToGlobalStorerImpl; template struct RegToGlobalStorerImpl { - using Global = Global_; - using Reg = Reg_; - using DType = typename Global::DType; + using Global = Global_; + using Reg = Reg_; + using DType = typename Global::DType; - DEVICE void operator()(const Reg& src, DType* dst) { - int lane_id = threadIdx.x % WARP_SIZE; - DType* data; + DEVICE void operator()(const Reg& src, DType* dst) { + int lane_id = threadIdx.x % WARP_SIZE; + DType* data; - int land_row = lane_id / kThreadCol; - int land_col = lane_id % kThreadCol * 2; + int land_row = lane_id / kThreadCol; + int land_col = lane_id % kThreadCol * 2; - Vectorize copy; + Vectorize copy; #pragma unroll - for (int i = 0; i < kRowExec; ++i) { - int row = i * kCols + land_row; + for (int i = 0; i < kRowExec; ++i) { + int row = i * kCols + land_row; #pragma unroll - for (int j = 0; j < kColExec; ++j) { - int col = j * kRows + land_col; - data = dst + row * kStride + col; - - copy(&src(i, j)(0, 0), data); - copy(&src(i, j)(1, 0), data + 8); - copy(&src(i, j)(0, 2), data + 8 * kStride); - copy(&src(i, j)(1, 2), data + 8 * kStride + 8); - } - } + for (int j = 0; j < kColExec; ++j) { + int col = j * kRows + land_col; + data = dst + row * kStride + col; + + copy(&src(i, j)(0, 0), data); + copy(&src(i, j)(1, 0), data + 8); + copy(&src(i, j)(0, 2), data + 8 * kStride); + copy(&src(i, j)(1, 2), data + 8 * kStride + 8); + } } - - private: - // pre-computed values - static constexpr int kThreadCol = BaseTileConfig::kThreadCols; - static constexpr int kRows = BaseTileConfig::kRows; - static constexpr int kCols = BaseTileConfig::kCols; - static constexpr int kStride = Global::kRowStride; - - // how many times a `BaseTile` is executed along the row and column - // direction. - static constexpr int kRowExec = Reg::kRows; - static constexpr int kColExec = Reg::kCols; + } + + private: + // pre-computed values + static constexpr int kThreadCol = BaseTileConfig::kThreadCols; + static constexpr int kRows = BaseTileConfig::kRows; + static constexpr int kCols = BaseTileConfig::kCols; + static constexpr int kStride = Global::kRowStride; + + // how many times a `BaseTile` is executed along the row and column + // direction. + static constexpr int kRowExec = Reg::kRows; + static constexpr int kColExec = Reg::kCols; }; template struct RegToGlobalStorerImpl { - using Global = Global_; - using Reg = Reg_; - using DType = typename Global::DType; + using Global = Global_; + using Reg = Reg_; + using DType = typename Global::DType; - DEVICE void operator()(const Reg& src, DType* dst) { - int lane_id = threadIdx.x % WARP_SIZE; - DType* data; + DEVICE void operator()(const Reg& src, DType* dst) { + int lane_id = threadIdx.x % WARP_SIZE; + DType* data; - int land_row = lane_id / kThreadCol; - int land_col = lane_id % kThreadCol * 2; + int land_row = lane_id / kThreadCol; + int land_col = lane_id % kThreadCol * 2; - Vectorize copy; + Vectorize copy; #pragma unroll - for (int i = 0; i < kColExec; ++i) { - int col = i * kRows + land_row; + for (int i = 0; i < kColExec; ++i) { + int col = i * kRows + land_row; #pragma unroll - for (int j = 0; j < kRowExec; ++j) { - int row = j * kCols + land_col; - data = dst + col * kStride + row; - - copy(&src(j, i)(0, 0), data); - copy(&src(j, i)(0, 1), data + 8); - copy(&src(j, i)(2, 0), data + 8 * kStride); - copy(&src(j, i)(2, 1), data + 8 * kStride + 8); - } - } + for (int j = 0; j < kRowExec; ++j) { + int row = j * kCols + land_col; + data = dst + col * kStride + row; + + copy(&src(j, i)(0, 0), data); + copy(&src(j, i)(0, 1), data + 8); + copy(&src(j, i)(2, 0), data + 8 * kStride); + copy(&src(j, i)(2, 1), data + 8 * kStride + 8); + } } - - private: - // pre-computed values - static constexpr int kThreadCol = BaseTileConfig::kThreadCols; - static constexpr int kRows = BaseTileConfig::kRows; - static constexpr int kCols = BaseTileConfig::kCols; - static constexpr int kStride = Global::kColStride; - // how many times a `BaseTile` is executed along the row and column - // direction. - static constexpr int kRowExec = Reg::kRows; - static constexpr int kColExec = Reg::kCols; + } + + private: + // pre-computed values + static constexpr int kThreadCol = BaseTileConfig::kThreadCols; + static constexpr int kRows = BaseTileConfig::kRows; + static constexpr int kCols = BaseTileConfig::kCols; + static constexpr int kStride = Global::kColStride; + // how many times a `BaseTile` is executed along the row and column + // direction. + static constexpr int kRowExec = Reg::kRows; + static constexpr int kColExec = Reg::kCols; }; /** @@ -216,26 +216,26 @@ struct RegToGlobalStorerImpl { */ template struct GlobalToRegLoader { - using Reg = Reg_; - using DType = typename Reg::DType::DType; - using WarpLayout = WarpLayout_; - static constexpr WarpReuse kMode = kMode_; - - template - DEVICE void operator()(const Global& src, Reg& dst) { - // advance the pointer to input data to the current warp - // according to warp reuse mode. - int offset = global_offset_.template get_warp_offset(); - - using Loader = GlobalToRegLoaderImpl; - Loader loader; - loader(src.data() + offset, dst); - } - - private: - using GlobalOffset = warp::GlobalOffsetHelper; - - GlobalOffset global_offset_; + using Reg = Reg_; + using DType = typename Reg::DType::DType; + using WarpLayout = WarpLayout_; + static constexpr WarpReuse kMode = kMode_; + + template + DEVICE void operator()(const Global& src, Reg& dst) { + // advance the pointer to input data to the current warp + // according to warp reuse mode. + int offset = global_offset_.template get_warp_offset(); + + using Loader = GlobalToRegLoaderImpl; + Loader loader; + loader(src.data() + offset, dst); + } + + private: + using GlobalOffset = warp::GlobalOffsetHelper; + + GlobalOffset global_offset_; }; /** @@ -247,25 +247,25 @@ struct GlobalToRegLoader { */ template struct RegToGlobalStorer { - using Global = Global_; - using Reg = Reg_; - using DType = typename Global::DType; - using WarpLayout = WarpLayout_; + using Global = Global_; + using Reg = Reg_; + using DType = typename Global::DType; + using WarpLayout = WarpLayout_; - DEVICE void operator()(const Reg& src, Global& dst) { - DType* dst_ptr = dst.mutable_data(); + DEVICE void operator()(const Reg& src, Global& dst) { + DType* dst_ptr = dst.mutable_data(); - // advance the pointer to output data to the current warp - // according to warp reuse mode. - int offset = global_offset_.template get_warp_offset(); + // advance the pointer to output data to the current warp + // according to warp reuse mode. + int offset = global_offset_.template get_warp_offset(); - using Storer = RegToGlobalStorerImpl; - Storer storer; - storer(src, dst_ptr + offset); - } + using Storer = RegToGlobalStorerImpl; + Storer storer; + storer(src, dst_ptr + offset); + } - using GlobalOffset = warp::GlobalOffsetHelper; + using GlobalOffset = warp::GlobalOffsetHelper; - GlobalOffset global_offset_; + GlobalOffset global_offset_; }; } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/global_to_shared.hpp b/include/cell/copy/global_to_shared.hpp index 3da6101b..e4d5081e 100644 --- a/include/cell/copy/global_to_shared.hpp +++ b/include/cell/copy/global_to_shared.hpp @@ -30,171 +30,167 @@ template struct GlobalToSharedLoaderImpl { - using Global = Global_; - using Shared = Shared_; - using DType = Global::DType; - using BaseShape = BaseShape_; - - static_assert(Global::kRows == Shared::kRows && - Global::kCols == Shared::kCols, - "Global and shared memory should have the same shape."); - static_assert(Global::kType == Shared::kType, - "The layout of Global memory and Shared memory tile should " - "be the same."); - static_assert(Global::kType == tl::Layout::kRowMajor, - "The layout of Global memory and Shared memory tile should " - "be row-major."); - static_assert(std::is_same_v, - "The data type of Shared and Global must be the same."); - - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; - - DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { - int row = lane_row_id(); - int col = lane_col_id() * kNumPerAccess; - - int src_offset = 0, dst_offset = 0, offset = 0; - uint32_t dst_ptr; + using Global = Global_; + using Shared = Shared_; + using DType = Global::DType; + using BaseShape = BaseShape_; + + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kRowMajor, + "The layout of Global memory and Shared memory tile should " + "be row-major."); + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; + + DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { + int row = lane_row_id(); + int col = lane_col_id() * kNumPerAccess; + + int src_offset = 0, dst_offset = 0, offset = 0; + uint32_t dst_ptr; #pragma unroll - for (int i = 0; i < kRowExec; ++i) { + for (int i = 0; i < kRowExec; ++i) { #pragma unroll - for (int j = 0; j < kColExec; ++j) { - src_offset = src_tile_(i, j) + in_src_tile_(row, col); - offset = warp_offset + - i * BaseShape::kRows * Shared::kRowStride + - j * BaseShape::kCols + row * Shared::kRowStride + col; - - dst_offset = - shared_tile.fetch_physical_offset(offset) - warp_offset; - - dst_ptr = static_cast( - __cvta_generic_to_shared(dst + dst_offset)); - ld_global_st_shared(dst_ptr, src + src_offset); - } - } - } - - private: - static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; - - static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; - - using SrcLayout = tl::MatrixLayout; - SrcLayout src_tile_; + for (int j = 0; j < kColExec; ++j) { + src_offset = src_tile_(i, j) + in_src_tile_(row, col); + offset = warp_offset + i * BaseShape::kRows * Shared::kRowStride + + j * BaseShape::kCols + row * Shared::kRowStride + col; - // Given a thread index, the GlobalLayout and SharedLayout below return the - // data offset from which the thread should load from the global memory tile - // and where to store it in the shared memory tile, respectively. - using InSrcLayout = tl::MatrixLayout; - - // `in_src_tile_` is a basetile handled by a single warp. - InSrcLayout in_src_tile_; - - Shared shared_tile; - - /// @brief returns the lane row of the current thread within a warp. - DEVICE int lane_row_id() { - // NOTE: When copying a RowMajor data tile, the thread layout is - // interpreted as RowMajor. - int lane_id = threadIdx.x % WARP_SIZE; - return lane_id / BaseShape::kColThreads; - } + dst_offset = shared_tile.fetch_physical_offset(offset) - warp_offset; - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_col_id() { - // NOTE: When copying a RowMajor data tile, the thread layout is - // interpreted as RowMajor. - int lane_id = threadIdx.x % WARP_SIZE; - return lane_id % BaseShape::kColThreads; + dst_ptr = + static_cast(__cvta_generic_to_shared(dst + dst_offset)); + ld_global_st_shared(dst_ptr, src + src_offset); + } } + } + + private: + static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; + + static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; + + using SrcLayout = + tl::MatrixLayout; + SrcLayout src_tile_; + + // Given a thread index, the GlobalLayout and SharedLayout below return the + // data offset from which the thread should load from the global memory tile + // and where to store it in the shared memory tile, respectively. + using InSrcLayout = tl::MatrixLayout; + + // `in_src_tile_` is a basetile handled by a single warp. + InSrcLayout in_src_tile_; + + Shared shared_tile; + + /// @brief returns the lane row of the current thread within a warp. + DEVICE int lane_row_id() { + // NOTE: When copying a RowMajor data tile, the thread layout is + // interpreted as RowMajor. + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id / BaseShape::kColThreads; + } + + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_col_id() { + // NOTE: When copying a RowMajor data tile, the thread layout is + // interpreted as RowMajor. + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id % BaseShape::kColThreads; + } }; template struct GlobalToSharedLoaderImpl { - using Global = Global_; - using Shared = Shared_; - using DType = Global::DType; - using BaseShape = BaseShape_; - - static_assert(Global::kRows == Shared::kRows && - Global::kCols == Shared::kCols, - "Global and shared memory should have the same shape."); - static_assert(Global::kType == Shared::kType, - "The layout of Global memory and Shared memory tile should " - "be the same."); - static_assert(Global::kType == tl::Layout::kColMajor, - "The layout of Global memory and Shared memory tile should " - "be column-major."); - - static_assert(std::is_same_v, - "The data type of Shared and Global must be the same."); - - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; - - DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { - int lane_row = lane_row_id() * kNumPerAccess; - int lane_col = lane_col_id(); - - int src_offset = 0, dst_offset = 0; - int offset = 0; - uint32_t dst_ptr; + using Global = Global_; + using Shared = Shared_; + using DType = Global::DType; + using BaseShape = BaseShape_; + + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kColMajor, + "The layout of Global memory and Shared memory tile should " + "be column-major."); + + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; + + DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { + int lane_row = lane_row_id() * kNumPerAccess; + int lane_col = lane_col_id(); + + int src_offset = 0, dst_offset = 0; + int offset = 0; + uint32_t dst_ptr; #pragma unroll - for (int j = 0; j < kColExec; ++j) { + for (int j = 0; j < kColExec; ++j) { #pragma unroll - for (int i = 0; i < kRowExec; ++i) { - src_offset = src_tile_(i, j) + in_src_tile_(lane_row, lane_col); - offset = warp_offset + - j * BaseShape::kCols * Shared::kColStride + - i * BaseShape::kRows + lane_col * Shared::kColStride + - lane_row; - dst_offset = - shared_tile.fetch_physical_offset(offset) - warp_offset; - - dst_ptr = static_cast( - __cvta_generic_to_shared(dst + dst_offset)); - ld_global_st_shared(dst_ptr, src + src_offset); - } - } + for (int i = 0; i < kRowExec; ++i) { + src_offset = src_tile_(i, j) + in_src_tile_(lane_row, lane_col); + offset = warp_offset + j * BaseShape::kCols * Shared::kColStride + + i * BaseShape::kRows + lane_col * Shared::kColStride + + lane_row; + dst_offset = shared_tile.fetch_physical_offset(offset) - warp_offset; + + dst_ptr = + static_cast(__cvta_generic_to_shared(dst + dst_offset)); + ld_global_st_shared(dst_ptr, src + src_offset); + } } + } - private: - static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; + private: + static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; - static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; + static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; - using SrcLayout = tl::MatrixLayout; - SrcLayout src_tile_; + using SrcLayout = tl::MatrixLayout; + SrcLayout src_tile_; - // Given a thread index, the GlobalLayout and SharedLayout below return the - // data offset from which the thread should load from the global memory tile - // and where to store it in the shared memory tile, respectively. - using GlobalLayout = tl::MatrixLayout; + // Given a thread index, the GlobalLayout and SharedLayout below return the + // data offset from which the thread should load from the global memory tile + // and where to store it in the shared memory tile, respectively. + using GlobalLayout = tl::MatrixLayout; - // `src_tile_` is a basetile handled by a single warp. - GlobalLayout in_src_tile_; + // `src_tile_` is a basetile handled by a single warp. + GlobalLayout in_src_tile_; - Shared shared_tile; + Shared shared_tile; - /// @brief returns the lane row of the current thread within a warp. - DEVICE int lane_row_id() { - int lane_id = threadIdx.x % WARP_SIZE; - return lane_id / BaseShape::kColThreads; - } + /// @brief returns the lane row of the current thread within a warp. + DEVICE int lane_row_id() { + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id / BaseShape::kColThreads; + } - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_col_id() { - int lane_id = threadIdx.x % WARP_SIZE; - return lane_id % BaseShape::kColThreads; - } + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_col_id() { + int lane_id = threadIdx.x % WARP_SIZE; + return lane_id % BaseShape::kColThreads; + } }; template struct SharedToGlobalStorerImpl { - using Shared = Shared_; - using Global = Global_; - using DType = Shared::DType; - using BaseShape = BaseShape_; - static_assert(Global::kRows == Shared::kRows && - Global::kCols == Shared::kCols, - "Global and shared memory should have the same shape."); - static_assert(Global::kType == Shared::kType, - "The layout of Global memory and Shared memory tile should " - "be the same."); - static_assert(Global::kType == tl::Layout::kRowMajor, - "The layout of Global memory and Shared memory tile should " - "be row-major."); - static_assert(std::is_same_v, - "The data type of Shared and Global must be the same."); - - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; - - DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { - int row = lane_row_id(); - int col = lane_col_id() * kNumPerAccess; - - uint32_t src_ptr; - int src_offset = 0, dst_offset = 0; - int offset = 0; + using Shared = Shared_; + using Global = Global_; + using DType = Shared::DType; + using BaseShape = BaseShape_; + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kRowMajor, + "The layout of Global memory and Shared memory tile should " + "be row-major."); + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; + + DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { + int row = lane_row_id(); + int col = lane_col_id() * kNumPerAccess; + + uint32_t src_ptr; + int src_offset = 0, dst_offset = 0; + int offset = 0; #pragma unroll - for (int i = 0; i < kRowExec; ++i) { + for (int i = 0; i < kRowExec; ++i) { #pragma unroll - for (int j = 0; j < kColExec; ++j) { - offset = warp_offset + - i * BaseShape::kRows * Shared::kRowStride + - j * BaseShape::kCols + row * Shared::kRowStride + col; - src_offset = - shared_tile.fetch_physical_offset(offset) - warp_offset; - dst_offset = dst_tile_(i, j) + in_dst_tile_(row, col); - - src_ptr = static_cast( - __cvta_generic_to_shared(src + src_offset)); - ld_shared_st_global(dst + dst_offset, src_ptr); - } - } - } - - private: - static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; - - using DstLayout = tl::MatrixLayout; - DstLayout dst_tile_; - - // NOTE: DO NOT modify `kNumPerAccess` and `kAccessInBits` here. - // `kAccessInBits` in the storer is for tensor core's output where only two - // numbers are contiguous in memory. This ensures the parameters remain - // consistent with those used in `SharedLayoutWrapper` within the - // register-to-shared storer. - static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8); - static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; - - using GlobalLayout = tl::MatrixLayout; - GlobalLayout in_dst_tile_; - - Shared shared_tile; - - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_row_id() { - return (threadIdx.x % WARP_SIZE) / BaseShape::kColThreads; - } - - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_col_id() { - return (threadIdx.x % WARP_SIZE) % BaseShape::kColThreads; + for (int j = 0; j < kColExec; ++j) { + offset = warp_offset + i * BaseShape::kRows * Shared::kRowStride + + j * BaseShape::kCols + row * Shared::kRowStride + col; + src_offset = shared_tile.fetch_physical_offset(offset) - warp_offset; + dst_offset = dst_tile_(i, j) + in_dst_tile_(row, col); + + src_ptr = + static_cast(__cvta_generic_to_shared(src + src_offset)); + ld_shared_st_global(dst + dst_offset, src_ptr); + } } + } + + private: + static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; + + using DstLayout = + tl::MatrixLayout; + DstLayout dst_tile_; + + // NOTE: DO NOT modify `kNumPerAccess` and `kAccessInBits` here. + // `kAccessInBits` in the storer is for tensor core's output where only two + // numbers are contiguous in memory. This ensures the parameters remain + // consistent with those used in `SharedLayoutWrapper` within the + // register-to-shared storer. + static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8); + static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; + + using GlobalLayout = tl::MatrixLayout; + GlobalLayout in_dst_tile_; + + Shared shared_tile; + + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_row_id() { + return (threadIdx.x % WARP_SIZE) / BaseShape::kColThreads; + } + + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_col_id() { + return (threadIdx.x % WARP_SIZE) % BaseShape::kColThreads; + } }; template struct SharedToGlobalStorerImpl { - using Shared = Shared_; - using Global = Global_; - using DType = Shared::DType; - using BaseShape = BaseShape_; - - static_assert(Global::kRows == Shared::kRows && - Global::kCols == Shared::kCols, - "Global and shared memory should have the same shape."); - static_assert(Global::kType == Shared::kType, - "The layout of Global memory and Shared memory tile should " - "be the same."); - static_assert(Global::kType == tl::Layout::kColMajor, - "The layout of Global memory and Shared memory tile should " - "be column-major."); - static_assert(std::is_same_v, - "The data type of Shared and Global must be the same."); - - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; - - DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { - int lane_row = lane_row_id() * kNumPerAccess; - int lane_col = lane_col_id(); - - int src_offset = 0, dst_offset = 0; - int offset = 0; - uint32_t src_ptr; + using Shared = Shared_; + using Global = Global_; + using DType = Shared::DType; + using BaseShape = BaseShape_; + + static_assert(Global::kRows == Shared::kRows && + Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + static_assert(Global::kType == Shared::kType, + "The layout of Global memory and Shared memory tile should " + "be the same."); + static_assert(Global::kType == tl::Layout::kColMajor, + "The layout of Global memory and Shared memory tile should " + "be column-major."); + static_assert(std::is_same_v, + "The data type of Shared and Global must be the same."); + + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; + + DEVICE void operator()(const DType* src, DType* dst, int warp_offset) { + int lane_row = lane_row_id() * kNumPerAccess; + int lane_col = lane_col_id(); + + int src_offset = 0, dst_offset = 0; + int offset = 0; + uint32_t src_ptr; #pragma unroll - for (int j = 0; j < kColExec; ++j) { + for (int j = 0; j < kColExec; ++j) { #pragma unroll - for (int i = 0; i < kRowExec; ++i) { - offset = warp_offset + - j * BaseShape::kCols * Shared::kColStride + - i * BaseShape::kRows + lane_col * Shared::kColStride + - lane_row; - src_offset = - shared_tile.fetch_physical_offset(offset) - warp_offset; - dst_offset = dst_tile_(i, j) + in_dst_tile_(lane_row, lane_col); - - src_ptr = static_cast( - __cvta_generic_to_shared(src + src_offset)); - ld_shared_st_global(dst + dst_offset, src_ptr); - } - } - } - - private: - static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; - - using DstLayout = tl::MatrixLayout; - DstLayout dst_tile_; - - // NOTE: DO NOT modify `kNumPerAccess` and `kAccessInBits` here. - // `kAccessInBits` in the storer is for tensor core's output where only two - // numbers are contiguous in memory. This ensures the parameters remain - // consistent with those used in `SharedLayoutWrapper` within the - // register-to-shared storer. - static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8); - static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; - - using GlobalLayout = tl::MatrixLayout; - GlobalLayout in_dst_tile_; - - Shared shared_tile; - - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_row_id() { - return (threadIdx.x % WARP_SIZE) / BaseShape::kColThreads; - } - - /// @brief returns the lane col of the current thread within a warp. - DEVICE int lane_col_id() { - return (threadIdx.x % WARP_SIZE) % BaseShape::kColThreads; + for (int i = 0; i < kRowExec; ++i) { + offset = warp_offset + j * BaseShape::kCols * Shared::kColStride + + i * BaseShape::kRows + lane_col * Shared::kColStride + + lane_row; + src_offset = shared_tile.fetch_physical_offset(offset) - warp_offset; + dst_offset = dst_tile_(i, j) + in_dst_tile_(lane_row, lane_col); + + src_ptr = + static_cast(__cvta_generic_to_shared(src + src_offset)); + ld_shared_st_global(dst + dst_offset, src_ptr); + } } + } + + private: + static constexpr int kAccessInBytes = AccessBase::kAccessInBytes; + + using DstLayout = tl::MatrixLayout; + DstLayout dst_tile_; + + // NOTE: DO NOT modify `kNumPerAccess` and `kAccessInBits` here. + // `kAccessInBits` in the storer is for tensor core's output where only two + // numbers are contiguous in memory. This ensures the parameters remain + // consistent with those used in `SharedLayoutWrapper` within the + // register-to-shared storer. + static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8); + static constexpr int kNumPerAccess = AccessBase::kNumPerAccess; + + using GlobalLayout = tl::MatrixLayout; + GlobalLayout in_dst_tile_; + + Shared shared_tile; + + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_row_id() { + return (threadIdx.x % WARP_SIZE) / BaseShape::kColThreads; + } + + /// @brief returns the lane col of the current thread within a warp. + DEVICE int lane_col_id() { + return (threadIdx.x % WARP_SIZE) % BaseShape::kColThreads; + } }; /// @brief The thread-block level API that cooperatively transfers a data tile @@ -371,137 +363,137 @@ struct SharedToGlobalStorerImpl struct GlobalToSharedLoader { - using Shared = Shared_; - using DType = Shared::DType; - using WarpLayout = WarpLayout_; - - // NOTE: The WarpShape calculated here is for the warp reuse mode `kCont`. - // If you use a different mode, update the WarpShape accordingly. - static_assert((Shared::kRows % WarpLayout ::kRows == 0) && - (Shared::kCols % WarpLayout::kCols == 0), - "The shape of SharedTile must be divisible by the shape of " - "WarpLayout."); - - using WarpShape = TileShape; - using BaseShape = WarpBaseTileShape; - - static_assert(Shared::kRows % BaseShape ::kRows == 0, - "Shared::kRows must be divisible by BaseShape::kRows."); - static_assert(Shared::kCols % BaseShape::kCols == 0, - "Shared::kCols must be divisible by BaseShape::kCols."); - - static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode - using ExecCounter = warp::ExecCounter; - using GlobalOffset = warp::GlobalOffsetHelper; - using SharedOffset = - warp::SharedOffsetHelper; - - static constexpr int kRowExec = ExecCounter::kRowExec; - static constexpr int kColExec = ExecCounter::kColExec; - - static_assert(kRowExec && kColExec, - "Ensure that the execution count for all rows and columns is " - "greater than 0."); - - static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; - static constexpr int kSharedContInBytes = - Shared::kType == tl::Layout::kRowMajor - ? Shared::kCols * sizeof(DType) / WarpLayout::kCols - : Shared::kRows * sizeof(DType) / WarpLayout::kRows; - - static_assert(kSharedAccessInBytes <= kSharedContInBytes, - "kSharedAccessInBytes must be less than or equal to " - "kSharedContInBytes"); - static_assert(kSharedAccessInBytes % 32 == 0, - "The number of bytes in a warp tile must be divisible by " - "32."); - - template - DEVICE void operator()(const Global& src, Shared& dst) { - static_assert( - Global::kRows == Shared::kRows && Global::kCols == Shared::kCols, - "Global and shared memory should have the same shape."); - - const DType* src_ptr = src.data(); - DType* dst_ptr = dst.mutable_data(); - - // get warp offset for global and shared memory - int offset_src = global_offset_.template get_warp_offset(); - int offset_dst = shared_offset_.get_warp_offset(); - - // Load a single warp tile from global memory to shared memory - using Loader = GlobalToSharedLoaderImpl; - Loader loader; - loader(src_ptr + offset_src, dst_ptr + offset_dst, offset_dst); - } - - private: - GlobalOffset global_offset_; - SharedOffset shared_offset_; + using Shared = Shared_; + using DType = Shared::DType; + using WarpLayout = WarpLayout_; + + // NOTE: The WarpShape calculated here is for the warp reuse mode `kCont`. + // If you use a different mode, update the WarpShape accordingly. + static_assert((Shared::kRows % WarpLayout ::kRows == 0) && + (Shared::kCols % WarpLayout::kCols == 0), + "The shape of SharedTile must be divisible by the shape of " + "WarpLayout."); + + using WarpShape = TileShape; + using BaseShape = WarpBaseTileShape; + + static_assert(Shared::kRows % BaseShape ::kRows == 0, + "Shared::kRows must be divisible by BaseShape::kRows."); + static_assert(Shared::kCols % BaseShape::kCols == 0, + "Shared::kCols must be divisible by BaseShape::kCols."); + + static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode + using ExecCounter = warp::ExecCounter; + using GlobalOffset = warp::GlobalOffsetHelper; + using SharedOffset = + warp::SharedOffsetHelper; + + static constexpr int kRowExec = ExecCounter::kRowExec; + static constexpr int kColExec = ExecCounter::kColExec; + + static_assert(kRowExec && kColExec, + "Ensure that the execution count for all rows and columns is " + "greater than 0."); + + static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; + static constexpr int kSharedContInBytes = + Shared::kType == tl::Layout::kRowMajor + ? Shared::kCols * sizeof(DType) / WarpLayout::kCols + : Shared::kRows * sizeof(DType) / WarpLayout::kRows; + + static_assert(kSharedAccessInBytes <= kSharedContInBytes, + "kSharedAccessInBytes must be less than or equal to " + "kSharedContInBytes"); + static_assert(kSharedAccessInBytes % 32 == 0, + "The number of bytes in a warp tile must be divisible by " + "32."); + + template + DEVICE void operator()(const Global& src, Shared& dst) { + static_assert( + Global::kRows == Shared::kRows && Global::kCols == Shared::kCols, + "Global and shared memory should have the same shape."); + + const DType* src_ptr = src.data(); + DType* dst_ptr = dst.mutable_data(); + + // get warp offset for global and shared memory + int offset_src = global_offset_.template get_warp_offset(); + int offset_dst = shared_offset_.get_warp_offset(); + + // Load a single warp tile from global memory to shared memory + using Loader = + GlobalToSharedLoaderImpl; + Loader loader; + loader(src_ptr + offset_src, dst_ptr + offset_dst, offset_dst); + } + + private: + GlobalOffset global_offset_; + SharedOffset shared_offset_; }; template struct SharedToGlobalStorer { - using Shared = Shared_; - using DType = Shared::DType; - using WarpLayout = WarpLayout_; - - using WarpShape = TileShape; - using BaseShape = WarpBaseTileShape; - - static_assert(Shared::kRows % BaseShape::kRows == 0, - "Shared::kRows must be divisible by BaseShape::kRows."); - static_assert(Shared::kCols % BaseShape::kCols == 0, - "Shared::kCols must be divisible by BaseShape::kCols."); - - static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode - - using GlobalOffset = warp::GlobalOffsetHelper; - using SharedOffset = - warp::SharedOffsetHelper; - - using ExecCounter = warp::ExecCounter; - - static constexpr int kRowExec = ExecCounter::kRowExec; - static constexpr int kColExec = ExecCounter::kColExec; - - static_assert(kRowExec && kColExec, - "Execution count should be greater than 0."); - - static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; - static constexpr int kSharedContInBytes = - Shared::kType == tl::Layout::kRowMajor - ? Shared::kCols * sizeof(DType) / WarpLayout::kCols - : Shared::kRows * sizeof(DType) / WarpLayout::kRows; - - static_assert(kSharedAccessInBytes <= kSharedContInBytes, - "kSharedAccessInBytes must be less than or equal to " - "kSharedContInBytes"); - static_assert(kSharedAccessInBytes % 32 == 0, - "The number of bytes in a warp tile must be divisible by " - "32."); - - template - DEVICE void operator()(const Shared& src_, Global& dst_) { - const DType* src = src_.data(); - DType* dst = dst_.mutable_data(); - - // The offset for data that the current warp should access - int offset_src = shared_offset_.get_warp_offset(); - int offset_dst = global_offset_.template get_warp_offset(); - - using Storer = SharedToGlobalStorerImpl; - Storer storer; - storer(src + offset_src, dst + offset_dst, offset_src); - } - - private: - SharedOffset shared_offset_; - GlobalOffset global_offset_; + using Shared = Shared_; + using DType = Shared::DType; + using WarpLayout = WarpLayout_; + + using WarpShape = TileShape; + using BaseShape = WarpBaseTileShape; + + static_assert(Shared::kRows % BaseShape::kRows == 0, + "Shared::kRows must be divisible by BaseShape::kRows."); + static_assert(Shared::kCols % BaseShape::kCols == 0, + "Shared::kCols must be divisible by BaseShape::kCols."); + + static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode + + using GlobalOffset = warp::GlobalOffsetHelper; + using SharedOffset = + warp::SharedOffsetHelper; + + using ExecCounter = warp::ExecCounter; + + static constexpr int kRowExec = ExecCounter::kRowExec; + static constexpr int kColExec = ExecCounter::kColExec; + + static_assert(kRowExec && kColExec, + "Execution count should be greater than 0."); + + static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; + static constexpr int kSharedContInBytes = + Shared::kType == tl::Layout::kRowMajor + ? Shared::kCols * sizeof(DType) / WarpLayout::kCols + : Shared::kRows * sizeof(DType) / WarpLayout::kRows; + + static_assert(kSharedAccessInBytes <= kSharedContInBytes, + "kSharedAccessInBytes must be less than or equal to " + "kSharedContInBytes"); + static_assert(kSharedAccessInBytes % 32 == 0, + "The number of bytes in a warp tile must be divisible by " + "32."); + + template + DEVICE void operator()(const Shared& src_, Global& dst_) { + const DType* src = src_.data(); + DType* dst = dst_.mutable_data(); + + // The offset for data that the current warp should access + int offset_src = shared_offset_.get_warp_offset(); + int offset_dst = global_offset_.template get_warp_offset(); + + using Storer = + SharedToGlobalStorerImpl; + Storer storer; + storer(src + offset_src, dst + offset_dst, offset_src); + } + + private: + SharedOffset shared_offset_; + GlobalOffset global_offset_; }; } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/register.hpp b/include/cell/copy/register.hpp index 68fa0275..9aa83acb 100644 --- a/include/cell/copy/register.hpp +++ b/include/cell/copy/register.hpp @@ -10,26 +10,26 @@ namespace tilefusion::cell::copy { namespace detail { template struct DataCopy { - DEVICE void operator()(const Element& src, Element& dst) { dst = src; } + DEVICE void operator()(const Element& src, Element& dst) { dst = src; } }; template struct RegCopy { - using DType = typename RegTile::DType; + using DType = typename RegTile::DType; - static constexpr int kRows = RegTile::kRows; - static constexpr int kCols = RegTile::kCols; + static constexpr int kRows = RegTile::kRows; + static constexpr int kCols = RegTile::kCols; - DEVICE void operator()(const RegTile& src, RegTile& dst) { - Copy c; + DEVICE void operator()(const RegTile& src, RegTile& dst) { + Copy c; #pragma unroll - for (int i = 0; i < kRows; ++i) { + for (int i = 0; i < kRows; ++i) { #pragma unroll - for (int j = 0; j < kCols; ++j) { - c(src(i, j), dst(i, j)); - } - } + for (int j = 0; j < kCols; ++j) { + c(src(i, j), dst(i, j)); + } } + } }; } // namespace detail diff --git a/include/cell/copy/shared_to_register.hpp b/include/cell/copy/shared_to_register.hpp index 0c4c2a12..0c789657 100644 --- a/include/cell/copy/shared_to_register.hpp +++ b/include/cell/copy/shared_to_register.hpp @@ -21,47 +21,46 @@ template : public LoadMatBase { - using LoadMat = LoadMatBase; - using DType = Shared::DType; - using Reg = Reg_; + using LoadMat = LoadMatBase; + using DType = Shared::DType; + using Reg = Reg_; - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; - DEVICE void operator()(const DType* src, Reg& dst, int warp_offset, - int iterator_offset) { - int global_offset = warp_offset + iterator_offset; - int lane_row = this->lane_row_id(); - int lane_col = this->lane_col_id() * LoadMat::kNumPerAccess; + DEVICE void operator()(const DType* src, Reg& dst, int warp_offset, + int iterator_offset) { + int global_offset = warp_offset + iterator_offset; + int lane_row = this->lane_row_id(); + int lane_col = this->lane_col_id() * LoadMat::kNumPerAccess; #pragma unroll - for (int i = 0; i < kRowExec; ++i) { + for (int i = 0; i < kRowExec; ++i) { #pragma unroll - for (int j = 0; j < kColExec; ++j) { - int tile_offset = global_offset + - i * kSharedRowStride * BaseShape::kRows + - j * BaseShape::kCols + - lane_row * kSharedRowStride + lane_col; - int offset = shared_tile.fetch_physical_offset(tile_offset) - - iterator_offset; - - // advance pointer to the 16x16 `BaseTile` indexed by(i, j). - // issue the hardware-backed memory access instruction. - this->ldmatrix(src + offset, dst(i, j).mutable_data()); - } - } + for (int j = 0; j < kColExec; ++j) { + int tile_offset = + global_offset + i * kSharedRowStride * BaseShape::kRows + + j * BaseShape::kCols + lane_row * kSharedRowStride + lane_col; + int offset = + shared_tile.fetch_physical_offset(tile_offset) - iterator_offset; + + // advance pointer to the 16x16 `BaseTile` indexed by(i, j). + // issue the hardware-backed memory access instruction. + this->ldmatrix(src + offset, dst(i, j).mutable_data()); + } } - - private: - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. - // Future refactoring of the program's concepts and interfaces should - // eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - static constexpr int kSharedRowStride = Shared::kRowStride; - Shared shared_tile; + } + + private: + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. + // Future refactoring of the program's concepts and interfaces should + // eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + static constexpr int kSharedRowStride = Shared::kRowStride; + Shared shared_tile; }; /// @brief partial specialization for column-major shared memory tile. @@ -70,49 +69,48 @@ template : public LoadMatBase { - using Reg = Reg_; - using DType = Shared::DType; - using LoadMat = LoadMatBase; - - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; - - DEVICE void operator()(const DType* src, Reg& dst, int warp_offset, - int iterator_offset) { - int global_offset = warp_offset + iterator_offset; - // transpose the lane position if the shared memory is in - // column-major. 16 threads are mapped to the strided dimension - // of the data while the 2 threads are mapped to the contiguous - // dimension of the data. - int lane_row = this->lane_col_id() * LoadMat::kNumPerAccess; - int lane_col = this->lane_row_id(); - - for (int i = 0; i < kColExec; ++i) { + using Reg = Reg_; + using DType = Shared::DType; + using LoadMat = LoadMatBase; + + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; + + DEVICE void operator()(const DType* src, Reg& dst, int warp_offset, + int iterator_offset) { + int global_offset = warp_offset + iterator_offset; + // transpose the lane position if the shared memory is in + // column-major. 16 threads are mapped to the strided dimension + // of the data while the 2 threads are mapped to the contiguous + // dimension of the data. + int lane_row = this->lane_col_id() * LoadMat::kNumPerAccess; + int lane_col = this->lane_row_id(); + + for (int i = 0; i < kColExec; ++i) { #pragma unroll - for (int j = 0; j < kRowExec; ++j) { - int tile_offset = global_offset + - i * kSharedColStride * BaseShape::kCols + - j * BaseShape::kRows + - lane_col * kSharedColStride + lane_row; - int offset = shared_tile.fetch_physical_offset(tile_offset) - - iterator_offset; - - // issue the hardware-backed memory access instruction - this->ldmatrix(src + offset, dst(j, i).mutable_data()); - } - } + for (int j = 0; j < kRowExec; ++j) { + int tile_offset = + global_offset + i * kSharedColStride * BaseShape::kCols + + j * BaseShape::kRows + lane_col * kSharedColStride + lane_row; + int offset = + shared_tile.fetch_physical_offset(tile_offset) - iterator_offset; + + // issue the hardware-backed memory access instruction + this->ldmatrix(src + offset, dst(j, i).mutable_data()); + } } - - private: - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. - // Future refactoring of the program's concepts and interfaces should - // eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - static constexpr int kSharedColStride = Shared::kColStride; - Shared shared_tile; + } + + private: + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. + // Future refactoring of the program's concepts and interfaces should + // eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + static constexpr int kSharedColStride = Shared::kColStride; + Shared shared_tile; }; template : public StoreMatBase { - using Reg = Reg_; - using Shared = Shared_; - using DType = Shared::DType; - using StoreMat = StoreMatBase; + using Reg = Reg_; + using Shared = Shared_; + using DType = Shared::DType; + using StoreMat = StoreMatBase; - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; - DEVICE void operator()(const Reg& src, DType* dst, int warp_offset) { + DEVICE void operator()(const Reg& src, DType* dst, int warp_offset) { #pragma unroll - for (int j = 0; j < kColExec; ++j) { + for (int j = 0; j < kColExec; ++j) { #pragma unroll - for (int i = 0; i < kRowExec; ++i) { - int lane_row = this->lane_row_id(); - int lane_col = this->lane_col_id(); + for (int i = 0; i < kRowExec; ++i) { + int lane_row = this->lane_row_id(); + int lane_col = this->lane_col_id(); - int tile_offset = warp_offset + i * kRowStride + j * kColStride; - int row = 0, col = 0; + int tile_offset = warp_offset + i * kRowStride + j * kColStride; + int row = 0, col = 0; #pragma unroll - for (int m = 0; m < StoreMat::kSegRows; ++m) { - row = lane_row + m * StoreMat::kThreadRows; + for (int m = 0; m < StoreMat::kSegRows; ++m) { + row = lane_row + m * StoreMat::kThreadRows; #pragma unroll - for (int n = 0; n < StoreMat::kSegCols; ++n) { - col = StoreMat::kElemPerSeg * - (lane_col + n * StoreMat::kThreadCols); - int in_tile_offset = row * Shared::kRowStride + col; - int offset = tile_offset + in_tile_offset; - int swizzled_offset = - shared_tile.fetch_physical_offset(offset); - - const PackedType* src_ptr = - reinterpret_cast( - src(i, j).data()); - PackedType* dst_ptr = - reinterpret_cast(dst); - - dst_ptr[swizzled_offset / StoreMat::kElemPerSeg] = - src_ptr[n * StoreMat::kSegCols + m]; - } - } - } + for (int n = 0; n < StoreMat::kSegCols; ++n) { + col = + StoreMat::kElemPerSeg * (lane_col + n * StoreMat::kThreadCols); + int in_tile_offset = row * Shared::kRowStride + col; + int offset = tile_offset + in_tile_offset; + int swizzled_offset = shared_tile.fetch_physical_offset(offset); + + const PackedType* src_ptr = + reinterpret_cast(src(i, j).data()); + PackedType* dst_ptr = reinterpret_cast(dst); + + dst_ptr[swizzled_offset / StoreMat::kElemPerSeg] = + src_ptr[n * StoreMat::kSegCols + m]; + } } + } } - - private: - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. - // Future refactoring of the program's concepts and interfaces should - // eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - using PackedType = - typename Packing::PackedType; - - static constexpr int kSharedRowStride = Shared::kRowStride; - static constexpr int kRowStride = BaseShape::kRows * kSharedRowStride; - static constexpr int kColStride = BaseShape::kCols; - - Shared shared_tile; + } + + private: + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. + // Future refactoring of the program's concepts and interfaces should + // eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + using PackedType = typename Packing::PackedType; + + static constexpr int kSharedRowStride = Shared::kRowStride; + static constexpr int kRowStride = BaseShape::kRows * kSharedRowStride; + static constexpr int kColStride = BaseShape::kCols; + + Shared shared_tile; }; template : public StoreMatBase { - using Reg = Reg_; - using Shared = Shared_; - using DType = Shared::DType; - using StoreMat = StoreMatBase; + using Reg = Reg_; + using Shared = Shared_; + using DType = Shared::DType; + using StoreMat = StoreMatBase; - static constexpr int kRowExec = kRowExec_; - static constexpr int kColExec = kColExec_; + static constexpr int kRowExec = kRowExec_; + static constexpr int kColExec = kColExec_; - DEVICE void operator()(const Reg& src, DType* dst, int warp_offset) { + DEVICE void operator()(const Reg& src, DType* dst, int warp_offset) { #pragma unroll - for (int j = 0; j < kColExec; ++j) { + for (int j = 0; j < kColExec; ++j) { #pragma unroll - for (int i = 0; i < kRowExec; ++i) { - int tile_offset = warp_offset + j * kColStride + i * kRowStride; - int lane_row = this->lane_row_id(); - int lane_col = this->lane_col_id(); + for (int i = 0; i < kRowExec; ++i) { + int tile_offset = warp_offset + j * kColStride + i * kRowStride; + int lane_row = this->lane_row_id(); + int lane_col = this->lane_col_id(); - int row = 0, col = 0; + int row = 0, col = 0; #pragma unroll - for (int m = 0; m < StoreMat::kSegRows; ++m) { - row = StoreMat::kElemPerSeg * - (lane_row + m * StoreMat::kThreadRows); + for (int m = 0; m < StoreMat::kSegRows; ++m) { + row = StoreMat::kElemPerSeg * (lane_row + m * StoreMat::kThreadRows); #pragma unroll - for (int n = 0; n < StoreMat::kSegCols; ++n) { - col = lane_col + n * StoreMat::kThreadCols; - - int in_tile_offset = col * Shared::kColStride + row; - int offset = tile_offset + in_tile_offset; - int swizzled_offset = - shared_tile.fetch_physical_offset(offset); - - const PackedType* src_ptr = - reinterpret_cast( - src(i, j).data()); - PackedType* dst_ptr = - reinterpret_cast(dst); - dst_ptr[swizzled_offset / StoreMat::kElemPerSeg] = - src_ptr[n * StoreMat::kSegCols + m]; - } - } - } + for (int n = 0; n < StoreMat::kSegCols; ++n) { + col = lane_col + n * StoreMat::kThreadCols; + + int in_tile_offset = col * Shared::kColStride + row; + int offset = tile_offset + in_tile_offset; + int swizzled_offset = shared_tile.fetch_physical_offset(offset); + + const PackedType* src_ptr = + reinterpret_cast(src(i, j).data()); + PackedType* dst_ptr = reinterpret_cast(dst); + dst_ptr[swizzled_offset / StoreMat::kElemPerSeg] = + src_ptr[n * StoreMat::kSegCols + m]; + } } + } } - - private: - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. - // Future refactoring of the program's concepts and interfaces should - // eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - using PackedType = - typename Packing::PackedType; - - static constexpr int kSharedColStride = Shared::kColStride; - static constexpr int kRowStride = BaseShape::kRows; - static constexpr int kColStride = BaseShape::kCols * kSharedColStride; - - Shared shared_tile; + } + + private: + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. + // Future refactoring of the program's concepts and interfaces should + // eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + using PackedType = typename Packing::PackedType; + + static constexpr int kSharedColStride = Shared::kColStride; + static constexpr int kRowStride = BaseShape::kRows; + static constexpr int kColStride = BaseShape::kCols * kSharedColStride; + + Shared shared_tile; }; } // namespace detail @@ -258,53 +247,53 @@ struct RegToSharedStorerImpl struct SharedToRegLoader { - using Reg = Reg_; - using DType = typename Reg::DType::DType; // the element data type - using WarpLayout = WarpLayout_; - static constexpr WarpReuse kMode = kMode_; - - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - - // how many times a `BaseTile` is executed along the row and column - // direction. - static constexpr int kRowExec = Reg::kRows; - static constexpr int kColExec = Reg::kCols; - - static_assert(kRowExec && kColExec, - "Execution count should be greater than 0."); - - template - DEVICE void operator()(const Shared& src, Reg& dst) { - static_assert(std::is_same_v, - "The data type of Shared and Reg must be the same."); - static_assert(Shared::kRows % WarpLayout::kRows == 0, - "The current implementation requires Shared::kRows must " - "be divisible by WarpLayout::kRows"); - static_assert(Shared::kCols % WarpLayout::kCols == 0, - "The current implementation requires Shared::kCols must " - "be divisible by WarpLayout::kCols"); - - static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; - static_assert(kSharedAccessInBytes % 32 == 0, - "The number of bytes in a warp tile must be divisible by " - "32."); - - using SharedOffset = - warp::SharedOffsetHelper; - SharedOffset shared_offset_; - - // advance the pointer to input data to the current warp according to - // warp reuse mode. - int warp_offset = shared_offset_.get_warp_offset(); - int iterator_offset = src.get_offset(); - - using Loader = detail::SharedToRegLoaderImpl; - Loader loader; - loader(src.data(), dst, warp_offset, iterator_offset); - } + using Reg = Reg_; + using DType = typename Reg::DType::DType; // the element data type + using WarpLayout = WarpLayout_; + static constexpr WarpReuse kMode = kMode_; + + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + + // how many times a `BaseTile` is executed along the row and column + // direction. + static constexpr int kRowExec = Reg::kRows; + static constexpr int kColExec = Reg::kCols; + + static_assert(kRowExec && kColExec, + "Execution count should be greater than 0."); + + template + DEVICE void operator()(const Shared& src, Reg& dst) { + static_assert(std::is_same_v, + "The data type of Shared and Reg must be the same."); + static_assert(Shared::kRows % WarpLayout::kRows == 0, + "The current implementation requires Shared::kRows must " + "be divisible by WarpLayout::kRows"); + static_assert(Shared::kCols % WarpLayout::kCols == 0, + "The current implementation requires Shared::kCols must " + "be divisible by WarpLayout::kCols"); + + static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; + static_assert(kSharedAccessInBytes % 32 == 0, + "The number of bytes in a warp tile must be divisible by " + "32."); + + using SharedOffset = + warp::SharedOffsetHelper; + SharedOffset shared_offset_; + + // advance the pointer to input data to the current warp according to + // warp reuse mode. + int warp_offset = shared_offset_.get_warp_offset(); + int iterator_offset = src.get_offset(); + + using Loader = detail::SharedToRegLoaderImpl; + Loader loader; + loader(src.data(), dst, warp_offset, iterator_offset); + } }; /// @brief partial specialization for 16x16x16 wmma's output, and st.shared.f32 @@ -312,69 +301,68 @@ struct SharedToRegLoader { /// matrix. template struct RegToSharedStorer { - using Reg = Reg_; - // elementary data type stored in the register tile. - using DType = typename Reg::DType::DType; - using WarpLayout = WarpLayout_; - - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. Future refactoring of the program's concepts and interfaces - // should eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - - // how many times a `BaseTile` is executed along the row and column - // direction. - static constexpr int kRowExec = Reg::kRows; - static constexpr int kColExec = Reg::kCols; - - static_assert(kRowExec && kColExec, - "Execution count should be greater than 0."); - - /// @brief Store the WMMA output register tile to shared memory. The source - /// is the current thread's local register tile, and the destination - /// is shared memory. - template - DEVICE void operator()(const Reg& src, Shared& dst_) { - static_assert(std::is_same_v, - "The element data type of Shared and Register tile must " - "be the same."); - static_assert((Reg::kNumel * Reg::DType::kNumel * 32 /*warp size*/ * - WarpLayout::kNumel) == Shared::kNumel, - "The number of elements held in the local register file " - "by all threads in the CTA must be the same as the " - "number held in the shared memory tile."); - static_assert( - Shared::kType == Reg::kType, - "The layout of Shared and Register tile must be the same."); - static_assert(Shared::kRows % BaseShape::kRows == 0, - "The number of shared memory rows must be divisible by " - "the base tile row."); - static_assert(Shared::kCols % BaseShape::kCols == 0, - "The number of shared memory columns must be divisible " - "by the base tile column."); - - static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; - - static_assert(kSharedAccessInBytes % 32 == 0, - "The number of bytes in a warp tile must be divisible by " - "32."); - - // advance the pointer to input data to the current warp according to - // warp reuse mode. During the store process, threads do not write to - // the same shared memory location, thus the warp reuse mode is set to - // `Cont`. - using SharedOffset = warp::SharedOffsetHelper; - SharedOffset shared_offset_; - int warp_offset = shared_offset_.get_warp_offset(); - - using Storer = detail::RegToSharedStorerImpl; - Storer storer; - - storer(src, dst_.mutable_data(), warp_offset); - } + using Reg = Reg_; + // elementary data type stored in the register tile. + using DType = typename Reg::DType::DType; + using WarpLayout = WarpLayout_; + + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. Future refactoring of the program's concepts and interfaces + // should eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + + // how many times a `BaseTile` is executed along the row and column + // direction. + static constexpr int kRowExec = Reg::kRows; + static constexpr int kColExec = Reg::kCols; + + static_assert(kRowExec && kColExec, + "Execution count should be greater than 0."); + + /// @brief Store the WMMA output register tile to shared memory. The source + /// is the current thread's local register tile, and the destination + /// is shared memory. + template + DEVICE void operator()(const Reg& src, Shared& dst_) { + static_assert(std::is_same_v, + "The element data type of Shared and Register tile must " + "be the same."); + static_assert((Reg::kNumel * Reg::DType::kNumel * 32 /*warp size*/ * + WarpLayout::kNumel) == Shared::kNumel, + "The number of elements held in the local register file " + "by all threads in the CTA must be the same as the " + "number held in the shared memory tile."); + static_assert(Shared::kType == Reg::kType, + "The layout of Shared and Register tile must be the same."); + static_assert(Shared::kRows % BaseShape::kRows == 0, + "The number of shared memory rows must be divisible by " + "the base tile row."); + static_assert(Shared::kCols % BaseShape::kCols == 0, + "The number of shared memory columns must be divisible " + "by the base tile column."); + + static constexpr int kSharedAccessInBytes = Shared::SwizzleBytes; + + static_assert(kSharedAccessInBytes % 32 == 0, + "The number of bytes in a warp tile must be divisible by " + "32."); + + // advance the pointer to input data to the current warp according to + // warp reuse mode. During the store process, threads do not write to + // the same shared memory location, thus the warp reuse mode is set to + // `Cont`. + using SharedOffset = warp::SharedOffsetHelper; + SharedOffset shared_offset_; + int warp_offset = shared_offset_.get_warp_offset(); + + using Storer = detail::RegToSharedStorerImpl; + Storer storer; + + storer(src, dst_.mutable_data(), warp_offset); + } }; } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/sync.hpp b/include/cell/copy/sync.hpp index 4c4877ac..0f83b129 100644 --- a/include/cell/copy/sync.hpp +++ b/include/cell/copy/sync.hpp @@ -10,18 +10,18 @@ namespace tilefusion::cell::copy { template DEVICE void wait_group() { #if defined(CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); #endif } DEVICE void commit_copy_group() { #if defined(CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.commit_group;\n" ::); + asm volatile("cp.async.commit_group;\n" ::); #endif } DEVICE void __copy_async() { - commit_copy_group(); - wait_group<0>(); + commit_copy_group(); + wait_group<0>(); } } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/vectorize.hpp b/include/cell/copy/vectorize.hpp index b38c57da..66c957b2 100644 --- a/include/cell/copy/vectorize.hpp +++ b/include/cell/copy/vectorize.hpp @@ -13,49 +13,49 @@ namespace tilefusion::cell::copy { */ template struct Vectorize { - using UnVecType = Element; - using VecType = Element; - static constexpr int vectorize_nums = kVecNums; - - /** - * @brief Copy data from unvectorized to vectorized. - * - * @param src Source data. - * @param dst Destination data. - */ - DEVICE void operator()(const UnVecType* src, UnVecType* dst) { - const VecType* src_vec = reinterpret_cast(src); - VecType* dst_vec = reinterpret_cast(dst); - *dst_vec = *src_vec; - } + using UnVecType = Element; + using VecType = Element; + static constexpr int vectorize_nums = kVecNums; + + /** + * @brief Copy data from unvectorized to vectorized. + * + * @param src Source data. + * @param dst Destination data. + */ + DEVICE void operator()(const UnVecType* src, UnVecType* dst) { + const VecType* src_vec = reinterpret_cast(src); + VecType* dst_vec = reinterpret_cast(dst); + *dst_vec = *src_vec; + } }; template <> struct Vectorize<__half, 2> { - using UnVecType = __half; - using VecType = __half2; - static constexpr int vectorize_nums = 2; - static constexpr int vectorize_bits = 32; - - DEVICE void operator()(const __half* src, __half* dst) { - const __half2* src_vec = reinterpret_cast(src); - __half2* dst_vec = reinterpret_cast<__half2*>(dst); - *dst_vec = *src_vec; - } + using UnVecType = __half; + using VecType = __half2; + static constexpr int vectorize_nums = 2; + static constexpr int vectorize_bits = 32; + + DEVICE void operator()(const __half* src, __half* dst) { + const __half2* src_vec = reinterpret_cast(src); + __half2* dst_vec = reinterpret_cast<__half2*>(dst); + *dst_vec = *src_vec; + } }; template <> struct Vectorize { - using UnVecType = float; - using VecType = float2; - static constexpr int vectorize_nums = 2; - static constexpr int vectorize_bits = 64; - - DEVICE void operator()(const float* src, float* dst) { - const float2* src_vec = reinterpret_cast(src); - float2* dst_vec = reinterpret_cast(dst); - *dst_vec = *src_vec; - } + using UnVecType = float; + using VecType = float2; + static constexpr int vectorize_nums = 2; + static constexpr int vectorize_bits = 64; + + DEVICE void operator()(const float* src, float* dst) { + const float2* src_vec = reinterpret_cast(src); + float2* dst_vec = reinterpret_cast(dst); + *dst_vec = *src_vec; + } }; } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/warp.hpp b/include/cell/copy/warp.hpp index 1529311d..a39b5f17 100644 --- a/include/cell/copy/warp.hpp +++ b/include/cell/copy/warp.hpp @@ -22,40 +22,40 @@ struct WarpOffsetHelper; template struct WarpOffsetHelper { - static constexpr int kRowStride = kRowStride_; - static constexpr int kColStride = kColStride_; + static constexpr int kRowStride = kRowStride_; + static constexpr int kColStride = kColStride_; - DEVICE int operator()(int i, int j) const { - return i * kRowStride + j * kColStride; - } + DEVICE int operator()(int i, int j) const { + return i * kRowStride + j * kColStride; + } }; template struct WarpOffsetHelper { - static constexpr int kRowStride = kRowStride_; - static constexpr int kColStride = kColStride_; + static constexpr int kRowStride = kRowStride_; + static constexpr int kColStride = kColStride_; - DEVICE int operator()(int i, int j) const { return j * kColStride; } + DEVICE int operator()(int i, int j) const { return j * kColStride; } }; template struct WarpOffsetHelper { - static constexpr int kRowStride = kRowStride_; - static constexpr int kColStride = kColStride_; + static constexpr int kRowStride = kRowStride_; + static constexpr int kColStride = kColStride_; - DEVICE int operator()(int i, int j) const { return i * kRowStride; } + DEVICE int operator()(int i, int j) const { return i * kRowStride; } }; /// @brief Helper for pretty printing a BaseTile's static shape-related /// information. This printer works ONLY on the host. struct BaseTilePrettyPrinter { - template - static HOST void print(std::ostream& out, const BaseShape& tile) { - // parameter `tile` here is not used - out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols - << "), Numel = " << BaseShape::kNumel << ", ThreadLayout = (" - << BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")"; - } + template + static HOST void print(std::ostream& out, const BaseShape& tile) { + // parameter `tile` here is not used + out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols + << "), Numel = " << BaseShape::kNumel << ", ThreadLayout = (" + << BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")"; + } }; } // namespace @@ -64,32 +64,32 @@ struct BaseTilePrettyPrinter { // calculates the row index of the current thread. template DEVICE int warp_row_id() { - /* - * Example1: suppose the warp layout is RowMajor<2,2>, like this: - * |-|-----|-----| - * |0|warp0|warp1| - * |-|-----|-----| - * |1|warp2|warp3| - * |-|-----|-----|, and the threadIdx is 67, then the warp row is 1. - * - * Example2: suppose the warp layout is ColMajor<2,2>, like this: - * |-|-----|-----| - * |0|warp0|warp2| - * |-|-----|-----| - * |1|warp1|warp3| - * |-|-----|-----|, and the threadIdx is 67, then the warp row is 0. - */ - int wid = threadIdx.x / WARP_SIZE; - - switch (tl::layout_type) { - case tl::Layout::kRowMajor: - return wid / tl::num_cols; - case tl::Layout::kColMajor: - return wid % tl::num_rows; - default: - assert(false && "Not implemented yet."); - return -1; - } + /* + * Example1: suppose the warp layout is RowMajor<2,2>, like this: + * |-|-----|-----| + * |0|warp0|warp1| + * |-|-----|-----| + * |1|warp2|warp3| + * |-|-----|-----|, and the threadIdx is 67, then the warp row is 1. + * + * Example2: suppose the warp layout is ColMajor<2,2>, like this: + * |-|-----|-----| + * |0|warp0|warp2| + * |-|-----|-----| + * |1|warp1|warp3| + * |-|-----|-----|, and the threadIdx is 67, then the warp row is 0. + */ + int wid = threadIdx.x / WARP_SIZE; + + switch (tl::layout_type) { + case tl::Layout::kRowMajor: + return wid / tl::num_cols; + case tl::Layout::kColMajor: + return wid % tl::num_rows; + default: + assert(false && "Not implemented yet."); + return -1; + } } // @brief In a thread block, warps are organized as 2-D matrices, each with @@ -97,142 +97,137 @@ DEVICE int warp_row_id() { // calculates the column index of the current thread. template DEVICE int warp_col_id() { - /* - * Example1: suppose the warp layout is RowMajor<2,2>, like this: - * |-----|-----| - * | 0 | 1 | - * |-----|-----| - * |warp0|warp1| - * |-----|-----| - * |warp2|warp3| - * |-----|-----|, and the threadIdx is 67, then the warp col is 0. - * - * Example2: suppose the warp layout is ColMajor<2,2>, like this: - * |-----|-----| - * | 0 | 1 | - * |-----|-----| - * |warp0|warp2| - * |-----|-----| - * |warp1|warp3| - * |-----|-----|, and the threadIdx is 67, then the warp row is 1. - */ - int wid = threadIdx.x / WARP_SIZE; - - switch (tl::layout_type) { - case tl::Layout::kRowMajor: - return wid % tl::num_cols; - case tl::Layout::kColMajor: - return wid / tl::num_rows; - default: - assert(false && "Not implemented yet."); - return -1; - } + /* + * Example1: suppose the warp layout is RowMajor<2,2>, like this: + * |-----|-----| + * | 0 | 1 | + * |-----|-----| + * |warp0|warp1| + * |-----|-----| + * |warp2|warp3| + * |-----|-----|, and the threadIdx is 67, then the warp col is 0. + * + * Example2: suppose the warp layout is ColMajor<2,2>, like this: + * |-----|-----| + * | 0 | 1 | + * |-----|-----| + * |warp0|warp2| + * |-----|-----| + * |warp1|warp3| + * |-----|-----|, and the threadIdx is 67, then the warp row is 1. + */ + int wid = threadIdx.x / WARP_SIZE; + + switch (tl::layout_type) { + case tl::Layout::kRowMajor: + return wid % tl::num_cols; + case tl::Layout::kColMajor: + return wid / tl::num_rows; + default: + assert(false && "Not implemented yet."); + return -1; + } } template HOST_DEVICE constexpr int warp_tile_rows() { - if constexpr (kMode == WarpReuse::kCont) { - return kSharedRows / kWarpRows; - } else if constexpr (kMode == WarpReuse::kRowReuseCont) { - return kSharedRows / kWarpRows; - } else if constexpr (kMode == WarpReuse::kColReuseCont) { - return kSharedRows; - } - return -1; + if constexpr (kMode == WarpReuse::kCont) { + return kSharedRows / kWarpRows; + } else if constexpr (kMode == WarpReuse::kRowReuseCont) { + return kSharedRows / kWarpRows; + } else if constexpr (kMode == WarpReuse::kColReuseCont) { + return kSharedRows; + } + return -1; } template HOST_DEVICE constexpr int warp_tile_cols() { - if constexpr (kMode == WarpReuse::kCont) { - return kSharedCols / kWarpCols; - } else if constexpr (kMode == WarpReuse::kRowReuseCont) { - return kSharedCols; - } else if constexpr (kMode == WarpReuse::kColReuseCont) { - return kSharedCols / kWarpCols; - } - return -1; + if constexpr (kMode == WarpReuse::kCont) { + return kSharedCols / kWarpCols; + } else if constexpr (kMode == WarpReuse::kRowReuseCont) { + return kSharedCols; + } else if constexpr (kMode == WarpReuse::kColReuseCont) { + return kSharedCols / kWarpCols; + } + return -1; } template struct ExecCounter { - using BaseShape = BaseShape_; - using Tile = Tile_; - - static_assert( - Tile::kCols % BaseShape::kCols == 0, - "The number of shared memory columns must be divisible by the base " - "tile column.\n"); - static_assert( - Tile::kRows % BaseShape::kRows == 0, - "The current implementation requires that the number of shared " - "memory rows be divisible by the base tile row.\n"); - - static constexpr int kWarpsPerRow = tl::num_rows; - static constexpr int kWarpsPerCol = tl::num_cols; - static constexpr WarpReuse kMode = kMode_; - - // @brief This function returns the number of times a `BaseTile` is executed - // along the direction of the shared memory row. - DEVICE static constexpr int row_exec_count() { - switch (kMode) { - // Warps in the same columns (`warps_per_row` in total) repeatedly - // load the shared memory rows. Therefore, `row_exec` is not divided - // by warps_per_row. - case WarpReuse::kColReuseCont: - return Tile::kRows / BaseShape::kRows; - default: // Cont, RowReuseCont hit this case. - return Tile::kRows / BaseShape::kRows / kWarpsPerRow; - } + using BaseShape = BaseShape_; + using Tile = Tile_; + + static_assert( + Tile::kCols % BaseShape::kCols == 0, + "The number of shared memory columns must be divisible by the base " + "tile column.\n"); + static_assert(Tile::kRows % BaseShape::kRows == 0, + "The current implementation requires that the number of shared " + "memory rows be divisible by the base tile row.\n"); + + static constexpr int kWarpsPerRow = tl::num_rows; + static constexpr int kWarpsPerCol = tl::num_cols; + static constexpr WarpReuse kMode = kMode_; + + // @brief This function returns the number of times a `BaseTile` is executed + // along the direction of the shared memory row. + DEVICE static constexpr int row_exec_count() { + switch (kMode) { + // Warps in the same columns (`warps_per_row` in total) repeatedly + // load the shared memory rows. Therefore, `row_exec` is not divided + // by warps_per_row. + case WarpReuse::kColReuseCont: + return Tile::kRows / BaseShape::kRows; + default: // Cont, RowReuseCont hit this case. + return Tile::kRows / BaseShape::kRows / kWarpsPerRow; } - - DEVICE static constexpr int col_exec_count() { - switch (kMode) { - // Warps in the same rows (`warps_per_col` in total) repeatedly load - // the shared memory columns. Therefore, `col_exec` is not divided - // by `warps_per_col`. - case WarpReuse::kRowReuseCont: - return Tile::kCols / BaseShape::kCols; - default: // Cont, ColReuseCont hit this case. - return Tile::kCols / BaseShape::kCols / kWarpsPerCol; - } + } + + DEVICE static constexpr int col_exec_count() { + switch (kMode) { + // Warps in the same rows (`warps_per_col` in total) repeatedly load + // the shared memory columns. Therefore, `col_exec` is not divided + // by `warps_per_col`. + case WarpReuse::kRowReuseCont: + return Tile::kCols / BaseShape::kCols; + default: // Cont, ColReuseCont hit this case. + return Tile::kCols / BaseShape::kCols / kWarpsPerCol; } + } - static constexpr int kRowExec = row_exec_count(); - static constexpr int kColExec = col_exec_count(); + static constexpr int kRowExec = row_exec_count(); + static constexpr int kColExec = col_exec_count(); }; template struct GlobalOffsetHelper { - static constexpr WarpReuse kMode = kMode_; - using WarpLayout = WarpLayout_; - - // @brief This function returns the offset to the start position of the - // current warp in the shared memory according to the warp reuse - // mode. - template - DEVICE int get_warp_offset() { - // Tile shape for a single warp - constexpr static int kWarpShapeRow = - Tile::kRows / tl::num_rows; - constexpr static int kWarpShapeCol = - Tile::kCols / tl::num_cols; - - constexpr static int kWarpRstride = - Tile::kType == tl::Layout::kRowMajor - ? Tile::kRowStride * kWarpShapeRow - : kWarpShapeRow; - constexpr static int kWarpCstride = - Tile::kType == tl::Layout::kRowMajor - ? kWarpShapeCol - : Tile::kColStride * kWarpShapeCol; - - using Offset = WarpOffsetHelper; - Offset offset_; - return offset_(warp_row_id(), warp_col_id()); - } + static constexpr WarpReuse kMode = kMode_; + using WarpLayout = WarpLayout_; + + // @brief This function returns the offset to the start position of the + // current warp in the shared memory according to the warp reuse + // mode. + template + DEVICE int get_warp_offset() { + // Tile shape for a single warp + constexpr static int kWarpShapeRow = Tile::kRows / tl::num_rows; + constexpr static int kWarpShapeCol = Tile::kCols / tl::num_cols; + + constexpr static int kWarpRstride = Tile::kType == tl::Layout::kRowMajor + ? Tile::kRowStride * kWarpShapeRow + : kWarpShapeRow; + constexpr static int kWarpCstride = Tile::kType == tl::Layout::kRowMajor + ? kWarpShapeCol + : Tile::kColStride * kWarpShapeCol; + + using Offset = WarpOffsetHelper; + Offset offset_; + return offset_(warp_row_id(), warp_col_id()); + } }; template struct SharedOffsetHelper { - DEVICE int get_warp_offset() { - switch (kMode) { - case WarpReuse::kCont: - return warp_row_id() * kRowStride * - BaseShape::kRows * Shared::kRowStride + - warp_col_id() * kColStride * - BaseShape::kCols; - case WarpReuse::kRowReuseCont: - return warp_row_id() * kRowStride * - BaseShape::kRows * Shared::kRowStride; - default: - assert(false && "Not implemented yet."); - return -1; - } + DEVICE int get_warp_offset() { + switch (kMode) { + case WarpReuse::kCont: + return warp_row_id() * kRowStride * BaseShape::kRows * + Shared::kRowStride + + warp_col_id() * kColStride * BaseShape::kCols; + case WarpReuse::kRowReuseCont: + return warp_row_id() * kRowStride * BaseShape::kRows * + Shared::kRowStride; + default: + assert(false && "Not implemented yet."); + return -1; } + } - private: - using Shared = Shared_; - using WarpLayout = WarpLayout_; - using BaseShape = BaseShape_; - static constexpr WarpReuse kMode = kMode_; + private: + using Shared = Shared_; + using WarpLayout = WarpLayout_; + using BaseShape = BaseShape_; + static constexpr WarpReuse kMode = kMode_; - constexpr static int kTilePerRow = Shared::kRows / BaseShape::kRows; - constexpr static int kTilePerCol = Shared::kCols / BaseShape::kCols; + constexpr static int kTilePerRow = Shared::kRows / BaseShape::kRows; + constexpr static int kTilePerCol = Shared::kCols / BaseShape::kCols; - // TODO(KuangjuX): hotfix this. - constexpr static int kRowStride = kTilePerRow / tl::num_rows; - constexpr static int kColStride = kTilePerCol / tl::num_cols; + // TODO(KuangjuX): hotfix this. + constexpr static int kRowStride = kTilePerRow / tl::num_rows; + constexpr static int kColStride = kTilePerCol / tl::num_cols; }; template struct SharedOffsetHelper { - DEVICE int get_warp_offset() { - switch (kMode) { - case WarpReuse::kCont: - return warp_row_id() * kRowStride * - BaseShape::kRows + - warp_col_id() * kColStride * - BaseShape::kCols * Shared::kColStride; - case WarpReuse::kColReuseCont: - return warp_col_id() * kColStride * - BaseShape::kCols * Shared::kColStride; - default: - assert(false && "Not implemented yet."); - return -1; - } + DEVICE int get_warp_offset() { + switch (kMode) { + case WarpReuse::kCont: + return warp_row_id() * kRowStride * BaseShape::kRows + + warp_col_id() * kColStride * BaseShape::kCols * + Shared::kColStride; + case WarpReuse::kColReuseCont: + return warp_col_id() * kColStride * BaseShape::kCols * + Shared::kColStride; + default: + assert(false && "Not implemented yet."); + return -1; } + } - private: - using Shared = Shared_; - using WarpLayout = WarpLayout_; - using BaseShape = BaseShape_; - static constexpr WarpReuse kMode = kMode_; + private: + using Shared = Shared_; + using WarpLayout = WarpLayout_; + using BaseShape = BaseShape_; + static constexpr WarpReuse kMode = kMode_; - constexpr static int kTilePerRow = Shared::kRows / BaseShape::kRows; - constexpr static int kTilePerCol = Shared::kCols / BaseShape::kCols; + constexpr static int kTilePerRow = Shared::kRows / BaseShape::kRows; + constexpr static int kTilePerCol = Shared::kCols / BaseShape::kCols; - constexpr static int kRowStride = kTilePerRow / tl::num_rows; - constexpr static int kColStride = kTilePerCol / tl::num_cols; + constexpr static int kRowStride = kTilePerRow / tl::num_rows; + constexpr static int kColStride = kTilePerCol / tl::num_cols; }; } // namespace tilefusion::cell::copy::warp diff --git a/include/cell/mask.hpp b/include/cell/mask.hpp index af6ad6a6..2f2a0246 100644 --- a/include/cell/mask.hpp +++ b/include/cell/mask.hpp @@ -7,9 +7,9 @@ namespace tilefusion::cell { enum class MaskMode { - kNone = 0U, // No mask - kCausal = 1U, // Causal mask - kCustom = 2U, // Custom mask + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask }; template struct ApplyMask { - using Element = RegTile::DType::DType; - static_assert(WarpLayout::kCols == 1, "WarpLayout::kCols must be 1"); - // static_assert(std::is_same_v || - // std::is_same_v, - // "Element must be float or half"); - // TODO(KuangjuX): support half precision. - static_assert(std::is_same_v, "Element must be float"); - - // Each thread processes 2 consecutive elements in the register tile. - static constexpr int kThreadStride = 2; - // Each 4 threads as a group to process a row of the register tile. - static constexpr int kThreadGroupSize = 4; - - static constexpr int kRegTileRows = RegTile::kRows; - static constexpr int kRegTileCols = RegTile::kCols; - static constexpr int kSubtileRows = RegTile::DType::kRows; - static constexpr int kSubtileCols = RegTile::DType::kCols; - - static constexpr int kWarpRows = WarpLayout::kRows; - static constexpr int kWarpCols = WarpLayout::kCols; - - // A BaseTile is a 16x16 tile by default. - static constexpr int kBaseShapeRows = BaseShape::kRows; - static constexpr int kBaseShapeCols = BaseShape::kCols; - - template - DEVICE void operator()(RegTile& tile, const int row_offset, - const int col_offset, Element mask_value) { - // Compute the column index offset for the current thread. - const int col_idx_offset = - col_offset + get_thread_col_offset() + get_warp_col_offset(); - - // Compute the row index offset for the current thread. - const int row_idx_offset = - row_offset + get_thread_row_offset() + get_warp_row_offset(); + using Element = RegTile::DType::DType; + static_assert(WarpLayout::kCols == 1, "WarpLayout::kCols must be 1"); + // static_assert(std::is_same_v || + // std::is_same_v, + // "Element must be float or half"); + // TODO(KuangjuX): support half precision. + static_assert(std::is_same_v, "Element must be float"); + + // Each thread processes 2 consecutive elements in the register tile. + static constexpr int kThreadStride = 2; + // Each 4 threads as a group to process a row of the register tile. + static constexpr int kThreadGroupSize = 4; + + static constexpr int kRegTileRows = RegTile::kRows; + static constexpr int kRegTileCols = RegTile::kCols; + static constexpr int kSubtileRows = RegTile::DType::kRows; + static constexpr int kSubtileCols = RegTile::DType::kCols; + + static constexpr int kWarpRows = WarpLayout::kRows; + static constexpr int kWarpCols = WarpLayout::kCols; + + // A BaseTile is a 16x16 tile by default. + static constexpr int kBaseShapeRows = BaseShape::kRows; + static constexpr int kBaseShapeCols = BaseShape::kCols; + + template + DEVICE void operator()(RegTile& tile, const int row_offset, + const int col_offset, Element mask_value) { + // Compute the column index offset for the current thread. + const int col_idx_offset = + col_offset + get_thread_col_offset() + get_warp_col_offset(); + + // Compute the row index offset for the current thread. + const int row_idx_offset = + row_offset + get_thread_row_offset() + get_warp_row_offset(); #pragma unroll - for (int m = 0; m < kRegTileRows; ++m) { + for (int m = 0; m < kRegTileRows; ++m) { #pragma unroll - for (int n = 0; n < kRegTileCols; ++n) { + for (int n = 0; n < kRegTileCols; ++n) { #pragma unroll - for (int i = 0; i < kSubtileRows; ++i) { + for (int i = 0; i < kSubtileRows; ++i) { #pragma unroll - for (int j = 0; j < kSubtileCols; ++j) { - /** - * In BaseTile Layout(ThreadIdx.x = 0): - * 0.00, 1.00, 128.00, 129.00 - * 8.00, 9.00, 136.00, 137.00 - * - * In BaseTile Layout(ThreadIdx.x = 1): - * 2.00, 3.00, 130.00, 131.00 - * 10.00, 11.00, 138.00, 139.00 - */ - - const int col_idx = - col_idx_offset + n * kBaseShapeCols + (j % 2) + - i * (kThreadGroupSize * kThreadStride); - const int row_idx = row_idx_offset + - m * kBaseShapeRows + - (j / 2) * (kBaseShapeRows / 2); - - if (col_idx > row_idx) { - tile(m, n)(i, j) = mask_value; - } - } - } + for (int j = 0; j < kSubtileCols; ++j) { + /** + * In BaseTile Layout(ThreadIdx.x = 0): + * 0.00, 1.00, 128.00, 129.00 + * 8.00, 9.00, 136.00, 137.00 + * + * In BaseTile Layout(ThreadIdx.x = 1): + * 2.00, 3.00, 130.00, 131.00 + * 10.00, 11.00, 138.00, 139.00 + */ + + const int col_idx = col_idx_offset + n * kBaseShapeCols + (j % 2) + + i * (kThreadGroupSize * kThreadStride); + const int row_idx = row_idx_offset + m * kBaseShapeRows + + (j / 2) * (kBaseShapeRows / 2); + + if (col_idx > row_idx) { + tile(m, n)(i, j) = mask_value; } + } } + } } + } - private: - DEVICE int lane_id() { return threadIdx.x % WARP_SIZE; } + private: + DEVICE int lane_id() { return threadIdx.x % WARP_SIZE; } - DEVICE int get_warp_row_offset() { - return get_warp_row_id() * get_warp_row_stride(); - } + DEVICE int get_warp_row_offset() { + return get_warp_row_id() * get_warp_row_stride(); + } - DEVICE int get_warp_col_offset() { - return get_warp_col_id() * get_warp_col_stride(); - } + DEVICE int get_warp_col_offset() { + return get_warp_col_id() * get_warp_col_stride(); + } - DEVICE int get_thread_row_offset() { return lane_id() / kThreadGroupSize; } + DEVICE int get_thread_row_offset() { return lane_id() / kThreadGroupSize; } - DEVICE int get_thread_col_offset() { - return (lane_id() % kThreadGroupSize) * kThreadStride; - } + DEVICE int get_thread_col_offset() { + return (lane_id() % kThreadGroupSize) * kThreadStride; + } - DEVICE int get_warp_row_stride() { - return RegTile::kRows * BaseShape::kRows; - } + DEVICE int get_warp_row_stride() { return RegTile::kRows * BaseShape::kRows; } - DEVICE int get_warp_col_stride() { - return RegTile::kCols * BaseShape::kCols; - } + DEVICE int get_warp_col_stride() { return RegTile::kCols * BaseShape::kCols; } - DEVICE int get_warp_row_id() { - return copy::warp::warp_row_id(); - } + DEVICE int get_warp_row_id() { return copy::warp::warp_row_id(); } - DEVICE int get_warp_col_id() { - return copy::warp::warp_col_id(); - } + DEVICE int get_warp_col_id() { return copy::warp::warp_col_id(); } }; } // namespace tilefusion::cell diff --git a/include/cell/pipeline.hpp b/include/cell/pipeline.hpp index 0089d228..794b7a5b 100644 --- a/include/cell/pipeline.hpp +++ b/include/cell/pipeline.hpp @@ -49,87 +49,87 @@ template struct Pipeline { - public: - // The number of iterations for the body kernel. - static constexpr int Iterations = Iterations_; - - DEVICE Pipeline(const Element* src_ptr, Element* dst_ptr) - : src_tile(SrcTile(src_ptr)), - tile_iter(TileIterator(src_tile.data())), - data_ptr(0), - cur_stages(0) { - // initialize the circular buffer - for (int i = 0; i < NUM_STAGES; i++) { - cyc_buffer[i] = DstTile(dst_ptr + i * DstTile::kNumel); - } + public: + // The number of iterations for the body kernel. + static constexpr int Iterations = Iterations_; + + DEVICE Pipeline(const Element* src_ptr, Element* dst_ptr) + : src_tile(SrcTile(src_ptr)), + tile_iter(TileIterator(src_tile.data())), + data_ptr(0), + cur_stages(0) { + // initialize the circular buffer + for (int i = 0; i < NUM_STAGES; i++) { + cyc_buffer[i] = DstTile(dst_ptr + i * DstTile::kNumel); } - - DEVICE Pipeline(SrcTile src_tile, DstTile dst_tiles[]) - : src_tile(src_tile), - tile_iter(TileIterator(src_tile.data())), - data_ptr(0), - cur_stages(0) { - for (int i = 0; i < NUM_STAGES; i++) { - cyc_buffer[i] = dst_tiles[i]; - } - } - - /** - * @brief Reset the source tile. - * @param src_ptr The pointer to the source tile. - */ - DEVICE void reset_src_tile(const Element* src_ptr) { - src_tile = SrcTile(src_ptr); - tile_iter = TileIterator(src_tile.data()); - data_ptr = 0; - } - - /** - * @brief Commit the copy operation. - */ - DEVICE void commit() { - copy(tile_iter(data_ptr), cyc_buffer[cur_stages % NUM_STAGES]); - data_ptr++; - cur_stages++; - } - - DEVICE const Element* get_dst_ptr_by_index(int index) const { - return cyc_buffer[index % NUM_STAGES].data(); + } + + DEVICE Pipeline(SrcTile src_tile, DstTile dst_tiles[]) + : src_tile(src_tile), + tile_iter(TileIterator(src_tile.data())), + data_ptr(0), + cur_stages(0) { + for (int i = 0; i < NUM_STAGES; i++) { + cyc_buffer[i] = dst_tiles[i]; } - - DEVICE const DstTile& get_dst_tile_by_index(int index) const { - return cyc_buffer[index % NUM_STAGES]; + } + + /** + * @brief Reset the source tile. + * @param src_ptr The pointer to the source tile. + */ + DEVICE void reset_src_tile(const Element* src_ptr) { + src_tile = SrcTile(src_ptr); + tile_iter = TileIterator(src_tile.data()); + data_ptr = 0; + } + + /** + * @brief Commit the copy operation. + */ + DEVICE void commit() { + copy(tile_iter(data_ptr), cyc_buffer[cur_stages % NUM_STAGES]); + data_ptr++; + cur_stages++; + } + + DEVICE const Element* get_dst_ptr_by_index(int index) const { + return cyc_buffer[index % NUM_STAGES].data(); + } + + DEVICE const DstTile& get_dst_tile_by_index(int index) const { + return cyc_buffer[index % NUM_STAGES]; + } + + DEVICE const Element* get_prev_dst() const { + return cyc_buffer[(cur_stages - 2) % NUM_STAGES].data(); + } + + DEVICE const Element* get_cur_dst() const { + return cyc_buffer[(cur_stages - 1) % NUM_STAGES].data(); + } + + /** + * @brief Dump the destination tile value. + * @param index The index of the destination tile. + */ + DEVICE void dump_dst_tile_value(int index) { + if (thread0()) { + printf("data[%d]:\n", index); + cyc_buffer[index % NUM_STAGES].dump_value(); } - - DEVICE const Element* get_prev_dst() const { - return cyc_buffer[(cur_stages - 2) % NUM_STAGES].data(); - } - - DEVICE const Element* get_cur_dst() const { - return cyc_buffer[(cur_stages - 1) % NUM_STAGES].data(); - } - - /** - * @brief Dump the destination tile value. - * @param index The index of the destination tile. - */ - DEVICE void dump_dst_tile_value(int index) { - if (thread0()) { - printf("data[%d]:\n", index); - cyc_buffer[index % NUM_STAGES].dump_value(); - } - } - - private: - static constexpr int kNumStages = NUM_STAGES; - - int data_ptr; - int cur_stages; - SrcTile src_tile; - // In multistage pipeline, the destination tile has circular buffer with a - // size of `NUM_STAGES`. - DstTile cyc_buffer[NUM_STAGES]; - TileIterator tile_iter; - Copy copy; + } + + private: + static constexpr int kNumStages = NUM_STAGES; + + int data_ptr; + int cur_stages; + SrcTile src_tile; + // In multistage pipeline, the destination tile has circular buffer with a + // size of `NUM_STAGES`. + DstTile cyc_buffer[NUM_STAGES]; + TileIterator tile_iter; + Copy copy; }; } // namespace tilefusion::cell diff --git a/include/cell/warp.hpp b/include/cell/warp.hpp index 75a72a1f..bb93bfcd 100644 --- a/include/cell/warp.hpp +++ b/include/cell/warp.hpp @@ -15,7 +15,7 @@ static constexpr uint32_t MASK_ALL = 0xFFFFFFFF; */ template DEVICE Element shuffle_sync(uint32_t mask, Element value, int src_lane) { - return __shfl_sync(mask, value, src_lane); + return __shfl_sync(mask, value, src_lane); } /** @@ -25,7 +25,7 @@ DEVICE Element shuffle_sync(uint32_t mask, Element value, int src_lane) { */ template DEVICE Element shuffle_down_sync(uint32_t mask, Element value, int delta) { - return __shfl_down_sync(mask, value, delta); + return __shfl_down_sync(mask, value, delta); } } // namespace tilefusion::cell diff --git a/include/config.hpp b/include/config.hpp index 271d3182..2ec592f6 100644 --- a/include/config.hpp +++ b/include/config.hpp @@ -4,26 +4,26 @@ #pragma once #if defined(__CUDA_ARCH__) - #define HOST_DEVICE __forceinline__ __host__ __device__ - #define DEVICE __forceinline__ __device__ - #define HOST __forceinline__ __host__ + #define HOST_DEVICE __forceinline__ __host__ __device__ + #define DEVICE __forceinline__ __device__ + #define HOST __forceinline__ __host__ #else - #define HOST_DEVICE inline - #define DEVICE inline - #define HOST inline + #define HOST_DEVICE inline + #define DEVICE inline + #define HOST inline #endif #if defined(__CUDACC__) - #define WARP_SIZE 32 + #define WARP_SIZE 32 #endif #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - #define CP_ASYNC_SM80_ENABLED + #define CP_ASYNC_SM80_ENABLED #endif // FP8 support requires CUDA 11.8+ AND Ada Lovelace (8.9+) or Hopper (9.0+) // architecture. The hardware support is detected at build time by CMake. #if defined(CUDA_FP8_HARDWARE_AVAILABLE) && CUDA_FP8_HARDWARE_AVAILABLE == 1 - #include - #define CUDA_FP8_AVAILABLE 1 + #include + #define CUDA_FP8_AVAILABLE 1 #endif diff --git a/include/cuda_utils.hpp b/include/cuda_utils.hpp index 231c0de0..4bb83482 100644 --- a/include/cuda_utils.hpp +++ b/include/cuda_utils.hpp @@ -22,49 +22,49 @@ inline int ceil_div(int a, int b) { return (a + b - 1) / b; } const char* cublasGetErrorString(cublasStatus_t status); inline void __cudaCheck(const cudaError err, const char* file, int line) { - if (err != cudaSuccess) { - fprintf(stderr, "%s(%d): CUDA error: %s.\n", file, line, - cudaGetErrorString(err)); - exit(EXIT_FAILURE); - } + if (err != cudaSuccess) { + fprintf(stderr, "%s(%d): CUDA error: %s.\n", file, line, + cudaGetErrorString(err)); + exit(EXIT_FAILURE); + } } #define CUDA_CHECK(call) __cudaCheck(call, __FILE__, __LINE__) inline void __cublasCheck(const cublasStatus_t err, const char* file, int line) { - if (err != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "%s(%d): Cublas error: %s.\n", file, line, - cublasGetErrorString(err)); - exit(EXIT_FAILURE); - } + if (err != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "%s(%d): Cublas error: %s.\n", file, line, + cublasGetErrorString(err)); + exit(EXIT_FAILURE); + } } #define CUBLAS_CHECK(call) __cublasCheck(call, __FILE__, __LINE__) -#define CUDA_DRIVER_CHECK(call) \ - do { \ - CUresult result = call; \ - if (result != CUDA_SUCCESS) { \ - const char* error_string; \ - cuGetErrorString(result, &error_string); \ - std::stringstream err; \ - err << "CUDA error: " << error_string << " (" << result << ") at " \ - << __FILE__ << ":" << __LINE__; \ - LOG(ERROR) << err.str(); \ - throw std::runtime_error(err.str()); \ - } \ - } while (0) +#define CUDA_DRIVER_CHECK(call) \ + do { \ + CUresult result = call; \ + if (result != CUDA_SUCCESS) { \ + const char* error_string; \ + cuGetErrorString(result, &error_string); \ + std::stringstream err; \ + err << "CUDA error: " << error_string << " (" << result << ") at " \ + << __FILE__ << ":" << __LINE__; \ + LOG(ERROR) << err.str(); \ + throw std::runtime_error(err.str()); \ + } \ + } while (0) inline void check_gpu_memory() { - size_t free_byte; - size_t total_byte; - CUDA_CHECK(cudaMemGetInfo(&free_byte, &total_byte)); + size_t free_byte; + size_t total_byte; + CUDA_CHECK(cudaMemGetInfo(&free_byte, &total_byte)); - double free_db = (double)free_byte; - double total_db = (double)total_byte; - double used_db = total_db - free_db; - printf("GPU memory usage: used = %f MB, free = %f MB, total = %f MB\n", - used_db / 1024.0 / 1024.0, free_db / 1024.0 / 1024.0, - total_db / 1024.0 / 1024.0); + double free_db = (double)free_byte; + double total_db = (double)total_byte; + double used_db = total_db - free_db; + printf("GPU memory usage: used = %f MB, free = %f MB, total = %f MB\n", + used_db / 1024.0 / 1024.0, free_db / 1024.0 / 1024.0, + total_db / 1024.0 / 1024.0); } } // namespace tilefusion diff --git a/include/jit/common.hpp b/include/jit/common.hpp index 63a89457..1054d67e 100644 --- a/include/jit/common.hpp +++ b/include/jit/common.hpp @@ -8,20 +8,20 @@ namespace tilefusion::jit { template static constexpr const char* get_type_string() { - if constexpr (std::is_same_v) { - return "float"; - } else if constexpr (std::is_same_v) { - return "double"; - } else if constexpr (std::is_same_v) { - return "int"; - } else if constexpr (std::is_same_v) { - return "__half"; - } else if constexpr (std::is_same_v) { - return "__bfloat16"; - } else { - // Makes the assertion dependent on the template parameter - // Only triggers when an unsupported type is actually used. - static_assert(sizeof(DType) == 0, "Unsupported data type"); - } + if constexpr (std::is_same_v) { + return "float"; + } else if constexpr (std::is_same_v) { + return "double"; + } else if constexpr (std::is_same_v) { + return "int"; + } else if constexpr (std::is_same_v) { + return "__half"; + } else if constexpr (std::is_same_v) { + return "__bfloat16"; + } else { + // Makes the assertion dependent on the template parameter + // Only triggers when an unsupported type is actually used. + static_assert(sizeof(DType) == 0, "Unsupported data type"); + } } } // namespace tilefusion::jit diff --git a/include/jit/compiler.hpp b/include/jit/compiler.hpp index 153eba41..d979fac0 100644 --- a/include/jit/compiler.hpp +++ b/include/jit/compiler.hpp @@ -19,65 +19,65 @@ namespace tilefusion::jit { * This class allows for runtime compilation of CUDA kernels using NVCC. */ class JitCompiler { - public: - static JitCompiler& instance(); + public: + static JitCompiler& instance(); - /** - * Compiles a CUDA kernel at runtime and returns a function pointer to the - * kernel. - * - * @param kernel_name The name of the kernel function to compile - * @param cuda_source The CUDA source code containing the kernel - * @param include_paths Additional include paths to pass to NVCC - * @param compile_args Additional compiler arguments to pass to NVCC - * @return Function pointer to the compiled kernel or nullptr if compilation - * fails - */ - CUfunction compile_kernel( - const std::string& kernel_name, const std::string& cuda_source, - const std::vector& include_paths = {}, - const std::vector& compile_args = {}); + /** + * Compiles a CUDA kernel at runtime and returns a function pointer to the + * kernel. + * + * @param kernel_name The name of the kernel function to compile + * @param cuda_source The CUDA source code containing the kernel + * @param include_paths Additional include paths to pass to NVCC + * @param compile_args Additional compiler arguments to pass to NVCC + * @return Function pointer to the compiled kernel or nullptr if compilation + * fails + */ + CUfunction compile_kernel(const std::string& kernel_name, + const std::string& cuda_source, + const std::vector& include_paths = {}, + const std::vector& compile_args = {}); - /** - * Gets a previously compiled kernel or compiles it if it doesn't exist. - * - * @param kernel_name The name of the kernel function - * @param cuda_source The CUDA source code - * @param include_paths Additional include paths - * @param compile_args Additional compiler arguments - * @return Function pointer to the compiled kernel - */ - CUfunction get_or_compile_kernel( - const std::string& kernel_name, const std::string& cuda_source, - const std::vector& include_paths = {}, - const std::vector& compile_args = {}); + /** + * Gets a previously compiled kernel or compiles it if it doesn't exist. + * + * @param kernel_name The name of the kernel function + * @param cuda_source The CUDA source code + * @param include_paths Additional include paths + * @param compile_args Additional compiler arguments + * @return Function pointer to the compiled kernel + */ + CUfunction get_or_compile_kernel( + const std::string& kernel_name, const std::string& cuda_source, + const std::vector& include_paths = {}, + const std::vector& compile_args = {}); - /** - * Clears the cache of compiled kernels. - */ - void clear_cache(); + /** + * Clears the cache of compiled kernels. + */ + void clear_cache(); - private: - JitCompiler(); - ~JitCompiler(); + private: + JitCompiler(); + ~JitCompiler(); - JitCompiler(const JitCompiler&) = delete; - JitCompiler& operator=(const JitCompiler&) = delete; + JitCompiler(const JitCompiler&) = delete; + JitCompiler& operator=(const JitCompiler&) = delete; - std::string compile_to_ptx(const std::string& cuda_source, - const std::vector& include_paths, - const std::vector& compile_args); + std::string compile_to_ptx(const std::string& cuda_source, + const std::vector& include_paths, + const std::vector& compile_args); - CUfunction load_ptx_and_get_kernel(const std::string& ptx, - const std::string& kernel_name); + CUfunction load_ptx_and_get_kernel(const std::string& ptx, + const std::string& kernel_name); - std::string write_to_temp_file(const std::string& content, - const std::string& extension); + std::string write_to_temp_file(const std::string& content, + const std::string& extension); - CUcontext cuda_context_; - std::unordered_map module_cache_; - std::unordered_map kernel_cache_; - std::mutex mutex_; + CUcontext cuda_context_; + std::unordered_map module_cache_; + std::unordered_map kernel_cache_; + std::mutex mutex_; }; } // namespace tilefusion::jit diff --git a/include/jit/config.hpp b/include/jit/config.hpp index 7fc59326..c4e76067 100644 --- a/include/jit/config.hpp +++ b/include/jit/config.hpp @@ -6,25 +6,25 @@ namespace tilefusion::jit { // Default JIT include paths inline std::vector get_default_include_paths() { - std::string current_file = __FILE__; - std::string project_root = - current_file.substr(0, current_file.find("/include/")); - return {project_root + "/include", - project_root + "/3rd-party/cutlass/include"}; + std::string current_file = __FILE__; + std::string project_root = + current_file.substr(0, current_file.find("/include/")); + return {project_root + "/include", + project_root + "/3rd-party/cutlass/include"}; } // Default JIT compilation arguments inline std::vector get_default_compile_args() { - return {"-O3", - "-std=c++20", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "-DNDEBUG", - "-Xcompiler", - "-fPIC", - "-Xcompiler", - "-Wall", - "-Xcompiler", - "-Wextra"}; + return {"-O3", + "-std=c++20", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-DNDEBUG", + "-Xcompiler", + "-fPIC", + "-Xcompiler", + "-Wall", + "-Xcompiler", + "-Wextra"}; } } // namespace tilefusion::jit diff --git a/include/kernel_registry.hpp b/include/kernel_registry.hpp index b9139567..d22315b6 100644 --- a/include/kernel_registry.hpp +++ b/include/kernel_registry.hpp @@ -19,73 +19,72 @@ namespace tilefusion { // import the TileFusion module in Python. #define TILEFUSION_EXPORT extern "C" __attribute__((visibility("default"))) -#define REGISTER_OP(name, schema, func) \ - namespace { \ - static bool name##_registered = []() { \ - tilefusion::KernelRegistry::instance().add_kernel(#name, schema, \ - func); \ - return true; \ - }(); \ - } +#define REGISTER_OP(name, schema, func) \ + namespace { \ + static bool name##_registered = []() { \ + tilefusion::KernelRegistry::instance().add_kernel(#name, schema, func); \ + return true; \ + }(); \ + } template struct KernelTraits { - static void register_impl(torch::Library& m, const char* name, - KernelFunc func) { - m.impl(name, torch::DispatchKey::CUDA, func); - } + static void register_impl(torch::Library& m, const char* name, + KernelFunc func) { + m.impl(name, torch::DispatchKey::CUDA, func); + } }; struct KernelInfo { - const char* name; - const char* schema; - void* func; - std::type_index type; + const char* name; + const char* schema; + void* func; + std::type_index type; }; class KernelRegistry { - public: - static KernelRegistry& instance() { - static KernelRegistry registry; - return registry; - } + public: + static KernelRegistry& instance() { + static KernelRegistry registry; + return registry; + } - template - void add_kernel(const char* name, const char* schema, KernelFunc func) { - kernels_.push_back( - {name, schema, reinterpret_cast(func), typeid(KernelFunc)}); - register_kernel_type(); - } + template + void add_kernel(const char* name, const char* schema, KernelFunc func) { + kernels_.push_back( + {name, schema, reinterpret_cast(func), typeid(KernelFunc)}); + register_kernel_type(); + } - void register_with_torch(torch::Library& m) const { - for (const auto& kernel : kernels_) { - m.def(kernel.schema); - } + void register_with_torch(torch::Library& m) const { + for (const auto& kernel : kernels_) { + m.def(kernel.schema); } + } - void register_implementations(torch::Library& m) const { - for (const auto& kernel : kernels_) { - auto it = registration_functions_.find(kernel.type); - if (it != registration_functions_.end()) { - it->second(m, kernel.name, kernel.func); - } - } + void register_implementations(torch::Library& m) const { + for (const auto& kernel : kernels_) { + auto it = registration_functions_.find(kernel.type); + if (it != registration_functions_.end()) { + it->second(m, kernel.name, kernel.func); + } } + } - private: - template - void register_kernel_type() { - registration_functions_[typeid(KernelFunc)] = - [](torch::Library& m, const char* name, void* func) { - KernelTraits::register_impl( - m, name, reinterpret_cast(func)); - }; - } + private: + template + void register_kernel_type() { + registration_functions_[typeid(KernelFunc)] = + [](torch::Library& m, const char* name, void* func) { + KernelTraits::register_impl( + m, name, reinterpret_cast(func)); + }; + } - std::vector kernels_; - std::unordered_map> - registration_functions_; + std::vector kernels_; + std::unordered_map> + registration_functions_; }; } // namespace tilefusion diff --git a/include/kernels/common.hpp b/include/kernels/common.hpp index 5eb307bb..04aadb8e 100644 --- a/include/kernels/common.hpp +++ b/include/kernels/common.hpp @@ -4,11 +4,11 @@ #pragma once #define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor.") + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor.") #define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") #define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) diff --git a/include/kernels/dispatch_macros.hpp b/include/kernels/dispatch_macros.hpp index b094675c..5084f42d 100644 --- a/include/kernels/dispatch_macros.hpp +++ b/include/kernels/dispatch_macros.hpp @@ -4,21 +4,20 @@ #pragma once #define DISPATCH_TYPE_CASE(TYPE, NV_TYPE, ...) \ - case TYPE: { \ - using scalar_t = NV_TYPE; \ - return __VA_ARGS__(); \ - } + case TYPE: { \ + using scalar_t = NV_TYPE; \ + return __VA_ARGS__(); \ + } -#define TILEFUSION_DISPATCH_ALL_TYPES(TYPE, ...) \ - c10::ScalarType _type = TYPE; \ - [&] { \ - switch (_type) { \ - DISPATCH_TYPE_CASE(c10::ScalarType::Float, float, __VA_ARGS__) \ - DISPATCH_TYPE_CASE(c10::ScalarType::Half, __half, __VA_ARGS__) \ - DISPATCH_TYPE_CASE(c10::ScalarType::BFloat16, __bfloat16, \ - __VA_ARGS__) \ - default: \ - AT_ERROR("Dispatch is not implemented for type: '", \ - toString(_type), "'"); \ - } \ - }(); +#define TILEFUSION_DISPATCH_ALL_TYPES(TYPE, ...) \ + c10::ScalarType _type = TYPE; \ + [&] { \ + switch (_type) { \ + DISPATCH_TYPE_CASE(c10::ScalarType::Float, float, __VA_ARGS__) \ + DISPATCH_TYPE_CASE(c10::ScalarType::Half, __half, __VA_ARGS__) \ + DISPATCH_TYPE_CASE(c10::ScalarType::BFloat16, __bfloat16, __VA_ARGS__) \ + default: \ + AT_ERROR("Dispatch is not implemented for type: '", toString(_type), \ + "'"); \ + } \ + }(); diff --git a/include/kernels/flash_attention_device.cuh b/include/kernels/flash_attention_device.cuh index bac1d976..60122635 100644 --- a/include/kernels/flash_attention_device.cuh +++ b/include/kernels/flash_attention_device.cuh @@ -17,118 +17,114 @@ template struct FlashAttentionTraits { - /// constants - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; - static_assert(kWarpPerCol == 1, - "warps must be arranged as a column vector."); - static constexpr int kThreads = tl::get_numel * 32; - - static constexpr int kSharedAccess = 64; - - using BaseShape = BaseTileShape; - - static constexpr int kM = dim_size<0, WholeShape>; // query length - static constexpr int kN = dim_size<1, WholeShape>; // key/value length - static constexpr int kK = dim_size<2, WholeShape>; // query/key hidden dim - static constexpr int kP = dim_size<3, WholeShape>; // value hidden dim - - static constexpr int kTM = dim_size<0, CtaTileShape>; - static constexpr int kTN = dim_size<1, CtaTileShape>; - static constexpr int kTK = dim_size<2, CtaTileShape>; - static constexpr int kTP = dim_size<3, CtaTileShape>; - - static constexpr double kSoftmaxScale = kSoftmaxScale_; - static constexpr bool kIsCausal = kIsCausal_; - - // query - using GlobalQ = GlobalTile>; - using GIteratorQ = GTileIterator>; - using SharedQ = - SharedTile, true, kSharedAccess>; - - static constexpr int kQMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kQKs = kTK / BaseShape::kCols; - using RegQ = RegTile, tl::RowMajor>; - - using SharedQLoader = GlobalToSharedLoader; - using RegQLoader = - SharedToRegLoader; - - // key - using GlobalK = GlobalTile>; - using GIteratorK = GTileIterator>; - using SharedK = - SharedTile, true, kSharedAccess>; - - static constexpr int kKKs = kTK / BaseShape::kRows; - static constexpr int kKNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegK = RegTile, tl::ColMajor>; - - using SharedKLoader = GlobalToSharedLoader; - using RegKLoader = - SharedToRegLoader; - - // value - using GlobalV = GlobalTile>; - using GIteratorV = GTileIterator>; - using SharedV = - SharedTile, true, kSharedAccess>; - - static constexpr int kVNs = kTN / BaseShape::kRows; - static constexpr int kVPs = kTP / kWarpPerCol / BaseShape::kCols; - using RegV = RegTile, tl::ColMajor>; - - using SharedVLoader = GlobalToSharedLoader; - using RegVLoader = - SharedToRegLoader; - - // output - using GlobalO = GlobalTile>; - - static constexpr int kOMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kOPs = kTP / kWarpPerCol / BaseShape::kCols; - using RegO = RegTile, tl::RowMajor>; - using RegOCast = - RegTile, tl::RowMajor>; - using OStorer = RegToGlobalStorer; - - // Reg Acc - static constexpr int kAccMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAccNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegAcc = - RegTile, tl::RowMajor>; - using RegAccCast = - RegTile, tl::RowMajor>; - - // Convert the accumulator to half - using ConvertAcc = RegTileConvert; - using ConvertO = RegTileConvert; - - using RegVec = RegTile>; - - using CopyVec = BaseTileCopy; - using RowMax = MaxReduce; - - using RowSum = SumReduce; - - using BroadcastSub = - BroadcastSub; - using BroadcastMul = BroadcastMul; - using BroadcastDiv = BroadcastDiv; - - using BlockExp = RegTileExp; - using BlockAdd = RegTileAdd; - - using VecMax = BaseTileMax; - using VecAdd = BaseTileAdd; - using VecSub = BaseTileSub; - using VecMul = BaseTileMul; - using VecExp = BaseTileExp; - - using ApplyMask = - ApplyMask; - using ApplyScoreScale = BroadcastScalarMul; + /// constants + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; + static_assert(kWarpPerCol == 1, "warps must be arranged as a column vector."); + static constexpr int kThreads = tl::get_numel * 32; + + static constexpr int kSharedAccess = 64; + + using BaseShape = BaseTileShape; + + static constexpr int kM = dim_size<0, WholeShape>; // query length + static constexpr int kN = dim_size<1, WholeShape>; // key/value length + static constexpr int kK = dim_size<2, WholeShape>; // query/key hidden dim + static constexpr int kP = dim_size<3, WholeShape>; // value hidden dim + + static constexpr int kTM = dim_size<0, CtaTileShape>; + static constexpr int kTN = dim_size<1, CtaTileShape>; + static constexpr int kTK = dim_size<2, CtaTileShape>; + static constexpr int kTP = dim_size<3, CtaTileShape>; + + static constexpr double kSoftmaxScale = kSoftmaxScale_; + static constexpr bool kIsCausal = kIsCausal_; + + // query + using GlobalQ = GlobalTile>; + using GIteratorQ = GTileIterator>; + using SharedQ = + SharedTile, true, kSharedAccess>; + + static constexpr int kQMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kQKs = kTK / BaseShape::kCols; + using RegQ = RegTile, tl::RowMajor>; + + using SharedQLoader = GlobalToSharedLoader; + using RegQLoader = + SharedToRegLoader; + + // key + using GlobalK = GlobalTile>; + using GIteratorK = GTileIterator>; + using SharedK = + SharedTile, true, kSharedAccess>; + + static constexpr int kKKs = kTK / BaseShape::kRows; + static constexpr int kKNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegK = RegTile, tl::ColMajor>; + + using SharedKLoader = GlobalToSharedLoader; + using RegKLoader = + SharedToRegLoader; + + // value + using GlobalV = GlobalTile>; + using GIteratorV = GTileIterator>; + using SharedV = + SharedTile, true, kSharedAccess>; + + static constexpr int kVNs = kTN / BaseShape::kRows; + static constexpr int kVPs = kTP / kWarpPerCol / BaseShape::kCols; + using RegV = RegTile, tl::ColMajor>; + + using SharedVLoader = GlobalToSharedLoader; + using RegVLoader = + SharedToRegLoader; + + // output + using GlobalO = GlobalTile>; + + static constexpr int kOMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kOPs = kTP / kWarpPerCol / BaseShape::kCols; + using RegO = RegTile, tl::RowMajor>; + using RegOCast = RegTile, tl::RowMajor>; + using OStorer = RegToGlobalStorer; + + // Reg Acc + static constexpr int kAccMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAccNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegAcc = + RegTile, tl::RowMajor>; + using RegAccCast = + RegTile, tl::RowMajor>; + + // Convert the accumulator to half + using ConvertAcc = RegTileConvert; + using ConvertO = RegTileConvert; + + using RegVec = RegTile>; + + using CopyVec = BaseTileCopy; + using RowMax = MaxReduce; + + using RowSum = SumReduce; + + using BroadcastSub = BroadcastSub; + using BroadcastMul = BroadcastMul; + using BroadcastDiv = BroadcastDiv; + + using BlockExp = RegTileExp; + using BlockAdd = RegTileAdd; + + using VecMax = BaseTileMax; + using VecAdd = BaseTileAdd; + using VecSub = BaseTileSub; + using VecMul = BaseTileMul; + using VecExp = BaseTileExp; + + using ApplyMask = ApplyMask; + using ApplyScoreScale = BroadcastScalarMul; }; template (shared_buf); + /// declare shared memory buffer + extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; + auto* shm = reinterpret_cast(shared_buf); - InType* sQ_ptr = shm; - InType* sK_ptr = sQ_ptr + KeTraits::SharedQ::kNumel; - InType* sV_ptr = sK_ptr + KeTraits::SharedK::kNumel; + InType* sQ_ptr = shm; + InType* sK_ptr = sQ_ptr + KeTraits::SharedQ::kNumel; + InType* sV_ptr = sK_ptr + KeTraits::SharedK::kNumel; - typename KeTraits::GIteratorQ gQs(Q); - typename KeTraits::SharedQ sQ(sQ_ptr); - typename KeTraits::RegQ rQ; + typename KeTraits::GIteratorQ gQs(Q); + typename KeTraits::SharedQ sQ(sQ_ptr); + typename KeTraits::RegQ rQ; - typename KeTraits::SharedQLoader load_sq; - typename KeTraits::RegQLoader load_rq; + typename KeTraits::SharedQLoader load_sq; + typename KeTraits::RegQLoader load_rq; - typename KeTraits::GIteratorK gKs(K); - typename KeTraits::SharedK sK(sK_ptr); - typename KeTraits::RegK rK; + typename KeTraits::GIteratorK gKs(K); + typename KeTraits::SharedK sK(sK_ptr); + typename KeTraits::RegK rK; - typename KeTraits::SharedKLoader load_sk; - typename KeTraits::RegKLoader load_rk; + typename KeTraits::SharedKLoader load_sk; + typename KeTraits::RegKLoader load_rk; - typename KeTraits::GIteratorV gVs(V); - typename KeTraits::SharedV sV(sV_ptr); + typename KeTraits::GIteratorV gVs(V); + typename KeTraits::SharedV sV(sV_ptr); - typename KeTraits::SharedVLoader load_sv; - typename KeTraits::RegVLoader load_rv; - typename KeTraits::RegV rV; + typename KeTraits::SharedVLoader load_sv; + typename KeTraits::RegVLoader load_rv; + typename KeTraits::RegV rV; - typename KeTraits::RegO exp_values_f32; + typename KeTraits::RegO exp_values_f32; - typename KeTraits::RegOCast rO; - typename KeTraits::RegOCast exp_values; + typename KeTraits::RegOCast rO; + typename KeTraits::RegOCast exp_values; - typename KeTraits::RegAcc attn_block_f32; - typename KeTraits::RegAccCast attn_block; + typename KeTraits::RegAcc attn_block_f32; + typename KeTraits::RegAccCast attn_block; - typename KeTraits::RegVec prev_norm_vec; - typename KeTraits::RegVec cur_norm_vec; + typename KeTraits::RegVec prev_norm_vec; + typename KeTraits::RegVec cur_norm_vec; - typename KeTraits::RegVec prev_max_vec; - typename KeTraits::RegVec cur_max_vec; - typename KeTraits::RegVec new_max_vec; + typename KeTraits::RegVec prev_max_vec; + typename KeTraits::RegVec cur_max_vec; + typename KeTraits::RegVec new_max_vec; - typename KeTraits::RegVec prev_sum_vec; - typename KeTraits::RegVec cur_sum_vec; - typename KeTraits::RegVec new_sum_vec; + typename KeTraits::RegVec prev_sum_vec; + typename KeTraits::RegVec cur_sum_vec; + typename KeTraits::RegVec new_sum_vec; - typename KeTraits::RegVec prev_norm_mul_sum; - typename KeTraits::RegVec cur_norm_mul_sum; - typename KeTraits::RegVec prev_sum_mul_norm; + typename KeTraits::RegVec prev_norm_mul_sum; + typename KeTraits::RegVec cur_norm_mul_sum; + typename KeTraits::RegVec prev_sum_mul_norm; - typename KeTraits::RowMax row_max; - typename KeTraits::RowSum row_sum; - typename KeTraits::CopyVec copy_vec; + typename KeTraits::RowMax row_max; + typename KeTraits::RowSum row_sum; + typename KeTraits::CopyVec copy_vec; - typename KeTraits::ConvertAcc cast_acc; // Convert acc to half precision - typename KeTraits::ConvertO cast_o; // Convert half precision to float. + typename KeTraits::ConvertAcc cast_acc; // Convert acc to half precision + typename KeTraits::ConvertO cast_o; // Convert half precision to float. - typename KeTraits::BroadcastSub broadcast_sub; - typename KeTraits::BroadcastMul broadcast_mul; - typename KeTraits::BroadcastDiv broadcast_div; + typename KeTraits::BroadcastSub broadcast_sub; + typename KeTraits::BroadcastMul broadcast_mul; + typename KeTraits::BroadcastDiv broadcast_div; - typename KeTraits::BlockExp block_exp; - typename KeTraits::BlockAdd block_add; + typename KeTraits::BlockExp block_exp; + typename KeTraits::BlockAdd block_add; - typename KeTraits::VecMax vec_max; - typename KeTraits::VecAdd vec_add; - typename KeTraits::VecSub vec_sub; - typename KeTraits::VecMul vec_mul; - typename KeTraits::VecExp vec_exp; + typename KeTraits::VecMax vec_max; + typename KeTraits::VecAdd vec_add; + typename KeTraits::VecSub vec_sub; + typename KeTraits::VecMul vec_mul; + typename KeTraits::VecExp vec_exp; - typename KeTraits::ApplyMask apply_mask; - typename KeTraits::ApplyScoreScale apply_score_scale; + typename KeTraits::ApplyMask apply_mask; + typename KeTraits::ApplyScoreScale apply_score_scale; - for (int n = 0; n < KeTraits::GIteratorV::sc0; ++n) { - load_sv(gVs(n), sV); + for (int n = 0; n < KeTraits::GIteratorV::sc0; ++n) { + load_sv(gVs(n), sV); - for (int k = 0; k < KeTraits::GIteratorQ::sc1; ++k) { - load_sq(gQs(k), sQ); - load_sk(gKs(k, n), sK); - __copy_async(); - __syncthreads(); + for (int k = 0; k < KeTraits::GIteratorQ::sc1; ++k) { + load_sq(gQs(k), sQ); + load_sk(gKs(k, n), sK); + __copy_async(); + __syncthreads(); - load_rq(sQ, rQ); - load_rk(sK, rK); - __syncthreads(); + load_rq(sQ, rQ); + load_rk(sK, rK); + __syncthreads(); - compute::gemm(rQ, rK, attn_block_f32); - } - load_rv(sV, rV); - __syncthreads(); + compute::gemm(rQ, rK, attn_block_f32); + } + load_rv(sV, rV); + __syncthreads(); - if (kIsCausal) { - apply_mask(attn_block_f32, blockIdx.x * kTM, n * kTN, -INFINITY); - } + if (kIsCausal) { + apply_mask(attn_block_f32, blockIdx.x * kTM, n * kTN, -INFINITY); + } - apply_score_scale(attn_block_f32, kSoftmaxScale, attn_block_f32); + apply_score_scale(attn_block_f32, kSoftmaxScale, attn_block_f32); - cast_acc(attn_block_f32, attn_block); + cast_acc(attn_block_f32, attn_block); - // Compute row max. - row_max(attn_block, cur_max_vec); + // Compute row max. + row_max(attn_block, cur_max_vec); - // Broadcast subtract from `attn_block`. - broadcast_sub(cur_max_vec, attn_block); + // Broadcast subtract from `attn_block`. + broadcast_sub(cur_max_vec, attn_block); - // Compute exp in `attn_block`. - block_exp(attn_block, attn_block); + // Compute exp in `attn_block`. + block_exp(attn_block, attn_block); - // Compute `cur_sum_vec` by reduce sum of `attn_block`. - row_sum(attn_block, cur_sum_vec); + // Compute `cur_sum_vec` by reduce sum of `attn_block`. + row_sum(attn_block, cur_sum_vec); - // Compute new max vector. - vec_max(cur_max_vec, prev_max_vec, new_max_vec); + // Compute new max vector. + vec_max(cur_max_vec, prev_max_vec, new_max_vec); - // Renormalization for the previous block. - vec_sub(prev_max_vec, new_max_vec, prev_norm_vec); - vec_exp(prev_norm_vec, prev_norm_vec); + // Renormalization for the previous block. + vec_sub(prev_max_vec, new_max_vec, prev_norm_vec); + vec_exp(prev_norm_vec, prev_norm_vec); - // Renormalization for the current block. - vec_sub(cur_max_vec, new_max_vec, cur_norm_vec); - vec_exp(cur_norm_vec, cur_norm_vec); + // Renormalization for the current block. + vec_sub(cur_max_vec, new_max_vec, cur_norm_vec); + vec_exp(cur_norm_vec, cur_norm_vec); - // Update normalization factor l(x) - vec_mul(prev_norm_vec, prev_sum_vec, prev_norm_mul_sum); - vec_mul(cur_norm_vec, cur_sum_vec, cur_norm_mul_sum); - vec_add(prev_norm_mul_sum, cur_norm_mul_sum, new_sum_vec); + // Update normalization factor l(x) + vec_mul(prev_norm_vec, prev_sum_vec, prev_norm_mul_sum); + vec_mul(cur_norm_vec, cur_sum_vec, cur_norm_mul_sum); + vec_add(prev_norm_mul_sum, cur_norm_mul_sum, new_sum_vec); - // Compute unnormized attention block. - compute::gemm(attn_block, rV, exp_values_f32); + // Compute unnormized attention block. + compute::gemm(attn_block, rV, exp_values_f32); - cast_o(exp_values_f32, exp_values); + cast_o(exp_values_f32, exp_values); - broadcast_mul(prev_norm_mul_sum, rO); + broadcast_mul(prev_norm_mul_sum, rO); - broadcast_mul(cur_norm_vec, exp_values); + broadcast_mul(cur_norm_vec, exp_values); - block_add(rO, exp_values, rO); + block_add(rO, exp_values, rO); - // Normalize the attention block. - broadcast_div(new_sum_vec, rO); + // Normalize the attention block. + broadcast_div(new_sum_vec, rO); - // Update max vector and sum vector. - copy_vec(new_max_vec, prev_max_vec); - copy_vec(new_sum_vec, prev_sum_vec); + // Update max vector and sum vector. + copy_vec(new_max_vec, prev_max_vec); + copy_vec(new_sum_vec, prev_sum_vec); - // Clear the accumulator. - attn_block_f32.clear(); - exp_values_f32.clear(); - } + // Clear the accumulator. + attn_block_f32.clear(); + exp_values_f32.clear(); + } - __syncthreads(); - typename KeTraits::GlobalO gO(O); - typename KeTraits::OStorer storer_o; - storer_o(rO, gO); + __syncthreads(); + typename KeTraits::GlobalO gO(O); + typename KeTraits::OStorer storer_o; + storer_o(rO, gO); } } // namespace tilefusion::kernels diff --git a/include/kernels/fused_two_gemms_device.cuh b/include/kernels/fused_two_gemms_device.cuh index 81e029ac..35eb03b6 100644 --- a/include/kernels/fused_two_gemms_device.cuh +++ b/include/kernels/fused_two_gemms_device.cuh @@ -17,99 +17,98 @@ template struct FusedTwoGemmsTraits { - /// constants - using InType = InType_; - using AccType = AccType_; - - static constexpr int kM = kM_; - static constexpr int kN = kN_; - static constexpr int kK = kK_; - static constexpr int kP = kP_; - - static constexpr int kTM = kTM_; - static constexpr int kTN = kTN_; - static constexpr int kTK = kTK_; - static constexpr int kTP = kTP_; - - static constexpr int kSharedAccess = 64; - - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; - static_assert(kWarpPerCol == 1, "WarpPerCol must be 1"); - - // operand A - using GlobalA = GlobalTile>; - // chunk the K dimension to fit into shared memory - using GIteratorA = GTileIterator>; - - static const bool kUseSwizzling = true; - - using SharedA = SharedTile, kUseSwizzling, - kSharedAccess>; - - using BaseShape = BaseTileShape; - static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kTK / BaseShape::kCols; - using RegA = RegTile, tl::RowMajor>; - - using SharedALoader = GlobalToSharedLoader; - using RegALoader = - SharedToRegLoader; - - // operand B - using GlobalB = GlobalTile>; - using GIteratorB = GTileIterator>; - using SharedB = SharedTile, kUseSwizzling, - kSharedAccess>; - - static constexpr int kBKs = kTK / BaseShape::kRows; - static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegB = RegTile, tl::ColMajor>; - - using SharedBLoader = GlobalToSharedLoader; - using RegBLoader = - SharedToRegLoader; - - // operand C - using GlobalC = GlobalTile>; - // chunk the N dimension to fit into shared memory - using GIteratorC = GTileIterator>; - using SharedC = SharedTile, kUseSwizzling, - kSharedAccess>; - - static constexpr int kCNs = kTN / BaseShape::kRows; - static constexpr int kCPs = kTP / kWarpPerCol / BaseShape::kCols; - using RegC = RegTile, tl::ColMajor>; - - using SharedCLoader = GlobalToSharedLoader; - using RegCLoader = - SharedToRegLoader; - - // output D - using GlobalD = GlobalTile>; - using SharedD = SharedTile, kUseSwizzling, - kSharedAccess>; - - static constexpr int kDMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kDPs = kTP / kWarpPerCol / BaseShape::kCols; - using RegD = RegTile, tl::RowMajor>; - using RegDHalf = - RegTile, tl::RowMajor>; - - // Reg Acc - static constexpr int kAccMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAccNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegAcc = - RegTile, tl::RowMajor>; - using RegAccCast = - RegTile, tl::RowMajor>; - - // Convert the accumulator to half - using ConvertHalf = compute::RegTileConvert; - using ConvertD = compute::RegTileConvert; - - using StoreRegD = RegToSharedStorer; - using StoreSharedD = SharedToGlobalStorer; + /// constants + using InType = InType_; + using AccType = AccType_; + + static constexpr int kM = kM_; + static constexpr int kN = kN_; + static constexpr int kK = kK_; + static constexpr int kP = kP_; + + static constexpr int kTM = kTM_; + static constexpr int kTN = kTN_; + static constexpr int kTK = kTK_; + static constexpr int kTP = kTP_; + + static constexpr int kSharedAccess = 64; + + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; + static_assert(kWarpPerCol == 1, "WarpPerCol must be 1"); + + // operand A + using GlobalA = GlobalTile>; + // chunk the K dimension to fit into shared memory + using GIteratorA = GTileIterator>; + + static const bool kUseSwizzling = true; + + using SharedA = + SharedTile, kUseSwizzling, kSharedAccess>; + + using BaseShape = BaseTileShape; + static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kTK / BaseShape::kCols; + using RegA = RegTile, tl::RowMajor>; + + using SharedALoader = GlobalToSharedLoader; + using RegALoader = + SharedToRegLoader; + + // operand B + using GlobalB = GlobalTile>; + using GIteratorB = GTileIterator>; + using SharedB = + SharedTile, kUseSwizzling, kSharedAccess>; + + static constexpr int kBKs = kTK / BaseShape::kRows; + static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegB = RegTile, tl::ColMajor>; + + using SharedBLoader = GlobalToSharedLoader; + using RegBLoader = + SharedToRegLoader; + + // operand C + using GlobalC = GlobalTile>; + // chunk the N dimension to fit into shared memory + using GIteratorC = GTileIterator>; + using SharedC = + SharedTile, kUseSwizzling, kSharedAccess>; + + static constexpr int kCNs = kTN / BaseShape::kRows; + static constexpr int kCPs = kTP / kWarpPerCol / BaseShape::kCols; + using RegC = RegTile, tl::ColMajor>; + + using SharedCLoader = GlobalToSharedLoader; + using RegCLoader = + SharedToRegLoader; + + // output D + using GlobalD = GlobalTile>; + using SharedD = + SharedTile, kUseSwizzling, kSharedAccess>; + + static constexpr int kDMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kDPs = kTP / kWarpPerCol / BaseShape::kCols; + using RegD = RegTile, tl::RowMajor>; + using RegDHalf = RegTile, tl::RowMajor>; + + // Reg Acc + static constexpr int kAccMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAccNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegAcc = + RegTile, tl::RowMajor>; + using RegAccCast = + RegTile, tl::RowMajor>; + + // Convert the accumulator to half + using ConvertHalf = compute::RegTileConvert; + using ConvertD = compute::RegTileConvert; + + using StoreRegD = RegToSharedStorer; + using StoreSharedD = SharedToGlobalStorer; }; template @@ -117,101 +116,100 @@ __device__ __forceinline__ void ke_fused_two_gemms(const InType* dA, const InType* dB, const InType* dC, InType* dD) { - // constants - static constexpr int kM = KeTraits::kM; - static constexpr int kN = KeTraits::kN; - static constexpr int kK = KeTraits::kK; - static constexpr int kP = KeTraits::kP; - - static constexpr int kTM = KeTraits::kTM; - static constexpr int kTP = KeTraits::kTP; - - using SharedA = KeTraits::SharedA; - using SharedB = KeTraits::SharedB; - using SharedC = KeTraits::SharedC; - using SharedD = KeTraits::SharedD; - - // Advance to the global data tile to the current CTA. - const InType* A = dA + blockIdx.z * (kM * kK) + blockIdx.x * (kTM * kK); - const InType* B = dB + blockIdx.z * (kK * kN); - const InType* gC_ptr = - dC + blockIdx.z * (kN * kP) + blockIdx.y * (kTP * kN); - - InType* gD_ptr = dD + blockIdx.z * (kM * kP) + blockIdx.x * (kTM * kP) + - (blockIdx.y * kTP); - - // shared memory buffer - extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; - auto* shm = reinterpret_cast(shared_buf); - - InType* sA_ptr = shm; - InType* sB_ptr = shm + SharedA::kNumel; - InType* sC_ptr = shm + SharedA::kNumel + SharedB::kNumel; - InType* sD_ptr = shm; - - // declare tile, iterators, loaders, and storers - typename KeTraits::GIteratorA gAs(A); - typename KeTraits::SharedA sA(sA_ptr); - typename KeTraits::RegA rA; - - typename KeTraits::SharedALoader load_sa; - typename KeTraits::RegALoader load_ra; - - typename KeTraits::GIteratorB gBs(B); - typename KeTraits::SharedB sB(sB_ptr); - typename KeTraits::RegB rB; - - typename KeTraits::SharedBLoader load_sb; - typename KeTraits::RegBLoader load_rb; - - typename KeTraits::GIteratorC gCs(gC_ptr); - typename KeTraits::SharedC sC(sC_ptr); - - typename KeTraits::SharedCLoader load_sc; - typename KeTraits::RegCLoader load_rc; - typename KeTraits::RegC rC; - - typename KeTraits::GlobalD gD(gD_ptr); - typename KeTraits::SharedD sD(sD_ptr); - typename KeTraits::RegD rD; - typename KeTraits::RegDHalf rD_half; - - typename KeTraits::RegAcc acc; - typename KeTraits::RegAccCast acc_half; - - typename KeTraits::ConvertHalf cast_acc; - typename KeTraits::ConvertD convert_d; - - for (int n = 0; n < KeTraits::GIteratorC::sc0; ++n) { - load_sc(gCs(n), sC); - - for (int k = 0; k < KeTraits::GIteratorA::sc1; ++k) { - load_sa(gAs(k), sA); - load_sb(gBs(k, n), sB); - __copy_async(); - __syncthreads(); - - load_ra(sA, rA); - load_rb(sB, rB); - __syncthreads(); - gemm(rA, rB, acc); - } - load_rc(sC, rC); - __syncthreads(); - - cast_acc(acc, acc_half); - - gemm(acc_half, rC, rD); - acc.clear(); + // constants + static constexpr int kM = KeTraits::kM; + static constexpr int kN = KeTraits::kN; + static constexpr int kK = KeTraits::kK; + static constexpr int kP = KeTraits::kP; + + static constexpr int kTM = KeTraits::kTM; + static constexpr int kTP = KeTraits::kTP; + + using SharedA = KeTraits::SharedA; + using SharedB = KeTraits::SharedB; + using SharedC = KeTraits::SharedC; + using SharedD = KeTraits::SharedD; + + // Advance to the global data tile to the current CTA. + const InType* A = dA + blockIdx.z * (kM * kK) + blockIdx.x * (kTM * kK); + const InType* B = dB + blockIdx.z * (kK * kN); + const InType* gC_ptr = dC + blockIdx.z * (kN * kP) + blockIdx.y * (kTP * kN); + + InType* gD_ptr = dD + blockIdx.z * (kM * kP) + blockIdx.x * (kTM * kP) + + (blockIdx.y * kTP); + + // shared memory buffer + extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; + auto* shm = reinterpret_cast(shared_buf); + + InType* sA_ptr = shm; + InType* sB_ptr = shm + SharedA::kNumel; + InType* sC_ptr = shm + SharedA::kNumel + SharedB::kNumel; + InType* sD_ptr = shm; + + // declare tile, iterators, loaders, and storers + typename KeTraits::GIteratorA gAs(A); + typename KeTraits::SharedA sA(sA_ptr); + typename KeTraits::RegA rA; + + typename KeTraits::SharedALoader load_sa; + typename KeTraits::RegALoader load_ra; + + typename KeTraits::GIteratorB gBs(B); + typename KeTraits::SharedB sB(sB_ptr); + typename KeTraits::RegB rB; + + typename KeTraits::SharedBLoader load_sb; + typename KeTraits::RegBLoader load_rb; + + typename KeTraits::GIteratorC gCs(gC_ptr); + typename KeTraits::SharedC sC(sC_ptr); + + typename KeTraits::SharedCLoader load_sc; + typename KeTraits::RegCLoader load_rc; + typename KeTraits::RegC rC; + + typename KeTraits::GlobalD gD(gD_ptr); + typename KeTraits::SharedD sD(sD_ptr); + typename KeTraits::RegD rD; + typename KeTraits::RegDHalf rD_half; + + typename KeTraits::RegAcc acc; + typename KeTraits::RegAccCast acc_half; + + typename KeTraits::ConvertHalf cast_acc; + typename KeTraits::ConvertD convert_d; + + for (int n = 0; n < KeTraits::GIteratorC::sc0; ++n) { + load_sc(gCs(n), sC); + + for (int k = 0; k < KeTraits::GIteratorA::sc1; ++k) { + load_sa(gAs(k), sA); + load_sb(gBs(k, n), sB); + __copy_async(); + __syncthreads(); + + load_ra(sA, rA); + load_rb(sB, rB); + __syncthreads(); + gemm(rA, rB, acc); } + load_rc(sC, rC); __syncthreads(); - convert_d(rD, rD_half); - typename KeTraits::StoreRegD store_rD; - store_rD(rD_half, sD); - __syncthreads(); + cast_acc(acc, acc_half); + + gemm(acc_half, rC, rD); + acc.clear(); + } + __syncthreads(); + convert_d(rD, rD_half); + + typename KeTraits::StoreRegD store_rD; + store_rD(rD_half, sD); + __syncthreads(); - typename KeTraits::StoreSharedD store_sD; - store_sD(sD, gD); + typename KeTraits::StoreSharedD store_sD; + store_sD(sD, gD); } } // namespace tilefusion::kernels diff --git a/include/kernels/gemm_device.cuh b/include/kernels/gemm_device.cuh index 1b0fe1f7..fac4fd86 100644 --- a/include/kernels/gemm_device.cuh +++ b/include/kernels/gemm_device.cuh @@ -18,406 +18,404 @@ template struct KeGemmTraits { - using InType = InType_; - using AccType = AccType_; - using BaseShape = BaseTileShape; - static constexpr int kNumStages = kNumStages_; - - static constexpr int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; - - static constexpr int kM = kM_; - static constexpr int kN = kN_; - static constexpr int kK = kK_; - - static constexpr int kTM = kTM_; - static constexpr int kTN = kTN_; - static constexpr int kTK = kTK_; - static constexpr int kRK = kRK_; - - static const bool kSwizzled = true; - - // Total data access for operand A in global memory - using GlobalA = GlobalTile>; - // Access a single global tile for operand A - using GIteratorA = GTileIterator>; - - // Shared Tile for operand A - using SharedA = - SharedTile, kSwizzled, kSharedAccess>; - // Access a single register tile for operand A - using SIteratorA = STileIterator>; - - // Register tile for a single thread of operand A - static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kRK / BaseShape::kCols; - using RegA = RegTile, tl::RowMajor>; - - // Loaders for operand A - using G2SLoaderA = GlobalToSharedLoader; - using S2RLoaderA = - SharedToRegLoader; - - // Total data access for operand B in global memory - using GlobalB = GlobalTile>; - // Access a single global tile for operand B - using GIteratorB = GTileIterator>; - - // Shared Tile for operand B - using SharedB = - SharedTile, kSwizzled, kSharedAccess>; - // Access a single register tile for operand B - using SIteratorB = STileIterator>; - - static_assert(GIteratorA::sc1 == GIteratorB::sc0, - "mismatched K dimension!"); - static_assert(SIteratorA::sc1 == SIteratorB::sc0, - "mismatched K dimension!"); - - // Register tile for a single thread of operand A - static constexpr int kBKs = kRK / BaseShape::kRows; - static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegB = RegTile, tl::ColMajor>; - - using G2SLoaderB = GlobalToSharedLoader; - using S2RLoaderB = - SharedToRegLoader; - - // Global Tile for output C - using GlobalC = GlobalTile>; - // Shared Tile for output C - using SharedC = SharedTile, kSwizzled>; - - // Register Tile for output C - static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; - static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; - using RegC = RegTile, tl::RowMajor>; - - using R2GStorerC = RegToGlobalStorer; - using R2SStorerC = RegToSharedStorer; - using S2GStorerC = SharedToGlobalStorer; - - using PipelineG2SA = Pipeline; - using PipelineG2SB = Pipeline; - - using PipelineS2RA = Pipeline; - using PipelineS2RB = Pipeline; + using InType = InType_; + using AccType = AccType_; + using BaseShape = BaseTileShape; + static constexpr int kNumStages = kNumStages_; + + static constexpr int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; + + static constexpr int kM = kM_; + static constexpr int kN = kN_; + static constexpr int kK = kK_; + + static constexpr int kTM = kTM_; + static constexpr int kTN = kTN_; + static constexpr int kTK = kTK_; + static constexpr int kRK = kRK_; + + static const bool kSwizzled = true; + + // Total data access for operand A in global memory + using GlobalA = GlobalTile>; + // Access a single global tile for operand A + using GIteratorA = GTileIterator>; + + // Shared Tile for operand A + using SharedA = + SharedTile, kSwizzled, kSharedAccess>; + // Access a single register tile for operand A + using SIteratorA = STileIterator>; + + // Register tile for a single thread of operand A + static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kRK / BaseShape::kCols; + using RegA = RegTile, tl::RowMajor>; + + // Loaders for operand A + using G2SLoaderA = GlobalToSharedLoader; + using S2RLoaderA = + SharedToRegLoader; + + // Total data access for operand B in global memory + using GlobalB = GlobalTile>; + // Access a single global tile for operand B + using GIteratorB = GTileIterator>; + + // Shared Tile for operand B + using SharedB = + SharedTile, kSwizzled, kSharedAccess>; + // Access a single register tile for operand B + using SIteratorB = STileIterator>; + + static_assert(GIteratorA::sc1 == GIteratorB::sc0, "mismatched K dimension!"); + static_assert(SIteratorA::sc1 == SIteratorB::sc0, "mismatched K dimension!"); + + // Register tile for a single thread of operand A + static constexpr int kBKs = kRK / BaseShape::kRows; + static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegB = RegTile, tl::ColMajor>; + + using G2SLoaderB = GlobalToSharedLoader; + using S2RLoaderB = + SharedToRegLoader; + + // Global Tile for output C + using GlobalC = GlobalTile>; + // Shared Tile for output C + using SharedC = SharedTile, kSwizzled>; + + // Register Tile for output C + static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kRows; + static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kCols; + using RegC = RegTile, tl::RowMajor>; + + using R2GStorerC = RegToGlobalStorer; + using R2SStorerC = RegToSharedStorer; + using S2GStorerC = SharedToGlobalStorer; + + using PipelineG2SA = Pipeline; + using PipelineG2SB = Pipeline; + + using PipelineS2RA = Pipeline; + using PipelineS2RB = Pipeline; }; template __device__ __forceinline__ void ke_gemm(const InType* dA, const InType* dB, AccType* dC) { - static constexpr int kN = KeTraits::kN; - static constexpr int kK = KeTraits::kK; - static constexpr int kTM = KeTraits::kTM; - static constexpr int kTN = KeTraits::kTN; + static constexpr int kN = KeTraits::kN; + static constexpr int kK = KeTraits::kK; + static constexpr int kTM = KeTraits::kTM; + static constexpr int kTN = KeTraits::kTN; - int offset_a = blockIdx.x * kTM * kK; - int offset_b = blockIdx.y * kTN * kK; - int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; + int offset_a = blockIdx.x * kTM * kK; + int offset_b = blockIdx.y * kTN * kK; + int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; - extern __shared__ __align__(sizeof(double)) unsigned char buf[]; - InType* sA_ptr = reinterpret_cast(buf); - InType* sB_ptr = sA_ptr + KeTraits::SIteratorA::Tile::kNumel; - AccType* sC_ptr = reinterpret_cast(buf); + extern __shared__ __align__(sizeof(double)) unsigned char buf[]; + InType* sA_ptr = reinterpret_cast(buf); + InType* sB_ptr = sA_ptr + KeTraits::SIteratorA::Tile::kNumel; + AccType* sC_ptr = reinterpret_cast(buf); - // declare tiles, iterators and loaders - typename KeTraits::GIteratorA gAs(dA + offset_a); - typename KeTraits::SIteratorA sAs(sA_ptr); + // declare tiles, iterators and loaders + typename KeTraits::GIteratorA gAs(dA + offset_a); + typename KeTraits::SIteratorA sAs(sA_ptr); - typename KeTraits::GIteratorB gBs(dB + offset_b); - typename KeTraits::SIteratorB sBs(sB_ptr); + typename KeTraits::GIteratorB gBs(dB + offset_b); + typename KeTraits::SIteratorB sBs(sB_ptr); - typename KeTraits::SharedA sA(sA_ptr); - typename KeTraits::RegA rA; + typename KeTraits::SharedA sA(sA_ptr); + typename KeTraits::RegA rA; - typename KeTraits::SharedB sB(sB_ptr); - typename KeTraits::RegB rB; + typename KeTraits::SharedB sB(sB_ptr); + typename KeTraits::RegB rB; - typename KeTraits::RegC acc; - typename KeTraits::SharedC sC(sC_ptr); - typename KeTraits::GlobalC gC(dC + offset_c); + typename KeTraits::RegC acc; + typename KeTraits::SharedC sC(sC_ptr); + typename KeTraits::GlobalC gC(dC + offset_c); - typename KeTraits::G2SLoaderA g2s_a; - typename KeTraits::S2RLoaderA s2r_a; + typename KeTraits::G2SLoaderA g2s_a; + typename KeTraits::S2RLoaderA s2r_a; - typename KeTraits::G2SLoaderB g2s_b; - typename KeTraits::S2RLoaderB s2r_b; + typename KeTraits::G2SLoaderB g2s_b; + typename KeTraits::S2RLoaderB s2r_b; - typename KeTraits::R2SStorerC r2s_c; - typename KeTraits::S2GStorerC s2g_c; + typename KeTraits::R2SStorerC r2s_c; + typename KeTraits::S2GStorerC s2g_c; - for (int k1 = 0; k1 < KeTraits::GIteratorA::sc1; ++k1) { - g2s_a(gAs(k1), sA); - g2s_b(gBs(k1), sB); - __copy_async(); - __syncthreads(); + for (int k1 = 0; k1 < KeTraits::GIteratorA::sc1; ++k1) { + g2s_a(gAs(k1), sA); + g2s_b(gBs(k1), sB); + __copy_async(); + __syncthreads(); - for (int k2 = 0; k2 < KeTraits::SIteratorA::sc1; ++k2) { - s2r_a(sAs(k2), rA); - s2r_b(sBs(k2), rB); + for (int k2 = 0; k2 < KeTraits::SIteratorA::sc1; ++k2) { + s2r_a(sAs(k2), rA); + s2r_b(sBs(k2), rB); - compute::gemm(rA, rB, acc); - } + compute::gemm(rA, rB, acc); } - r2s_c(acc, sC); - __syncthreads(); - s2g_c(sC, gC); + } + r2s_c(acc, sC); + __syncthreads(); + s2g_c(sC, gC); } template __device__ __forceinline__ void ke_gemm_level1_pipeline(const InType* dA, const InType* dB, AccType* dC) { - static constexpr int kTM = KeTraits::kTM; - static constexpr int kTN = KeTraits::kTN; - static constexpr int kN = KeTraits::kN; - static constexpr int kK = KeTraits::kK; - static constexpr int kNumStages = KeTraits::kNumStages; + static constexpr int kTM = KeTraits::kTM; + static constexpr int kTN = KeTraits::kTN; + static constexpr int kN = KeTraits::kN; + static constexpr int kK = KeTraits::kK; + static constexpr int kNumStages = KeTraits::kNumStages; - int offset_a = blockIdx.x * kTM * kK; - int offset_b = blockIdx.y * kTN * kK; - int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; + int offset_a = blockIdx.x * kTM * kK; + int offset_b = blockIdx.y * kTN * kK; + int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; - extern __shared__ __align__(sizeof(double)) unsigned char buf[]; - InType* sA_ptr = reinterpret_cast(buf); - InType* sB_ptr = sA_ptr + KeTraits::SIteratorA::Tile::kNumel * kNumStages; - AccType* sC_ptr = reinterpret_cast(buf); + extern __shared__ __align__(sizeof(double)) unsigned char buf[]; + InType* sA_ptr = reinterpret_cast(buf); + InType* sB_ptr = sA_ptr + KeTraits::SIteratorA::Tile::kNumel * kNumStages; + AccType* sC_ptr = reinterpret_cast(buf); - typename KeTraits::RegA rA; - typename KeTraits::RegB rB; + typename KeTraits::RegA rA; + typename KeTraits::RegB rB; - typename KeTraits::RegC acc; - typename KeTraits::SharedC sC(sC_ptr); - typename KeTraits::GlobalC gC(dC + offset_c); + typename KeTraits::RegC acc; + typename KeTraits::SharedC sC(sC_ptr); + typename KeTraits::GlobalC gC(dC + offset_c); - typename KeTraits::G2SLoaderA g2s_a; - typename KeTraits::S2RLoaderA s2r_a; + typename KeTraits::G2SLoaderA g2s_a; + typename KeTraits::S2RLoaderA s2r_a; - typename KeTraits::G2SLoaderB g2s_b; - typename KeTraits::S2RLoaderB s2r_b; + typename KeTraits::G2SLoaderB g2s_b; + typename KeTraits::S2RLoaderB s2r_b; - typename KeTraits::PipelineG2SA pipeline_g2s_a(dA + offset_a, sA_ptr); - typename KeTraits::PipelineG2SB pipeline_g2s_b(dB + offset_b, sB_ptr); + typename KeTraits::PipelineG2SA pipeline_g2s_a(dA + offset_a, sA_ptr); + typename KeTraits::PipelineG2SB pipeline_g2s_b(dB + offset_b, sB_ptr); - // Issue the global to shared copy before main loop. - pipeline_g2s_a.commit(); - pipeline_g2s_b.commit(); - commit_copy_group(); + // Issue the global to shared copy before main loop. + pipeline_g2s_a.commit(); + pipeline_g2s_b.commit(); + commit_copy_group(); - for (int k = 0; k < KeTraits::PipelineG2SA::Iterations - 1; ++k) { - // Barrier to wait for the previous copy to finish. - wait_group<0>(); - __syncthreads(); - pipeline_g2s_a.commit(); - pipeline_g2s_b.commit(); - commit_copy_group(); - // Compute(i - 1) - const InType* sA_ptr_prev = pipeline_g2s_a.get_prev_dst(); - const InType* sB_ptr_prev = pipeline_g2s_b.get_prev_dst(); - typename KeTraits::SIteratorA sAs(sA_ptr_prev); - typename KeTraits::SIteratorB sBs(sB_ptr_prev); - for (int k2 = 0; k2 < KeTraits::SIteratorA::sc1; ++k2) { - s2r_a(sAs(k2), rA); - s2r_b(sBs(k2), rB); - compute::gemm(rA, rB, acc); - } - } + for (int k = 0; k < KeTraits::PipelineG2SA::Iterations - 1; ++k) { + // Barrier to wait for the previous copy to finish. wait_group<0>(); __syncthreads(); - - // Compute(i) - const InType* sA_ptr_cur = pipeline_g2s_a.get_cur_dst(); - const InType* sB_ptr_cur = pipeline_g2s_b.get_cur_dst(); - typename KeTraits::SIteratorA sAs(sA_ptr_cur); - typename KeTraits::SIteratorB sBs(sB_ptr_cur); + pipeline_g2s_a.commit(); + pipeline_g2s_b.commit(); + commit_copy_group(); + // Compute(i - 1) + const InType* sA_ptr_prev = pipeline_g2s_a.get_prev_dst(); + const InType* sB_ptr_prev = pipeline_g2s_b.get_prev_dst(); + typename KeTraits::SIteratorA sAs(sA_ptr_prev); + typename KeTraits::SIteratorB sBs(sB_ptr_prev); for (int k2 = 0; k2 < KeTraits::SIteratorA::sc1; ++k2) { - s2r_a(sAs(k2), rA); - s2r_b(sBs(k2), rB); - compute::gemm(rA, rB, acc); + s2r_a(sAs(k2), rA); + s2r_b(sBs(k2), rB); + compute::gemm(rA, rB, acc); } - __syncthreads(); - - // Store the result from register tile to global memory. - typename KeTraits::R2SStorerC r2s_c; - typename KeTraits::S2GStorerC s2g_c; - r2s_c(acc, sC); - __syncthreads(); - s2g_c(sC, gC); + } + wait_group<0>(); + __syncthreads(); + + // Compute(i) + const InType* sA_ptr_cur = pipeline_g2s_a.get_cur_dst(); + const InType* sB_ptr_cur = pipeline_g2s_b.get_cur_dst(); + typename KeTraits::SIteratorA sAs(sA_ptr_cur); + typename KeTraits::SIteratorB sBs(sB_ptr_cur); + for (int k2 = 0; k2 < KeTraits::SIteratorA::sc1; ++k2) { + s2r_a(sAs(k2), rA); + s2r_b(sBs(k2), rB); + compute::gemm(rA, rB, acc); + } + __syncthreads(); + + // Store the result from register tile to global memory. + typename KeTraits::R2SStorerC r2s_c; + typename KeTraits::S2GStorerC s2g_c; + r2s_c(acc, sC); + __syncthreads(); + s2g_c(sC, gC); } template __device__ __forceinline__ void ke_gemm_level2_pipeline(const InType* dA, const InType* dB, AccType* dC) { - static constexpr int kTM = KeTraits::kTM; - static constexpr int kTN = KeTraits::kTN; - static constexpr int kN = KeTraits::kN; - static constexpr int kK = KeTraits::kK; - static constexpr int kRK = KeTraits::kRK; - static constexpr int kNumStages = KeTraits::kNumStages; + static constexpr int kTM = KeTraits::kTM; + static constexpr int kTN = KeTraits::kTN; + static constexpr int kN = KeTraits::kN; + static constexpr int kK = KeTraits::kK; + static constexpr int kRK = KeTraits::kRK; + static constexpr int kNumStages = KeTraits::kNumStages; - int offset_a = blockIdx.x * kTM * kK; - int offset_b = blockIdx.y * kTN * kK; - int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; + int offset_a = blockIdx.x * kTM * kK; + int offset_b = blockIdx.y * kTN * kK; + int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; - extern __shared__ __align__(sizeof(double)) unsigned char buf[]; - InType* sA_ptr = reinterpret_cast(buf); - InType* sB_ptr = sA_ptr + KeTraits::SIteratorA::Tile::kNumel * kNumStages; - AccType* sC_ptr = reinterpret_cast(buf); + extern __shared__ __align__(sizeof(double)) unsigned char buf[]; + InType* sA_ptr = reinterpret_cast(buf); + InType* sB_ptr = sA_ptr + KeTraits::SIteratorA::Tile::kNumel * kNumStages; + AccType* sC_ptr = reinterpret_cast(buf); - // Declare the cycle buffer for the register tiles. - typename KeTraits::RegA rA_cyc_buf[kNumStages - 1]; - typename KeTraits::RegB rB_cyc_buf[kNumStages - 1]; + // Declare the cycle buffer for the register tiles. + typename KeTraits::RegA rA_cyc_buf[kNumStages - 1]; + typename KeTraits::RegB rB_cyc_buf[kNumStages - 1]; - typename KeTraits::RegC acc; - typename KeTraits::SharedC sC(sC_ptr); - typename KeTraits::GlobalC gC(dC + offset_c); + typename KeTraits::RegC acc; + typename KeTraits::SharedC sC(sC_ptr); + typename KeTraits::GlobalC gC(dC + offset_c); - typename KeTraits::G2SLoaderA g2s_a; - typename KeTraits::S2RLoaderA s2r_a; + typename KeTraits::G2SLoaderA g2s_a; + typename KeTraits::S2RLoaderA s2r_a; - typename KeTraits::G2SLoaderB g2s_b; - typename KeTraits::S2RLoaderB s2r_b; + typename KeTraits::G2SLoaderB g2s_b; + typename KeTraits::S2RLoaderB s2r_b; - typename KeTraits::PipelineG2SA pipeline_g2s_a(dA + offset_a, sA_ptr); - typename KeTraits::PipelineG2SB pipeline_g2s_b(dB + offset_b, sB_ptr); + typename KeTraits::PipelineG2SA pipeline_g2s_a(dA + offset_a, sA_ptr); + typename KeTraits::PipelineG2SB pipeline_g2s_b(dB + offset_b, sB_ptr); - // In 3-stage pipeline, we need to issue 2 global to shared copies before - // the main loop. + // In 3-stage pipeline, we need to issue 2 global to shared copies before + // the main loop. - // We issue copy instructions using 2 commit groups. - pipeline_g2s_a.commit(); - pipeline_g2s_b.commit(); - commit_copy_group(); + // We issue copy instructions using 2 commit groups. + pipeline_g2s_a.commit(); + pipeline_g2s_b.commit(); + commit_copy_group(); - pipeline_g2s_a.commit(); - pipeline_g2s_b.commit(); - commit_copy_group(); + pipeline_g2s_a.commit(); + pipeline_g2s_b.commit(); + commit_copy_group(); - // Wait for at least 1 copy to finish. - wait_group<1>(); - __syncthreads(); + // Wait for at least 1 copy to finish. + wait_group<1>(); + __syncthreads(); - const InType* sA0 = pipeline_g2s_a.get_dst_ptr_by_index(0); - const InType* sB0 = pipeline_g2s_b.get_dst_ptr_by_index(0); + const InType* sA0 = pipeline_g2s_a.get_dst_ptr_by_index(0); + const InType* sB0 = pipeline_g2s_b.get_dst_ptr_by_index(0); - typename KeTraits::PipelineS2RA pipeline_s2r_a(sA0, rA_cyc_buf); - typename KeTraits::PipelineS2RB pipeline_s2r_b(sB0, rB_cyc_buf); + typename KeTraits::PipelineS2RA pipeline_s2r_a(sA0, rA_cyc_buf); + typename KeTraits::PipelineS2RB pipeline_s2r_b(sB0, rB_cyc_buf); - // Issue the first data loading from shared memory to register tile. - pipeline_s2r_a.commit(); - pipeline_s2r_b.commit(); - auto rA = pipeline_s2r_a.get_dst_tile_by_index(0); - auto rB = pipeline_s2r_b.get_dst_tile_by_index(0); + // Issue the first data loading from shared memory to register tile. + pipeline_s2r_a.commit(); + pipeline_s2r_b.commit(); + auto rA = pipeline_s2r_a.get_dst_tile_by_index(0); + auto rB = pipeline_s2r_b.get_dst_tile_by_index(0); - // gemm stage 1: handle all global to shared copies. - // BLOCK: GIteratorA::sc1 - 2 -#pragma unroll - for (int k = 0; k < KeTraits::PipelineG2SA::Iterations - 2; ++k) { - // NOTE(KuangjuX): we have to add `#pragma unroll` here, otherwise - // misaligned errors will be reported. + // gemm stage 1: handle all global to shared copies. + // BLOCK: GIteratorA::sc1 - 2 #pragma unroll - for (int k2 = 0; k2 < KeTraits::PipelineS2RA::Iterations; ++k2) { - // circular issue next data loading from shared memory to register - // tile. - - pipeline_s2r_a.commit(); - pipeline_s2r_b.commit(); - - if (k2 == KeTraits::PipelineS2RA::Iterations - 2) { - wait_group<0>(); - __syncthreads(); - /** - * When `k2 == PipelineS2RA::Iterations - 2`, the current shared - * tile has just been traversed and needs to be replaced with a - * new shared tile. Since `PipelineG2S` is emitted twice before - * the loop, and `S2R` obtains the data of the first emission - * outside the loop, when `k = 0`, the index of the data to be - * obtained is `k + 1`. - */ - auto sA = pipeline_g2s_a.get_cur_dst(); - auto sB = pipeline_g2s_b.get_cur_dst(); - // reset the shared tile in shared to register copy. - pipeline_s2r_a.reset_src_tile(sA); - pipeline_s2r_b.reset_src_tile(sB); - } - - // execute gemm operation in previous register tile. - auto rA = pipeline_s2r_a.get_dst_tile_by_index(k2); - auto rB = pipeline_s2r_b.get_dst_tile_by_index(k2); - __syncthreads(); - compute::gemm(rA, rB, acc); - } - - // Issue the next global to shared copy. - pipeline_g2s_a.commit(); - pipeline_g2s_b.commit(); - commit_copy_group(); - } - - // gemm stage 2: handle the second-to-last shared tile. + for (int k = 0; k < KeTraits::PipelineG2SA::Iterations - 2; ++k) { // NOTE(KuangjuX): we have to add `#pragma unroll` here, otherwise // misaligned errors will be reported. #pragma unroll for (int k2 = 0; k2 < KeTraits::PipelineS2RA::Iterations; ++k2) { - // circular issue next data loading from shared memory to register - // tile. + // circular issue next data loading from shared memory to register + // tile. + + pipeline_s2r_a.commit(); + pipeline_s2r_b.commit(); + + if (k2 == KeTraits::PipelineS2RA::Iterations - 2) { + wait_group<0>(); + __syncthreads(); + /** + * When `k2 == PipelineS2RA::Iterations - 2`, the current shared + * tile has just been traversed and needs to be replaced with a + * new shared tile. Since `PipelineG2S` is emitted twice before + * the loop, and `S2R` obtains the data of the first emission + * outside the loop, when `k = 0`, the index of the data to be + * obtained is `k + 1`. + */ + auto sA = pipeline_g2s_a.get_cur_dst(); + auto sB = pipeline_g2s_b.get_cur_dst(); + // reset the shared tile in shared to register copy. + pipeline_s2r_a.reset_src_tile(sA); + pipeline_s2r_b.reset_src_tile(sB); + } + + // execute gemm operation in previous register tile. + auto rA = pipeline_s2r_a.get_dst_tile_by_index(k2); + auto rB = pipeline_s2r_b.get_dst_tile_by_index(k2); + __syncthreads(); + compute::gemm(rA, rB, acc); + } - pipeline_s2r_a.commit(); - pipeline_s2r_b.commit(); + // Issue the next global to shared copy. + pipeline_g2s_a.commit(); + pipeline_g2s_b.commit(); + commit_copy_group(); + } - if (k2 == KeTraits::PipelineS2RA::Iterations - 2) { - // Wait the last global to shared tile copy to finish. - wait_group<0>(); - __syncthreads(); + // gemm stage 2: handle the second-to-last shared tile. + // NOTE(KuangjuX): we have to add `#pragma unroll` here, otherwise + // misaligned errors will be reported. +#pragma unroll + for (int k2 = 0; k2 < KeTraits::PipelineS2RA::Iterations; ++k2) { + // circular issue next data loading from shared memory to register + // tile. - // fetch the last shared tile in global memory. - const InType* sA = pipeline_g2s_a.get_cur_dst(); - const InType* sB = pipeline_g2s_b.get_cur_dst(); + pipeline_s2r_a.commit(); + pipeline_s2r_b.commit(); - // reset the last shared tile in shared to register copy. - pipeline_s2r_a.reset_src_tile(sA); - pipeline_s2r_b.reset_src_tile(sB); - } + if (k2 == KeTraits::PipelineS2RA::Iterations - 2) { + // Wait the last global to shared tile copy to finish. + wait_group<0>(); + __syncthreads(); - auto rA = pipeline_s2r_a.get_dst_tile_by_index(k2); - auto rB = pipeline_s2r_b.get_dst_tile_by_index(k2); + // fetch the last shared tile in global memory. + const InType* sA = pipeline_g2s_a.get_cur_dst(); + const InType* sB = pipeline_g2s_b.get_cur_dst(); - compute::gemm(rA, rB, acc); + // reset the last shared tile in shared to register copy. + pipeline_s2r_a.reset_src_tile(sA); + pipeline_s2r_b.reset_src_tile(sB); } - // gemm stage 3: handle the last shared tile - // NOTE(KuangjuX): we have to add `#pragma unroll` here, otherwise - // misaligned errors will be reported. + auto rA = pipeline_s2r_a.get_dst_tile_by_index(k2); + auto rB = pipeline_s2r_b.get_dst_tile_by_index(k2); + + compute::gemm(rA, rB, acc); + } + + // gemm stage 3: handle the last shared tile + // NOTE(KuangjuX): we have to add `#pragma unroll` here, otherwise + // misaligned errors will be reported. #pragma unroll - for (int k2 = 0; k2 < KeTraits::PipelineS2RA::Iterations; ++k2) { - // In last stage, we only need to issue Iterations - 1 times - // data loading from shared memory to register tile beacuase - // we have already done an advance copy in the previous stage. - if (k2 < KeTraits::PipelineS2RA::Iterations - 1) { - pipeline_s2r_a.commit(); - pipeline_s2r_b.commit(); - } - - auto rA = pipeline_s2r_a.get_dst_tile_by_index(k2); - auto rB = pipeline_s2r_b.get_dst_tile_by_index(k2); - - compute::gemm(rA, rB, acc); + for (int k2 = 0; k2 < KeTraits::PipelineS2RA::Iterations; ++k2) { + // In last stage, we only need to issue Iterations - 1 times + // data loading from shared memory to register tile beacuase + // we have already done an advance copy in the previous stage. + if (k2 < KeTraits::PipelineS2RA::Iterations - 1) { + pipeline_s2r_a.commit(); + pipeline_s2r_b.commit(); } - __syncthreads(); + auto rA = pipeline_s2r_a.get_dst_tile_by_index(k2); + auto rB = pipeline_s2r_b.get_dst_tile_by_index(k2); - typename KeTraits::R2SStorerC r2s_c; - typename KeTraits::S2GStorerC s2g_c; - r2s_c(acc, sC); - __syncthreads(); - s2g_c(sC, gC); + compute::gemm(rA, rB, acc); + } + + __syncthreads(); + + typename KeTraits::R2SStorerC r2s_c; + typename KeTraits::S2GStorerC s2g_c; + r2s_c(acc, sC); + __syncthreads(); + s2g_c(sC, gC); } } // namespace tilefusion::kernels diff --git a/include/types/base.hpp b/include/types/base.hpp index 18ffbeca..44ffd8a2 100644 --- a/include/types/base.hpp +++ b/include/types/base.hpp @@ -41,30 +41,30 @@ concept Fp8Type = /// @param Element: the data type of the elements. template struct AccessBase { - // the maximal width of vectorized access. - static constexpr int kAccessInBits = 128; - static constexpr int kAccessInBytes = kAccessInBits / 8; + // the maximal width of vectorized access. + static constexpr int kAccessInBits = 128; + static constexpr int kAccessInBytes = kAccessInBits / 8; - static constexpr int kElementBits = sizeof(Element) * 8; - static constexpr int kNumPerAccess = kAccessInBits / kElementBits; + static constexpr int kElementBits = sizeof(Element) * 8; + static constexpr int kNumPerAccess = kAccessInBits / kElementBits; - // the width of memory transaction, Shared memory cacheline width. - static constexpr int kMemTransWidth = 1024; // 1024 bits, 128 bytes + // the width of memory transaction, Shared memory cacheline width. + static constexpr int kMemTransWidth = 1024; // 1024 bits, 128 bytes - // The ideal number of columns for a single warp to load. - // When loading data through the L1 cache, an entire 128-byte cache line is - // fetched. Ensuring contiguous threads read contiguous data in memory - // optimizes the usage of the L1 cache. - static constexpr int kExpectedSize = kMemTransWidth / kElementBits; + // The ideal number of columns for a single warp to load. + // When loading data through the L1 cache, an entire 128-byte cache line is + // fetched. Ensuring contiguous threads read contiguous data in memory + // optimizes the usage of the L1 cache. + static constexpr int kExpectedSize = kMemTransWidth / kElementBits; }; // FIXME(ying): Legacy code, remove it gradually. template - requires BaseType + requires BaseType struct BaseTileShape { - static constexpr int kRows = 16; - static constexpr int kCols = 16; - static constexpr int kNumel = 256 /* kRows * kCols */; + static constexpr int kRows = 16; + static constexpr int kCols = 16; + static constexpr int kNumel = 256 /* kRows * kCols */; }; #ifdef CUDA_FP8_AVAILABLE @@ -72,40 +72,38 @@ struct BaseTileShape { /// constructors/conversions template DEVICE T from_float(float val) { - if constexpr (std::is_same_v) { - return __fp8_e4m3(val); - } else if constexpr (std::is_same_v) { - return __fp8_e5m2(val); - } else if constexpr (std::is_same_v) { - return __float2half(val); - } else if constexpr (std::is_same_v) { - return __float2bfloat16(val); - } else if constexpr (std::is_same_v) { - return val; - } else { - static_assert(sizeof(T) == 0, - "Unsupported type for from_float conversion"); - } + if constexpr (std::is_same_v) { + return __fp8_e4m3(val); + } else if constexpr (std::is_same_v) { + return __fp8_e5m2(val); + } else if constexpr (std::is_same_v) { + return __float2half(val); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(val); + } else if constexpr (std::is_same_v) { + return val; + } else { + static_assert(sizeof(T) == 0, "Unsupported type for from_float conversion"); + } } /// @brief Convert from source type to float using built-in cast /// operators/conversions template DEVICE float to_float(const T& val) { - if constexpr (std::is_same_v) { - return static_cast(val); - } else if constexpr (std::is_same_v) { - return static_cast(val); - } else if constexpr (std::is_same_v) { - return __half2float(val); - } else if constexpr (std::is_same_v) { - return __bfloat162float(val); - } else if constexpr (std::is_same_v) { - return val; - } else { - static_assert(sizeof(T) == 0, - "Unsupported type for to_float conversion"); - } + if constexpr (std::is_same_v) { + return static_cast(val); + } else if constexpr (std::is_same_v) { + return static_cast(val); + } else if constexpr (std::is_same_v) { + return __half2float(val); + } else if constexpr (std::is_same_v) { + return __bfloat162float(val); + } else if constexpr (std::is_same_v) { + return val; + } else { + static_assert(sizeof(T) == 0, "Unsupported type for to_float conversion"); + } } #else // !CUDA_FP8_AVAILABLE @@ -113,32 +111,32 @@ DEVICE float to_float(const T& val) { /// @brief Fallback conversion functions when FP8 is not available template DEVICE T from_float(float val) { - if constexpr (std::is_same_v) { - return __float2half(val); - } else if constexpr (std::is_same_v) { - return __float2bfloat16(val); - } else if constexpr (std::is_same_v) { - return val; - } else { - static_assert(sizeof(T) == 0, - "FP8 types not available - requires Ada Lovelace (RTX " - "4090) or Hopper (H100) GPU with CUDA 11.8+"); - } + if constexpr (std::is_same_v) { + return __float2half(val); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(val); + } else if constexpr (std::is_same_v) { + return val; + } else { + static_assert(sizeof(T) == 0, + "FP8 types not available - requires Ada Lovelace (RTX " + "4090) or Hopper (H100) GPU with CUDA 11.8+"); + } } template DEVICE float to_float(const T& val) { - if constexpr (std::is_same_v) { - return __half2float(val); - } else if constexpr (std::is_same_v) { - return __bfloat162float(val); - } else if constexpr (std::is_same_v) { - return val; - } else { - static_assert(sizeof(T) == 0, - "FP8 types not available - requires Ada Lovelace (RTX " - "4090) or Hopper (H100) GPU with CUDA 11.8+"); - } + if constexpr (std::is_same_v) { + return __half2float(val); + } else if constexpr (std::is_same_v) { + return __bfloat162float(val); + } else if constexpr (std::is_same_v) { + return val; + } else { + static_assert(sizeof(T) == 0, + "FP8 types not available - requires Ada Lovelace (RTX " + "4090) or Hopper (H100) GPU with CUDA 11.8+"); + } } #endif // CUDA_FP8_AVAILABLE diff --git a/include/types/base_tile.hpp b/include/types/base_tile.hpp index 6060e65c..2b5784a9 100644 --- a/include/types/base_tile.hpp +++ b/include/types/base_tile.hpp @@ -12,13 +12,13 @@ namespace { /// @brief Helper for pretty printing a BaseTile's static shape-related /// information. This printer works ONLY on the host. struct BaseTilePrettyPrinter { - template - static HOST void print(std::ostream& out, const BaseShape& tile) { - // parameter `tile` here is not used - out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols - << "), Numel = " << BaseShape::kNumel << ", ThreadLayout = (" - << BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")"; - } + template + static HOST void print(std::ostream& out, const BaseShape& tile) { + // parameter `tile` here is not used + out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols + << "), Numel = " << BaseShape::kNumel << ", ThreadLayout = (" + << BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")"; + } }; } // namespace @@ -30,88 +30,88 @@ struct WarpBaseTileShape; template struct WarpBaseTileShape { - using AccessInfo = AccessBase; + using AccessInfo = AccessBase; - static constexpr int kTileRows = dim_size<0, TileShape>; - static constexpr int kTileCols = dim_size<1, TileShape>; + static constexpr int kTileRows = dim_size<0, TileShape>; + static constexpr int kTileCols = dim_size<1, TileShape>; - // In a row-major layout, columns are the contiguous dimension in memory. We - // enforce the use of 128-bit vectorized instructions for data loading by a - // single thread. This implies that the minimum number of columns should be - // at least 128 bits. - static constexpr int kMinCols = - AccessInfo::kAccessInBits / (sizeof(DType) * 8); + // In a row-major layout, columns are the contiguous dimension in memory. We + // enforce the use of 128-bit vectorized instructions for data loading by a + // single thread. This implies that the minimum number of columns should be + // at least 128 bits. + static constexpr int kMinCols = + AccessInfo::kAccessInBits / (sizeof(DType) * 8); - static_assert(kTileCols >= kMinCols, "The number of columns is too small."); + static_assert(kTileCols >= kMinCols, "The number of columns is too small."); - static_assert(kTileCols < AccessInfo::kExpectedSize || - (kTileCols >= AccessInfo::kExpectedSize && - kTileCols % AccessInfo::kExpectedSize == 0), - "The current implementation requires that the number of " - "columns of the tile be divisible by the cache line width."); + static_assert(kTileCols < AccessInfo::kExpectedSize || + (kTileCols >= AccessInfo::kExpectedSize && + kTileCols % AccessInfo::kExpectedSize == 0), + "The current implementation requires that the number of " + "columns of the tile be divisible by the cache line width."); - static constexpr int kCols = kTileCols >= AccessInfo::kExpectedSize - ? AccessInfo::kExpectedSize - : kTileCols; + static constexpr int kCols = kTileCols >= AccessInfo::kExpectedSize + ? AccessInfo::kExpectedSize + : kTileCols; - // number of columns in a warp - static constexpr int kColThreads = kCols / AccessInfo::kNumPerAccess; - static_assert(WARP_SIZE % kColThreads == 0, - "Fail to infer warp thread layout."); - static constexpr int kRowThreads = WARP_SIZE / kColThreads; + // number of columns in a warp + static constexpr int kColThreads = kCols / AccessInfo::kNumPerAccess; + static_assert(WARP_SIZE % kColThreads == 0, + "Fail to infer warp thread layout."); + static constexpr int kRowThreads = WARP_SIZE / kColThreads; - static constexpr int kRows = kRowThreads; - static_assert(kTileRows % kRowThreads == 0, - "The number of rows of the tile isn't evenly divisible by " - "the number of threads in a column."); + static constexpr int kRows = kRowThreads; + static_assert(kTileRows % kRowThreads == 0, + "The number of rows of the tile isn't evenly divisible by " + "the number of threads in a column."); - static constexpr tl::Layout kType = tl::Layout::kRowMajor; - static constexpr int kNumel = kRows * kCols; + static constexpr tl::Layout kType = tl::Layout::kRowMajor; + static constexpr int kNumel = kRows * kCols; - using WarpThreadLayout = tl::RowMajor; + using WarpThreadLayout = tl::RowMajor; }; template struct WarpBaseTileShape { - using AccessInfo = AccessBase; + using AccessInfo = AccessBase; - static constexpr int kTileRows = dim_size<0, TileShape>; - static constexpr int kTileCols = dim_size<1, TileShape>; + static constexpr int kTileRows = dim_size<0, TileShape>; + static constexpr int kTileCols = dim_size<1, TileShape>; - // In a column-major layout, columns are the contiguous dimension in memory. - // We enforce the use of 128-bit vectorized instructions for data loading by - // a single thread. This implies that the minimum number of columns should - // be at least 128 bits. - static constexpr int kMinRows = - AccessInfo::kAccessInBits / (sizeof(DType) * 8); + // In a column-major layout, columns are the contiguous dimension in memory. + // We enforce the use of 128-bit vectorized instructions for data loading by + // a single thread. This implies that the minimum number of columns should + // be at least 128 bits. + static constexpr int kMinRows = + AccessInfo::kAccessInBits / (sizeof(DType) * 8); - static_assert(kTileRows >= kMinRows, "The number of rows is too small."); + static_assert(kTileRows >= kMinRows, "The number of rows is too small."); - static_assert(kTileRows < AccessInfo::kExpectedSize || - (kTileRows >= AccessInfo::kExpectedSize && - kTileRows % AccessInfo::kExpectedSize == 0), - "The current implementation requires that the number of " - "rows of the tile be divisible by the cache line width."); + static_assert(kTileRows < AccessInfo::kExpectedSize || + (kTileRows >= AccessInfo::kExpectedSize && + kTileRows % AccessInfo::kExpectedSize == 0), + "The current implementation requires that the number of " + "rows of the tile be divisible by the cache line width."); - static constexpr int kRows = kTileRows >= AccessInfo::kExpectedSize - ? AccessInfo::kExpectedSize - : kTileRows; + static constexpr int kRows = kTileRows >= AccessInfo::kExpectedSize + ? AccessInfo::kExpectedSize + : kTileRows; - // number of rows in a warp - static constexpr int kRowThreads = kRows / AccessInfo::kNumPerAccess; - static_assert(WARP_SIZE % kRowThreads == 0, - "Fail to infer warp thread layout."); - static constexpr int kColThreads = WARP_SIZE / kRowThreads; + // number of rows in a warp + static constexpr int kRowThreads = kRows / AccessInfo::kNumPerAccess; + static_assert(WARP_SIZE % kRowThreads == 0, + "Fail to infer warp thread layout."); + static constexpr int kColThreads = WARP_SIZE / kRowThreads; - static constexpr int kCols = kColThreads; - static_assert(kTileCols % kColThreads == 0, - "The number of columns of the tile isn't evenly divisible by " - "the number of threads in a row."); + static constexpr int kCols = kColThreads; + static_assert(kTileCols % kColThreads == 0, + "The number of columns of the tile isn't evenly divisible by " + "the number of threads in a row."); - static constexpr tl::Layout kType = tl::Layout::kColMajor; - static constexpr int kNumel = kRows * kCols; + static constexpr tl::Layout kType = tl::Layout::kColMajor; + static constexpr int kNumel = kRows * kCols; - using WarpThreadLayout = tl::ColMajor; + using WarpThreadLayout = tl::ColMajor; }; /// @brief Pretty printer for the static shape information of a @@ -120,7 +120,7 @@ struct WarpBaseTileShape { template static HOST std::ostream& operator<<( std::ostream& out, const WarpBaseTileShape& tile) { - BaseTilePrettyPrinter::print(out, tile); - return out; + BaseTilePrettyPrinter::print(out, tile); + return out; } } // namespace tilefusion diff --git a/include/types/global.hpp b/include/types/global.hpp index 765e875d..c9998003 100644 --- a/include/types/global.hpp +++ b/include/types/global.hpp @@ -14,59 +14,59 @@ namespace { /// @brief Helper for pretty printing a GlobalTile's static shape-related /// information. This printer works ONLY on the host. struct GlobalTilePrettyPrinter { - template - static HOST void print(std::ostream& out, const Global& tile) { - // parameter `tile` here is not used - out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", " - << Global::kCols << ", " << Global::kRowStride << ", " - << Global::kColStride << "), numel = " << Global::kNumel; - } + template + static HOST void print(std::ostream& out, const Global& tile) { + // parameter `tile` here is not used + out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", " + << Global::kCols << ", " << Global::kRowStride << ", " + << Global::kColStride << "), numel = " << Global::kNumel; + } }; } // namespace template struct GlobalTile { - using DType = Element_; - using Layout = Layout_; + using DType = Element_; + using Layout = Layout_; - static constexpr int kNumel = tl::get_numel; + static constexpr int kNumel = tl::get_numel; - static constexpr int kRows = tl::num_rows; - static constexpr int kCols = tl::num_cols; + static constexpr int kRows = tl::num_rows; + static constexpr int kCols = tl::num_cols; - static constexpr int kRowStride = tl::row_stride; - static constexpr int kColStride = tl::col_stride; + static constexpr int kRowStride = tl::row_stride; + static constexpr int kColStride = tl::col_stride; - static constexpr tl::Layout kType = tl::layout_type; + static constexpr tl::Layout kType = tl::layout_type; - // This Ctor is to enable the use of the pretty printer of SharedTile in the - // host code. - HOST GlobalTile() : data_(nullptr), layout_(Layout{}) {} + // This Ctor is to enable the use of the pretty printer of SharedTile in the + // host code. + HOST GlobalTile() : data_(nullptr), layout_(Layout{}) {} - DEVICE GlobalTile(DType* data) : data_(data), layout_(Layout{}) {} + DEVICE GlobalTile(DType* data) : data_(data), layout_(Layout{}) {} - DEVICE GlobalTile(const DType* data) - : data_(const_cast(data)), layout_(Layout{}) {} + DEVICE GlobalTile(const DType* data) + : data_(const_cast(data)), layout_(Layout{}) {} - DEVICE DType* mutable_data() { return data_; } + DEVICE DType* mutable_data() { return data_; } - DEVICE const DType* data() const { return data_; } + DEVICE const DType* data() const { return data_; } - HOST_DEVICE const Layout& layout() const { return layout_; } + HOST_DEVICE const Layout& layout() const { return layout_; } - // for write access - DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } + // for write access + DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } - // for read access - DEVICE - const DType& operator()(int x, int y) const { return data_[layout_(x, y)]; } + // for read access + DEVICE + const DType& operator()(int x, int y) const { return data_[layout_(x, y)]; } - DEVICE void dump_value() { util::print_tile(data_, layout_); } + DEVICE void dump_value() { util::print_tile(data_, layout_); } - private: - DType* data_; - Layout layout_; + private: + DType* data_; + Layout layout_; }; /// @brief Pretty printer for the static shape information of a SharedTile. @@ -74,8 +74,8 @@ struct GlobalTile { template static HOST std::ostream& operator<<(std::ostream& out, const GlobalTile& tile) { - GlobalTilePrettyPrinter::print(out, tile); - return out; + GlobalTilePrettyPrinter::print(out, tile); + return out; } } // namespace tilefusion diff --git a/include/types/global_tile_iterator.hpp b/include/types/global_tile_iterator.hpp index cfcad62f..3553689f 100644 --- a/include/types/global_tile_iterator.hpp +++ b/include/types/global_tile_iterator.hpp @@ -13,15 +13,15 @@ namespace { /// @brief Helper for pretty printing a tile iterator's static shape-related /// information. This printer works ONLY on the host. struct GTileIteratorPrettyPrinter { - template - static HOST void print(std::ostream& out, const TileIterator& itr) { - size_t size1 = dim_size<0, typename TileIterator::ChunkShape>; - size_t size2 = dim_size<1, typename TileIterator::ChunkShape>; - - out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = (" - << size1 << ", " << size2 << "), stripe count = (" - << TileIterator::sc0 << ", " << TileIterator::sc1 << ")"; - } + template + static HOST void print(std::ostream& out, const TileIterator& itr) { + size_t size1 = dim_size<0, typename TileIterator::ChunkShape>; + size_t size2 = dim_size<1, typename TileIterator::ChunkShape>; + + out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = (" + << size1 << ", " << size2 << "), stripe count = (" << TileIterator::sc0 + << ", " << TileIterator::sc1 << ")"; + } }; } // namespace @@ -32,122 +32,122 @@ struct GTileIteratorPrettyPrinter { /// tile is partitioned (chunk shape). template class GTileIterator { - public: - using Tile = Tile_; - using DType = Tile::DType; - using ChunkShape = ChunkShape_; - - static_assert(Tile::kRows >= dim_size<0, ChunkShape>, - "Tile::kRows must be >= dim_size<0, ChunkShape>"); - static_assert(Tile::kCols >= dim_size<1, ChunkShape>, - "Tile::kCols must be >= dim_size<1, ChunkShape>"); - - static constexpr int kStride0 = dim_size<0, ChunkShape>; - static constexpr int kStride1 = dim_size<1, ChunkShape>; - - static constexpr int sc0 = Tile::kRows / kStride0; - static constexpr int sc1 = Tile::kCols / kStride1; - - HOST_DEVICE GTileIterator() : data_(nullptr) {} - - DEVICE GTileIterator(DType* data) : data_(data) {} - - DEVICE GTileIterator(const DType* data) : data_(const_cast(data)) {} - - // Since a Tile is considered to be at most a 2D array, the iterator - // traverses over these two dimensions. The current rules are: - // 1. If the index is a 2D integer, this access is considered to be a - // single tile, hence it returns a Tile. - // 2. If any part of the index is an underscore, this access is - // considered to be a slice, naturally it returns a TileIterator. - DEVICE auto operator()(int i) { - assert(data_); // The iterator is not initialized. - static_assert(sc0 == 1 || sc1 == 1, - "A single index is supported only when the strip count " - "of one of the iterator's dimensions is 1."); - - int x = sc0 == 1 ? 0 : i; - int y = sc0 == 1 ? i : 0; - - using TileLayout = tl::MatrixLayout; - using NewTile = GlobalTile; - - int offset = Tile::kType == tl::Layout::kRowMajor - ? x * (kStride0 * Tile::kRowStride) + y * kStride1 - : x * kStride0 + y * (Tile::kColStride * kStride1); - - NewTile tile(data_ + offset); - - return tile; - } - - DEVICE auto operator()(int x, int y) { - assert(data_); // The iterator is not initialized. - assert(x < sc0 && y < sc1); // indices must be within the strip count. - - using TileLayout = tl::MatrixLayout; - using NewTile = GlobalTile; - - int offset = Tile::kType == tl::Layout::kRowMajor - ? x * (kStride0 * Tile::kRowStride) + y * kStride1 - : x * kStride0 + y * (Tile::kColStride * kStride1); - NewTile tile(data_ + offset); - - return tile; - } - - DEVICE auto operator()(int x, const Underscore& y) { - assert(data_); // The iterator is not initialized. - assert(x < sc0); // index must be within the strip count. - - // Updated the layout for sub-tiles accessed by the sliced iterator. - // Note: Only the shape changes; the stride remains the same. - using TileLayout = tl::MatrixLayout; - using NewTile = GlobalTile; - using Iter = GTileIterator; - static_assert(Iter::sc0 == 1); - - // advance pointer to the correct start position - int offset = Tile::kType == tl::Layout::kRowMajor - ? x * (kStride0 * Tile::kCols) - : x * kStride0; - - Iter iter(data_ + offset); - return iter; - } - - DEVICE auto operator()(const Underscore& x, int y) { - assert(data_); // The iterator is not initialized. - assert(y < sc1); // index must be within the strip count. - - // Updated the layout for sub-tiles accessed by the sliced iterator. - // Note: Only the shape changes; the stride remains the same. - - using TileLayout = tl::MatrixLayout; - using NewTile = GlobalTile; - using Iter = GTileIterator; - static_assert(Iter::sc1 == 1); - - // advance pointer to the correct start position - int offset = Tile::kType == tl::Layout::kRowMajor - ? y * kStride1 - : y * (Tile::kRows * kStride1); - - Iter iter(data_ + offset); - return iter; - } - - DEVICE auto to_tile() { - Tile tile(data_); - return tile; - } - - private: - DType* data_; + public: + using Tile = Tile_; + using DType = Tile::DType; + using ChunkShape = ChunkShape_; + + static_assert(Tile::kRows >= dim_size<0, ChunkShape>, + "Tile::kRows must be >= dim_size<0, ChunkShape>"); + static_assert(Tile::kCols >= dim_size<1, ChunkShape>, + "Tile::kCols must be >= dim_size<1, ChunkShape>"); + + static constexpr int kStride0 = dim_size<0, ChunkShape>; + static constexpr int kStride1 = dim_size<1, ChunkShape>; + + static constexpr int sc0 = Tile::kRows / kStride0; + static constexpr int sc1 = Tile::kCols / kStride1; + + HOST_DEVICE GTileIterator() : data_(nullptr) {} + + DEVICE GTileIterator(DType* data) : data_(data) {} + + DEVICE GTileIterator(const DType* data) : data_(const_cast(data)) {} + + // Since a Tile is considered to be at most a 2D array, the iterator + // traverses over these two dimensions. The current rules are: + // 1. If the index is a 2D integer, this access is considered to be a + // single tile, hence it returns a Tile. + // 2. If any part of the index is an underscore, this access is + // considered to be a slice, naturally it returns a TileIterator. + DEVICE auto operator()(int i) { + assert(data_); // The iterator is not initialized. + static_assert(sc0 == 1 || sc1 == 1, + "A single index is supported only when the strip count " + "of one of the iterator's dimensions is 1."); + + int x = sc0 == 1 ? 0 : i; + int y = sc0 == 1 ? i : 0; + + using TileLayout = tl::MatrixLayout; + using NewTile = GlobalTile; + + int offset = Tile::kType == tl::Layout::kRowMajor + ? x * (kStride0 * Tile::kRowStride) + y * kStride1 + : x * kStride0 + y * (Tile::kColStride * kStride1); + + NewTile tile(data_ + offset); + + return tile; + } + + DEVICE auto operator()(int x, int y) { + assert(data_); // The iterator is not initialized. + assert(x < sc0 && y < sc1); // indices must be within the strip count. + + using TileLayout = tl::MatrixLayout; + using NewTile = GlobalTile; + + int offset = Tile::kType == tl::Layout::kRowMajor + ? x * (kStride0 * Tile::kRowStride) + y * kStride1 + : x * kStride0 + y * (Tile::kColStride * kStride1); + NewTile tile(data_ + offset); + + return tile; + } + + DEVICE auto operator()(int x, const Underscore& y) { + assert(data_); // The iterator is not initialized. + assert(x < sc0); // index must be within the strip count. + + // Updated the layout for sub-tiles accessed by the sliced iterator. + // Note: Only the shape changes; the stride remains the same. + using TileLayout = tl::MatrixLayout; + using NewTile = GlobalTile; + using Iter = GTileIterator; + static_assert(Iter::sc0 == 1); + + // advance pointer to the correct start position + int offset = Tile::kType == tl::Layout::kRowMajor + ? x * (kStride0 * Tile::kCols) + : x * kStride0; + + Iter iter(data_ + offset); + return iter; + } + + DEVICE auto operator()(const Underscore& x, int y) { + assert(data_); // The iterator is not initialized. + assert(y < sc1); // index must be within the strip count. + + // Updated the layout for sub-tiles accessed by the sliced iterator. + // Note: Only the shape changes; the stride remains the same. + + using TileLayout = tl::MatrixLayout; + using NewTile = GlobalTile; + using Iter = GTileIterator; + static_assert(Iter::sc1 == 1); + + // advance pointer to the correct start position + int offset = Tile::kType == tl::Layout::kRowMajor + ? y * kStride1 + : y * (Tile::kRows * kStride1); + + Iter iter(data_ + offset); + return iter; + } + + DEVICE auto to_tile() { + Tile tile(data_); + return tile; + } + + private: + DType* data_; }; /// @brief Pretty printer for the static shape information of a TileIterator. @@ -155,8 +155,8 @@ class GTileIterator { template static HOST std::ostream& operator<<( std::ostream& out, const GTileIterator& itr) { - GTileIteratorPrettyPrinter::print(out, itr); - return out; + GTileIteratorPrettyPrinter::print(out, itr); + return out; } } // namespace tilefusion diff --git a/include/types/layout.hpp b/include/types/layout.hpp index abfc6b26..30721d20 100644 --- a/include/types/layout.hpp +++ b/include/types/layout.hpp @@ -18,31 +18,31 @@ namespace tilefusion::tile_layout { */ enum class Layout { - kRowMajor = 0, - kColMajor = 1, + kRowMajor = 0, + kColMajor = 1, }; HOST_DEVICE const char* layout_type_to_str(Layout type) { - switch (type) { - case Layout::kRowMajor: - return "RowMajor"; - case Layout::kColMajor: - return "ColMajor"; - } - return "UnsupportedLayout"; + switch (type) { + case Layout::kRowMajor: + return "RowMajor"; + case Layout::kColMajor: + return "ColMajor"; + } + return "UnsupportedLayout"; } namespace { /// @brief Helper for pretty printing a matrix layout's static shape-related /// information. This printer works ONLY on the host. struct MatrixLayoutPrettyPrinter { - template - static HOST void print(std::ostream& out, const Layout& layout) { - out << layout_type_to_str(Layout::kType) << "<" << Layout::kRows << ", " - << Layout::kCols << ">, Strides<" << Layout::kRowStride << ", " - << Layout::kColStride << ">, Numel = " << Layout::kNumel; - } + template + static HOST void print(std::ostream& out, const Layout& layout) { + out << layout_type_to_str(Layout::kType) << "<" << Layout::kRows << ", " + << Layout::kCols << ">, Strides<" << Layout::kRowStride << ", " + << Layout::kColStride << ">, Numel = " << Layout::kNumel; + } }; } // namespace @@ -51,19 +51,19 @@ template struct MatrixLayout { - static constexpr int kRows = kRows_; - static constexpr int kCols = kCols_; + static constexpr int kRows = kRows_; + static constexpr int kCols = kCols_; - static constexpr int kRowStride = kRowStride_; - static constexpr int kColStride = kColStride_; + static constexpr int kRowStride = kRowStride_; + static constexpr int kColStride = kColStride_; - static constexpr int kNumel = kRows * kCols; + static constexpr int kNumel = kRows * kCols; - static constexpr Layout kType = kType_; + static constexpr Layout kType = kType_; - HOST_DEVICE int operator()(int i, int j) const { - return i * kRowStride + j * kColStride; - } + HOST_DEVICE int operator()(int i, int j) const { + return i * kRowStride + j * kColStride; + } }; /// @brief Pretty printer for the static shape information of a MatrixLayout. @@ -73,8 +73,8 @@ template & layout) { - MatrixLayoutPrettyPrinter::print(out, layout); - return out; + MatrixLayoutPrettyPrinter::print(out, layout); + return out; } // In the row major layout, the contiguous dimension in memory is the @@ -110,89 +110,87 @@ static constexpr Layout layout_type = Layout_::kType; template struct is_row_major { - static constexpr bool value = Layout_::kType == Layout::kRowMajor; + static constexpr bool value = Layout_::kType == Layout::kRowMajor; }; template struct is_col_major { - static constexpr bool value = Layout_::kType == Layout::kColMajor; + static constexpr bool value = Layout_::kType == Layout::kColMajor; }; template struct is_contiguous { - static constexpr bool value = - is_row_major::value - ? (Layout::kRowStride == Layout::kCols && Layout::kColStride == 1) - : (Layout::kColStride == Layout::kRows && Layout::kRowStride == 1); + static constexpr bool value = + is_row_major::value + ? (Layout::kRowStride == Layout::kCols && Layout::kColStride == 1) + : (Layout::kColStride == Layout::kRows && Layout::kRowStride == 1); }; template struct BlockMatrxLayout { - using InnerLayout = InnerLayout_; - using OuterLayout = OuterLayout_; + using InnerLayout = InnerLayout_; + using OuterLayout = OuterLayout_; - static constexpr int kRows = OuterLayout_::kRows; - static constexpr int kCols = OuterLayout_::kCols; - static constexpr int kNumel = OuterLayout_::kNumel; + static constexpr int kRows = OuterLayout_::kRows; + static constexpr int kCols = OuterLayout_::kCols; + static constexpr int kNumel = OuterLayout_::kNumel; - static constexpr int kInnerRows = InnerLayout_::kRows; - static constexpr int kInnerCols = InnerLayout_::kCols; + static constexpr int kInnerRows = InnerLayout_::kRows; + static constexpr int kInnerCols = InnerLayout_::kCols; - static_assert(kRows % kInnerRows == 0, - "OuterLayout rows must be divisible by InnerLayout rows"); - static_assert(kCols % kInnerCols == 0, - "OuterLayout cols must be divisible by InnerLayout cols"); + static_assert(kRows % kInnerRows == 0, + "OuterLayout rows must be divisible by InnerLayout rows"); + static_assert(kCols % kInnerCols == 0, + "OuterLayout cols must be divisible by InnerLayout cols"); - static constexpr int kInnerNumel = InnerLayout_::kNumel; - static constexpr Layout kType = OuterLayout::kType; + static constexpr int kInnerNumel = InnerLayout_::kNumel; + static constexpr Layout kType = OuterLayout::kType; - static constexpr int kTileRows = kRows / kInnerRows; - static constexpr int kTileCols = kCols / kInnerCols; + static constexpr int kTileRows = kRows / kInnerRows; + static constexpr int kTileCols = kCols / kInnerCols; - static constexpr bool kIsRowMajor = is_row_major::value; - static constexpr bool kIsContiguous = is_contiguous::value; + static constexpr bool kIsRowMajor = is_row_major::value; + static constexpr bool kIsContiguous = is_contiguous::value; - static constexpr int kRowStride = - kIsContiguous ? (kIsRowMajor ? kTileCols * kInnerNumel : kInnerNumel) - : OuterLayout::kRowStride; - static constexpr int kColStride = - kIsContiguous ? (kIsRowMajor ? kInnerNumel : kTileRows * kInnerNumel) - : OuterLayout::kColStride; + static constexpr int kRowStride = + kIsContiguous ? (kIsRowMajor ? kTileCols * kInnerNumel : kInnerNumel) + : OuterLayout::kRowStride; + static constexpr int kColStride = + kIsContiguous ? (kIsRowMajor ? kInnerNumel : kTileRows * kInnerNumel) + : OuterLayout::kColStride; - HOST_DEVICE int operator()(int i, int j) const { - const int outer_i = RowDivMod::div(i); - const int outer_j = ColDivMod::div(j); + HOST_DEVICE int operator()(int i, int j) const { + const int outer_i = RowDivMod::div(i); + const int outer_j = ColDivMod::div(j); - const int inner_i = RowDivMod::mod(i); - const int inner_j = ColDivMod::mod(j); + const int inner_i = RowDivMod::mod(i); + const int inner_j = ColDivMod::mod(j); - return outer_(outer_i, outer_j) + inner_(inner_i, inner_j); - } + return outer_(outer_i, outer_j) + inner_(inner_i, inner_j); + } - HOST_DEVICE void dump() const { - for (int i = 0; i < kRows; ++i) { - for (int j = 0; j < kCols; ++j) { - printf("%d, ", operator()(i, j)); - } - printf("\n"); - } + HOST_DEVICE void dump() const { + for (int i = 0; i < kRows; ++i) { + for (int j = 0; j < kCols; ++j) { + printf("%d, ", operator()(i, j)); + } + printf("\n"); } + } - HOST auto get_outer_layout() const { return decltype(outer_){}; } + HOST auto get_outer_layout() const { return decltype(outer_){}; } - private: - static constexpr bool kInnerRowsIsPow2 = - (kInnerRows & (kInnerRows - 1)) == 0; - static constexpr bool kInnerColsIsPow2 = - (kInnerCols & (kInnerCols - 1)) == 0; + private: + static constexpr bool kInnerRowsIsPow2 = (kInnerRows & (kInnerRows - 1)) == 0; + static constexpr bool kInnerColsIsPow2 = (kInnerCols & (kInnerCols - 1)) == 0; - using RowDivMod = DivModSelector; - using ColDivMod = DivModSelector; + using RowDivMod = DivModSelector; + using ColDivMod = DivModSelector; - using BlockOuter = MatrixLayout; - BlockOuter outer_; - InnerLayout inner_; + using BlockOuter = MatrixLayout; + BlockOuter outer_; + InnerLayout inner_; }; /// @brief Pretty printer for BlockMatrxLayout @@ -200,11 +198,11 @@ template static HOST std::ostream& operator<<( std::ostream& out, const BlockMatrxLayout& layout) { - out << "BlockMatrixLayout {" << std::endl - << " Outer: " << layout.get_outer_layout() << ", " << std::endl - << " Inner: " << InnerLayout_{} << std::endl - << " }"; - return out; + out << "BlockMatrixLayout {" << std::endl + << " Outer: " << layout.get_outer_layout() << ", " << std::endl + << " Inner: " << InnerLayout_{} << std::endl + << " }"; + return out; } template @@ -221,15 +219,15 @@ concept BlockMixedLayout = (is_col_major::value && is_row_major::value); template - requires BlockRowMajorLayout + requires BlockRowMajorLayout using BlockRowMajor = BlockMatrxLayout; template - requires BlockColMajorLayout + requires BlockColMajorLayout using BlockColMajor = BlockMatrxLayout; template - requires BlockMixedLayout + requires BlockMixedLayout using BlockMixed = BlockMatrxLayout; } // namespace tilefusion::tile_layout diff --git a/include/types/packing.hpp b/include/types/packing.hpp index b64060d1..51b5b9c7 100644 --- a/include/types/packing.hpp +++ b/include/types/packing.hpp @@ -9,16 +9,16 @@ struct Packing; template <> struct Packing<__half, 2> { - static constexpr int kDateBytes = 2; - static constexpr int kPackedBytes = 4; - using PackedType = int; + static constexpr int kDateBytes = 2; + static constexpr int kPackedBytes = 4; + using PackedType = int; }; template <> struct Packing { - static constexpr int kDateBytes = 4; - static constexpr int kPackedBytes = 4; - using PackedType = int2; + static constexpr int kDateBytes = 4; + static constexpr int kPackedBytes = 4; + using PackedType = int2; }; } // namespace tilefusion diff --git a/include/types/register.hpp b/include/types/register.hpp index 99ade752..0ff0e83e 100644 --- a/include/types/register.hpp +++ b/include/types/register.hpp @@ -47,81 +47,81 @@ constexpr int get_cols<__fp8_e5m2> = 1; /// @brief Helper for pretty printing a register tile's static shape /// information. This printer works ONLY on the host. struct RegTilePrettyPrinter { - template - static HOST void print(std::ostream& out, const Tile& tile) { - out << layout_type_to_str(Tile::kType) << "(" - << Tile::kRows * get_rows << ", " - << Tile::kCols * get_cols << ")"; - } + template + static HOST void print(std::ostream& out, const Tile& tile) { + out << layout_type_to_str(Tile::kType) << "(" + << Tile::kRows * get_rows << ", " + << Tile::kCols * get_cols << ")"; + } }; DEVICE void clear(float* data, int numel) { - memset((void*)data, 0, sizeof(float) * numel); + memset((void*)data, 0, sizeof(float) * numel); } DEVICE void clear(__half* data, int numel) { - memset((void*)data, 0, sizeof(__half) * numel); + memset((void*)data, 0, sizeof(__half) * numel); } #ifdef CUDA_FP8_AVAILABLE DEVICE void clear(__fp8_e4m3* data, int numel) { - memset((void*)data, 0, sizeof(__fp8_e4m3) * numel); + memset((void*)data, 0, sizeof(__fp8_e4m3) * numel); } DEVICE void clear(__fp8_e5m2* data, int numel) { - memset((void*)data, 0, sizeof(__fp8_e5m2) * numel); + memset((void*)data, 0, sizeof(__fp8_e5m2) * numel); } #endif template DEVICE void clear_impl(DType* data, int numel) { - for (int i = 0; i < numel; ++i) { - clear(data[i].mutable_data(), 8); - } + for (int i = 0; i < numel; ++i) { + clear(data[i].mutable_data(), 8); + } } } // namespace template class RegTile { - public: - using DType = Element_; - using Layout = Layout_; + public: + using DType = Element_; + using Layout = Layout_; - static constexpr int kNumel = Layout::kNumel; - static constexpr int kRows = Layout::kRows; - static constexpr int kCols = Layout::kCols; + static constexpr int kNumel = Layout::kNumel; + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; - // FIXME(haruhi): this is a hack to fix the layout type deduction for when - // the shape is 1x1. This is a workaround. Fix this to be more robust. - static constexpr tl::Layout kType = tl::layout_type; + // FIXME(haruhi): this is a hack to fix the layout type deduction for when + // the shape is 1x1. This is a workaround. Fix this to be more robust. + static constexpr tl::Layout kType = tl::layout_type; - DEVICE RegTile() : layout_(Layout{}) { - memset((void*)data_, 0, sizeof(data_)); - } + DEVICE RegTile() : layout_(Layout{}) { + memset((void*)data_, 0, sizeof(data_)); + } - DEVICE DType* mutable_data() { return (DType*)data_; } + DEVICE DType* mutable_data() { return (DType*)data_; } - DEVICE const DType* data() const { return (DType*)data_; } + DEVICE const DType* data() const { return (DType*)data_; } - DEVICE const Layout& layout() const { return layout_; } + DEVICE const Layout& layout() const { return layout_; } - // for write access - DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } + // for write access + DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } - // for read access - DEVICE const DType& operator()(int x, int y) const { - return data_[layout_(x, y)]; - } + // for read access + DEVICE const DType& operator()(int x, int y) const { + return data_[layout_(x, y)]; + } - DEVICE void dump_value() const { - util::print_tile(const_cast(data_), layout_); - } + DEVICE void dump_value() const { + util::print_tile(const_cast(data_), layout_); + } - DEVICE void clear() { clear_impl(data_, kNumel); } + DEVICE void clear() { clear_impl(data_, kNumel); } - private: - DType data_[kNumel]; - Layout layout_; + private: + DType data_[kNumel]; + Layout layout_; }; template @@ -140,8 +140,8 @@ using BaseTileColMajor = RegTile>; template static HOST std::ostream& operator<<(std::ostream& out, const RegTile& tile) { - RegTilePrettyPrinter::print(out, tile); - return out; + RegTilePrettyPrinter::print(out, tile); + return out; } } // namespace tilefusion diff --git a/include/types/shared.hpp b/include/types/shared.hpp index fbb04a95..d1a32097 100644 --- a/include/types/shared.hpp +++ b/include/types/shared.hpp @@ -16,15 +16,15 @@ namespace { /// @brief Helper for pretty printing a SharedTile's static shape-related /// information. This printer works ONLY on the host. struct SharedTilePrettyPrinter { - template - static HOST void print(std::ostream& out, const Shared& tile) { - // parameter `tile` here is not used - auto swizzled = Shared::kSwizzled ? "swizzled" : "non-swizzled"; - out << "SharedTile {" << std::endl - << " " << typename Shared::Layout{} << std::endl - << " Swizzled = " << swizzled << std::endl - << "}"; - } + template + static HOST void print(std::ostream& out, const Shared& tile) { + // parameter `tile` here is not used + auto swizzled = Shared::kSwizzled ? "swizzled" : "non-swizzled"; + out << "SharedTile {" << std::endl + << " " << typename Shared::Layout{} << std::endl + << " Swizzled = " << swizzled << std::endl + << "}"; + } }; } // namespace @@ -32,123 +32,121 @@ struct SharedTilePrettyPrinter { template class SharedTile { - public: - using DType = Element_; - using Layout = Layout_; + public: + using DType = Element_; + using Layout = Layout_; - static constexpr int kRows = Layout::kRows; - static constexpr int kCols = Layout::kCols; - static constexpr int kRowStride = Layout::kRowStride; - static constexpr int kColStride = Layout::kColStride; + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; + static constexpr int kRowStride = Layout::kRowStride; + static constexpr int kColStride = Layout::kColStride; - static constexpr tl::Layout kType = Layout::kType; - static constexpr int kNumel = Layout::kNumel; + static constexpr tl::Layout kType = Layout::kType; + static constexpr int kNumel = Layout::kNumel; - static constexpr bool isRowMajor = tl::is_row_major::value; + static constexpr bool isRowMajor = tl::is_row_major::value; - static constexpr int SwizzleBytes = SwizzleBytes_; - static constexpr bool kSwizzled = kSwizzled_; + static constexpr int SwizzleBytes = SwizzleBytes_; + static constexpr bool kSwizzled = kSwizzled_; - using SwizzleBaseShape = SwizzleBaseTileShape; + using SwizzleBaseShape = SwizzleBaseTileShape; - static constexpr int kSwizzleRows = - isRowMajor ? SwizzleBaseShape::kRows : SwizzleBaseShape::kCols; - static constexpr int kSwizzleCols = - isRowMajor ? SwizzleBaseShape::kCols : SwizzleBaseShape::kRows; + static constexpr int kSwizzleRows = + isRowMajor ? SwizzleBaseShape::kRows : SwizzleBaseShape::kCols; + static constexpr int kSwizzleCols = + isRowMajor ? SwizzleBaseShape::kCols : SwizzleBaseShape::kRows; - using NonSwizzled = std::conditional_t< - isRowMajor, tl::MatrixLayout, - tl::MatrixLayout>; + using NonSwizzled = std::conditional_t< + isRowMajor, tl::MatrixLayout, + tl::MatrixLayout>; - using Swizzled = SwizzledLayout< - NonSwizzled, - Swizzle, - kType>; + using Swizzled = SwizzledLayout< + NonSwizzled, + Swizzle, + kType>; - using InTileLayout = std::conditional_t; + using InTileLayout = std::conditional_t; - using TileLayout = std::conditional_t< - isRowMajor, - tl::MatrixLayout, - tl::MatrixLayout>; + using TileLayout = std::conditional_t< + isRowMajor, + tl::MatrixLayout, + tl::MatrixLayout>; - InTileLayout in_tile_layout_; - TileLayout tile_layout_; + InTileLayout in_tile_layout_; + TileLayout tile_layout_; - // This Ctor is to enable the use of the pretty printer of SharedTile - // in the host code. - DEVICE SharedTile() : data_(nullptr), layout_(Layout{}), offset_(0) {} + // This Ctor is to enable the use of the pretty printer of SharedTile + // in the host code. + DEVICE SharedTile() : data_(nullptr), layout_(Layout{}), offset_(0) {} - DEVICE SharedTile(DType* data) - : data_(data), layout_(Layout{}), offset_(0) {} + DEVICE SharedTile(DType* data) : data_(data), layout_(Layout{}), offset_(0) {} - DEVICE SharedTile(const DType* data) - : data_(const_cast(data)), layout_(Layout{}), offset_(0) {} + DEVICE SharedTile(const DType* data) + : data_(const_cast(data)), layout_(Layout{}), offset_(0) {} - DEVICE SharedTile(DType* data, int offset) - : data_(data), layout_(Layout{}), offset_(offset) {} + DEVICE SharedTile(DType* data, int offset) + : data_(data), layout_(Layout{}), offset_(offset) {} - DEVICE DType* mutable_data() { return data_; } + DEVICE DType* mutable_data() { return data_; } - DEVICE const DType* data() const { return data_; } + DEVICE const DType* data() const { return data_; } - DEVICE int get_offset() const { return offset_; } + DEVICE int get_offset() const { return offset_; } - HOST_DEVICE const Layout& layout() const { return layout_; } + HOST_DEVICE const Layout& layout() const { return layout_; } - // for write access - DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } + // for write access + DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } - // for read access - DEVICE - const DType& operator()(int x, int y) const { return data_[layout_(x, y)]; } + // for read access + DEVICE + const DType& operator()(int x, int y) const { return data_[layout_(x, y)]; } - DEVICE int fetch_physical_offset(int offset) { - return swizzle_offset(offset); - } + DEVICE int fetch_physical_offset(int offset) { + return swizzle_offset(offset); + } - DEVICE void dump_value() { util::print_tile(data_, layout_); } + DEVICE void dump_value() { util::print_tile(data_, layout_); } - private: - DType* data_; - Layout layout_; - int offset_; + private: + DType* data_; + Layout layout_; + int offset_; - DEVICE int2 swizzle_tile_id(int offset) { - int swizzle_tile_row = kType == tl::Layout::kRowMajor - ? (offset / kRowStride) / kSwizzleRows - : (offset % kColStride) / kSwizzleRows; + DEVICE int2 swizzle_tile_id(int offset) { + int swizzle_tile_row = kType == tl::Layout::kRowMajor + ? (offset / kRowStride) / kSwizzleRows + : (offset % kColStride) / kSwizzleRows; - int swizzle_tile_col = kType == tl::Layout::kRowMajor - ? (offset % kRowStride) / kSwizzleCols - : (offset / kColStride) / kSwizzleCols; + int swizzle_tile_col = kType == tl::Layout::kRowMajor + ? (offset % kRowStride) / kSwizzleCols + : (offset / kColStride) / kSwizzleCols; - return make_int2(swizzle_tile_row, swizzle_tile_col); - } + return make_int2(swizzle_tile_row, swizzle_tile_col); + } - DEVICE int2 in_swizzle_tile_id(int offset) { - int row = kType == tl::Layout::kRowMajor ? offset / kRowStride - : offset % kColStride; - int col = kType == tl::Layout::kRowMajor ? offset % kRowStride - : offset / kColStride; + DEVICE int2 in_swizzle_tile_id(int offset) { + int row = kType == tl::Layout::kRowMajor ? offset / kRowStride + : offset % kColStride; + int col = kType == tl::Layout::kRowMajor ? offset % kRowStride + : offset / kColStride; - int in_swizzle_tile_row = row % kSwizzleRows; - int in_swizzle_tile_col = col % kSwizzleCols; + int in_swizzle_tile_row = row % kSwizzleRows; + int in_swizzle_tile_col = col % kSwizzleCols; - return make_int2(in_swizzle_tile_row, in_swizzle_tile_col); - } + return make_int2(in_swizzle_tile_row, in_swizzle_tile_col); + } - DEVICE int swizzle_offset(int offset) { - auto tile_id = swizzle_tile_id(offset); - auto in_tile_id = in_swizzle_tile_id(offset); - int swizzle_tile_offset = tile_layout_(tile_id.x, tile_id.y); - int in_swizzle_tile_offset = - in_tile_layout_(in_tile_id.x, in_tile_id.y); + DEVICE int swizzle_offset(int offset) { + auto tile_id = swizzle_tile_id(offset); + auto in_tile_id = in_swizzle_tile_id(offset); + int swizzle_tile_offset = tile_layout_(tile_id.x, tile_id.y); + int in_swizzle_tile_offset = in_tile_layout_(in_tile_id.x, in_tile_id.y); - return swizzle_tile_offset + in_swizzle_tile_offset; - } + return swizzle_tile_offset + in_swizzle_tile_offset; + } }; /// @brief Pretty printer for the static shape information of a SharedTile. @@ -156,8 +154,8 @@ class SharedTile { template static HOST std::ostream& operator<<( std::ostream& out, const SharedTile& tile) { - SharedTilePrettyPrinter::print(out, tile); - return out; + SharedTilePrettyPrinter::print(out, tile); + return out; } } // namespace tilefusion diff --git a/include/types/shared_tile_iterator.hpp b/include/types/shared_tile_iterator.hpp index 390a8305..9fdc40b9 100644 --- a/include/types/shared_tile_iterator.hpp +++ b/include/types/shared_tile_iterator.hpp @@ -15,28 +15,26 @@ namespace { /// @brief Helper for pretty printing a tile iterator's static shape-related /// information. This printer works ONLY on the host. struct STileIteratorPrettyPrinter { - template - static HOST void print(std::ostream& out, const TileIterator& itr) { - out << "SharedTileIterator {" << std::endl - << " ChunkShape = (" << TileIterator::kChunkRows << ", " - << TileIterator::kChunkCols << "), stripe count = (" - << TileIterator::sc0 << ", " << TileIterator::sc1 << ")" - << std::endl - << "}"; - } + template + static HOST void print(std::ostream& out, const TileIterator& itr) { + out << "SharedTileIterator {" << std::endl + << " ChunkShape = (" << TileIterator::kChunkRows << ", " + << TileIterator::kChunkCols << "), stripe count = (" + << TileIterator::sc0 << ", " << TileIterator::sc1 << ")" << std::endl + << "}"; + } }; /// @brief Helper for pretty printing STileIterator2's static shape information struct STileIterator2PrettyPrinter { - template - static HOST void print(std::ostream& out, const TileIterator& itr) { - out << "SharedTileIterator2 {" << std::endl - << " ChunkShape = (" << TileIterator::kChunkRows << ", " - << TileIterator::kChunkCols << "), stripe count = (" - << TileIterator::sc0 << ", " << TileIterator::sc1 << ")" - << std::endl - << "}"; - } + template + static HOST void print(std::ostream& out, const TileIterator& itr) { + out << "SharedTileIterator2 {" << std::endl + << " ChunkShape = (" << TileIterator::kChunkRows << ", " + << TileIterator::kChunkCols << "), stripe count = (" + << TileIterator::sc0 << ", " << TileIterator::sc1 << ")" << std::endl + << "}"; + } }; /// @brief Type trait to detect if a layout is a BlockMatrxLayout @@ -59,19 +57,18 @@ struct SubTileLayoutCreator; /// preserve the block structure template struct SubTileLayoutCreator { - using OuterLayout = - tl::MatrixLayout; - using type = - tl::BlockMatrxLayout; + using OuterLayout = + tl::MatrixLayout; + using type = + tl::BlockMatrxLayout; }; /// @brief Specialization for simple MatrixLayout template struct SubTileLayoutCreator { - using type = - tl::MatrixLayout; + using type = tl::MatrixLayout; }; template @@ -86,107 +83,105 @@ using SubTileLayout_t = /// @param ChunkShape_ The shape of the smaller tiles (chunk shape) template class STileIterator { - public: - using Tile = Tile_; - using DType = Tile::DType; - using ChunkShape = ChunkShape_; + public: + using Tile = Tile_; + using DType = Tile::DType; + using ChunkShape = ChunkShape_; - // FIXME(ying): a hotfix. The akwared dependencies on mma will be removed - // in future refactor. - using MmaAtom = MmaAtom<__half, __half, __half, MMA_ATOM_16x16x16>; - using BaseShape = typename MmaAtom::BaseTile; + // FIXME(ying): a hotfix. The akwared dependencies on mma will be removed + // in future refactor. + using MmaAtom = MmaAtom<__half, __half, __half, MMA_ATOM_16x16x16>; + using BaseShape = typename MmaAtom::BaseTile; - static constexpr int kChunkRows = dim_size<0, ChunkShape>; - static constexpr int kChunkCols = dim_size<1, ChunkShape>; + static constexpr int kChunkRows = dim_size<0, ChunkShape>; + static constexpr int kChunkCols = dim_size<1, ChunkShape>; - static_assert(Tile::kRows >= kChunkRows, - "Tile::kRows must be >= kChunkRows"); - static_assert(Tile::kCols >= kChunkCols, - "Tile::kCols must be >= kChunkCols"); + static_assert(Tile::kRows >= kChunkRows, "Tile::kRows must be >= kChunkRows"); + static_assert(Tile::kCols >= kChunkCols, "Tile::kCols must be >= kChunkCols"); - static constexpr int sc0 = Tile::kRows / kChunkRows; - static constexpr int sc1 = Tile::kCols / kChunkCols; + static constexpr int sc0 = Tile::kRows / kChunkRows; + static constexpr int sc1 = Tile::kCols / kChunkCols; - HOST_DEVICE STileIterator() : data_(nullptr) {} + HOST_DEVICE STileIterator() : data_(nullptr) {} - DEVICE explicit STileIterator(DType* data) : data_(data) {} + DEVICE explicit STileIterator(DType* data) : data_(data) {} - DEVICE explicit STileIterator(const DType* data) - : data_(const_cast(data)) {} + DEVICE explicit STileIterator(const DType* data) + : data_(const_cast(data)) {} - /// @brief Access a single sub-tile by linear index - /// @param i Linear index of the sub-tile - /// @return A new tile representing the sub-tile - DEVICE auto operator()(int i) { - static_assert(sc0 == 1 || sc1 == 1, - "A single index is supported only when the strip count " - "of one of the iterator's dimensions is 1."); + /// @brief Access a single sub-tile by linear index + /// @param i Linear index of the sub-tile + /// @return A new tile representing the sub-tile + DEVICE auto operator()(int i) { + static_assert(sc0 == 1 || sc1 == 1, + "A single index is supported only when the strip count " + "of one of the iterator's dimensions is 1."); - assert(data_ != nullptr); + assert(data_ != nullptr); - const int x = sc0 == 1 ? 0 : i; - const int y = sc0 == 1 ? i : 0; + const int x = sc0 == 1 ? 0 : i; + const int y = sc0 == 1 ? i : 0; - using TileLayout = tl::MatrixLayout; - using NewTile = - SharedTile; + using TileLayout = tl::MatrixLayout; + using NewTile = + SharedTile; - const int offset = compute_offset(x, y); - return NewTile(data_ + offset, offset); - } + const int offset = compute_offset(x, y); + return NewTile(data_ + offset, offset); + } - DEVICE auto operator()(int x, int y) { - assert(false && "Not implemented yet."); - return 0; - } + DEVICE auto operator()(int x, int y) { + assert(false && "Not implemented yet."); + return 0; + } - DEVICE auto operator()(int x, const Underscore& y) { - assert(false && "Not implemented yet."); - return 0; - } + DEVICE auto operator()(int x, const Underscore& y) { + assert(false && "Not implemented yet."); + return 0; + } - DEVICE auto operator()(const Underscore& x, int y) { - assert(false && "Not implemented yet."); - return 0; - } + DEVICE auto operator()(const Underscore& x, int y) { + assert(false && "Not implemented yet."); + return 0; + } - /// @brief Convert back to the original tile - DEVICE auto to_tile() { return Tile(data_); } + /// @brief Convert back to the original tile + DEVICE auto to_tile() { return Tile(data_); } - private: - // pre-compute values - static constexpr int kTilePerRow = Tile::kRows / BaseShape::kRows; - static constexpr int kTilePerCol = Tile::kCols / BaseShape::kCols; + private: + // pre-compute values + static constexpr int kTilePerRow = Tile::kRows / BaseShape::kRows; + static constexpr int kTilePerCol = Tile::kCols / BaseShape::kCols; - static constexpr int kTilePerChunkRow = kChunkRows / BaseShape::kRows; - static constexpr int kTilePerChunkCol = kChunkCols / BaseShape::kCols; + static constexpr int kTilePerChunkRow = kChunkRows / BaseShape::kRows; + static constexpr int kTilePerChunkCol = kChunkCols / BaseShape::kCols; - static constexpr bool kIsRowMajor = Tile::kType == tl::Layout::kRowMajor; + static constexpr bool kIsRowMajor = Tile::kType == tl::Layout::kRowMajor; - static constexpr int kTileRowStride = kIsRowMajor ? Tile::kCols : 1; - static constexpr int kTileColStride = kIsRowMajor ? 1 : Tile::kRows; + static constexpr int kTileRowStride = kIsRowMajor ? Tile::kCols : 1; + static constexpr int kTileColStride = kIsRowMajor ? 1 : Tile::kRows; - /// @brief Compute memory offset for sub-tile at position (x, y) - DEVICE int compute_offset(int x, int y) const { - if constexpr (kIsRowMajor) { - return x * (kChunkRows * Tile::kRowStride) + - y * kTilePerChunkCol * BaseShape::kCols; - } else { - return x * kTilePerChunkRow * BaseShape::kRows + - y * (Tile::kColStride * kChunkCols); - } + /// @brief Compute memory offset for sub-tile at position (x, y) + DEVICE int compute_offset(int x, int y) const { + if constexpr (kIsRowMajor) { + return x * (kChunkRows * Tile::kRowStride) + + y * kTilePerChunkCol * BaseShape::kCols; + } else { + return x * kTilePerChunkRow * BaseShape::kRows + + y * (Tile::kColStride * kChunkCols); } + } - DType* data_; + DType* data_; }; /// @brief Pretty printer for STileIterator template static HOST std::ostream& operator<<( std::ostream& out, const STileIterator& itr) { - STileIteratorPrettyPrinter::print(out, itr); - return out; + STileIteratorPrettyPrinter::print(out, itr); + return out; } /// @brief Advanced SharedTileIterator with better block layout support @@ -194,111 +189,111 @@ static HOST std::ostream& operator<<( /// @param ChunkShape_ The shape of the smaller tiles (chunk shape) template class STileIterator2 { - public: - using Tile = Tile_; - using DType = Tile::DType; - using ChunkShape = ChunkShape_; - - static constexpr int kChunkRows = dim_size<0, ChunkShape>; - static constexpr int kChunkCols = dim_size<1, ChunkShape>; - - static_assert( - Tile::kRows >= kChunkRows && Tile::kRows % kChunkRows == 0, - "Tile::kRows must be >= kChunkRows and divisible by kChunkRows"); - static_assert( - Tile::kCols >= kChunkCols && Tile::kCols % kChunkCols == 0, - "Tile::kCols must be >= kChunkCols and divisible by kChunkCols"); - - static constexpr int sc0 = Tile::kRows / kChunkRows; - static constexpr int sc1 = Tile::kCols / kChunkCols; - static constexpr int kNumel = sc0 * sc1; - - HOST_DEVICE STileIterator2() : tile_(nullptr), data_(nullptr) {} - - DEVICE explicit STileIterator2(Tile* tile) - : tile_(tile), data_(const_cast(tile->data())) {} - - /// @brief Access a single sub-tile by linear index - /// @param i Linear index of the sub-tile - /// @return A new tile representing the sub-tile - DEVICE auto operator()(int i) { - static_assert(sc0 == 1 || sc1 == 1, - "A single index is supported only when the strip count " - "of one of the iterator's dimensions is 1."); - assert(tile_ != nullptr && data_ != nullptr); - - // A tile is partitioned into sub-tiles along the row or column - // dimension. `x` and `y` are the indices of the sub-tile in the - // row and column dimension, respectively. - const int x = sc0 == 1 ? 0 : i; - const int y = sc0 == 1 ? i : 0; - - using TileLayout = SubTileLayout_t; - using NewTile = SharedTile; - - const int offset = compute_offset(x, y); - return NewTile(data_ + offset); + public: + using Tile = Tile_; + using DType = Tile::DType; + using ChunkShape = ChunkShape_; + + static constexpr int kChunkRows = dim_size<0, ChunkShape>; + static constexpr int kChunkCols = dim_size<1, ChunkShape>; + + static_assert( + Tile::kRows >= kChunkRows && Tile::kRows % kChunkRows == 0, + "Tile::kRows must be >= kChunkRows and divisible by kChunkRows"); + static_assert( + Tile::kCols >= kChunkCols && Tile::kCols % kChunkCols == 0, + "Tile::kCols must be >= kChunkCols and divisible by kChunkCols"); + + static constexpr int sc0 = Tile::kRows / kChunkRows; + static constexpr int sc1 = Tile::kCols / kChunkCols; + static constexpr int kNumel = sc0 * sc1; + + HOST_DEVICE STileIterator2() : tile_(nullptr), data_(nullptr) {} + + DEVICE explicit STileIterator2(Tile* tile) + : tile_(tile), data_(const_cast(tile->data())) {} + + /// @brief Access a single sub-tile by linear index + /// @param i Linear index of the sub-tile + /// @return A new tile representing the sub-tile + DEVICE auto operator()(int i) { + static_assert(sc0 == 1 || sc1 == 1, + "A single index is supported only when the strip count " + "of one of the iterator's dimensions is 1."); + assert(tile_ != nullptr && data_ != nullptr); + + // A tile is partitioned into sub-tiles along the row or column + // dimension. `x` and `y` are the indices of the sub-tile in the + // row and column dimension, respectively. + const int x = sc0 == 1 ? 0 : i; + const int y = sc0 == 1 ? i : 0; + + using TileLayout = SubTileLayout_t; + using NewTile = SharedTile; + + const int offset = compute_offset(x, y); + return NewTile(data_ + offset); + } + + DEVICE auto operator()(int x, int y) { + assert(false && "Not implemented yet."); + return 0; + } + + DEVICE auto operator()(int x, const Underscore& y) { + assert(false && "Not implemented yet."); + return 0; + } + + DEVICE auto operator()(const Underscore& x, int y) { + assert(false && "Not implemented yet."); + return 0; + } + + /// @brief Convert back to the original tile + DEVICE auto to_tile() { + assert(tile_ != nullptr); + return *tile_; + } + + private: + using Layout = typename Tile::Layout; + static constexpr bool kIsBlockLayout = is_block_layout_v; + + // Compute stride multipliers based on layout type + static constexpr int kRowCount = []() { + if constexpr (kIsBlockLayout) { + return kChunkRows / Layout::InnerLayout::kRows; + } else { + return kChunkRows; } + }(); - DEVICE auto operator()(int x, int y) { - assert(false && "Not implemented yet."); - return 0; + static constexpr int kColCount = []() { + if constexpr (kIsBlockLayout) { + return kChunkCols / Layout::InnerLayout::kCols; + } else { + return kChunkCols; } + }(); - DEVICE auto operator()(int x, const Underscore& y) { - assert(false && "Not implemented yet."); - return 0; - } - - DEVICE auto operator()(const Underscore& x, int y) { - assert(false && "Not implemented yet."); - return 0; - } + static constexpr int kRowStride = Layout::kRowStride * kRowCount; + static constexpr int kColStride = Layout::kColStride * kColCount; - /// @brief Convert back to the original tile - DEVICE auto to_tile() { - assert(tile_ != nullptr); - return *tile_; - } - - private: - using Layout = typename Tile::Layout; - static constexpr bool kIsBlockLayout = is_block_layout_v; - - // Compute stride multipliers based on layout type - static constexpr int kRowCount = []() { - if constexpr (kIsBlockLayout) { - return kChunkRows / Layout::InnerLayout::kRows; - } else { - return kChunkRows; - } - }(); - - static constexpr int kColCount = []() { - if constexpr (kIsBlockLayout) { - return kChunkCols / Layout::InnerLayout::kCols; - } else { - return kChunkCols; - } - }(); - - static constexpr int kRowStride = Layout::kRowStride * kRowCount; - static constexpr int kColStride = Layout::kColStride * kColCount; - - /// @brief Compute memory offset for sub-tile at position (x, y) - DEVICE int compute_offset(int x, int y) const { - return x * kRowStride + y * kColStride; - } + /// @brief Compute memory offset for sub-tile at position (x, y) + DEVICE int compute_offset(int x, int y) const { + return x * kRowStride + y * kColStride; + } - Tile* tile_; - DType* data_; + Tile* tile_; + DType* data_; }; /// @brief Pretty printer for STileIterator2 template static HOST std::ostream& operator<<( std::ostream& out, const STileIterator2& itr) { - STileIterator2PrettyPrinter::print(out, itr); - return out; + STileIterator2PrettyPrinter::print(out, itr); + return out; } } // namespace tilefusion diff --git a/include/types/swizzle.hpp b/include/types/swizzle.hpp index 6c1da778..67d86674 100644 --- a/include/types/swizzle.hpp +++ b/include/types/swizzle.hpp @@ -15,31 +15,31 @@ namespace tl = tile_layout; */ template struct Swizzle { - static constexpr int Bbits = kB; - static constexpr int Mbits = kM; - static constexpr int Sbits = kS; - /** - * @brief Applies the swizzle function to permute a 1-D index. - * - * @param idx The 1-D index within the swizzle space of 2^B * 2^S * 2^M - * elements. - * @return The permuted (swizzled) index. - */ - HOST_DEVICE int operator()(int idx) const { - // | Bbits | Sbits | Mbits | - // Mbits as mask for the lower bits. - - int bs = idx >> Mbits; - // (b, s) as a 2d coordinate. - int y = bs & ((1 << Sbits) - 1); - int x = bs >> Sbits; - - int swizzled_y = x ^ y; - - // Use swizzled_y instead of y and build swizzled idx. - return (x << (Mbits + Sbits)) | (swizzled_y << Mbits) | - (idx & ((1 << Mbits) - 1)); - } + static constexpr int Bbits = kB; + static constexpr int Mbits = kM; + static constexpr int Sbits = kS; + /** + * @brief Applies the swizzle function to permute a 1-D index. + * + * @param idx The 1-D index within the swizzle space of 2^B * 2^S * 2^M + * elements. + * @return The permuted (swizzled) index. + */ + HOST_DEVICE int operator()(int idx) const { + // | Bbits | Sbits | Mbits | + // Mbits as mask for the lower bits. + + int bs = idx >> Mbits; + // (b, s) as a 2d coordinate. + int y = bs & ((1 << Sbits) - 1); + int x = bs >> Sbits; + + int swizzled_y = x ^ y; + + // Use swizzled_y instead of y and build swizzled idx. + return (x << (Mbits + Sbits)) | (swizzled_y << Mbits) | + (idx & ((1 << Mbits) - 1)); + } }; /** @@ -62,84 +62,84 @@ struct SwizzledLayout; template struct SwizzledLayout { - using Layout = Layout_; - using Swizzle = Swizzle_; - - static constexpr int Bbits = Swizzle_::Bbits; - static constexpr int Mbits = Swizzle_::Mbits; - static constexpr int Sbits = Swizzle_::Sbits; - - static_assert(Layout::kRows == (1 << Bbits), - "The number of rows in the layout should be 2^B."); - static_assert(Layout::kCols == (1 << (Mbits + Sbits)), - "The number of columns in the layout should be 2^S * 2^M."); - - // to be compatible with all the other layouts - static constexpr int kRows = Layout::kRows; - static constexpr int kCols = Layout::kCols; - static constexpr int kNumel = Layout::kNumel; - static constexpr tl::Layout kType = Layout::kType; - - /** - * @brief Compose the swizzle function with the layout function. - * - * @param x The row index, with a total of 2^B rows. - * @param y The column index, with a total of 2^S * 2^M columns. - * @return The swizzled index after applying the layout function. - */ - HOST_DEVICE auto operator()(int x, int y) const { - int idx = (x << (Mbits + Sbits)) | y; - - int swizzled_idx = swizzle_(idx); - int swizzled_x = swizzled_idx >> (Mbits + Sbits); - int swizzled_y = swizzled_idx & ((1 << (Mbits + Sbits)) - 1); - return layout_(swizzled_x, swizzled_y); - } - - private: - Swizzle swizzle_; - Layout layout_; + using Layout = Layout_; + using Swizzle = Swizzle_; + + static constexpr int Bbits = Swizzle_::Bbits; + static constexpr int Mbits = Swizzle_::Mbits; + static constexpr int Sbits = Swizzle_::Sbits; + + static_assert(Layout::kRows == (1 << Bbits), + "The number of rows in the layout should be 2^B."); + static_assert(Layout::kCols == (1 << (Mbits + Sbits)), + "The number of columns in the layout should be 2^S * 2^M."); + + // to be compatible with all the other layouts + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; + static constexpr int kNumel = Layout::kNumel; + static constexpr tl::Layout kType = Layout::kType; + + /** + * @brief Compose the swizzle function with the layout function. + * + * @param x The row index, with a total of 2^B rows. + * @param y The column index, with a total of 2^S * 2^M columns. + * @return The swizzled index after applying the layout function. + */ + HOST_DEVICE auto operator()(int x, int y) const { + int idx = (x << (Mbits + Sbits)) | y; + + int swizzled_idx = swizzle_(idx); + int swizzled_x = swizzled_idx >> (Mbits + Sbits); + int swizzled_y = swizzled_idx & ((1 << (Mbits + Sbits)) - 1); + return layout_(swizzled_x, swizzled_y); + } + + private: + Swizzle swizzle_; + Layout layout_; }; template struct SwizzledLayout { - using Layout = Layout_; - using Swizzle = Swizzle_; - - static constexpr int Bbits = Swizzle::Bbits; - static constexpr int Mbits = Swizzle::Mbits; - static constexpr int Sbits = Swizzle::Sbits; - - static_assert(Layout::kRows == (1 << (Mbits + Sbits)), - "The number of rows in the layout should be 2^S * 2^M."); - static_assert(Layout::kCols == (1 << Bbits), - "The number of columns in the layout should be 2^B."); - - // to be compatible with all the other layouts - static constexpr int kRows = Layout::kRows; - static constexpr int kCols = Layout::kCols; - static constexpr int kNumel = Layout::kNumel; - static constexpr tl::Layout kType = Layout::kType; - - /** - * @brief Compose the swizzle function with the layout function. - * - * @param x The row index, with a total of 2^B rows. - * @param y The column index, with a total of 2^S * 2^M columns. - * @return The swizzled index after applying the layout function. - */ - HOST_DEVICE auto operator()(int x, int y) const { - int idx = (y << (Bbits + Mbits)) | x; - - int swizzled_idx = swizzle_(idx); - int swizzled_y = swizzled_idx >> (Mbits + Sbits); - int swizzled_x = swizzled_idx & ((1 << (Mbits + Sbits)) - 1); - return layout_(swizzled_x, swizzled_y); - } - - private: - Swizzle swizzle_; - Layout layout_; + using Layout = Layout_; + using Swizzle = Swizzle_; + + static constexpr int Bbits = Swizzle::Bbits; + static constexpr int Mbits = Swizzle::Mbits; + static constexpr int Sbits = Swizzle::Sbits; + + static_assert(Layout::kRows == (1 << (Mbits + Sbits)), + "The number of rows in the layout should be 2^S * 2^M."); + static_assert(Layout::kCols == (1 << Bbits), + "The number of columns in the layout should be 2^B."); + + // to be compatible with all the other layouts + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; + static constexpr int kNumel = Layout::kNumel; + static constexpr tl::Layout kType = Layout::kType; + + /** + * @brief Compose the swizzle function with the layout function. + * + * @param x The row index, with a total of 2^B rows. + * @param y The column index, with a total of 2^S * 2^M columns. + * @return The swizzled index after applying the layout function. + */ + HOST_DEVICE auto operator()(int x, int y) const { + int idx = (y << (Bbits + Mbits)) | x; + + int swizzled_idx = swizzle_(idx); + int swizzled_y = swizzled_idx >> (Mbits + Sbits); + int swizzled_x = swizzled_idx & ((1 << (Mbits + Sbits)) - 1); + return layout_(swizzled_x, swizzled_y); + } + + private: + Swizzle swizzle_; + Layout layout_; }; /// @brief Pretty printer for SwizzledLayout @@ -147,70 +147,70 @@ template static HOST std::ostream& operator<<( std::ostream& out, const SwizzledLayout& layout) { - out << "SwizzledLayout { " << Layout_{} << ", Swizzle<" << Swizzle_::Bbits - << ", " << Swizzle_::Mbits << ", " << Swizzle_::Sbits << "> }"; - return out; + out << "SwizzledLayout { " << Layout_{} << ", Swizzle<" << Swizzle_::Bbits + << ", " << Swizzle_::Mbits << ", " << Swizzle_::Sbits << "> }"; + return out; } /** * @brief The base tile shape for Swizzle<3, 3, 3>. */ template - requires BaseType + requires BaseType struct SwizzleBaseTileShape; template - requires HalfType + requires HalfType struct SwizzleBaseTileShape { - using DType = Element; + using DType = Element; - static constexpr int kRows = 8; - static constexpr int kCols = 64; - static constexpr int kNumel = kRows * kCols; + static constexpr int kRows = 8; + static constexpr int kCols = 64; + static constexpr int kNumel = kRows * kCols; - static constexpr int B = 3; - static constexpr int M = 3; - static constexpr int S = 3; + static constexpr int B = 3; + static constexpr int M = 3; + static constexpr int S = 3; }; template <> struct SwizzleBaseTileShape { - using DType = float; + using DType = float; - static constexpr int kRows = 8; - static constexpr int kCols = 32; - static constexpr int kNumel = kRows * kCols; + static constexpr int kRows = 8; + static constexpr int kCols = 32; + static constexpr int kNumel = kRows * kCols; - static constexpr int B = 3; - static constexpr int M = 2; - static constexpr int S = 3; + static constexpr int B = 3; + static constexpr int M = 2; + static constexpr int S = 3; }; template - requires HalfType + requires HalfType struct SwizzleBaseTileShape { - using DType = Element; + using DType = Element; - static constexpr int kRows = 4; - static constexpr int kCols = 32; - static constexpr int kNumel = kRows * kCols; + static constexpr int kRows = 4; + static constexpr int kCols = 32; + static constexpr int kNumel = kRows * kCols; - static constexpr int B = 2; - static constexpr int M = 3; - static constexpr int S = 2; + static constexpr int B = 2; + static constexpr int M = 3; + static constexpr int S = 2; }; template <> struct SwizzleBaseTileShape { - using DType = float; + using DType = float; - static constexpr int kRows = 4; - static constexpr int kCols = 16; - static constexpr int kNumel = kRows * kCols; + static constexpr int kRows = 4; + static constexpr int kCols = 16; + static constexpr int kNumel = kRows * kCols; - static constexpr int B = 2; - static constexpr int M = 2; - static constexpr int S = 2; + static constexpr int B = 2; + static constexpr int M = 2; + static constexpr int S = 2; }; } // namespace tilefusion diff --git a/include/types/tile_shape.hpp b/include/types/tile_shape.hpp index b6cc7d02..0143766f 100644 --- a/include/types/tile_shape.hpp +++ b/include/types/tile_shape.hpp @@ -10,15 +10,15 @@ namespace tilefusion { template struct TileShape { - static constexpr cute::array shape = {Ns...}; + static constexpr cute::array shape = {Ns...}; - static constexpr size_t get_numel() { - size_t product = 1; - for (size_t n : shape) product *= n; - return product; - } + static constexpr size_t get_numel() { + size_t product = 1; + for (size_t n : shape) product *= n; + return product; + } - static constexpr size_t kNumel = get_numel(); + static constexpr size_t kNumel = get_numel(); }; template diff --git a/include/util/cuda_info.hpp b/include/util/cuda_info.hpp index dc6b6c6d..3f4bd2fc 100644 --- a/include/util/cuda_info.hpp +++ b/include/util/cuda_info.hpp @@ -12,23 +12,23 @@ namespace tilefusion { // Returns the name of the device. std::string get_device_name() { - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); - std::stringstream ss(prop.name); - const char delim = ' '; + std::stringstream ss(prop.name); + const char delim = ' '; - std::string s; - std::vector out; + std::string s; + std::vector out; - while (std::getline(ss, s, delim)) { - out.push_back(s); - } + while (std::getline(ss, s, delim)) { + out.push_back(s); + } - std::stringstream out_ss; - int i = 0; - for (; i < static_cast(out.size()) - 1; ++i) out_ss << out[i] << "_"; - out_ss << out[i]; - return out_ss.str(); + std::stringstream out_ss; + int i = 0; + for (; i < static_cast(out.size()) - 1; ++i) out_ss << out[i] << "_"; + out_ss << out[i]; + return out_ss.str(); } } // namespace tilefusion diff --git a/include/util/cuda_timer.hpp b/include/util/cuda_timer.hpp index 7e0b5c53..a57dcc9a 100644 --- a/include/util/cuda_timer.hpp +++ b/include/util/cuda_timer.hpp @@ -13,34 +13,33 @@ namespace tilefusion { /// ... /// float time = timer.stop(); class CudaTimer { - public: - CudaTimer() { - CUDA_CHECK(cudaEventCreate(&start_event)); - CUDA_CHECK(cudaEventCreate(&stop_event)); - } + public: + CudaTimer() { + CUDA_CHECK(cudaEventCreate(&start_event)); + CUDA_CHECK(cudaEventCreate(&stop_event)); + } - ~CudaTimer() { - CUDA_CHECK(cudaEventDestroy(start_event)); - CUDA_CHECK(cudaEventDestroy(stop_event)); - } + ~CudaTimer() { + CUDA_CHECK(cudaEventDestroy(start_event)); + CUDA_CHECK(cudaEventDestroy(stop_event)); + } - void start(cudaStream_t st = 0) { - stream = st; - CUDA_CHECK(cudaEventRecord(start_event, stream)); - } + void start(cudaStream_t st = 0) { + stream = st; + CUDA_CHECK(cudaEventRecord(start_event, stream)); + } - float stop() { - float milliseconds = 0.; - CUDA_CHECK(cudaEventRecord(stop_event, stream)); - CUDA_CHECK(cudaEventSynchronize(stop_event)); - CUDA_CHECK( - cudaEventElapsedTime(&milliseconds, start_event, stop_event)); - return milliseconds; - } + float stop() { + float milliseconds = 0.; + CUDA_CHECK(cudaEventRecord(stop_event, stream)); + CUDA_CHECK(cudaEventSynchronize(stop_event)); + CUDA_CHECK(cudaEventElapsedTime(&milliseconds, start_event, stop_event)); + return milliseconds; + } - private: - cudaEvent_t start_event; - cudaEvent_t stop_event; - cudaStream_t stream; + private: + cudaEvent_t start_event; + cudaEvent_t stop_event; + cudaStream_t stream; }; } // namespace tilefusion diff --git a/include/util/debug.hpp b/include/util/debug.hpp index e2abfb0b..6164813c 100644 --- a/include/util/debug.hpp +++ b/include/util/debug.hpp @@ -10,15 +10,15 @@ namespace tilefusion { DEVICE bool block(int bid) { - int id = blockIdx.x + blockIdx.y * gridDim.x + - blockIdx.z * gridDim.x * gridDim.y; - return id == bid; + int id = + blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + return id == bid; } DEVICE bool thread(int tid, int bid) { - int id = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * blockDim.x * blockDim.y; - return id == tid && block(bid); + int id = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y; + return id == tid && block(bid); } // usage, e.g. diff --git a/include/util/math_utils.hpp b/include/util/math_utils.hpp index 8d843971..c8a51357 100644 --- a/include/util/math_utils.hpp +++ b/include/util/math_utils.hpp @@ -10,46 +10,46 @@ namespace tilefusion { /// @brief Helper function to check if a number is power of 2 at compile time template consteval bool is_power_of_2() { - return kN > 0 && (kN & (kN - 1)) == 0; + return kN > 0 && (kN & (kN - 1)) == 0; } /// @brief Helper function to count trailing zeros template HOST_DEVICE constexpr int count_trailing_zeros() { - static_assert(is_power_of_2(), "kN must be a power of 2"); - int count = 0; - int temp = kN; - while ((temp & 1) == 0) { - temp >>= 1; - count++; - } - return count; + static_assert(is_power_of_2(), "kN must be a power of 2"); + int count = 0; + int temp = kN; + while ((temp & 1) == 0) { + temp >>= 1; + count++; + } + return count; } /// @brief Helper function to compute division for power of 2 template HOST_DEVICE constexpr int div_pow2(int x) { - static_assert(is_power_of_2(), "kN must be a power of 2"); - return x >> count_trailing_zeros(); + static_assert(is_power_of_2(), "kN must be a power of 2"); + return x >> count_trailing_zeros(); } /// @brief Helper function to compute modulo for power of 2 template HOST_DEVICE constexpr int mod_pow2(int x) { - static_assert(is_power_of_2(), "kN must be a power of 2"); - return x & (kN - 1); + static_assert(is_power_of_2(), "kN must be a power of 2"); + return x & (kN - 1); } /// @brief Helper function to compute division and modulo for any number template HOST_DEVICE constexpr int div_any(int x) { - return x / kN; + return x / kN; } /// @brief Helper function to compute modulo for any number template HOST_DEVICE constexpr int mod_any(int x) { - return x % kN; + return x % kN; } /// @brief Select appropriate division/modulo functions based on whether n is @@ -59,15 +59,15 @@ struct DivModSelector; template struct DivModSelector { - static HOST_DEVICE constexpr int div(int x) { return div_any(x); } + static HOST_DEVICE constexpr int div(int x) { return div_any(x); } - static HOST_DEVICE constexpr int mod(int x) { return mod_any(x); } + static HOST_DEVICE constexpr int mod(int x) { return mod_any(x); } }; template struct DivModSelector { - static HOST_DEVICE constexpr int div(int x) { return div_pow2(x); } + static HOST_DEVICE constexpr int div(int x) { return div_pow2(x); } - static HOST_DEVICE constexpr int mod(int x) { return mod_pow2(x); } + static HOST_DEVICE constexpr int mod(int x) { return mod_pow2(x); } }; } // namespace tilefusion diff --git a/include/util/print.hpp b/include/util/print.hpp index 8cb28723..9ba76e00 100644 --- a/include/util/print.hpp +++ b/include/util/print.hpp @@ -11,15 +11,15 @@ namespace tilefusion::util { namespace tl = tile_layout; template - requires BaseType + requires BaseType DEVICE void print_numeric_tile(const DType* data, const Layout& layout) { - for (int i = 0; i < Layout::kRows; ++i) { - for (int j = 0; j < Layout::kCols; ++j) - printf("%.2f, ", to_float(data[layout(i, j)])); - printf("\n"); + for (int i = 0; i < Layout::kRows; ++i) { + for (int j = 0; j < Layout::kCols; ++j) + printf("%.2f, ", to_float(data[layout(i, j)])); + printf("\n"); - if (i && (i + 1) % 16 == 0) printf("\n"); - } + if (i && (i + 1) % 16 == 0) printf("\n"); + } } /// @brief Print a tile of floating point numbers. NOTE: when @@ -31,130 +31,130 @@ DEVICE void print_numeric_tile(const DType* data, const Layout& layout) { // } template DEVICE void print_tile(const DType* data, const Layout& layout) { - if constexpr (std::is_same::value || - std::is_same::value || - std::is_same::value + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value #ifdef CUDA_FP8_AVAILABLE - || std::is_same::value || - std::is_same::value + || std::is_same::value || + std::is_same::value #endif - ) { - print_numeric_tile(data, layout); - } else { - /// Since register tile is a nested array-like structure. printing - /// resigter tile hits this function. - for (int i = 0; i < Layout::kRows; ++i) { - for (int j = 0; j < Layout::kCols; ++j) { - auto tile = data[layout(i, j)]; - print_numeric_tile(tile.data(), tile.layout()); - } - } + ) { + print_numeric_tile(data, layout); + } else { + /// Since register tile is a nested array-like structure. printing + /// resigter tile hits this function. + for (int i = 0; i < Layout::kRows; ++i) { + for (int j = 0; j < Layout::kCols; ++j) { + auto tile = data[layout(i, j)]; + print_numeric_tile(tile.data(), tile.layout()); + } } + } } template struct RegVecPrinter { - static constexpr int kRows = RegTile::kRows; + static constexpr int kRows = RegTile::kRows; - DEVICE void operator()(const RegTile& tile, int tid) { - int lane_id = tid % 32; - for (int i = 0; i < kRows; ++i) { - if (lane_id % 4 == 0) { - printf("%.2f, ", to_float(tile(i, 0))); - } + DEVICE void operator()(const RegTile& tile, int tid) { + int lane_id = tid % 32; + for (int i = 0; i < kRows; ++i) { + if (lane_id % 4 == 0) { + printf("%.2f, ", to_float(tile(i, 0))); + } #if defined(__CUDA_ARCH__) - // Sync Threads to print in-order data. - __syncthreads(); + // Sync Threads to print in-order data. + __syncthreads(); #endif - if (lane_id % 4 == 0) { - printf("%.2f, ", to_float(tile(i, 1))); - } - } - - if (lane_id == 0) printf("\n"); + if (lane_id % 4 == 0) { + printf("%.2f, ", to_float(tile(i, 1))); + } } + + if (lane_id == 0) printf("\n"); + } }; template struct RegTilePrinter { - constexpr static int kRows = RegTile::kRows; - constexpr static int kCols = RegTile::kCols; + constexpr static int kRows = RegTile::kRows; + constexpr static int kCols = RegTile::kCols; - void operator()(const RegTile& tile, int tid) {} + void operator()(const RegTile& tile, int tid) {} }; template struct RegTilePrinter { - constexpr static int kRows = RegTile::kRows; - constexpr static int kCols = RegTile::kCols; - - using DType = typename RegTile::DType::DType; - - DEVICE void print_tile_col(const RegTile& tile, int lane_id, int row_num, - bool is_top) { - for (int col_num = 0; col_num < kCols; ++col_num) { - if (is_top) { - printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(0, 0)), - to_float(tile(row_num, col_num)(0, 1))); - printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(1, 0)), - to_float(tile(row_num, col_num)(1, 1))); - } else { - printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(0, 2)), - to_float(tile(row_num, col_num)(0, 3))); - printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(1, 2)), - to_float(tile(row_num, col_num)(1, 3))); - } - } - if (lane_id % 4 == 0) printf("\n"); + constexpr static int kRows = RegTile::kRows; + constexpr static int kCols = RegTile::kCols; + + using DType = typename RegTile::DType::DType; + + DEVICE void print_tile_col(const RegTile& tile, int lane_id, int row_num, + bool is_top) { + for (int col_num = 0; col_num < kCols; ++col_num) { + if (is_top) { + printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(0, 0)), + to_float(tile(row_num, col_num)(0, 1))); + printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(1, 0)), + to_float(tile(row_num, col_num)(1, 1))); + } else { + printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(0, 2)), + to_float(tile(row_num, col_num)(0, 3))); + printf("%.2f, %.2f, ", to_float(tile(row_num, col_num)(1, 2)), + to_float(tile(row_num, col_num)(1, 3))); + } } - - DEVICE void operator()(const RegTile& tile, int tid) { - // BaseTile base_tile; - int lane_id = tid % 32; - for (int i = 0; i < kRows; ++i) { - // Print top row. - if (lane_id >= 0 && lane_id <= 3) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 4 && lane_id <= 7) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 8 && lane_id <= 11) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 12 && lane_id <= 15) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 16 && lane_id <= 19) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 20 && lane_id <= 23) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 24 && lane_id <= 27) - print_tile_col(tile, lane_id, i, true); - else if (lane_id >= 28 && lane_id <= 31) - print_tile_col(tile, lane_id, i, true); + if (lane_id % 4 == 0) printf("\n"); + } + + DEVICE void operator()(const RegTile& tile, int tid) { + // BaseTile base_tile; + int lane_id = tid % 32; + for (int i = 0; i < kRows; ++i) { + // Print top row. + if (lane_id >= 0 && lane_id <= 3) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 4 && lane_id <= 7) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 8 && lane_id <= 11) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 12 && lane_id <= 15) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 16 && lane_id <= 19) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 20 && lane_id <= 23) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 24 && lane_id <= 27) + print_tile_col(tile, lane_id, i, true); + else if (lane_id >= 28 && lane_id <= 31) + print_tile_col(tile, lane_id, i, true); #if defined(__CUDA_ARCH__) - // Sync Threads to print in-order data. - __syncthreads(); + // Sync Threads to print in-order data. + __syncthreads(); #endif - // Print bottom row. - if (lane_id >= 0 && lane_id <= 3) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 4 && lane_id <= 7) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 8 && lane_id <= 11) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 12 && lane_id <= 15) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 16 && lane_id <= 19) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 20 && lane_id <= 23) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 24 && lane_id <= 27) - print_tile_col(tile, lane_id, i, false); - else if (lane_id >= 28 && lane_id <= 31) - print_tile_col(tile, lane_id, i, false); - } - if (lane_id == 0) printf("\n"); + // Print bottom row. + if (lane_id >= 0 && lane_id <= 3) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 4 && lane_id <= 7) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 8 && lane_id <= 11) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 12 && lane_id <= 15) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 16 && lane_id <= 19) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 20 && lane_id <= 23) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 24 && lane_id <= 27) + print_tile_col(tile, lane_id, i, false); + else if (lane_id >= 28 && lane_id <= 31) + print_tile_col(tile, lane_id, i, false); } + if (lane_id == 0) printf("\n"); + } }; } // namespace tilefusion::util diff --git a/src/cuda_info.cc b/src/cuda_info.cc index b95ea500..efe15391 100644 --- a/src/cuda_info.cc +++ b/src/cuda_info.cc @@ -11,102 +11,101 @@ namespace tilefusion { // Returns the number of GPUs. int GetGPUDeviceCount() { - int deviceCount = 0; - CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); - return deviceCount; + int deviceCount = 0; + CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); + return deviceCount; } // Returns the compute capability of the given GPU. int GetGPUComputeCapability(int id) { - int major, minor; - CUDA_CHECK( - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id)); - CUDA_CHECK( - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id)); - return major * 10 + minor; + int major, minor; + CUDA_CHECK( + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id)); + CUDA_CHECK( + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id)); + return major * 10 + minor; } // Returns the number of multiprocessors for the given GPU. int GetGPUMultiProcessors(int id) { - int count; - CUDA_CHECK( - cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id)); - return count; + int count; + CUDA_CHECK( + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id)); + return count; } // Returns the maximum number of threads per multiprocessor for the given GPU. int GetGPUMaxThreadsPerMultiProcessor(int id) { - int count; - CUDA_CHECK(cudaDeviceGetAttribute( - &count, cudaDevAttrMaxThreadsPerMultiProcessor, id)); - return count; + int count; + CUDA_CHECK(cudaDeviceGetAttribute( + &count, cudaDevAttrMaxThreadsPerMultiProcessor, id)); + return count; } // Returns the maximum number of threads per block for the given GPU. int GetGPUMaxThreadsPerBlock(int id) { - int count; - CUDA_CHECK( - cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id)); - return count; + int count; + CUDA_CHECK(cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id)); + return count; } // Returns the maximum grid size for the given GPU. dim3 GetGpuMaxGridDimSize(int id) { - dim3 grid_size; + dim3 grid_size; - int size; - CUDA_CHECK(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id)); - grid_size.x = size; + int size; + CUDA_CHECK(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id)); + grid_size.x = size; - CUDA_CHECK(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id)); - grid_size.y = size; + CUDA_CHECK(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id)); + grid_size.y = size; - CUDA_CHECK(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id)); - grid_size.z = size; - return grid_size; + CUDA_CHECK(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id)); + grid_size.z = size; + return grid_size; } // Returns the name of the device. std::string GetDeviceName() { - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); - std::stringstream ss(prop.name); - const char delim = ' '; + std::stringstream ss(prop.name); + const char delim = ' '; - std::string s; - std::vector out; + std::string s; + std::vector out; - while (std::getline(ss, s, delim)) { - out.push_back(s); - } + while (std::getline(ss, s, delim)) { + out.push_back(s); + } - std::stringstream out_ss; - int i = 0; - for (; i < static_cast(out.size()) - 1; ++i) out_ss << out[i] << "_"; - out_ss << out[i]; - return out_ss.str(); + std::stringstream out_ss; + int i = 0; + for (; i < static_cast(out.size()) - 1; ++i) out_ss << out[i] << "_"; + out_ss << out[i]; + return out_ss.str(); } std::string GetComputeCapability() { - int device_id; - CUDA_CHECK(cudaGetDevice(&device_id)); + int device_id; + CUDA_CHECK(cudaGetDevice(&device_id)); - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); - std::stringstream ss; - ss << "sm_" << prop.major << prop.minor; - return ss.str(); + std::stringstream ss; + ss << "sm_" << prop.major << prop.minor; + return ss.str(); } int GetMaxSharedMemoryPerBlock() { - int device_id; - CUDA_CHECK(cudaGetDevice(&device_id)); + int device_id; + CUDA_CHECK(cudaGetDevice(&device_id)); - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); - return prop.sharedMemPerBlock; + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); + return prop.sharedMemPerBlock; } } // namespace tilefusion diff --git a/src/cuda_utils.cc b/src/cuda_utils.cc index 1327a9c9..6eb4b4ea 100644 --- a/src/cuda_utils.cc +++ b/src/cuda_utils.cc @@ -5,28 +5,28 @@ namespace tilefusion { const char* cublasGetErrorString(cublasStatus_t status) { - switch (status) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return "unknown error"; + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "unknown error"; } } // namespace tilefusion diff --git a/src/jit/compiler.cc b/src/jit/compiler.cc index 925a337a..bd57d443 100644 --- a/src/jit/compiler.cc +++ b/src/jit/compiler.cc @@ -24,263 +24,257 @@ namespace tilefusion::jit { namespace { // Generate a random string to use as part of the temp file name std::string generate_random_string(size_t length) { - static const char alphanum[] = - "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dist(0, sizeof(alphanum) - 2); - - std::string result; - result.reserve(length); - for (size_t i = 0; i < length; ++i) { - result += alphanum[dist(gen)]; - } - return result; + static const char alphanum[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dist(0, sizeof(alphanum) - 2); + + std::string result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + result += alphanum[dist(gen)]; + } + return result; } std::string exec_cmd(const std::string& cmd) { - LOG(INFO) << "Executing command: " << cmd; - std::array buffer; - std::string result; - std::unique_ptr pipe(popen(cmd.c_str(), "r"), - pclose); - - if (!pipe) { - LOG(ERROR) << "popen() failed for command: " << cmd; - throw std::runtime_error("popen() failed!"); - } - - while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { - result += buffer.data(); - } - - int status = pclose(pipe.release()); - if (status != 0) { - LOG(WARNING) << "Command failed with status " << status << ": " << cmd; - LOG(WARNING) << "Output: " << result; - } - - return result; + LOG(INFO) << "Executing command: " << cmd; + std::array buffer; + std::string result; + std::unique_ptr pipe(popen(cmd.c_str(), "r"), + pclose); + + if (!pipe) { + LOG(ERROR) << "popen() failed for command: " << cmd; + throw std::runtime_error("popen() failed!"); + } + + while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + result += buffer.data(); + } + + int status = pclose(pipe.release()); + if (status != 0) { + LOG(WARNING) << "Command failed with status " << status << ": " << cmd; + LOG(WARNING) << "Output: " << result; + } + + return result; } std::string get_hash_key(const std::string& kernel_name, const std::string& cuda_source, const std::vector& compile_args) { - std::stringstream ss; - ss << kernel_name << "___"; - ss << cuda_source << "___"; - for (const auto& arg : compile_args) { - ss << arg << " "; - } - - return ss.str(); + std::stringstream ss; + ss << kernel_name << "___"; + ss << cuda_source << "___"; + for (const auto& arg : compile_args) { + ss << arg << " "; + } + + return ss.str(); } std::string get_nvcc_path() { #ifdef NVCC_PATH - return NVCC_PATH; + return NVCC_PATH; #else - return "nvcc"; + return "nvcc"; #endif } } // namespace JitCompiler& JitCompiler::instance() { - static JitCompiler instance; - return instance; + static JitCompiler instance; + return instance; } JitCompiler::JitCompiler() { - // FIXME(ying): GLog should be initialized before this function is called - - CUresult result = cuInit(0); + // FIXME(ying): GLog should be initialized before this function is called + + CUresult result = cuInit(0); + if (result != CUDA_SUCCESS) { + LOG(FATAL) << "Failed to initialize CUDA driver"; + throw std::runtime_error("Failed to initialize CUDA driver"); + } + + result = cuCtxGetCurrent(&cuda_context_); + if (result != CUDA_SUCCESS) { + CUdevice device; + result = cuDeviceGet(&device, 0); if (result != CUDA_SUCCESS) { - LOG(FATAL) << "Failed to initialize CUDA driver"; - throw std::runtime_error("Failed to initialize CUDA driver"); + throw std::runtime_error("Failed to get CUDA device"); } - result = cuCtxGetCurrent(&cuda_context_); + result = cuCtxCreate(&cuda_context_, 0, device); if (result != CUDA_SUCCESS) { - CUdevice device; - result = cuDeviceGet(&device, 0); - if (result != CUDA_SUCCESS) { - throw std::runtime_error("Failed to get CUDA device"); - } - - result = cuCtxCreate(&cuda_context_, 0, device); - if (result != CUDA_SUCCESS) { - throw std::runtime_error("Failed to create CUDA context"); - } + throw std::runtime_error("Failed to create CUDA context"); } + } } JitCompiler::~JitCompiler() { - for (auto& [key, module] : module_cache_) { - cuModuleUnload(module); - } + for (auto& [key, module] : module_cache_) { + cuModuleUnload(module); + } } CUfunction JitCompiler::compile_kernel( const std::string& kernel_name, const std::string& cuda_source, const std::vector& include_paths, const std::vector& compile_args) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); - std::string key = get_hash_key(kernel_name, cuda_source, compile_args); + std::string key = get_hash_key(kernel_name, cuda_source, compile_args); - auto kernel_it = kernel_cache_.find(key); - if (kernel_it != kernel_cache_.end()) { - return kernel_it->second; - } + auto kernel_it = kernel_cache_.find(key); + if (kernel_it != kernel_cache_.end()) { + return kernel_it->second; + } - try { - std::string ptx = - compile_to_ptx(cuda_source, include_paths, compile_args); - CUfunction kernel = load_ptx_and_get_kernel(ptx, kernel_name); - kernel_cache_[key] = kernel; - return kernel; - } catch (const std::exception& e) { - std::cerr << "Error compiling kernel: " << e.what() << std::endl; - return nullptr; - } + try { + std::string ptx = compile_to_ptx(cuda_source, include_paths, compile_args); + CUfunction kernel = load_ptx_and_get_kernel(ptx, kernel_name); + kernel_cache_[key] = kernel; + return kernel; + } catch (const std::exception& e) { + std::cerr << "Error compiling kernel: " << e.what() << std::endl; + return nullptr; + } } CUfunction JitCompiler::get_or_compile_kernel( const std::string& kernel_name, const std::string& cuda_source, const std::vector& include_paths, const std::vector& compile_args) { - return compile_kernel(kernel_name, cuda_source, include_paths, - compile_args); + return compile_kernel(kernel_name, cuda_source, include_paths, compile_args); } void JitCompiler::clear_cache() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); - for (auto& [key, module] : module_cache_) { - cuModuleUnload(module); - } + for (auto& [key, module] : module_cache_) { + cuModuleUnload(module); + } - module_cache_.clear(); - kernel_cache_.clear(); + module_cache_.clear(); + kernel_cache_.clear(); } std::string JitCompiler::compile_to_ptx( const std::string& cuda_source, const std::vector& include_paths, const std::vector& compile_args) { - std::string cu_file = write_to_temp_file(cuda_source, ".cu"); + std::string cu_file = write_to_temp_file(cuda_source, ".cu"); - std::stringstream cmd; - cmd << get_nvcc_path() << " -ptx "; + std::stringstream cmd; + cmd << get_nvcc_path() << " -ptx "; - cmd << "-arch=" << GetComputeCapability() << " "; + cmd << "-arch=" << GetComputeCapability() << " "; - for (const auto& path : include_paths) { - cmd << "-I" << path << " "; - } + for (const auto& path : include_paths) { + cmd << "-I" << path << " "; + } - for (const auto& arg : compile_args) { - cmd << arg << " "; - } + for (const auto& arg : compile_args) { + cmd << arg << " "; + } - std::string ptx_file = cu_file.substr(0, cu_file.size() - 3) + ".ptx"; - cmd << cu_file << " -o " << ptx_file; + std::string ptx_file = cu_file.substr(0, cu_file.size() - 3) + ".ptx"; + cmd << cu_file << " -o " << ptx_file; - std::string output = exec_cmd(cmd.str()); + std::string output = exec_cmd(cmd.str()); - std::ifstream ptx_stream(ptx_file); - if (!ptx_stream.good()) { - LOG(ERROR) << "Failed to open PTX file: " << ptx_file; - LOG(ERROR) << "nvcc output: " << output; - throw std::runtime_error("Failed to compile CUDA source: " + output); - } + std::ifstream ptx_stream(ptx_file); + if (!ptx_stream.good()) { + LOG(ERROR) << "Failed to open PTX file: " << ptx_file; + LOG(ERROR) << "nvcc output: " << output; + throw std::runtime_error("Failed to compile CUDA source: " + output); + } - std::stringstream ptx_content; - ptx_content << ptx_stream.rdbuf(); - ptx_stream.close(); + std::stringstream ptx_content; + ptx_content << ptx_stream.rdbuf(); + ptx_stream.close(); - // For debugging, keep the files instead of removing them - LOG(INFO) << "Keeping temporary files for debugging:"; - LOG(INFO) << " Source: " << cu_file; - LOG(INFO) << " PTX: " << ptx_file; + // For debugging, keep the files instead of removing them + LOG(INFO) << "Keeping temporary files for debugging:"; + LOG(INFO) << " Source: " << cu_file; + LOG(INFO) << " PTX: " << ptx_file; - // Comment out the removal for debugging - // std::remove(cu_file.c_str()); - // std::remove(ptx_file.c_str()); + // Comment out the removal for debugging + // std::remove(cu_file.c_str()); + // std::remove(ptx_file.c_str()); - return ptx_content.str(); + return ptx_content.str(); } CUfunction JitCompiler::load_ptx_and_get_kernel( const std::string& ptx, const std::string& kernel_name) { - std::string module_key = kernel_name + "_" + generate_random_string(10); + std::string module_key = kernel_name + "_" + generate_random_string(10); - CUmodule module; - CUDA_DRIVER_CHECK(cuModuleLoadData(&module, ptx.c_str())); + CUmodule module; + CUDA_DRIVER_CHECK(cuModuleLoadData(&module, ptx.c_str())); - CUfunction kernel; - CUDA_DRIVER_CHECK( - cuModuleGetFunction(&kernel, module, kernel_name.c_str())); + CUfunction kernel; + CUDA_DRIVER_CHECK(cuModuleGetFunction(&kernel, module, kernel_name.c_str())); - module_cache_[module_key] = module; - LOG(INFO) << "Loaded PTX module: " << module_key; - return kernel; + module_cache_[module_key] = module; + LOG(INFO) << "Loaded PTX module: " << module_key; + return kernel; } std::string JitCompiler::write_to_temp_file(const std::string& content, const std::string& extension) { - const char* home_dir = getenv("HOME"); - std::string cache_dir; - - if (!home_dir || strlen(home_dir) == 0) { - LOG(WARNING) - << "HOME environment variable not set or empty, using /tmp instead"; - cache_dir = "/tmp/tilefusion"; - } else { - cache_dir = std::string(home_dir) + "/.cache/tilefusion"; - } - - std::string mkdir_cmd = "mkdir -p " + cache_dir; - - int ret = system(mkdir_cmd.c_str()); - if (ret != 0) { - LOG(ERROR) << "Failed to create cache directory (ret=" << ret - << "): " << cache_dir; - throw std::runtime_error("Failed to create cache directory: " + - cache_dir); - } - - std::string filename = - cache_dir + "/" + generate_random_string(10) + extension; - - std::ofstream out(filename); - if (!out.good()) { - LOG(ERROR) << "Failed to open file for writing: " << filename; - throw std::runtime_error("Failed to create temporary file: " + - filename); - } - - out << content; - if (!out.good()) { - LOG(ERROR) << "Failed to write content to file: " << filename; - throw std::runtime_error("Failed to write to temporary file: " + - filename); - } - - out.close(); - if (!out) { - LOG(ERROR) << "Failed to close file: " << filename; - throw std::runtime_error("Failed to close temporary file: " + filename); - } - - std::ifstream check(filename); - if (!check.good()) { - LOG(ERROR) << "File verification failed: " << filename; - throw std::runtime_error("File doesn't exist after write: " + filename); - } - check.close(); - - return filename; + const char* home_dir = getenv("HOME"); + std::string cache_dir; + + if (!home_dir || strlen(home_dir) == 0) { + LOG(WARNING) + << "HOME environment variable not set or empty, using /tmp instead"; + cache_dir = "/tmp/tilefusion"; + } else { + cache_dir = std::string(home_dir) + "/.cache/tilefusion"; + } + + std::string mkdir_cmd = "mkdir -p " + cache_dir; + + int ret = system(mkdir_cmd.c_str()); + if (ret != 0) { + LOG(ERROR) << "Failed to create cache directory (ret=" << ret + << "): " << cache_dir; + throw std::runtime_error("Failed to create cache directory: " + cache_dir); + } + + std::string filename = + cache_dir + "/" + generate_random_string(10) + extension; + + std::ofstream out(filename); + if (!out.good()) { + LOG(ERROR) << "Failed to open file for writing: " << filename; + throw std::runtime_error("Failed to create temporary file: " + filename); + } + + out << content; + if (!out.good()) { + LOG(ERROR) << "Failed to write content to file: " << filename; + throw std::runtime_error("Failed to write to temporary file: " + filename); + } + + out.close(); + if (!out) { + LOG(ERROR) << "Failed to close file: " << filename; + throw std::runtime_error("Failed to close temporary file: " + filename); + } + + std::ifstream check(filename); + if (!check.good()) { + LOG(ERROR) << "File verification failed: " << filename; + throw std::runtime_error("File doesn't exist after write: " + filename); + } + check.close(); + + return filename; } } // namespace tilefusion::jit diff --git a/src/kernels/flash_attn.cu b/src/kernels/flash_attn.cu index 5b13b47c..408ed018 100644 --- a/src/kernels/flash_attn.cu +++ b/src/kernels/flash_attn.cu @@ -21,37 +21,37 @@ std::string generate_kernel_wrapper( int64_t tile_hidden_qk, int64_t tile_hidden_v, // int64_t warp_rows, int64_t warp_cols, // double softmax_scale, bool causal) { - std::stringstream ss; + std::stringstream ss; - ss << R"( + ss << R"( #include "kernels/flash_attention_device.cuh" using namespace tilefusion::kernels; )"; - ss << "\n// Layout and shape definitions\n"; - ss << "using WarpLayout = tl::RowMajor<" << warp_rows << ", " << warp_cols - << ">;\n"; - ss << "using WholeShape = TileShape<" << length_q << ", " << length_kv - << ", " << hidden_qk << ", " << hidden_v << ">;\n"; - ss << "using CtaTileShape = TileShape<" << tile_length_q << ", " - << tile_length_kv << ", " << tile_hidden_qk << ", " << tile_hidden_v - << ">;\n\n"; - - ss << "// Flash attention configuration\n"; - ss << "using Config = FlashAttentionTraits<" << in_type << ", " << acc_type - << ", " << out_type << ", WholeShape, CtaTileShape, WarpLayout, " - << softmax_scale << ", " << causal << ">;\n\n"; - - ss << "// Kernel function\n"; - ss << "extern \"C\" __global__ void " << kernel_name << "(const " << in_type - << "* Q, const " << in_type << "* K, const " << in_type << "* V, " - << out_type << "* O) {\n"; - ss << " ke_flash_attention<" << in_type << ", " << acc_type << ", " - << out_type << ", Config>(Q, K, V, O);\n"; - ss << "}\n"; - - return ss.str(); + ss << "\n// Layout and shape definitions\n"; + ss << "using WarpLayout = tl::RowMajor<" << warp_rows << ", " << warp_cols + << ">;\n"; + ss << "using WholeShape = TileShape<" << length_q << ", " << length_kv << ", " + << hidden_qk << ", " << hidden_v << ">;\n"; + ss << "using CtaTileShape = TileShape<" << tile_length_q << ", " + << tile_length_kv << ", " << tile_hidden_qk << ", " << tile_hidden_v + << ">;\n\n"; + + ss << "// Flash attention configuration\n"; + ss << "using Config = FlashAttentionTraits<" << in_type << ", " << acc_type + << ", " << out_type << ", WholeShape, CtaTileShape, WarpLayout, " + << softmax_scale << ", " << causal << ">;\n\n"; + + ss << "// Kernel function\n"; + ss << "extern \"C\" __global__ void " << kernel_name << "(const " << in_type + << "* Q, const " << in_type << "* K, const " << in_type << "* V, " + << out_type << "* O) {\n"; + ss << " ke_flash_attention<" << in_type << ", " << acc_type << ", " + << out_type << ", Config>(Q, K, V, O);\n"; + ss << "}\n"; + + return ss.str(); } } // namespace @@ -60,95 +60,95 @@ void flash_attention(const torch::Tensor& Q, const torch::Tensor& K, int64_t tile_length_q, int64_t tile_length_kv, int64_t tile_hidden_qk, int64_t tile_hidden_v, double softmax_scale, bool causal) { - CHECK_INPUT(Q); - CHECK_INPUT(K); - CHECK_INPUT(V); - CHECK_INPUT(O); - - const at::ScalarType dtype = Q.scalar_type(); - TORCH_CHECK(dtype == at::ScalarType::Half && K.scalar_type() == dtype && - V.scalar_type() == dtype && O.scalar_type() == dtype, - "the inputs and output must be half-precision (fp16)."); - - const int64_t length_q = Q.size(0); - const int64_t length_kv = K.size(1); - const int64_t hidden_qk = Q.size(1); - const int64_t hidden_v = V.size(0); - - using InType = __half; - using AccType = float; - using OutType = __half; - using WarpLayout = tl::RowMajor<4, 1>; - - std::string in_type = jit::get_type_string(); - std::string acc_type = jit::get_type_string(); - std::string out_type = jit::get_type_string(); - - std::stringstream kernel_name_ss; - kernel_name_ss << "flash_attention_kernel_" << in_type << "_" << acc_type - << "_" << out_type << "_" << length_q << "_" << length_kv - << "_" << hidden_qk << "_" << hidden_v << "_" - << tile_length_q << "_" << tile_length_kv << "_" - << tile_hidden_qk << "_" << tile_hidden_v; - std::string kernel_name = kernel_name_ss.str(); - - std::string kernel_wrapper_src = generate_kernel_wrapper( - kernel_name, in_type, acc_type, out_type, length_q, length_kv, - hidden_qk, hidden_v, tile_length_q, tile_length_kv, tile_hidden_qk, - tile_hidden_v, tl::num_rows, tl::num_cols, - softmax_scale, causal); - - auto& jit = jit::JitCompiler::instance(); - - auto include_paths = jit::get_default_include_paths(); - auto compile_args = jit::get_default_compile_args(); - CUfunction kernel = jit.get_or_compile_kernel( - kernel_name, kernel_wrapper_src, include_paths, compile_args); - if (!kernel) { - throw std::runtime_error("Failed to compile or retrieve kernel"); - } - - const InType* dQ = reinterpret_cast(Q.data_ptr()); - const InType* dK = reinterpret_cast(K.data_ptr()); - const InType* dV = reinterpret_cast(V.data_ptr()); - OutType* dO = reinterpret_cast(O.data_ptr()); - - void* args[] = {(void*)&dQ, (void*)&dK, (void*)&dV, (void*)&dO}; - - int shm_input = - (tile_length_q * tile_hidden_qk + tile_hidden_qk * tile_length_kv + - tile_length_kv * tile_hidden_v); - int shm_output = tile_length_q * tile_hidden_v; - int shm_size = shm_input < shm_output ? shm_output * sizeof(OutType) - : shm_input * sizeof(InType); - - int block_x = ceil_div(length_q, tile_length_q); - int block_y = ceil_div(hidden_v, tile_hidden_v); - int batch_size = 1; // FIXME(ying): batch size is hardcoded to 1 for now - - static constexpr int kThreads = tl::get_numel * 32; - - dim3 grid(block_x, block_y, batch_size); - dim3 block(kThreads, 1, 1); - - if (shm_size > GetMaxSharedMemoryPerBlock()) { - // Set shared memory size if it exceeds the device limit - CUDA_DRIVER_CHECK(cuFuncSetAttribute( - kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shm_size)); - } - - CUDA_DRIVER_CHECK(cuLaunchKernel( - kernel, grid.x, grid.y, grid.z, // grid dimensions - block.x, block.y, block.z, // block dimensions - shm_size, // shared memory size - 0, // stream - args, // kernel arguments - nullptr // extra parameters - )); - - cudaDeviceSynchronize(); - - LOG(INFO) << "flash_attention kernel launched successfully"; + CHECK_INPUT(Q); + CHECK_INPUT(K); + CHECK_INPUT(V); + CHECK_INPUT(O); + + const at::ScalarType dtype = Q.scalar_type(); + TORCH_CHECK(dtype == at::ScalarType::Half && K.scalar_type() == dtype && + V.scalar_type() == dtype && O.scalar_type() == dtype, + "the inputs and output must be half-precision (fp16)."); + + const int64_t length_q = Q.size(0); + const int64_t length_kv = K.size(1); + const int64_t hidden_qk = Q.size(1); + const int64_t hidden_v = V.size(0); + + using InType = __half; + using AccType = float; + using OutType = __half; + using WarpLayout = tl::RowMajor<4, 1>; + + std::string in_type = jit::get_type_string(); + std::string acc_type = jit::get_type_string(); + std::string out_type = jit::get_type_string(); + + std::stringstream kernel_name_ss; + kernel_name_ss << "flash_attention_kernel_" << in_type << "_" << acc_type + << "_" << out_type << "_" << length_q << "_" << length_kv + << "_" << hidden_qk << "_" << hidden_v << "_" << tile_length_q + << "_" << tile_length_kv << "_" << tile_hidden_qk << "_" + << tile_hidden_v; + std::string kernel_name = kernel_name_ss.str(); + + std::string kernel_wrapper_src = generate_kernel_wrapper( + kernel_name, in_type, acc_type, out_type, length_q, length_kv, hidden_qk, + hidden_v, tile_length_q, tile_length_kv, tile_hidden_qk, tile_hidden_v, + tl::num_rows, tl::num_cols, softmax_scale, + causal); + + auto& jit = jit::JitCompiler::instance(); + + auto include_paths = jit::get_default_include_paths(); + auto compile_args = jit::get_default_compile_args(); + CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_wrapper_src, + include_paths, compile_args); + if (!kernel) { + throw std::runtime_error("Failed to compile or retrieve kernel"); + } + + const InType* dQ = reinterpret_cast(Q.data_ptr()); + const InType* dK = reinterpret_cast(K.data_ptr()); + const InType* dV = reinterpret_cast(V.data_ptr()); + OutType* dO = reinterpret_cast(O.data_ptr()); + + void* args[] = {(void*)&dQ, (void*)&dK, (void*)&dV, (void*)&dO}; + + int shm_input = + (tile_length_q * tile_hidden_qk + tile_hidden_qk * tile_length_kv + + tile_length_kv * tile_hidden_v); + int shm_output = tile_length_q * tile_hidden_v; + int shm_size = shm_input < shm_output ? shm_output * sizeof(OutType) + : shm_input * sizeof(InType); + + int block_x = ceil_div(length_q, tile_length_q); + int block_y = ceil_div(hidden_v, tile_hidden_v); + int batch_size = 1; // FIXME(ying): batch size is hardcoded to 1 for now + + static constexpr int kThreads = tl::get_numel * 32; + + dim3 grid(block_x, block_y, batch_size); + dim3 block(kThreads, 1, 1); + + if (shm_size > GetMaxSharedMemoryPerBlock()) { + // Set shared memory size if it exceeds the device limit + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shm_size)); + } + + CUDA_DRIVER_CHECK(cuLaunchKernel( + kernel, grid.x, grid.y, grid.z, // grid dimensions + block.x, block.y, block.z, // block dimensions + shm_size, // shared memory size + 0, // stream + args, // kernel arguments + nullptr // extra parameters + )); + + cudaDeviceSynchronize(); + + LOG(INFO) << "flash_attention kernel launched successfully"; } } // namespace tilefusion::kernels diff --git a/src/kernels/fused_two_gemms.cu b/src/kernels/fused_two_gemms.cu index ff38f57b..b8afbec4 100644 --- a/src/kernels/fused_two_gemms.cu +++ b/src/kernels/fused_two_gemms.cu @@ -38,111 +38,108 @@ std::string generate_kernel_wrapper(const std::string& kernel_name, int64_t n, int64_t k, int64_t p, int64_t tm = 64, int64_t tn = 64, int64_t tk = 64, int64_t tp = 64) { - std::stringstream ss; + std::stringstream ss; - ss << R"( + ss << R"( #include "kernels/fused_two_gemms_device.cuh" using namespace tilefusion::kernels; )"; - ss << "\n// Fused two gemms configuration\n"; - ss << "using Config = FusedTwoGemmsTraits<" << in_type << ", " << acc_type - << ", tl::RowMajor<2, 1>, " << m << ", " << n << ", " << k << ", " << p - << ", " << tm << ", " << tn << ", " << tk << ", " << tp << ">;\n\n"; + ss << "\n// Fused two gemms configuration\n"; + ss << "using Config = FusedTwoGemmsTraits<" << in_type << ", " << acc_type + << ", tl::RowMajor<2, 1>, " << m << ", " << n << ", " << k << ", " << p + << ", " << tm << ", " << tn << ", " << tk << ", " << tp << ">;\n\n"; - ss << "// Kernel function\n"; - ss << "extern \"C\" __global__ void " << kernel_name << "(const " << in_type - << "* A, const " << in_type << "* B, const " << in_type << "* C, " - << in_type << "* D) {\n"; - ss << " ke_fused_two_gemms<" << in_type << ", " << acc_type << ", " - << "Config>(A, B, C, D);\n"; - ss << "}\n"; + ss << "// Kernel function\n"; + ss << "extern \"C\" __global__ void " << kernel_name << "(const " << in_type + << "* A, const " << in_type << "* B, const " << in_type << "* C, " + << in_type << "* D) {\n"; + ss << " ke_fused_two_gemms<" << in_type << ", " << acc_type << ", " + << "Config>(A, B, C, D);\n"; + ss << "}\n"; - return ss.str(); + return ss.str(); } } // namespace void fused_two_gemms(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, torch::Tensor& D, int64_t tm, int64_t tn, int64_t tk, int64_t tp) { - CHECK_INPUT(A); - CHECK_INPUT(B); - CHECK_INPUT(C); - CHECK_INPUT(D); - - const at::ScalarType dtype = A.scalar_type(); - TORCH_CHECK(dtype == at::ScalarType::Half && B.scalar_type() == dtype && - C.scalar_type() == dtype && D.scalar_type() == dtype, - "the inputs and output must be half-precision (fp16)."); - - const int64_t m = A.size(0); - const int64_t n = B.size(0); - const int64_t k = B.size(1); - const int64_t p = C.size(0); - - // TODO(ying): warp layout should be a configurable parameter - using WarpLayout = tl::RowMajor<2, 1>; - using InType = __half; - using AccType = float; - - // calculate shared memory usage - int shm_input = (tm * tk + tk * tn + tn * tp); - int shm_output = tm * tp; - const int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) - : shm_input * sizeof(InType); - - std::string in_type = jit::get_type_string(); - std::string acc_type = jit::get_type_string(); - - std::stringstream kernel_name_ss; - kernel_name_ss << "fused_two_gemms_kernel_" << in_type << "_" << acc_type - << "_" << m << "_" << n << "_" << k << "_" << p; - std::string kernel_name = kernel_name_ss.str(); - - std::string kernel_wrapper = generate_kernel_wrapper( - kernel_name, in_type, acc_type, m, n, k, p, tm, tn, tk, tp); - - auto& jit = jit::JitCompiler::instance(); - - auto include_paths = jit::get_default_include_paths(); - auto compile_args = jit::get_default_compile_args(); - CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_wrapper, - include_paths, compile_args); - - if (!kernel) { - throw std::runtime_error("Failed to compile or retrieve kernel"); - } - - const InType* A_ptr = - reinterpret_cast(A.data_ptr()); - const InType* B_ptr = - reinterpret_cast(B.data_ptr()); - const InType* C_ptr = - reinterpret_cast(C.data_ptr()); - InType* D_ptr = reinterpret_cast(D.data_ptr()); - - void* args[] = {(void*)&A_ptr, (void*)&B_ptr, (void*)&C_ptr, (void*)&D_ptr, - (void*)&m, (void*)&n, (void*)&k, (void*)&p}; - - int block_x = ceil_div(m, tm); - int block_y = ceil_div(p, tp); - int block_z = 1; - static constexpr int kThreads = tl::get_numel * 32; - - if (shm_size > GetMaxSharedMemoryPerBlock()) { - // Set shared memory size if it exceeds the device limit - CUDA_DRIVER_CHECK(cuFuncSetAttribute( - kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shm_size)); - } - - CUDA_DRIVER_CHECK(cuLaunchKernel(kernel, block_x, block_y, block_z, // grid - kThreads, 1, 1, // block - shm_size, // shared memory bytes - nullptr, // stream - args, // kernel parameters - nullptr)); // extra parameters - - LOG(INFO) << "Fused two gemms kernel launched successfully"; + CHECK_INPUT(A); + CHECK_INPUT(B); + CHECK_INPUT(C); + CHECK_INPUT(D); + + const at::ScalarType dtype = A.scalar_type(); + TORCH_CHECK(dtype == at::ScalarType::Half && B.scalar_type() == dtype && + C.scalar_type() == dtype && D.scalar_type() == dtype, + "the inputs and output must be half-precision (fp16)."); + + const int64_t m = A.size(0); + const int64_t n = B.size(0); + const int64_t k = B.size(1); + const int64_t p = C.size(0); + + // TODO(ying): warp layout should be a configurable parameter + using WarpLayout = tl::RowMajor<2, 1>; + using InType = __half; + using AccType = float; + + // calculate shared memory usage + int shm_input = (tm * tk + tk * tn + tn * tp); + int shm_output = tm * tp; + const int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) + : shm_input * sizeof(InType); + + std::string in_type = jit::get_type_string(); + std::string acc_type = jit::get_type_string(); + + std::stringstream kernel_name_ss; + kernel_name_ss << "fused_two_gemms_kernel_" << in_type << "_" << acc_type + << "_" << m << "_" << n << "_" << k << "_" << p; + std::string kernel_name = kernel_name_ss.str(); + + std::string kernel_wrapper = generate_kernel_wrapper( + kernel_name, in_type, acc_type, m, n, k, p, tm, tn, tk, tp); + + auto& jit = jit::JitCompiler::instance(); + + auto include_paths = jit::get_default_include_paths(); + auto compile_args = jit::get_default_compile_args(); + CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_wrapper, + include_paths, compile_args); + + if (!kernel) { + throw std::runtime_error("Failed to compile or retrieve kernel"); + } + + const InType* A_ptr = reinterpret_cast(A.data_ptr()); + const InType* B_ptr = reinterpret_cast(B.data_ptr()); + const InType* C_ptr = reinterpret_cast(C.data_ptr()); + InType* D_ptr = reinterpret_cast(D.data_ptr()); + + void* args[] = {(void*)&A_ptr, (void*)&B_ptr, (void*)&C_ptr, (void*)&D_ptr, + (void*)&m, (void*)&n, (void*)&k, (void*)&p}; + + int block_x = ceil_div(m, tm); + int block_y = ceil_div(p, tp); + int block_z = 1; + static constexpr int kThreads = tl::get_numel * 32; + + if (shm_size > GetMaxSharedMemoryPerBlock()) { + // Set shared memory size if it exceeds the device limit + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shm_size)); + } + + CUDA_DRIVER_CHECK(cuLaunchKernel(kernel, block_x, block_y, block_z, // grid + kThreads, 1, 1, // block + shm_size, // shared memory bytes + nullptr, // stream + args, // kernel parameters + nullptr)); // extra parameters + + LOG(INFO) << "Fused two gemms kernel launched successfully"; } } // namespace tilefusion::kernels diff --git a/src/kernels/gemm.cu b/src/kernels/gemm.cu index 8a063cf9..a592ba6b 100644 --- a/src/kernels/gemm.cu +++ b/src/kernels/gemm.cu @@ -24,38 +24,38 @@ std::string generate_gemm_kernel_wrapper(const std::string& in_type, int64_t tn, int64_t tk, int64_t num_stages, int64_t pipeline_level) { - int64_t kRK = 16; - int swizzle_bytes = 64; - std::stringstream ss; - ss << R"( + int64_t kRK = 16; + int swizzle_bytes = 64; + std::stringstream ss; + ss << R"( #include "kernels/gemm_device.cuh" using namespace tilefusion::kernels; using Config = KeGemmTraits<)" - << in_type << ", " << acc_type << R"(, + << in_type << ", " << acc_type << R"(, tl::RowMajor<1, 1>, )" - << m << ", " << n << ", " << k << ", " << tm << ", " << tn << ", " << tk - << ", " << kRK << ", " << num_stages << ", " << swizzle_bytes << R"(>; + << m << ", " << n << ", " << k << ", " << tm << ", " << tn << ", " << tk + << ", " << kRK << ", " << num_stages << ", " << swizzle_bytes << R"(>; extern "C" __global__ void gemm_kernel_)" - << in_type << "_" << acc_type << "_" << m << "_" << n << "_" << k << "_" - << tm << "_" << tn << "_" << tk << "_" << num_stages << "_" - << pipeline_level << R"((const )" << in_type << R"(* A, const )" - << in_type << R"(* B, )" << acc_type << R"(* C) {)"; - ss << std::endl; - if (pipeline_level == 0) { - ss << "ke_gemm(A, B, C);"; - } else if (pipeline_level == 1) { - ss << "ke_gemm_level1_pipeline(A, B, C);"; - } else if (pipeline_level == 2) { - ss << "ke_gemm_level2_pipeline(A, B, C);"; - } - - ss << std::endl << "}"; - - return ss.str(); + << in_type << "_" << acc_type << "_" << m << "_" << n << "_" << k << "_" + << tm << "_" << tn << "_" << tk << "_" << num_stages << "_" + << pipeline_level << R"((const )" << in_type << R"(* A, const )" << in_type + << R"(* B, )" << acc_type << R"(* C) {)"; + ss << std::endl; + if (pipeline_level == 0) { + ss << "ke_gemm(A, B, C);"; + } else if (pipeline_level == 1) { + ss << "ke_gemm_level1_pipeline(A, B, C);"; + } else if (pipeline_level == 2) { + ss << "ke_gemm_level2_pipeline(A, B, C);"; + } + + ss << std::endl << "}"; + + return ss.str(); } } // namespace @@ -63,83 +63,82 @@ extern "C" __global__ void gemm_kernel_)" void gemm(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, int64_t tm, int64_t tn, int64_t tk, int64_t num_stages, int64_t pipeline_level) { - CHECK_INPUT(A); - CHECK_INPUT(B); - - const at::ScalarType dtype = A.scalar_type(); - TORCH_CHECK(dtype == at::ScalarType::Half && B.scalar_type() == dtype, - "the inputs must be half-precision (fp16)."); - - const int64_t m = A.size(0); - const int64_t k = A.size(1); - const int64_t n = B.size(1); - - using WarpLayout = tl::RowMajor<1, 1>; - using InType = __half; - using AccType = float; - - int shm_input = (tm * tk + tk * tn) * num_stages; - int shm_output = tm * tn; - int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) - : shm_input * sizeof(InType); - - std::string in_type = jit::get_type_string(); - std::string acc_type = jit::get_type_string(); - - std::string kernel_wrapper = generate_gemm_kernel_wrapper( - in_type, acc_type, m, n, k, tm, tn, tk, num_stages, pipeline_level); - - std::string kernel_name = - "gemm_kernel_" + in_type + "_" + acc_type + "_" + std::to_string(m) + - "_" + std::to_string(n) + "_" + std::to_string(k) + "_" + - std::to_string(tm) + "_" + std::to_string(tn) + "_" + - std::to_string(tk) + "_" + std::to_string(num_stages) + "_" + - std::to_string(pipeline_level); - - auto& jit = jit::JitCompiler::instance(); - auto include_paths = jit::get_default_include_paths(); - auto compile_args = jit::get_default_compile_args(); - CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_wrapper, - include_paths, compile_args); - - if (!kernel) { - throw std::runtime_error("Failed to compile or retrieve kernel"); - } - - const InType* a_ptr = reinterpret_cast(A.data_ptr()); - const InType* b_ptr = reinterpret_cast(B.data_ptr()); - AccType* c_ptr = reinterpret_cast(C.data_ptr()); - - void* args[] = {(void*)&a_ptr, - (void*)&b_ptr, - (void*)&c_ptr, - (void*)&m, - (void*)&n, - (void*)&k, - (void*)&tm, - (void*)&tn, - (void*)&tk, - (void*)&num_stages, - (void*)&pipeline_level}; - - int block_x = ceil_div(m, tm); - int block_y = ceil_div(n, tn); - int block_z = 1; - static constexpr int kThreads = tl::get_numel * 32; - - if (shm_size > GetMaxSharedMemoryPerBlock()) { - // Set shared memory size if it exceeds the device limit - CUDA_DRIVER_CHECK(cuFuncSetAttribute( - kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shm_size)); - } - - CUDA_DRIVER_CHECK(cuLaunchKernel(kernel, block_x, block_y, block_z, // grid - kThreads, 1, 1, // block - shm_size, // shared memory bytes - nullptr, // stream - args, // arguments - nullptr)); // extra parameters - LOG(INFO) << "gemm kernel launched successfully"; + CHECK_INPUT(A); + CHECK_INPUT(B); + + const at::ScalarType dtype = A.scalar_type(); + TORCH_CHECK(dtype == at::ScalarType::Half && B.scalar_type() == dtype, + "the inputs must be half-precision (fp16)."); + + const int64_t m = A.size(0); + const int64_t k = A.size(1); + const int64_t n = B.size(1); + + using WarpLayout = tl::RowMajor<1, 1>; + using InType = __half; + using AccType = float; + + int shm_input = (tm * tk + tk * tn) * num_stages; + int shm_output = tm * tn; + int shm_size = shm_input < shm_output ? shm_output * sizeof(InType) + : shm_input * sizeof(InType); + + std::string in_type = jit::get_type_string(); + std::string acc_type = jit::get_type_string(); + + std::string kernel_wrapper = generate_gemm_kernel_wrapper( + in_type, acc_type, m, n, k, tm, tn, tk, num_stages, pipeline_level); + + std::string kernel_name = + "gemm_kernel_" + in_type + "_" + acc_type + "_" + std::to_string(m) + + "_" + std::to_string(n) + "_" + std::to_string(k) + "_" + + std::to_string(tm) + "_" + std::to_string(tn) + "_" + std::to_string(tk) + + "_" + std::to_string(num_stages) + "_" + std::to_string(pipeline_level); + + auto& jit = jit::JitCompiler::instance(); + auto include_paths = jit::get_default_include_paths(); + auto compile_args = jit::get_default_compile_args(); + CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_wrapper, + include_paths, compile_args); + + if (!kernel) { + throw std::runtime_error("Failed to compile or retrieve kernel"); + } + + const InType* a_ptr = reinterpret_cast(A.data_ptr()); + const InType* b_ptr = reinterpret_cast(B.data_ptr()); + AccType* c_ptr = reinterpret_cast(C.data_ptr()); + + void* args[] = {(void*)&a_ptr, + (void*)&b_ptr, + (void*)&c_ptr, + (void*)&m, + (void*)&n, + (void*)&k, + (void*)&tm, + (void*)&tn, + (void*)&tk, + (void*)&num_stages, + (void*)&pipeline_level}; + + int block_x = ceil_div(m, tm); + int block_y = ceil_div(n, tn); + int block_z = 1; + static constexpr int kThreads = tl::get_numel * 32; + + if (shm_size > GetMaxSharedMemoryPerBlock()) { + // Set shared memory size if it exceeds the device limit + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shm_size)); + } + + CUDA_DRIVER_CHECK(cuLaunchKernel(kernel, block_x, block_y, block_z, // grid + kThreads, 1, 1, // block + shm_size, // shared memory bytes + nullptr, // stream + args, // arguments + nullptr)); // extra parameters + LOG(INFO) << "gemm kernel launched successfully"; } } // namespace tilefusion::kernels diff --git a/src/kernels/scatter_nd.cu b/src/kernels/scatter_nd.cu index ea9270d3..9b2c979e 100644 --- a/src/kernels/scatter_nd.cu +++ b/src/kernels/scatter_nd.cu @@ -11,88 +11,87 @@ template __global__ void ke_scatter_nd(const T* in, T* out, const int64_t* indices, unsigned int const* __restrict__ strides, size_t n, size_t rank, size_t slice_size) { - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x, - step = blockDim.x * gridDim.x; - tid < n; tid += step) { - if (tid < n) { - // tid = indices_index - unsigned int out_index = 0; - // the rank of `data`. - auto i = indices + tid * rank; - // Compute the offset in the output. - // j = i[0] * strides[0] + i[1] * strides[1] + ... + i[k] * - // strides[k] - - for (auto k = 0; k < rank; ++k) { - out_index += i[k] * __ldg(strides + k); - }; - for (size_t offset = 0; offset < slice_size; ++offset) { - atomicAdd(out + out_index + offset, - in[tid * slice_size + offset]); - } - } + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; tid += step) { + if (tid < n) { + // tid = indices_index + unsigned int out_index = 0; + // the rank of `data`. + auto i = indices + tid * rank; + // Compute the offset in the output. + // j = i[0] * strides[0] + i[1] * strides[1] + ... + i[k] * + // strides[k] + + for (auto k = 0; k < rank; ++k) { + out_index += i[k] * __ldg(strides + k); + }; + for (size_t offset = 0; offset < slice_size; ++offset) { + atomicAdd(out + out_index + offset, in[tid * slice_size + offset]); + } } + } } void scatter_nd(const torch::Tensor& data, torch::Tensor& updates, const torch::Tensor& indices) { - auto data_dims = data.sizes(); - auto update_dims = updates.sizes(); - auto indices_dims = indices.sizes(); + auto data_dims = data.sizes(); + auto update_dims = updates.sizes(); + auto indices_dims = indices.sizes(); - // k is the last dimension of indices. - int64_t k = indices_dims[indices_dims.size() - 1]; + // k is the last dimension of indices. + int64_t k = indices_dims[indices_dims.size() - 1]; - // the rank of data. - size_t rank = data_dims.size(); + // the rank of data. + size_t rank = data_dims.size(); - unsigned int* strides = new unsigned int[rank]; - strides[rank - 1] = 1; + unsigned int* strides = new unsigned int[rank]; + strides[rank - 1] = 1; - for (int64_t i = rank - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * data_dims[i + 1]; - } + for (int64_t i = rank - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * data_dims[i + 1]; + } - unsigned int* device_strides; - CUDA_CHECK(cudaMalloc(&device_strides, rank * sizeof(unsigned int))); - CUDA_CHECK(cudaMemcpy(device_strides, strides, rank * sizeof(unsigned int), - cudaMemcpyHostToDevice)); + unsigned int* device_strides; + CUDA_CHECK(cudaMalloc(&device_strides, rank * sizeof(unsigned int))); + CUDA_CHECK(cudaMemcpy(device_strides, strides, rank * sizeof(unsigned int), + cudaMemcpyHostToDevice)); - // `n` is the product of all dimensions excluding the innermost - // dimension of `indices`. - size_t n = indices.numel() / k; + // `n` is the product of all dimensions excluding the innermost + // dimension of `indices`. + size_t n = indices.numel() / k; - size_t slice_size = 1; - for (size_t i = k; i < rank; ++i) { - slice_size *= data_dims[i]; - } + size_t slice_size = 1; + for (size_t i = k; i < rank; ++i) { + slice_size *= data_dims[i]; + } - size_t data_size = data.numel(); + size_t data_size = data.numel(); #ifdef DEBUG - for (int i = rank - 1; i >= 0; --i) { - std::cout << "strides[" << i << "]: " << strides[i] << std::endl; - } - for (int i = rank - 1; i >= 0; --i) { - std::cout << "data_dims[" << i << "]: " << data_dims[i] << std::endl; - } - std::cout << "k: " << k << ", rank: " << rank << std::endl; - std::cout << "n: " << n << ", slice_size: " << slice_size << std::endl; - std::cout << "data_size: " << data_size << std::endl; + for (int i = rank - 1; i >= 0; --i) { + std::cout << "strides[" << i << "]: " << strides[i] << std::endl; + } + for (int i = rank - 1; i >= 0; --i) { + std::cout << "data_dims[" << i << "]: " << data_dims[i] << std::endl; + } + std::cout << "k: " << k << ", rank: " << rank << std::endl; + std::cout << "n: " << n << ", slice_size: " << slice_size << std::endl; + std::cout << "data_size: " << data_size << std::endl; #endif - // TODO: Add some assertion checks. - int64_t block = 256; - int64_t grid = (n + block - 1) / block; - - TILEFUSION_DISPATCH_ALL_TYPES(data.scalar_type(), [&] { - ke_scatter_nd<<>>( - reinterpret_cast(updates.const_data_ptr()), - reinterpret_cast(data.mutable_data_ptr()), - reinterpret_cast(indices.const_data_ptr()), - reinterpret_cast(device_strides), n, k, - slice_size); - }); + // TODO: Add some assertion checks. + int64_t block = 256; + int64_t grid = (n + block - 1) / block; + + TILEFUSION_DISPATCH_ALL_TYPES(data.scalar_type(), [&] { + ke_scatter_nd<<>>( + reinterpret_cast(updates.const_data_ptr()), + reinterpret_cast(data.mutable_data_ptr()), + reinterpret_cast(indices.const_data_ptr()), + reinterpret_cast(device_strides), n, k, + slice_size); + }); } } // namespace tilefusion::kernels diff --git a/src/torch_bind.cc b/src/torch_bind.cc index 08213ba2..ed84f894 100644 --- a/src/torch_bind.cc +++ b/src/torch_bind.cc @@ -6,11 +6,11 @@ namespace tilefusion { TORCH_LIBRARY(tilefusion, m) { - KernelRegistry::instance().register_with_torch(m); + KernelRegistry::instance().register_with_torch(m); } TORCH_LIBRARY_IMPL(tilefusion, CUDA, m) { - KernelRegistry::instance().register_implementations(m); + KernelRegistry::instance().register_implementations(m); } } // namespace tilefusion diff --git a/tests/cpp/cell/test_broadcast.cu b/tests/cpp/cell/test_broadcast.cu index ab66d9f1..197bd5ca 100644 --- a/tests/cpp/cell/test_broadcast.cu +++ b/tests/cpp/cell/test_broadcast.cu @@ -20,94 +20,94 @@ template __global__ void reg_broadcast(Element* src) { - using SrcLoadTile = GlobalTile; - using DstLoadTile = RegTile; - using SrcReduceTile = DstLoadTile; - using DstReduceTile = RegTile>; - using SrcBroadcastTile = DstReduceTile; - using DstBroadcastTile = SrcReduceTile; + using SrcLoadTile = GlobalTile; + using DstLoadTile = RegTile; + using SrcReduceTile = DstLoadTile; + using DstReduceTile = RegTile>; + using SrcBroadcastTile = DstReduceTile; + using DstBroadcastTile = SrcReduceTile; - SrcLoadTile src_load_tile(src); - DstLoadTile dst_load_tile; - DstReduceTile dst_reduce_tile; - DstBroadcastTile dst_broadcast_tile; + SrcLoadTile src_load_tile(src); + DstLoadTile dst_load_tile; + DstReduceTile dst_reduce_tile; + DstBroadcastTile dst_broadcast_tile; - // Load data from global memory to register file - copy::GlobalToRegLoader loader; - loader(src_load_tile, dst_load_tile); - __syncthreads(); + // Load data from global memory to register file + copy::GlobalToRegLoader loader; + loader(src_load_tile, dst_load_tile); + __syncthreads(); - // Execute reduce operation. - compute::MaxReduce row_max; - row_max(dst_load_tile, dst_reduce_tile); + // Execute reduce operation. + compute::MaxReduce row_max; + row_max(dst_load_tile, dst_reduce_tile); - __syncthreads(); + __syncthreads(); - compute::Broadcast - broadcast_reduce; + compute::Broadcast + broadcast_reduce; - broadcast_reduce(dst_reduce_tile, dst_broadcast_tile); + broadcast_reduce(dst_reduce_tile, dst_broadcast_tile); - __syncthreads(); + __syncthreads(); - if (thread(0)) { - printf("Row Max:\n"); - printf("Thread 0:\n"); - dst_broadcast_tile.dump_value(); - } + if (thread(0)) { + printf("Row Max:\n"); + printf("Thread 0:\n"); + dst_broadcast_tile.dump_value(); + } } template void run_row_major_reg_broadcast() { - int kNumel = 16 * 16 * kHeight * kWidth; - int kWarpSize = tl::get_numel; + int kNumel = 16 * 16 * kHeight * kWidth; + int kWarpSize = tl::get_numel; - using ReduceLayout = tl::RowMajor; + using ReduceLayout = tl::RowMajor; - thrust::host_vector h_src(kNumel); - for (int i = 0; i < kNumel; ++i) { - h_src[i] = (Element)i; - } + thrust::host_vector h_src(kNumel); + for (int i = 0; i < kNumel; ++i) { + h_src[i] = (Element)i; + } - thrust::device_vector d_src = h_src; + thrust::device_vector d_src = h_src; - reg_broadcast - <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); + reg_broadcast + <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); } TEST(TestRegBroadcast, row_major_reg_broadcast_0) { - const int kHeight = 1; - const int kWidth = 1; - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor; + const int kHeight = 1; + const int kWidth = 1; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_row_major_reg_broadcast< - Element, RegLayout, GlobalLayout, BaseTileRowMajor, WarpLayout, - tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); + run_row_major_reg_broadcast, WarpLayout, + tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); } TEST(TestRegBroadcast, row_major_reg_broadcast_1) { - const int kHeight = 2; - const int kWidth = 2; - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor; + const int kHeight = 2; + const int kWidth = 2; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_row_major_reg_broadcast< - Element, RegLayout, GlobalLayout, BaseTileRowMajor, WarpLayout, - tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); + run_row_major_reg_broadcast, WarpLayout, + tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_flash_attn.cu b/tests/cpp/cell/test_flash_attn.cu index 16a64f51..50145e91 100644 --- a/tests/cpp/cell/test_flash_attn.cu +++ b/tests/cpp/cell/test_flash_attn.cu @@ -24,140 +24,139 @@ template __global__ void flash_attn_reg_reduce(Element* src) { - using SrcLoadTile = GlobalTile; - using DstLoadTile = RegTile; - using SrcReduceTile = DstLoadTile; - using DstReduceTile = RegTile>; - using SrcBroadcastTile = DstReduceTile; - using DstBroadcastTile = SrcReduceTile; - - SrcLoadTile src_load_tile(src); - DstLoadTile attn_block; - DstReduceTile last_max_vec; - DstReduceTile max_vec; - DstReduceTile last_norm_vec; - DstReduceTile norm_vec; - DstBroadcastTile max_broadcast_tile; - DstBroadcastTile norm_broadcast_tile; - - // Load data from global memory to register file - copy::GlobalToRegLoader loader; - loader(src_load_tile, attn_block); - - // Copy `max_vec` into `last_max_vec` - copy::BaseTileCopy copy_max_reg; - copy_max_reg(max_vec, last_max_vec); - // Copy `norm_vec` into `last_norm_vec` - copy::BaseTileCopy copy_norm_reg; - copy_norm_reg(norm_vec, last_norm_vec); - - // Execute reduce operation. - compute::MaxReduce row_max; - // accumulate onto the max_vec - row_max(attn_block, max_vec); - - compute::Broadcast - broadcast_max; - - broadcast_max(max_vec, max_broadcast_tile); - - if (thread(0)) { - printf("Thread 0:\n"); - max_vec.dump_value(); - max_broadcast_tile.dump_value(); - attn_block.dump_value(); - } - - // subtract max from attention -- now all <= 0. - compute::RegTileSub sub_row_max; - sub_row_max(attn_block, max_broadcast_tile, attn_block); - - if (thread(0)) { - printf("Thread 0:\n"); - attn_block.dump_value(); - } - - // exponentiate the block in-place. - compute::RegTileExp exp_attn; - exp_attn(attn_block, attn_block); - - if (thread(0)) { - printf("Thread 0:\n"); - attn_block.dump_value(); - } - - // subtract new max from old max to find the new normalization. - compute::BaseTileSub sub_new_max; - sub_new_max(last_max_vec, max_vec, last_max_vec); - - // exponentiate this vector -- this is what we need to normalize by. - compute::BaseTileExp exp_max; - exp_max(last_max_vec, last_max_vec); - - // and the norm vec is now normalized. - compute::BaseTileMul mul_norm; - mul_norm(last_max_vec, norm_vec, norm_vec); - - // Accumulate the new attention block onto the now-rescaled norm-vec. - // Reduce Sum + Add - DstReduceTile sum_vec; - compute::SumReduce row_sum; - row_sum(attn_block, sum_vec); - compute::BaseTileAdd add_sum; - add_sum(sum_vec, norm_vec, norm_vec); - - // Now the attention block is correctly normalized. - // Broadcast + Divide - compute::Broadcast - broadcast_norm; - broadcast_norm(norm_vec, norm_broadcast_tile); - compute::RegTileDiv div_norm; - div_norm(attn_block, norm_broadcast_tile, attn_block); - - // Normalize the previous norm vec accorfing to the new max. - compute::BaseTileMul mul_norm_new; - mul_norm_new(last_max_vec, last_norm_vec, last_norm_vec); - - // Normalize the previous norm vec according to the new norm. - compute::BaseTileDiv div_norm_new; - div_norm_new(last_norm_vec, norm_vec, last_norm_vec); + using SrcLoadTile = GlobalTile; + using DstLoadTile = RegTile; + using SrcReduceTile = DstLoadTile; + using DstReduceTile = RegTile>; + using SrcBroadcastTile = DstReduceTile; + using DstBroadcastTile = SrcReduceTile; + + SrcLoadTile src_load_tile(src); + DstLoadTile attn_block; + DstReduceTile last_max_vec; + DstReduceTile max_vec; + DstReduceTile last_norm_vec; + DstReduceTile norm_vec; + DstBroadcastTile max_broadcast_tile; + DstBroadcastTile norm_broadcast_tile; + + // Load data from global memory to register file + copy::GlobalToRegLoader loader; + loader(src_load_tile, attn_block); + + // Copy `max_vec` into `last_max_vec` + copy::BaseTileCopy copy_max_reg; + copy_max_reg(max_vec, last_max_vec); + // Copy `norm_vec` into `last_norm_vec` + copy::BaseTileCopy copy_norm_reg; + copy_norm_reg(norm_vec, last_norm_vec); + + // Execute reduce operation. + compute::MaxReduce row_max; + // accumulate onto the max_vec + row_max(attn_block, max_vec); + + compute::Broadcast broadcast_max; + + broadcast_max(max_vec, max_broadcast_tile); + + if (thread(0)) { + printf("Thread 0:\n"); + max_vec.dump_value(); + max_broadcast_tile.dump_value(); + attn_block.dump_value(); + } + + // subtract max from attention -- now all <= 0. + compute::RegTileSub sub_row_max; + sub_row_max(attn_block, max_broadcast_tile, attn_block); + + if (thread(0)) { + printf("Thread 0:\n"); + attn_block.dump_value(); + } + + // exponentiate the block in-place. + compute::RegTileExp exp_attn; + exp_attn(attn_block, attn_block); + + if (thread(0)) { + printf("Thread 0:\n"); + attn_block.dump_value(); + } + + // subtract new max from old max to find the new normalization. + compute::BaseTileSub sub_new_max; + sub_new_max(last_max_vec, max_vec, last_max_vec); + + // exponentiate this vector -- this is what we need to normalize by. + compute::BaseTileExp exp_max; + exp_max(last_max_vec, last_max_vec); + + // and the norm vec is now normalized. + compute::BaseTileMul mul_norm; + mul_norm(last_max_vec, norm_vec, norm_vec); + + // Accumulate the new attention block onto the now-rescaled norm-vec. + // Reduce Sum + Add + DstReduceTile sum_vec; + compute::SumReduce row_sum; + row_sum(attn_block, sum_vec); + compute::BaseTileAdd add_sum; + add_sum(sum_vec, norm_vec, norm_vec); + + // Now the attention block is correctly normalized. + // Broadcast + Divide + compute::Broadcast + broadcast_norm; + broadcast_norm(norm_vec, norm_broadcast_tile); + compute::RegTileDiv div_norm; + div_norm(attn_block, norm_broadcast_tile, attn_block); + + // Normalize the previous norm vec accorfing to the new max. + compute::BaseTileMul mul_norm_new; + mul_norm_new(last_max_vec, last_norm_vec, last_norm_vec); + + // Normalize the previous norm vec according to the new norm. + compute::BaseTileDiv div_norm_new; + div_norm_new(last_norm_vec, norm_vec, last_norm_vec); } template void run_row_major_reg_flash_attn() { - int kNumel = 16 * 16 * kHeight * kWidth; - int kWarpSize = tl::get_numel; + int kNumel = 16 * 16 * kHeight * kWidth; + int kWarpSize = tl::get_numel; - using ReduceLayout = tl::RowMajor; + using ReduceLayout = tl::RowMajor; - thrust::host_vector h_src(kNumel); - for (int i = 0; i < kNumel; ++i) { - h_src[i] = (Element)i; - } + thrust::host_vector h_src(kNumel); + for (int i = 0; i < kNumel; ++i) { + h_src[i] = (Element)i; + } - thrust::device_vector d_src = h_src; + thrust::device_vector d_src = h_src; - flash_attn_reg_reduce - <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); + flash_attn_reg_reduce + <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); } TEST(TestRegBroadcast, row_major_reg_flash_attn_0) { - const int kHeight = 1; - const int kWidth = 1; - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor; + const int kHeight = 1; + const int kWidth = 1; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_row_major_reg_flash_attn< - Element, RegLayout, GlobalLayout, BaseTileRowMajor, WarpLayout, - tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); + run_row_major_reg_flash_attn, WarpLayout, + tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_g2r_copy.cu b/tests/cpp/cell/test_g2r_copy.cu index ad5f554f..00611965 100644 --- a/tests/cpp/cell/test_g2r_copy.cu +++ b/tests/cpp/cell/test_g2r_copy.cu @@ -18,120 +18,106 @@ template __global__ void load_g2r(Element* src) { - using SrcTile = GlobalTile; - using DstTile = RegTile; - SrcTile src_tile(src); - DstTile dst_tile; + using SrcTile = GlobalTile; + using DstTile = RegTile; + SrcTile src_tile(src); + DstTile dst_tile; - copy::GlobalToRegLoader loader; - loader(src_tile, dst_tile); - __syncthreads(); + copy::GlobalToRegLoader loader; + loader(src_tile, dst_tile); + __syncthreads(); - if (thread0()) { - printf("thread 0:\n"); - dst_tile.dump_value(); - } + if (thread0()) { + printf("thread 0:\n"); + dst_tile.dump_value(); + } - if (thread(1)) { - printf("thread 1:\n"); - dst_tile.dump_value(); - } + if (thread(1)) { + printf("thread 1:\n"); + dst_tile.dump_value(); + } - __syncthreads(); + __syncthreads(); - if (thread(32)) { - printf("thread 32:\n"); - dst_tile.dump_value(); - } + if (thread(32)) { + printf("thread 32:\n"); + dst_tile.dump_value(); + } - __syncthreads(); + __syncthreads(); - if (thread(64)) { - printf("thread 64:\n"); - dst_tile.dump_value(); - } + if (thread(64)) { + printf("thread 64:\n"); + dst_tile.dump_value(); + } - __syncthreads(); + __syncthreads(); - if (thread(96)) { - printf("thread 96:\n"); - dst_tile.dump_value(); - } + if (thread(96)) { + printf("thread 96:\n"); + dst_tile.dump_value(); + } } template __global__ void store_r2g(Element* dst) { - using SrcLayout = RegTile; - using DstLayout = GlobalTile; - - SrcLayout src_tile; - DstLayout dst_tile(dst); - - int lane_id = threadIdx.x % 32; - - switch (kType) { - case tl::Layout::kRowMajor: - // row major - for (int i = 0; i < kHeight; ++i) { - int row = i * 16 + lane_id / 4; - for (int j = 0; j < kWidth; ++j) { - int col = j * 16 + (lane_id % 4) * 2; - src_tile(i, j)(0, 0) = row * DstLayout::kRowStride + col; - src_tile(i, j)(0, 1) = - row * DstLayout::kRowStride + col + 1; - src_tile(i, j)(1, 0) = - row * DstLayout::kRowStride + col + 8; - src_tile(i, j)(1, 1) = - row * DstLayout::kRowStride + col + 9; - src_tile(i, j)(0, 2) = - (row + 8) * DstLayout::kRowStride + col; - src_tile(i, j)(0, 3) = - (row + 8) * DstLayout::kRowStride + col + 1; - src_tile(i, j)(1, 2) = - (row + 8) * DstLayout::kRowStride + col + 8; - src_tile(i, j)(1, 3) = - (row + 8) * DstLayout::kRowStride + col + 9; - } - } - break; - case tl::Layout::kColMajor: - // col major - for (int i = 0; i < kWidth; ++i) { - int col = i * 16 + lane_id / 4; - for (int j = 0; j < kHeight; ++j) { - int row = j * 16 + (lane_id % 4) * 2; - src_tile(j, i)(0, 0) = col * DstLayout::kColStride + row; - src_tile(j, i)(1, 0) = - col * DstLayout::kColStride + row + 1; - src_tile(j, i)(0, 1) = - col * DstLayout::kColStride + row + 8; - src_tile(j, i)(1, 1) = - col * DstLayout::kColStride + row + 9; - src_tile(j, i)(2, 0) = - (col + 8) * DstLayout::kColStride + row; - src_tile(j, i)(3, 0) = - (col + 8) * DstLayout::kColStride + row + 1; - src_tile(j, i)(2, 1) = - (col + 8) * DstLayout::kColStride + row + 8; - src_tile(j, i)(3, 1) = - (col + 8) * DstLayout::kColStride + row + 9; - } - } - break; - default: - break; - } - - copy::RegToGlobalStorer storer; - storer(src_tile, dst_tile); - - __syncthreads(); - - if (thread(0)) { - dst_tile.dump_value(); - } + using SrcLayout = RegTile; + using DstLayout = GlobalTile; + + SrcLayout src_tile; + DstLayout dst_tile(dst); + + int lane_id = threadIdx.x % 32; + + switch (kType) { + case tl::Layout::kRowMajor: + // row major + for (int i = 0; i < kHeight; ++i) { + int row = i * 16 + lane_id / 4; + for (int j = 0; j < kWidth; ++j) { + int col = j * 16 + (lane_id % 4) * 2; + src_tile(i, j)(0, 0) = row * DstLayout::kRowStride + col; + src_tile(i, j)(0, 1) = row * DstLayout::kRowStride + col + 1; + src_tile(i, j)(1, 0) = row * DstLayout::kRowStride + col + 8; + src_tile(i, j)(1, 1) = row * DstLayout::kRowStride + col + 9; + src_tile(i, j)(0, 2) = (row + 8) * DstLayout::kRowStride + col; + src_tile(i, j)(0, 3) = (row + 8) * DstLayout::kRowStride + col + 1; + src_tile(i, j)(1, 2) = (row + 8) * DstLayout::kRowStride + col + 8; + src_tile(i, j)(1, 3) = (row + 8) * DstLayout::kRowStride + col + 9; + } + } + break; + case tl::Layout::kColMajor: + // col major + for (int i = 0; i < kWidth; ++i) { + int col = i * 16 + lane_id / 4; + for (int j = 0; j < kHeight; ++j) { + int row = j * 16 + (lane_id % 4) * 2; + src_tile(j, i)(0, 0) = col * DstLayout::kColStride + row; + src_tile(j, i)(1, 0) = col * DstLayout::kColStride + row + 1; + src_tile(j, i)(0, 1) = col * DstLayout::kColStride + row + 8; + src_tile(j, i)(1, 1) = col * DstLayout::kColStride + row + 9; + src_tile(j, i)(2, 0) = (col + 8) * DstLayout::kColStride + row; + src_tile(j, i)(3, 0) = (col + 8) * DstLayout::kColStride + row + 1; + src_tile(j, i)(2, 1) = (col + 8) * DstLayout::kColStride + row + 8; + src_tile(j, i)(3, 1) = (col + 8) * DstLayout::kColStride + row + 9; + } + } + break; + default: + break; + } + + copy::RegToGlobalStorer storer; + storer(src_tile, dst_tile); + + __syncthreads(); + + if (thread(0)) { + dst_tile.dump_value(); + } } template void run_load_g2r_test() { - int kNumel = 16 * 16 * kHeight * kWidth; - thrust::host_vector h_src(kNumel); - for (int i = 0; i < kNumel; ++i) { - h_src[i] = (Element)i; - } + int kNumel = 16 * 16 * kHeight * kWidth; + thrust::host_vector h_src(kNumel); + for (int i = 0; i < kNumel; ++i) { + h_src[i] = (Element)i; + } - thrust::device_vector d_src = h_src; + thrust::device_vector d_src = h_src; - load_g2r<<<1, 32 * kWarpSize>>>(d_src.data().get()); + load_g2r<<<1, 32 * kWarpSize>>>(d_src.data().get()); } template void run_store_r2g_test() { - int kNumel = 16 * 16 * kHeight * kWidth; - thrust::host_vector h_dst(kNumel, 0); + int kNumel = 16 * 16 * kHeight * kWidth; + thrust::host_vector h_dst(kNumel, 0); - thrust::device_vector d_dst = h_dst; + thrust::device_vector d_dst = h_dst; - store_r2g<<<1, 32 * kWarpSize>>>(d_dst.data().get()); + store_r2g<<<1, 32 * kWarpSize>>>(d_dst.data().get()); } } // namespace TEST(TestG2RegCopy, copy_2d_tile_g2r_row_major_0) { - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor<1, 1>; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor<1, 1>; - const int kHeight = 1; - const int kWidth = 1; - const int kWarpSize = 1; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const int kHeight = 1; + const int kWidth = 1; + const int kWarpSize = 1; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_row_major_1) { - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor<2, 2>; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor<2, 2>; - const int kHeight = 2; - const int kWidth = 2; - const int kWarpSize = 1; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const int kHeight = 2; + const int kWidth = 2; + const int kWarpSize = 1; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_row_major_2) { - using Element = float; - using WarpLayout = tl::RowMajor<2, 2>; - using RegLayout = tl::RowMajor<1, 2>; + using Element = float; + using WarpLayout = tl::RowMajor<2, 2>; + using RegLayout = tl::RowMajor<1, 2>; - const int kHeight = 2; - const int kWidth = 2; - const int kWarpSize = 4; - const copy::WarpReuse kMode = copy::WarpReuse::kRowReuseCont; + const int kHeight = 2; + const int kWidth = 2; + const int kWarpSize = 4; + const copy::WarpReuse kMode = copy::WarpReuse::kRowReuseCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_row_major_3) { - using Element = float; - using WarpLayout = tl::RowMajor<2, 2>; - using RegLayout = tl::RowMajor<2, 1>; + using Element = float; + using WarpLayout = tl::RowMajor<2, 2>; + using RegLayout = tl::RowMajor<2, 1>; - const int kHeight = 2; - const int kWidth = 2; - const int kWarpSize = 4; - const copy::WarpReuse kMode = copy::WarpReuse::kColReuseCont; + const int kHeight = 2; + const int kWidth = 2; + const int kWarpSize = 4; + const copy::WarpReuse kMode = copy::WarpReuse::kColReuseCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_col_major_0) { - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::ColMajor<1, 1>; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::ColMajor<1, 1>; - const int kHeight = 1; - const int kWidth = 1; - const int kWarpSize = 1; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const int kHeight = 1; + const int kWidth = 1; + const int kWarpSize = 1; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; + using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_col_major_1) { - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::ColMajor<2, 2>; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::ColMajor<2, 2>; - const int kHeight = 2; - const int kWidth = 2; - const int kWarpSize = 1; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const int kHeight = 2; + const int kWidth = 2; + const int kWarpSize = 1; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; + using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_col_major_2) { - using Element = float; - using WarpLayout = tl::RowMajor<2, 2>; - using RegLayout = tl::ColMajor<1, 2>; + using Element = float; + using WarpLayout = tl::RowMajor<2, 2>; + using RegLayout = tl::ColMajor<1, 2>; - const int kHeight = 2; - const int kWidth = 2; - const int kWarpSize = 4; - const copy::WarpReuse kMode = copy::WarpReuse::kRowReuseCont; + const int kHeight = 2; + const int kWidth = 2; + const int kWarpSize = 4; + const copy::WarpReuse kMode = copy::WarpReuse::kRowReuseCont; - using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; + using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, load_2d_tile_g2r_col_major_3) { - using Element = float; - using WarpLayout = tl::RowMajor<2, 2>; - using RegLayout = tl::ColMajor<2, 1>; + using Element = float; + using WarpLayout = tl::RowMajor<2, 2>; + using RegLayout = tl::ColMajor<2, 1>; - const int kHeight = 2; - const int kWidth = 2; - const int kWarpSize = 4; - const copy::WarpReuse kMode = copy::WarpReuse::kColReuseCont; + const int kHeight = 2; + const int kWidth = 2; + const int kWarpSize = 4; + const copy::WarpReuse kMode = copy::WarpReuse::kColReuseCont; - using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; + using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; - run_load_g2r_test, kMode, kHeight, - kWidth, kWarpSize>(); + run_load_g2r_test, kMode, kHeight, + kWidth, kWarpSize>(); } TEST(TestG2RegCopy, store_2d_tile_r2g_row_major) { - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor<1, 1>; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor<1, 1>; - const int kHeight = 1; - const int kWidth = 1; - const int kWarpSize = 1; + const int kHeight = 1; + const int kWidth = 1; + const int kWarpSize = 1; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_store_r2g_test, kHeight, kWidth, - kWarpSize>(); + run_store_r2g_test, kHeight, kWidth, + kWarpSize>(); } TEST(TestG2RegCopy, store_2d_tile_r2g_col_major) { - using Element = float; - using WarpLayout = tl::ColMajor<1, 1>; - using RegLayout = tl::ColMajor<1, 1>; + using Element = float; + using WarpLayout = tl::ColMajor<1, 1>; + using RegLayout = tl::ColMajor<1, 1>; - const int kHeight = 1; - const int kWidth = 1; - const int kWarpSize = 1; + const int kHeight = 1; + const int kWidth = 1; + const int kWarpSize = 1; - using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; + using GlobalLayout = tl::ColMajor<16 * kWidth, 16 * kHeight>; - run_store_r2g_test, kHeight, kWidth, - kWarpSize>(); + run_store_r2g_test, kHeight, kWidth, + kWarpSize>(); } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_g2s_load.cu b/tests/cpp/cell/test_g2s_load.cu index 465f3814..bf9bd8f9 100644 --- a/tests/cpp/cell/test_g2s_load.cu +++ b/tests/cpp/cell/test_g2s_load.cu @@ -18,29 +18,29 @@ template __global__ void copy_g2s(const Element* src_ptr, Element* dst_ptr, Loader& loader, Storer& storer) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); - SrcTile src(src_ptr); // global memory tile - DstTile inter(buf); // shared memory tile - SrcTile dst(dst_ptr); // global memory tile + SrcTile src(src_ptr); // global memory tile + DstTile inter(buf); // shared memory tile + SrcTile dst(dst_ptr); // global memory tile - loader(src, inter); - copy::__copy_async(); - __syncthreads(); + loader(src, inter); + copy::__copy_async(); + __syncthreads(); - storer(inter, dst); - __syncthreads(); + storer(inter, dst); + __syncthreads(); #if defined(DEBUG) - if (thread(0)) { - printf("\nshared\n"); - inter.dump_value(); - - printf("\nglobal\n"); - dst.dump_value(); - printf("\n"); - } + if (thread(0)) { + printf("\nshared\n"); + inter.dump_value(); + + printf("\nglobal\n"); + dst.dump_value(); + printf("\n"); + } #endif } @@ -48,258 +48,222 @@ template void run_test_row_major() { - static const int kThreads = tl::get_numel * 32; + static const int kThreads = tl::get_numel * 32; - int numel = kRows * kCols; - thrust::host_vector h_A(numel); - for (int i = 0; i < h_A.size(); ++i) - h_A[i] = static_cast(i % 2048); + int numel = kRows * kCols; + thrust::host_vector h_A(numel); + for (int i = 0; i < h_A.size(); ++i) h_A[i] = static_cast(i % 2048); - thrust::device_vector d_B(numel); - thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); - thrust::device_vector d_A = h_A; + thrust::device_vector d_B(numel); + thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); + thrust::device_vector d_A = h_A; - using SrcTile = GlobalTile>; + using SrcTile = GlobalTile>; - using DstTile = SharedTile, kSwizzled, - kSharedAccessInBytes>; + using DstTile = SharedTile, kSwizzled, + kSharedAccessInBytes>; - using Loader = GlobalToSharedLoader; - Loader loader; + using Loader = GlobalToSharedLoader; + Loader loader; - using Storer = SharedToGlobalStorer; - Storer storer; + using Storer = SharedToGlobalStorer; + Storer storer; - auto copy_kernel = copy_g2s; + auto copy_kernel = copy_g2s; - int shm_size = kRows * kCols * sizeof(Element); - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - copy_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } + int shm_size = kRows * kCols * sizeof(Element); + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(copy_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + } - dim3 dim_grid(1, 1); - dim3 dim_block(kThreads); - copy_kernel<<>>( - thrust::raw_pointer_cast(d_A.data()), - thrust::raw_pointer_cast(d_B.data()), loader, storer); - cudaDeviceSynchronize(); + dim3 dim_grid(1, 1); + dim3 dim_block(kThreads); + copy_kernel<<>>( + thrust::raw_pointer_cast(d_A.data()), + thrust::raw_pointer_cast(d_B.data()), loader, storer); + cudaDeviceSynchronize(); - thrust::host_vector h_B(numel); - h_B = d_B; + thrust::host_vector h_B(numel); + h_B = d_B; - assert_equal( - reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), - reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), numel, - 1e-5); + assert_equal(reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), + reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), + numel, 1e-5); } template void run_test_col_major() { - static const int kThreads = tl::get_numel * 32; + static const int kThreads = tl::get_numel * 32; - int numel = kRows * kCols; - thrust::host_vector h_A(numel); - for (int i = 0; i < h_A.size(); ++i) - h_A[i] = static_cast(i % 2048); + int numel = kRows * kCols; + thrust::host_vector h_A(numel); + for (int i = 0; i < h_A.size(); ++i) h_A[i] = static_cast(i % 2048); - thrust::device_vector d_B(numel); - thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); - thrust::device_vector d_A = h_A; + thrust::device_vector d_B(numel); + thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); + thrust::device_vector d_A = h_A; - using SrcTile = GlobalTile>; - using DstTile = SharedTile, kSwizzled, - kSharedAccessInBytes>; + using SrcTile = GlobalTile>; + using DstTile = SharedTile, kSwizzled, + kSharedAccessInBytes>; - using Loader = copy::GlobalToSharedLoader; - Loader loader; + using Loader = copy::GlobalToSharedLoader; + Loader loader; - using Storer = copy::SharedToGlobalStorer; - Storer storer; + using Storer = copy::SharedToGlobalStorer; + Storer storer; - dim3 dim_grid(1, 1); - dim3 dim_block(kThreads); + dim3 dim_grid(1, 1); + dim3 dim_block(kThreads); - auto copy_kernel = copy_g2s; + auto copy_kernel = copy_g2s; - int shm_size = kRows * kCols * sizeof(Element); - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - copy_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } + int shm_size = kRows * kCols * sizeof(Element); + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(copy_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); + } - copy_kernel<<>>( - thrust::raw_pointer_cast(d_A.data()), - thrust::raw_pointer_cast(d_B.data()), loader, storer); - cudaDeviceSynchronize(); + copy_kernel<<>>( + thrust::raw_pointer_cast(d_A.data()), + thrust::raw_pointer_cast(d_B.data()), loader, storer); + cudaDeviceSynchronize(); - thrust::host_vector h_B(numel); - h_B = d_B; + thrust::host_vector h_B(numel); + h_B = d_B; - assert_equal( - reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), - reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), numel, - 1e-5); + assert_equal(reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), + reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), + numel, 1e-5); } } // namespace TEST(GlobalToSharedLoad, test_row_major_load) { - { // test non-swizzled __half. - static constexpr bool kSwizzled = false; - - run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 64, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 256, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<1, 4>, 16, 256, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<4, 1>, 64, 64, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 128, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 4>, 32, 256, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 4>, 64, 512, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32, 64, kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 64, 64, kSwizzled>(); - } - - { // test swizzled __half. - static constexpr bool kSwizzled = true; - - run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 64, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 256, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<1, 4>, 16, 256, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<4, 1>, 64, 64, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 128, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 4>, 32, 256, 128, - kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 4>, 64, 512, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32, 64, kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); - run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 64, 64, kSwizzled>(); - } - - { // test non-swizzled float. - static constexpr bool kSwizzled = false; - - run_test_row_major, 8, 32, 128, kSwizzled>(); - run_test_row_major, 16, 64, 128, kSwizzled>(); - run_test_row_major, 16, 128, 128, - kSwizzled>(); - run_test_row_major, 64, 32, 128, kSwizzled>(); - run_test_row_major, 32, 64, 128, kSwizzled>(); - run_test_row_major, 32, 128, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_row_major, 8, 16, 64, kSwizzled>(); - run_test_row_major, 16, 16, 64, kSwizzled>(); - run_test_row_major, 16, 32, 64, kSwizzled>(); - } - - { // test swizzled float. - static constexpr bool kSwizzled = true; - - run_test_row_major, 8, 32, 128, kSwizzled>(); - run_test_row_major, 16, 64, 128, kSwizzled>(); - run_test_row_major, 16, 128, 128, - kSwizzled>(); - run_test_row_major, 64, 32, 128, kSwizzled>(); - run_test_row_major, 32, 64, 128, kSwizzled>(); - run_test_row_major, 32, 128, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_row_major, 8, 16, 64, kSwizzled>(); - run_test_row_major, 16, 16, 64, kSwizzled>(); - run_test_row_major, 16, 32, 64, kSwizzled>(); - } + { // test non-swizzled __half. + static constexpr bool kSwizzled = false; + + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 64, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 256, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<1, 4>, 16, 256, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<4, 1>, 64, 64, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 128, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 4>, 32, 256, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 4>, 64, 512, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32, 64, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 64, 64, kSwizzled>(); + } + + { // test swizzled __half. + static constexpr bool kSwizzled = true; + + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 64, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 256, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<1, 4>, 16, 256, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<4, 1>, 64, 64, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 128, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 4>, 32, 256, 128, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 4>, 64, 512, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32, 64, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); + run_test_row_major<__half, tl::RowMajor<2, 2>, 32, 64, 64, kSwizzled>(); + } + + { // test non-swizzled float. + static constexpr bool kSwizzled = false; + + run_test_row_major, 8, 32, 128, kSwizzled>(); + run_test_row_major, 16, 64, 128, kSwizzled>(); + run_test_row_major, 16, 128, 128, kSwizzled>(); + run_test_row_major, 64, 32, 128, kSwizzled>(); + run_test_row_major, 32, 64, 128, kSwizzled>(); + run_test_row_major, 32, 128, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_row_major, 8, 16, 64, kSwizzled>(); + run_test_row_major, 16, 16, 64, kSwizzled>(); + run_test_row_major, 16, 32, 64, kSwizzled>(); + } + + { // test swizzled float. + static constexpr bool kSwizzled = true; + + run_test_row_major, 8, 32, 128, kSwizzled>(); + run_test_row_major, 16, 64, 128, kSwizzled>(); + run_test_row_major, 16, 128, 128, kSwizzled>(); + run_test_row_major, 64, 32, 128, kSwizzled>(); + run_test_row_major, 32, 64, 128, kSwizzled>(); + run_test_row_major, 32, 128, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_row_major, 8, 16, 64, kSwizzled>(); + run_test_row_major, 16, 16, 64, kSwizzled>(); + run_test_row_major, 16, 32, 64, kSwizzled>(); + } } TEST(GlobalToSharedLoad, test_col_major_load) { - { - static constexpr bool kSwizzled = false; - - run_test_col_major<__half, tl::RowMajor<1, 1>, 64, 16, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<1, 1>, 128, 16, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<1, 4>, 64, 128, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<4, 1>, 256, 16, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<2, 2>, 128, 32, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 16, 64, kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<2, 2>, 64, 32, 64, kSwizzled>(); - } - - { - static constexpr bool kSwizzled = true; - run_test_col_major<__half, tl::RowMajor<1, 1>, 64, 16, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<1, 1>, 128, 16, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<1, 4>, 64, 128, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<4, 1>, 256, 32, 128, - kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<2, 2>, 128, 32, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 16, 64, kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); - run_test_col_major<__half, tl::RowMajor<2, 2>, 64, 32, 64, kSwizzled>(); - } - - { - static constexpr bool kSwizzled = false; - - run_test_col_major, 32, 16, 128, kSwizzled>(); - run_test_col_major, 64, 16, 128, kSwizzled>(); - run_test_col_major, 64, 64, 128, kSwizzled>(); - run_test_col_major, 128, 32, 128, - kSwizzled>(); - run_test_col_major, 64, 64, 128, kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_col_major, 16, 16, 64, kSwizzled>(); - run_test_col_major, 16, 32, 64, kSwizzled>(); - run_test_col_major, 32, 32, 64, kSwizzled>(); - } - - { - static constexpr bool kSwizzled = true; - run_test_col_major, 128, 128, 128, - kSwizzled>(); - - // Swizzle<2, 3, 3> - run_test_col_major, 16, 16, 64, kSwizzled>(); - run_test_col_major, 16, 32, 64, kSwizzled>(); - run_test_col_major, 32, 32, 64, kSwizzled>(); - } + { + static constexpr bool kSwizzled = false; + + run_test_col_major<__half, tl::RowMajor<1, 1>, 64, 16, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<1, 1>, 128, 16, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<1, 4>, 64, 128, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<4, 1>, 256, 16, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<2, 2>, 128, 32, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 16, 64, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<2, 2>, 64, 32, 64, kSwizzled>(); + } + + { + static constexpr bool kSwizzled = true; + run_test_col_major<__half, tl::RowMajor<1, 1>, 64, 16, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<1, 1>, 128, 16, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<1, 4>, 64, 128, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<4, 1>, 256, 32, 128, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<2, 2>, 128, 32, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 16, 64, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<1, 1>, 32, 32, 64, kSwizzled>(); + run_test_col_major<__half, tl::RowMajor<2, 2>, 64, 32, 64, kSwizzled>(); + } + + { + static constexpr bool kSwizzled = false; + + run_test_col_major, 32, 16, 128, kSwizzled>(); + run_test_col_major, 64, 16, 128, kSwizzled>(); + run_test_col_major, 64, 64, 128, kSwizzled>(); + run_test_col_major, 128, 32, 128, kSwizzled>(); + run_test_col_major, 64, 64, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_col_major, 16, 16, 64, kSwizzled>(); + run_test_col_major, 16, 32, 64, kSwizzled>(); + run_test_col_major, 32, 32, 64, kSwizzled>(); + } + + { + static constexpr bool kSwizzled = true; + run_test_col_major, 128, 128, 128, kSwizzled>(); + + // Swizzle<2, 3, 3> + run_test_col_major, 16, 16, 64, kSwizzled>(); + run_test_col_major, 16, 32, 64, kSwizzled>(); + run_test_col_major, 32, 32, 64, kSwizzled>(); + } } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_gemm.cu b/tests/cpp/cell/test_gemm.cu index e513e4c4..0bd6d065 100644 --- a/tests/cpp/cell/test_gemm.cu +++ b/tests/cpp/cell/test_gemm.cu @@ -17,65 +17,62 @@ namespace tl = tile_layout; namespace { float rand_float(float a = 1e-3, float b = 1) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; } bool check_correctness(const half* hc1, const float* hc2, int row, int col) { - int numel = row * col; - bool pass_unittest = true; - static const float eps = 5e-2; + int numel = row * col; + bool pass_unittest = true; + static const float eps = 5e-2; #if defined(DEBUG) - int cut_off = 128; - std::stringstream ss; - ss << std::setprecision(3) << std::endl - << "ours:" << std::endl - << 0 << ":\t"; - for (int i = 0; i < cut_off; ++i) { - ss << hc2[i] << ", "; - if (i & (i + 1) % 16 == 0) { - ss << std::endl << (i + 1) / 16 << ":\t"; - } + int cut_off = 128; + std::stringstream ss; + ss << std::setprecision(3) << std::endl << "ours:" << std::endl << 0 << ":\t"; + for (int i = 0; i < cut_off; ++i) { + ss << hc2[i] << ", "; + if (i & (i + 1) % 16 == 0) { + ss << std::endl << (i + 1) / 16 << ":\t"; } + } - ss << std::endl << "cublas:" << std::endl << 0 << ":\t"; - for (int i = 0; i < cut_off; ++i) { - ss << __half2float(hc1[i]) << ", "; - if (i & (i + 1) % 16 == 0) { - ss << std::endl << (i + 1) / 16 << "\t"; - } + ss << std::endl << "cublas:" << std::endl << 0 << ":\t"; + for (int i = 0; i < cut_off; ++i) { + ss << __half2float(hc1[i]) << ", "; + if (i & (i + 1) % 16 == 0) { + ss << std::endl << (i + 1) / 16 << "\t"; } - LOG(INFO) << ss.str(); + } + LOG(INFO) << ss.str(); #endif - double total_diff = 0.; - double max_abs_diff = FLT_MIN; - double diff = 0.; + double total_diff = 0.; + double max_abs_diff = FLT_MIN; + double diff = 0.; - for (int i = 0; i < numel; ++i) { - diff = abs(__half2float(hc1[i]) - hc2[i]); - max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; - total_diff += diff; + for (int i = 0; i < numel; ++i) { + diff = abs(__half2float(hc1[i]) - hc2[i]); + max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; + total_diff += diff; #if defined(DEBUG) - if (diff > eps) { - LOG(INFO) << i - << "-th value has large numeric absolute diff: " << diff - << ", Expected: " << __half2float(hc1[i]) - << "; Got: " << hc2[i] << std::endl; - } -#endif + if (diff > eps) { + LOG(INFO) << i << "-th value has large numeric absolute diff: " << diff + << ", Expected: " << __half2float(hc1[i]) << "; Got: " << hc2[i] + << std::endl; } +#endif + } - double avg_diff = total_diff / numel; - LOG(INFO) << "Average absolute diff: " << avg_diff - << ", Max absolute diff: " << max_abs_diff; - if (avg_diff > eps) pass_unittest = false; + double avg_diff = total_diff / numel; + LOG(INFO) << "Average absolute diff: " << avg_diff + << ", Max absolute diff: " << max_abs_diff; + if (avg_diff > eps) pass_unittest = false; - return pass_unittest; + return pass_unittest; } // @brief: This implementation interprets A and C as being laid out in row-major @@ -93,14 +90,14 @@ bool check_correctness(const half* hc1, const float* hc2, int row, int col) { // [N, M] = [N, K] @ [K, M] void cublas_hgemm(int m, int n, int k, const __half* A, const __half* B, __half* C, int lda, int ldb, int ldc) { - __half alf = 1.; - __half bet = 0.; - - cublasHandle_t handle; - CUBLAS_CHECK(cublasCreate(&handle)); - CUBLAS_CHECK(cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alf, A, - lda, B, ldb, &bet, C, ldc)); - CUBLAS_CHECK(cublasDestroy(handle)); + __half alf = 1.; + __half bet = 0.; + + cublasHandle_t handle; + CUBLAS_CHECK(cublasCreate(&handle)); + CUBLAS_CHECK(cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alf, A, + lda, B, ldb, &bet, C, ldc)); + CUBLAS_CHECK(cublasDestroy(handle)); } // @param strided_k: chunk size to partition the k dimension of the shared @@ -109,62 +106,61 @@ template struct TestTraits { - using MmaAtom = compute::MmaAtom; - using BaseShape = MmaAtom::BaseTile; - - /// ======== 1. configure threads and warp layout in a CTA ============ - using WarpLayout = WarpLayout_; - static constexpr int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; - - /// == 2. configure tile transfer between global and shared using CuTe == - using GlobalA = GlobalTile>; - using SharedA = SharedTile, kSwizzled, - kSharedAccessInBytes>; - using LoadSharedA = GlobalToSharedLoader; - - using GlobalB = GlobalTile>; - using SharedB = SharedTile, kSwizzled, - kSharedAccessInBytes>; - using LoadSharedB = GlobalToSharedLoader; - - /// === 3. configure tile transfer between shared and register loader === - // shared tile for operand A - using TileIteratorA = STileIterator>; - // shared tile for operand B - using TileIteratorB = STileIterator>; - - static_assert(TileIteratorA::sc1 == TileIteratorB::sc0, - "mismatched K dimension!"); - - // register tile for operand A, calculate register usage for operand A - // warp tile shape for the operand A - static constexpr int kAMs = kM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kChunkK / BaseShape::kCols; - using RegA = RegTile, tl::RowMajor>; - // load RegTileA from shared - using LoadRegA = - SharedToRegLoader; - - // register tile for operand B, calculate register usage for operand B - static constexpr int kBKs = kChunkK / BaseShape::kRows; - static constexpr int kBNs = kN / kWarpPerCol / BaseShape::kCols; - using RegB = RegTile, tl::ColMajor>; - // load RegTileB from shared - using LoadRegB = - SharedToRegLoader; - - // register tile for output C - // calculate register usage for output C - static constexpr int kCMs = kM / kWarpPerRow / BaseShape::kRows; - static constexpr int kCNs = kN / kWarpPerCol / BaseShape::kCols; - - using RegC = - RegTile, tl::RowMajor>; - using GlobalC = GlobalTile>; - using CStorer = copy::RegToGlobalStorer; + using MmaAtom = compute::MmaAtom; + using BaseShape = MmaAtom::BaseTile; + + /// ======== 1. configure threads and warp layout in a CTA ============ + using WarpLayout = WarpLayout_; + static constexpr int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; + + /// == 2. configure tile transfer between global and shared using CuTe == + using GlobalA = GlobalTile>; + using SharedA = SharedTile, kSwizzled, + kSharedAccessInBytes>; + using LoadSharedA = GlobalToSharedLoader; + + using GlobalB = GlobalTile>; + using SharedB = SharedTile, kSwizzled, + kSharedAccessInBytes>; + using LoadSharedB = GlobalToSharedLoader; + + /// === 3. configure tile transfer between shared and register loader === + // shared tile for operand A + using TileIteratorA = STileIterator>; + // shared tile for operand B + using TileIteratorB = STileIterator>; + + static_assert(TileIteratorA::sc1 == TileIteratorB::sc0, + "mismatched K dimension!"); + + // register tile for operand A, calculate register usage for operand A + // warp tile shape for the operand A + static constexpr int kAMs = kM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kChunkK / BaseShape::kCols; + using RegA = RegTile, tl::RowMajor>; + // load RegTileA from shared + using LoadRegA = + SharedToRegLoader; + + // register tile for operand B, calculate register usage for operand B + static constexpr int kBKs = kChunkK / BaseShape::kRows; + static constexpr int kBNs = kN / kWarpPerCol / BaseShape::kCols; + using RegB = RegTile, tl::ColMajor>; + // load RegTileB from shared + using LoadRegB = + SharedToRegLoader; + + // register tile for output C + // calculate register usage for output C + static constexpr int kCMs = kM / kWarpPerRow / BaseShape::kRows; + static constexpr int kCNs = kN / kWarpPerCol / BaseShape::kCols; + + using RegC = RegTile, tl::RowMajor>; + using GlobalC = GlobalTile>; + using CStorer = copy::RegToGlobalStorer; }; template __global__ void test_gemm(const Element* ga, const Element* gb, ElementAcc* gc) { - GlobalA gA(ga); - GlobalB gB(gb); + GlobalA gA(ga); + GlobalB gB(gb); - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* shared_a = reinterpret_cast(buf_); - auto* shared_b = shared_a + TileIteratorA::Tile::kNumel; + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* shared_a = reinterpret_cast(buf_); + auto* shared_b = shared_a + TileIteratorA::Tile::kNumel; - SharedA sA(shared_a); - SharedB sB(shared_b); + SharedA sA(shared_a); + SharedB sB(shared_b); - LoadSharedA loaderA; - loaderA(gA, sA); + LoadSharedA loaderA; + loaderA(gA, sA); - LoadSharedB loaderB; - loaderB(gB, sB); - __copy_async(); - __syncthreads(); + LoadSharedB loaderB; + loaderB(gB, sB); + __copy_async(); + __syncthreads(); - TileIteratorA sAs(shared_a); - TileIteratorB sBs(shared_b); + TileIteratorA sAs(shared_a); + TileIteratorB sBs(shared_b); - LoadRegA load_rA; - RegA rA; + LoadRegA load_rA; + RegA rA; - LoadRegB load_rB; - RegB rB; + LoadRegB load_rB; + RegB rB; - RegC acc; + RegC acc; - for (int k = 0; k < TileIteratorA::sc1; ++k) { - load_rA(sAs(k), rA); - load_rB(sBs(k), rB); + for (int k = 0; k < TileIteratorA::sc1; ++k) { + load_rA(sAs(k), rA); + load_rB(sBs(k), rB); - compute::gemm(rA, rB, acc); - } - __syncthreads(); + compute::gemm(rA, rB, acc); + } + __syncthreads(); - // store from register to global - GlobalC gC(gc); - StoreC store_rC; - store_rC(acc, gC); + // store from register to global + GlobalC gC(gc); + StoreC store_rC; + store_rC(acc, gC); } } // namespace @@ -223,169 +219,159 @@ template void run_test() { - // unittest for register-level gemm by calling into wmma PTX - using Element = __half; - using ElementAcc = float; + // unittest for register-level gemm by calling into wmma PTX + using Element = __half; + using ElementAcc = float; - // initialize data - thrust::host_vector h_a(kM * kK); - for (int i = 0; i < h_a.size(); ++i) { + // initialize data + thrust::host_vector h_a(kM * kK); + for (int i = 0; i < h_a.size(); ++i) { #if defined(DEBUG) - h_a[i] = static_cast(i % 2048); + h_a[i] = static_cast(i % 2048); #else - h_a[i] = static_cast(rand_float()); + h_a[i] = static_cast(rand_float()); #endif - } + } - thrust::host_vector h_b(kK * kN); - for (int i = 0; i < h_b.size(); ++i) { + thrust::host_vector h_b(kK * kN); + for (int i = 0; i < h_b.size(); ++i) { #if defined(DEBUG) - h_b[i] = static_cast(i % 2048); + h_b[i] = static_cast(i % 2048); #else - h_b[i] = static_cast(rand_float()); + h_b[i] = static_cast(rand_float()); #endif - } + } - thrust::host_vector h_c(kM * kN); - thrust::fill(h_c.begin(), h_c.end(), 0.); + thrust::host_vector h_c(kM * kN); + thrust::fill(h_c.begin(), h_c.end(), 0.); - thrust::device_vector d_a = h_a; - thrust::device_vector d_b = h_b; - thrust::device_vector d_c = h_c; + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_c = h_c; - // define the configuration of the test - using config = TestTraits; + // define the configuration of the test + using config = TestTraits; - LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "], warps: [" - << config::kWarpPerRow << ", " << config::kWarpPerCol - << "], k_chunk_size: " << kChunkK - << ", kThreads: " << config::kThreads; + LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "], warps: [" + << config::kWarpPerRow << ", " << config::kWarpPerCol + << "], k_chunk_size: " << kChunkK + << ", kThreads: " << config::kThreads; - using RegA = typename config::RegA; - using RegB = typename config::RegB; - using RegC = typename config::RegC; + using RegA = typename config::RegA; + using RegB = typename config::RegB; + using RegC = typename config::RegC; - using IteratorA = typename config::TileIteratorA; - using IteratorB = typename config::TileIteratorB; + using IteratorA = typename config::TileIteratorA; + using IteratorB = typename config::TileIteratorB; #if defined(DEBUG) - LOG(INFO) << "TileIteratorA: " << IteratorA{} << std::endl - << "TileIteratorB: " << IteratorB{} << std::endl - << "RegA: " << RegA{} << std::endl - << "RegB: " << RegB{} << std::endl - << "RegC: " << RegC{} << std::endl; + LOG(INFO) << "TileIteratorA: " << IteratorA{} << std::endl + << "TileIteratorB: " << IteratorB{} << std::endl + << "RegA: " << RegA{} << std::endl + << "RegB: " << RegB{} << std::endl + << "RegC: " << RegC{} << std::endl; #endif - dim3 dim_grid(1, 1, 1); - dim3 dim_block(config::kThreads, 1, 1); - int shm_size = (kM + kN) * kK * sizeof(Element); - - auto kernel = test_gemm< - Element, ElementAcc, typename config::GlobalA, typename config::SharedA, - typename config::LoadSharedA, typename config::GlobalB, - typename config::SharedB, typename config::LoadSharedB, IteratorA, RegA, - typename config::LoadRegA, IteratorB, RegB, typename config::LoadRegB, - typename config::GlobalC, RegC, typename config::CStorer>; - - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } - - kernel<<>>( - thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c.data())); - cudaDeviceSynchronize(); - h_c = d_c; - - // unittest for correctness, take cublas as the ground-truth - thrust::device_vector<__half> d_cublas(kM * kN); - thrust::fill(d_cublas.begin(), d_cublas.end(), 0.); - - cublas_hgemm( - kN, kM, kK, - reinterpret_cast(thrust::raw_pointer_cast(d_b.data())), - reinterpret_cast(thrust::raw_pointer_cast(d_a.data())), - reinterpret_cast<__half*>(thrust::raw_pointer_cast(d_cublas.data())), - kK /*lda*/, kK /*ldb*/, kN /*ldc*/); - thrust::host_vector<__half> h_cublas = d_cublas; - - EXPECT_TRUE(check_correctness(thrust::raw_pointer_cast(h_cublas.data()), - thrust::raw_pointer_cast(h_c.data()), kM, kN)) - << "Failed test!" << std::endl; + dim3 dim_grid(1, 1, 1); + dim3 dim_block(config::kThreads, 1, 1); + int shm_size = (kM + kN) * kK * sizeof(Element); + + auto kernel = test_gemm< + Element, ElementAcc, typename config::GlobalA, typename config::SharedA, + typename config::LoadSharedA, typename config::GlobalB, + typename config::SharedB, typename config::LoadSharedB, IteratorA, RegA, + typename config::LoadRegA, IteratorB, RegB, typename config::LoadRegB, + typename config::GlobalC, RegC, typename config::CStorer>; + + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } + + kernel<<>>( + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data())); + cudaDeviceSynchronize(); + h_c = d_c; + + // unittest for correctness, take cublas as the ground-truth + thrust::device_vector<__half> d_cublas(kM * kN); + thrust::fill(d_cublas.begin(), d_cublas.end(), 0.); + + cublas_hgemm( + kN, kM, kK, + reinterpret_cast(thrust::raw_pointer_cast(d_b.data())), + reinterpret_cast(thrust::raw_pointer_cast(d_a.data())), + reinterpret_cast<__half*>(thrust::raw_pointer_cast(d_cublas.data())), + kK /*lda*/, kK /*ldb*/, kN /*ldc*/); + thrust::host_vector<__half> h_cublas = d_cublas; + + EXPECT_TRUE(check_correctness(thrust::raw_pointer_cast(h_cublas.data()), + thrust::raw_pointer_cast(h_c.data()), kM, kN)) + << "Failed test!" << std::endl; } TEST(TestGemm, test) { - // This unit test loads the entire matrices A and B into shared memory. - // For example, on A100, do not test GEMM larger than [128, 128, 128], - // as this will cause a shared memory overflow. - { - static constexpr int kSharedAccessInBytes = 128; - // 1 warp - run_test<16, 16, 64, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); // minimal shape - run_test<32, 16, 64, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); - run_test<16, 32, 64, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); - run_test<32, 32, 64, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); - run_test<64, 64, 64, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); - run_test<128, 64, 64, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); - run_test<128, 64, 128, tl::RowMajor<1, 1>, 64, true, - kSharedAccessInBytes>(); - - // smaller chunk size - run_test<128, 64, 128, tl::RowMajor<1, 1>, 32, true, - kSharedAccessInBytes>(); - run_test<128, 64, 128, tl::RowMajor<1, 1>, 16, true, - kSharedAccessInBytes>(); - - // 2 x 1 warps - run_test<32, 64, 128, tl::RowMajor<2, 1>, 128, true, - kSharedAccessInBytes>(); - run_test<64, 64, 128, tl::RowMajor<2, 1>, 128, true, - kSharedAccessInBytes>(); - run_test<32, 128, 128, tl::RowMajor<2, 1>, 128, true, - kSharedAccessInBytes>(); - run_test<32, 128, 128, tl::RowMajor<2, 1>, 64, true, - kSharedAccessInBytes>(); - - // 1 x 2 warps - run_test<32, 128, 128, tl::RowMajor<1, 2>, 128, true, - kSharedAccessInBytes>(); - - // 2 x 2 warps - run_test<64, 64, 128, tl::RowMajor<2, 2>, 128, true, - kSharedAccessInBytes>(); - run_test<64, 64, 128, tl::RowMajor<2, 2>, 64, true, - kSharedAccessInBytes>(); - - // smaller chunk size - run_test<64, 64, 128, tl::RowMajor<2, 2>, 32, true, - kSharedAccessInBytes>(); - run_test<64, 64, 128, tl::RowMajor<2, 2>, 16, true, - kSharedAccessInBytes>(); - - // 4 x 1 warps - run_test<64, 16, 256, tl::RowMajor<4, 1>, 256, true, - kSharedAccessInBytes>(); - } - - { - static constexpr int kSharedAccessInBytes = 64; - // Swizzle<2, 3, 3> - run_test<32, 32, 32, tl::RowMajor<1, 1>, 32, true, - kSharedAccessInBytes>(); - run_test<64, 64, 64, tl::RowMajor<2, 2>, 64, true, - kSharedAccessInBytes>(); - run_test<128, 128, 64, tl::RowMajor<2, 2>, 64, true, - kSharedAccessInBytes>(); - } + // This unit test loads the entire matrices A and B into shared memory. + // For example, on A100, do not test GEMM larger than [128, 128, 128], + // as this will cause a shared memory overflow. + { + static constexpr int kSharedAccessInBytes = 128; + // 1 warp + run_test<16, 16, 64, tl::RowMajor<1, 1>, 64, true, + kSharedAccessInBytes>(); // minimal shape + run_test<32, 16, 64, tl::RowMajor<1, 1>, 64, true, kSharedAccessInBytes>(); + run_test<16, 32, 64, tl::RowMajor<1, 1>, 64, true, kSharedAccessInBytes>(); + run_test<32, 32, 64, tl::RowMajor<1, 1>, 64, true, kSharedAccessInBytes>(); + run_test<64, 64, 64, tl::RowMajor<1, 1>, 64, true, kSharedAccessInBytes>(); + run_test<128, 64, 64, tl::RowMajor<1, 1>, 64, true, kSharedAccessInBytes>(); + run_test<128, 64, 128, tl::RowMajor<1, 1>, 64, true, + kSharedAccessInBytes>(); + + // smaller chunk size + run_test<128, 64, 128, tl::RowMajor<1, 1>, 32, true, + kSharedAccessInBytes>(); + run_test<128, 64, 128, tl::RowMajor<1, 1>, 16, true, + kSharedAccessInBytes>(); + + // 2 x 1 warps + run_test<32, 64, 128, tl::RowMajor<2, 1>, 128, true, + kSharedAccessInBytes>(); + run_test<64, 64, 128, tl::RowMajor<2, 1>, 128, true, + kSharedAccessInBytes>(); + run_test<32, 128, 128, tl::RowMajor<2, 1>, 128, true, + kSharedAccessInBytes>(); + run_test<32, 128, 128, tl::RowMajor<2, 1>, 64, true, + kSharedAccessInBytes>(); + + // 1 x 2 warps + run_test<32, 128, 128, tl::RowMajor<1, 2>, 128, true, + kSharedAccessInBytes>(); + + // 2 x 2 warps + run_test<64, 64, 128, tl::RowMajor<2, 2>, 128, true, + kSharedAccessInBytes>(); + run_test<64, 64, 128, tl::RowMajor<2, 2>, 64, true, kSharedAccessInBytes>(); + + // smaller chunk size + run_test<64, 64, 128, tl::RowMajor<2, 2>, 32, true, kSharedAccessInBytes>(); + run_test<64, 64, 128, tl::RowMajor<2, 2>, 16, true, kSharedAccessInBytes>(); + + // 4 x 1 warps + run_test<64, 16, 256, tl::RowMajor<4, 1>, 256, true, + kSharedAccessInBytes>(); + } + + { + static constexpr int kSharedAccessInBytes = 64; + // Swizzle<2, 3, 3> + run_test<32, 32, 32, tl::RowMajor<1, 1>, 32, true, kSharedAccessInBytes>(); + run_test<64, 64, 64, tl::RowMajor<2, 2>, 64, true, kSharedAccessInBytes>(); + run_test<128, 128, 64, tl::RowMajor<2, 2>, 64, true, + kSharedAccessInBytes>(); + } } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_reduce.cu b/tests/cpp/cell/test_reduce.cu index f86db2d8..5c64f570 100644 --- a/tests/cpp/cell/test_reduce.cu +++ b/tests/cpp/cell/test_reduce.cu @@ -19,180 +19,180 @@ template __global__ void reg_reduce(Element* src) { - using SrcLoadTile = GlobalTile; - using DstLoadTile = RegTile; - using SrcReduceTile = DstLoadTile; - using DstReduceTile = RegTile>; - - SrcLoadTile src_load_tile(src); - DstLoadTile dst_load_tile; - DstReduceTile dst_reduce_tile; - - // Load data from global memory to register file - copy::GlobalToRegLoader loader; - loader(src_load_tile, dst_load_tile); - __syncthreads(); - - // Execute reduce operation. - compute::MaxReduce row_max; - row_max(dst_load_tile, dst_reduce_tile); - - __syncthreads(); - - if (thread(0)) { - printf("Row Max:\n"); - printf("Thread 0:\n"); - dst_reduce_tile.dump_value(); - } - - if (thread(1)) { - printf("Thread 1:\n"); - dst_reduce_tile.dump_value(); - } - - if (thread(4)) { - printf("Thread 4:\n"); - dst_reduce_tile.dump_value(); - } - - if (thread(8)) { - printf("Thread 8:\n"); - dst_reduce_tile.dump_value(); - } - - __syncthreads(); - - compute::SumReduce row_sum; - row_sum(dst_load_tile, dst_reduce_tile); - - __syncthreads(); - - if (thread(0)) { - printf("Row Sum:\n"); - printf("Thread 0:\n"); - dst_reduce_tile.dump_value(); - } - - if (thread(1)) { - printf("Thread 1:\n"); - dst_reduce_tile.dump_value(); - } - - if (thread(4)) { - printf("Thread 4:\n"); - dst_reduce_tile.dump_value(); - } - - if (thread(8)) { - printf("Thread 8:\n"); - dst_reduce_tile.dump_value(); - } + using SrcLoadTile = GlobalTile; + using DstLoadTile = RegTile; + using SrcReduceTile = DstLoadTile; + using DstReduceTile = RegTile>; + + SrcLoadTile src_load_tile(src); + DstLoadTile dst_load_tile; + DstReduceTile dst_reduce_tile; + + // Load data from global memory to register file + copy::GlobalToRegLoader loader; + loader(src_load_tile, dst_load_tile); + __syncthreads(); + + // Execute reduce operation. + compute::MaxReduce row_max; + row_max(dst_load_tile, dst_reduce_tile); + + __syncthreads(); + + if (thread(0)) { + printf("Row Max:\n"); + printf("Thread 0:\n"); + dst_reduce_tile.dump_value(); + } + + if (thread(1)) { + printf("Thread 1:\n"); + dst_reduce_tile.dump_value(); + } + + if (thread(4)) { + printf("Thread 4:\n"); + dst_reduce_tile.dump_value(); + } + + if (thread(8)) { + printf("Thread 8:\n"); + dst_reduce_tile.dump_value(); + } + + __syncthreads(); + + compute::SumReduce row_sum; + row_sum(dst_load_tile, dst_reduce_tile); + + __syncthreads(); + + if (thread(0)) { + printf("Row Sum:\n"); + printf("Thread 0:\n"); + dst_reduce_tile.dump_value(); + } + + if (thread(1)) { + printf("Thread 1:\n"); + dst_reduce_tile.dump_value(); + } + + if (thread(4)) { + printf("Thread 4:\n"); + dst_reduce_tile.dump_value(); + } + + if (thread(8)) { + printf("Thread 8:\n"); + dst_reduce_tile.dump_value(); + } } template void run_row_major_reg_reduce() { - int kNumel = 16 * 16 * kHeight * kWidth; - int kWarpSize = tl::get_numel; + int kNumel = 16 * 16 * kHeight * kWidth; + int kWarpSize = tl::get_numel; - using ReduceLayout = tl::RowMajor; + using ReduceLayout = tl::RowMajor; - thrust::host_vector h_src(kNumel); - for (int i = 0; i < kNumel; ++i) { - h_src[i] = (Element)i; - } + thrust::host_vector h_src(kNumel); + for (int i = 0; i < kNumel; ++i) { + h_src[i] = (Element)i; + } - thrust::device_vector d_src = h_src; + thrust::device_vector d_src = h_src; - reg_reduce - <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); + reg_reduce + <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); } template void run_col_major_reg_reduce() { - int kNumel = 16 * 16 * kHeight * kWidth; - int kWarpSize = tl::get_numel; + int kNumel = 16 * 16 * kHeight * kWidth; + int kWarpSize = tl::get_numel; - using ReduceLayout = tl::ColMajor<2, kWidth>; + using ReduceLayout = tl::ColMajor<2, kWidth>; - thrust::host_vector h_src(kNumel); - for (int i = 0; i < kNumel; ++i) { - h_src[i] = (Element)i; - } + thrust::host_vector h_src(kNumel); + for (int i = 0; i < kNumel; ++i) { + h_src[i] = (Element)i; + } - thrust::device_vector d_src = h_src; + thrust::device_vector d_src = h_src; - reg_reduce - <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); + reg_reduce + <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); } TEST(TestRegReduce, row_major_reg_reduce_0) { - const int kHeight = 1; - const int kWidth = 1; - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor; + const int kHeight = 1; + const int kWidth = 1; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_row_major_reg_reduce, WarpLayout, - tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); + run_row_major_reg_reduce, WarpLayout, + tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); } TEST(TestRegReduce, row_major_reg_reduce_1) { - const int kHeight = 2; - const int kWidth = 2; - using Element = float; - using WarpLayout = tl::RowMajor<1, 1>; - using RegLayout = tl::RowMajor; + const int kHeight = 2; + const int kWidth = 2; + using Element = float; + using WarpLayout = tl::RowMajor<1, 1>; + using RegLayout = tl::RowMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; - run_row_major_reg_reduce, WarpLayout, - tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); + run_row_major_reg_reduce, WarpLayout, + tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); } TEST(TestRegReduce, col_major_reg_reduce_0) { - const int kHeight = 1; - const int kWidth = 1; - using Element = float; - using WarpLayout = tl::ColMajor<1, 1>; - using RegLayout = tl::ColMajor; + const int kHeight = 1; + const int kWidth = 1; + using Element = float; + using WarpLayout = tl::ColMajor<1, 1>; + using RegLayout = tl::ColMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::ColMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::ColMajor<16 * kHeight, 16 * kWidth>; - run_col_major_reg_reduce, WarpLayout, - tl::Layout::kColMajor, kMode, kHeight, kWidth>(); + run_col_major_reg_reduce, WarpLayout, + tl::Layout::kColMajor, kMode, kHeight, kWidth>(); } TEST(TestRegReduce, col_major_reg_reduce_1) { - const int kHeight = 2; - const int kWidth = 2; - using Element = float; - using WarpLayout = tl::ColMajor<1, 1>; - using RegLayout = tl::ColMajor; + const int kHeight = 2; + const int kWidth = 2; + using Element = float; + using WarpLayout = tl::ColMajor<1, 1>; + using RegLayout = tl::ColMajor; - const copy::WarpReuse kMode = copy::WarpReuse::kCont; + const copy::WarpReuse kMode = copy::WarpReuse::kCont; - using GlobalLayout = tl::ColMajor<16 * kHeight, 16 * kWidth>; + using GlobalLayout = tl::ColMajor<16 * kHeight, 16 * kWidth>; - run_col_major_reg_reduce, WarpLayout, - tl::Layout::kColMajor, kMode, kHeight, kWidth>(); + run_col_major_reg_reduce, WarpLayout, + tl::Layout::kColMajor, kMode, kHeight, kWidth>(); } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_s2r_copy.cu b/tests/cpp/cell/test_s2r_copy.cu index 8f519734..64d1c302 100644 --- a/tests/cpp/cell/test_s2r_copy.cu +++ b/tests/cpp/cell/test_s2r_copy.cu @@ -16,7 +16,7 @@ namespace tl = tile_layout; namespace { template __device__ void init_value(Element* data, int numel) { - for (int i = 0; i < numel; ++i) data[i] = static_cast(i % 2048); + for (int i = 0; i < numel; ++i) data[i] = static_cast(i % 2048); } template @@ -24,82 +24,81 @@ __device__ bool check_results(const Element* data, int numel); template <> __device__ bool check_results(const __half* data, int numel) { - const float epsilon = 1e-4; - bool pass_test = true; - - float v = 0.; - float diff = 0.; - for (int i = 0; i < numel; ++i) { - v = static_cast(i % 2048); - diff = abs(__half2float(data[i]) - v); - if (diff > epsilon) { - printf("Error data[%d]; Expected: %.0f, Got: %.0f\n", i, v, - __half2float(data[i])); - pass_test = false; - } + const float epsilon = 1e-4; + bool pass_test = true; + + float v = 0.; + float diff = 0.; + for (int i = 0; i < numel; ++i) { + v = static_cast(i % 2048); + diff = abs(__half2float(data[i]) - v); + if (diff > epsilon) { + printf("Error data[%d]; Expected: %.0f, Got: %.0f\n", i, v, + __half2float(data[i])); + pass_test = false; } + } - return pass_test; + return pass_test; } template <> __device__ bool check_results(const float* data, int numel) { - const float epsilon = 1e-4; - bool pass_test = true; - - for (int i = 0; i < numel; ++i) { - float v = float(i % 2048); - if (abs(data[i] - v) > epsilon) { - printf("Error data[%d]; Expected: %.0f, Got: %.0f\n", i, v, - data[i]); - pass_test = false; - } + const float epsilon = 1e-4; + bool pass_test = true; + + for (int i = 0; i < numel; ++i) { + float v = float(i % 2048); + if (abs(data[i] - v) > epsilon) { + printf("Error data[%d]; Expected: %.0f, Got: %.0f\n", i, v, data[i]); + pass_test = false; } + } - return pass_test; + return pass_test; } template __global__ void run_test_load(Copy& copy) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); - init_value(buf, Shared::kNumel); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + init_value(buf, Shared::kNumel); - Shared s_tile(buf); - Reg r_tile; + Shared s_tile(buf); + Reg r_tile; - copy(s_tile, r_tile); + copy(s_tile, r_tile); #if defined(DEBUG) - if (thread0()) { - r_tile.dump_value(); - } + if (thread0()) { + r_tile.dump_value(); + } #endif } template __global__ void run_test_store(Loader& loader, Storer& storer) { - using DType = typename Shared::DType; + using DType = typename Shared::DType; - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); - init_value(buf, Shared::kNumel); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + init_value(buf, Shared::kNumel); - Shared s_tile(buf); - Reg r_tile; + Shared s_tile(buf); + Reg r_tile; - loader(s_tile, r_tile); // load from shared to register - __syncthreads(); + loader(s_tile, r_tile); // load from shared to register + __syncthreads(); - memset(buf, 0, Shared::kNumel * sizeof(DType)); // clean the shared memory + memset(buf, 0, Shared::kNumel * sizeof(DType)); // clean the shared memory - // the reverse operation, store from register to shared - storer(r_tile, s_tile); - __syncthreads(); + // the reverse operation, store from register to shared + storer(r_tile, s_tile); + __syncthreads(); - if (thread0()) { - assert(check_results(buf, Shared::kNumel)); - } + if (thread0()) { + assert(check_results(buf, Shared::kNumel)); + } } template __global__ void run_test_store_float(ConvertHalf& convert, Loader& loader, Storer& storer) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - // load half data to shared memory - using DType = typename SharedHalf::DType; - auto* buf = reinterpret_cast(buf_); - init_value(buf, SharedHalf::kNumel); + // load half data to shared memory + using DType = typename SharedHalf::DType; + auto* buf = reinterpret_cast(buf_); + init_value(buf, SharedHalf::kNumel); - // store buffer on shared memory for storing tcu's output register tile - using AccType = typename SharedFloat::DType; - auto* store_buf = reinterpret_cast(buf + SharedHalf::kNumel); - memset(store_buf, 0, SharedFloat::kNumel * sizeof(AccType)); + // store buffer on shared memory for storing tcu's output register tile + using AccType = typename SharedFloat::DType; + auto* store_buf = reinterpret_cast(buf + SharedHalf::kNumel); + memset(store_buf, 0, SharedFloat::kNumel * sizeof(AccType)); - SharedHalf sh_tile(buf); - RegHalf rh_tile; - RegFloat rf_tile; + SharedHalf sh_tile(buf); + RegHalf rh_tile; + RegFloat rf_tile; - loader(sh_tile, rh_tile); // load from shared to register - __syncthreads(); + loader(sh_tile, rh_tile); // load from shared to register + __syncthreads(); - convert(rh_tile, rf_tile); + convert(rh_tile, rf_tile); #if defined(DEBUG) - if (thread0()) { - printf("register tile:\n"); - rh_tile.dump_value(); - - printf("converted register tile:\n"); - rf_tile.dump_value(); - printf("\n"); - } + if (thread0()) { + printf("register tile:\n"); + rh_tile.dump_value(); + + printf("converted register tile:\n"); + rf_tile.dump_value(); + printf("\n"); + } #endif - SharedFloat sf_tile(store_buf); + SharedFloat sf_tile(store_buf); - // the reverse operation, store from register to shared - storer(rf_tile, sf_tile); - __syncthreads(); + // the reverse operation, store from register to shared + storer(rf_tile, sf_tile); + __syncthreads(); - if (thread0()) { - assert(check_results(store_buf, SharedFloat::kNumel)); - } + if (thread0()) { + assert(check_results(store_buf, SharedFloat::kNumel)); + } } } // namespace TEST(TestShared2Reg, operand_A) { // load mode for loading operand A in gemm - using Element = __half; - - using WarpLayout = tl::RowMajor<1, 1>; - const int kThreads = tl::get_numel * 32; - - using Shared = SharedTile>; - // Each thread accesses 2x4 elements (the shape of `BaseHalfTileRowMajor`) - // within a 16x16 `BaseTile`. These 2x4 elements are accessed 2x2 times - // along each dimension, contributing to the final register tile handled by - // a single thread. - using Reg = RegTile, tl::RowMajor<4, 4>>; - - // In the `RowReuseCont` mode, warps in the same row repeatedly access the - // same data. - using Copy = SharedToRegLoader; - Copy copy; - - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = Shared::kNumel * sizeof(Element); - - run_test_load - <<>>(copy); - cudaDeviceSynchronize(); + using Element = __half; + + using WarpLayout = tl::RowMajor<1, 1>; + const int kThreads = tl::get_numel * 32; + + using Shared = SharedTile>; + // Each thread accesses 2x4 elements (the shape of `BaseHalfTileRowMajor`) + // within a 16x16 `BaseTile`. These 2x4 elements are accessed 2x2 times + // along each dimension, contributing to the final register tile handled by + // a single thread. + using Reg = RegTile, tl::RowMajor<4, 4>>; + + // In the `RowReuseCont` mode, warps in the same row repeatedly access the + // same data. + using Copy = SharedToRegLoader; + Copy copy; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); + + run_test_load + <<>>(copy); + cudaDeviceSynchronize(); } TEST(TestShared2Reg, operand_B) { // load mode for loading operand B in gemm - using Element = __half; - - using WarpLayout = tl::RowMajor<1, 1>; - const int kThreads = tl::get_numel * 32; - - // a 32x64 row-major shared tile is equivalent to a 64x32 col-major tile - // using Shared = SharedTile>; - using Shared = SharedTile>; - - // Each thread accesses 4x2 elements (the shape of `BaseHalfTileRowMajor`) - // within a 16x16 `BaseTile`. These 4x2 elements are accessed 2x2 times - // along each dimension, contributing to the final register tile handled by - // a single thread. - using Reg = RegTile, tl::ColMajor<2, 2>>; - // In the `ColReuseCont` mode, warps in the same column repeatedly access - // the same data. - using Copy = SharedToRegLoader; - Copy copy; - - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = Shared::kNumel * sizeof(Element); - - run_test_load - <<>>(copy); - cudaDeviceSynchronize(); + using Element = __half; + + using WarpLayout = tl::RowMajor<1, 1>; + const int kThreads = tl::get_numel * 32; + + // a 32x64 row-major shared tile is equivalent to a 64x32 col-major tile + // using Shared = SharedTile>; + using Shared = SharedTile>; + + // Each thread accesses 4x2 elements (the shape of `BaseHalfTileRowMajor`) + // within a 16x16 `BaseTile`. These 4x2 elements are accessed 2x2 times + // along each dimension, contributing to the final register tile handled by + // a single thread. + using Reg = RegTile, tl::ColMajor<2, 2>>; + // In the `ColReuseCont` mode, warps in the same column repeatedly access + // the same data. + using Copy = SharedToRegLoader; + Copy copy; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); + + run_test_load + <<>>(copy); + cudaDeviceSynchronize(); } TEST(TestReg2Shared, operand_C_half) { - using Element = __half; + using Element = __half; - using WarpLayout = tl::RowMajor<1, 1>; - const int kThreads = tl::get_numel * 32; + using WarpLayout = tl::RowMajor<1, 1>; + const int kThreads = tl::get_numel * 32; - using Shared = SharedTile>; - using Reg = RegTile, tl::RowMajor<4, 8>>; + using Shared = SharedTile>; + using Reg = RegTile, tl::RowMajor<4, 8>>; - using Loader = SharedToRegLoader; - Loader loader; + using Loader = SharedToRegLoader; + Loader loader; - using Storer = RegToSharedStorer; - Storer storer; + using Storer = RegToSharedStorer; + Storer storer; - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = Shared::kNumel * sizeof(Element); + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); - run_test_store - <<>>(loader, storer); - cudaDeviceSynchronize(); + run_test_store + <<>>(loader, storer); + cudaDeviceSynchronize(); } TEST(TestShared2Reg, operand_A_swizzle) { - using Element = __half; + using Element = __half; - using WarpLayout = tl::RowMajor<1, 1>; - const int kThreads = tl::get_numel * 32; + using WarpLayout = tl::RowMajor<1, 1>; + const int kThreads = tl::get_numel * 32; - const int kRows = 16; - const int kCols = 64; + const int kRows = 16; + const int kCols = 64; - using SharedLayout = tl::RowMajor; - const bool kUseSwizzledLayout = true; - using Shared = SharedTile; - using Reg = RegTile, tl::RowMajor<1, 4>>; + using SharedLayout = tl::RowMajor; + const bool kUseSwizzledLayout = true; + using Shared = SharedTile; + using Reg = RegTile, tl::RowMajor<1, 4>>; - using Copy = SharedToRegLoader; - Copy copy; + using Copy = SharedToRegLoader; + Copy copy; - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = Shared::kNumel * sizeof(Element); + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); - run_test_load - <<>>(copy); - cudaDeviceSynchronize(); + run_test_load + <<>>(copy); + cudaDeviceSynchronize(); } TEST(TestReg2Shared, operand_C_float) { - using Element = __half; - using AccType = float; - - const int kRowRepeats = 4; - const int kColRepeats = 8; - const int kRows = 16 * kRowRepeats; - const int kCols = 16 * kColRepeats; - - const int kWarpPerRow = 1; - const int kWarpPerCol = 1; - using WarpLayout = tl::RowMajor; - const int kThreads = tl::get_numel * 32; - - using SharedHalf = SharedTile>; - using RegHalf = RegTile< - BaseTileRowMajor, - tl::RowMajor>; - - using SharedFloat = SharedTile>; - using RegFloat = RegTile< - BaseTileRowMajor, - tl::RowMajor>; - - using ConvertHalf = compute::RegTileConvert; - ConvertHalf convert; - - using Loader = SharedToRegLoader; - Loader loader; - - using Storer = RegToSharedStorer; - Storer storer; - - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = SharedHalf::kNumel * sizeof(Element) + - SharedFloat::kNumel * sizeof(AccType); - - run_test_store_float - <<>>(convert, loader, storer); - cudaDeviceSynchronize(); + using Element = __half; + using AccType = float; + + const int kRowRepeats = 4; + const int kColRepeats = 8; + const int kRows = 16 * kRowRepeats; + const int kCols = 16 * kColRepeats; + + const int kWarpPerRow = 1; + const int kWarpPerCol = 1; + using WarpLayout = tl::RowMajor; + const int kThreads = tl::get_numel * 32; + + using SharedHalf = SharedTile>; + using RegHalf = RegTile< + BaseTileRowMajor, + tl::RowMajor>; + + using SharedFloat = SharedTile>; + using RegFloat = RegTile< + BaseTileRowMajor, + tl::RowMajor>; + + using ConvertHalf = compute::RegTileConvert; + ConvertHalf convert; + + using Loader = SharedToRegLoader; + Loader loader; + + using Storer = RegToSharedStorer; + Storer storer; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = SharedHalf::kNumel * sizeof(Element) + + SharedFloat::kNumel * sizeof(AccType); + + run_test_store_float + <<>>(convert, loader, storer); + cudaDeviceSynchronize(); } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_single_wmma.cu b/tests/cpp/cell/test_single_wmma.cu index 24dfbfa1..89a624cd 100644 --- a/tests/cpp/cell/test_single_wmma.cu +++ b/tests/cpp/cell/test_single_wmma.cu @@ -19,61 +19,59 @@ namespace { // row-major, and B is interpreted as being laid out in column-major. __device__ void naive_gemm(int kM, int kN, int kK, // const __half* A, const __half* B, float* C) { - if (!thread0()) return; - - for (int i = 0; i < kM; ++i) { - for (int j = 0; j < kN; ++j) { - float s = 0.; - for (int k = 0; k < kK; ++k) { - s += __half2float(A[i * kK + k]) * __half2float(B[k + kK * j]); - } - C[i * kN + j] = s; - } + if (!thread0()) return; + + for (int i = 0; i < kM; ++i) { + for (int j = 0; j < kN; ++j) { + float s = 0.; + for (int k = 0; k < kK; ++k) { + s += __half2float(A[i * kK + k]) * __half2float(B[k + kK * j]); + } + C[i * kN + j] = s; } + } } __device__ void check_results(const float* hc1, const float* hc2, int numel) { - for (int i = 0; i < numel; ++i) { - if (fabs(hc1[i] - hc2[i]) > 1e-3) { - printf("error: %d, %.4f, %.4f\n", i, hc1[i], hc2[i]); - printf("test failed!\n"); - return; - } + for (int i = 0; i < numel; ++i) { + if (fabs(hc1[i] - hc2[i]) > 1e-3) { + printf("error: %d, %.4f, %.4f\n", i, hc1[i], hc2[i]); + printf("test failed!\n"); + return; } + } - printf("test passed!\n"); + printf("test passed!\n"); #if defined(DEBUG) - if (thread0()) { - int cut_off = numel < 128 ? numel : 128; - printf("\nours:\n"); - printf("%d:\t", 0); - for (int i = 0; i < cut_off; i++) { - printf("%.2f, ", hc1[i]); - if (i & (i + 1) % 16 == 0) printf("\n%d:\t", (i + 1) / 16); - } - printf("\nground-truth:\n"); - printf("%d:\t", 0); - for (int i = 0; i < cut_off; i++) { - printf("%.2f, ", hc2[i]); - if (i & (i + 1) % 16 == 0) printf("\n%d:\t", (i + 1) / 16); - } + if (thread0()) { + int cut_off = numel < 128 ? numel : 128; + printf("\nours:\n"); + printf("%d:\t", 0); + for (int i = 0; i < cut_off; i++) { + printf("%.2f, ", hc1[i]); + if (i & (i + 1) % 16 == 0) printf("\n%d:\t", (i + 1) / 16); } + printf("\nground-truth:\n"); + printf("%d:\t", 0); + for (int i = 0; i < cut_off; i++) { + printf("%.2f, ", hc2[i]); + if (i & (i + 1) % 16 == 0) printf("\n%d:\t", (i + 1) / 16); + } + } #endif } template __device__ void init_values(Element* a, Element* b, ElementAcc* c, int M, int N, int K) { - if (!thread0()) return; + if (!thread0()) return; - for (int i = 0; i < M * K; ++i) - a[i] = static_cast(i % 2048 / 1000); + for (int i = 0; i < M * K; ++i) a[i] = static_cast(i % 2048 / 1000); - for (int i = 0; i < K * N; ++i) - b[i] = static_cast(i % 2048 / 1000); + for (int i = 0; i < K * N; ++i) b[i] = static_cast(i % 2048 / 1000); - for (int i = 0; i < M * N; ++i) c[i] = 0.; + for (int i = 0; i < M * N; ++i) c[i] = 0.; } template __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB, StoreRegC& store_rC) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* shared_a = reinterpret_cast(buf_); - auto* shared_b = shared_a + TileIteratorA::Tile::kNumel; - auto* shared_c = - reinterpret_cast(shared_b + TileIteratorB::Tile::kNumel); - auto shared_ref = - shared_c + TileIteratorA::Tile::kRows * TileIteratorB::Tile::kCols; + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* shared_a = reinterpret_cast(buf_); + auto* shared_b = shared_a + TileIteratorA::Tile::kNumel; + auto* shared_c = + reinterpret_cast(shared_b + TileIteratorB::Tile::kNumel); + auto shared_ref = + shared_c + TileIteratorA::Tile::kRows * TileIteratorB::Tile::kCols; - init_values(shared_a, shared_b, shared_c, M, N, K); + init_values(shared_a, shared_b, shared_c, M, N, K); - __syncthreads(); + __syncthreads(); - SharedC sC(shared_c); + SharedC sC(shared_c); - TileIteratorA sAs(shared_a); - TileIteratorB sBs(shared_b); + TileIteratorA sAs(shared_a); + TileIteratorB sBs(shared_b); - RegA rA; - RegB rB; - RegC acc; + RegA rA; + RegB rB; + RegC acc; - for (int k = 0; k < TileIteratorA::sc1; ++k) { - auto sA = sAs(k); - auto sB = sBs(k); + for (int k = 0; k < TileIteratorA::sc1; ++k) { + auto sA = sAs(k); + auto sB = sBs(k); - load_rA(sA, rA); - load_rB(sB, rB); - __syncthreads(); + load_rA(sA, rA); + load_rB(sB, rB); + __syncthreads(); - gemm(rA, rB, acc); - } + gemm(rA, rB, acc); + } - __syncthreads(); + __syncthreads(); - store_rC(acc, sC); - __syncthreads(); + store_rC(acc, sC); + __syncthreads(); - if (thread0()) { - __half* dA = reinterpret_cast<__half*>(shared_a); - __half* dB = reinterpret_cast<__half*>(shared_b); - float* dC = reinterpret_cast(shared_ref); - naive_gemm(M, N, K, dA, dB, dC); + if (thread0()) { + __half* dA = reinterpret_cast<__half*>(shared_a); + __half* dB = reinterpret_cast<__half*>(shared_b); + float* dC = reinterpret_cast(shared_ref); + naive_gemm(M, N, K, dA, dB, dC); - check_results(dC, shared_c, M * N); - } + check_results(dC, shared_c, M * N); + } } } // namespace @@ -135,103 +133,101 @@ __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB, template struct TestTraits { - static const int kThreads = tl::get_numel * 32; + static const int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; - static constexpr int kWarpPerCol = tl::num_cols; + static constexpr int kWarpPerRow = tl::num_rows; + static constexpr int kWarpPerCol = tl::num_cols; - // ============= shared to register loader ================= - using MmaAtom = MmaAtom; - using BaseShape = MmaAtom::BaseTile; + // ============= shared to register loader ================= + using MmaAtom = MmaAtom; + using BaseShape = MmaAtom::BaseTile; - static constexpr int kAMs = kM / kWarpPerRow / BaseShape::kRows; - static constexpr int kAKs = kK / BaseShape::kCols; + static constexpr int kAMs = kM / kWarpPerRow / BaseShape::kRows; + static constexpr int kAKs = kK / BaseShape::kCols; - static constexpr int kBKs = kK / BaseShape::kRows; - static constexpr int kBNs = kN / kWarpPerCol / BaseShape::kCols; + static constexpr int kBKs = kK / BaseShape::kRows; + static constexpr int kBNs = kN / kWarpPerCol / BaseShape::kCols; - static constexpr int kCMs = kM / kWarpPerRow / BaseShape::kRows; - static constexpr int kCNs = kN / kWarpPerCol / BaseShape::kCols; + static constexpr int kCMs = kM / kWarpPerRow / BaseShape::kRows; + static constexpr int kCNs = kN / kWarpPerCol / BaseShape::kCols; - using SharedA = SharedTile>; - using TileIteratorA = STileIterator>; + using SharedA = SharedTile>; + using TileIteratorA = STileIterator>; - using RegA = RegTile, tl::RowMajor>; - using LoadRegA = - SharedToRegLoader; + using RegA = RegTile, tl::RowMajor>; + using LoadRegA = + SharedToRegLoader; - using SharedB = SharedTile>; + using SharedB = SharedTile>; - using RegB = RegTile, tl::ColMajor>; - using TileIteratorB = STileIterator>; - using LoadRegB = - SharedToRegLoader; + using RegB = RegTile, tl::ColMajor>; + using TileIteratorB = STileIterator>; + using LoadRegB = + SharedToRegLoader; - static_assert(TileIteratorA::sc1 == TileIteratorB::sc0, - "dimension mismatch!"); + static_assert(TileIteratorA::sc1 == TileIteratorB::sc0, + "dimension mismatch!"); - // ============= register to shared storer ================= - using SharedC = SharedTile>; - using RegC = - RegTile, tl::RowMajor>; - using StoreRegC = RegToSharedStorer; + // ============= register to shared storer ================= + using SharedC = SharedTile>; + using RegC = RegTile, tl::RowMajor>; + using StoreRegC = RegToSharedStorer; }; template void run_test() { - using Element = __half; - using ElementAcc = float; + using Element = __half; + using ElementAcc = float; - using config = TestTraits; + using config = TestTraits; - dim3 dim_grid(1, 1, 1); - dim3 dim_block(config::kThreads, 1, 1); + dim3 dim_grid(1, 1, 1); + dim3 dim_block(config::kThreads, 1, 1); - typename config::LoadRegA load_rA; - typename config::LoadRegB load_rB; - typename config::StoreRegC store_rC; + typename config::LoadRegA load_rA; + typename config::LoadRegB load_rB; + typename config::StoreRegC store_rC; - using RegA = typename config::RegA; - using RegB = typename config::RegB; - using RegC = typename config::RegC; + using RegA = typename config::RegA; + using RegB = typename config::RegB; + using RegC = typename config::RegC; - LOG(INFO) << std::endl - << "RegA: " << RegA{} << std::endl - << "RegB: " << RegB{} << std::endl - << "RegC: " << RegC{} << std::endl; + LOG(INFO) << std::endl + << "RegA: " << RegA{} << std::endl + << "RegB: " << RegB{} << std::endl + << "RegC: " << RegC{} << std::endl; - auto kernel = - test_wmma; + auto kernel = + test_wmma; - int shm_size = - (kM + kN) * kK * sizeof(Element) + 2 * kM * kN * sizeof(ElementAcc); + int shm_size = + (kM + kN) * kK * sizeof(Element) + 2 * kM * kN * sizeof(ElementAcc); - if (shm_size > 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } + if (shm_size > 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } - kernel<<>>(load_rA, load_rB, store_rC); + kernel<<>>(load_rA, load_rB, store_rC); - cudaDeviceSynchronize(); + cudaDeviceSynchronize(); - LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "]. Test passed!" - << std::endl; + LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "]. Test passed!" + << std::endl; } TEST(TestWmma, test_m16n16k16_f) { - run_test<16, 16, 64, tl::RowMajor<1, 1>>(); - run_test<32, 32, 64, tl::RowMajor<1, 1>>(); - run_test<64, 64, 64, tl::RowMajor<1, 1>>(); - - run_test<32, 64, 64, tl::RowMajor<2, 1>>(); - run_test<32, 64, 64, tl::RowMajor<2, 1>>(); - run_test<64, 64, 64, tl::RowMajor<2, 1>>(); - run_test<128, 64, 64, tl::RowMajor<2, 1>>(); + run_test<16, 16, 64, tl::RowMajor<1, 1>>(); + run_test<32, 32, 64, tl::RowMajor<1, 1>>(); + run_test<64, 64, 64, tl::RowMajor<1, 1>>(); + + run_test<32, 64, 64, tl::RowMajor<2, 1>>(); + run_test<32, 64, 64, tl::RowMajor<2, 1>>(); + run_test<64, 64, 64, tl::RowMajor<2, 1>>(); + run_test<128, 64, 64, tl::RowMajor<2, 1>>(); } } // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_swizzled_copy.cu b/tests/cpp/cell/test_swizzled_copy.cu index 7dcd8611..82744ff7 100644 --- a/tests/cpp/cell/test_swizzled_copy.cu +++ b/tests/cpp/cell/test_swizzled_copy.cu @@ -21,22 +21,22 @@ namespace tl = tile_layout; namespace { template __device__ void init_value(Element* data, int numel) { - for (int i = 0; i < numel; ++i) data[i] = static_cast(0.); + for (int i = 0; i < numel; ++i) data[i] = static_cast(0.); } template DEVICE void check_results(const Reg& r_tile, const Reg& r_tile_swizzled, int rows, int cols) { - const int numel = BaseTileRowMajor::kNumel; + const int numel = BaseTileRowMajor::kNumel; - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - const DType* data1 = r_tile(i, j).data(); - const DType* data2 = r_tile_swizzled(i, j).data(); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + const DType* data1 = r_tile(i, j).data(); + const DType* data2 = r_tile_swizzled(i, j).data(); - for (int n = 0; n < numel; ++n) assert(data1[n] == data2[n]); - } + for (int n = 0; n < numel; ++n) assert(data1[n] == data2[n]); } + } } template __global__ void swizzled_copy(const Element* data, G2S1& g2s, G2S2& g2s_swizzled, S2R& s2r) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); - init_value(buf, Shared1::kNumel + Shared2::kNumel); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + init_value(buf, Shared1::kNumel + Shared2::kNumel); - GIterator g_tiles(data); + GIterator g_tiles(data); - Shared1 s_tile(buf); - Shared2 s_swizzled_tile(buf + Shared1::kNumel); + Shared1 s_tile(buf); + Shared2 s_swizzled_tile(buf + Shared1::kNumel); - Reg r_tile; - Reg r_tile_swizzled; + Reg r_tile; + Reg r_tile_swizzled; - SIterator1 s_tiles(buf); - SIterator2 s_swizzled_tiles(buf + Shared1::kNumel); + SIterator1 s_tiles(buf); + SIterator2 s_swizzled_tiles(buf + Shared1::kNumel); - for (int k = 0; k < GIterator::sc1; ++k) { - g2s(g_tiles(k), s_tile); - g2s_swizzled(g_tiles(k), s_swizzled_tile); - __copy_async(); - __syncthreads(); + for (int k = 0; k < GIterator::sc1; ++k) { + g2s(g_tiles(k), s_tile); + g2s_swizzled(g_tiles(k), s_swizzled_tile); + __copy_async(); + __syncthreads(); - for (int i = 0; i < SIterator1::sc1; ++i) { - s2r(s_tiles(i), r_tile); - s2r(s_swizzled_tiles(i), r_tile_swizzled); - __syncthreads(); + for (int i = 0; i < SIterator1::sc1; ++i) { + s2r(s_tiles(i), r_tile); + s2r(s_swizzled_tiles(i), r_tile_swizzled); + __syncthreads(); #ifdef DEBUG - if (thread(0)) { - printf("\niteration [%d, %d]\n", k, i); - s_tiles(i).dump_value(); + if (thread(0)) { + printf("\niteration [%d, %d]\n", k, i); + s_tiles(i).dump_value(); - printf("\ns_swizzled_tiles:\n"); - s_swizzled_tiles(i).dump_value(); + printf("\ns_swizzled_tiles:\n"); + s_swizzled_tiles(i).dump_value(); - printf("r_tile:\n"); - r_tile.dump_value(); + printf("r_tile:\n"); + r_tile.dump_value(); - printf("\nr_tile_swizzled:\n"); - r_tile_swizzled.dump_value(); - } + printf("\nr_tile_swizzled:\n"); + r_tile_swizzled.dump_value(); + } #endif - check_results(r_tile, r_tile_swizzled, Reg::kRows, - Reg::kCols); - } + check_results(r_tile, r_tile_swizzled, Reg::kRows, + Reg::kCols); } + } } /// @brief This unit test verifies the correctness of the swizzled row-major @@ -99,93 +99,93 @@ template void run_test_rowmajor() { - static_assert(kShmRows == kRows, "kShmRows must be equal to kRows"); + static_assert(kShmRows == kRows, "kShmRows must be equal to kRows"); - using Element = __half; - const int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; + using Element = __half; + const int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerRow = tl::num_rows; - using Global = GlobalTile>; - using GIterator = GTileIterator>; + using Global = GlobalTile>; + using GIterator = GTileIterator>; - // for non-swizzled layout - using Shared1 = SharedTile, false, - kSharedAccessInBytes>; - using SIterator1 = STileIterator>; + // for non-swizzled layout + using Shared1 = SharedTile, false, + kSharedAccessInBytes>; + using SIterator1 = STileIterator>; - // for swizzled layout - using Shared2 = SharedTile, true, - kSharedAccessInBytes>; - using SIterator2 = STileIterator>; + // for swizzled layout + using Shared2 = SharedTile, true, + kSharedAccessInBytes>; + using SIterator2 = STileIterator>; - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. Future refactoring of the program's concepts and interfaces - // should eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. Future refactoring of the program's concepts and interfaces + // should eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; - const int kSc0 = kShmRows / kWarpPerRow / BaseShape::kRows; - const int kSc1 = kChunkShm / BaseShape::kCols; + const int kSc0 = kShmRows / kWarpPerRow / BaseShape::kRows; + const int kSc1 = kChunkShm / BaseShape::kCols; - using Reg = RegTile, tl::RowMajor>; + using Reg = RegTile, tl::RowMajor>; #ifdef DEBUG - LOG(INFO) << std::endl - << "GlobalTile: " << Global{} << std::endl - << "GIterator: " << GIterator{} << std::endl - << "SharedTile2: " << std::endl - << Shared1{} << std::endl - << "SIterator1: " << SIterator1{} << std::endl - << std::endl - << "SharedTile2: " << std::endl - << Shared2{} << std::endl - << "SIterator2: " << SIterator2{} << std::endl; + LOG(INFO) << std::endl + << "GlobalTile: " << Global{} << std::endl + << "GIterator: " << GIterator{} << std::endl + << "SharedTile2: " << std::endl + << Shared1{} << std::endl + << "SIterator1: " << SIterator1{} << std::endl + << std::endl + << "SharedTile2: " << std::endl + << Shared2{} << std::endl + << "SIterator2: " << SIterator2{} << std::endl; #endif - using G2S1 = GlobalToSharedLoader; - using G2S2 = GlobalToSharedLoader; - using S2R = SharedToRegLoader; - - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = (Shared1::kNumel + Shared2::kNumel) * sizeof(Element); - - const int numel = kRows * kCols; - using Element = __half; - thrust::host_vector hA(numel); - for (int i = 0; i < hA.size(); ++i) { - hA[i] = static_cast(i % 2048); - } - thrust::device_vector dA = hA; - - G2S1 g2s; - G2S2 g2s_swizzled; - S2R s2r; - - auto test_func = - &swizzled_copy; - - // maximal statically allocated smem per block - const int kMaxSmemPerBlock = 48 * 1024; - if (shm_size > kMaxSmemPerBlock) { - cudaFuncSetAttribute( - test_func, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } - - test_func<<>>( - thrust::raw_pointer_cast(dA.data()), g2s, g2s_swizzled, s2r); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "Kernel failed: %s\n", cudaGetErrorString(err)); - exit(1); - } - - std::ostringstream ss; - ss << "[" << kRows << ", " << kCols << ", " << kShmRows << ", " << kShmCols - << ", " << kChunkShm << "]"; - LOG(INFO) << std::endl << ss.str() << " passed!" << std::endl; + using G2S1 = GlobalToSharedLoader; + using G2S2 = GlobalToSharedLoader; + using S2R = SharedToRegLoader; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = (Shared1::kNumel + Shared2::kNumel) * sizeof(Element); + + const int numel = kRows * kCols; + using Element = __half; + thrust::host_vector hA(numel); + for (int i = 0; i < hA.size(); ++i) { + hA[i] = static_cast(i % 2048); + } + thrust::device_vector dA = hA; + + G2S1 g2s; + G2S2 g2s_swizzled; + S2R s2r; + + auto test_func = + &swizzled_copy; + + // maximal statically allocated smem per block + const int kMaxSmemPerBlock = 48 * 1024; + if (shm_size > kMaxSmemPerBlock) { + cudaFuncSetAttribute(test_func, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } + + test_func<<>>( + thrust::raw_pointer_cast(dA.data()), g2s, g2s_swizzled, s2r); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "Kernel failed: %s\n", cudaGetErrorString(err)); + exit(1); + } + + std::ostringstream ss; + ss << "[" << kRows << ", " << kCols << ", " << kShmRows << ", " << kShmCols + << ", " << kChunkShm << "]"; + LOG(INFO) << std::endl << ss.str() << " passed!" << std::endl; } /// @brief This unit test verifies the correctness of the swizzled column-major @@ -194,435 +194,435 @@ template void run_test_colmajor() { - using Element = __half; - const int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerCol = tl::num_cols; + using Element = __half; + const int kThreads = tl::get_numel * 32; + static constexpr int kWarpPerCol = tl::num_cols; - static_assert(kShmCols == kCols, "kShmCols must be equal to kCols."); + static_assert(kShmCols == kCols, "kShmCols must be equal to kCols."); - using Global = GlobalTile>; - using GIterator = GTileIterator>; + using Global = GlobalTile>; + using GIterator = GTileIterator>; - // for non-swizzled layout - using Shared1 = SharedTile, - false /*disable swizzled layout on shared*/, - kSharedAccessInBytes>; - using SIterator1 = STileIterator>; + // for non-swizzled layout + using Shared1 = SharedTile, + false /*disable swizzled layout on shared*/, + kSharedAccessInBytes>; + using SIterator1 = STileIterator>; - // for swizzled layout - using Shared2 = SharedTile, - true /*enable swizzled layout on shared*/, - kSharedAccessInBytes>; - using SIterator2 = STileIterator>; + // for swizzled layout + using Shared2 = SharedTile, + true /*enable swizzled layout on shared*/, + kSharedAccessInBytes>; + using SIterator2 = STileIterator>; - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. Future refactoring of the program's concepts and interfaces - // should eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. Future refactoring of the program's concepts and interfaces + // should eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; - const int kSc0 = kChunkShm / BaseShape::kRows; - const int kSc1 = kShmCols / BaseShape::kCols / kWarpPerCol; + const int kSc0 = kChunkShm / BaseShape::kRows; + const int kSc1 = kShmCols / BaseShape::kCols / kWarpPerCol; - using Reg = RegTile, tl::ColMajor>; + using Reg = RegTile, tl::ColMajor>; #ifdef DEBUG - LOG(INFO) << std::endl - << "GlobalTile: " << Global{} << std::endl - << "GIterator: " << GIterator{} << std::endl - << "SharedTile2: " << std::endl - << Shared1{} << std::endl - << "SIterator1: " << SIterator1{} << std::endl - << std::endl - << "SharedTile2: " << std::endl - << Shared2{} << std::endl - << "SIterator2: " << SIterator2{} << std::endl; + LOG(INFO) << std::endl + << "GlobalTile: " << Global{} << std::endl + << "GIterator: " << GIterator{} << std::endl + << "SharedTile2: " << std::endl + << Shared1{} << std::endl + << "SIterator1: " << SIterator1{} << std::endl + << std::endl + << "SharedTile2: " << std::endl + << Shared2{} << std::endl + << "SIterator2: " << SIterator2{} << std::endl; #endif - using G2S1 = GlobalToSharedLoader; - using G2S2 = GlobalToSharedLoader; - using S2R = SharedToRegLoader; - - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = (Shared1::kNumel + Shared2::kNumel) * sizeof(Element); - - const int numel = kRows * kCols; - using Element = __half; - thrust::host_vector hA(numel); - for (int i = 0; i < hA.size(); ++i) { - hA[i] = static_cast(i % 2048); - } - thrust::device_vector dA = hA; - - G2S1 g2s; - G2S2 g2s_swizzled; - S2R s2r; - - auto test_func = - &swizzled_copy; - - // maximal statically allocated smem per block - const int kMaxSmemPerBlock = 48 * 1024; - if (shm_size > kMaxSmemPerBlock) { - cudaFuncSetAttribute( - test_func, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); - } - - test_func<<>>( - thrust::raw_pointer_cast(dA.data()), g2s, g2s_swizzled, s2r); - cudaDeviceSynchronize(); - - std::ostringstream ss; - ss << "[" << kRows << ", " << kCols << ", " << kShmRows << ", " << kShmCols - << ", " << kChunkShm << "]"; - LOG(INFO) << std::endl << ss.str() << " passed!" << std::endl; + using G2S1 = GlobalToSharedLoader; + using G2S2 = GlobalToSharedLoader; + using S2R = SharedToRegLoader; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = (Shared1::kNumel + Shared2::kNumel) * sizeof(Element); + + const int numel = kRows * kCols; + using Element = __half; + thrust::host_vector hA(numel); + for (int i = 0; i < hA.size(); ++i) { + hA[i] = static_cast(i % 2048); + } + thrust::device_vector dA = hA; + + G2S1 g2s; + G2S2 g2s_swizzled; + S2R s2r; + + auto test_func = + &swizzled_copy; + + // maximal statically allocated smem per block + const int kMaxSmemPerBlock = 48 * 1024; + if (shm_size > kMaxSmemPerBlock) { + cudaFuncSetAttribute(test_func, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } + + test_func<<>>( + thrust::raw_pointer_cast(dA.data()), g2s, g2s_swizzled, s2r); + cudaDeviceSynchronize(); + + std::ostringstream ss; + ss << "[" << kRows << ", " << kCols << ", " << kShmRows << ", " << kShmCols + << ", " << kChunkShm << "]"; + LOG(INFO) << std::endl << ss.str() << " passed!" << std::endl; } template __global__ void swizzled_store(const Element* src, Element* dst) { - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - auto* buf = reinterpret_cast(buf_); + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); - Loader loader; - StorerR2S storer1; - StorerS2G storer2; + Loader loader; + StorerR2S storer1; + StorerS2G storer2; - Global g_src_tile(src); - Reg r_tile; + Global g_src_tile(src); + Reg r_tile; - Shared s_tile(buf); - Global g_dst_tile(dst); + Shared s_tile(buf); + Global g_dst_tile(dst); - loader(g_src_tile, r_tile); - __syncthreads(); + loader(g_src_tile, r_tile); + __syncthreads(); - storer1(r_tile, s_tile); - __syncthreads(); + storer1(r_tile, s_tile); + __syncthreads(); - storer2(s_tile, g_dst_tile); - __syncthreads(); + storer2(s_tile, g_dst_tile); + __syncthreads(); #if defined(DEBUG) - if (thread0()) { - printf("\nglobal tile source:\n"); - g_src_tile.dump_value(); + if (thread0()) { + printf("\nglobal tile source:\n"); + g_src_tile.dump_value(); - printf("\nshared tile:\n"); - s_tile.dump_value(); + printf("\nshared tile:\n"); + s_tile.dump_value(); - printf("\nglobal tile target:\n"); - g_dst_tile.dump_value(); - } + printf("\nglobal tile target:\n"); + g_dst_tile.dump_value(); + } #endif } template void test_row_major_store() { - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. Future refactoring of the program's concepts and interfaces - // should eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. Future refactoring of the program's concepts and interfaces + // should eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; - const int kThreads = tl::get_numel * 32; + const int kThreads = tl::get_numel * 32; - // define tiles - using Global = GlobalTile>; - static constexpr int kRowRepeats = - kRows / WarpLayout::kRows / BaseShape::kRows; - static constexpr int kColRepeats = - kCols / WarpLayout::kCols / BaseShape::kCols; + // define tiles + using Global = GlobalTile>; + static constexpr int kRowRepeats = + kRows / WarpLayout::kRows / BaseShape::kRows; + static constexpr int kColRepeats = + kCols / WarpLayout::kCols / BaseShape::kCols; - using Reg = RegTile, - tl::RowMajor>; - using Shared = SharedTile, kSwizzled, - kSharedAccessInBytes>; + using Reg = RegTile, + tl::RowMajor>; + using Shared = SharedTile, kSwizzled, + kSharedAccessInBytes>; - // define loader and storer - using Loader = GlobalToRegLoader; - using StorerR2S = RegToSharedStorer; - using StorerS2G = SharedToGlobalStorer; + // define loader and storer + using Loader = GlobalToRegLoader; + using StorerR2S = RegToSharedStorer; + using StorerS2G = SharedToGlobalStorer; - int numel = kRows * kCols; + int numel = kRows * kCols; - thrust::host_vector h_src(numel); - for (int i = 0; i < h_src.size(); ++i) h_src[i] = static_cast(i); + thrust::host_vector h_src(numel); + for (int i = 0; i < h_src.size(); ++i) h_src[i] = static_cast(i); - thrust::device_vector d_src = h_src; + thrust::device_vector d_src = h_src; - thrust::device_vector d_dst(numel); - thrust::fill(d_dst.begin(), d_dst.end(), static_cast(0.)); + thrust::device_vector d_dst(numel); + thrust::fill(d_dst.begin(), d_dst.end(), static_cast(0.)); - auto test_func = &swizzled_store; + auto test_func = &swizzled_store; - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = Shared::kNumel * sizeof(Element); + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); - test_func<<>>( - thrust::raw_pointer_cast(d_src.data()), - thrust::raw_pointer_cast(d_dst.data())); - cudaDeviceSynchronize(); + test_func<<>>( + thrust::raw_pointer_cast(d_src.data()), + thrust::raw_pointer_cast(d_dst.data())); + cudaDeviceSynchronize(); - thrust::host_vector h_dst = d_dst; + thrust::host_vector h_dst = d_dst; - assert_equal(thrust::raw_pointer_cast(h_src.data()), - thrust::raw_pointer_cast(h_dst.data()), numel, 1e-4); + assert_equal(thrust::raw_pointer_cast(h_src.data()), + thrust::raw_pointer_cast(h_dst.data()), numel, 1e-4); - LOG(INFO) << "[" << kRows << ", " << kCols << "] test passed!" << std::endl; + LOG(INFO) << "[" << kRows << ", " << kCols << "] test passed!" << std::endl; }; template void test_col_major_store() { - // FIXME(ying): Address the unnatural dependency on the MMA atom caused by - // BaseShape. - // Future refactoring of the program's concepts and interfaces should - // eliminate this dependency. - using MmaAtom = - compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; - using BaseShape = MmaAtom::BaseTile; - const int kThreads = tl::get_numel * 32; - - // define tiles - using Global = GlobalTile>; - static constexpr int kRowRepeats = - kRows / WarpLayout::kRows / BaseShape::kRows; - static constexpr int kColRepeats = - kCols / WarpLayout::kCols / BaseShape::kCols; - using Reg = RegTile, - tl::ColMajor>; - using Shared = SharedTile, kSwizzled>; - - // define loader and storer - using Loader = GlobalToRegLoader; - using StorerR2S = RegToSharedStorer; - using StorerS2G = SharedToGlobalStorer; - - int numel = kRows * kCols; - thrust::host_vector h_src(numel); - for (int i = 0; i < h_src.size(); ++i) - h_src[i] = static_cast(i % 2048); - thrust::device_vector d_src = h_src; - - thrust::device_vector d_dst(numel); - thrust::fill(d_dst.begin(), d_dst.end(), static_cast(0.)); - - auto test_func = &swizzled_store; - - dim3 dim_grid(1, 1, 1); - dim3 dim_block(kThreads, 1, 1); - int shm_size = Shared::kNumel * sizeof(Element); - - test_func<<>>( - thrust::raw_pointer_cast(d_src.data()), - thrust::raw_pointer_cast(d_dst.data())); - cudaDeviceSynchronize(); - - thrust::host_vector h_dst = d_dst; - - assert_equal(thrust::raw_pointer_cast(h_src.data()), - thrust::raw_pointer_cast(h_dst.data()), numel, 1e-4); - - LOG(INFO) << "[" << kRows << ", " << kCols << "] test passed!" << std::endl; + // FIXME(ying): Address the unnatural dependency on the MMA atom caused by + // BaseShape. + // Future refactoring of the program's concepts and interfaces should + // eliminate this dependency. + using MmaAtom = + compute::MmaAtom<__half, __half, __half, compute::MMA_ATOM_16x16x16>; + using BaseShape = MmaAtom::BaseTile; + const int kThreads = tl::get_numel * 32; + + // define tiles + using Global = GlobalTile>; + static constexpr int kRowRepeats = + kRows / WarpLayout::kRows / BaseShape::kRows; + static constexpr int kColRepeats = + kCols / WarpLayout::kCols / BaseShape::kCols; + using Reg = RegTile, + tl::ColMajor>; + using Shared = SharedTile, kSwizzled>; + + // define loader and storer + using Loader = GlobalToRegLoader; + using StorerR2S = RegToSharedStorer; + using StorerS2G = SharedToGlobalStorer; + + int numel = kRows * kCols; + thrust::host_vector h_src(numel); + for (int i = 0; i < h_src.size(); ++i) + h_src[i] = static_cast(i % 2048); + thrust::device_vector d_src = h_src; + + thrust::device_vector d_dst(numel); + thrust::fill(d_dst.begin(), d_dst.end(), static_cast(0.)); + + auto test_func = &swizzled_store; + + dim3 dim_grid(1, 1, 1); + dim3 dim_block(kThreads, 1, 1); + int shm_size = Shared::kNumel * sizeof(Element); + + test_func<<>>( + thrust::raw_pointer_cast(d_src.data()), + thrust::raw_pointer_cast(d_dst.data())); + cudaDeviceSynchronize(); + + thrust::host_vector h_dst = d_dst; + + assert_equal(thrust::raw_pointer_cast(h_src.data()), + thrust::raw_pointer_cast(h_dst.data()), numel, 1e-4); + + LOG(INFO) << "[" << kRows << ", " << kCols << "] test passed!" << std::endl; }; } // namespace TEST(TestSwizzledLoad, test_load_row_major) { - { - static constexpr int kSharedAccessInBytes = 128; - - run_test_rowmajor, 32, 64, 32, 64, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 32, 128, 32, 64, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 32, 128, 32, 128, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 32, 256, 32, 256, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 64, 64, 64, 64, - kSharedAccessInBytes>(); - - // smaller chunk - run_test_rowmajor, 64, 64, 64, 64, 32, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 64, 64, 64, 16, - kSharedAccessInBytes>(); - - run_test_rowmajor, 128, 128, 128, 64, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 128, 128, 128, 128, 128, - kSharedAccessInBytes>(); - run_test_rowmajor, 128, 128, 128, 128, 128, - kSharedAccessInBytes>(); - - run_test_rowmajor, 16, 256, 16, 128, 128, - kSharedAccessInBytes>(); - run_test_rowmajor, 32, 256, 32, 128, 128, - kSharedAccessInBytes>(); - - run_test_rowmajor, 32, 128, 32, 128, 128, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 128, 64, 128, 128, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 256, 64, 128, 128, - kSharedAccessInBytes>(); - - run_test_rowmajor, 32, 128, 32, 128, 128, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 256, 64, 128, 128, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 256, 64, 128, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 256, 64, 128, 32, - kSharedAccessInBytes>(); - - run_test_rowmajor, 32, 64, 32, 64, 64, - kSharedAccessInBytes>(); - run_test_rowmajor, 64, 64, 64, 64, 64, - kSharedAccessInBytes>(); - } - - { - static constexpr int kSharedAccessInBytes = 64; - // Swizzle <2, 3, 3> - run_test_rowmajor, 16, 32, 16, 32, 32, - kSharedAccessInBytes>(); - run_test_rowmajor, 32, 32, 32, 32, 32, - kSharedAccessInBytes>(); - run_test_rowmajor, 32, 64, 32, 64, 64, - kSharedAccessInBytes>(); - } + { + static constexpr int kSharedAccessInBytes = 128; + + run_test_rowmajor, 32, 64, 32, 64, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 32, 128, 32, 64, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 32, 128, 32, 128, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 32, 256, 32, 256, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 64, 64, 64, 64, + kSharedAccessInBytes>(); + + // smaller chunk + run_test_rowmajor, 64, 64, 64, 64, 32, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 64, 64, 64, 16, + kSharedAccessInBytes>(); + + run_test_rowmajor, 128, 128, 128, 64, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 128, 128, 128, 128, 128, + kSharedAccessInBytes>(); + run_test_rowmajor, 128, 128, 128, 128, 128, + kSharedAccessInBytes>(); + + run_test_rowmajor, 16, 256, 16, 128, 128, + kSharedAccessInBytes>(); + run_test_rowmajor, 32, 256, 32, 128, 128, + kSharedAccessInBytes>(); + + run_test_rowmajor, 32, 128, 32, 128, 128, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 128, 64, 128, 128, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 256, 64, 128, 128, + kSharedAccessInBytes>(); + + run_test_rowmajor, 32, 128, 32, 128, 128, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 256, 64, 128, 128, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 256, 64, 128, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 256, 64, 128, 32, + kSharedAccessInBytes>(); + + run_test_rowmajor, 32, 64, 32, 64, 64, + kSharedAccessInBytes>(); + run_test_rowmajor, 64, 64, 64, 64, 64, + kSharedAccessInBytes>(); + } + + { + static constexpr int kSharedAccessInBytes = 64; + // Swizzle <2, 3, 3> + run_test_rowmajor, 16, 32, 16, 32, 32, + kSharedAccessInBytes>(); + run_test_rowmajor, 32, 32, 32, 32, 32, + kSharedAccessInBytes>(); + run_test_rowmajor, 32, 64, 32, 64, 64, + kSharedAccessInBytes>(); + } } TEST(TestSwizzledLoad, test_load_col_major) { - { - static constexpr int kSharedAccessInBytes = 128; - - run_test_colmajor, 64, 32, 64, 32, 32, - kSharedAccessInBytes>(); - run_test_colmajor, 128, 64, 64, 64, 32, - kSharedAccessInBytes>(); - run_test_colmajor, 128, 64, 64, 64, 32, - kSharedAccessInBytes>(); - - run_test_colmajor, 128, 64, 128, 64, 64, - kSharedAccessInBytes>(); - run_test_colmajor, 64, 128, 64, 128, 64, - kSharedAccessInBytes>(); - - run_test_colmajor, 128, 128, 128, 128, 64, - kSharedAccessInBytes>(); - run_test_colmajor, 256, 128, 256, 128, 64, - kSharedAccessInBytes>(); - } - - { - static constexpr int kSharedAccessInBytes = 64; - // Swizzle <2, 3, 3> - run_test_colmajor, 32, 16, 32, 16, 16, - kSharedAccessInBytes>(); - run_test_colmajor, 32, 32, 32, 32, 32, - kSharedAccessInBytes>(); - run_test_colmajor, 64, 32, 64, 32, 32, - kSharedAccessInBytes>(); - } + { + static constexpr int kSharedAccessInBytes = 128; + + run_test_colmajor, 64, 32, 64, 32, 32, + kSharedAccessInBytes>(); + run_test_colmajor, 128, 64, 64, 64, 32, + kSharedAccessInBytes>(); + run_test_colmajor, 128, 64, 64, 64, 32, + kSharedAccessInBytes>(); + + run_test_colmajor, 128, 64, 128, 64, 64, + kSharedAccessInBytes>(); + run_test_colmajor, 64, 128, 64, 128, 64, + kSharedAccessInBytes>(); + + run_test_colmajor, 128, 128, 128, 128, 64, + kSharedAccessInBytes>(); + run_test_colmajor, 256, 128, 256, 128, 64, + kSharedAccessInBytes>(); + } + + { + static constexpr int kSharedAccessInBytes = 64; + // Swizzle <2, 3, 3> + run_test_colmajor, 32, 16, 32, 16, 16, + kSharedAccessInBytes>(); + run_test_colmajor, 32, 32, 32, 32, 32, + kSharedAccessInBytes>(); + run_test_colmajor, 64, 32, 64, 32, 32, + kSharedAccessInBytes>(); + } } TEST(TestNonSwizzledStore, test_row_major) { - static constexpr int kSwizzled = false; - - { - static constexpr int kSharedAccessInBytes = 128; - test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 1>, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 1>, 64, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<1, 2>, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 2>, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - - test_row_major_store, 16, 32, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 16, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 64, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - } - - { - static constexpr int kSharedAccessInBytes = 64; - // Swizzle <2, 3, 3> - test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 32, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 2>, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - } + static constexpr int kSwizzled = false; + + { + static constexpr int kSharedAccessInBytes = 128; + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 1>, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 1>, 64, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<1, 2>, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 2>, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + + test_row_major_store, 16, 32, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 16, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 64, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + } + + { + static constexpr int kSharedAccessInBytes = 64; + // Swizzle <2, 3, 3> + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 32, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 2>, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + } } TEST(TestSwizzledStored, test_row_major) { - static constexpr int kSwizzled = true; - - { - static constexpr int kSharedAccessInBytes = 128; - test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 1>, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 1>, 64, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<1, 2>, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 2>, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - - test_row_major_store, 16, 32, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 16, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 64, 64, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store, 64, 128, kSwizzled, - kSharedAccessInBytes>(); - } - - { - static constexpr int kSharedAccessInBytes = 64; - // Swizzle <2, 3, 3> - test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 32, kSwizzled, - kSharedAccessInBytes>(); - test_row_major_store<__half, tl::RowMajor<2, 2>, 32, 64, kSwizzled, - kSharedAccessInBytes>(); - } + static constexpr int kSwizzled = true; + + { + static constexpr int kSharedAccessInBytes = 128; + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 1>, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 1>, 64, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<1, 2>, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 2>, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + + test_row_major_store, 16, 32, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 16, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 64, 64, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store, 64, 128, kSwizzled, + kSharedAccessInBytes>(); + } + + { + static constexpr int kSharedAccessInBytes = 64; + // Swizzle <2, 3, 3> + test_row_major_store<__half, tl::RowMajor<1, 1>, 16, 32, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<1, 1>, 32, 32, kSwizzled, + kSharedAccessInBytes>(); + test_row_major_store<__half, tl::RowMajor<2, 2>, 32, 64, kSwizzled, + kSharedAccessInBytes>(); + } } // TEST(TestNonSwizzledStored, test_col_major) { diff --git a/tests/cpp/common/test_utils.cc b/tests/cpp/common/test_utils.cc index f43768c6..b8bcc1cb 100644 --- a/tests/cpp/common/test_utils.cc +++ b/tests/cpp/common/test_utils.cc @@ -8,23 +8,23 @@ namespace tilefusion::testing { template <> void assert_equal(const __half* v1, const __half* v2, int64_t numel, float epsilon) { - float a = 0.f; - float b = 0.f; - for (int i = 0; i < numel; ++i) { - a = __half2float(v1[i]); - b = __half2float(v2[i]); + float a = 0.f; + float b = 0.f; + for (int i = 0; i < numel; ++i) { + a = __half2float(v1[i]); + b = __half2float(v2[i]); - EXPECT_NEAR(a, b, epsilon) << "v1[" << i << "] vs. v2[" << i - << "] = " << a << " vs. " << b << std::endl; - } + EXPECT_NEAR(a, b, epsilon) << "v1[" << i << "] vs. v2[" << i << "] = " << a + << " vs. " << b << std::endl; + } } template <> void assert_equal(const float* v1, const float* v2, int64_t numel, float epsilon) { - for (int i = 0; i < numel; ++i) - EXPECT_NEAR(v1[i], v2[i], epsilon) - << "v1[" << i << "] vs. v2[" << i << "] = " << v1[i] << " vs. " - << v2[i] << std::endl; + for (int i = 0; i < numel; ++i) + EXPECT_NEAR(v1[i], v2[i], epsilon) + << "v1[" << i << "] vs. v2[" << i << "] = " << v1[i] << " vs. " << v2[i] + << std::endl; } } // namespace tilefusion::testing diff --git a/tests/cpp/jit/test_jit.cc b/tests/cpp/jit/test_jit.cc index 4d19f5d4..ebd22091 100644 --- a/tests/cpp/jit/test_jit.cc +++ b/tests/cpp/jit/test_jit.cc @@ -19,15 +19,15 @@ using namespace tilefusion::jit; namespace { float rand_float(float a = 1e-3, float b = 1) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; } std::string generate_add_kernel_source(const std::string& dtype, int numel) { - std::stringstream ss; - ss << R"( + std::stringstream ss; + ss << R"( template __device__ void add_device(const DType* a, const DType* b, DType* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -35,102 +35,101 @@ __device__ void add_device(const DType* a, const DType* b, DType* out) { } extern "C" __global__ void add_kernel_)" - << dtype << "_" << numel << R"(( + << dtype << "_" << numel << R"(( const )" - << dtype << R"(* a, const )" << dtype << R"(* b, )" << dtype - << R"(* out) { + << dtype << R"(* a, const )" << dtype << R"(* b, )" << dtype << R"(* out) { add_device<)" - << dtype << ", " << numel << R"(>(a, b, out); + << dtype << ", " << numel << R"(>(a, b, out); } )"; - return ss.str(); + return ss.str(); } template void jit_add_template(const DType* a, const DType* b, DType* out, int n) { - if (n == 0) return; + if (n == 0) return; - std::string dtype = get_type_string(); - std::string kernel_source = generate_add_kernel_source(dtype, n); - std::string kernel_name = "add_kernel_" + dtype + "_" + std::to_string(n); + std::string dtype = get_type_string(); + std::string kernel_source = generate_add_kernel_source(dtype, n); + std::string kernel_name = "add_kernel_" + dtype + "_" + std::to_string(n); - auto& jit = JitCompiler::instance(); - CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_source); + auto& jit = JitCompiler::instance(); + CUfunction kernel = jit.get_or_compile_kernel(kernel_name, kernel_source); - if (!kernel) { - throw std::runtime_error("Failed to compile or retrieve kernel"); - } + if (!kernel) { + throw std::runtime_error("Failed to compile or retrieve kernel"); + } - int block_size = 128; - int grid_size = (n + block_size - 1) / block_size; + int block_size = 128; + int grid_size = (n + block_size - 1) / block_size; - void* args[] = {&a, &b, &out}; + void* args[] = {&a, &b, &out}; - CUDA_DRIVER_CHECK(cuLaunchKernel(kernel, grid_size, 1, 1, block_size, 1, 1, - 0, nullptr, args, nullptr)); + CUDA_DRIVER_CHECK(cuLaunchKernel(kernel, grid_size, 1, 1, block_size, 1, 1, 0, + nullptr, args, nullptr)); - LOG(INFO) << "Kernel launched successfully"; + LOG(INFO) << "Kernel launched successfully"; } } // namespace TEST(TESTJit, test_jit) { - const int kNumel = 1024; - using Element = float; - - thrust::host_vector h_a(kNumel); - for (size_t i = 0; i < h_a.size(); ++i) h_a[i] = rand_float(); - - thrust::host_vector h_b(kNumel); - for (size_t i = 0; i < h_b.size(); ++i) h_b[i] = rand_float(); - - thrust::host_vector h_out(kNumel); - thrust::fill(h_out.begin(), h_out.end(), 0.); - - thrust::device_vector d_a = h_a; - thrust::device_vector d_b = h_b; - thrust::device_vector d_out = h_out; - - jit_add_template(thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_out.data()), kNumel); - h_out = d_out; - - // Verify results - bool all_correct = true; - float max_diff = 0.0f; - size_t error_idx = 0; - - // Calculate ground truth on CPU - thrust::host_vector h_expected(kNumel); - for (size_t i = 0; i < kNumel; ++i) { - h_expected[i] = h_a[i] + h_b[i]; + const int kNumel = 1024; + using Element = float; + + thrust::host_vector h_a(kNumel); + for (size_t i = 0; i < h_a.size(); ++i) h_a[i] = rand_float(); + + thrust::host_vector h_b(kNumel); + for (size_t i = 0; i < h_b.size(); ++i) h_b[i] = rand_float(); + + thrust::host_vector h_out(kNumel); + thrust::fill(h_out.begin(), h_out.end(), 0.); + + thrust::device_vector d_a = h_a; + thrust::device_vector d_b = h_b; + thrust::device_vector d_out = h_out; + + jit_add_template(thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_out.data()), kNumel); + h_out = d_out; + + // Verify results + bool all_correct = true; + float max_diff = 0.0f; + size_t error_idx = 0; + + // Calculate ground truth on CPU + thrust::host_vector h_expected(kNumel); + for (size_t i = 0; i < kNumel; ++i) { + h_expected[i] = h_a[i] + h_b[i]; + } + + for (size_t i = 0; i < kNumel; ++i) { + float diff = std::abs(h_out[i] - h_expected[i]); + if (diff > max_diff) { + max_diff = diff; + error_idx = i; } - for (size_t i = 0; i < kNumel; ++i) { - float diff = std::abs(h_out[i] - h_expected[i]); - if (diff > max_diff) { - max_diff = diff; - error_idx = i; - } - - if (diff > 1e-5f) { - all_correct = false; - if (i < 10) { - LOG(ERROR) << "Mismatch at index " << i << ": GPU=" << h_out[i] - << ", CPU=" << h_expected[i] << ", diff=" << diff; - } - } + if (diff > 1e-5f) { + all_correct = false; + if (i < 10) { + LOG(ERROR) << "Mismatch at index " << i << ": GPU=" << h_out[i] + << ", CPU=" << h_expected[i] << ", diff=" << diff; + } } + } - LOG(INFO) << "Max difference: " << max_diff << " at index " << error_idx; + LOG(INFO) << "Max difference: " << max_diff << " at index " << error_idx; - LOG(INFO) << "Sample results (first 10 elements):"; - for (size_t i = 0; i < 10 && i < kNumel; ++i) { - LOG(INFO) << "Index " << i << ": GPU=" << h_out[i] - << ", CPU=" << h_expected[i]; - } + LOG(INFO) << "Sample results (first 10 elements):"; + for (size_t i = 0; i < 10 && i < kNumel; ++i) { + LOG(INFO) << "Index " << i << ": GPU=" << h_out[i] + << ", CPU=" << h_expected[i]; + } - EXPECT_TRUE(all_correct) << "GPU and CPU results do not match!"; + EXPECT_TRUE(all_correct) << "GPU and CPU results do not match!"; } } // namespace tilefusion::testing diff --git a/tests/cpp/test_unit.cc b/tests/cpp/test_unit.cc index e44f4ae7..ae566695 100644 --- a/tests/cpp/test_unit.cc +++ b/tests/cpp/test_unit.cc @@ -4,10 +4,10 @@ #include "common/test_utils.hpp" int main(int argc, char** argv) { - FLAGS_alsologtostderr = 1; // redirect log to stderr - google::InitGoogleLogging(argv[0]); + FLAGS_alsologtostderr = 1; // redirect log to stderr + google::InitGoogleLogging(argv[0]); - testing::InitGoogleTest(&argc, argv); + testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/tests/cpp/types/test_fp8.cu b/tests/cpp/types/test_fp8.cu index 89d74fb3..7ae686fe 100644 --- a/tests/cpp/types/test_fp8.cu +++ b/tests/cpp/types/test_fp8.cu @@ -12,19 +12,19 @@ __global__ void fp8_conversion_kernel(const float* input, void* output_e4m3, void* output_e5m2, float* result_e4m3, float* result_e5m2, int size) { #ifdef CUDA_FP8_AVAILABLE - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - __nv_fp8_e4m3* e4m3_output = static_cast<__nv_fp8_e4m3*>(output_e4m3); - __nv_fp8_e5m2* e5m2_output = static_cast<__nv_fp8_e5m2*>(output_e5m2); - - // Convert float to FP8 - e4m3_output[idx] = from_float<__fp8_e4m3>(input[idx]); - e5m2_output[idx] = from_float<__fp8_e5m2>(input[idx]); - - // Convert back to float - result_e4m3[idx] = to_float(e4m3_output[idx]); - result_e5m2[idx] = to_float(e5m2_output[idx]); - } + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + __nv_fp8_e4m3* e4m3_output = static_cast<__nv_fp8_e4m3*>(output_e4m3); + __nv_fp8_e5m2* e5m2_output = static_cast<__nv_fp8_e5m2*>(output_e5m2); + + // Convert float to FP8 + e4m3_output[idx] = from_float<__fp8_e4m3>(input[idx]); + e5m2_output[idx] = from_float<__fp8_e5m2>(input[idx]); + + // Convert back to float + result_e4m3[idx] = to_float(e4m3_output[idx]); + result_e5m2[idx] = to_float(e5m2_output[idx]); + } #endif } @@ -34,283 +34,282 @@ namespace tilefusion::testing { /// @brief Test basic FP8 construction and conversion TEST(TestFP8, test_fp8_construction) { - // Test with values that are exactly representable in FP8 - { - // Test simple powers of 2 and small integers - // usually exactly representable - __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(2.0f); - float e4m3_back = to_float(e4m3_val); - printf("E4M3 (2.0): %f\n", e4m3_back); - EXPECT_EQ(e4m3_back, 2.0f); // Should be exact - - __fp8_e5m2 e5m2_val(4.0f); - float e5m2_back = static_cast(e5m2_val); - printf("E5M2 (4.0): %f\n", e5m2_back); - EXPECT_EQ(e5m2_back, 4.0f); // Should be exact - } - - // Test edge cases - { - __fp8_e4m3 e4m3_zero(0.0f); - __fp8_e5m2 e5m2_zero(0.0f); - EXPECT_EQ(static_cast(e4m3_zero), 0.0f); - EXPECT_EQ(static_cast(e5m2_zero), 0.0f); - - __fp8_e4m3 e4m3_one(1.0f); - __fp8_e5m2 e5m2_one(1.0f); - EXPECT_EQ(static_cast(e4m3_one), 1.0f); - EXPECT_EQ(static_cast(e5m2_one), 1.0f); - } + // Test with values that are exactly representable in FP8 + { + // Test simple powers of 2 and small integers + // usually exactly representable + __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(2.0f); + float e4m3_back = to_float(e4m3_val); + printf("E4M3 (2.0): %f\n", e4m3_back); + EXPECT_EQ(e4m3_back, 2.0f); // Should be exact + + __fp8_e5m2 e5m2_val(4.0f); + float e5m2_back = static_cast(e5m2_val); + printf("E5M2 (4.0): %f\n", e5m2_back); + EXPECT_EQ(e5m2_back, 4.0f); // Should be exact + } + + // Test edge cases + { + __fp8_e4m3 e4m3_zero(0.0f); + __fp8_e5m2 e5m2_zero(0.0f); + EXPECT_EQ(static_cast(e4m3_zero), 0.0f); + EXPECT_EQ(static_cast(e5m2_zero), 0.0f); + + __fp8_e4m3 e4m3_one(1.0f); + __fp8_e5m2 e5m2_one(1.0f); + EXPECT_EQ(static_cast(e4m3_one), 1.0f); + EXPECT_EQ(static_cast(e5m2_one), 1.0f); + } } /// @brief Test FP8 precision characteristics and ranges TEST(TestFP8, test_fp8_precision_characteristics) { - { // Test small values in the precise range - float test_val = 0.5f; - __fp8_e4m3 e4m3_val(test_val); - __fp8_e5m2 e5m2_val(test_val); - - printf("Small value (0.5): E4M3=%f, E5M2=%f\n", to_float(e4m3_val), - to_float(e5m2_val)); - - // Use relative tolerance - EXPECT_NEAR(to_float(e4m3_val), test_val, test_val * 0.1f); - EXPECT_NEAR(to_float(e5m2_val), test_val, test_val * 0.1f); - } - - { // Test medium values (where precision loss starts) - float test_val = 3.0f; // Use a value more likely to be representable - __fp8_e4m3 e4m3_val(test_val); - __fp8_e5m2 e5m2_val(test_val); - - printf("Medium value (3.0): E4M3=%f, E5M2=%f\n", to_float(e4m3_val), - to_float(e5m2_val)); - - // Use relative tolerance - EXPECT_NEAR(to_float(e4m3_val), test_val, test_val * 0.15f); - EXPECT_NEAR(to_float(e5m2_val), test_val, test_val * 0.25f); - } - - { // Test larger values (significant quantization) - float test_val = 9.0f; - __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(test_val); - __fp8_e5m2 e5m2_val = from_float<__fp8_e5m2>(test_val); - - printf("Large value (8.0): E4M3=%f, E5M2=%f\n", to_float(e4m3_val), - to_float(e5m2_val)); - - // Much larger relative tolerance for larger values - EXPECT_NEAR(to_float(e4m3_val), test_val, test_val * 0.25f); - EXPECT_NEAR(to_float(e5m2_val), test_val, test_val * 0.5f); - } + { // Test small values in the precise range + float test_val = 0.5f; + __fp8_e4m3 e4m3_val(test_val); + __fp8_e5m2 e5m2_val(test_val); + + printf("Small value (0.5): E4M3=%f, E5M2=%f\n", to_float(e4m3_val), + to_float(e5m2_val)); + + // Use relative tolerance + EXPECT_NEAR(to_float(e4m3_val), test_val, test_val * 0.1f); + EXPECT_NEAR(to_float(e5m2_val), test_val, test_val * 0.1f); + } + + { // Test medium values (where precision loss starts) + float test_val = 3.0f; // Use a value more likely to be representable + __fp8_e4m3 e4m3_val(test_val); + __fp8_e5m2 e5m2_val(test_val); + + printf("Medium value (3.0): E4M3=%f, E5M2=%f\n", to_float(e4m3_val), + to_float(e5m2_val)); + + // Use relative tolerance + EXPECT_NEAR(to_float(e4m3_val), test_val, test_val * 0.15f); + EXPECT_NEAR(to_float(e5m2_val), test_val, test_val * 0.25f); + } + + { // Test larger values (significant quantization) + float test_val = 9.0f; + __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(test_val); + __fp8_e5m2 e5m2_val = from_float<__fp8_e5m2>(test_val); + + printf("Large value (8.0): E4M3=%f, E5M2=%f\n", to_float(e4m3_val), + to_float(e5m2_val)); + + // Much larger relative tolerance for larger values + EXPECT_NEAR(to_float(e4m3_val), test_val, test_val * 0.25f); + EXPECT_NEAR(to_float(e5m2_val), test_val, test_val * 0.5f); + } } /// @brief Test that conversion functions work without crashing TEST(TestFP8, test_fp8_conversion_safety) { - // Test a diverse range of values to ensure no crashes - std::vector test_values = { - // Small values - 0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, - 0.875f, - // Around 1.0 - 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, - // Small integers and fractions - 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, - // Medium values - 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, - // Larger values (testing FP8 range limits) - 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f}; - - for (float val : test_values) { - // Just test that conversion works without crashing - __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(val); - __fp8_e5m2 e5m2_val = from_float<__fp8_e5m2>(val); - - float e4m3_back = to_float(e4m3_val); - float e5m2_back = to_float(e5m2_val); - - // Basic sanity check - result should be finite - EXPECT_TRUE(std::isfinite(e4m3_back)) - << "E4M3 conversion of " << val << " produced non-finite result"; - EXPECT_TRUE(std::isfinite(e5m2_back)) - << "E5M2 conversion of " << val << " produced non-finite result"; - - printf("Value %f -> E4M3: %f, E5M2: %f\n", val, e4m3_back, e5m2_back); - } + // Test a diverse range of values to ensure no crashes + std::vector test_values = { + // Small values + 0.0f, 0.0625f, 0.125f, 0.1875f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, + 0.875f, + // Around 1.0 + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + // Small integers and fractions + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + // Medium values + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + // Larger values (testing FP8 range limits) + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 14.0f, 16.0f, 20.0f, 24.0f, 28.0f}; + + for (float val : test_values) { + // Just test that conversion works without crashing + __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(val); + __fp8_e5m2 e5m2_val = from_float<__fp8_e5m2>(val); + + float e4m3_back = to_float(e4m3_val); + float e5m2_back = to_float(e5m2_val); + + // Basic sanity check - result should be finite + EXPECT_TRUE(std::isfinite(e4m3_back)) + << "E4M3 conversion of " << val << " produced non-finite result"; + EXPECT_TRUE(std::isfinite(e5m2_back)) + << "E5M2 conversion of " << val << " produced non-finite result"; + + printf("Value %f -> E4M3: %f, E5M2: %f\n", val, e4m3_back, e5m2_back); + } } /// @brief Test TileFusion utility functions TEST(TestFP8, test_fp8_utility_functions) { - float original = 1.5f; + float original = 1.5f; - // Test E4M3 conversions - __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(original); - float e4m3_result = to_float(e4m3_val); - EXPECT_NEAR(e4m3_result, original, 0.01f); + // Test E4M3 conversions + __fp8_e4m3 e4m3_val = from_float<__fp8_e4m3>(original); + float e4m3_result = to_float(e4m3_val); + EXPECT_NEAR(e4m3_result, original, 0.01f); - // Test E5M2 conversions - __fp8_e5m2 e5m2_val = from_float<__fp8_e5m2>(original); - float e5m2_result = to_float(e5m2_val); - EXPECT_NEAR(e5m2_result, original, 0.01f); + // Test E5M2 conversions + __fp8_e5m2 e5m2_val = from_float<__fp8_e5m2>(original); + float e5m2_result = to_float(e5m2_val); + EXPECT_NEAR(e5m2_result, original, 0.01f); } /// @brief Test FP8 arithmetic operations (through float conversion) TEST(TestFP8, test_fp8_arithmetic) { - __fp8_e4m3 a(3.0f); - __fp8_e5m2 b(2.0f); + __fp8_e4m3 a(3.0f); + __fp8_e5m2 b(2.0f); - // Convert to float for computation - float a_float = to_float(a); - float b_float = to_float(b); + // Convert to float for computation + float a_float = to_float(a); + float b_float = to_float(b); - // Perform computation in float - float sum = a_float + b_float; - float product = a_float * b_float; + // Perform computation in float + float sum = a_float + b_float; + float product = a_float * b_float; - // Convert back to FP8 - __fp8_e4m3 sum_e4m3(sum); - __fp8_e5m2 product_e5m2(product); + // Convert back to FP8 + __fp8_e4m3 sum_e4m3(sum); + __fp8_e5m2 product_e5m2(product); - EXPECT_NEAR(static_cast(sum_e4m3), 5.0f, 0.1f); - EXPECT_NEAR(static_cast(product_e5m2), 6.0f, 0.1f); + EXPECT_NEAR(static_cast(sum_e4m3), 5.0f, 0.1f); + EXPECT_NEAR(static_cast(product_e5m2), 6.0f, 0.1f); } /// @brief Test FP8 type traits TEST(TestFP8, test_fp8_traits) { - // Test that FP8 types satisfy BaseType concept - static_assert(BaseType<__fp8_e4m3>); - static_assert(BaseType<__fp8_e5m2>); - - // Test that FP8 types satisfy Fp8Type concept - static_assert(Fp8Type<__fp8_e4m3>); - static_assert(Fp8Type<__fp8_e5m2>); - - // Test that other types don't satisfy Fp8Type concept - static_assert(!Fp8Type); - static_assert(!Fp8Type<__half>); - static_assert(!Fp8Type<__bfloat16>); + // Test that FP8 types satisfy BaseType concept + static_assert(BaseType<__fp8_e4m3>); + static_assert(BaseType<__fp8_e5m2>); + + // Test that FP8 types satisfy Fp8Type concept + static_assert(Fp8Type<__fp8_e4m3>); + static_assert(Fp8Type<__fp8_e5m2>); + + // Test that other types don't satisfy Fp8Type concept + static_assert(!Fp8Type); + static_assert(!Fp8Type<__half>); + static_assert(!Fp8Type<__bfloat16>); } /// @brief Test FP8 operations on device TEST(TestFP8, test_fp8_device_operations) { - const int size = 64; - const int bytes = size * sizeof(float); - - std::vector h_input(size); - std::vector h_output_e4m3(size); - std::vector h_output_e5m2(size); - - // Initialize input with better test values for FP8 - std::vector good_fp8_values = { - // Small precise values - 0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f, - // Around 1.0 - 1.0f, 1.25f, 1.5f, 1.75f, - // Small integers and key fractions - 2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 4.5f, 5.0f, 6.0f, 7.0f, 8.0f, - // Larger values within FP8 range - 9.0f, 10.0f, 12.0f, 14.0f, 16.0f}; - - for (int i = 0; i < size; ++i) { - h_input[i] = good_fp8_values[i % good_fp8_values.size()]; - } - - // Device memory - float *d_input, *d_result_e4m3, *d_result_e5m2; - __fp8_e4m3* d_fp8_e4m3; - __fp8_e5m2* d_fp8_e5m2; - - cudaMalloc(&d_input, bytes); - cudaMalloc(&d_result_e4m3, bytes); - cudaMalloc(&d_result_e5m2, bytes); - cudaMalloc(&d_fp8_e4m3, size * sizeof(__fp8_e4m3)); - cudaMalloc(&d_fp8_e5m2, size * sizeof(__fp8_e5m2)); - - cudaMemcpy(d_input, h_input.data(), bytes, cudaMemcpyHostToDevice); - - dim3 block(256); - dim3 grid((size + block.x - 1) / block.x); - fp8_conversion_kernel<<>>( - d_input, static_cast(d_fp8_e4m3), static_cast(d_fp8_e5m2), - d_result_e4m3, d_result_e5m2, size); - - // Copy results back - cudaMemcpy(h_output_e4m3.data(), d_result_e4m3, bytes, - cudaMemcpyDeviceToHost); - cudaMemcpy(h_output_e5m2.data(), d_result_e5m2, bytes, - cudaMemcpyDeviceToHost); - - // Verify results with appropriate tolerances - for (int i = 0; i < size; ++i) { - float input_val = h_input[i]; - float e4m3_result = h_output_e4m3[i]; - float e5m2_result = h_output_e5m2[i]; - - // Use relative tolerance that scales with the input value - float e4m3_tolerance = std::max(0.1f, input_val * 0.2f); - float e5m2_tolerance = std::max(0.1f, input_val * 0.3f); - - EXPECT_NEAR(e4m3_result, input_val, e4m3_tolerance) - << "E4M3 mismatch at index " << i << " (input=" << input_val << ")"; - EXPECT_NEAR(e5m2_result, input_val, e5m2_tolerance) - << "E5M2 mismatch at index " << i << " (input=" << input_val << ")"; - } - - cudaFree(d_input); - cudaFree(d_result_e4m3); - cudaFree(d_result_e5m2); - cudaFree(d_fp8_e4m3); - cudaFree(d_fp8_e5m2); + const int size = 64; + const int bytes = size * sizeof(float); + + std::vector h_input(size); + std::vector h_output_e4m3(size); + std::vector h_output_e5m2(size); + + // Initialize input with better test values for FP8 + std::vector good_fp8_values = { + // Small precise values + 0.0f, 0.125f, 0.25f, 0.375f, 0.5f, 0.625f, 0.75f, 0.875f, + // Around 1.0 + 1.0f, 1.25f, 1.5f, 1.75f, + // Small integers and key fractions + 2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 4.5f, 5.0f, 6.0f, 7.0f, 8.0f, + // Larger values within FP8 range + 9.0f, 10.0f, 12.0f, 14.0f, 16.0f}; + + for (int i = 0; i < size; ++i) { + h_input[i] = good_fp8_values[i % good_fp8_values.size()]; + } + + // Device memory + float *d_input, *d_result_e4m3, *d_result_e5m2; + __fp8_e4m3* d_fp8_e4m3; + __fp8_e5m2* d_fp8_e5m2; + + cudaMalloc(&d_input, bytes); + cudaMalloc(&d_result_e4m3, bytes); + cudaMalloc(&d_result_e5m2, bytes); + cudaMalloc(&d_fp8_e4m3, size * sizeof(__fp8_e4m3)); + cudaMalloc(&d_fp8_e5m2, size * sizeof(__fp8_e5m2)); + + cudaMemcpy(d_input, h_input.data(), bytes, cudaMemcpyHostToDevice); + + dim3 block(256); + dim3 grid((size + block.x - 1) / block.x); + fp8_conversion_kernel<<>>( + d_input, static_cast(d_fp8_e4m3), static_cast(d_fp8_e5m2), + d_result_e4m3, d_result_e5m2, size); + + // Copy results back + cudaMemcpy(h_output_e4m3.data(), d_result_e4m3, bytes, + cudaMemcpyDeviceToHost); + cudaMemcpy(h_output_e5m2.data(), d_result_e5m2, bytes, + cudaMemcpyDeviceToHost); + + // Verify results with appropriate tolerances + for (int i = 0; i < size; ++i) { + float input_val = h_input[i]; + float e4m3_result = h_output_e4m3[i]; + float e5m2_result = h_output_e5m2[i]; + + // Use relative tolerance that scales with the input value + float e4m3_tolerance = std::max(0.1f, input_val * 0.2f); + float e5m2_tolerance = std::max(0.1f, input_val * 0.3f); + + EXPECT_NEAR(e4m3_result, input_val, e4m3_tolerance) + << "E4M3 mismatch at index " << i << " (input=" << input_val << ")"; + EXPECT_NEAR(e5m2_result, input_val, e5m2_tolerance) + << "E5M2 mismatch at index " << i << " (input=" << input_val << ")"; + } + + cudaFree(d_input); + cudaFree(d_result_e4m3); + cudaFree(d_result_e5m2); + cudaFree(d_fp8_e4m3); + cudaFree(d_fp8_e5m2); } #else // !CUDA_FP8_AVAILABLE /// @brief Test that runs when FP8 is not available TEST(TestFP8, test_fp8_not_available) { - // This test ensures that the build works even when FP8 is not available - GTEST_SKIP() << "FP8 support not available - requires Ada Lovelace or " - "Hopper GPU with CUDA 11.8+"; + // This test ensures that the build works even when FP8 is not available + GTEST_SKIP() << "FP8 support not available - requires Ada Lovelace or " + "Hopper GPU with CUDA 11.8+"; } #endif // CUDA_FP8_AVAILABLE /// @brief Test hardware detection (runs regardless of FP8 support) TEST(TestFP8, test_hardware_detection) { - int device; - cudaError_t err = cudaGetDevice(&device); - EXPECT_EQ(err, cudaSuccess); - - cudaDeviceProp prop; - err = cudaGetDeviceProperties(&prop, device); - EXPECT_EQ(err, cudaSuccess); - - int major = prop.major; - int minor = prop.minor; - int compute_capability = major * 10 + minor; - - LOG(INFO) << "GPU: " << prop.name; - LOG(INFO) << "Compute Capability: " << major << "." << minor; - - if (compute_capability >= 89) { - LOG(INFO) << "FP8 hardware support: YES"; - if (major == 8 && minor == 9) { - LOG(INFO) << "Architecture: Ada Lovelace"; - } else if (major >= 9) { - LOG(INFO) << "Architecture: Hopper"; - } - } else { - LOG(INFO) - << "FP8 hardware support: NO (requires compute capability 8.9+)"; + int device; + cudaError_t err = cudaGetDevice(&device); + EXPECT_EQ(err, cudaSuccess); + + cudaDeviceProp prop; + err = cudaGetDeviceProperties(&prop, device); + EXPECT_EQ(err, cudaSuccess); + + int major = prop.major; + int minor = prop.minor; + int compute_capability = major * 10 + minor; + + LOG(INFO) << "GPU: " << prop.name; + LOG(INFO) << "Compute Capability: " << major << "." << minor; + + if (compute_capability >= 89) { + LOG(INFO) << "FP8 hardware support: YES"; + if (major == 8 && minor == 9) { + LOG(INFO) << "Architecture: Ada Lovelace"; + } else if (major >= 9) { + LOG(INFO) << "Architecture: Hopper"; } + } else { + LOG(INFO) << "FP8 hardware support: NO (requires compute capability 8.9+)"; + } #ifdef CUDA_FP8_AVAILABLE - LOG(INFO) << "FP8 compile-time support: YES"; + LOG(INFO) << "FP8 compile-time support: YES"; #else - LOG(INFO) << "FP8 compile-time support: NO"; + LOG(INFO) << "FP8 compile-time support: NO"; #endif - // Test always passes - this is just for information - EXPECT_TRUE(true); + // Test always passes - this is just for information + EXPECT_TRUE(true); } } // namespace tilefusion::testing diff --git a/tests/cpp/types/test_gtile_iterator.cu b/tests/cpp/types/test_gtile_iterator.cu index 9f090bb7..414c58a8 100644 --- a/tests/cpp/types/test_gtile_iterator.cu +++ b/tests/cpp/types/test_gtile_iterator.cu @@ -18,142 +18,142 @@ struct GTileIteratorTester; template struct GTileIteratorTester { - using Element = float; - using Layout = Layout_; + using Element = float; + using Layout = Layout_; - static constexpr int kRows = Layout::kRows; - static constexpr int kCols = Layout::kCols; + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; - static constexpr int kStride0 = dim_size<0, ChunkShape>; - static constexpr int kStride1 = dim_size<1, ChunkShape>; + static constexpr int kStride0 = dim_size<0, ChunkShape>; + static constexpr int kStride1 = dim_size<1, ChunkShape>; - const int kTileRowStride = kStride0 * Layout::kRowStride; - const int kTileColStride = kStride1; + const int kTileRowStride = kStride0 * Layout::kRowStride; + const int kTileColStride = kStride1; - static_assert(kRows % kStride0 == 0, "kRows must be divisible by kStride0"); - static_assert(kCols % kStride1 == 0, "kCols must be divisible by kStride1"); + static_assert(kRows % kStride0 == 0, "kRows must be divisible by kStride0"); + static_assert(kCols % kStride1 == 0, "kCols must be divisible by kStride1"); - using Tile = GlobalTile>; - using Iterator = GTileIterator; + using Tile = GlobalTile>; + using Iterator = GTileIterator; - void operator()() { - int numel = kRows * kCols; - thrust::host_vector data(numel); + void operator()() { + int numel = kRows * kCols; + thrust::host_vector data(numel); - Layout layout; - Element* ptr = data.data(); - int count = 0; - for (int i = 0; i < kRows; ++i) - for (int j = 0; j < kCols; ++j) ptr[count++] = layout(i, j); + Layout layout; + Element* ptr = data.data(); + int count = 0; + for (int i = 0; i < kRows; ++i) + for (int j = 0; j < kCols; ++j) ptr[count++] = layout(i, j); #if defined(DEBUG_PRINT) - Tile gtile(ptr); - gtile.dump_value(); + Tile gtile(ptr); + gtile.dump_value(); #endif - EXPECT_EQ(Iterator::sc0, kRows / kStride0); - EXPECT_EQ(Iterator::sc1, kCols / kStride1); - - Iterator iter(data.data()); - - for (int i = 0; i < Iterator::sc0; ++i) { - for (int j = 0; j < Iterator::sc1; ++j) { - int start_n = i * kTileRowStride + j * kTileColStride; - auto tile = iter(i, j); - for (int m = 0; m < kStride0; ++m) { - for (int n = 0; n < kStride1; ++n) { - int v1 = int(tile(m, n)); - int v2 = start_n + m * Layout::kRowStride + n; - EXPECT_EQ(v1, v2); - } - } + EXPECT_EQ(Iterator::sc0, kRows / kStride0); + EXPECT_EQ(Iterator::sc1, kCols / kStride1); + + Iterator iter(data.data()); + + for (int i = 0; i < Iterator::sc0; ++i) { + for (int j = 0; j < Iterator::sc1; ++j) { + int start_n = i * kTileRowStride + j * kTileColStride; + auto tile = iter(i, j); + for (int m = 0; m < kStride0; ++m) { + for (int n = 0; n < kStride1; ++n) { + int v1 = int(tile(m, n)); + int v2 = start_n + m * Layout::kRowStride + n; + EXPECT_EQ(v1, v2); + } + } #if defined(DEBUG_PRINT) - printf("\nIteration-[%d, %d]:\n", i, j); - iter(i, j).dump_value(); - printf("\n"); + printf("\nIteration-[%d, %d]:\n", i, j); + iter(i, j).dump_value(); + printf("\n"); #endif - } - } + } } + } }; template struct GTileIteratorTester { - using Element = float; - using Layout = Layout_; + using Element = float; + using Layout = Layout_; - static constexpr int kRows = Layout::kRows; - static constexpr int kCols = Layout::kCols; + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; - static constexpr int kStride0 = dim_size<0, ChunkShape>; - static constexpr int kStride1 = dim_size<1, ChunkShape>; + static constexpr int kStride0 = dim_size<0, ChunkShape>; + static constexpr int kStride1 = dim_size<1, ChunkShape>; - const int kTileRowStride = kStride0; - const int kTileColStride = kStride1 * Layout::kColStride; + const int kTileRowStride = kStride0; + const int kTileColStride = kStride1 * Layout::kColStride; - static_assert(kRows % kStride0 == 0, "kRows must be divisible by kStride0"); - static_assert(kCols % kStride1 == 0, "kCols must be divisible by kStride1"); + static_assert(kRows % kStride0 == 0, "kRows must be divisible by kStride0"); + static_assert(kCols % kStride1 == 0, "kCols must be divisible by kStride1"); - using Tile = GlobalTile>; - using Iterator = GTileIterator; + using Tile = GlobalTile>; + using Iterator = GTileIterator; - void operator()() { - int numel = kRows * kCols; - thrust::host_vector data(numel); + void operator()() { + int numel = kRows * kCols; + thrust::host_vector data(numel); - Layout layout; - Element* ptr = data.data(); - int count = 0; - for (int i = 0; i < kRows; ++i) - for (int j = 0; j < kCols; ++j) ptr[count++] = layout(i, j); + Layout layout; + Element* ptr = data.data(); + int count = 0; + for (int i = 0; i < kRows; ++i) + for (int j = 0; j < kCols; ++j) ptr[count++] = layout(i, j); #if defined(DEBUG_PRINT) - Tile gtile(ptr); - gtile.dump_value(); + Tile gtile(ptr); + gtile.dump_value(); #endif - EXPECT_EQ(Iterator::sc0, kRows / kStride0); - EXPECT_EQ(Iterator::sc1, kCols / kStride1); + EXPECT_EQ(Iterator::sc0, kRows / kStride0); + EXPECT_EQ(Iterator::sc1, kCols / kStride1); - Iterator iter(data.data()); + Iterator iter(data.data()); - for (int i = 0; i < Iterator::sc0; ++i) { - for (int j = 0; j < Iterator::sc1; ++j) { - int start_n = i * kTileRowStride + j * kTileColStride; + for (int i = 0; i < Iterator::sc0; ++i) { + for (int j = 0; j < Iterator::sc1; ++j) { + int start_n = i * kTileRowStride + j * kTileColStride; - auto tile = iter(i, j); - for (int m = 0; m < kStride0; ++m) { - for (int n = 0; n < kStride1; ++n) { - int v1 = int(tile(m, n)); - int v2 = start_n + m + n * Layout::kColStride; + auto tile = iter(i, j); + for (int m = 0; m < kStride0; ++m) { + for (int n = 0; n < kStride1; ++n) { + int v1 = int(tile(m, n)); + int v2 = start_n + m + n * Layout::kColStride; - EXPECT_EQ(v1, v2); - } - } + EXPECT_EQ(v1, v2); + } + } #if defined(DEBUG_PRINT) - printf("\nIteration-[%d, %d]:\n", i, j); - iter(i, j).dump_value(); - printf("\n"); + printf("\nIteration-[%d, %d]:\n", i, j); + iter(i, j).dump_value(); + printf("\n"); #endif - } - } + } } + } }; } // namespace TEST(TestGTileIterator, test_row_major) { - using Tester = GTileIteratorTester, TileShape<2, 3>, - tl::Layout::kRowMajor>; - Tester tester; - tester(); + using Tester = GTileIteratorTester, TileShape<2, 3>, + tl::Layout::kRowMajor>; + Tester tester; + tester(); } TEST(TestGTileIterator, col_major) { - using Tester = GTileIteratorTester, TileShape<2, 3>, - tl::Layout::kColMajor>; - Tester tester; - tester(); + using Tester = GTileIteratorTester, TileShape<2, 3>, + tl::Layout::kColMajor>; + Tester tester; + tester(); } } // namespace tilefusion::testing diff --git a/tests/cpp/types/test_layout.cu b/tests/cpp/types/test_layout.cu index caf65184..98ee21ba 100644 --- a/tests/cpp/types/test_layout.cu +++ b/tests/cpp/types/test_layout.cu @@ -10,112 +10,112 @@ using namespace tilefusion::cell; namespace tl = tile_layout; TEST(TestLayout, test_layout) { - using Element = __half; - - using Layout1 = tl::RowMajor<4, 7>; - EXPECT_EQ(tl::num_rows, 4); - EXPECT_EQ(tl::num_cols, 7); - EXPECT_EQ(tl::get_numel, 28); - EXPECT_EQ(tl::row_stride, 7); - EXPECT_EQ(tl::col_stride, 1); - - tl::Layout type1 = tl::layout_type; - EXPECT_EQ(type1, tl::Layout::kRowMajor); - auto layout_name1 = layout_type_to_str(type1); - EXPECT_EQ(layout_name1, "RowMajor"); - - using Layout2 = tl::ColMajor<4, 7>; - EXPECT_EQ(tl::num_rows, 4); - EXPECT_EQ(tl::num_cols, 7); - EXPECT_EQ(tl::get_numel, 28); - EXPECT_EQ(tl::row_stride, 1); - EXPECT_EQ(tl::col_stride, 4); - - tl::Layout type2 = tl::layout_type; - EXPECT_EQ(type2, tl::Layout::kColMajor); - auto layout_name2 = layout_type_to_str(type2); - EXPECT_EQ(layout_name2, "ColMajor"); + using Element = __half; + + using Layout1 = tl::RowMajor<4, 7>; + EXPECT_EQ(tl::num_rows, 4); + EXPECT_EQ(tl::num_cols, 7); + EXPECT_EQ(tl::get_numel, 28); + EXPECT_EQ(tl::row_stride, 7); + EXPECT_EQ(tl::col_stride, 1); + + tl::Layout type1 = tl::layout_type; + EXPECT_EQ(type1, tl::Layout::kRowMajor); + auto layout_name1 = layout_type_to_str(type1); + EXPECT_EQ(layout_name1, "RowMajor"); + + using Layout2 = tl::ColMajor<4, 7>; + EXPECT_EQ(tl::num_rows, 4); + EXPECT_EQ(tl::num_cols, 7); + EXPECT_EQ(tl::get_numel, 28); + EXPECT_EQ(tl::row_stride, 1); + EXPECT_EQ(tl::col_stride, 4); + + tl::Layout type2 = tl::layout_type; + EXPECT_EQ(type2, tl::Layout::kColMajor); + auto layout_name2 = layout_type_to_str(type2); + EXPECT_EQ(layout_name2, "ColMajor"); } TEST(TestLayout, test_block_row_major) { - using Layout = tl::BlockRowMajor, tl::RowMajor<2, 3>>; + using Layout = tl::BlockRowMajor, tl::RowMajor<2, 3>>; - EXPECT_EQ(Layout::kTileRows, 7); - EXPECT_EQ(Layout::kTileCols, 3); - EXPECT_EQ(Layout::kRowStride, 18); - EXPECT_EQ(Layout::kColStride, 6); - EXPECT_EQ(Layout::kType, tl::Layout::kRowMajor); + EXPECT_EQ(Layout::kTileRows, 7); + EXPECT_EQ(Layout::kTileCols, 3); + EXPECT_EQ(Layout::kRowStride, 18); + EXPECT_EQ(Layout::kColStride, 6); + EXPECT_EQ(Layout::kType, tl::Layout::kRowMajor); - Layout layout; + Layout layout; #if defined(DEBUG) - layout.dump(); + layout.dump(); #endif - EXPECT_EQ(layout(2, 0), 18); - EXPECT_EQ(layout(2, 1), 19); - EXPECT_EQ(layout(4, 3), 42); - EXPECT_EQ(layout(4, 4), 43); + EXPECT_EQ(layout(2, 0), 18); + EXPECT_EQ(layout(2, 1), 19); + EXPECT_EQ(layout(4, 3), 42); + EXPECT_EQ(layout(4, 4), 43); } TEST(TestLayout, test_block_col_major) { - using Layout = tl::BlockColMajor, tl::ColMajor<2, 3>>; + using Layout = tl::BlockColMajor, tl::ColMajor<2, 3>>; - EXPECT_EQ(Layout::kTileRows, 7); - EXPECT_EQ(Layout::kTileCols, 3); - EXPECT_EQ(Layout::kRowStride, 6); - EXPECT_EQ(Layout::kColStride, 42); - EXPECT_EQ(Layout::kType, tl::Layout::kColMajor); + EXPECT_EQ(Layout::kTileRows, 7); + EXPECT_EQ(Layout::kTileCols, 3); + EXPECT_EQ(Layout::kRowStride, 6); + EXPECT_EQ(Layout::kColStride, 42); + EXPECT_EQ(Layout::kType, tl::Layout::kColMajor); - Layout layout; + Layout layout; #if defined(DEBUG) - layout.dump(); + layout.dump(); #endif - EXPECT_EQ(layout(6, 0), 18); - EXPECT_EQ(layout(7, 0), 19); - EXPECT_EQ(layout(0, 3), 42); - EXPECT_EQ(layout(1, 3), 43); + EXPECT_EQ(layout(6, 0), 18); + EXPECT_EQ(layout(7, 0), 19); + EXPECT_EQ(layout(0, 3), 42); + EXPECT_EQ(layout(1, 3), 43); } TEST(TestLayout, test_block_mixed1) { - using Layout = tl::BlockMixed, tl::ColMajor<2, 3>>; + using Layout = tl::BlockMixed, tl::ColMajor<2, 3>>; - EXPECT_EQ(Layout::kTileRows, 7); - EXPECT_EQ(Layout::kTileCols, 3); - EXPECT_EQ(Layout::kRowStride, 18); - EXPECT_EQ(Layout::kColStride, 6); - EXPECT_EQ(Layout::kType, tl::Layout::kRowMajor); + EXPECT_EQ(Layout::kTileRows, 7); + EXPECT_EQ(Layout::kTileCols, 3); + EXPECT_EQ(Layout::kRowStride, 18); + EXPECT_EQ(Layout::kColStride, 6); + EXPECT_EQ(Layout::kType, tl::Layout::kRowMajor); - Layout layout; + Layout layout; #if defined(DEBUG) - layout.dump(); + layout.dump(); #endif - EXPECT_EQ(layout(2, 0), 18); - EXPECT_EQ(layout(3, 0), 19); - EXPECT_EQ(layout(4, 3), 42); - EXPECT_EQ(layout(5, 3), 43); + EXPECT_EQ(layout(2, 0), 18); + EXPECT_EQ(layout(3, 0), 19); + EXPECT_EQ(layout(4, 3), 42); + EXPECT_EQ(layout(5, 3), 43); } TEST(TestLayout, test_block_mixed2) { - using Layout = tl::BlockMixed, tl::RowMajor<2, 3>>; + using Layout = tl::BlockMixed, tl::RowMajor<2, 3>>; - EXPECT_EQ(Layout::kTileRows, 7); - EXPECT_EQ(Layout::kTileCols, 3); - EXPECT_EQ(Layout::kRowStride, 6); - EXPECT_EQ(Layout::kColStride, 42); - EXPECT_EQ(Layout::kType, tl::Layout::kColMajor); + EXPECT_EQ(Layout::kTileRows, 7); + EXPECT_EQ(Layout::kTileCols, 3); + EXPECT_EQ(Layout::kRowStride, 6); + EXPECT_EQ(Layout::kColStride, 42); + EXPECT_EQ(Layout::kType, tl::Layout::kColMajor); - Layout layout; + Layout layout; #if defined(DEBUG) - layout.dump(); + layout.dump(); #endif - EXPECT_EQ(layout(6, 0), 18); - EXPECT_EQ(layout(6, 1), 19); - EXPECT_EQ(layout(0, 3), 42); - EXPECT_EQ(layout(0, 4), 43); + EXPECT_EQ(layout(6, 0), 18); + EXPECT_EQ(layout(6, 1), 19); + EXPECT_EQ(layout(0, 3), 42); + EXPECT_EQ(layout(0, 4), 43); } } // namespace tilefusion::testing diff --git a/tests/cpp/types/test_stile_iterator.cu b/tests/cpp/types/test_stile_iterator.cu index 41435ef1..d8a0bae7 100644 --- a/tests/cpp/types/test_stile_iterator.cu +++ b/tests/cpp/types/test_stile_iterator.cu @@ -16,209 +16,209 @@ namespace { /// @brief Initialize buffer with sequential values for testing template __device__ void init_buf(DType* buf, int numel) { - for (int i = 0; i < numel; ++i) { - buf[i] = static_cast(i); - } + for (int i = 0; i < numel; ++i) { + buf[i] = static_cast(i); + } } /// @brief Test kernel for shared tile iterator template __global__ void test_stile_iterator() { - using DType = typename Shared::DType; - extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; - DType* buf = reinterpret_cast(buf_); + using DType = typename Shared::DType; + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + DType* buf = reinterpret_cast(buf_); - init_buf(buf, Shared::kNumel); + init_buf(buf, Shared::kNumel); - Shared s_tile(buf); - SIterator s_itr(&s_tile); + Shared s_tile(buf); + SIterator s_itr(&s_tile); - printf("\nshared tile:\n"); - s_tile.dump_value(); + printf("\nshared tile:\n"); + s_tile.dump_value(); - for (int i = 0; i < SIterator::kNumel; ++i) { - printf("\nsub-tile %d:\n", i); - auto tile = s_itr(i); - tile.dump_value(); - } + for (int i = 0; i < SIterator::kNumel; ++i) { + printf("\nsub-tile %d:\n", i); + auto tile = s_itr(i); + tile.dump_value(); + } } } // namespace TEST(TestSharedTileIterator, row_major) { - using InType = __half; - static constexpr int kRows = 4; - static constexpr int kCols = 24; + using InType = __half; + static constexpr int kRows = 4; + static constexpr int kCols = 24; - static constexpr int kChunkRows = 4; - static constexpr int kChunkCols = 8; + static constexpr int kChunkRows = 4; + static constexpr int kChunkCols = 8; - using SharedLayout = tl::RowMajor; - using Shared = SharedTile; - using SIterator = STileIterator2>; + using SharedLayout = tl::RowMajor; + using Shared = SharedTile; + using SIterator = STileIterator2>; - LOG(INFO) << std::endl << Shared{} << std::endl; - LOG(INFO) << std::endl << SIterator{} << std::endl; + LOG(INFO) << std::endl << Shared{} << std::endl; + LOG(INFO) << std::endl << SIterator{} << std::endl; - using SubTileLayout = - SubTileLayoutCreator; - using Layout = typename SubTileLayout::type; + using SubTileLayout = + SubTileLayoutCreator; + using Layout = typename SubTileLayout::type; - LOG(INFO) << std::endl - << "SubTileLayout: " << std::endl - << "\t" << Layout{} << std::endl; + LOG(INFO) << std::endl + << "SubTileLayout: " << std::endl + << "\t" << Layout{} << std::endl; - int shm_size = Shared::kNumel * sizeof(InType); - dim3 blocks(1, 1, 1); - dim3 threads(1, 1, 1); - test_stile_iterator<<>>(); - cudaDeviceSynchronize(); + int shm_size = Shared::kNumel * sizeof(InType); + dim3 blocks(1, 1, 1); + dim3 threads(1, 1, 1); + test_stile_iterator<<>>(); + cudaDeviceSynchronize(); } TEST(TestSharedTileIterator, col_major) { - using InType = __half; - static constexpr int kRows = 24; - static constexpr int kCols = 4; + using InType = __half; + static constexpr int kRows = 24; + static constexpr int kCols = 4; - static constexpr int kChunkRows = 8; - static constexpr int kChunkCols = 4; + static constexpr int kChunkRows = 8; + static constexpr int kChunkCols = 4; - using SharedLayout = tl::ColMajor; - using Shared = SharedTile; - using SIterator = STileIterator2>; + using SharedLayout = tl::ColMajor; + using Shared = SharedTile; + using SIterator = STileIterator2>; - LOG(INFO) << std::endl << Shared{} << std::endl; - LOG(INFO) << std::endl << SIterator{} << std::endl; + LOG(INFO) << std::endl << Shared{} << std::endl; + LOG(INFO) << std::endl << SIterator{} << std::endl; - using SubTileLayout = - SubTileLayoutCreator; - using Layout = typename SubTileLayout::type; + using SubTileLayout = + SubTileLayoutCreator; + using Layout = typename SubTileLayout::type; - LOG(INFO) << std::endl - << "SubTileLayout: " << std::endl - << "\t" << Layout{} << std::endl; + LOG(INFO) << std::endl + << "SubTileLayout: " << std::endl + << "\t" << Layout{} << std::endl; - int shm_size = Shared::kNumel * sizeof(InType); - dim3 blocks(1, 1, 1); - dim3 threads(1, 1, 1); - test_stile_iterator<<>>(); - cudaDeviceSynchronize(); + int shm_size = Shared::kNumel * sizeof(InType); + dim3 blocks(1, 1, 1); + dim3 threads(1, 1, 1); + test_stile_iterator<<>>(); + cudaDeviceSynchronize(); } TEST(TestSharedTileIterator, block_row_major) { - using InType = __half; - static constexpr int kRows = 4; - static constexpr int kCols = 16; + using InType = __half; + static constexpr int kRows = 4; + static constexpr int kCols = 16; - static constexpr int kChunkRows = 4; - static constexpr int kChunkCols = 8; + static constexpr int kChunkRows = 4; + static constexpr int kChunkCols = 8; - using SharedLayout = - tl::BlockRowMajor, tl::RowMajor<2, 4>>; + using SharedLayout = + tl::BlockRowMajor, tl::RowMajor<2, 4>>; - std::cout << "SharedLayout: " << std::endl << SharedLayout{} << std::endl; + std::cout << "SharedLayout: " << std::endl << SharedLayout{} << std::endl; - using Shared = SharedTile; - using SIterator = STileIterator2>; + using Shared = SharedTile; + using SIterator = STileIterator2>; - LOG(INFO) << std::endl << Shared{} << std::endl; - LOG(INFO) << std::endl << SIterator{} << std::endl; + LOG(INFO) << std::endl << Shared{} << std::endl; + LOG(INFO) << std::endl << SIterator{} << std::endl; - using SubTileLayout = - SubTileLayoutCreator; - using Layout = typename SubTileLayout::type; + using SubTileLayout = + SubTileLayoutCreator; + using Layout = typename SubTileLayout::type; - LOG(INFO) << std::endl - << "SubTileLayout: " << std::endl - << Layout{} << std::endl; + LOG(INFO) << std::endl + << "SubTileLayout: " << std::endl + << Layout{} << std::endl; - int shm_size = Shared::kNumel * sizeof(InType); - dim3 blocks(1, 1, 1); - dim3 threads(1, 1, 1); - test_stile_iterator<<>>(); - cudaDeviceSynchronize(); + int shm_size = Shared::kNumel * sizeof(InType); + dim3 blocks(1, 1, 1); + dim3 threads(1, 1, 1); + test_stile_iterator<<>>(); + cudaDeviceSynchronize(); } TEST(TestSharedTileIterator, block_col_major) { - using InType = __half; - static constexpr int kRows = 16; - static constexpr int kCols = 4; + using InType = __half; + static constexpr int kRows = 16; + static constexpr int kCols = 4; - static constexpr int kChunkRows = 8; - static constexpr int kChunkCols = 4; + static constexpr int kChunkRows = 8; + static constexpr int kChunkCols = 4; - using SharedLayout = - tl::BlockColMajor, tl::ColMajor<4, 2>>; + using SharedLayout = + tl::BlockColMajor, tl::ColMajor<4, 2>>; - std::cout << "SharedLayout: " << std::endl << SharedLayout{} << std::endl; + std::cout << "SharedLayout: " << std::endl << SharedLayout{} << std::endl; - using Shared = SharedTile; - using SIterator = STileIterator2>; + using Shared = SharedTile; + using SIterator = STileIterator2>; - LOG(INFO) << std::endl << Shared{} << std::endl; - LOG(INFO) << std::endl << SIterator{} << std::endl; + LOG(INFO) << std::endl << Shared{} << std::endl; + LOG(INFO) << std::endl << SIterator{} << std::endl; - using SubTileLayout = - SubTileLayoutCreator; - using Layout = typename SubTileLayout::type; + using SubTileLayout = + SubTileLayoutCreator; + using Layout = typename SubTileLayout::type; - LOG(INFO) << std::endl - << "SubTileLayout: " << std::endl - << Layout{} << std::endl; + LOG(INFO) << std::endl + << "SubTileLayout: " << std::endl + << Layout{} << std::endl; - int shm_size = Shared::kNumel * sizeof(InType); - dim3 blocks(1, 1, 1); - dim3 threads(1, 1, 1); - test_stile_iterator<<>>(); - cudaDeviceSynchronize(); + int shm_size = Shared::kNumel * sizeof(InType); + dim3 blocks(1, 1, 1); + dim3 threads(1, 1, 1); + test_stile_iterator<<>>(); + cudaDeviceSynchronize(); } TEST(TestSharedTileIterator, block_swizzled_row_major) { - using InType = __half; - static constexpr int kRows = 8; - static constexpr int kCols = 16; + using InType = __half; + static constexpr int kRows = 8; + static constexpr int kCols = 16; - static constexpr int kChunkRows = 8; - static constexpr int kChunkCols = 4; + static constexpr int kChunkRows = 8; + static constexpr int kChunkCols = 4; - using SharedLayout = - tl::BlockRowMajor, - SwizzledLayout, Swizzle<1, 0, 2>>>; + using SharedLayout = + tl::BlockRowMajor, + SwizzledLayout, Swizzle<1, 0, 2>>>; - using Shared = SharedTile; - using SIterator = STileIterator2>; + using Shared = SharedTile; + using SIterator = STileIterator2>; - LOG(INFO) << std::endl << Shared{} << std::endl; - LOG(INFO) << std::endl << SIterator{} << std::endl; + LOG(INFO) << std::endl << Shared{} << std::endl; + LOG(INFO) << std::endl << SIterator{} << std::endl; - int shm_size = Shared::kNumel * sizeof(InType); - dim3 blocks(1, 1, 1); - dim3 threads(1, 1, 1); - test_stile_iterator<<>>(); - cudaDeviceSynchronize(); + int shm_size = Shared::kNumel * sizeof(InType); + dim3 blocks(1, 1, 1); + dim3 threads(1, 1, 1); + test_stile_iterator<<>>(); + cudaDeviceSynchronize(); } TEST(TestSharedTileIterator, block_swizzled_col_major) { - using InType = __half; - static constexpr int kRows = 32; - static constexpr int kCols = 8; + using InType = __half; + static constexpr int kRows = 32; + static constexpr int kCols = 8; - static constexpr int kChunkRows = 8; - static constexpr int kChunkCols = 8; + static constexpr int kChunkRows = 8; + static constexpr int kChunkCols = 8; - using SharedLayout = - tl::BlockColMajor, - SwizzledLayout, Swizzle<2, 0, 2>>>; + using SharedLayout = + tl::BlockColMajor, + SwizzledLayout, Swizzle<2, 0, 2>>>; - using Shared = SharedTile; - using SIterator = STileIterator2>; + using Shared = SharedTile; + using SIterator = STileIterator2>; - LOG(INFO) << std::endl << Shared{} << std::endl; - LOG(INFO) << std::endl << SIterator{} << std::endl; + LOG(INFO) << std::endl << Shared{} << std::endl; + LOG(INFO) << std::endl << SIterator{} << std::endl; - int shm_size = Shared::kNumel * sizeof(InType); - dim3 blocks(1, 1, 1); - dim3 threads(1, 1, 1); - test_stile_iterator<<>>(); - cudaDeviceSynchronize(); + int shm_size = Shared::kNumel * sizeof(InType); + dim3 blocks(1, 1, 1); + dim3 threads(1, 1, 1); + test_stile_iterator<<>>(); + cudaDeviceSynchronize(); } } // namespace tilefusion::testing diff --git a/tests/cpp/types/test_swizzle.cu b/tests/cpp/types/test_swizzle.cu index cc44cda2..837bbbc0 100644 --- a/tests/cpp/types/test_swizzle.cu +++ b/tests/cpp/types/test_swizzle.cu @@ -18,90 +18,90 @@ int flatten(int x, int y, int width) { return x * width + y; } template int swizzle_ref(int x, int y) { - int b = x; - int s = y >> kM; + int b = x; + int s = y >> kM; - int swizzled_s = b ^ s; - int swizzle_idx = - (b << (kM + kS)) | (swizzled_s << kM) | (y & ((1 << kM) - 1)); + int swizzled_s = b ^ s; + int swizzle_idx = + (b << (kM + kS)) | (swizzled_s << kM) | (y & ((1 << kM) - 1)); - return swizzle_idx; + return swizzle_idx; } template int2 test_swizzle(int x, int y) { - Swizzle swizzle; - int idx = flatten(x, y, 1 << (kS + kM)); - int swizzled_idx = swizzle(idx); + Swizzle swizzle; + int idx = flatten(x, y, 1 << (kS + kM)); + int swizzled_idx = swizzle(idx); - int ref_swizzled_idx = swizzle_ref(x, y); + int ref_swizzled_idx = swizzle_ref(x, y); #ifdef DEBUG - printf("idx: %d, swizzled_idx: %d, ref_swizzled_idx: %d\n", idx, - swizzled_idx, ref_swizzled_idx); + printf("idx: %d, swizzled_idx: %d, ref_swizzled_idx: %d\n", idx, swizzled_idx, + ref_swizzled_idx); #endif - return make_int2(swizzled_idx, ref_swizzled_idx); + return make_int2(swizzled_idx, ref_swizzled_idx); } } // namespace TEST(TestSwizzle, test_swizzle_function) { - const int kB = 3; - const int kM = 3; - const int kS = 3; - - int2 swizzled_idx_0_0 = test_swizzle(0, 0); - int2 swizzled_idx_1_0 = test_swizzle(1, 0); - int2 swizzled_idx_1_4 = test_swizzle(1, 4); - int2 swizzled_idx_2_0 = test_swizzle(2, 0); - int2 swizzled_idx_2_4 = test_swizzle(2, 4); - - EXPECT_EQ(swizzled_idx_0_0.x, swizzled_idx_0_0.y); - EXPECT_EQ(swizzled_idx_1_0.x, swizzled_idx_1_0.y); - EXPECT_EQ(swizzled_idx_1_4.x, swizzled_idx_1_4.y); - EXPECT_EQ(swizzled_idx_2_0.x, swizzled_idx_2_0.y); - EXPECT_EQ(swizzled_idx_2_4.x, swizzled_idx_2_4.y); + const int kB = 3; + const int kM = 3; + const int kS = 3; + + int2 swizzled_idx_0_0 = test_swizzle(0, 0); + int2 swizzled_idx_1_0 = test_swizzle(1, 0); + int2 swizzled_idx_1_4 = test_swizzle(1, 4); + int2 swizzled_idx_2_0 = test_swizzle(2, 0); + int2 swizzled_idx_2_4 = test_swizzle(2, 4); + + EXPECT_EQ(swizzled_idx_0_0.x, swizzled_idx_0_0.y); + EXPECT_EQ(swizzled_idx_1_0.x, swizzled_idx_1_0.y); + EXPECT_EQ(swizzled_idx_1_4.x, swizzled_idx_1_4.y); + EXPECT_EQ(swizzled_idx_2_0.x, swizzled_idx_2_0.y); + EXPECT_EQ(swizzled_idx_2_4.x, swizzled_idx_2_4.y); } TEST(TestSwizzle, test_swizzled_row_major) { - using BlockRowMajor = tl::BlockRowMajor< - tl::RowMajor<16, 64>, - SwizzledLayout, Swizzle<3, 3, 3>>>; - - // for unit test - using Atom = - decltype(composition(cute::Swizzle<3, 3, 3>{}, - cute::Layout, Stride<_64, _1>>{})); - using CuteLayout = - decltype(tile_to_shape(Atom{}, Shape<_16, _64>{}, Step<_16, _1>{})); - - BlockRowMajor layout1; - CuteLayout layout2; - - for (int i = 0; i < int(size<0>(layout2)); ++i) { - for (int j = 0; j < int(size<1>(layout2)); ++j) { - EXPECT_EQ(layout1(i, j), layout2(i, j)); - } + using BlockRowMajor = + tl::BlockRowMajor, + SwizzledLayout, Swizzle<3, 3, 3>>>; + + // for unit test + using Atom = + decltype(composition(cute::Swizzle<3, 3, 3>{}, + cute::Layout, Stride<_64, _1>>{})); + using CuteLayout = + decltype(tile_to_shape(Atom{}, Shape<_16, _64>{}, Step<_16, _1>{})); + + BlockRowMajor layout1; + CuteLayout layout2; + + for (int i = 0; i < int(size<0>(layout2)); ++i) { + for (int j = 0; j < int(size<1>(layout2)); ++j) { + EXPECT_EQ(layout1(i, j), layout2(i, j)); } + } } TEST(TestSwizzle, test_swizzled_col_major) { - using BlockColMajor = tl::BlockColMajor< - tl::ColMajor<64, 16>, - SwizzledLayout, Swizzle<3, 3, 3>>>; + using BlockColMajor = + tl::BlockColMajor, + SwizzledLayout, Swizzle<3, 3, 3>>>; - using Atom = decltype(composition(cute::Swizzle<3, 3, 3>{}, - cute::Layout>{})); - using CuteLayout = decltype(tile_to_shape(Atom{}, Shape<_64, _16>{})); + using Atom = decltype(composition(cute::Swizzle<3, 3, 3>{}, + cute::Layout>{})); + using CuteLayout = decltype(tile_to_shape(Atom{}, Shape<_64, _16>{})); - BlockColMajor layout1; - CuteLayout layout2; + BlockColMajor layout1; + CuteLayout layout2; - for (int i = 0; i < int(size<0>(layout2)); ++i) { - for (int j = 0; j < int(size<1>(layout2)); ++j) { - EXPECT_EQ(layout1(i, j), layout2(i, j)); - } + for (int i = 0; i < int(size<0>(layout2)); ++i) { + for (int j = 0; j < int(size<1>(layout2)); ++j) { + EXPECT_EQ(layout1(i, j), layout2(i, j)); } + } } } // namespace tilefusion::testing diff --git a/tests/cpp/types/test_warp_base_tile_shape.cu b/tests/cpp/types/test_warp_base_tile_shape.cu index 0d039979..f1ee4232 100644 --- a/tests/cpp/types/test_warp_base_tile_shape.cu +++ b/tests/cpp/types/test_warp_base_tile_shape.cu @@ -8,182 +8,182 @@ namespace tilefusion::testing { namespace tl = tile_layout; TEST(InferAtomicWarpTile, test1_half_row_major) { - using DType = __half; - const tl::Layout kLayout = tl::Layout::kRowMajor; + using DType = __half; + const tl::Layout kLayout = tl::Layout::kRowMajor; - { // atomic warp shape: 32x8, thread layout: 32x1 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 32x8, thread layout: 32x1 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 32); - EXPECT_EQ(WarpTile::kCols, 8); + EXPECT_EQ(WarpTile::kRows, 32); + EXPECT_EQ(WarpTile::kCols, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 32); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 1); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 32); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 1); + } - { // atomic warp shape: 16x16, thread layout: 16x2 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 16x16, thread layout: 16x2 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 16); - EXPECT_EQ(WarpTile::kCols, 16); + EXPECT_EQ(WarpTile::kRows, 16); + EXPECT_EQ(WarpTile::kCols, 16); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 16); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 2); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 16); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 2); + } - { // atomic warp shape: 8x32, thread layout: 8x4 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 8x32, thread layout: 8x4 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 8); - EXPECT_EQ(WarpTile::kCols, 32); + EXPECT_EQ(WarpTile::kRows, 8); + EXPECT_EQ(WarpTile::kCols, 32); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); + } - { // atomic warp shape: 4x64, thread layout: 4x8 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 4x64, thread layout: 4x8 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 4); - EXPECT_EQ(WarpTile::kCols, 64); + EXPECT_EQ(WarpTile::kRows, 4); + EXPECT_EQ(WarpTile::kCols, 64); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); + } } TEST(InferAtomicWarpTile, test2_half_column_major) { - using DType = __half; - const tl::Layout kLayout = tl::Layout::kColMajor; + using DType = __half; + const tl::Layout kLayout = tl::Layout::kColMajor; - { // atomic warp shape: 8x32, thread layout: 1x32 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 8x32, thread layout: 1x32 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 8); - EXPECT_EQ(WarpTile::kCols, 32); + EXPECT_EQ(WarpTile::kRows, 8); + EXPECT_EQ(WarpTile::kCols, 32); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 1); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 32); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 1); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 32); + } - { // atomic warp shape: 16x16, thread layout: 2x16 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 16x16, thread layout: 2x16 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 16); - EXPECT_EQ(WarpTile::kCols, 16); + EXPECT_EQ(WarpTile::kRows, 16); + EXPECT_EQ(WarpTile::kCols, 16); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 2); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 16); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 2); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 16); + } - { // atomic warp shape: 32x8, thread layout: 4x8 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 32x8, thread layout: 4x8 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 32); - EXPECT_EQ(WarpTile::kCols, 8); + EXPECT_EQ(WarpTile::kRows, 32); + EXPECT_EQ(WarpTile::kCols, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); + } - { // atomic warp shape: 64x4, thread layout: 8x4 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 64x4, thread layout: 8x4 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 64); - EXPECT_EQ(WarpTile::kCols, 4); + EXPECT_EQ(WarpTile::kRows, 64); + EXPECT_EQ(WarpTile::kCols, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); + } } TEST(InferAtomicWarpTile, test3_float_row_major) { - using DType = float; - const tl::Layout kLayout = tl::Layout::kRowMajor; + using DType = float; + const tl::Layout kLayout = tl::Layout::kRowMajor; - { // atomic warp shape: 32x4, thread layout: 32x1 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 32x4, thread layout: 32x1 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 32); - EXPECT_EQ(WarpTile::kCols, 4); + EXPECT_EQ(WarpTile::kRows, 32); + EXPECT_EQ(WarpTile::kCols, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 32); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 1); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 32); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 1); + } - { // atomic warp shape: 16x8, thread layout: 16x2 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 16x8, thread layout: 16x2 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 16); - EXPECT_EQ(WarpTile::kCols, 8); + EXPECT_EQ(WarpTile::kRows, 16); + EXPECT_EQ(WarpTile::kCols, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 16); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 2); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 16); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 2); + } - { // atomic warp shape: 8x16, thread layout: 8x4 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 8x16, thread layout: 8x4 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 8); - EXPECT_EQ(WarpTile::kCols, 16); + EXPECT_EQ(WarpTile::kRows, 8); + EXPECT_EQ(WarpTile::kCols, 16); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); + } - { // atomic warp shape: 4x32, thread layout: 4x8 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 4x32, thread layout: 4x8 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 4); - EXPECT_EQ(WarpTile::kCols, 32); + EXPECT_EQ(WarpTile::kRows, 4); + EXPECT_EQ(WarpTile::kCols, 32); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); + } } TEST(InferAtomicWarpTile, test4_float_column_major) { - using DType = float; - const tl::Layout kLayout = tl::Layout::kColMajor; + using DType = float; + const tl::Layout kLayout = tl::Layout::kColMajor; - { // atomic warp shape: 4x32, thread layout: 1x32 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 4x32, thread layout: 1x32 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 4); - EXPECT_EQ(WarpTile::kCols, 32); + EXPECT_EQ(WarpTile::kRows, 4); + EXPECT_EQ(WarpTile::kCols, 32); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 1); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 32); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 1); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 32); + } - { // atomic warp shape: 8x16, thread layout: 2x16 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 8x16, thread layout: 2x16 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 8); - EXPECT_EQ(WarpTile::kCols, 16); + EXPECT_EQ(WarpTile::kRows, 8); + EXPECT_EQ(WarpTile::kCols, 16); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 2); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 16); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 2); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 16); + } - { // atomic warp shape: 16x8, thread layout: 4x8 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 16x8, thread layout: 4x8 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 16); - EXPECT_EQ(WarpTile::kCols, 8); + EXPECT_EQ(WarpTile::kRows, 16); + EXPECT_EQ(WarpTile::kCols, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 4); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 8); + } - { // atomic warp shape: 4x32, thread layout: 8x4 - using WarpTile = WarpBaseTileShape, kLayout>; + { // atomic warp shape: 4x32, thread layout: 8x4 + using WarpTile = WarpBaseTileShape, kLayout>; - EXPECT_EQ(WarpTile::kRows, 32); - EXPECT_EQ(WarpTile::kCols, 4); + EXPECT_EQ(WarpTile::kRows, 32); + EXPECT_EQ(WarpTile::kCols, 4); - EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); - EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); - } + EXPECT_EQ(WarpTile::WarpThreadLayout::kRows, 8); + EXPECT_EQ(WarpTile::WarpThreadLayout::kCols, 4); + } } } // namespace tilefusion::testing