diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h index c1499a83f39..9a99aecf992 100644 --- a/backends/aoti/slim/c10/core/ScalarType.h +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -28,7 +28,7 @@ enum class ScalarType : int8_t { Short = 2, // int16_t Int = 3, // int32_t Long = 4, // int64_t - // Half = 5, // float16 - not currently needed + Half = 5, // float16 Float = 6, // float // Double = 7, // double - not currently needed // ComplexHalf = 8, @@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char; constexpr ScalarType kShort = ScalarType::Short; constexpr ScalarType kInt = ScalarType::Int; constexpr ScalarType kLong = ScalarType::Long; +constexpr ScalarType kHalf = ScalarType::Half; constexpr ScalarType kFloat = ScalarType::Float; constexpr ScalarType kBool = ScalarType::Bool; constexpr ScalarType kBFloat16 = ScalarType::BFloat16; @@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) { return sizeof(int32_t); case ScalarType::Long: return sizeof(int64_t); + case ScalarType::Half: + return 2; // sizeof(__half) = 2 bytes case ScalarType::Float: return sizeof(float); case ScalarType::Bool: @@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) { return "Int"; case ScalarType::Long: return "Long"; + case ScalarType::Half: + return "Half"; case ScalarType::Float: return "Float"; case ScalarType::Bool: @@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) { /// @param t The scalar type to check. /// @return true if the scalar type is floating point, false otherwise. inline bool isFloatingType(ScalarType t) { - return t == ScalarType::Float || t == ScalarType::BFloat16; + return t == ScalarType::Half || t == ScalarType::Float || + t == ScalarType::BFloat16; } /// Checks if the scalar type is an integral type (including bool optionally). @@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) { case ScalarType::Short: case ScalarType::Int: case ScalarType::Long: + case ScalarType::Half: case ScalarType::Float: case ScalarType::Bool: case ScalarType::BFloat16: diff --git a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp index 332f5d7d264..4c06f7ef101 100644 --- a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp +++ b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp @@ -36,6 +36,7 @@ const std::vector kAllScalarTypes = { {ScalarType::Short, 2, 2, "Short", false, true, true, false}, {ScalarType::Int, 3, 4, "Int", false, true, true, false}, {ScalarType::Long, 4, 8, "Long", false, true, true, false}, + {ScalarType::Half, 5, 2, "Half", true, false, false, false}, {ScalarType::Float, 6, 4, "Float", true, false, false, false}, {ScalarType::Bool, 11, 1, "Bool", false, false, true, true}, {ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false}, @@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) { EXPECT_EQ(kLong, ScalarType::Long); } +TEST_F(ScalarTypeConstantsTest, KHalfConstant) { + EXPECT_EQ(kHalf, ScalarType::Half); +} + TEST_F(ScalarTypeConstantsTest, KFloatConstant) { EXPECT_EQ(kFloat, ScalarType::Float); } @@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) { EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t)); } +TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) { + EXPECT_EQ(elementSize(ScalarType::Half), 2); +} + TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) { EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); } @@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) { TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) { EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16)); } + +// ============================================================================= +// isValidScalarType Tests +// ============================================================================= + +class IsValidScalarTypeTest : public ::testing::Test {}; + +TEST_F(IsValidScalarTypeTest, HalfIsValid) { + EXPECT_TRUE(isValidScalarType(ScalarType::Half)); +} + +TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) { + EXPECT_TRUE(isValidScalarType(ScalarType::Byte)); + EXPECT_TRUE(isValidScalarType(ScalarType::Char)); + EXPECT_TRUE(isValidScalarType(ScalarType::Short)); + EXPECT_TRUE(isValidScalarType(ScalarType::Int)); + EXPECT_TRUE(isValidScalarType(ScalarType::Long)); + EXPECT_TRUE(isValidScalarType(ScalarType::Half)); + EXPECT_TRUE(isValidScalarType(ScalarType::Float)); + EXPECT_TRUE(isValidScalarType(ScalarType::Bool)); + EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16)); +} + +TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) { + EXPECT_FALSE(isValidScalarType(ScalarType::Undefined)); +} diff --git a/backends/cuda/runtime/shims/sort.cu b/backends/cuda/runtime/shims/sort.cu index 804b5a55959..8d4a9771e62 100644 --- a/backends/cuda/runtime/shims/sort.cu +++ b/backends/cuda/runtime/shims/sort.cu @@ -24,8 +24,8 @@ namespace executorch::backends::cuda { namespace c10_slim = executorch::backends::aoti::slim::c10; -// PyTorch ScalarType::Half = 5, not defined in slim ScalarType enum. -constexpr auto kHalf = static_cast(5); +// PyTorch ScalarType::Half = 5, now defined in slim ScalarType enum. +using c10_slim::kHalf; namespace { @@ -188,7 +188,7 @@ AOTITorchError aoti_torch_cuda_sort_stable( case c10_slim::ScalarType::BFloat16: elem_size = sizeof(__nv_bfloat16); break; - case kHalf: + case c10_slim::ScalarType::Half: elem_size = sizeof(__half); break; default: @@ -387,7 +387,7 @@ AOTITorchError aoti_torch_cuda_sort_stable( stream); break; } - case kHalf: { + case c10_slim::ScalarType::Half: { sort_slice_impl( static_cast<__half*>(values_base) + offset, idx_ptr,