Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,38 @@ message(STATUS "tilefusion: CUDA detected: " ${CUDA_VERSION})
message(STATUS "tilefusion: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
message(STATUS "tilefusion: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})

# Detect GPU architecture for FP8 support
if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12"
OR (${CUDA_VERSION_MAJOR} VERSION_EQUAL "11" AND ${CUDA_VERSION_MINOR}
VERSION_GREATER_EQUAL "8"))

cuda_select_nvcc_arch_flags(FP8_ARCH_FLAGS Auto)

set(FP8_SUPPORT_DETECTED FALSE)
string(REGEX MATCHALL "compute_([0-9]+)" COMPUTE_ARCHS "${FP8_ARCH_FLAGS}")
foreach(compute_arch ${COMPUTE_ARCHS})
string(REGEX REPLACE "compute_([0-9]+)" "\\1" arch_num "${compute_arch}")
if(arch_num GREATER_EQUAL 89)
set(FP8_SUPPORT_DETECTED TRUE)
message(
STATUS "tilefusion: FP8-capable architecture detected: sm_${arch_num}")
break()
endif()
endforeach()

if(FP8_SUPPORT_DETECTED)
message(STATUS "tilefusion: FP8 hardware support detected - enabling FP8")
add_compile_definitions(CUDA_FP8_HARDWARE_AVAILABLE=1)
else()
message(STATUS "tilefusion: FP8 hardware support NOT detected")
add_compile_definitions(CUDA_FP8_HARDWARE_AVAILABLE=0)
endif()
else()
message(STATUS "tilefusion: CUDA version ${CUDA_VERSION} "
"does not support FP8 (requires 11.8+)")
add_compile_definitions(CUDA_FP8_HARDWARE_AVAILABLE=0)
endif()

if(ENABLE_DEBUG)
message(STATUS "tilefusion: Debug mode enabled")
set(CMAKE_BUILD_TYPE Debug)
Expand Down
59 changes: 51 additions & 8 deletions include/cell/compute/math_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include "cuda_utils.hpp"
#include "types/base.hpp"

namespace tilefusion::cell::compute {

Expand Down Expand Up @@ -146,14 +146,12 @@ struct Relu<__half> {
};
#endif

template <typename SrcType, typename DstType>
template <typename InType, typename OutType>
struct Convert {
DEVICE DstType operator()(SrcType a) const {
return static_cast<DstType>(a);
}
DEVICE OutType operator()(InType a) const { return OutType(a); }

DEVICE void operator()(const SrcType& src, DstType& dst) {
dst = static_cast<DstType>(src);
DEVICE void operator()(const InType& src, OutType& dst) {
dst = OutType(src);
}
};

Expand All @@ -176,6 +174,51 @@ struct Convert<__half, float> {
dst = __half2float(src);
}
};
#endif

#ifdef CUDA_FP8_AVAILABLE
// FP8 E4M3 conversions
template <>
struct Convert<float, __nv_fp8_e4m3> {
DEVICE __nv_fp8_e4m3 operator()(float a) const {
return from_float<__nv_fp8_e4m3>(a);
}

DEVICE void operator()(const float& src, __nv_fp8_e4m3& dst) {
dst = from_float<__nv_fp8_e4m3>(src);
}
};

template <>
struct Convert<__nv_fp8_e4m3, float> {
DEVICE float operator()(__nv_fp8_e4m3 a) const { return to_float(a); }

DEVICE void operator()(const __nv_fp8_e4m3& src, float& dst) {
dst = to_float(src);
}
};

// FP8 E5M2 conversions
template <>
struct Convert<float, __nv_fp8_e5m2> {
DEVICE __nv_fp8_e5m2 operator()(float a) const {
return from_float<__nv_fp8_e5m2>(a);
}

DEVICE void operator()(const float& src, __nv_fp8_e5m2& dst) {
dst = from_float<__nv_fp8_e5m2>(src);
}
};

template <>
struct Convert<__nv_fp8_e5m2, float> {
DEVICE float operator()(__nv_fp8_e5m2 a) const { return to_float(a); }

DEVICE void operator()(const __nv_fp8_e5m2& src, float& dst) {
dst = to_float(src);
}
};
#endif // CUDA_FP8_AVAILABLE

#endif // defined(__CUDA_ARCH__)

} // namespace tilefusion::cell::compute
4 changes: 2 additions & 2 deletions include/cell/copy/copy_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/
#pragma once

#include "traits/base.hpp"
#include "types/base.hpp"
#include "types/layout.hpp"

namespace tilefusion::cell::copy::atom {
Expand Down Expand Up @@ -150,7 +150,7 @@ DEVICE void ld_shared_st_global<16>(void* dst, uint32_t src) {
} // namespace

template <typename Element>
requires traits::HalfType<Element>
requires HalfType<Element>
struct LoadMatBase {
using DType = Element;
using ThreadLayout = tile_layout::ColMajor<16, 2>;
Expand Down
25 changes: 8 additions & 17 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#pragma once

#include "cell/copy/mod.hpp"
#include "traits/base.hpp"
#include "types/mod.hpp"

namespace tilefusion::cell::copy {
Expand Down Expand Up @@ -77,11 +76,9 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
}

private:
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;
static constexpr int kNumPerAccess = AccessBase<DType>::kNumPerAccess;

static constexpr int kAccessInBytes =
traits::AccessBase<DType>::kAccessInBytes;
static constexpr int kAccessInBytes = AccessBase<DType>::kAccessInBytes;

using SrcLayout = tl::MatrixLayout<kRowExec, kColExec,
BaseShape::kRows * Global::kRowStride,
Expand Down Expand Up @@ -168,11 +165,9 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
}

private:
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;
static constexpr int kNumPerAccess = AccessBase<DType>::kNumPerAccess;

static constexpr int kAccessInBytes =
traits::AccessBase<DType>::kAccessInBytes;
static constexpr int kAccessInBytes = AccessBase<DType>::kAccessInBytes;

using SrcLayout = tl::MatrixLayout<kRowExec, kColExec, BaseShape::kRows,
BaseShape::kCols * Global::kColStride>;
Expand Down Expand Up @@ -257,8 +252,7 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape_, kRowExec_,
}

private:
static constexpr int kAccessInBytes =
traits::AccessBase<DType>::kAccessInBytes;
static constexpr int kAccessInBytes = AccessBase<DType>::kAccessInBytes;

using DstLayout = tl::MatrixLayout<kRowExec, kColExec,
BaseShape::kRows * Global::kRowStride,
Expand All @@ -271,8 +265,7 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape_, kRowExec_,
// consistent with those used in `SharedLayoutWrapper` within the
// register-to-shared storer.
static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8);
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;
static constexpr int kNumPerAccess = AccessBase<DType>::kNumPerAccess;

