Skip to content

[Cpp API Compatibility] Align event api#78553

Merged
SigureMo merged 8 commits intoPaddlePaddle:developfrom
youge325:align-event-api
Apr 2, 2026
Merged

[Cpp API Compatibility] Align event api#78553
SigureMo merged 8 commits intoPaddlePaddle:developfrom
youge325:align-event-api

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Apr 1, 2026

PR Category

Execute Infrastructure

PR Types

Improvements

Description

拆分自 #78484

Align Event related APIs,将 c10::Event 实现为完整的跨平台兼容类。

变更详情

问题背景

历史上 c10::Event 存在以下问题:

  1. #ifdef PADDLE_WITH_CUDA 包裹,非 CUDA 构建时不可用
  2. 缺少 EventFlag 枚举和对应构造函数
  3. 缺少移动语义和属性读取接口
  4. CPU 路径异常行为不一致

变更内容 (c10/core/Event.h)

重构 c10::Event 为完整的跨平台兼容类:

// EventFlag 枚举
enum class EventFlag { PYTORCH_DEFAULT, BACKEND_DEFAULT, INVALID };

// 完整构造函数
Event(const DeviceType device_type,
      const EventFlag flag = EventFlag::PYTORCH_DEFAULT);

// 禁止拷贝、允许移动
Event(const Event&) = delete;
Event& operator=(const Event&) = delete;
Event(Event&&) = default;
Event& operator=(Event&&) = default;

// 属性读取
Device device() const noexcept;
DeviceType device_type() const noexcept;
DeviceIndex device_index() const noexcept;
EventFlag flag() const noexcept;
bool was_marked_for_recording() const noexcept;

// 记录方法
void record(const Stream& stream);
void recordOnce(const Stream& stream);

// 同步/查询
void block(const Stream& stream) const;
bool query() const;
void synchronize() const;
double elapsedTime(const Event& event) const;
void* eventId() const;

跨平台支持

  • c10::Event 移出 #ifdef PADDLE_WITH_CUDA 条件编译
  • CUDA 构建时包含完整的 CUDA event 功能
  • 非 CUDA 构建时提供占位实现,抛出 Backend doesn't support events 异常

对齐效果

测试项 当前 Paddle PyTorch 结论
EventCompatTest.EventDefault / EventWithFlag / EventRecordThrows / EventRecordOnceThrows / EventMove / EventDevice 已进入常规回归 一致 ✅ 已纳入回归

相关文档

是否引起精度变化

Copilot AI review requested due to automatic review settings April 1, 2026 11:17
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 1, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 1, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the C++ compat layer to better match PyTorch’s stream/event APIs, including stream-pool signature parity and a more feature-complete c10::Event surface.

Changes:

  • Update c10::cuda::getStreamFromPool to support the (int priority, DeviceIndex) overload and map negative priorities to the high-priority pool.
  • Expand c10::Event to track device metadata, add PyTorch-like methods (e.g., recordOnce, query, eventId, synchronize), and tighten stream/device-type checks.
  • Remove deprecated/raw stream record_stream overloads and legacy raw_stream() exposure.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
paddle/phi/api/include/compat/c10/cuda/CUDAStream.h Adds PyTorch-shaped getStreamFromPool(int priority, ...) and updates priority selection logic.
paddle/phi/api/include/compat/c10/core/Stream.h Exposes c10::Stream into namespace at for API compatibility.
paddle/phi/api/include/compat/c10/core/Event.h Refactors/extends c10::Event interface and behavior to resemble PyTorch.
paddle/phi/api/include/compat/ATen/ops/record_stream.h Removes deprecated record_stream overloads (CUDAStream / cudaStream_t) in favor of at::Stream.
paddle/phi/api/include/compat/ATen/core/TensorBody.h Removes the corresponding record_stream overload declarations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

#include <c10/core/Stream.h>

#ifdef PADDLE_WITH_CUDA
#include <c10/cuda/CUDAStream.h>
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header uses std::mutex/std::unique_lock in EventPool, but doesn't include <mutex> directly (it currently relies on transitive includes from other headers). Adding an explicit #include <mutex> under the PADDLE_WITH_CUDA guard would make this file self-contained and less fragile to include-order changes.

Suggested change
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAStream.h>
#include <mutex>

Copilot uses AI. Check for mistakes.
Comment on lines 687 to 689
}

void record_stream(at::Stream s) const;
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the record_stream(at::cuda::CUDAStream) overload was removed, the forward-declaration block for at::cuda::CUDAStream near the top of this file (and its comment) appears unused and misleading (the file no longer declares any overload needing it). Consider removing that forward declaration/comment to avoid stale documentation and unnecessary namespace pollution.

Copilot uses AI. Check for mistakes.
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 89.28571% with 9 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@638c7c3). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddle/phi/api/include/compat/c10/core/Event.h 89.28% 9 Missing ⚠️

❌ Your patch status has failed because the patch coverage (89.28%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #78553   +/-   ##
==========================================
  Coverage           ?   89.28%           
==========================================
  Files              ?        1           
  Lines              ?       84           
  Branches           ?        0           
==========================================
  Hits               ?       75           
  Misses             ?        9           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@SigureMo
Copy link
Copy Markdown
Member

SigureMo commented Apr 2, 2026

/skip-reason: 未覆盖分支均为正常逻辑不可达分支,为确保与 PyTorch 逻辑对齐而添加的完善逻辑

Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit 0a5a78c into PaddlePaddle:develop Apr 2, 2026
148 of 157 checks passed
@youge325 youge325 deleted the align-event-api branch April 2, 2026 13:31
liuhao2638 pushed a commit to liuhao2638/Paddle that referenced this pull request Apr 7, 2026
YuhanXu pushed a commit to YuhanXu/Paddle that referenced this pull request Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants