Skip to content

[Feature] Support CUDA Graph under mixed mode DeepEP communication#7344

Closed
lizexu123 wants to merge 8 commits intoPaddlePaddle:developfrom
lizexu123:cuda_graph_stream
Closed

[Feature] Support CUDA Graph under mixed mode DeepEP communication#7344
lizexu123 wants to merge 8 commits intoPaddlePaddle:developfrom
lizexu123:cuda_graph_stream

Conversation

@lizexu123
Copy link
Copy Markdown
Collaborator

@lizexu123 lizexu123 commented Apr 11, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

报错日志:
DeepEP/csrc/kernels/internode_ll.cu:553 operation would make the legacy stream depend on a capturing blocking stream
根本原因:

Python: low_latency_dispatch(return_recv_hook=True, async_finish=False)
    ↓
C++: deep_ep.cpp:1679 → launch_stream.stream() 传给 internode_ll::dispatch()
    ↓
C++: internode_ll.cu:520 → LAUNCH_KERNEL(&cfg, dispatch_func, ...)
    ↓
宏展开: launch.cuh:29 → CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ...))
    ↓
CUDA驱动: 检测到 legacy stream 与 capturing stream 之间有隐式依赖 → 报错!
// deep_ep.cpp:1614-1615
auto compute_stream = at::cuda::getCurrentCUDAStream();  // ← c10 的 TLS
auto launch_stream = return_recv_hook ? compute_stream : comm_stream.value();
at::cuda::getCurrentCUDAStream() 返回的是 c10 TLS 中的默认流(stream 0,即 legacy stream),而不是 Paddle 当前的 capture stream。
所以 cudaLaunchKernelEx 是在 legacy stream 上调用的,而 CUDA graph capture 在 Paddle 的CudaStreamDefault创建的流上进行的。

本次修复内容:

# 1. capture stream 用 Python API 创建 → 默认 cudaStreamNonBlocking
self._capture_stream = paddle.device.Stream()
# paddle/fluid/pybind/cuda_streams_py.cc:370
# → auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;

# 2. _DeepEPStreamGuard 同步 c10 TLS
with _DeepEPStreamGuard(self._capture_stream):
    # at::cuda::getCurrentCUDAStream() → 返回 capture_stream(不是 legacy stream)
    # cudaLaunchKernelEx 在 capture_stream 上调用 → 被 capture 进 graph
    # capture_stream 是 NonBlocking → 与 legacy stream 无隐式同步

本来可以很简单的实现,比如像sglang/python/sglang/srt/distributed/parallel_state.py:483-510中这样

@contextmanager
def graph_capture(self, stream=None):
    if stream is None:
        stream = torch.cuda.Stream()              # 1. 创建新流

    curr_stream = torch.cuda.current_stream()
    stream.wait_stream(curr_stream)               # 2. 新流等待当前流完成

    with torch.cuda.stream(stream):               # 3. 切换当前流
        yield graph_capture_context

with torch.cuda.stream(stream):
# PyTorch 的 current stream → stream ✓
# c10 的 TLS → stream ✓
# DeepEP 调用 at::cuda::getCurrentCUDAStream() → stream ✓
但是Paddle的paddle.device.stream_guard() 只更新了 Paddle 自己的 GPUContext,没有更新 c10 的 TLS:

with paddle.device.stream_guard(stream):
    # Paddle GPUContext → stream     ✓
    # c10 TLS → 没更新!             ✗
    # DeepEP 调用 at::cuda::getCurrentCUDAStream() → legacy stream (stream 0)  ✗

所以我们才需要用 ctypes 手动调用 c10::cuda::setCurrentCUDAStream() 来弥补这个差距。

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 11, 2026

Thanks for your contribution!

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ lizexu123
❌ root


root seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

fastdeploy-bot

This comment was marked as outdated.

fastdeploy-bot

This comment was marked as outdated.

@fastdeploy-bot
Copy link
Copy Markdown

🤖 AI Code Review | 2026-04-12

📋 Review 摘要

PR 概述:支持 CUDA Graph 在混合模式 DeepEP 通信下的使用
变更范围:model_executor/graph_optimization/、model_executor/forward_meta.py、model_executor/utils.py、model_executor/layers/moe/ep.py
影响面 Tag[Feature]

PR 规范检查

PR 标题包含有效 Tag [Feature],描述包含 Motivation、Modifications 等部分,符合规范。

问题

级别 文件 概述
🔴 Bug cudagraph_piecewise_backend.py:417 函数签名不匹配:is_decode 参数已删除但调用时仍传入

总体评价

PR 实现了通过 _DeepEPStreamGuard 同步 c10 stream 状态,解决了 DeepEP 在 CUDA Graph capture 时的 stream 冲突问题。核心逻辑清晰,但存在一个阻塞性 Bug 需要修复。另外新增的 audio_token_num 字段未被使用,建议确认是否需要或移除。


📍 具体问题

🔴 Bug - cudagraph_piecewise_backend.py:417

函数签名不匹配导致运行时错误

根据 diff,run_static_model 方法的签名已从 def run_static_model(self, entry: ConcreteSizeEntry, is_decode: bool = False, **kwargs) 修改为 def run_static_model(self, entry: ConcreteSizeEntry, **kwargs),删除了 is_decode 参数。

但此处的调用代码仍传入 is_decode=static_cudagraph_for_decode 参数,会导致运行时 TypeError

建议修复方式:

# 移除 is_decode 参数
return self.run_static_model(entry, **kwargs)

🟡 建议 - forward_meta.py:161

新增的 audio_token_num 字段未被使用

在代码库中搜索发现,audio_token_num 字段仅在定义处出现(第 161 行),没有其他地方使用此字段。这与 PR 的主题(CUDA Graph + DeepEP 通信)无关。

请确认:

  • 如果这是为后续功能预留的字段,建议在 PR 描述中说明用途
  • 如果是误加,建议移除此字段

🟡 建议 - utils.py:134

日志语言不统一且不够专业

新增的中文日志 logger.info("权重没初始化啊!") 使用了口语化表达。建议改为英文专业日志,与其他部分保持一致。

建议修改为:

logger.warning(f"Weight '{weight_name}' is not initialized, skipping process_weight_transpose")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants