Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ enum class ScalarType : int8_t {
Short = 2, // int16_t
Int = 3, // int32_t
Long = 4, // int64_t
// Half = 5, // float16 - not currently needed
Half = 5, // float16
Float = 6, // float
// Double = 7, // double - not currently needed
// ComplexHalf = 8,
Expand All @@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char;
constexpr ScalarType kShort = ScalarType::Short;
constexpr ScalarType kInt = ScalarType::Int;
constexpr ScalarType kLong = ScalarType::Long;
constexpr ScalarType kHalf = ScalarType::Half;
constexpr ScalarType kFloat = ScalarType::Float;
constexpr ScalarType kBool = ScalarType::Bool;
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
Expand All @@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) {
return sizeof(int32_t);
case ScalarType::Long:
return sizeof(int64_t);
case ScalarType::Half:
return 2; // sizeof(__half) = 2 bytes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment?

case ScalarType::Float:
return sizeof(float);
case ScalarType::Bool:
Expand All @@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) {
return "Int";
case ScalarType::Long:
return "Long";
case ScalarType::Half:
return "Half";
case ScalarType::Float:
return "Float";
case ScalarType::Bool:
Expand All @@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) {
/// @param t The scalar type to check.
/// @return true if the scalar type is floating point, false otherwise.
inline bool isFloatingType(ScalarType t) {
return t == ScalarType::Float || t == ScalarType::BFloat16;
return t == ScalarType::Half || t == ScalarType::Float ||
t == ScalarType::BFloat16;
}

/// Checks if the scalar type is an integral type (including bool optionally).
Expand Down Expand Up @@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) {
case ScalarType::Short:
case ScalarType::Int:
case ScalarType::Long:
case ScalarType::Half:
case ScalarType::Float:
case ScalarType::Bool:
case ScalarType::BFloat16:
Expand Down
35 changes: 35 additions & 0 deletions backends/aoti/slim/c10/core/test/test_scalar_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const std::vector<ScalarTypeTestData> kAllScalarTypes = {
{ScalarType::Short, 2, 2, "Short", false, true, true, false},
{ScalarType::Int, 3, 4, "Int", false, true, true, false},
{ScalarType::Long, 4, 8, "Long", false, true, true, false},
{ScalarType::Half, 5, 2, "Half", true, false, false, false},
{ScalarType::Float, 6, 4, "Float", true, false, false, false},
Comment on lines 36 to 40
{ScalarType::Bool, 11, 1, "Bool", false, false, true, true},
{ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false},
Expand Down Expand Up @@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) {
EXPECT_EQ(kLong, ScalarType::Long);
}

TEST_F(ScalarTypeConstantsTest, KHalfConstant) {
EXPECT_EQ(kHalf, ScalarType::Half);
}

TEST_F(ScalarTypeConstantsTest, KFloatConstant) {
EXPECT_EQ(kFloat, ScalarType::Float);
}
Expand Down Expand Up @@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) {
EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t));
}

TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) {
EXPECT_EQ(elementSize(ScalarType::Half), 2);
}

TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) {
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
}
Expand All @@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) {
TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) {
EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16));
}

// =============================================================================
// isValidScalarType Tests
// =============================================================================

class IsValidScalarTypeTest : public ::testing::Test {};

TEST_F(IsValidScalarTypeTest, HalfIsValid) {
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
}
Comment on lines +215 to +217
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove


TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) {
EXPECT_TRUE(isValidScalarType(ScalarType::Byte));
EXPECT_TRUE(isValidScalarType(ScalarType::Char));
EXPECT_TRUE(isValidScalarType(ScalarType::Short));
EXPECT_TRUE(isValidScalarType(ScalarType::Int));
EXPECT_TRUE(isValidScalarType(ScalarType::Long));
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
EXPECT_TRUE(isValidScalarType(ScalarType::Float));
EXPECT_TRUE(isValidScalarType(ScalarType::Bool));
EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16));
}

TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) {
EXPECT_FALSE(isValidScalarType(ScalarType::Undefined));
}
8 changes: 4 additions & 4 deletions backends/cuda/runtime/shims/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace executorch::backends::cuda {

namespace c10_slim = executorch::backends::aoti::slim::c10;

// PyTorch ScalarType::Half = 5, not defined in slim ScalarType enum.
constexpr auto kHalf = static_cast<c10_slim::ScalarType>(5);
// PyTorch ScalarType::Half = 5, now defined in slim ScalarType enum.
using c10_slim::kHalf;

Comment on lines +27 to 29
Comment on lines +27 to 29
namespace {

Expand Down Expand Up @@ -188,7 +188,7 @@ AOTITorchError aoti_torch_cuda_sort_stable(
case c10_slim::ScalarType::BFloat16:
elem_size = sizeof(__nv_bfloat16);
break;
case kHalf:
case c10_slim::ScalarType::Half:
elem_size = sizeof(__half);
break;
default:
Expand Down Expand Up @@ -387,7 +387,7 @@ AOTITorchError aoti_torch_cuda_sort_stable(
stream);
break;
}
case kHalf: {
case c10_slim::ScalarType::Half: {
sort_slice_impl(
static_cast<__half*>(values_base) + offset,
idx_ptr,
Expand Down
Loading