Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 88 additions & 12 deletions csrc/elastic/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,13 @@ class ElasticBuffer {
std::vector<int>,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, // recv_aux_weights: per-row router-weight gradient scalar in NvS layout
std::optional<EventHandle>>
dispatch(const torch::Tensor& x,
const std::optional<torch::Tensor>& sf,
const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& aux_weights, // second per-(t,k) scalar carried alongside topk_weights
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<int>& cached_num_recv_tokens,
const std::optional<std::vector<int>>& cached_num_recv_tokens_per_expert_list,
Expand All @@ -679,6 +681,7 @@ class ElasticBuffer {
const std::optional<torch::Tensor>& cached_dst_buffer_slot_idx,
const std::optional<torch::Tensor>& cached_token_metadata_at_forward,
const std::optional<torch::Tensor>& cached_channel_linked_list,
const std::optional<torch::Tensor>& cached_recv_src_metadata, // cached+expand replay reuses the recorded expanded-row map
const int& num_max_tokens_per_rank,
const int& num_experts, const int& expert_alignment,
const int& num_sms, const int& num_qps,
Expand All @@ -687,7 +690,9 @@ class ElasticBuffer {
const bool& async_with_compute_stream,
const bool& allocate_on_comm_stream,
const bool& do_handle_copy, const bool& do_cpu_sync, const bool& do_expand,
const bool& use_tma_aligned_col_major_sf) const {
const bool& use_tma_aligned_col_major_sf,
const std::optional<torch::Tensor>& scatter_to_nvs_out,
const std::optional<torch::Tensor>& scatter_src_metadata) const {
// Check SM count
EP_HOST_ASSERT(num_sms > 0);

Expand Down Expand Up @@ -746,6 +751,18 @@ class ElasticBuffer {
topk_weights_ptr = topk_weights->data_ptr<float>();
}

// Optional second per-(t,k) scalar (router-weight gradient), carried alongside topk_weights so a
// single dispatch yields both routing (recv_topk_weights) and this scalar (recv_aux_weights) in the
// same layout, letting combine output d_hidden and d_route_weights in one pass.
float* aux_weights_ptr = nullptr;
if (aux_weights.has_value()) {
const auto [num_tokens_a, num_topk_a] = get_shape<2>(aux_weights.value());
EP_HOST_ASSERT(num_tokens == num_tokens_a and num_topk == num_topk_a);
EP_HOST_ASSERT(aux_weights->is_cuda() and aux_weights->is_contiguous());
EP_HOST_ASSERT(aux_weights->scalar_type() == torch::kFloat);
aux_weights_ptr = aux_weights->data_ptr<float>();
}

// Expert receiving counter
int* cumulative_local_expert_recv_stats_ptr = nullptr;
if (cumulative_local_expert_recv_stats.has_value()) {
Expand Down Expand Up @@ -947,6 +964,7 @@ class ElasticBuffer {
EP_HOST_ASSERT(num_sms <= jit::device_runtime->get_num_sms());
launch_dispatch(x.data_ptr(), sf_ptr,
topk_idx.data_ptr<topk_idx_t>(), topk_weights_ptr,
aux_weights_ptr,
copied_topk_idx_ptr,
cumulative_local_expert_recv_stats_ptr,
psum_num_recv_tokens_per_scaleup_rank.data_ptr<int>(),
Expand Down Expand Up @@ -977,12 +995,27 @@ class ElasticBuffer {
// Assign these values according to modes
if (cached_mode) {
// Cached mode
// TODO: support to expand for MoE training backward with cached handles from non-expanding forward,
// which requires maintaining the same expanding order between forward and backward
EP_HOST_ASSERT(not do_expand and "Cannot do expand with cached mode");
// Support expand replay for MoE training backward. The fresh expand dispatch that built this
// handle recorded the per-(token,k) expanded-row assignment into recv_src_metadata[:, 2+k]; the
// cached epilogue reuses that map (passed in via `cached_recv_src_metadata`) instead of
// re-running the race-dependent atomicAdd, so the replay's expanded layout is row-for-row
// identical to the forward (see dispatch_copy_epilogue.cuh).
if (do_expand) {
EP_HOST_ASSERT(cached_recv_src_metadata.has_value() and
"cached+expand replay requires cached recv_src_metadata");
EP_HOST_ASSERT(expert_alignment == 1 and
"cached+expand replay is implemented for expert_alignment == 1");
}
EP_HOST_ASSERT(not do_cpu_sync and "Cannot do CPU sync with cached mode");
num_recv_tokens = cached_num_recv_tokens.value();
num_recv_tokens_per_expert_list = cached_num_recv_tokens_per_expert_list.value();
if (do_expand) {
// Total expanded rows = sum of per-expert received counts (expert_alignment == 1,
// so there is no per-expert alignment padding). Matches the fresh do_cpu_sync path.
num_expanded_tokens = 0;
for (const auto c : num_recv_tokens_per_expert_list)
num_expanded_tokens += c;
}
} else if (do_cpu_sync) {
// Non-cached mode with sync
const auto start_cpu_time = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -1046,14 +1079,20 @@ class ElasticBuffer {
auto recv_sf = std::optional<torch::Tensor>();
auto recv_topk_idx = std::optional<torch::Tensor>();
auto recv_topk_weights = std::optional<torch::Tensor>();
auto recv_src_metadata = torch::empty(
{num_recv_tokens, num_topk + 2},
torch::TensorOptions(torch::kCUDA).dtype(torch::kInt));
auto recv_aux_weights = std::optional<torch::Tensor>(); // per-row router-weight gradient scalar output
// In cached+expand replay, reuse the handle's recorded metadata (its [:, 2+k] columns hold the
// forward's expanded-row map, which the epilogue reads back). Otherwise allocate fresh.
auto recv_src_metadata = (cached_mode and do_expand)
? cached_recv_src_metadata.value()
: torch::empty(
{num_recv_tokens, num_topk + 2},
torch::TensorOptions(torch::kCUDA).dtype(torch::kInt));

// Optional tensors
void* recv_sf_ptr = nullptr;
topk_idx_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_aux_weights_ptr = nullptr; // per-row router-weight gradient scalar
int recv_sf_token_stride = 0, recv_sf_hidden_stride = 0;
if (sf.has_value()) {
if (not use_tma_aligned_col_major_sf) {
Expand All @@ -1077,25 +1116,46 @@ class ElasticBuffer {
torch::empty({num_allocated_tokens, num_topk}, topk_weights->options());
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
// recv_aux_weights only in expand mode (needs the per-row [NvS] scalar)
if (aux_weights.has_value() and do_expand) {
recv_aux_weights = torch::empty({num_allocated_tokens}, aux_weights->options());
recv_aux_weights_ptr = recv_aux_weights->data_ptr<float>();
}

// Process prefix sum, in expanding mode, it is also atomic counters
if (do_expand) {
// Slice and exclusive part and do atomic additions into inclusive
EP_HOST_ASSERT(not cached_mode);
psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert.slice(0, 0, num_local_experts);
// For cached+expand the epilogue reuses the recorded expanded-row map (no atomic counters),
// and the cached psum is already shaped [num_local_experts]. Only the fresh expand path needs
// the exclusive prefix-sum slice used as atomic counters.
if (not cached_mode) {
// Slice the exclusive part and do atomic additions into inclusive
psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert.slice(0, 0, num_local_experts);
}
} else if (not cached_mode) {
// Slice the inclusive part (and will not be used in the epilogue)
psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert.slice(0, 1, num_local_experts + 1);
}
EP_HOST_ASSERT(psum_num_recv_tokens_per_expert.size(0) == num_local_experts);

// Fused combine-backward scatter. When `scatter_to_nvs_out` is given (only meaningful for the
// non-expand cached backward), the copy epilogue additionally scatters each received row to
// out_nvs[scatter_src_metadata[i, 2+k]] -- the NvS destinations recorded by the forward expand
// dispatch -- fusing the Python-side gather/index_copy into this dispatch kernel. `scatter_src_metadata`
// must be the forward handle's recv_src_metadata ([num_recv, num_topk+2]).
const bool do_scatter_to_nvs = scatter_to_nvs_out.has_value() and not do_expand;
void* out_nvs_ptr = scatter_to_nvs_out.has_value() ? scatter_to_nvs_out->data_ptr() : nullptr;
const int* scatter_src_metadata_ptr =
scatter_src_metadata.has_value() ? scatter_src_metadata->data_ptr<int>() : nullptr;
EP_HOST_ASSERT(not do_scatter_to_nvs or scatter_src_metadata_ptr != nullptr);

// Launch copy kernels with full SMs
stream_control_before_epilogue(previous_event_before_epilogue);
launch_dispatch_copy_epilogue(buffer, workspace,
psum_num_recv_tokens_per_scaleup_rank.data_ptr<int>(),
psum_num_recv_tokens_per_expert.data_ptr<int>(),
recv_x.data_ptr(), recv_sf_ptr,
recv_topk_idx_ptr, recv_topk_weights_ptr,
recv_aux_weights_ptr,
recv_src_metadata.data_ptr<int>(),
channel_linked_list_ptr,
num_recv_tokens, num_max_tokens_per_rank,
Expand All @@ -1108,6 +1168,7 @@ class ElasticBuffer {
jit::device_runtime->get_num_smem_bytes(),
num_channels,
do_expand, cached_mode,
do_scatter_to_nvs, scatter_src_metadata_ptr, out_nvs_ptr,
comm_stream);

// Stream control
Expand Down Expand Up @@ -1136,6 +1197,7 @@ class ElasticBuffer {
dst_buffer_slot_idx,
token_metadata_at_forward,
channel_linked_list,
recv_aux_weights, // per-row router-weight gradient scalar in NvS
event};
}

Expand Down Expand Up @@ -1186,8 +1248,22 @@ class ElasticBuffer {

// Check optional tensors
if (use_expanded_layout) {
// Reduction should be done with SwiGLU
EP_HOST_ASSERT(not topk_weights.has_value());
// Optionally carry the per-expanded-row router weight gradient (droute_weights_nvs) so combine
// can reduce it back to droute_weights_sk [S,K] in one pass. The weight is a 1D [NvS] float
// tensor aligned row-for-row with `x`. Each k-slot's scalar must land in its own output column,
// so we require expanded-send (allow_multiple_reduction == false) — local reduction would merge
// distinct k-slots' weights — and a single NVLink domain (the only path wired to deliver
// per-row weights to the source rank).
if (topk_weights.has_value()) {
const auto [num_w] = get_shape<1>(topk_weights.value());
EP_HOST_ASSERT(num_w == num_tokens);
EP_HOST_ASSERT(topk_weights->is_cuda() and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat);
EP_HOST_ASSERT(not allow_multiple_reduction and
"expand-combine topk_weights requires allow_multiple_reduction=False");
EP_HOST_ASSERT(nccl_context->num_scaleout_ranks == 1 and
"expand-combine topk_weights is implemented for a single NVLink domain only");
}
} else if (topk_weights.has_value()) {
const auto [num_tokens__, num_topk__] = get_shape<2>(topk_weights.value());
EP_HOST_ASSERT(num_tokens == num_tokens__ and num_topk == num_topk__);
Expand Down
Loading