修复 DeepEPv2 反向 MoE 场景的路由权重梯度传播问题#672
Conversation
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> (cherry picked from commit 7a9715dffbb47a33a7aebbb337d507b17d3408ec)
ds-review-bot
left a comment
There was a problem hiding this comment.
🤖 ds-review-bot Code Review
Model 1
本 MR 为 DeepEPv2 反向 MoE 场景修复路由权重梯度的传播问题。核心思路是在 dispatch 阶段让 aux_weights(每个 (t,k) 的路由梯度标量)搭乘已有的 topk_weights 通道,经同一个 dispatch 同时产出 recv_topk_weights(路由)与 recv_aux_weights(逐行标量),从而在 combine 反向时把 handle.recv_aux_weights 作为 topk_weights 输入,一次归约即得到 droute_weights_sk。同时:1) 令 cached+expand 重放不再重跑竞态相关的 atomicAdd,而是复用 forward 记录在 recv_src_metadata[:, 2+k] 的展开行映射,使重放与 forward 逐行对齐;2) 为非 expand 的 combine 反向新增 fused scatter,把接收行按 forward 记录的 NvS 目标就地 scatter 到 out_nvs,融合 Python 端的 gather/index_copy。整体设计自洽、与现有基础设施(layout 元数据偏移、num_metadata_bytes 尾部新增 K*float 区域、dispatch/combine 的 slot 映射、网格同步)衔接紧密,未发现破坏已有前向行为的改动,aux 区域仅在反向被实际使用。
主要需确认的一点是 combine push 中通过裸指针向对端/对称内存写入权重标量(而非像隐层数据那样走 TMA + fence),在对端可见性上存在平台相关的风险,建议验证。
Model 2
本 MR 修复 DeepEPv2 反向 MoE 场景下路由权重(router weight)梯度的传播问题,引入两条新通路:(1) forward dispatch 新增 aux_weights(每 (token,k) 标量),与 topk_weights 并行携带并在 expand 模式下产出 recv_aux_weights;(2) combine 支持 expand 模式下传入 1D [NvS] 的 topk_weights(路由权重梯度),通过 expanded-send 推送并在 reduce epilogue 还原为 [S, num_topk] 的 combined_topk_weights。同时支持 cached+expand 回放(复用 forward 记录的 expanded-row 映射避免 race-dependent 的 atomicAdd),以及 combine-backward 的 fused scatter(在 dispatch copy epilogue 直接 scatter 到 NvS 输出)。layout 在 metadata 末尾新增 K*float 的 aux 区域,集中修改且 buffer size 计算自动同步。
整体评估:实现严谨,host 端断言对支持的配置(allow_multiple_reduction=False、单 NVLink 域、expert_alignment==1)做了清晰约束,C++/Python/kernel 三侧参数传递一致,新增字段均沿用既有 num_allocated_tokens/TokenLayout 模式,未发现明确 bug。layout metadata 无条件增大 K*float 会让所有 dispatch/combine 路径的 buffer 占用略增,但已通过统一的 buffer size 计算正确传播,属预期行为。
Files reviewed: 8
Issues found: 🟡 1 warning | 🔵 1 suggestion
Inline comments posted: 2
| if (topk_weights != nullptr) { | ||
| float* w_dst = gin.get_sym_ptr<team_t>(token_buffer.get_topk_weights_ptr(), src_rank_idx); | ||
| *w_dst = __ldg(topk_weights + slot_idx); | ||
| } |
There was a problem hiding this comment.
🟡 warning: 这里通过 gin.get_sym_ptr 拿到对端对称地址后,用裸指针 *w_dst = __ldg(...) 直接把单个权重标量写入对端缓冲区;而该 k-slot 对应的隐层数据在 nvlink_bypass 分支中是走 TMA store (tma_store_commit) 写入的。TMA 写入后 combine 再经 cudaGridDependencySynchronize 让 reduce epilogue 可见,但这条权重标量的裸写既没有 threadfence,也不参与 TMA 的刷写语义,能否被对端/读端可见完全依赖网格完成 + NVLink 的缓存一致性。在 Hopper 等 NVLink 缓存一致的系统上可能没问题,但在需要显式 flush/fence 才能让对端可见的平台(或对称内存被映射为 uncached 的情形)下,reduce epilogue 可能读到陈旧值,导致 droute_weights_sk 计算错误。建议与隐层数据保持一致(例如在写后加 __threadfence,或改为与现有数据路径一致的 fence/TMA 写入)并实际验证。
| // row-for-row aligned with the forward. `-1` rows (unselected k-slots) stay -1 and are | ||
| // skipped by the `dst_tensor_idx >= 0` store guards below. | ||
| if (lane_idx < kNumTopk) | ||
| dst_tensor_idx = recv_src_metadata[i * (2 + kNumTopk) + 2 + lane_idx]; |
There was a problem hiding this comment.
🔵 suggestion: cached+expand 重放(第 131 行)与 scatter_to_nvs(第 165 行)都以 i (0..num_recv_tokens-1) 为行索引读取 recv_src_metadata/scatter_src_metadata[i * (2+kNumTopk) + 2 + lane],没有对该元数据张量的行数做任何断言/边界检查。当前调用方传入的都是 handle.recv_src_metadata(其行数等于缓存句柄的 num_recv_tokens),正常路径下恰好对齐;但一旦调用方传入的行数与当前 num_recv_tokens 不一致,kernel 内会出现静默越界读写(元数据越界读取、out_nvs 越界写入),属于典型的灾难式失败模式。建议在 C++ 侧补充形状断言(如 scatter_src_metadata 行数需等于 num_recv_tokens),以防调用侧误用。
(cherry picked from commit 7a9715dffbb47a33a7aebbb337d507b17d3408ec)