using GlobalLayout = tl::MatrixLayout<BaseShape::kRows, BaseShape::kCols,
Global::kRowStride, 1>;
Expand Down Expand Up @@ -342,8 +335,7 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape_, kRowExec_,
}

private:
static constexpr int kAccessInBytes =
traits::AccessBase<DType>::kAccessInBytes;
static constexpr int kAccessInBytes = AccessBase<DType>::kAccessInBytes;

using DstLayout = tl::MatrixLayout<kRowExec, kColExec, BaseShape::kRows,
BaseShape::kCols * Global::kColStride>;
Expand All @@ -355,8 +347,7 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape_, kRowExec_,
// consistent with those used in `SharedLayoutWrapper` within the
// register-to-shared storer.
static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8);
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;
static constexpr int kNumPerAccess = AccessBase<DType>::kNumPerAccess;

using GlobalLayout = tl::MatrixLayout<BaseShape::kRows, BaseShape::kCols, 1,
Global::kColStride>;
Expand Down
1 change: 0 additions & 1 deletion include/cell/copy/shared_to_register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#pragma once

#include "cell/copy/mod.hpp"
#include "traits/base.hpp"
#include "types/mod.hpp"

namespace tilefusion::cell::copy {
Expand Down
1 change: 0 additions & 1 deletion include/cell/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
#include "cell/mask.hpp"
#include "cell/pipeline.hpp"
#include "cell/warp.hpp"
#include "traits/base.hpp"
#include "types/mod.hpp"
7 changes: 7 additions & 0 deletions include/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define CP_ASYNC_SM80_ENABLED
#endif

// FP8 support requires CUDA 11.8+ AND Ada Lovelace (8.9+) or Hopper (9.0+)
// architecture. The hardware support is detected at build time by CMake.
#if defined(CUDA_FP8_HARDWARE_AVAILABLE) && CUDA_FP8_HARDWARE_AVAILABLE == 1
#include <cuda_fp8.h>
#define CUDA_FP8_AVAILABLE 1
#endif
2 changes: 1 addition & 1 deletion include/jit/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include "traits/base.hpp"
#include "types/base.hpp"

namespace tilefusion::jit {
template <typename DType>
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/flash_attention_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct FlashAttentionTraits {

static constexpr int kSharedAccess = 64;

using BaseShape = traits::BaseTileShape<InType>;
using BaseShape = BaseTileShape<InType>;

static constexpr int kM = dim_size<0, WholeShape>; // query length
static constexpr int kN = dim_size<1, WholeShape>; // key/value length
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/fused_two_gemms_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct FusedTwoGemmsTraits {
using SharedA = SharedTile<InType, tl::RowMajor<kTM, kTK>, kUseSwizzling,
kSharedAccess>;

using BaseShape = traits::BaseTileShape<InType>;
using BaseShape = BaseTileShape<InType>;
static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kRows;
static constexpr int kAKs = kTK / BaseShape::kCols;
using RegA = RegTile<BaseTileRowMajor<InType>, tl::RowMajor<kAMs, kAKs>>;
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/gemm_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ template <typename InType_, typename AccType_, typename WarpLayout, //
struct KeGemmTraits {
using InType = InType_;
using AccType = AccType_;
using BaseShape = traits::BaseTileShape<InType>;
using BaseShape = BaseTileShape<InType>;
static constexpr int kNumStages = kNumStages_;

static constexpr int kThreads = tl::get_numel<WarpLayout> * 32;
Expand Down
51 changes: 0 additions & 51 deletions include/traits/base.hpp

This file was deleted.

Loading