Skip to content

修复 DeepEPv2 反向 MoE 场景的路由权重梯度传播问题#672

Open
WANGWEI011 wants to merge 1 commit into
deepseek-ai:mainfrom
WANGWEI011:wangwei-deepepv2-0626
Open

修复 DeepEPv2 反向 MoE 场景的路由权重梯度传播问题#672
WANGWEI011 wants to merge 1 commit into
deepseek-ai:mainfrom
WANGWEI011:wangwei-deepepv2-0626

Conversation

@WANGWEI011

Copy link
Copy Markdown

(cherry picked from commit 7a9715dffbb47a33a7aebbb337d507b17d3408ec)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
(cherry picked from commit 7a9715dffbb47a33a7aebbb337d507b17d3408ec)

@ds-review-bot ds-review-bot left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 ds-review-bot Code Review

Model 1

本 MR 为 DeepEPv2 反向 MoE 场景修复路由权重梯度的传播问题。核心思路是在 dispatch 阶段让 aux_weights(每个 (t,k) 的路由梯度标量)搭乘已有的 topk_weights 通道,经同一个 dispatch 同时产出 recv_topk_weights(路由)与 recv_aux_weights(逐行标量),从而在 combine 反向时把 handle.recv_aux_weights 作为 topk_weights 输入,一次归约即得到 droute_weights_sk。同时:1) 令 cached+expand 重放不再重跑竞态相关的 atomicAdd,而是复用 forward 记录在 recv_src_metadata[:, 2+k] 的展开行映射,使重放与 forward 逐行对齐;2) 为非 expand 的 combine 反向新增 fused scatter,把接收行按 forward 记录的 NvS 目标就地 scatter 到 out_nvs,融合 Python 端的 gather/index_copy。整体设计自洽、与现有基础设施(layout 元数据偏移、num_metadata_bytes 尾部新增 K*float 区域、dispatch/combine 的 slot 映射、网格同步)衔接紧密,未发现破坏已有前向行为的改动,aux 区域仅在反向被实际使用。

主要需确认的一点是 combine push 中通过裸指针向对端/对称内存写入权重标量(而非像隐层数据那样走 TMA + fence),在对端可见性上存在平台相关的风险,建议验证。

Model 2

本 MR 修复 DeepEPv2 反向 MoE 场景下路由权重(router weight)梯度的传播问题,引入两条新通路:(1) forward dispatch 新增 aux_weights(每 (token,k) 标量),与 topk_weights 并行携带并在 expand 模式下产出 recv_aux_weights;(2) combine 支持 expand 模式下传入 1D [NvS]topk_weights(路由权重梯度),通过 expanded-send 推送并在 reduce epilogue 还原为 [S, num_topk]combined_topk_weights。同时支持 cached+expand 回放(复用 forward 记录的 expanded-row 映射避免 race-dependent 的 atomicAdd),以及 combine-backward 的 fused scatter(在 dispatch copy epilogue 直接 scatter 到 NvS 输出)。layout 在 metadata 末尾新增 K*float 的 aux 区域,集中修改且 buffer size 计算自动同步。

整体评估:实现严谨,host 端断言对支持的配置(allow_multiple_reduction=False、单 NVLink 域、expert_alignment==1)做了清晰约束,C++/Python/kernel 三侧参数传递一致,新增字段均沿用既有 num_allocated_tokens/TokenLayout 模式,未发现明确 bug。layout metadata 无条件增大 K*float 会让所有 dispatch/combine 路径的 buffer 占用略增,但已通过统一的 buffer size 计算正确传播,属预期行为。

Files reviewed: 8
Issues found: 🟡 1 warning | 🔵 1 suggestion
Inline comments posted: 2

Comment on lines +202 to +205
if (topk_weights != nullptr) {
float* w_dst = gin.get_sym_ptr<team_t>(token_buffer.get_topk_weights_ptr(), src_rank_idx);
*w_dst = __ldg(topk_weights + slot_idx);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 warning: 这里通过 gin.get_sym_ptr 拿到对端对称地址后,用裸指针 *w_dst = __ldg(...) 直接把单个权重标量写入对端缓冲区;而该 k-slot 对应的隐层数据在 nvlink_bypass 分支中是走 TMA store (tma_store_commit) 写入的。TMA 写入后 combine 再经 cudaGridDependencySynchronize 让 reduce epilogue 可见,但这条权重标量的裸写既没有 threadfence,也不参与 TMA 的刷写语义,能否被对端/读端可见完全依赖网格完成 + NVLink 的缓存一致性。在 Hopper 等 NVLink 缓存一致的系统上可能没问题,但在需要显式 flush/fence 才能让对端可见的平台(或对称内存被映射为 uncached 的情形)下,reduce epilogue 可能读到陈旧值,导致 droute_weights_sk 计算错误。建议与隐层数据保持一致(例如在写后加 __threadfence,或改为与现有数据路径一致的 fence/TMA 写入)并实际验证。

// row-for-row aligned with the forward. `-1` rows (unselected k-slots) stay -1 and are
// skipped by the `dst_tensor_idx >= 0` store guards below.
if (lane_idx < kNumTopk)
dst_tensor_idx = recv_src_metadata[i * (2 + kNumTopk) + 2 + lane_idx];

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 suggestion: cached+expand 重放(第 131 行)与 scatter_to_nvs(第 165 行)都以 i (0..num_recv_tokens-1) 为行索引读取 recv_src_metadata/scatter_src_metadata[i * (2+kNumTopk) + 2 + lane],没有对该元数据张量的行数做任何断言/边界检查。当前调用方传入的都是 handle.recv_src_metadata(其行数等于缓存句柄的 num_recv_tokens),正常路径下恰好对齐;但一旦调用方传入的行数与当前 num_recv_tokens 不一致,kernel 内会出现静默越界读写(元数据越界读取、out_nvs 越界写入),属于典型的灾难式失败模式。建议在 C++ 侧补充形状断言(如 scatter_src_metadata 行数需等于 num_recv_tokens),以防调用侧误用。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants