Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
BasedOnStyle: Google
UseTab: Never
ColumnLimit: 80
IndentWidth: 4
IndentWidth: 2

AccessModifierOffset: -2
AccessModifierOffset: -1

DerivePointerAlignment: false
PointerAlignment: Left
Expand Down
60 changes: 30 additions & 30 deletions benchmarks/cpp/flashattention/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,58 @@ using namespace cute;

template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto convert_type(cute::Tensor<Engine, Layout> const& tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}

template <typename Layout>
DEVICE auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using namespace cute;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
auto l = logical_divide(rowcol_layout,
Shape<Underscore, Shape<Underscore, Int<2>>>{});
using namespace cute;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
auto l = logical_divide(rowcol_layout,
Shape<Underscore, Shape<Underscore, Int<2>>>{});

return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)),
get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)), get<1>(get<1>(get<1>(l))));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)),
get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)), get<1>(get<1>(get<1>(l))));
}

DEVICE auto convert_layout_C_Aregs() {
using namespace cute;
auto layout_s = Layout<Shape<Shape<_2, _2>, _2, _16>>{};
auto l = logical_divide(layout_s, Shape<Underscore, Underscore, _2>{});
using namespace cute;
auto layout_s = Layout<Shape<Shape<_2, _2>, _2, _16>>{};
auto l = logical_divide(layout_s, Shape<Underscore, Underscore, _2>{});

return make_layout(
make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))),
get<1>(l), get<1>(get<2>(l)));
return make_layout(
make_layout(get<0>(get<0>(l)), get<1>(get<0>(l)), get<0>(get<2>(l))),
get<1>(l), get<1>(get<2>(l)));
}

/**
* @brief Convert a 3d register tensor into a 2d register tensor.
*/
template <class LayoutType>
DEVICE auto convert_layout_scores(LayoutType layout_s) {
using namespace cute;
static_assert(decltype(size<0>(layout_s))::value == 4);
static_assert(decltype(rank(layout_s))::value == 3);
using namespace cute;
static_assert(decltype(size<0>(layout_s))::value == 4);
static_assert(decltype(rank(layout_s))::value == 3);

auto l = logical_divide(layout_s, Shape<_2>{});
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)),
make_layout(get<0>(get<0>(l)), get<2>(l)));
auto l = logical_divide(layout_s, Shape<_2>{});
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)),
make_layout(get<0>(get<0>(l)), get<2>(l)));
}

template <int ATOMNUM, class LayoutType>
DEVICE auto convert_layout_scores_copyview(LayoutType layout_s) {
using namespace cute;
using namespace cute;

auto l = logical_divide(layout_s, Shape<Underscore, Int<ATOMNUM>>{});
return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l)));
auto l = logical_divide(layout_s, Shape<Underscore, Int<ATOMNUM>>{});
return make_layout(get<0>(get<1>(l)), get<0>(l), get<1>(get<1>(l)));
}

} // namespace cutlass_wrapper
Expand Down
Loading