diff --git a/paddle/phi/api/include/compat/ATen/ops/equal.h b/paddle/phi/api/include/compat/ATen/ops/equal.h index 4619144f8f5aba..1ac49d9e245d74 100644 --- a/paddle/phi/api/include/compat/ATen/ops/equal.h +++ b/paddle/phi/api/include/compat/ATen/ops/equal.h @@ -22,6 +22,12 @@ namespace at { inline bool equal(const at::Tensor& self, const at::Tensor& other) { + PD_CHECK(self.defined(), + "Expected a proper Tensor but got None (or an undefined Tensor in " + "C++)"); + PD_CHECK(other.defined(), + "Expected a proper Tensor but got None (or an undefined Tensor in " + "C++)"); PD_CHECK(self.device() == other.device(), "Cannot compare two tensors on " "different devices. Got: ", diff --git a/paddle/phi/api/include/compat/ATen/ops/select.h b/paddle/phi/api/include/compat/ATen/ops/select.h index 8c859da44349b6..6c522db600add2 100644 --- a/paddle/phi/api/include/compat/ATen/ops/select.h +++ b/paddle/phi/api/include/compat/ATen/ops/select.h @@ -19,13 +19,35 @@ namespace at { inline at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index) { + // Normalize dim to positive value for error messages + int64_t orig_dim = dim; if (dim < 0) { dim += self.dim(); } - // Handle negative indexing + // Check dim is valid + if (dim < 0 || dim >= self.dim()) { + PD_CHECK(false, + "select(): index ", + orig_dim, + " out of range for tensor of size ", + self.sizes(), + " at dimension ", + orig_dim); + } + // Handle negative index + int64_t orig_index = index; if (index < 0) { - int64_t dim_size = self.size(dim); - index = dim_size + index; + index = self.size(dim) + index; + } + // Check index is valid + if (index < 0 || index >= self.size(dim)) { + PD_CHECK(false, + "select(): index ", + orig_index, + " out of range for tensor of size ", + self.sizes(), + " at dimension ", + orig_dim < 0 ? orig_dim + self.dim() : orig_dim); } return Tensor( diff --git a/paddle/phi/api/include/compat/ATen/ops/std.h b/paddle/phi/api/include/compat/ATen/ops/std.h index b8600de2c857d9..ed7875f64431c1 100644 --- a/paddle/phi/api/include/compat/ATen/ops/std.h +++ b/paddle/phi/api/include/compat/ATen/ops/std.h @@ -32,6 +32,21 @@ inline Tensor std_impl(const Tensor& self, const std::vector& dims_vec, double correction_value, bool keepdim) { + // Validate dimensions before processing + int64_t ndim = self.dim(); + for (int64_t d : dims_vec) { + int64_t dim_idx = d < 0 ? d + ndim : d; + if (dim_idx < 0 || dim_idx >= ndim) { + PD_CHECK(false, + "Dimension out of range (expected to be in range of [", + -ndim, + ", ", + ndim - 1, + "], but got ", + d, + ")"); + } + } phi::IntArray dims_int_array(dims_vec); paddle::Tensor tensor = self._PD_GetInner(); diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h index 5495e655040dcb..e72a013c70f697 100644 --- a/paddle/phi/api/include/compat/c10/core/ScalarType.h +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -225,6 +225,23 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) #undef DEFINE_CONSTANT +constexpr ScalarType kComplexHalf = ScalarType::ComplexHalf; +constexpr ScalarType kQInt8 = ScalarType::QInt8; +constexpr ScalarType kQUInt8 = ScalarType::QUInt8; +constexpr ScalarType kQInt32 = ScalarType::QInt32; +constexpr ScalarType kQUInt4x2 = ScalarType::QUInt4x2; +constexpr ScalarType kQUInt2x4 = ScalarType::QUInt2x4; +constexpr ScalarType kBits1x8 = ScalarType::Bits1x8; +constexpr ScalarType kBits2x4 = ScalarType::Bits2x4; +constexpr ScalarType kBits4x2 = ScalarType::Bits4x2; +constexpr ScalarType kBits8 = ScalarType::Bits8; +constexpr ScalarType kBits16 = ScalarType::Bits16; +constexpr ScalarType kFloat8_e5m2fnuz = ScalarType::Float8_e5m2fnuz; +constexpr ScalarType kFloat8_e4m3fnuz = ScalarType::Float8_e4m3fnuz; +constexpr ScalarType kFloat8_e8m0fnu = ScalarType::Float8_e8m0fnu; +constexpr ScalarType kFloat4_e2m1fn_x2 = ScalarType::Float4_e2m1fn_x2; +constexpr ScalarType kUndefined = ScalarType::Undefined; + #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ _(uint8_t, Byte) \ _(int8_t, Char) \ @@ -281,6 +298,8 @@ inline const char* toString(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + case ScalarType::ComplexHalf: + return "ComplexHalf"; case ScalarType::QInt8: return "QInt8"; case ScalarType::QUInt8: @@ -291,8 +310,6 @@ inline const char* toString(ScalarType t) { return "QUInt4x2"; case ScalarType::QUInt2x4: return "QUInt2x4"; - case ScalarType::ComplexHalf: - return "ComplexHalf"; case ScalarType::Bits1x8: return "Bits1x8"; case ScalarType::Bits2x4: @@ -326,6 +343,8 @@ inline size_t elementSize(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + case ScalarType::ComplexHalf: + return sizeof(at::Half) * 2; case ScalarType::QInt8: case ScalarType::QUInt8: case ScalarType::QUInt4x2: @@ -334,10 +353,15 @@ inline size_t elementSize(ScalarType t) { case ScalarType::Bits2x4: case ScalarType::Bits4x2: case ScalarType::Bits8: + case ScalarType::Float8_e5m2fnuz: + case ScalarType::Float8_e4m3fnuz: + case ScalarType::Float8_e8m0fnu: + case ScalarType::Float4_e2m1fn_x2: return 1; case ScalarType::QInt32: - case ScalarType::Bits16: return 4; + case ScalarType::Bits16: + return 2; default: TORCH_CHECK(false, "Unknown ScalarType"); } @@ -354,15 +378,14 @@ inline bool isIntegralType(ScalarType t, bool includeBool) { } inline bool isFloat8Type(ScalarType t) { - return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn; - // || t == ScalarType::Float8_e5m2fnuz - // || t == ScalarType::Float8_e4m3fnuz - // || t == ScalarType::Float8_e8m0fnu + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn || + t == ScalarType::Float8_e5m2fnuz || t == ScalarType::Float8_e4m3fnuz || + t == ScalarType::Float8_e8m0fnu; } inline bool isReducedFloatingType(ScalarType t) { - return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t); - //|| t == ScalarType::Float4_e2m1fn_x2 + return t == ScalarType::Half || t == ScalarType::BFloat16 || + isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2; } inline bool isFloatingType(ScalarType t) { @@ -371,9 +394,8 @@ inline bool isFloatingType(ScalarType t) { } inline bool isComplexType(ScalarType t) { - return ( - /* t == ScalarType::ComplexHalf || */ t == ScalarType::ComplexFloat || - t == ScalarType::ComplexDouble); + return (t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); } inline bool isSignedType(ScalarType t) { @@ -408,9 +430,9 @@ inline bool isSignedType(ScalarType t) { CASE_ISSIGNED(Float8_e4m3fn); // Complex types (treated as signed) + case ScalarType::ComplexHalf: case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: - case ScalarType::ComplexHalf: return true; // Signed quantized types (explicitly return true) diff --git a/paddle/phi/api/include/compat/c10/core/Stream.cpp b/paddle/phi/api/include/compat/c10/core/Stream.cpp index 9a52b8c9f9f05d..60873f6e05a93c 100644 --- a/paddle/phi/api/include/compat/c10/core/Stream.cpp +++ b/paddle/phi/api/include/compat/c10/core/Stream.cpp @@ -44,9 +44,11 @@ void* Stream::native_handle() const { return reinterpret_cast(static_cast(id_)); } #endif - PADDLE_THROW(::common::errors::Unimplemented( - "c10::Stream::native_handle() is not supported for device type %d", - static_cast(device_type()))); + // Match PyTorch error message format for unsupported device types + PD_CHECK(false, + "native_handle() is not supported for this device type (", + static_cast(device_type()), + ")"); } bool Stream::query() const { diff --git a/paddle/phi/api/include/compat/c10/core/Stream.h b/paddle/phi/api/include/compat/c10/core/Stream.h index 58912130daf303..e9bcbc939d9215 100644 --- a/paddle/phi/api/include/compat/c10/core/Stream.h +++ b/paddle/phi/api/include/compat/c10/core/Stream.h @@ -90,9 +90,8 @@ class Stream final { }; inline std::ostream& operator<<(std::ostream& os, const Stream& s) { - os << "Stream(device_type=" << static_cast(s.device_type()) - << ", device_index=" << static_cast(s.device_index()) - << ", id=" << s.id() << ")"; + // Format: "stream {id} on device {device_type}:{device_index}" + os << "stream " << s.id() << " on device " << s.device(); return os; } diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h index 3cf18fd4f22574..4eb38ceecc681f 100644 --- a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h @@ -28,7 +28,5 @@ void synchronize(int64_t device_index = -1); } // namespace torch::cuda namespace at::cuda { -using torch::cuda::device_count; -using torch::cuda::is_available; using torch::cuda::synchronize; } // namespace at::cuda diff --git a/test/cpp/compat/ATen_TensorAccessor_test.cc b/test/cpp/compat/ATen_TensorAccessor_test.cc index cb1eaaf3c8add0..4d6b8e9648bcfa 100644 --- a/test/cpp/compat/ATen_TensorAccessor_test.cc +++ b/test/cpp/compat/ATen_TensorAccessor_test.cc @@ -198,7 +198,7 @@ TEST(TensorAccessorTest, PackedAccessorWithIntType) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(TensorAccessorTest, PackedAccessorCUDA) { - if (at::cuda::is_available()) { + if (torch::cuda::is_available()) { // Create CUDA tensor at::Tensor tensor = at::arange(12, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)) diff --git a/test/cpp/compat/c10_ScalarType_test.cc b/test/cpp/compat/c10_ScalarType_test.cc index ea0895412f72df..6a3bbc9b77fff9 100644 --- a/test/cpp/compat/c10_ScalarType_test.cc +++ b/test/cpp/compat/c10_ScalarType_test.cc @@ -25,7 +25,6 @@ #include #include #endif -#include #include "ATen/ATen.h" #include "gtest/gtest.h" #include "paddle/phi/common/float16.h" @@ -92,8 +91,32 @@ TEST(TensorBaseTest, TypeCheckingAPIs) { ASSERT_FALSE(bool_tensor.is_signed()); } -TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) { +TEST(ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) { + EXPECT_EQ(static_cast(c10::ScalarType::ComplexHalf), 8); + EXPECT_EQ(static_cast(c10::ScalarType::QInt8), 12); + EXPECT_EQ(static_cast(c10::ScalarType::Bits16), 22); + EXPECT_EQ(static_cast(c10::ScalarType::Float8_e5m2fnuz), 25); + EXPECT_EQ(static_cast(c10::ScalarType::Float4_e2m1fn_x2), 45); + EXPECT_EQ(c10::NumScalarTypes, 47); + + EXPECT_EQ(c10::kComplexHalf, c10::ScalarType::ComplexHalf); + EXPECT_EQ(c10::kQInt8, c10::ScalarType::QInt8); + EXPECT_EQ(c10::kBits16, c10::ScalarType::Bits16); + EXPECT_EQ(c10::kFloat8_e4m3fnuz, c10::ScalarType::Float8_e4m3fnuz); + EXPECT_EQ(c10::kFloat8_e8m0fnu, c10::ScalarType::Float8_e8m0fnu); + EXPECT_EQ(c10::kFloat4_e2m1fn_x2, c10::ScalarType::Float4_e2m1fn_x2); + EXPECT_EQ(c10::kUndefined, c10::ScalarType::Undefined); + + EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2"); + EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4"); EXPECT_STREQ(c10::toString(c10::ScalarType::Bits1x8), "Bits1x8"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2"); + EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8"); EXPECT_STREQ(c10::toString(c10::ScalarType::Bits16), "Bits16"); EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e5m2fnuz), "Float8_e5m2fnuz"); @@ -104,65 +127,27 @@ TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) { EXPECT_STREQ(c10::toString(c10::ScalarType::Float4_e2m1fn_x2), "Float4_e2m1fn_x2"); EXPECT_STREQ(c10::toString(c10::ScalarType::Undefined), "Undefined"); - EXPECT_STREQ(c10::toString(static_cast(-1)), - "UNKNOWN_SCALAR"); - - EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2), - static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), static_cast(4)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), static_cast(4)); - EXPECT_THROW(c10::elementSize(c10::ScalarType::Undefined), ::std::exception); - - EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::Bool, true)); - EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Bool, false)); - EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2)); - EXPECT_FALSE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz)); - EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::BFloat16)); - EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Float)); - EXPECT_FALSE(c10::isComplexType(c10::ScalarType::ComplexHalf)); - - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Int1)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::UInt3)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::QUInt8)); - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Float8_e5m2fnuz)); - EXPECT_THROW(c10::isSignedType(c10::ScalarType::Undefined), ::std::exception); - - std::ostringstream oss; - oss << c10::ScalarType::UInt7; - EXPECT_EQ(oss.str(), "UInt7"); -} - -TEST(ScalarTypeCompatTest, AdditionalEnumAndPredicateBranches) { - EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2"); - EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4"); - EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2"); - EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8"); - - EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4), - static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), static_cast(1)); - EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), static_cast(1)); - - EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::UInt64, false)); - EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Float, true)); - EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fn)); - EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Half)); - EXPECT_FALSE(c10::isReducedFloatingType(c10::ScalarType::Float)); - EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Half)); - EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexFloat)); - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::QInt8)); - EXPECT_TRUE(c10::isSignedType(c10::ScalarType::ComplexHalf)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Byte)); - EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Bool)); - EXPECT_THROW(c10::isSignedType(c10::ScalarType::NumOptions), - ::std::exception); + EXPECT_EQ(c10::elementSize(c10::ScalarType::ComplexHalf), + sizeof(at::Half) * 2); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), 4U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits1x8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), 2U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e5m2fnuz), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e4m3fnuz), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e8m0fnu), 1U); + EXPECT_EQ(c10::elementSize(c10::ScalarType::Float4_e2m1fn_x2), 1U); + + EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexHalf)); + EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2fnuz)); + EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz)); + EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e8m0fnu)); + EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Float4_e2m1fn_x2)); } diff --git a/test/cpp/compat/c10_storage_test.cc b/test/cpp/compat/c10_storage_test.cc index f05083fe88747f..d94dee0d47b605 100644 --- a/test/cpp/compat/c10_storage_test.cc +++ b/test/cpp/compat/c10_storage_test.cc @@ -26,8 +26,9 @@ #include "paddle/phi/backends/gpu/gpu_info.h" // Forward-declare getCUDADeviceAllocator to avoid include-order conflicts -// between ATen/cuda/CUDAContextLight.h (defines at::cuda::is_available inline) -// and torch/cuda.h (adds `using torch::cuda::is_available` to at::cuda). +// between ATen/cuda/CUDAContextLight.h (defines torch::cuda::is_available +// inline) and torch/cuda.h (adds `using torch::cuda::is_available` to +// at::cuda). namespace at::cuda { c10::Allocator* getCUDADeviceAllocator(); } // namespace at::cuda @@ -242,7 +243,7 @@ TEST(StorageTest, DeviceAndDeviceTypeAPIs) { ASSERT_EQ(place.GetType(), phi::AllocationType::CPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (at::cuda::is_available()) { + if (torch::cuda::is_available()) { at::TensorBase cuda_tensor = at::ones( {2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); const c10::Storage& cuda_storage = cuda_tensor.storage(); @@ -1041,7 +1042,7 @@ TEST(StorageTest, ReferenceSemanticsSetNbytesVisibleThroughCopy) { TEST(StorageTest, CUDAAllocatorZeroBytePreservesDevice) { // getCUDADeviceAllocator()->allocate(0) must return a DataPtr whose device // is the current CUDA device, not a default-constructed CPU DataPtr. - if (!at::cuda::is_available()) { + if (!torch::cuda::is_available()) { return; // No CUDA device, skip } diff --git a/test/cpp/compat/compat_basic_test.cc b/test/cpp/compat/compat_basic_test.cc index d6f33869449003..232a9fd66e8f7c 100644 --- a/test/cpp/compat/compat_basic_test.cc +++ b/test/cpp/compat/compat_basic_test.cc @@ -231,7 +231,7 @@ TEST(compat_basic_test, BasicCase) { TEST(TestDevice, DeviceAPIsOnCUDA) { // Test device related APIs on CUDA if available #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (at::cuda::is_available()) { + if (torch::cuda::is_available()) { at::TensorBase cuda_tensor = at::ones( {2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA));