[Cpp API Compatibility] Fix flashmla compile#78550
[Cpp API Compatibility] Fix flashmla compile#78550SigureMo merged 5 commits intoPaddlePaddle:developfrom
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Pull request overview
This PR updates Paddle’s C++ API compatibility layer to improve PyTorch-compat behavior and fix compilation issues encountered by FlashMLA (split out from #78484).
Changes:
- Removes
torch/cuda.hre-exports ofdevice_count/is_availableintoat::cudato avoid symbol/redefinition conflicts. - Refactors
c10::ScalarTypeenum/utilities and trims related unit tests. - Aligns some compat behaviors/messages (e.g.,
c10::Streamprinting, added bounds/defined checks in a few ATen ops).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/cpp/compat/c10_ScalarType_test.cc | Removes ScalarType utility branch tests (and an unused include). |
| paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h | Stops re-exporting device_count / is_available into at::cuda. |
| paddle/phi/api/include/compat/c10/core/Stream.h | Changes c10::Stream operator<< formatting to a PyTorch-like form. |
| paddle/phi/api/include/compat/c10/core/Stream.cpp | Changes unsupported native_handle() failure path/message construction. |
| paddle/phi/api/include/compat/c10/core/ScalarType.h | Refactors ScalarType enum generation and removes several explicit utility branches. |
| paddle/phi/api/include/compat/ATen/ops/std.h | Adds dimension-range validation for std implementation. |
| paddle/phi/api/include/compat/ATen/ops/select.h | Adds explicit dim/index range checks and negative index normalization. |
| paddle/phi/api/include/compat/ATen/ops/equal.h | Adds defined-tensor checks before comparing tensors. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| inline const char* toString(ScalarType t) { | ||
| #define DEFINE_CASE(_1, _2, name) \ | ||
| case ScalarType::name: \ | ||
| return #name; | ||
|
|
||
| switch (t) { | ||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) | ||
| case ScalarType::QInt8: | ||
| return "QInt8"; | ||
| case ScalarType::QUInt8: | ||
| return "QUInt8"; | ||
| case ScalarType::QInt32: | ||
| return "QInt32"; | ||
| case ScalarType::QUInt4x2: | ||
| return "QUInt4x2"; | ||
| case ScalarType::QUInt2x4: | ||
| return "QUInt2x4"; | ||
| case ScalarType::ComplexHalf: | ||
| return "ComplexHalf"; | ||
| case ScalarType::Bits1x8: | ||
| return "Bits1x8"; | ||
| case ScalarType::Bits2x4: | ||
| return "Bits2x4"; | ||
| case ScalarType::Bits4x2: | ||
| return "Bits4x2"; | ||
| case ScalarType::Bits8: | ||
| return "Bits8"; | ||
| case ScalarType::Bits16: | ||
| return "Bits16"; | ||
| case ScalarType::Float8_e5m2fnuz: | ||
| return "Float8_e5m2fnuz"; | ||
| case ScalarType::Float8_e4m3fnuz: | ||
| return "Float8_e4m3fnuz"; | ||
| case ScalarType::Float8_e8m0fnu: | ||
| return "Float8_e8m0fnu"; | ||
| case ScalarType::Float4_e2m1fn_x2: | ||
| return "Float4_e2m1fn_x2"; | ||
| case ScalarType::Undefined: | ||
| return "Undefined"; | ||
| default: | ||
| return "UNKNOWN_SCALAR"; | ||
| } | ||
| #undef DEFINE_CASE | ||
| } | ||
|
|
||
| inline size_t elementSize(ScalarType t) { | ||
| #define CASE_ELEMENTSIZE_CASE(ctype, _2, name) \ | ||
| case ScalarType::name: \ | ||
| return sizeof(ctype); | ||
|
|
||
| switch (t) { | ||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) | ||
| case ScalarType::QInt8: | ||
| case ScalarType::QUInt8: | ||
| case ScalarType::QUInt4x2: | ||
| case ScalarType::QUInt2x4: | ||
| case ScalarType::Bits1x8: | ||
| case ScalarType::Bits2x4: | ||
| case ScalarType::Bits4x2: | ||
| case ScalarType::Bits8: | ||
| return 1; | ||
| case ScalarType::QInt32: | ||
| case ScalarType::Bits16: | ||
| return 4; | ||
| default: | ||
| TORCH_CHECK(false, "Unknown ScalarType"); | ||
| } |
There was a problem hiding this comment.
ScalarType still declares the quantized enums via AT_FORALL_QINT_TYPES(...), but toString() and elementSize() no longer handle QInt8/QUInt8/QInt32/QUInt4x2/QUInt2x4 (the switches only expand AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS). This is a behavioral regression: printing these scalar types will return "UNKNOWN_SCALAR", and elementSize() will throw. Add QInt cases back (e.g., by also expanding AT_FORALL_QINT_TYPES in these switches) and ensure their element sizes match the intended semantics.
| PD_CHECK(false, | ||
| "native_handle() is not supported for this device type (", |
There was a problem hiding this comment.
PD_CHECK(false, ...) will throw a PD_Exception that appends an additional "Expected false, but it's not satisfied." context. If the goal is to match PyTorch's error message format exactly (as the comment states), this will not produce an exact match. Prefer throwing directly (e.g., TORCH_CHECK(false, ...) / PD_THROW(...)) or pass the real condition into PD_CHECK(...) instead of false so the extra context remains meaningful.
| PD_CHECK(false, | |
| "native_handle() is not supported for this device type (", | |
| PD_THROW("native_handle() is not supported for this device type (", |
| if (dim < 0 || dim >= self.dim()) { | ||
| PD_CHECK(false, | ||
| "select(): index ", | ||
| orig_dim, | ||
| " out of range for tensor of size ", | ||
| self.sizes(), | ||
| " at dimension ", | ||
| orig_dim); | ||
| } | ||
| // Handle negative index | ||
| int64_t orig_index = index; | ||
| if (index < 0) { | ||
| int64_t dim_size = self.size(dim); | ||
| index = dim_size + index; | ||
| index = self.size(dim) + index; | ||
| } | ||
| // Check index is valid | ||
| if (index < 0 || index >= self.size(dim)) { | ||
| PD_CHECK(false, | ||
| "select(): index ", | ||
| orig_index, | ||
| " out of range for tensor of size ", | ||
| self.sizes(), | ||
| " at dimension ", | ||
| orig_dim < 0 ? orig_dim + self.dim() : orig_dim); | ||
| } |
There was a problem hiding this comment.
New error handling was added for invalid dim / index values, but the existing select tests (e.g., test/cpp/compat/ATen_select_test.cc) don't cover these out-of-range branches. Add test cases that assert an exception is thrown for (1) dim out of range (including negative beyond -self.dim()), and (2) index out of range (including negative beyond -size(dim)).
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #78550 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 2
Lines ? 6
Branches ? 0
===========================================
Hits ? 6
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
拆分自 #78484
修复编译 FlashMLA 遇到的 redefinition 错误。如果要用
at::cuda::is_available(),需要统一使用torch::cuda::is_available(),移除torch/cuda.h中重复导出到at::cuda命名空间的定义。变更详情
1. 修复
torch/cuda.h头文件问题原代码将
torch::cuda::is_available()重复导出到at::cuda命名空间,导致与ATen/cuda/CUDAContext.h中的定义冲突。FlashMLA 使用at::cuda::is_available()时会触发编译错误。修复策略是统一入口:
torch/cuda.h中导出is_available到at::cudatorch::cuda::is_available()作为跨库兼容入口2. 更新测试代码使用
torch::cuda::is_available()修改以下测试文件:
test/cpp/compat/ATen_TensorAccessor_test.cctest/cpp/compat/compat_basic_test.cctest/cpp/compat/c10_storage_test.cc3. 修复 Stream 相关接口 (
c10/core/Stream.h,c10/core/Stream.cpp)Stream的比较运算符行为native_handle()在 CPU 设备上的异常抛出语义4. 精简 ScalarType 实现
暂时移除部分尚未完全对齐的
ScalarType相关代码(将在后续 PR 中重新引入)。相关文档
是否引起精度变化
否