Skip to content

Commit 8946a32

Browse files
youge325liuhao2638
authored andcommitted
[Cpp API Compatibility] Fix flashmla compile (PaddlePaddle#78550)
1 parent 52c1a86 commit 8946a32

File tree

8 files changed

+135
-86
lines changed

8 files changed

+135
-86
lines changed

paddle/phi/api/include/compat/ATen/ops/equal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
namespace at {
2323

2424
inline bool equal(const at::Tensor& self, const at::Tensor& other) {
25+
PD_CHECK(self.defined(),
26+
"Expected a proper Tensor but got None (or an undefined Tensor in "
27+
"C++)");
28+
PD_CHECK(other.defined(),
29+
"Expected a proper Tensor but got None (or an undefined Tensor in "
30+
"C++)");
2531
PD_CHECK(self.device() == other.device(),
2632
"Cannot compare two tensors on "
2733
"different devices. Got: ",

paddle/phi/api/include/compat/ATen/ops/select.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,35 @@
1919
namespace at {
2020

2121
inline at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index) {
22+
// Normalize dim to positive value for error messages
23+
int64_t orig_dim = dim;
2224
if (dim < 0) {
2325
dim += self.dim();
2426
}
25-
// Handle negative indexing
27+
// Check dim is valid
28+
if (dim < 0 || dim >= self.dim()) {
29+
PD_CHECK(false,
30+
"select(): index ",
31+
orig_dim,
32+
" out of range for tensor of size ",
33+
self.sizes(),
34+
" at dimension ",
35+
orig_dim);
36+
}
37+
// Handle negative index
38+
int64_t orig_index = index;
2639
if (index < 0) {
27-
int64_t dim_size = self.size(dim);
28-
index = dim_size + index;
40+
index = self.size(dim) + index;
41+
}
42+
// Check index is valid
43+
if (index < 0 || index >= self.size(dim)) {
44+
PD_CHECK(false,
45+
"select(): index ",
46+
orig_index,
47+
" out of range for tensor of size ",
48+
self.sizes(),
49+
" at dimension ",
50+
orig_dim < 0 ? orig_dim + self.dim() : orig_dim);
2951
}
3052

3153
return Tensor(

paddle/phi/api/include/compat/ATen/ops/std.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ inline Tensor std_impl(const Tensor& self,
3232
const std::vector<int64_t>& dims_vec,
3333
double correction_value,
3434
bool keepdim) {
35+
// Validate dimensions before processing
36+
int64_t ndim = self.dim();
37+
for (int64_t d : dims_vec) {
38+
int64_t dim_idx = d < 0 ? d + ndim : d;
39+
if (dim_idx < 0 || dim_idx >= ndim) {
40+
PD_CHECK(false,
41+
"Dimension out of range (expected to be in range of [",
42+
-ndim,
43+
", ",
44+
ndim - 1,
45+
"], but got ",
46+
d,
47+
")");
48+
}
49+
}
3550
phi::IntArray dims_int_array(dims_vec);
3651
paddle::Tensor tensor = self._PD_GetInner();
3752

paddle/phi/api/include/compat/c10/core/ScalarType.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,23 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
225225
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
226226
#undef DEFINE_CONSTANT
227227

228+
constexpr ScalarType kComplexHalf = ScalarType::ComplexHalf;
229+
constexpr ScalarType kQInt8 = ScalarType::QInt8;
230+
constexpr ScalarType kQUInt8 = ScalarType::QUInt8;
231+
constexpr ScalarType kQInt32 = ScalarType::QInt32;
232+
constexpr ScalarType kQUInt4x2 = ScalarType::QUInt4x2;
233+
constexpr ScalarType kQUInt2x4 = ScalarType::QUInt2x4;
234+
constexpr ScalarType kBits1x8 = ScalarType::Bits1x8;
235+
constexpr ScalarType kBits2x4 = ScalarType::Bits2x4;
236+
constexpr ScalarType kBits4x2 = ScalarType::Bits4x2;
237+
constexpr ScalarType kBits8 = ScalarType::Bits8;
238+
constexpr ScalarType kBits16 = ScalarType::Bits16;
239+
constexpr ScalarType kFloat8_e5m2fnuz = ScalarType::Float8_e5m2fnuz;
240+
constexpr ScalarType kFloat8_e4m3fnuz = ScalarType::Float8_e4m3fnuz;
241+
constexpr ScalarType kFloat8_e8m0fnu = ScalarType::Float8_e8m0fnu;
242+
constexpr ScalarType kFloat4_e2m1fn_x2 = ScalarType::Float4_e2m1fn_x2;
243+
constexpr ScalarType kUndefined = ScalarType::Undefined;
244+
228245
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
229246
_(uint8_t, Byte) \
230247
_(int8_t, Char) \
@@ -281,6 +298,8 @@ inline const char* toString(ScalarType t) {
281298

282299
switch (t) {
283300
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
301+
case ScalarType::ComplexHalf:
302+
return "ComplexHalf";
284303
case ScalarType::QInt8:
285304
return "QInt8";
286305
case ScalarType::QUInt8:
@@ -291,8 +310,6 @@ inline const char* toString(ScalarType t) {
291310
return "QUInt4x2";
292311
case ScalarType::QUInt2x4:
293312
return "QUInt2x4";
294-
case ScalarType::ComplexHalf:
295-
return "ComplexHalf";
296313
case ScalarType::Bits1x8:
297314
return "Bits1x8";
298315
case ScalarType::Bits2x4:
@@ -326,6 +343,8 @@ inline size_t elementSize(ScalarType t) {
326343

327344
switch (t) {
328345
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
346+
case ScalarType::ComplexHalf:
347+
return sizeof(at::Half) * 2;
329348
case ScalarType::QInt8:
330349
case ScalarType::QUInt8:
331350
case ScalarType::QUInt4x2:
@@ -334,10 +353,15 @@ inline size_t elementSize(ScalarType t) {
334353
case ScalarType::Bits2x4:
335354
case ScalarType::Bits4x2:
336355
case ScalarType::Bits8:
356+
case ScalarType::Float8_e5m2fnuz:
357+
case ScalarType::Float8_e4m3fnuz:
358+
case ScalarType::Float8_e8m0fnu:
359+
case ScalarType::Float4_e2m1fn_x2:
337360
return 1;
338361
case ScalarType::QInt32:
339-
case ScalarType::Bits16:
340362
return 4;
363+
case ScalarType::Bits16:
364+
return 2;
341365
default:
342366
TORCH_CHECK(false, "Unknown ScalarType");
343367
}
@@ -354,15 +378,14 @@ inline bool isIntegralType(ScalarType t, bool includeBool) {
354378
}
355379

356380
inline bool isFloat8Type(ScalarType t) {
357-
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn;
358-
// || t == ScalarType::Float8_e5m2fnuz
359-
// || t == ScalarType::Float8_e4m3fnuz
360-
// || t == ScalarType::Float8_e8m0fnu
381+
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn ||
382+
t == ScalarType::Float8_e5m2fnuz || t == ScalarType::Float8_e4m3fnuz ||
383+
t == ScalarType::Float8_e8m0fnu;
361384
}
362385

363386
inline bool isReducedFloatingType(ScalarType t) {
364-
return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
365-
//|| t == ScalarType::Float4_e2m1fn_x2
387+
return t == ScalarType::Half || t == ScalarType::BFloat16 ||
388+
isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2;
366389
}
367390

368391
inline bool isFloatingType(ScalarType t) {
@@ -371,9 +394,8 @@ inline bool isFloatingType(ScalarType t) {
371394
}
372395

373396
inline bool isComplexType(ScalarType t) {
374-
return (
375-
/* t == ScalarType::ComplexHalf || */ t == ScalarType::ComplexFloat ||
376-
t == ScalarType::ComplexDouble);
397+
return (t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
398+
t == ScalarType::ComplexDouble);
377399
}
378400

379401
inline bool isSignedType(ScalarType t) {
@@ -408,9 +430,9 @@ inline bool isSignedType(ScalarType t) {
408430
CASE_ISSIGNED(Float8_e4m3fn);
409431

410432
// Complex types (treated as signed)
433+
case ScalarType::ComplexHalf:
411434
case ScalarType::ComplexFloat:
412435
case ScalarType::ComplexDouble:
413-
case ScalarType::ComplexHalf:
414436
return true;
415437

416438
// Signed quantized types (explicitly return true)

paddle/phi/api/include/compat/c10/core/Stream.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ void* Stream::native_handle() const {
4444
return reinterpret_cast<void*>(static_cast<intptr_t>(id_));
4545
}
4646
#endif
47-
PADDLE_THROW(::common::errors::Unimplemented(
48-
"c10::Stream::native_handle() is not supported for device type %d",
49-
static_cast<int>(device_type())));
47+
// Match PyTorch error message format for unsupported device types
48+
PD_CHECK(false,
49+
"native_handle() is not supported for this device type (",
50+
static_cast<int>(device_type()),
51+
")");
5052
}
5153

5254
bool Stream::query() const {

paddle/phi/api/include/compat/c10/core/Stream.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ class Stream final {
9090
};
9191

9292
inline std::ostream& operator<<(std::ostream& os, const Stream& s) {
93-
os << "Stream(device_type=" << static_cast<int>(s.device_type())
94-
<< ", device_index=" << static_cast<int>(s.device_index())
95-
<< ", id=" << s.id() << ")";
93+
// Format: "stream {id} on device {device_type}:{device_index}"
94+
os << "stream " << s.id() << " on device " << s.device();
9695
return os;
9796
}
9897

paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,5 @@ void synchronize(int64_t device_index = -1);
2828

2929
} // namespace torch::cuda
3030
namespace at::cuda {
31-
using torch::cuda::device_count;
32-
using torch::cuda::is_available;
3331
using torch::cuda::synchronize;
3432
} // namespace at::cuda

test/cpp/compat/c10_ScalarType_test.cc

Lines changed: 47 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include <c10/cuda/CUDAFunctions.h>
2626
#include <c10/cuda/CUDAGuard.h>
2727
#endif
28-
#include <sstream>
2928
#include "ATen/ATen.h"
3029
#include "gtest/gtest.h"
3130
#include "paddle/phi/common/float16.h"
@@ -92,8 +91,32 @@ TEST(TensorBaseTest, TypeCheckingAPIs) {
9291
ASSERT_FALSE(bool_tensor.is_signed());
9392
}
9493

95-
TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
94+
TEST(ScalarTypeTest, RestoredCompatScalarTypesKeepSourceLevelSemantics) {
95+
EXPECT_EQ(static_cast<int>(c10::ScalarType::ComplexHalf), 8);
96+
EXPECT_EQ(static_cast<int>(c10::ScalarType::QInt8), 12);
97+
EXPECT_EQ(static_cast<int>(c10::ScalarType::Bits16), 22);
98+
EXPECT_EQ(static_cast<int>(c10::ScalarType::Float8_e5m2fnuz), 25);
99+
EXPECT_EQ(static_cast<int>(c10::ScalarType::Float4_e2m1fn_x2), 45);
100+
EXPECT_EQ(c10::NumScalarTypes, 47);
101+
102+
EXPECT_EQ(c10::kComplexHalf, c10::ScalarType::ComplexHalf);
103+
EXPECT_EQ(c10::kQInt8, c10::ScalarType::QInt8);
104+
EXPECT_EQ(c10::kBits16, c10::ScalarType::Bits16);
105+
EXPECT_EQ(c10::kFloat8_e4m3fnuz, c10::ScalarType::Float8_e4m3fnuz);
106+
EXPECT_EQ(c10::kFloat8_e8m0fnu, c10::ScalarType::Float8_e8m0fnu);
107+
EXPECT_EQ(c10::kFloat4_e2m1fn_x2, c10::ScalarType::Float4_e2m1fn_x2);
108+
EXPECT_EQ(c10::kUndefined, c10::ScalarType::Undefined);
109+
110+
EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf");
111+
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8");
112+
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8");
113+
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32");
114+
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2");
115+
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4");
96116
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits1x8), "Bits1x8");
117+
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4");
118+
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2");
119+
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8");
97120
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits16), "Bits16");
98121
EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e5m2fnuz),
99122
"Float8_e5m2fnuz");
@@ -104,65 +127,27 @@ TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
104127
EXPECT_STREQ(c10::toString(c10::ScalarType::Float4_e2m1fn_x2),
105128
"Float4_e2m1fn_x2");
106129
EXPECT_STREQ(c10::toString(c10::ScalarType::Undefined), "Undefined");
107-
EXPECT_STREQ(c10::toString(static_cast<c10::ScalarType>(-1)),
108-
"UNKNOWN_SCALAR");
109-
110-
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), static_cast<size_t>(1));
111-
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2),
112-
static_cast<size_t>(1));
113-
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), static_cast<size_t>(4));
114-
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), static_cast<size_t>(4));
115-
EXPECT_THROW(c10::elementSize(c10::ScalarType::Undefined), ::std::exception);
116-
117-
EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::Bool, true));
118-
EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Bool, false));
119-
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2));
120-
EXPECT_FALSE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz));
121-
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::BFloat16));
122-
EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Float));
123-
EXPECT_FALSE(c10::isComplexType(c10::ScalarType::ComplexHalf));
124-
125-
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Int1));
126-
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::UInt3));
127-
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::QUInt8));
128-
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Float8_e5m2fnuz));
129-
EXPECT_THROW(c10::isSignedType(c10::ScalarType::Undefined), ::std::exception);
130-
131-
std::ostringstream oss;
132-
oss << c10::ScalarType::UInt7;
133-
EXPECT_EQ(oss.str(), "UInt7");
134-
}
135-
136-
TEST(ScalarTypeCompatTest, AdditionalEnumAndPredicateBranches) {
137-
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8");
138-
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8");
139-
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32");
140-
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2");
141-
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4");
142-
EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf");
143-
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4");
144-
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2");
145-
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8");
146-
147-
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), static_cast<size_t>(1));
148-
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4),
149-
static_cast<size_t>(1));
150-
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), static_cast<size_t>(1));
151-
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), static_cast<size_t>(1));
152-
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), static_cast<size_t>(1));
153-
154-
EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::UInt64, false));
155-
EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Float, true));
156-
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fn));
157-
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Half));
158-
EXPECT_FALSE(c10::isReducedFloatingType(c10::ScalarType::Float));
159-
EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Half));
160-
EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexFloat));
161130

162-
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::QInt8));
163-
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::ComplexHalf));
164-
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Byte));
165-
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Bool));
166-
EXPECT_THROW(c10::isSignedType(c10::ScalarType::NumOptions),
167-
::std::exception);
131+
EXPECT_EQ(c10::elementSize(c10::ScalarType::ComplexHalf),
132+
sizeof(at::Half) * 2);
133+
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), 1U);
134+
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), 1U);
135+
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), 4U);
136+
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2), 1U);
137+
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4), 1U);
138+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits1x8), 1U);
139+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), 1U);
140+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), 1U);
141+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), 1U);
142+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), 2U);
143+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e5m2fnuz), 1U);
144+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e4m3fnuz), 1U);
145+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float8_e8m0fnu), 1U);
146+
EXPECT_EQ(c10::elementSize(c10::ScalarType::Float4_e2m1fn_x2), 1U);
147+
148+
EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexHalf));
149+
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2fnuz));
150+
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz));
151+
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e8m0fnu));
152+
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Float4_e2m1fn_x2));
168153
}

0 commit comments

Comments
 (0)