From d39c77c59639e638860d8b7970838a7cdf43e2de Mon Sep 17 00:00:00 2001 From: youge325 Date: Fri, 27 Mar 2026 12:36:30 +0800 Subject: [PATCH 1/4] revert useless ScalarType first --- .../api/include/compat/c10/core/ScalarType.h | 115 ++---------------- test/cpp/compat/c10_ScalarType_test.cc | 76 ------------ 2 files changed, 8 insertions(+), 183 deletions(-) diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h index 5495e655040dc..0fe99f5d1a0f0 100644 --- a/paddle/phi/api/include/compat/c10/core/ScalarType.h +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -130,58 +130,15 @@ struct dummy_int1_7_t {}; _(uint32_t, UINT32, UInt32) enum class PADDLE_API ScalarType : int8_t { - Byte = 0, - Char = 1, - Short = 2, - Int = 3, - Long = 4, - Half = 5, - Float = 6, - Double = 7, - ComplexHalf = 8, - ComplexFloat = 9, - ComplexDouble = 10, - Bool = 11, - QInt8 = 12, - QUInt8 = 13, - QInt32 = 14, - BFloat16 = 15, - QUInt4x2 = 16, - QUInt2x4 = 17, - Bits1x8 = 18, - Bits2x4 = 19, - Bits4x2 = 20, - Bits8 = 21, - Bits16 = 22, - Float8_e5m2 = 23, - Float8_e4m3fn = 24, - Float8_e5m2fnuz = 25, - Float8_e4m3fnuz = 26, - UInt16 = 27, - UInt32 = 28, - UInt64 = 29, - UInt1 = 30, - UInt2 = 31, - UInt3 = 32, - UInt4 = 33, - UInt5 = 34, - UInt6 = 35, - UInt7 = 36, - Int1 = 37, - Int2 = 38, - Int3 = 39, - Int4 = 40, - Int5 = 41, - Int6 = 42, - Int7 = 43, - Float8_e8m0fnu = 44, - Float4_e2m1fn_x2 = 45, - Undefined = 46, - NumOptions = 47 +#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ +#define DEFINE_ST_ENUM_VAL_FOR_QINTS_(_1, n) n, + AT_FORALL_QINT_TYPES(DEFINE_ST_ENUM_VAL_FOR_QINTS_) +#undef DEFINE_ST_ENUM_VAL_FOR_QINTS_ + Undefined, + NumOptions }; - -constexpr uint16_t NumScalarTypes = - static_cast(ScalarType::NumOptions); namespace impl { // These are used to map ScalarTypes to C++ types. @@ -281,38 +238,6 @@ inline const char* toString(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) - case ScalarType::QInt8: - return "QInt8"; - case ScalarType::QUInt8: - return "QUInt8"; - case ScalarType::QInt32: - return "QInt32"; - case ScalarType::QUInt4x2: - return "QUInt4x2"; - case ScalarType::QUInt2x4: - return "QUInt2x4"; - case ScalarType::ComplexHalf: - return "ComplexHalf"; - case ScalarType::Bits1x8: - return "Bits1x8"; - case ScalarType::Bits2x4: - return "Bits2x4"; - case ScalarType::Bits4x2: - return "Bits4x2"; - case ScalarType::Bits8: - return "Bits8"; - case ScalarType::Bits16: - return "Bits16"; - case ScalarType::Float8_e5m2fnuz: - return "Float8_e5m2fnuz"; - case ScalarType::Float8_e4m3fnuz: - return "Float8_e4m3fnuz"; - case ScalarType::Float8_e8m0fnu: - return "Float8_e8m0fnu"; - case ScalarType::Float4_e2m1fn_x2: - return "Float4_e2m1fn_x2"; - case ScalarType::Undefined: - return "Undefined"; default: return "UNKNOWN_SCALAR"; } @@ -326,18 +251,6 @@ inline size_t elementSize(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) - case ScalarType::QInt8: - case ScalarType::QUInt8: - case ScalarType::QUInt4x2: - case ScalarType::QUInt2x4: - case ScalarType::Bits1x8: - case ScalarType::Bits2x4: - case ScalarType::Bits4x2: - case ScalarType::Bits8: - return 1; - case ScalarType::QInt32: - case ScalarType::Bits16: - return 4; default: TORCH_CHECK(false, "Unknown ScalarType"); } @@ -410,7 +323,6 @@ inline bool isSignedType(ScalarType t) { // Complex types (treated as signed) case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: - case ScalarType::ComplexHalf: return true; // Signed quantized types (explicitly return true) @@ -438,22 +350,11 @@ inline bool isSignedType(ScalarType t) { case ScalarType::QUInt8: case ScalarType::QUInt4x2: case ScalarType::QUInt2x4: - case ScalarType::Bits1x8: - case ScalarType::Bits2x4: - case ScalarType::Bits4x2: - case ScalarType::Bits8: - case ScalarType::Bits16: return false; // Bool is unsigned (using numeric_limits) CASE_ISSIGNED(Bool); - case ScalarType::Float8_e5m2fnuz: - case ScalarType::Float8_e4m3fnuz: - case ScalarType::Float8_e8m0fnu: - case ScalarType::Float4_e2m1fn_x2: - return true; - // Invalid/undefined types - should not happen in normal usage // If this is hit, it indicates a programming error or unsupported type case ScalarType::Undefined: diff --git a/test/cpp/compat/c10_ScalarType_test.cc b/test/cpp/compat/c10_ScalarType_test.cc index ea0895412f72d..a373ea8184104 100644 --- a/test/cpp/compat/c10_ScalarType_test.cc +++ b/test/cpp/compat/c10_ScalarType_test.cc @@ -25,7 +25,6 @@ #include #include #endif -#include #include "ATen/ATen.h" #include "gtest/gtest.h" #include "paddle/phi/common/float16.h" @@ -91,78 +90,3 @@ TEST(TensorBaseTest, TypeCheckingAPIs) { ASSERT_FALSE(uint8_tensor.is_signed()); ASSERT_FALSE(bool_tensor.is_signed()); } - -TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) { - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits1x8), "Bits1x8"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits16), "Bits16"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e5m2fnuz), - "Float8_e5m2fnuz"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e4m3fnuz), - "Float8_e4m3fnuz"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e8m0fnu), - "Float8_e8m0fnu"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Float4_e2m1fn_x2), - "Float4_e2m1fn_x2"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Undefined), "Undefined"); - EXPECT_STREQ(c10::toString(static_cast(-1)), - "UNKNOWN_SCALAR"); - - EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2), - static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), static_cast(4)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), static_cast(4)); - EXPECT_THROW(c10::elementSize(c10::ScalarType::Undefined), ::std::exception); - - EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::Bool, true)); - EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Bool, false)); - EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2)); - EXPECT_FALSE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz)); - EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::BFloat16)); - EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Float)); - EXPECT_FALSE(c10::isComplexType(c10::ScalarType::ComplexHalf)); - - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Int1)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::UInt3)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::QUInt8)); - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Float8_e5m2fnuz)); - EXPECT_THROW(c10::isSignedType(c10::ScalarType::Undefined), ::std::exception); - - std::ostringstream oss; - oss << c10::ScalarType::UInt7; - EXPECT_EQ(oss.str(), "UInt7"); -} - -TEST(ScalarTypeCompatTest, AdditionalEnumAndPredicateBranches) { - EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4"); - EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8"); - - EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4), - static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), static_cast(1)); - - EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::UInt64, false)); - EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Float, true)); - EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fn)); - EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Half)); - EXPECT_FALSE(c10::isReducedFloatingType(c10::ScalarType::Float)); - EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Half)); - EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexFloat)); - - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::QInt8)); - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::ComplexHalf)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Byte)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Bool)); - EXPECT_THROW(c10::isSignedType(c10::ScalarType::NumOptions), - ::std::exception); -} From bed6e5620906044b356a66c69009a723da8479ea Mon Sep 17 00:00:00 2001 From: youge325 Date: Sat, 28 Mar 2026 11:24:53 +0800 Subject: [PATCH 2/4] resolve some mismatch, and fix compling error in FlashMLA --- .../phi/api/include/compat/ATen/ops/equal.h | 6 ++++ .../phi/api/include/compat/ATen/ops/select.h | 28 +++++++++++++++++-- paddle/phi/api/include/compat/ATen/ops/std.h | 15 ++++++++++ .../api/include/compat/c10/core/Stream.cpp | 8 ++++-- .../phi/api/include/compat/c10/core/Stream.h | 5 ++-- .../torch/csrc/api/include/torch/cuda.h | 2 -- 6 files changed, 53 insertions(+), 11 deletions(-) diff --git a/paddle/phi/api/include/compat/ATen/ops/equal.h b/paddle/phi/api/include/compat/ATen/ops/equal.h index 4619144f8f5ab..1ac49d9e245d7 100644 --- a/paddle/phi/api/include/compat/ATen/ops/equal.h +++ b/paddle/phi/api/include/compat/ATen/ops/equal.h @@ -22,6 +22,12 @@ namespace at { inline bool equal(const at::Tensor& self, const at::Tensor& other) { + PD_CHECK(self.defined(), + "Expected a proper Tensor but got None (or an undefined Tensor in " + "C++)"); + PD_CHECK(other.defined(), + "Expected a proper Tensor but got None (or an undefined Tensor in " + "C++)"); PD_CHECK(self.device() == other.device(), "Cannot compare two tensors on " "different devices. Got: ", diff --git a/paddle/phi/api/include/compat/ATen/ops/select.h b/paddle/phi/api/include/compat/ATen/ops/select.h index 8c859da44349b..6c522db600add 100644 --- a/paddle/phi/api/include/compat/ATen/ops/select.h +++ b/paddle/phi/api/include/compat/ATen/ops/select.h @@ -19,13 +19,35 @@ namespace at { inline at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index) { + // Normalize dim to positive value for error messages + int64_t orig_dim = dim; if (dim < 0) { dim += self.dim(); } - // Handle negative indexing + // Check dim is valid + if (dim < 0 || dim >= self.dim()) { + PD_CHECK(false, + "select(): index ", + orig_dim, + " out of range for tensor of size ", + self.sizes(), + " at dimension ", + orig_dim); + } + // Handle negative index + int64_t orig_index = index; if (index < 0) { - int64_t dim_size = self.size(dim); - index = dim_size + index; + index = self.size(dim) + index; + } + // Check index is valid + if (index < 0 || index >= self.size(dim)) { + PD_CHECK(false, + "select(): index ", + orig_index, + " out of range for tensor of size ", + self.sizes(), + " at dimension ", + orig_dim < 0 ? orig_dim + self.dim() : orig_dim); } return Tensor( diff --git a/paddle/phi/api/include/compat/ATen/ops/std.h b/paddle/phi/api/include/compat/ATen/ops/std.h index b8600de2c857d..ed7875f64431c 100644 --- a/paddle/phi/api/include/compat/ATen/ops/std.h +++ b/paddle/phi/api/include/compat/ATen/ops/std.h @@ -32,6 +32,21 @@ inline Tensor std_impl(const Tensor& self, const std::vector& dims_vec, double correction_value, bool keepdim) { + // Validate dimensions before processing + int64_t ndim = self.dim(); + for (int64_t d : dims_vec) { + int64_t dim_idx = d < 0 ? d + ndim : d; + if (dim_idx < 0 || dim_idx >= ndim) { + PD_CHECK(false, + "Dimension out of range (expected to be in range of [", + -ndim, + ", ", + ndim - 1, + "], but got ", + d, + ")"); + } + } phi::IntArray dims_int_array(dims_vec); paddle::Tensor tensor = self._PD_GetInner(); diff --git a/paddle/phi/api/include/compat/c10/core/Stream.cpp b/paddle/phi/api/include/compat/c10/core/Stream.cpp index 9a52b8c9f9f05..60873f6e05a93 100644 --- a/paddle/phi/api/include/compat/c10/core/Stream.cpp +++ b/paddle/phi/api/include/compat/c10/core/Stream.cpp @@ -44,9 +44,11 @@ void* Stream::native_handle() const { return reinterpret_cast(static_cast(id_)); } #endif - PADDLE_THROW(::common::errors::Unimplemented( - "c10::Stream::native_handle() is not supported for device type %d", - static_cast(device_type()))); + // Match PyTorch error message format for unsupported device types + PD_CHECK(false, + "native_handle() is not supported for this device type (", + static_cast(device_type()), + ")"); } bool Stream::query() const { diff --git a/paddle/phi/api/include/compat/c10/core/Stream.h b/paddle/phi/api/include/compat/c10/core/Stream.h index 58912130daf30..e9bcbc939d921 100644 --- a/paddle/phi/api/include/compat/c10/core/Stream.h +++ b/paddle/phi/api/include/compat/c10/core/Stream.h @@ -90,9 +90,8 @@ class Stream final { }; inline std::ostream& operator<<(std::ostream& os, const Stream& s) { - os << "Stream(device_type=" << static_cast(s.device_type()) - << ", device_index=" << static_cast(s.device_index()) - << ", id=" << s.id() << ")"; + // Format: "stream {id} on device {device_type}:{device_index}" + os << "stream " << s.id() << " on device " << s.device(); return os; } diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h index 3cf18fd4f2257..4eb38ceecc681 100644 --- a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h @@ -28,7 +28,5 @@ void synchronize(int64_t device_index = -1); } // namespace torch::cuda namespace at::cuda { -using torch::cuda::device_count; -using torch::cuda::is_available; using torch::cuda::synchronize; } // namespace at::cuda From 1e7a992200d69539a7b16ac4095e2cce3beb597e Mon Sep 17 00:00:00 2001 From: youge325 Date: Wed, 1 Apr 2026 20:10:32 +0800 Subject: [PATCH 3/4] fix: update tests to use torch::cuda::is_available --- test/cpp/compat/ATen_TensorAccessor_test.cc | 2 +- test/cpp/compat/c10_storage_test.cc | 9 +++++---- test/cpp/compat/compat_basic_test.cc | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/cpp/compat/ATen_TensorAccessor_test.cc b/test/cpp/compat/ATen_TensorAccessor_test.cc index cb1eaaf3c8add..4d6b8e9648bcf 100644 --- a/test/cpp/compat/ATen_TensorAccessor_test.cc +++ b/test/cpp/compat/ATen_TensorAccessor_test.cc @@ -198,7 +198,7 @@ TEST(TensorAccessorTest, PackedAccessorWithIntType) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(TensorAccessorTest, PackedAccessorCUDA) { - if (at::cuda::is_available()) { + if (torch::cuda::is_available()) { // Create CUDA tensor at::Tensor tensor = at::arange(12, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)) diff --git a/test/cpp/compat/c10_storage_test.cc b/test/cpp/compat/c10_storage_test.cc index f05083fe88747..d94dee0d47b60 100644 --- a/test/cpp/compat/c10_storage_test.cc +++ b/test/cpp/compat/c10_storage_test.cc @@ -26,8 +26,9 @@ #include "paddle/phi/backends/gpu/gpu_info.h" // Forward-declare getCUDADeviceAllocator to avoid include-order conflicts -// between ATen/cuda/CUDAContextLight.h (defines at::cuda::is_available inline) -// and torch/cuda.h (adds `using torch::cuda::is_available` to at::cuda). +// between ATen/cuda/CUDAContextLight.h (defines torch::cuda::is_available +// inline) and torch/cuda.h (adds `using torch::cuda::is_available` to +// at::cuda). namespace at::cuda { c10::Allocator* getCUDADeviceAllocator(); } // namespace at::cuda @@ -242,7 +243,7 @@ TEST(StorageTest, DeviceAndDeviceTypeAPIs) { ASSERT_EQ(place.GetType(), phi::AllocationType::CPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (at::cuda::is_available()) { + if (torch::cuda::is_available()) { at::TensorBase cuda_tensor = at::ones( {2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); const c10::Storage& cuda_storage = cuda_tensor.storage(); @@ -1041,7 +1042,7 @@ TEST(StorageTest, ReferenceSemanticsSetNbytesVisibleThroughCopy) { TEST(StorageTest, CUDAAllocatorZeroBytePreservesDevice) { // getCUDADeviceAllocator()->allocate(0) must return a DataPtr whose device // is the current CUDA device, not a default-constructed CPU DataPtr. - if (!at::cuda::is_available()) { + if (!torch::cuda::is_available()) { return; // No CUDA device, skip } diff --git a/test/cpp/compat/compat_basic_test.cc b/test/cpp/compat/compat_basic_test.cc index d6f3386944900..232a9fd66e8f7 100644 --- a/test/cpp/compat/compat_basic_test.cc +++ b/test/cpp/compat/compat_basic_test.cc @@ -231,7 +231,7 @@ TEST(compat_basic_test, BasicCase) { TEST(TestDevice, DeviceAPIsOnCUDA) { // Test device related APIs on CUDA if available #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (at::cuda::is_available()) { + if (torch::cuda::is_available()) { at::TensorBase cuda_tensor = at::ones( {2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); From 236217e8841ee457b187d75ace1c1c432cd7aa68 Mon Sep 17 00:00:00 2001 From: youge325 Date: Thu, 2 Apr 2026 22:28:29 +0800 Subject: [PATCH 4/4] revert ScalarType --- .../api/include/compat/c10/core/ScalarType.h | 155 ++++++++++++++++-- test/cpp/compat/c10_ScalarType_test.cc | 61 +++++++ 2 files changed, 199 insertions(+), 17 deletions(-) diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h index 0fe99f5d1a0f0..e72a013c70f69 100644 --- a/paddle/phi/api/include/compat/c10/core/ScalarType.h +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -130,15 +130,58 @@ struct dummy_int1_7_t {}; _(uint32_t, UINT32, UInt32) enum class PADDLE_API ScalarType : int8_t { -#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n, - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) -#undef DEFINE_ENUM_ST_ENUM_VAL_ -#define DEFINE_ST_ENUM_VAL_FOR_QINTS_(_1, n) n, - AT_FORALL_QINT_TYPES(DEFINE_ST_ENUM_VAL_FOR_QINTS_) -#undef DEFINE_ST_ENUM_VAL_FOR_QINTS_ - Undefined, - NumOptions + Byte = 0, + Char = 1, + Short = 2, + Int = 3, + Long = 4, + Half = 5, + Float = 6, + Double = 7, + ComplexHalf = 8, + ComplexFloat = 9, + ComplexDouble = 10, + Bool = 11, + QInt8 = 12, + QUInt8 = 13, + QInt32 = 14, + BFloat16 = 15, + QUInt4x2 = 16, + QUInt2x4 = 17, + Bits1x8 = 18, + Bits2x4 = 19, + Bits4x2 = 20, + Bits8 = 21, + Bits16 = 22, + Float8_e5m2 = 23, + Float8_e4m3fn = 24, + Float8_e5m2fnuz = 25, + Float8_e4m3fnuz = 26, + UInt16 = 27, + UInt32 = 28, + UInt64 = 29, + UInt1 = 30, + UInt2 = 31, + UInt3 = 32, + UInt4 = 33, + UInt5 = 34, + UInt6 = 35, + UInt7 = 36, + Int1 = 37, + Int2 = 38, + Int3 = 39, + Int4 = 40, + Int5 = 41, + Int6 = 42, + Int7 = 43, + Float8_e8m0fnu = 44, + Float4_e2m1fn_x2 = 45, + Undefined = 46, + NumOptions = 47 }; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); namespace impl { // These are used to map ScalarTypes to C++ types. @@ -182,6 +225,23 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) #undef DEFINE_CONSTANT +constexpr ScalarType kComplexHalf = ScalarType::ComplexHalf; +constexpr ScalarType kQInt8 = ScalarType::QInt8; +constexpr ScalarType kQUInt8 = ScalarType::QUInt8; +constexpr ScalarType kQInt32 = ScalarType::QInt32; +constexpr ScalarType kQUInt4x2 = ScalarType::QUInt4x2; +constexpr ScalarType kQUInt2x4 = ScalarType::QUInt2x4; +constexpr ScalarType kBits1x8 = ScalarType::Bits1x8; +constexpr ScalarType kBits2x4 = ScalarType::Bits2x4; +constexpr ScalarType kBits4x2 = ScalarType::Bits4x2; +constexpr ScalarType kBits8 = ScalarType::Bits8; +constexpr ScalarType kBits16 = ScalarType::Bits16; +constexpr ScalarType kFloat8_e5m2fnuz = ScalarType::Float8_e5m2fnuz; +constexpr ScalarType kFloat8_e4m3fnuz = ScalarType::Float8_e4m3fnuz; +constexpr ScalarType kFloat8_e8m0fnu = ScalarType::Float8_e8m0fnu; +constexpr ScalarType kFloat4_e2m1fn_x2 = ScalarType::Float4_e2m1fn_x2; +constexpr ScalarType kUndefined = ScalarType::Undefined; + #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ _(uint8_t, Byte) \ _(int8_t, Char) \ @@ -238,6 +298,38 @@ inline const char* toString(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + case ScalarType::ComplexHalf: + return "ComplexHalf"; + case ScalarType::QInt8: + return "QInt8"; + case ScalarType::QUInt8: + return "QUInt8"; + case ScalarType::QInt32: + return "QInt32"; + case ScalarType::QUInt4x2: + return "QUInt4x2"; + case ScalarType::QUInt2x4: + return "QUInt2x4"; + case ScalarType::Bits1x8: + return "Bits1x8"; + case ScalarType::Bits2x4: + return "Bits2x4"; + case ScalarType::Bits4x2: + return "Bits4x2"; + case ScalarType::Bits8: + return "Bits8"; + case ScalarType::Bits16: + return "Bits16"; + case ScalarType::Float8_e5m2fnuz: + return "Float8_e5m2fnuz"; + case ScalarType::Float8_e4m3fnuz: + return "Float8_e4m3fnuz"; + case ScalarType::Float8_e8m0fnu: + return "Float8_e8m0fnu"; + case ScalarType::Float4_e2m1fn_x2: + return "Float4_e2m1fn_x2"; + case ScalarType::Undefined: + return "Undefined"; default: return "UNKNOWN_SCALAR"; } @@ -251,6 +343,25 @@ inline size_t elementSize(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + case ScalarType::ComplexHalf: + return sizeof(at::Half) * 2; + case ScalarType::QInt8: + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Float8_e5m2fnuz: + case ScalarType::Float8_e4m3fnuz: + case ScalarType::Float8_e8m0fnu: + case ScalarType::Float4_e2m1fn_x2: + return 1; + case ScalarType::QInt32: + return 4; + case ScalarType::Bits16: + return 2; default: TORCH_CHECK(false, "Unknown ScalarType"); } @@ -267,15 +378,14 @@ inline bool isIntegralType(ScalarType t, bool includeBool) { } inline bool isFloat8Type(ScalarType t) { - return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn; - // || t == ScalarType::Float8_e5m2fnuz - // || t == ScalarType::Float8_e4m3fnuz - // || t == ScalarType::Float8_e8m0fnu + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn || + t == ScalarType::Float8_e5m2fnuz || t == ScalarType::Float8_e4m3fnuz || + t == ScalarType::Float8_e8m0fnu; } inline bool isReducedFloatingType(ScalarType t) { - return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t); - //|| t == ScalarType::Float4_e2m1fn_x2 + return t == ScalarType::Half || t == ScalarType::BFloat16 || + isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2; } inline bool isFloatingType(ScalarType t) { @@ -284,9 +394,8 @@ inline bool isFloatingType(ScalarType t) { } inline bool isComplexType(ScalarType t) { - return ( - /* t == ScalarType::ComplexHalf || */ t == ScalarType::ComplexFloat || - t == ScalarType::ComplexDouble); + return (t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); } inline bool isSignedType(ScalarType t) { @@ -321,6 +430,7 @@ inline bool isSignedType(ScalarType t) { CASE_ISSIGNED(Float8_e4m3fn); // Complex types (treated as signed) + case ScalarType::ComplexHalf: case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: return true; @@ -350,11 +460,22 @@ inline bool isSignedType(ScalarType t) { case ScalarType::QUInt8: case ScalarType::QUInt4x2: case ScalarType::QUInt2x4: + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: return false; // Bool is unsigned (using numeric_limits) CASE_ISSIGNED(Bool); + case ScalarType::Float8_e5m2fnuz: + case ScalarType::Float8_e4m3fnuz: + case ScalarType::Float8_e8m0fnu: + case ScalarType::Float4_e2m1fn_x2: + return true; + // Invalid/undefined types - should not happen in normal usage // If this is hit, it indicates a programming error or unsupported type case ScalarType::Undefined: diff --git a/test/cpp/compat/c10_ScalarType_test.cc b/test/cpp/compat/c10_ScalarType_test.cc index a373ea8184104..6a3bbc9b77fff 100644 --- a/test/cpp/compat/c10_ScalarType_test.cc +++ b/test/cpp/compat/c10_ScalarType_test.cc @@ -90,3 +90,64 @@ TEST(TensorBaseTest, TypeCheckingAPIs) { ASSERT_FALSE(uint8_tensor.is_signed()); ASSERT_FALSE(bool_tensor.is_signed()); } + +TEST(ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) { + EXPECT_EQ(static_cast(c10::ScalarType::ComplexHalf), 8); + EXPECT_EQ(static_cast(c10::ScalarType::QInt8), 12); + EXPECT_EQ(static_cast(c10::ScalarType::Bits16), 22); + EXPECT_EQ(static_cast(c10::ScalarType::Float8_e5m2fnuz), 25); + EXPECT_EQ(static_cast(c10::ScalarType::Float4_e2m1fn_x2), 45); + EXPECT_EQ(c10::NumScalarTypes, 47); + + EXPECT_EQ(c10::kComplexHalf, c10::ScalarType::ComplexHalf); + EXPECT_EQ(c10::kQInt8, c10::ScalarType::QInt8); + EXPECT_EQ(c10::kBits16, c10::ScalarType::Bits16); + EXPECT_EQ(c10::kFloat8_e4m3fnuz, c10::ScalarType::Float8_e4m3fnuz); + EXPECT_EQ(c10::kFloat8_e8m0fnu, c10::ScalarType::Float8_e8m0fnu); + EXPECT_EQ(c10::kFloat4_e2m1fn_x2, c10::ScalarType::Float4_e2m1fn_x2); + EXPECT_EQ(c10::kUndefined, c10::ScalarType::Undefined); + + EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits1x8), "Bits1x8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits16), "Bits16"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e5m2fnuz), + "Float8_e5m2fnuz"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e4m3fnuz), + "Float8_e4m3fnuz"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e8m0fnu), + "Float8_e8m0fnu"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Float4_e2m1fn_x2), + "Float4_e2m1fn_x2"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Undefined), "Undefined"); + + EXPECT_EQ(c10::elementSize(c10::ScalarType::ComplexHalf), + sizeof(at::Half) * 2); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), 4U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits1x8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), 2U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e5m2fnuz), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e4m3fnuz), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e8m0fnu), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float4_e2m1fn_x2), 1U); + + EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexHalf)); + EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2fnuz)); + EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz)); + EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e8m0fnu)); + EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Float4_e2m1fn_x2)); +}