Skip to content

[Cpp API Compatibility] Align misc apis#78555

Merged
SigureMo merged 9 commits intoPaddlePaddle:developfrom
youge325:align-misc-apis
Apr 3, 2026
Merged

[Cpp API Compatibility] Align misc apis#78555
SigureMo merged 9 commits intoPaddlePaddle:developfrom
youge325:align-misc-apis

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Apr 1, 2026

PR Category

Execute Infrastructure

PR Types

Improvements

Description

拆分自 #78484

杂项 API 对齐:Allocator、CUDAFunctions、CUDAGuard、CUDAStream、version.h 等。

变更详情

1. c10/core/Allocator.h 接口补齐

新增内容:

// 类型定义
typedef uint64_t CaptureId_t;
typedef int64_t MempoolId_t;
struct MempoolIdHash { ... };

// DataPtr 方法
void* mutable_get() const;

// Allocator 注册
void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority);
Allocator* GetAllocator(const DeviceType& t);

// 辅助类
struct InefficientStdFunctionContext {
  static DataPtr makeDataPtr(void* ptr, std::function<void(void*)> deleter, Device device);
};

行为修正:

  • is_simple_data_ptr() 语义修正为 get() == get_context()(与 PyTorch 一致)

2. c10/cuda/CUDAFunctions.h/cpp 重构

  • CUDAFunctions.h 中的部分实现移到新文件 CUDAFunctions.cpp
  • 修复 Windows 构建问题

3. c10/cuda/CUDAGuard.h 接口补齐

新增方法:

// CUDAGuard
DeviceIndex original_device() const;

// OptionalCUDAGuard
DeviceIndex original_device() const;
void reset();

4. c10/cuda/CUDAStream.h 扩展

新增功能:

// 优先级支持
static std::tuple<int, int> priority_range();
CUDAStream(Unchecked /*unused*/, Stream stream);
int priority() const;

// 序列化
StreamData3 pack3() const;
static CUDAStream unpack3(StreamId stream_id, DeviceIndex device_index, DeviceType device_type);

// 外部流包装
CUDAStream getStreamFromExternal(cudaStream_t ext_stream, c10::DeviceIndex device_index);

// 查询与同步
bool query() const;
void synchronize() const;

// 哈希支持
struct hash<CUDAStream>;

5. 新增 torch/version.h

提供与 PyTorch 版本信息相关的宏,用于第三方库(如 DeepGEMM、FlashMLA)检测 Paddle 兼容层的版本。

6. 测试修复

  • test/cpp/compat/ATen_TensorAccessor_test.cc: 修复 CUDA 测试条件
  • test/cpp/compat/c10_storage_test.cc: 修复 is_simple_data_ptr 断言
  • test/cpp/compat/c10_layout_test.cc: 补充 SparseCooTensorInferSize 断言
  • test/cpp/compat/compat_basic_test.cc: 修复 CUDA 测试条件

相关文档

是否引起精度变化

Copilot AI review requested due to automatic review settings April 1, 2026 11:19
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 1, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 1, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR continues the C++ API compatibility alignment work by bringing several misc compat-layer APIs and behaviors closer to LibTorch/PyTorch, and updating the corresponding C++ compat tests.

Changes:

  • Switch CUDA availability checks in compat tests to torch::cuda::is_available().
  • Expand c10::cuda::CUDAStream / c10::cuda::CUDAGuard compat APIs and adjust stream-pool and allocator behaviors to better match PyTorch.
  • Add missing compat headers/sources (e.g., torch/version.h, c10/cuda/CUDAFunctions.cpp) and strengthen sparse tensor test assertions.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test/cpp/compat/compat_basic_test.cc Uses torch::cuda::is_available() for CUDA-gated test sections.
test/cpp/compat/c10_storage_test.cc Aligns CUDA availability checks and updates is_simple_data_ptr expectations to PyTorch semantics.
test/cpp/compat/c10_layout_test.cc Adds extra sparse COO shape/dim assertions.
test/cpp/compat/ATen_TensorAccessor_test.cc Uses torch::cuda::is_available() for CUDA-gated accessor test.
paddle/phi/api/include/compat/torch/csrc/api/include/torch/version.h Introduces LibTorch-style version macros for the compat layer.
paddle/phi/api/include/compat/CMakeLists.txt Adds c10/cuda/CUDAFunctions.cpp to the compat build sources.
paddle/phi/api/include/compat/c10/cuda/CUDAStream.h Adds stream API surface (priority/query/sync/pack/unpack/hash/etc.) and new pool/external stream helpers.
paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h Tracks original/current device to better emulate PyTorch guard semantics.
paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.h Moves device_count/device_synchronize to out-of-line definitions; keeps stream sync helper gated by CUDA/HIP.
paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.cpp Implements device_count() and device_synchronize() out-of-line.
paddle/phi/api/include/compat/c10/core/Allocator.h Aligns is_simple_data_ptr semantics and adds allocator registry/utilities.
paddle/phi/api/include/compat/ATen/core/ivalue.h Adds camelCase API aliases and exposes c10::IValue alias for compatibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 230 to +240
/**
* Set the current CUDA stream for the device of the given stream in the
* calling thread.
*
* Implements per-thread, per-device current stream semantics: the change is
* local to the calling OS thread and does not affect any shared state such as
* Paddle's GPUContext. Other threads continue to see their own current stream.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority,
c10::DeviceIndex device_index) {
return getStreamFromPool(isHighPriority ? -1 : 0, device_index);
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comment above getStreamFromPool(...) describes "Set the current CUDA stream..." but the function below is a stream-pool accessor. This mismatch is misleading for users and makes the header harder to maintain; update the comment to describe getStreamFromPool (and keep setCurrentCUDAStream documented at its own definition).

Copilot uses AI. Check for mistakes.
* Paddle's GPUContext. Other threads continue to see their own current stream.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority,
c10::DeviceIndex device_index) {
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getStreamFromPool(bool isHighPriority, DeviceIndex ...) no longer has a default device_index. This makes calls like getStreamFromPool(true) either fail to compile or (more dangerously) bind to the getStreamFromPool(int priority, DeviceIndex=-1) overload with priority=1, changing semantics (high-priority requested but low-priority returned). Add device_index = -1 to the bool overload (matching PyTorch) and ensure bool arguments cannot silently resolve to the int overload.

Suggested change
c10::DeviceIndex device_index) {
c10::DeviceIndex device_index = -1) {

Copilot uses AI. Check for mistakes.
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 1, 2026

/re-run all-failed

1 similar comment
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

1 similar comment
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 95.55556% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@bb246f6). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...addle/phi/api/include/compat/c10/cuda/CUDAStream.h 80.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #78555   +/-   ##
==========================================
  Coverage           ?   95.55%           
==========================================
  Files              ?        4           
  Lines              ?       45           
  Branches           ?        0           
==========================================
  Hits               ?       43           
  Misses             ?        2           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@SigureMo SigureMo closed this Apr 3, 2026
@SigureMo SigureMo reopened this Apr 3, 2026
Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit 6d87891 into PaddlePaddle:develop Apr 3, 2026
91 of 92 checks passed
@youge325 youge325 deleted the align-misc-apis branch April 3, 2026 01:39
liuhao2638 pushed a commit to liuhao2638/Paddle that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants