From fb73fe94d0bd0b11e97924644265797754b7f608 Mon Sep 17 00:00:00 2001 From: wangwei02 Date: Tue, 23 Jun 2026 03:51:46 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20DeepEPv2=20=E5=8F=8D?= =?UTF-8?q?=E5=90=91=20MoE=20=E5=9C=BA=E6=99=AF=E7=9A=84=E8=B7=AF=E7=94=B1?= =?UTF-8?q?=E6=9D=83=E9=87=8D=E6=A2=AF=E5=BA=A6=E4=BC=A0=E6=92=AD=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 (1M context) (cherry picked from commit 7a9715dffbb47a33a7aebbb337d507b17d3408ec) --- csrc/elastic/buffer.hpp | 100 +++++++++++++++--- csrc/kernels/elastic/dispatch.hpp | 23 +++- deep_ep/buffers/elastic.py | 46 ++++++-- deep_ep/include/deep_ep/common/layout.cuh | 12 ++- deep_ep/include/deep_ep/impls/combine.cuh | 18 +++- .../deep_ep/impls/combine_reduce_epilogue.cuh | 38 +++++-- deep_ep/include/deep_ep/impls/dispatch.cuh | 3 + .../deep_ep/impls/dispatch_copy_epilogue.cuh | 47 +++++++- 8 files changed, 242 insertions(+), 45 deletions(-) diff --git a/csrc/elastic/buffer.hpp b/csrc/elastic/buffer.hpp index f3d73a7f8..11d242560 100644 --- a/csrc/elastic/buffer.hpp +++ b/csrc/elastic/buffer.hpp @@ -666,11 +666,13 @@ class ElasticBuffer { std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional, + std::optional, // recv_aux_weights: per-row router-weight gradient scalar in NvS layout std::optional> dispatch(const torch::Tensor& x, const std::optional& sf, const torch::Tensor& topk_idx, const std::optional& topk_weights, + const std::optional& aux_weights, // second per-(t,k) scalar carried alongside topk_weights const std::optional& cumulative_local_expert_recv_stats, const std::optional& cached_num_recv_tokens, const std::optional>& cached_num_recv_tokens_per_expert_list, @@ -679,6 +681,7 @@ class ElasticBuffer { const std::optional& cached_dst_buffer_slot_idx, const std::optional& cached_token_metadata_at_forward, const std::optional& cached_channel_linked_list, + const std::optional& cached_recv_src_metadata, // cached+expand replay reuses the recorded expanded-row map const int& num_max_tokens_per_rank, const int& num_experts, const int& expert_alignment, const int& num_sms, const int& num_qps, @@ -687,7 +690,9 @@ class ElasticBuffer { const bool& async_with_compute_stream, const bool& allocate_on_comm_stream, const bool& do_handle_copy, const bool& do_cpu_sync, const bool& do_expand, - const bool& use_tma_aligned_col_major_sf) const { + const bool& use_tma_aligned_col_major_sf, + const std::optional& scatter_to_nvs_out, + const std::optional& scatter_src_metadata) const { // Check SM count EP_HOST_ASSERT(num_sms > 0); @@ -746,6 +751,18 @@ class ElasticBuffer { topk_weights_ptr = topk_weights->data_ptr(); } + // Optional second per-(t,k) scalar (router-weight gradient), carried alongside topk_weights so a + // single dispatch yields both routing (recv_topk_weights) and this scalar (recv_aux_weights) in the + // same layout, letting combine output d_hidden and d_route_weights in one pass. + float* aux_weights_ptr = nullptr; + if (aux_weights.has_value()) { + const auto [num_tokens_a, num_topk_a] = get_shape<2>(aux_weights.value()); + EP_HOST_ASSERT(num_tokens == num_tokens_a and num_topk == num_topk_a); + EP_HOST_ASSERT(aux_weights->is_cuda() and aux_weights->is_contiguous()); + EP_HOST_ASSERT(aux_weights->scalar_type() == torch::kFloat); + aux_weights_ptr = aux_weights->data_ptr(); + } + // Expert receiving counter int* cumulative_local_expert_recv_stats_ptr = nullptr; if (cumulative_local_expert_recv_stats.has_value()) { @@ -947,6 +964,7 @@ class ElasticBuffer { EP_HOST_ASSERT(num_sms <= jit::device_runtime->get_num_sms()); launch_dispatch(x.data_ptr(), sf_ptr, topk_idx.data_ptr(), topk_weights_ptr, + aux_weights_ptr, copied_topk_idx_ptr, cumulative_local_expert_recv_stats_ptr, psum_num_recv_tokens_per_scaleup_rank.data_ptr(), @@ -977,12 +995,27 @@ class ElasticBuffer { // Assign these values according to modes if (cached_mode) { // Cached mode - // TODO: support to expand for MoE training backward with cached handles from non-expanding forward, - // which requires maintaining the same expanding order between forward and backward - EP_HOST_ASSERT(not do_expand and "Cannot do expand with cached mode"); + // Support expand replay for MoE training backward. The fresh expand dispatch that built this + // handle recorded the per-(token,k) expanded-row assignment into recv_src_metadata[:, 2+k]; the + // cached epilogue reuses that map (passed in via `cached_recv_src_metadata`) instead of + // re-running the race-dependent atomicAdd, so the replay's expanded layout is row-for-row + // identical to the forward (see dispatch_copy_epilogue.cuh). + if (do_expand) { + EP_HOST_ASSERT(cached_recv_src_metadata.has_value() and + "cached+expand replay requires cached recv_src_metadata"); + EP_HOST_ASSERT(expert_alignment == 1 and + "cached+expand replay is implemented for expert_alignment == 1"); + } EP_HOST_ASSERT(not do_cpu_sync and "Cannot do CPU sync with cached mode"); num_recv_tokens = cached_num_recv_tokens.value(); num_recv_tokens_per_expert_list = cached_num_recv_tokens_per_expert_list.value(); + if (do_expand) { + // Total expanded rows = sum of per-expert received counts (expert_alignment == 1, + // so there is no per-expert alignment padding). Matches the fresh do_cpu_sync path. + num_expanded_tokens = 0; + for (const auto c : num_recv_tokens_per_expert_list) + num_expanded_tokens += c; + } } else if (do_cpu_sync) { // Non-cached mode with sync const auto start_cpu_time = std::chrono::high_resolution_clock::now(); @@ -1046,14 +1079,20 @@ class ElasticBuffer { auto recv_sf = std::optional(); auto recv_topk_idx = std::optional(); auto recv_topk_weights = std::optional(); - auto recv_src_metadata = torch::empty( - {num_recv_tokens, num_topk + 2}, - torch::TensorOptions(torch::kCUDA).dtype(torch::kInt)); + auto recv_aux_weights = std::optional(); // per-row router-weight gradient scalar output + // In cached+expand replay, reuse the handle's recorded metadata (its [:, 2+k] columns hold the + // forward's expanded-row map, which the epilogue reads back). Otherwise allocate fresh. + auto recv_src_metadata = (cached_mode and do_expand) + ? cached_recv_src_metadata.value() + : torch::empty( + {num_recv_tokens, num_topk + 2}, + torch::TensorOptions(torch::kCUDA).dtype(torch::kInt)); // Optional tensors void* recv_sf_ptr = nullptr; topk_idx_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; + float* recv_aux_weights_ptr = nullptr; // per-row router-weight gradient scalar int recv_sf_token_stride = 0, recv_sf_hidden_stride = 0; if (sf.has_value()) { if (not use_tma_aligned_col_major_sf) { @@ -1077,18 +1116,38 @@ class ElasticBuffer { torch::empty({num_allocated_tokens, num_topk}, topk_weights->options()); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } + // recv_aux_weights only in expand mode (needs the per-row [NvS] scalar) + if (aux_weights.has_value() and do_expand) { + recv_aux_weights = torch::empty({num_allocated_tokens}, aux_weights->options()); + recv_aux_weights_ptr = recv_aux_weights->data_ptr(); + } // Process prefix sum, in expanding mode, it is also atomic counters if (do_expand) { - // Slice and exclusive part and do atomic additions into inclusive - EP_HOST_ASSERT(not cached_mode); - psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert.slice(0, 0, num_local_experts); + // For cached+expand the epilogue reuses the recorded expanded-row map (no atomic counters), + // and the cached psum is already shaped [num_local_experts]. Only the fresh expand path needs + // the exclusive prefix-sum slice used as atomic counters. + if (not cached_mode) { + // Slice the exclusive part and do atomic additions into inclusive + psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert.slice(0, 0, num_local_experts); + } } else if (not cached_mode) { // Slice the inclusive part (and will not be used in the epilogue) psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert.slice(0, 1, num_local_experts + 1); } EP_HOST_ASSERT(psum_num_recv_tokens_per_expert.size(0) == num_local_experts); + // Fused combine-backward scatter. When `scatter_to_nvs_out` is given (only meaningful for the + // non-expand cached backward), the copy epilogue additionally scatters each received row to + // out_nvs[scatter_src_metadata[i, 2+k]] -- the NvS destinations recorded by the forward expand + // dispatch -- fusing the Python-side gather/index_copy into this dispatch kernel. `scatter_src_metadata` + // must be the forward handle's recv_src_metadata ([num_recv, num_topk+2]). + const bool do_scatter_to_nvs = scatter_to_nvs_out.has_value() and not do_expand; + void* out_nvs_ptr = scatter_to_nvs_out.has_value() ? scatter_to_nvs_out->data_ptr() : nullptr; + const int* scatter_src_metadata_ptr = + scatter_src_metadata.has_value() ? scatter_src_metadata->data_ptr() : nullptr; + EP_HOST_ASSERT(not do_scatter_to_nvs or scatter_src_metadata_ptr != nullptr); + // Launch copy kernels with full SMs stream_control_before_epilogue(previous_event_before_epilogue); launch_dispatch_copy_epilogue(buffer, workspace, @@ -1096,6 +1155,7 @@ class ElasticBuffer { psum_num_recv_tokens_per_expert.data_ptr(), recv_x.data_ptr(), recv_sf_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + recv_aux_weights_ptr, recv_src_metadata.data_ptr(), channel_linked_list_ptr, num_recv_tokens, num_max_tokens_per_rank, @@ -1108,6 +1168,7 @@ class ElasticBuffer { jit::device_runtime->get_num_smem_bytes(), num_channels, do_expand, cached_mode, + do_scatter_to_nvs, scatter_src_metadata_ptr, out_nvs_ptr, comm_stream); // Stream control @@ -1136,6 +1197,7 @@ class ElasticBuffer { dst_buffer_slot_idx, token_metadata_at_forward, channel_linked_list, + recv_aux_weights, // per-row router-weight gradient scalar in NvS event}; } @@ -1186,8 +1248,22 @@ class ElasticBuffer { // Check optional tensors if (use_expanded_layout) { - // Reduction should be done with SwiGLU - EP_HOST_ASSERT(not topk_weights.has_value()); + // Optionally carry the per-expanded-row router weight gradient (droute_weights_nvs) so combine + // can reduce it back to droute_weights_sk [S,K] in one pass. The weight is a 1D [NvS] float + // tensor aligned row-for-row with `x`. Each k-slot's scalar must land in its own output column, + // so we require expanded-send (allow_multiple_reduction == false) — local reduction would merge + // distinct k-slots' weights — and a single NVLink domain (the only path wired to deliver + // per-row weights to the source rank). + if (topk_weights.has_value()) { + const auto [num_w] = get_shape<1>(topk_weights.value()); + EP_HOST_ASSERT(num_w == num_tokens); + EP_HOST_ASSERT(topk_weights->is_cuda() and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat); + EP_HOST_ASSERT(not allow_multiple_reduction and + "expand-combine topk_weights requires allow_multiple_reduction=False"); + EP_HOST_ASSERT(nccl_context->num_scaleout_ranks == 1 and + "expand-combine topk_weights is implemented for a single NVLink domain only"); + } } else if (topk_weights.has_value()) { const auto [num_tokens__, num_topk__] = get_shape<2>(topk_weights.value()); EP_HOST_ASSERT(num_tokens == num_tokens__ and num_topk == num_topk__); diff --git a/csrc/kernels/elastic/dispatch.hpp b/csrc/kernels/elastic/dispatch.hpp index ddac5db8d..20823c0ad 100644 --- a/csrc/kernels/elastic/dispatch.hpp +++ b/csrc/kernels/elastic/dispatch.hpp @@ -106,6 +106,7 @@ class DispatchRuntime final : public jit::LaunchRuntime { // Parameters void* x; sf_pack_t* sf; topk_idx_t* topk_idx; float* topk_weights; + float* aux_weights; // second per-(t,k) scalar (router-weight gradient) topk_idx_t* copied_topk_idx; int* cumulative_local_expert_recv_stats; int* psum_num_recv_tokens_per_scaleup_rank; @@ -168,6 +169,7 @@ static void __instantiate_kernel() {{ EP_CUDA_UNIFIED_CHECK(jit::launch_kernel( kernel, config, args.x, args.sf, args.topk_idx, args.topk_weights, + args.aux_weights, args.copied_topk_idx, args.cumulative_local_expert_recv_stats, args.psum_num_recv_tokens_per_scaleup_rank, @@ -213,6 +215,7 @@ static layout::TokenLayout get_dispatch_token_layout( static void launch_dispatch(void* x, void* sf, topk_idx_t* topk_idx, float* topk_weights, + float* aux_weights, topk_idx_t* copied_topk_idx, int* cumulative_local_expert_recv_stats, int* psum_num_recv_tokens_per_scaleup_rank, @@ -287,6 +290,7 @@ static void launch_dispatch(void* x, void* sf, .num_experts = num_experts, .num_topk = num_topk, .expert_alignment = expert_alignment, .num_qps = num_qps, .num_timeout_cycles = num_timeout_cycles, .x = x, .sf = static_cast(sf), .topk_idx = topk_idx, .topk_weights = topk_weights, + .aux_weights = aux_weights, .copied_topk_idx = copied_topk_idx, .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats, .psum_num_recv_tokens_per_scaleup_rank = psum_num_recv_tokens_per_scaleup_rank, @@ -310,7 +314,7 @@ class DispatchCopyEpilogueRuntime final : public jit::LaunchRuntime(&dispatch_copy_epilogue_impl<{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}>); + auto ptr = reinterpret_cast(&dispatch_copy_epilogue_impl<{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}>); }} )", - args.do_expand, args.cached_mode, + args.do_expand, args.cached_mode, args.do_scatter_to_nvs, args.launch_args.grid_dim.first, args.num_channels, args.num_warps, args.num_scaleout_ranks, args.num_scaleup_ranks, args.num_hidden_bytes, args.num_sf_packs, @@ -357,11 +363,13 @@ static void __instantiate_kernel() {{ args.psum_num_recv_tokens_per_scaleup_rank, args.psum_num_recv_tokens_per_expert, args.recv_x, args.recv_sf, args.recv_topk_idx, args.recv_topk_weights, + args.recv_aux_weights, args.recv_src_metadata, args.channel_linked_list, args.num_recv_tokens, args.recv_sf_token_stride, args.recv_sf_hidden_stride, - args.scaleout_rank_idx, args.scaleup_rank_idx)); + args.scaleout_rank_idx, args.scaleup_rank_idx, + args.scatter_src_metadata, args.out_nvs)); } }; @@ -370,6 +378,7 @@ static void launch_dispatch_copy_epilogue(void* buffer, void* workspace, int* psum_num_recv_tokens_per_expert, void* recv_x, void* recv_sf, topk_idx_t* recv_topk_idx, float* recv_topk_weights, + float* recv_aux_weights, int* recv_src_metadata, int* channel_linked_list, const int& num_recv_tokens, const int& num_max_tokens_per_rank, @@ -381,6 +390,8 @@ static void launch_dispatch_copy_epilogue(void* buffer, void* workspace, const int& num_sms, const int& num_smem_bytes, const int& num_channels, const bool& do_expand, const bool& cached_mode, + const bool& do_scatter_to_nvs, + const int* scatter_src_metadata, void* out_nvs, const at::cuda::CUDAStream& stream) { // Maximize shared memory utilization const auto token_layout = layout::TokenLayout(num_hidden_bytes, num_sf_packs * sizeof(sf_pack_t), num_topk, true); @@ -389,7 +400,7 @@ static void launch_dispatch_copy_epilogue(void* buffer, void* workspace, // Generate, build and launch const DispatchCopyEpilogueRuntime::Args args = { - .do_expand = do_expand, .cached_mode = cached_mode, + .do_expand = do_expand, .cached_mode = cached_mode, .do_scatter_to_nvs = do_scatter_to_nvs, .num_channels = num_channels, .num_warps = num_warps, .num_scaleout_ranks = num_scaleout_ranks, .num_scaleup_ranks = num_scaleup_ranks, .num_hidden_bytes = num_hidden_bytes, .num_sf_packs = num_sf_packs, @@ -400,11 +411,13 @@ static void launch_dispatch_copy_epilogue(void* buffer, void* workspace, .psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert, .recv_x = recv_x, .recv_sf = recv_sf, .recv_topk_idx = recv_topk_idx, .recv_topk_weights = recv_topk_weights, + .recv_aux_weights = recv_aux_weights, .recv_src_metadata = recv_src_metadata, .channel_linked_list = channel_linked_list, .num_recv_tokens = num_recv_tokens, .recv_sf_token_stride = recv_sf_token_stride, .recv_sf_hidden_stride = recv_sf_hidden_stride, .scaleout_rank_idx = scaleout_rank_idx, .scaleup_rank_idx = scaleup_rank_idx, + .scatter_src_metadata = scatter_src_metadata, .out_nvs = out_nvs, .launch_args = jit::LaunchArgs(num_sms, num_threads, num_smem_bytes, 1, false, true)}; const auto code = DispatchCopyEpilogueRuntime::generate(args); const auto runtime = jit::compiler->build("dispatch_copy_epilogue", code); diff --git a/deep_ep/buffers/elastic.py b/deep_ep/buffers/elastic.py index 19414e9ad..ebbbb3729 100644 --- a/deep_ep/buffers/elastic.py +++ b/deep_ep/buffers/elastic.py @@ -65,7 +65,8 @@ def __init__(self, recv_src_metadata: torch.Tensor, dst_buffer_slot_idx: torch.Tensor, token_metadata_at_forward: Optional[torch.Tensor], - channel_linked_list: Optional[torch.Tensor]): + channel_linked_list: Optional[torch.Tensor], + recv_aux_weights: Optional[torch.Tensor] = None): # per-row router-weight gradient scalar in NvS # NOTES: remember to copy the original users' input to prevent uncasual modifications on them assert topk_idx is not None @@ -82,6 +83,7 @@ def __init__(self, self.dst_buffer_slot_idx = dst_buffer_slot_idx self.token_metadata_at_forward = token_metadata_at_forward self.channel_linked_list = channel_linked_list + self.recv_aux_weights = recv_aux_weights # pass as combine(topk_weights=handle.recv_aux_weights) to reduce d_route # Inferred value, may not accurate without CPU sync self.num_recv_tokens = recv_src_metadata.shape[0] @@ -379,16 +381,18 @@ def barrier(self, use_comm_stream: bool = True, with_cpu_sync: bool = False) -> def _unpack_handle(handle: Optional[EPHandle] = None) \ -> Tuple[Optional[int], Optional[list], Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor]]: if handle is None: - return None, None, None, None, None, None, None + return None, None, None, None, None, None, None, None return (handle.num_recv_tokens, handle.num_recv_tokens_per_expert_list, handle.psum_num_recv_tokens_per_scaleup_rank, handle.psum_num_recv_tokens_per_expert, handle.dst_buffer_slot_idx, handle.token_metadata_at_forward, - handle.channel_linked_list) + handle.channel_linked_list, + handle.recv_src_metadata) # reused by cached+expand replay to reproduce the expanded-row layout @staticmethod def capture() -> EventHandle: @@ -709,6 +713,7 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, + aux_weights: Optional[torch.Tensor] = None, # second per-(t,k) router-weight gradient scalar cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, num_experts: Optional[int] = None, num_max_tokens_per_rank: Optional[int] = None, @@ -722,7 +727,8 @@ def dispatch(self, do_handle_copy: bool = True, do_cpu_sync: Optional[bool] = None, do_expand: bool = False, - use_tma_aligned_col_major_sf: bool = False) \ + use_tma_aligned_col_major_sf: bool = False, + scatter_to_nvs_out: Optional[torch.Tensor] = None) \ -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor], Optional[torch.Tensor], EPHandle, EventOverlap]: @@ -799,13 +805,23 @@ def dispatch(self, cached_psum_num_recv_tokens_per_scaleup_rank, cached_psum_num_recv_tokens_per_expert, cached_dst_buffer_slot_idx, cached_token_metadata_at_forward, - cached_channel_linked_list) = self._unpack_handle(handle) + cached_channel_linked_list, + cached_recv_src_metadata) = self._unpack_handle(handle) # last item used by cached+expand replay # Some default values num_max_tokens_per_rank = value_or(num_max_tokens_per_rank, self.num_max_tokens_per_rank) expert_alignment = value_or(expert_alignment, 1) do_cpu_sync = value_or(do_cpu_sync, True) + # Fused combine-backward scatter: when an output buffer is given, the copy epilogue scatters each + # received row directly into it, using the forward handle's recv_src_metadata ([:, 2:] holds the + # NvS destinations recorded by the forward expand dispatch). Only valid for the non-expand cached + # backward; this fuses the Python-side gather + index_copy into the dispatch kernel. + if scatter_to_nvs_out is not None: + assert handle is not None, 'scatter_to_nvs_out requires a cached forward handle' + assert not do_expand, 'scatter_to_nvs_out is only valid for non-expand (combine backward)' + scatter_src_metadata = handle.recv_src_metadata if scatter_to_nvs_out is not None else None + # Do dispatch (recv_x, recv_sf, recv_topk_idx, recv_topk_weights, @@ -817,7 +833,9 @@ def dispatch(self, dst_buffer_slot_idx, token_metadata_at_forward, channel_linked_list, + recv_aux_weights, # per-row router-weight gradient scalar in NvS event) = self.runtime.dispatch(x, sf, topk_idx, topk_weights, + aux_weights, # matches the C++ dispatch param (after topk_weights) cumulative_local_expert_recv_stats, cached_num_recv_tokens, cached_num_recv_tokens_per_expert_list, @@ -826,6 +844,7 @@ def dispatch(self, cached_dst_buffer_slot_idx, cached_token_metadata_at_forward, cached_channel_linked_list, + cached_recv_src_metadata, # matches the C++ dispatch param num_max_tokens_per_rank, num_experts, expert_alignment, num_sms, num_qps, @@ -833,7 +852,8 @@ def dispatch(self, previous_event_before_epilogue, async_with_compute_stream, allocate_on_comm_stream, do_handle_copy, do_cpu_sync, do_expand, - use_tma_aligned_col_major_sf) + use_tma_aligned_col_major_sf, + scatter_to_nvs_out, scatter_src_metadata) if handle is None: handle = EPHandle(do_expand, num_experts, expert_alignment, @@ -846,7 +866,8 @@ def dispatch(self, recv_src_metadata, dst_buffer_slot_idx, token_metadata_at_forward, - channel_linked_list) + channel_linked_list, + recv_aux_weights) # per-row router-weight gradient scalar in NvS # Repack SF recv_x = (recv_x, recv_sf) if recv_sf is not None else recv_x @@ -883,8 +904,13 @@ def combine(self, Arguments: x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks. handle: a must-set communication handle, you can obtain this from the `dispatch` function. - topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to - its original ranks. Not used in expand mode. + topk_weights: in non-expand mode, `[num_tokens, num_topk]` with `torch.float`, the tokens' + top-k weights for reducing to its original ranks. + In expand mode it may instead be a 1D `[num_tokens]` (= `[NvS]`) `torch.float` tensor — + the per-expanded-row router weight gradient (droute_weights_nvs) — and combine will reduce + it back to `combined_topk_weights` `[S, num_topk]` (droute_weights_sk). + Requires the buffer built with `allow_multiple_reduction=False` and a single NVLink + domain. Leave as `None` for a plain hidden-only combine. bias: 0, 1 or 2 `[num_combined_tokens, hidden]` with `torch.bfloat16` final bias to the output. num_sms: the number of SMs to use (0 to reuse the SM count from the dispatch handle). num_qps: the number of RDMA QPs to use (0 for automatic via `get_theoretical_num_qps`). diff --git a/deep_ep/include/deep_ep/common/layout.cuh b/deep_ep/include/deep_ep/common/layout.cuh index 0e7a61e83..ac89599fc 100644 --- a/deep_ep/include/deep_ep/common/layout.cuh +++ b/deep_ep/include/deep_ep/common/layout.cuh @@ -192,7 +192,8 @@ struct TokenLayout { with_metadata(with_metadata), num_topk(num_topk), num_metadata_bytes(num_topk * (sizeof(int) + sizeof(float)) + - (with_metadata ? (1 + num_topk) * sizeof(int) : 0)), + (with_metadata ? (1 + num_topk) * sizeof(int) : 0) + + num_topk * sizeof(float)), // trailing aux-weight region (K floats) for the second scalar base(base) { EP_STATIC_ASSERT(sizeof(int) == sizeof(float), "Invalid size assumption"); EP_UNIFIED_ASSERT(num_hidden_bytes % ptx::kNumTMAAlignBytes == 0); @@ -243,6 +244,15 @@ struct TokenLayout { return get_src_token_global_idx_ptr() + 1; } + // Second per-(t,k) scalar (router-weight gradient) region at the end of the metadata (after + // topk_weights/src/linked_list), so it changes no existing offset; num_metadata_bytes is +K*float + // accordingly. Lets dispatch carry routing and this scalar together. + __forceinline__ __device__ __host__ float* get_aux_weights_ptr() const { + const int off = num_topk * static_cast(sizeof(int) + sizeof(float)) + + (with_metadata ? (1 + num_topk) * static_cast(sizeof(int)) : 0); + return math::advance_ptr(get_metadata_ptr(), off); + } + __forceinline__ __device__ ptx::mbarrier* get_mbarrier_ptr() const { return math::advance_ptr(get_metadata_ptr(), math::align(num_metadata_bytes, ptx::kNumTMAAlignBytes)); } diff --git a/deep_ep/include/deep_ep/impls/combine.cuh b/deep_ep/include/deep_ep/impls/combine.cuh index 79f3931c7..4330a74db 100644 --- a/deep_ep/include/deep_ep/impls/combine.cuh +++ b/deep_ep/include/deep_ep/impls/combine.cuh @@ -64,9 +64,9 @@ combine_impl(nv_bfloat16* x, ptx::mbarrier_init_with_fence(mbarrier_ptr, 1); __syncwarp(); - // Expanding mode must not be backward - if constexpr (kUseExpandedLayout) - EP_DEVICE_ASSERT(topk_weights == nullptr); + // Expand mode may carry `topk_weights` ([NvS] router weight gradient), delivered per k-slot in the + // expanded-send branch below (the only path taken when kUseExpandedLayout and not + // kAllowMultipleReduction). // Gin handle // We treat each warp as a "channel" @@ -194,7 +194,19 @@ combine_impl(nv_bfloat16* x, ptx::tma_store_1d(gin.get_sym_ptr(token_buffer.get_base_ptr(), src_rank_idx), tma_buffer.get_base_ptr(), kNumHiddenBytes); ptx::tma_store_commit(); + + // Deliver this k-slot's router weight gradient scalar into the same remote + // buffer slot's weight region (position 0). The reduce epilogue reads it back + // and scatters to combined_topk_weights[src_token, k] — a direct store through + // the NVLink symmetric pointer, covered by the final combine barrier. + if (topk_weights != nullptr) { + float* w_dst = gin.get_sym_ptr(token_buffer.get_topk_weights_ptr(), src_rank_idx); + *w_dst = __ldg(topk_weights + slot_idx); + } } else { + // NOTE: the RDMA path does not carry the per-row weight. expand-combine with + // topk_weights is host-asserted to a single NVLink domain (num_scaleout_ranks==1), + // so nvlink_bypass is always true here for the supported config. // Write to the RDMA send buffer const auto send_token_buffer = send_buffer.get_rank_buffer(src_rank_idx).get_token_buffer(src_token_idx * kNumTopk + k); diff --git a/deep_ep/include/deep_ep/impls/combine_reduce_epilogue.cuh b/deep_ep/include/deep_ep/impls/combine_reduce_epilogue.cuh index b57415011..266bbb2f6 100644 --- a/deep_ep/include/deep_ep/impls/combine_reduce_epilogue.cuh +++ b/deep_ep/include/deep_ep/impls/combine_reduce_epilogue.cuh @@ -126,18 +126,36 @@ combine_reduce_epilogue_impl(nv_bfloat16* combined_x, // Write top-k weights if (combined_topk_weights != nullptr) { - const auto master_lane_idx = ptx::get_master_lane_idx(ptx::match(stored_dst_rank_idx)); - if (lane_idx < kNumTopk) { - float value = 0; - if (stored_dst_rank_idx >= 0) { - const auto dst_ptr = comm_buffer - .get_rank_buffer(kUseRankLayout ? stored_dst_rank_idx : master_lane_idx) - .get_token_buffer(token_idx).get_topk_weights_ptr() + lane_idx; - value = *dst_ptr; + if constexpr (kUseExpandedLayout) { + // In expand mode each selected k-slot's router weight gradient was delivered by the combine + // push (expanded-send) into that k-slot's own buffer slot at weight position 0. `lane_idx` + // is the k-slot; scatter the scalar straight to column lane_idx. Unselected slots + // (stored_dst_rank_idx < 0) contribute 0. This requires !kAllowMultipleReduction (so + // kUseRankLayout==false and slot index == k-slot, matching the push), host-asserted for + // expand-combine with weights. + if (lane_idx < kNumTopk) { + float value = 0; + if (stored_dst_rank_idx >= 0) { + value = *(comm_buffer.get_rank_buffer(lane_idx) + .get_token_buffer(token_idx).get_topk_weights_ptr()); + } + combined_topk_weights[token_idx * kNumTopk + lane_idx] = value; } - combined_topk_weights[token_idx * kNumTopk + lane_idx] = value; + __syncwarp(); + } else { + const auto master_lane_idx = ptx::get_master_lane_idx(ptx::match(stored_dst_rank_idx)); + if (lane_idx < kNumTopk) { + float value = 0; + if (stored_dst_rank_idx >= 0) { + const auto dst_ptr = comm_buffer + .get_rank_buffer(kUseRankLayout ? stored_dst_rank_idx : master_lane_idx) + .get_token_buffer(token_idx).get_topk_weights_ptr() + lane_idx; + value = *dst_ptr; + } + combined_topk_weights[token_idx * kNumTopk + lane_idx] = value; + } + __syncwarp(); } - __syncwarp(); } } } diff --git a/deep_ep/include/deep_ep/impls/dispatch.cuh b/deep_ep/include/deep_ep/impls/dispatch.cuh index cf21a18b0..2c1e23a5d 100644 --- a/deep_ep/include/deep_ep/impls/dispatch.cuh +++ b/deep_ep/include/deep_ep/impls/dispatch.cuh @@ -31,6 +31,7 @@ template (blockIdx.x), thread_idx = static_cast(threadIdx.x); const auto warp_idx = ptx::get_warp_idx(), lane_idx = ptx::get_lane_idx(); @@ -110,8 +115,20 @@ dispatch_copy_epilogue_impl(void* buffer, void* workspace, // Calculate target indices in the tensor int dst_tensor_idx = -1; - if (not kDoExpand and ptx::elect_one_sync()) { - dst_tensor_idx = i; + if constexpr (not kDoExpand) { + if (ptx::elect_one_sync()) + dst_tensor_idx = i; + } else if constexpr (kCachedMode) { + // Cached+expand replay. The fresh expand dispatch that built this handle recorded, per + // (recv-token i, k-slot lane), the atomically-assigned expanded row into + // recv_src_metadata[:, 2+lane] (see the write near the end of this loop). The expanded row + // order within an expert is otherwise race-dependent and not reproducible across dispatch + // calls, so a cached replay must reuse that recorded map (passed in via the handle's + // recv_src_metadata) instead of re-running the racy atomicAdd, guaranteeing the replay is + // row-for-row aligned with the forward. `-1` rows (unselected k-slots) stay -1 and are + // skipped by the `dst_tensor_idx >= 0` store guards below. + if (lane_idx < kNumTopk) + dst_tensor_idx = recv_src_metadata[i * (2 + kNumTopk) + 2 + lane_idx]; } else if (kDoExpand and dst_expert_idx >= 0) { dst_tensor_idx = atomicAdd(psum_num_recv_tokens_per_expert + dst_expert_idx, 1); } @@ -137,6 +154,24 @@ dispatch_copy_epilogue_impl(void* buffer, void* workspace, } __syncwarp(); + // Fused combine-backward scatter (non-expand path). Each valid lane k additionally writes the same + // received hidden row to out_nvs[scatter_src_metadata[i, 2+k]] -- the NvS destinations recorded by + // the forward expand dispatch (deterministic, no atomicAdd). This is bit-identical to the Python + // gather path out[meta[i,2+k]] = d_base[i]: same bytes, same destinations, no reduction. Mirrors the + // expand store above (per-lane TMA store from the same shared source to a per-lane destination). + if constexpr (kDoScatterToNvS) { + constexpr int kScatterMetadataStride = 2 + kNumTopk; + if (lane_idx < kNumTopk) { + const int nvs_row = scatter_src_metadata[i * kScatterMetadataStride + 2 + lane_idx]; + if (nvs_row >= 0) { + ptx::tma_store_1d(math::advance_ptr(out_nvs, static_cast(nvs_row) * kNumHiddenBytes), + tma_buffer.get_hidden_ptr(), kNumHiddenBytes); + ptx::tma_store_commit(); + } + } + __syncwarp(); + } + // Store SF if constexpr (kNumSFPacks > 0) { constexpr auto kNumFullIters = kNumSFPacks / 32; @@ -178,6 +213,10 @@ dispatch_copy_epilogue_impl(void* buffer, void* workspace, // For backward, weights are optional recv_topk_weights[i * kNumTopk + lane_idx] = tma_buffer.get_topk_weights_ptr()[lane_idx]; } + // Aux scalar in parallel with recv_topk_weights, expand mode only (the per-row [NvS] scalar) + if (kDoExpand and recv_aux_weights != nullptr and dst_tensor_idx >= 0) { + recv_aux_weights[dst_tensor_idx] = tma_buffer.get_aux_weights_ptr()[lane_idx]; + } __syncwarp(); // Write source token index