2525#include < c10/cuda/CUDAFunctions.h>
2626#include < c10/cuda/CUDAGuard.h>
2727#endif
28- #include < sstream>
2928#include " ATen/ATen.h"
3029#include " gtest/gtest.h"
3130#include " paddle/phi/common/float16.h"
@@ -92,8 +91,32 @@ TEST(TensorBaseTest, TypeCheckingAPIs) {
9291 ASSERT_FALSE (bool_tensor.is_signed ());
9392}
9493
95- TEST (ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
94+ TEST (ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) {
95+ EXPECT_EQ (static_cast <int >(c10::ScalarType::ComplexHalf), 8 );
96+ EXPECT_EQ (static_cast <int >(c10::ScalarType::QInt8), 12 );
97+ EXPECT_EQ (static_cast <int >(c10::ScalarType::Bits16), 22 );
98+ EXPECT_EQ (static_cast <int >(c10::ScalarType::Float8_e5m2fnuz), 25 );
99+ EXPECT_EQ (static_cast <int >(c10::ScalarType::Float4_e2m1fn_x2), 45 );
100+ EXPECT_EQ (c10::NumScalarTypes, 47 );
101+
102+ EXPECT_EQ (c10::kComplexHalf , c10::ScalarType::ComplexHalf);
103+ EXPECT_EQ (c10::kQInt8 , c10::ScalarType::QInt8);
104+ EXPECT_EQ (c10::kBits16 , c10::ScalarType::Bits16);
105+ EXPECT_EQ (c10::kFloat8_e4m3fnuz , c10::ScalarType::Float8_e4m3fnuz);
106+ EXPECT_EQ (c10::kFloat8_e8m0fnu , c10::ScalarType::Float8_e8m0fnu);
107+ EXPECT_EQ (c10::kFloat4_e2m1fn_x2 , c10::ScalarType::Float4_e2m1fn_x2);
108+ EXPECT_EQ (c10::kUndefined , c10::ScalarType::Undefined);
109+
110+ EXPECT_STREQ (c10::toString (c10::ScalarType::ComplexHalf), " ComplexHalf" );
111+ EXPECT_STREQ (c10::toString (c10::ScalarType::QInt8), " QInt8" );
112+ EXPECT_STREQ (c10::toString (c10::ScalarType::QUInt8), " QUInt8" );
113+ EXPECT_STREQ (c10::toString (c10::ScalarType::QInt32), " QInt32" );
114+ EXPECT_STREQ (c10::toString (c10::ScalarType::QUInt4x2), " QUInt4x2" );
115+ EXPECT_STREQ (c10::toString (c10::ScalarType::QUInt2x4), " QUInt2x4" );
96116 EXPECT_STREQ (c10::toString (c10::ScalarType::Bits1x8), " Bits1x8" );
117+ EXPECT_STREQ (c10::toString (c10::ScalarType::Bits2x4), " Bits2x4" );
118+ EXPECT_STREQ (c10::toString (c10::ScalarType::Bits4x2), " Bits4x2" );
119+ EXPECT_STREQ (c10::toString (c10::ScalarType::Bits8), " Bits8" );
97120 EXPECT_STREQ (c10::toString (c10::ScalarType::Bits16), " Bits16" );
98121 EXPECT_STREQ (c10::toString (c10::ScalarType::Float8_e5m2fnuz),
99122 " Float8_e5m2fnuz" );
@@ -104,65 +127,27 @@ TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
104127 EXPECT_STREQ (c10::toString (c10::ScalarType::Float4_e2m1fn_x2),
105128 " Float4_e2m1fn_x2" );
106129 EXPECT_STREQ (c10::toString (c10::ScalarType::Undefined), " Undefined" );
107- EXPECT_STREQ (c10::toString (static_cast <c10::ScalarType>(-1 )),
108- " UNKNOWN_SCALAR" );
109-
110- EXPECT_EQ (c10::elementSize (c10::ScalarType::QInt8), static_cast <size_t >(1 ));
111- EXPECT_EQ (c10::elementSize (c10::ScalarType::QUInt4x2),
112- static_cast <size_t >(1 ));
113- EXPECT_EQ (c10::elementSize (c10::ScalarType::QInt32), static_cast <size_t >(4 ));
114- EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits16), static_cast <size_t >(4 ));
115- EXPECT_THROW (c10::elementSize (c10::ScalarType::Undefined), ::std::exception);
116-
117- EXPECT_TRUE (c10::isIntegralType (c10::ScalarType::Bool, true ));
118- EXPECT_FALSE (c10::isIntegralType (c10::ScalarType::Bool, false ));
119- EXPECT_TRUE (c10::isFloat8Type (c10::ScalarType::Float8_e5m2));
120- EXPECT_FALSE (c10::isFloat8Type (c10::ScalarType::Float8_e4m3fnuz));
121- EXPECT_TRUE (c10::isReducedFloatingType (c10::ScalarType::BFloat16));
122- EXPECT_TRUE (c10::isFloatingType (c10::ScalarType::Float));
123- EXPECT_FALSE (c10::isComplexType (c10::ScalarType::ComplexHalf));
124-
125- EXPECT_TRUE (c10::isSignedType (c10::ScalarType::Int1));
126- EXPECT_FALSE (c10::isSignedType (c10::ScalarType::UInt3));
127- EXPECT_FALSE (c10::isSignedType (c10::ScalarType::QUInt8));
128- EXPECT_TRUE (c10::isSignedType (c10::ScalarType::Float8_e5m2fnuz));
129- EXPECT_THROW (c10::isSignedType (c10::ScalarType::Undefined), ::std::exception);
130-
131- std::ostringstream oss;
132- oss << c10::ScalarType::UInt7;
133- EXPECT_EQ (oss.str (), " UInt7" );
134- }
135-
136- TEST (ScalarTypeCompatTest, AdditionalEnumAndPredicateBranches) {
137- EXPECT_STREQ (c10::toString (c10::ScalarType::QInt8), " QInt8" );
138- EXPECT_STREQ (c10::toString (c10::ScalarType::QUInt8), " QUInt8" );
139- EXPECT_STREQ (c10::toString (c10::ScalarType::QInt32), " QInt32" );
140- EXPECT_STREQ (c10::toString (c10::ScalarType::QUInt4x2), " QUInt4x2" );
141- EXPECT_STREQ (c10::toString (c10::ScalarType::QUInt2x4), " QUInt2x4" );
142- EXPECT_STREQ (c10::toString (c10::ScalarType::ComplexHalf), " ComplexHalf" );
143- EXPECT_STREQ (c10::toString (c10::ScalarType::Bits2x4), " Bits2x4" );
144- EXPECT_STREQ (c10::toString (c10::ScalarType::Bits4x2), " Bits4x2" );
145- EXPECT_STREQ (c10::toString (c10::ScalarType::Bits8), " Bits8" );
146-
147- EXPECT_EQ (c10::elementSize (c10::ScalarType::QUInt8), static_cast <size_t >(1 ));
148- EXPECT_EQ (c10::elementSize (c10::ScalarType::QUInt2x4),
149- static_cast <size_t >(1 ));
150- EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits2x4), static_cast <size_t >(1 ));
151- EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits4x2), static_cast <size_t >(1 ));
152- EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits8), static_cast <size_t >(1 ));
153-
154- EXPECT_TRUE (c10::isIntegralType (c10::ScalarType::UInt64, false ));
155- EXPECT_FALSE (c10::isIntegralType (c10::ScalarType::Float, true ));
156- EXPECT_TRUE (c10::isFloat8Type (c10::ScalarType::Float8_e4m3fn));
157- EXPECT_TRUE (c10::isReducedFloatingType (c10::ScalarType::Half));
158- EXPECT_FALSE (c10::isReducedFloatingType (c10::ScalarType::Float));
159- EXPECT_TRUE (c10::isFloatingType (c10::ScalarType::Half));
160- EXPECT_TRUE (c10::isComplexType (c10::ScalarType::ComplexFloat));
161130
162- EXPECT_TRUE (c10::isSignedType (c10::ScalarType::QInt8));
163- EXPECT_TRUE (c10::isSignedType (c10::ScalarType::ComplexHalf));
164- EXPECT_FALSE (c10::isSignedType (c10::ScalarType::Byte));
165- EXPECT_FALSE (c10::isSignedType (c10::ScalarType::Bool));
166- EXPECT_THROW (c10::isSignedType (c10::ScalarType::NumOptions),
167- ::std::exception);
131+ EXPECT_EQ (c10::elementSize (c10::ScalarType::ComplexHalf),
132+ sizeof (at::Half) * 2 );
133+ EXPECT_EQ (c10::elementSize (c10::ScalarType::QInt8), 1U );
134+ EXPECT_EQ (c10::elementSize (c10::ScalarType::QUInt8), 1U );
135+ EXPECT_EQ (c10::elementSize (c10::ScalarType::QInt32), 4U );
136+ EXPECT_EQ (c10::elementSize (c10::ScalarType::QUInt4x2), 1U );
137+ EXPECT_EQ (c10::elementSize (c10::ScalarType::QUInt2x4), 1U );
138+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits1x8), 1U );
139+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits2x4), 1U );
140+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits4x2), 1U );
141+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits8), 1U );
142+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Bits16), 2U );
143+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Float8_e5m2fnuz), 1U );
144+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Float8_e4m3fnuz), 1U );
145+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Float8_e8m0fnu), 1U );
146+ EXPECT_EQ (c10::elementSize (c10::ScalarType::Float4_e2m1fn_x2), 1U );
147+
148+ EXPECT_TRUE (c10::isComplexType (c10::ScalarType::ComplexHalf));
149+ EXPECT_TRUE (c10::isFloat8Type (c10::ScalarType::Float8_e5m2fnuz));
150+ EXPECT_TRUE (c10::isFloat8Type (c10::ScalarType::Float8_e4m3fnuz));
151+ EXPECT_TRUE (c10::isFloat8Type (c10::ScalarType::Float8_e8m0fnu));
152+ EXPECT_TRUE (c10::isReducedFloatingType (c10::ScalarType::Float4_e2m1fn_x2));
168153}
0 commit comments