diff --git a/3rdparty/deep_ep/config.hpp b/3rdparty/deep_ep/config.hpp new file mode 100644 index 0000000000..0839265799 --- /dev/null +++ b/3rdparty/deep_ep/config.hpp @@ -0,0 +1,193 @@ +// clang-format off +#pragma once + +#include "kernels/api.cuh" +#include "kernels/exception.cuh" + +namespace deep_ep { + +template +dtype_t ceil_div(dtype_t a, dtype_t b) { + return (a + b - 1) / b; +} + +template +dtype_t align_up(dtype_t a, dtype_t b) { + return ceil_div(a, b) * b; +} + +template +dtype_t align_down(dtype_t a, dtype_t b) { + return a / b * b; +} + +struct Config { + int num_sms; + int num_max_nvl_chunked_send_tokens; + int num_max_nvl_chunked_recv_tokens; + int num_max_rdma_chunked_send_tokens; + int num_max_rdma_chunked_recv_tokens; + + Config(int num_sms, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens) + : num_sms(num_sms), + num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), + num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), + num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), + num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { + EP_HOST_ASSERT(num_sms >= 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); + + // Ceil up RDMA buffer size + this->num_max_rdma_chunked_recv_tokens = align_up(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); + // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); + } + + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } + + size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { + // Legacy mode + if (num_ranks <= NUM_MAX_NVL_PEERS) + return 0; + + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } +}; + +struct LowLatencyBuffer { + int num_clean_int = 0; + + void* dispatch_rdma_send_buffer = nullptr; + void* dispatch_rdma_recv_data_buffer = nullptr; + int* dispatch_rdma_recv_count_buffer = nullptr; + + void* combine_rdma_send_buffer = nullptr; + void* combine_rdma_recv_data_buffer = nullptr; + int* combine_rdma_recv_flag_buffer = nullptr; + + void* combine_rdma_send_buffer_data_start = nullptr; + size_t num_bytes_per_combine_msg = 0; + + std::pair clean_meta() { + EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); + return {dispatch_rdma_recv_count_buffer, num_clean_int}; + } +}; + +struct LowLatencyLayout { + void* rdma_buffer = nullptr; + size_t total_bytes = 0; + LowLatencyBuffer buffers[2]; + + template + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { + return reinterpret_cast(reinterpret_cast(ptr) + count); + } + + LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts): rdma_buffer(rdma_buffer) { + const int num_scales = hidden / 128; + + // Dispatch and combine layout: + // - 2 symmetric odd/even send buffer + // - 2 symmetric odd/even receive buffers + // - 2 symmetric odd/even signaling buffers + + // Message sizes + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation + // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max + EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); + + // Send buffer + size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); + EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); + total_bytes += send_buffer_bytes * 2; + + // Symmetric receive buffers + // TODO: optimize memory usages + size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); + EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); + total_bytes += recv_buffer_bytes * 2; + + // Symmetric signaling buffers + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; + size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); + size_t signaling_buffer_bytes_aligned = align_up(signaling_buffer_bytes, 128); + total_bytes += signaling_buffer_bytes_aligned * 2; + + // Assign pointers + // NOTES: we still leave some space for distinguishing dispatch/combine buffer, + // so you may see some parameters are duplicated + for (int i = 0; i < 2; ++i) { + buffers[i] = {static_cast(signaling_buffer_bytes / sizeof(int)), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), + num_bytes_per_combine_msg}; + } + } +}; + +inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; + return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; +} + +} // namespace deep_ep + +// clang-format on diff --git a/3rdparty/deep_ep/deep_ep.cpp b/3rdparty/deep_ep/deep_ep.cpp new file mode 100644 index 0000000000..0207e91353 --- /dev/null +++ b/3rdparty/deep_ep/deep_ep.cpp @@ -0,0 +1,1604 @@ +#include "deep_ep.hpp" + +#include "kernels/api.cuh" +#include "kernels/exception.cuh" +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/core/allocator.h" +#include "src/turbomind/core/context.h" +#include "src/turbomind/core/data_type.h" +#include "src/turbomind/kernels/core/math.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/string_utils.h" + +#include +#include +#include +#include +#include +#include +#include + +using turbomind::fmtstr; +using turbomind::round_up; + +namespace shared_memory { +void cu_mem_set_access_all(void* ptr, size_t size) +{ + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + CUmemAccessDesc access_desc[device_count]; + for (int idx = 0; idx < device_count; ++idx) { + access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc[idx].location.id = idx; + access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); +} + +void cu_mem_free(void* ptr) +{ + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); +} + +size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) +{ + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if (size == 0) + size = granularity; + return size; +} + +SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric): use_fabric(use_fabric) {} + +void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) +{ + if (use_fabric) { + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } + else { + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +void SharedMemoryAllocator::free(void* ptr) +{ + if (use_fabric) { + cu_mem_free(ptr); + } + else { + CUDA_CHECK(cudaFree(ptr)); + } +} + +void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) +{ + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if (use_fabric) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + CU_CHECK(cuMemExportToShareableHandle( + &mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + } + else { + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) +{ + if (use_fabric) { + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle( + &handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } + else { + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +void SharedMemoryAllocator::close_mem_handle(void* ptr) +{ + if (use_fabric) { + cu_mem_free(ptr); + } + else { + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} +} // namespace shared_memory + +namespace deep_ep { + +Buffer::Buffer(int rank, + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + int64_t num_ll_rdma_bytes, + bool low_latency_mode, + bool enable_shrink, + bool use_fabric, + int qps_per_rank, + HostComm h_comm): + rank(rank), + num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), + low_latency_mode(low_latency_mode), + num_rdma_bytes(num_rdma_bytes), + num_ll_rdma_bytes(num_ll_rdma_bytes), + enable_shrink(enable_shrink), + shared_memory_allocator(use_fabric), + qps_per_rank(qps_per_rank), + h_comm(h_comm) +{ + // Common checks + EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 + and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 + and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(0 <= rank and rank < num_ranks + and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + if (num_rdma_bytes > 0) { + EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); + } + + // Get ranks + CUDA_CHECK(cudaGetDevice(&device_id)); + rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + + // Get device info + cudaDeviceProp device_prop = {}; + CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); + num_device_sms = device_prop.multiProcessorCount; + + // Number of per-channel bytes cannot be large + EP_HOST_ASSERT(ceil_div(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits::max()); + EP_HOST_ASSERT(ceil_div(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits::max()); + + auto comm_stream = turbomind::core::Context::stream().handle(); + + // Create 32 MiB workspace + CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + + // MoE counter + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); + *moe_recv_counter = -1; + + // MoE expert-level counter + CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) + moe_recv_expert_counter[i] = -1; + + // MoE RDMA-level counter + if (num_rdma_ranks > 0) { + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); + *moe_recv_rdma_counter = -1; + } + + // NVLink + if (num_nvl_bytes > 0) { + allocate_sync_nvl_buffer(); + } + + // RDMA + if (num_rdma_bytes || num_ll_rdma_bytes) { + allocate_rdma_buffer(); + } + + turbomind::core::Context::stream().Sync(); + h_comm->Sync(); + + // Ready to use + available = true; +} + +void Buffer::allocate_sync_nvl_buffer() +{ + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); + + auto stream = turbomind::core::Context::stream().handle(); + + HostComm h_nvl_comm = h_comm->Split(rdma_rank, 0); + TM_CHECK_EQ(h_nvl_comm->is_same_process(), true); + + ipc_comm = CreateDeviceCommunicator("cuda-ipc", h_nvl_comm->n_ranks(), nvl_rank, h_nvl_comm); + + buffer_ptrs[nvl_rank] = + ipc_comm->Allocate(num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); + + buffer_ptrs_gpu = + reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); + + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = + reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + + barrier_signal_bytes + buffer_ptr_bytes); + + // No need to synchronize, will do a full device sync during `sync` + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, stream)); + + AllGather(h_nvl_comm, buffer_ptrs, 1); + + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i != nvl_rank) { + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); + } + } + + // Copy all buffer and barrier signal pointers to GPU + CUDA_CHECK(cudaMemcpyAsync( + buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(barrier_signal_ptrs_gpu, + barrier_signal_ptrs, + sizeof(int*) * NUM_MAX_NVL_PEERS, + cudaMemcpyHostToDevice, + stream)); +} + +void Buffer::allocate_rdma_buffer() +{ + TM_CHECK_EQ(comm, nullptr); + if ((not low_latency_mode) and (num_rdma_ranks == 1)) { + return; + } + + std::vector unique_ids; + if (rank == 0) { + unique_ids = deep_ep::internode::get_unique_id(); + } + Broadcast(h_comm, unique_ids, 0); + + comm = std::make_shared(); + comm->init(unique_ids, rank, num_ranks, low_latency_mode, qps_per_rank); + internode::barrier(comm.get()); + + auto stream = turbomind::core::Context::stream().handle(); + + if (num_rdma_bytes) { + // Allocate High-Throughput RDMA buffer + rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get()); + // Register memory with NCCL communicators (sets up windows for RDMA) + internode::register_memory(rdma_buffer_ptr, num_rdma_bytes, comm.get()); + } + + if (num_ll_rdma_bytes) { + // Allocate Low-Latency RDMA buffer + rdma_ll_buffer_ptr = internode::alloc(num_ll_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get()); + // Register memory with NCCL communicators (sets up windows for RDMA) + internode::register_memory(rdma_ll_buffer_ptr, num_ll_rdma_bytes, comm.get()); + + // Clean buffer (mainly for low-latency mode) + CUDA_CHECK(cudaMemsetAsync(rdma_ll_buffer_ptr, 0, num_ll_rdma_bytes, stream)); + + internode_ll::set_p2p_disabled_flag(comm->is_p2p_disabled()); + } + + // Allocate and clean shrink buffer + if (enable_shrink) { + int num_mask_buffer_bytes = num_ranks * sizeof(int); + int num_sync_buffer_bytes = num_ranks * sizeof(int); + mask_buffer_ptr = + reinterpret_cast(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get())); + sync_buffer_ptr = + reinterpret_cast(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get())); + CUDA_CHECK(cudaMemsetAsync(mask_buffer_ptr, 0, num_mask_buffer_bytes, stream)); + CUDA_CHECK(cudaMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes)); + } + + // Barrier + internode::barrier(comm.get()); +} + +bool Buffer::is_available() const +{ + return available; +} + +bool Buffer::is_internode_available() const +{ + return is_available() and num_ranks > NUM_MAX_NVL_PEERS; +} + +int Buffer::get_num_rdma_ranks() const +{ + return num_rdma_ranks; +} + +int Buffer::get_rdma_rank() const +{ + return rdma_rank; +} + +int Buffer::get_root_rdma_rank(bool global) const +{ + return global ? nvl_rank : 0; +} + +int Buffer::get_local_device_id() const +{ + return device_id; +} + +void Buffer::destroy() +{ + TM_LOG_DEBUG("[NCCLEP][%d] Destroying buffer", rank); + EP_HOST_ASSERT(not destroyed); + + // Synchronize + auto comm_stream = turbomind::core::Context::stream().handle(); + + if (num_nvl_bytes > 0 && ipc_comm) { + turbomind::core::Context::stream().Sync(); + ipc_comm->Free(buffer_ptrs[nvl_rank]); + ipc_comm = {}; + } + + // Free NVSHMEM + if (is_available() && comm != nullptr) { + turbomind::core::Context::stream().Sync(); + if (num_rdma_bytes > 0) { + internode::free(rdma_buffer_ptr, comm.get()); + } + if (num_ll_rdma_bytes > 0) { + internode::free(rdma_ll_buffer_ptr, comm.get()); + } + if (enable_shrink) { + internode::free(mask_buffer_ptr, comm.get()); + internode::free(sync_buffer_ptr, comm.get()); + } + internode::finalize(comm.get()); + } + + // Free workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); + + destroyed = true; + available = false; +} + +std::tuple, Tensor, Tensor> // +Buffer::get_dispatch_layout(const Tensor& topk_idx, int num_experts) +{ + + auto num_tokens = static_cast(topk_idx.shape(0)); + auto num_topk = static_cast(topk_idx.shape(1)); + auto num_tokens_per_rank = Tensor_{{num_ranks}, turbomind::kDEVICE}; + auto num_tokens_per_rdma_rank = std::optional(); + auto num_tokens_per_expert = Tensor_{{num_experts}, turbomind::kDEVICE}; + auto is_token_in_rank = Tensor_{{num_tokens, num_ranks}, turbomind::kDEVICE}; + if (is_internode_available()) { + num_tokens_per_rdma_rank = Buffer_{num_rdma_ranks, turbomind::kDEVICE}; + } + static_assert(sizeof(topk_idx_t) == sizeof(int64_t), "topk_idx_t must be int64_t"); + + auto stream = turbomind::core::Context::stream().handle(); + layout::get_dispatch_layout(topk_idx.data(), + num_tokens_per_rank.data(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data() : + nullptr, + num_tokens_per_expert.data(), + is_token_in_rank.data_or((bool*)nullptr), // num_tokens may be zero + num_tokens, + num_topk, + num_ranks, + num_experts, + stream); + + return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank}; +} + +std::tuple, + std::optional, + std::optional, + std::vector, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor> +Buffer::intranode_dispatch(const Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config) +{ + bool cached_mode = cached_rank_prefix_matrix.has_value(); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); + } + else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + // Type checks + EP_HOST_ASSERT(is_token_in_rank.dtype() == turbomind::kBool); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(cached_channel_prefix_matrix->dtype() == turbomind::kInt32); + } + else { + EP_HOST_ASSERT(num_tokens_per_expert->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(num_tokens_per_rank->dtype() == turbomind::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.shape(1) * byte_size(x.dtype())) % sizeof(int4) == 0); + EP_HOST_ASSERT(is_token_in_rank.ndim() == 2 and is_token_in_rank.is_contiguous()); + EP_HOST_ASSERT(is_token_in_rank.shape(0) == x.shape(0) and is_token_in_rank.shape(1) == num_ranks); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->ndim() == 2 and cached_rank_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rank_prefix_matrix->shape(0) == num_ranks + and cached_rank_prefix_matrix->shape(1) == num_ranks); + EP_HOST_ASSERT(cached_channel_prefix_matrix->ndim() == 2 and cached_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_channel_prefix_matrix->shape(0) == num_ranks + and cached_channel_prefix_matrix->shape(1) == num_channels); + } + else { + EP_HOST_ASSERT(num_tokens_per_expert->ndim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->shape(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->shape(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + EP_HOST_ASSERT(num_tokens_per_rank->ndim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->shape(0) == num_ranks); + } + + auto num_tokens = static_cast(x.shape(0)); + auto hidden = static_cast(x.shape(1)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->shape(0)); + auto num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + const topk_idx_t* topk_idx_ptr = nullptr; + const float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->shape(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->ndim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->ndim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->shape(0) and num_tokens == topk_weights->shape(0)); + EP_HOST_ASSERT(num_topk == topk_weights->shape(1)); + EP_HOST_ASSERT(topk_weights->dtype() == turbomind::kFloat32); + topk_idx_ptr = topk_idx->data_or((topk_idx_t*)nullptr); + topk_weights_ptr = topk_weights->data_or((float*)nullptr); + } + + // FP8 scales checks + const float* x_scales_ptr = nullptr; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(byte_size(x.dtype()) == 1); + EP_HOST_ASSERT(x_scales->dtype() == turbomind::kFloat32 or x_scales->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(x_scales->ndim() == 2); + EP_HOST_ASSERT(x_scales->shape(0) == num_tokens); + num_scales = x_scales->ndim() == 1 ? 1 : static_cast(x_scales->shape(1)); + x_scales_ptr = x_scales->data_or((float*)nullptr); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1; + auto rank_prefix_matrix = Tensor(); + auto channel_prefix_matrix = Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // used to compute offsets in MoeFfnLayer + auto moe_recv_expert_counter_ten = Tensor({num_local_experts}, turbomind::kInt32, turbomind::kDEVICE); + + // Barrier or send sizes + // To clean: channel start/end offset, head and tail + int num_memset_int = num_channels * num_ranks * 4; + if (cached_mode) { + EP_HOST_ASSERT(0); + // num_recv_tokens = cached_num_recv_tokens; + // rank_prefix_matrix = cached_rank_prefix_matrix.value(); + // channel_prefix_matrix = cached_channel_prefix_matrix.value(); + + // // Copy rank prefix matrix and clean flags + // intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), + // num_memset_int, + // buffer_ptrs_gpu, + // barrier_signal_ptrs_gpu, + // rank, + // num_ranks, + // comm_stream); + } + else { + rank_prefix_matrix = Tensor({num_ranks, num_ranks}, turbomind::kInt32, turbomind::kDEVICE); + channel_prefix_matrix = Tensor({num_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + + // Send sizes + // Meta information: + // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` + // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` + // NOTES: no more token dropping in this version + *moe_recv_counter = -1; + for (int i = 0; i < num_local_experts; ++i) + moe_recv_expert_counter[i] = -1; + EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); + intranode::notify_dispatch(num_tokens_per_rank->data(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_expert->data(), + moe_recv_expert_counter_mapped, + moe_recv_expert_counter_ten.data(), + num_experts, + num_tokens, + is_token_in_rank.data_or((bool*)nullptr), // num_tokens may be zero + channel_prefix_matrix.data(), + rank_prefix_matrix.data(), + num_memset_int, + expert_alignment, + buffer_ptrs_gpu, + barrier_signal_ptrs_gpu, + rank, + turbomind::core::Context::stream().handle(), + num_channels); + + if (num_worst_tokens > 0) { + // No CPU sync, just allocate the worst case + num_recv_tokens = num_worst_tokens; + + // Must be forward with top-k stuffs + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(topk_weights.has_value()); + } + else { + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() + - start_time) + .count() + > NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + } + + // Allocate new tensors + auto recv_x = Tensor({num_recv_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + auto recv_src_idx = Tensor({num_recv_tokens}, turbomind::kInt32, turbomind::kDEVICE); + auto recv_topk_idx = std::optional(); + auto recv_topk_weights = std::optional(); + auto recv_x_scales = std::optional(); + auto recv_channel_prefix_matrix = Tensor({num_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + auto send_head = Tensor({num_tokens, num_ranks}, turbomind::kInt32, turbomind::kDEVICE); + + // Assign pointers + topk_idx_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = Tensor({num_recv_tokens, num_topk}, topk_idx->dtype(), topk_idx->device()); + recv_topk_weights = Tensor({num_recv_tokens, num_topk}, topk_weights->dtype(), topk_weights->device()); + recv_topk_idx_ptr = recv_topk_idx->data_or((topk_idx_t*)nullptr); + recv_topk_weights_ptr = recv_topk_weights->data_or((float*)nullptr); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->ndim() == 1 ? + Tensor({num_recv_tokens}, x_scales->dtype(), x_scales->device()) : + Tensor({num_recv_tokens, num_scales}, x_scales->dtype(), x_scales->device()); + recv_x_scales_ptr = recv_x_scales->data_or((float*)nullptr); + } + + // Dispatch + EP_HOST_ASSERT( + num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * byte_size(recv_x.dtype()) + + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) + * num_scales // FP8 scale buffer + <= num_nvl_bytes); + + intranode::dispatch(recv_x.data_or((void*)nullptr), + recv_x_scales_ptr, + recv_src_idx.data_or((int*)nullptr), + recv_topk_idx_ptr, + recv_topk_weights_ptr, + recv_channel_prefix_matrix.data(), + send_head.data_or((int*)nullptr), + x.data_or((void*)nullptr), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + is_token_in_rank.data_or((bool*)nullptr), + channel_prefix_matrix.data(), + num_tokens, + num_worst_tokens, + static_cast(hidden * byte_size(recv_x.dtype()) / sizeof(int4)), + num_topk, + num_experts, + num_scales, + scale_token_stride, + scale_hidden_stride, + buffer_ptrs_gpu, + rank, + num_ranks, + turbomind::core::Context::stream().handle(), + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); + + // Return values + return {recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + moe_recv_expert_counter_ten, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + send_head}; +} + +std::tuple> +Buffer::intranode_combine(const Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const Tensor& src_idx, + const Tensor& rank_prefix_matrix, + const Tensor& channel_prefix_matrix, + Tensor& send_head, + const Config& config) +{ + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_idx.ndim() == 1 and src_idx.is_contiguous() and src_idx.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(send_head.ndim() == 2 and send_head.is_contiguous() and send_head.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.ndim() == 2 and rank_prefix_matrix.is_contiguous() + and rank_prefix_matrix.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.ndim() == 2 and channel_prefix_matrix.is_contiguous() + and channel_prefix_matrix.dtype() == turbomind::kInt32); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + + auto num_tokens = static_cast(x.shape(0)), hidden = static_cast(x.shape(1)); + auto num_recv_tokens = static_cast(send_head.shape(0)); + EP_HOST_ASSERT(src_idx.shape(0) == num_tokens); + EP_HOST_ASSERT(send_head.shape(1) == num_ranks); + EP_HOST_ASSERT(rank_prefix_matrix.shape(0) == num_ranks and rank_prefix_matrix.shape(1) == num_ranks); + EP_HOST_ASSERT(channel_prefix_matrix.shape(0) == num_ranks and channel_prefix_matrix.shape(1) == num_channels); + EP_HOST_ASSERT((hidden * byte_size(x.dtype())) % sizeof(int4) == 0); + + int num_topk = 0; + auto recv_topk_weights = std::optional(); + const float* topk_weights_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->ndim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->shape(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->dtype() == turbomind::kFloat32); + num_topk = static_cast(topk_weights->shape(1)); + topk_weights_ptr = topk_weights->data_or((float*)nullptr); + recv_topk_weights = Tensor({num_recv_tokens, num_topk}, turbomind::kFloat32, turbomind::kDEVICE); + recv_topk_weights_ptr = recv_topk_weights->data_or((float*)nullptr); + } + + // Launch barrier and reset queue head and tail + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); + intranode::cached_notify_combine(buffer_ptrs_gpu, + send_head.data_or((int*)nullptr), + num_channels, + num_recv_tokens, + num_channels * num_ranks * 2, + barrier_signal_ptrs_gpu, + rank, + num_ranks, + turbomind::core::Context::stream().handle()); + + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.ndim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.dtype() == x.dtype()); + EP_HOST_ASSERT(bias.shape(0) == num_recv_tokens and bias.shape(1) == hidden); + bias_ptrs[i] = bias.data_or((void*)nullptr); + } + + // Combine data + auto recv_x = Tensor({num_recv_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * byte_size(x.dtype()) + + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk + * sizeof(float) // Top-k weight buffer + <= num_nvl_bytes); + intranode::combine(CUDA_R_16BF, + recv_x.data_or((void*)nullptr), + recv_topk_weights_ptr, + x.data_or((void*)nullptr), + topk_weights_ptr, + bias_ptrs[0], + bias_ptrs[1], + src_idx.data_or((int*)nullptr), + rank_prefix_matrix.data(), + channel_prefix_matrix.data(), + send_head.data_or((int*)nullptr), + num_tokens, + num_recv_tokens, + hidden, + num_topk, + buffer_ptrs_gpu, + rank, + num_ranks, + turbomind::core::Context::stream().handle(), + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); + + return {recv_x, std::nullopt}; +} + +std::tuple, + Tensor, + Tensor, + Tensor> +Buffer::low_latency_dispatch(const Tensor& x, + const Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0) +{ + // Tensor checks + // By default using `ptp128c` FP8 cast + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous() and x.dtype() == turbomind::kBfloat16); + EP_HOST_ASSERT(x.shape(1) % sizeof(int4) == 0 and x.shape(1) % 128 == 0); + EP_HOST_ASSERT(topk_idx.ndim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(x.shape(0) == topk_idx.shape(0) and x.shape(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.dtype() == turbomind::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); + + // Diagnosis tensors + EP_HOST_ASSERT(not cumulative_local_expert_recv_stats.has_value()); + EP_HOST_ASSERT(not dispatch_wait_recv_cost_stats.has_value()); + // if (cumulative_local_expert_recv_stats.has_value()) { + // EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dtype() == turbomind::kInt32); + // EP_HOST_ASSERT(cumulative_local_expert_recv_stats->ndim() == 1 + // and cumulative_local_expert_recv_stats->is_contiguous()); + // EP_HOST_ASSERT(cumulative_local_expert_recv_stats->shape(0) == num_experts / num_ranks); + // } + // if (dispatch_wait_recv_cost_stats.has_value()) { + // EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dtype() == turbomind::kInt64); + // EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->ndim() == 1 and + // dispatch_wait_recv_cost_stats->is_contiguous()); EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->shape(0) == + // num_ranks); + // } + + auto num_tokens = static_cast(x.shape(0)); + auto hidden = static_cast(x.shape(1)); + auto num_topk = static_cast(topk_idx.shape(1)); + auto num_local_experts = num_experts / num_ranks; + + // Buffer control + LowLatencyLayout layout(rdma_ll_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_ll_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + // Allocate packed tensors + auto packed_recv_x = Tensor( + {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.dtype(), turbomind::kDEVICE); + auto packed_recv_src_info = Tensor( + {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, turbomind::kInt32, turbomind::kDEVICE); + auto packed_recv_layout_range = Tensor({num_local_experts, num_ranks}, turbomind::kInt64, turbomind::kDEVICE); + auto packed_recv_count = Tensor({num_local_experts}, turbomind::kInt32, turbomind::kDEVICE); + + // Allocate column-majored scales + auto packed_recv_x_scales = std::optional(); + void* packed_recv_x_scales_ptr = nullptr; + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 + and "TMA requires the number of tokens to be multiple of 4"); + + if (use_fp8) { + // TODO: support unaligned cases + EP_HOST_ASSERT(hidden % 512 == 0); + if (not use_ue8m0) { + packed_recv_x_scales = + Tensor({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + turbomind::kFloat32, + turbomind::kDEVICE); + } + else { + EP_HOST_ASSERT(round_scale); + packed_recv_x_scales = + Tensor({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + turbomind::kInt32, + turbomind::kDEVICE); + } + packed_recv_x_scales = packed_recv_x_scales->transpose(1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_or((float*)nullptr); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + const int phases = LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE; + auto dev_comm = comm->get_device_communicator(true); + auto nccl_win = comm->get_device_nccl_window(rdma_ll_buffer_ptr); + auto signals_base = comm->get_signals_base(low_latency_buffer_idx, true); + + internode_ll::dispatch( + packed_recv_x.raw_data(), + packed_recv_x_scales_ptr, + packed_recv_src_info.data(), + packed_recv_layout_range.data(), + packed_recv_count.data(), + mask_buffer_ptr, + nullptr, + nullptr, + buffer.dispatch_rdma_recv_data_buffer, + buffer.dispatch_rdma_recv_count_buffer, + buffer.dispatch_rdma_send_buffer, + reinterpret_cast(buffer.dispatch_rdma_recv_data_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.dispatch_rdma_recv_count_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.dispatch_rdma_send_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + x.raw_data(), + topk_idx.data(), + next_clean_meta.first, + next_clean_meta.second, + num_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + use_fp8, + round_scale, + use_ue8m0, + workspace, + num_device_sms, + nccl_win, + dev_comm, + signals_base, + turbomind::core::Context::stream().handle(), + phases); + + return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range}; +} + +std::tuple // +Buffer::low_latency_combine(const Tensor& x, + const Tensor& topk_idx, + const Tensor& topk_weights, + const Tensor& src_info, + const Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + const std::optional& out) +{ + EP_HOST_ASSERT(low_latency_mode); + + // Tensor checks + EP_HOST_ASSERT(x.ndim() == 3 and x.is_contiguous() and x.dtype() == turbomind::kBfloat16); + EP_HOST_ASSERT(x.shape(0) == num_experts / num_ranks); + EP_HOST_ASSERT(x.shape(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(x.shape(2) % sizeof(int4) == 0 and x.shape(2) % 128 == 0); + EP_HOST_ASSERT(topk_idx.ndim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.shape(0) == topk_weights.shape(0) and topk_idx.shape(1) == topk_weights.shape(1)); + EP_HOST_ASSERT(topk_idx.dtype() == turbomind::kInt64); + EP_HOST_ASSERT(topk_weights.ndim() == 2 and topk_weights.is_contiguous()); + EP_HOST_ASSERT(topk_weights.shape(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_weights.dtype() == turbomind::kFloat32); + EP_HOST_ASSERT(src_info.ndim() == 2 and src_info.is_contiguous()); + EP_HOST_ASSERT(src_info.dtype() == turbomind::kInt32 and x.shape(0) == src_info.shape(0)); + EP_HOST_ASSERT(layout_range.ndim() == 2 and layout_range.is_contiguous()); + EP_HOST_ASSERT(layout_range.dtype() == turbomind::kInt64); + EP_HOST_ASSERT(layout_range.shape(0) == num_experts / num_ranks and layout_range.shape(1) == num_ranks); + + EP_HOST_ASSERT(not combine_wait_recv_cost_stats.has_value()); + // if (combine_wait_recv_cost_stats.has_value()) { + // EP_HOST_ASSERT(combine_wait_recv_cost_stats->dtype() == turbomind::kInt64); + // EP_HOST_ASSERT(combine_wait_recv_cost_stats->ndim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); + // EP_HOST_ASSERT(combine_wait_recv_cost_stats->shape(0) == num_ranks); + // } + + auto hidden = static_cast(x.shape(2)); + auto num_topk = static_cast(topk_weights.shape(1)); + auto num_combined_tokens = static_cast(topk_weights.shape(0)); + + // Buffer control + LowLatencyLayout layout(rdma_ll_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_ll_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + // Allocate output tensor + Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->ndim() == 2 and out->is_contiguous()); + EP_HOST_ASSERT(out->shape(0) == num_combined_tokens and out->shape(1) == hidden); + EP_HOST_ASSERT(out->dtype() == x.dtype()); + combined_x = out.value(); + } + else { + combined_x = Tensor({num_combined_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + const int phases = LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE; + auto dev_comm = comm->get_device_communicator(true); + auto nccl_win = comm->get_device_nccl_window(rdma_ll_buffer_ptr); + auto signals_base = comm->get_signals_base(low_latency_buffer_idx, true); + + internode_ll::combine( + combined_x.data_or((void*)nullptr), + buffer.combine_rdma_recv_data_buffer, + buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_send_buffer, + reinterpret_cast(buffer.combine_rdma_recv_data_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.combine_rdma_recv_flag_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.combine_rdma_send_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + x.raw_data(), + topk_idx.data_or((topk_idx_t*)nullptr), + topk_weights.data_or((float*)nullptr), + src_info.data(), + layout_range.data(), + mask_buffer_ptr, + nullptr, + next_clean_meta.first, + next_clean_meta.second, + num_combined_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + use_logfmt, + workspace, + num_device_sms, + nccl_win, + dev_comm, + signals_base, + turbomind::core::Context::stream().handle(), + phases, + zero_copy); + + return {combined_x}; +} + +std::tuple, + std::optional, + std::optional, + std::vector, + Tensor, + Tensor, + Tensor, + std::optional, + Tensor, + std::optional, + Tensor, + std::optional, + std::optional, + std::optional> +Buffer::internode_dispatch(const Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, + int num_worst_tokens, + const Config& config) +{ + + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); + } + else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + + // Type checks + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dtype() == turbomind::kInt32); + } + else { + EP_HOST_ASSERT(num_tokens_per_rank->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(num_tokens_per_expert->dtype() == turbomind::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.shape(1) * byte_size(x.dtype())) % sizeof(int4) == 0); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->ndim() == 2 + and cached_rdma_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->shape(0) == num_rdma_ranks + and cached_rdma_channel_prefix_matrix->shape(1) == num_channels); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->ndim() == 1 + and cached_recv_rdma_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->shape(0) == num_rdma_ranks); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->ndim() == 2 + and cached_gbl_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->shape(0) == num_ranks + and cached_gbl_channel_prefix_matrix->shape(1) == num_channels); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->ndim() == 1 + and cached_recv_gbl_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->shape(0) == num_ranks); + } + else { + EP_HOST_ASSERT(num_tokens_per_rank->ndim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->ndim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->ndim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->shape(0) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->shape(0) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_expert->shape(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->shape(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + } + + auto num_tokens = static_cast(x.shape(0)); + auto hidden = static_cast(x.shape(1)); + auto hidden_int4 = static_cast(x.shape(1) * byte_size(x.dtype()) / sizeof(int4)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->shape(0)); + auto num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + const topk_idx_t* topk_idx_ptr = nullptr; + const float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->shape(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->ndim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->ndim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->shape(0) and num_tokens == topk_weights->shape(0)); + EP_HOST_ASSERT(num_topk == topk_weights->shape(1)); + EP_HOST_ASSERT(topk_weights->dtype() == turbomind::kFloat32); + topk_idx_ptr = topk_idx->data_or((topk_idx_t*)nullptr); + topk_weights_ptr = topk_weights->data_or((float*)nullptr); + } + + // FP8 scales checks + const float* x_scales_ptr = nullptr; + int num_scales = 0; + int scale_token_stride = 0; + int scale_hidden_stride = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(byte_size(x.dtype()) == 1); + EP_HOST_ASSERT(x_scales->dtype() == turbomind::kFloat32 or x_scales->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(x_scales->ndim() == 2); + EP_HOST_ASSERT(x_scales->shape(0) == num_tokens); + num_scales = x_scales->ndim() == 1 ? 1 : static_cast(x_scales->shape(1)); + x_scales_ptr = x_scales->data_or((float*)nullptr); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1, num_rdma_recv_tokens = -1; + auto rdma_channel_prefix_matrix = Tensor(); + auto recv_rdma_rank_prefix_sum = Tensor(); + auto gbl_channel_prefix_matrix = Tensor(); + auto recv_gbl_rank_prefix_sum = Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // used to compute offsets in MoeFfnLayer + auto moe_recv_expert_counter_ten = Tensor({num_local_experts}, turbomind::kInt32, turbomind::kDEVICE); + + auto dev_comm = comm->get_device_communicator(false); + auto nccl_win = comm->get_device_nccl_window(rdma_buffer_ptr); + auto signals_base = comm->get_signals_base(0, false); + auto gin_base_ptr = rdma_buffer_ptr; + + if (cached_mode) { + EP_HOST_ASSERT(not cached_mode); + } + else { + rdma_channel_prefix_matrix = Tensor({num_rdma_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + recv_rdma_rank_prefix_sum = Tensor({num_rdma_ranks}, turbomind::kInt32, turbomind::kDEVICE); + gbl_channel_prefix_matrix = Tensor({num_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + recv_gbl_rank_prefix_sum = Tensor({num_ranks}, turbomind::kInt32, turbomind::kDEVICE); + + // Send sizes + *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; + for (int i = 0; i < num_local_experts; ++i) + moe_recv_expert_counter[i] = -1; + internode::notify_dispatch(num_tokens_per_rank->data(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data(), + moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data(), + moe_recv_expert_counter_mapped, + moe_recv_expert_counter_ten.data(), + num_experts, + is_token_in_rank.data_or((bool*)nullptr), + num_tokens, + num_worst_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + expert_alignment, + rdma_channel_prefix_matrix.data(), + recv_rdma_rank_prefix_sum.data(), + gbl_channel_prefix_matrix.data(), + recv_gbl_rank_prefix_sum.data(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, + rank, + turbomind::core::Context::stream().handle(), + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, + low_latency_mode, + gin_base_ptr, + dev_comm, + nccl_win, + signals_base); + + // Synchronize total received tokens and tokens per expert + if (num_worst_tokens > 0) { + num_recv_tokens = num_worst_tokens; + num_rdma_recv_tokens = num_worst_tokens; + } + else { + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() + - start_time) + .count() + > NUM_CPU_TIMEOUT_SECS) { + printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", + rank, + num_recv_tokens, + num_rdma_recv_tokens); + for (int i = 0; i < num_local_experts; ++i) + printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } + } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + } + + // Allocate new tensors + auto recv_x = Tensor({num_recv_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + auto recv_topk_idx = std::optional(); + auto recv_topk_weights = std::optional(); + auto recv_x_scales = std::optional(); + auto recv_src_meta = std::optional(); + auto recv_rdma_channel_prefix_matrix = std::optional(); + auto recv_gbl_channel_prefix_matrix = std::optional(); + auto send_rdma_head = std::optional(); + auto send_nvl_head = std::optional(); + if (not cached_mode) { + recv_src_meta = + Tensor({num_recv_tokens, internode::get_source_meta_bytes()}, turbomind::kUint8, turbomind::kDEVICE); + recv_rdma_channel_prefix_matrix = Tensor({num_rdma_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + recv_gbl_channel_prefix_matrix = Tensor({num_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + send_rdma_head = Tensor({num_tokens, num_rdma_ranks}, turbomind::kInt32, turbomind::kDEVICE); + send_nvl_head = Tensor({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, turbomind::kInt32, turbomind::kDEVICE); + } + + // Assign pointers + topk_idx_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = Tensor({num_recv_tokens, num_topk}, topk_idx->dtype(), turbomind::kDEVICE); + recv_topk_weights = Tensor({num_recv_tokens, num_topk}, topk_weights->dtype(), turbomind::kDEVICE); + recv_topk_idx_ptr = recv_topk_idx->data_or((topk_idx_t*)nullptr); + recv_topk_weights_ptr = recv_topk_weights->data_or((float*)nullptr); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->ndim() == 1 ? + Tensor({num_recv_tokens}, x_scales->dtype(), turbomind::kDEVICE) : + Tensor({num_recv_tokens, num_scales}, x_scales->dtype(), turbomind::kDEVICE); + recv_x_scales_ptr = recv_x_scales->data_or((float*)nullptr); + } + + // Launch data dispatch + // NOTES: the buffer size checks are moved into the `.cu` file + internode::dispatch(recv_x.data_or((void*)nullptr), + recv_x_scales_ptr, + recv_topk_idx_ptr, + recv_topk_weights_ptr, + cached_mode ? nullptr : recv_src_meta->data_or((void*)nullptr), + x.data_or((void*)nullptr), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + cached_mode ? nullptr : send_rdma_head->data_or((int*)nullptr), + cached_mode ? nullptr : send_nvl_head->data_or((int*)nullptr), + cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data(), + cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data(), + rdma_channel_prefix_matrix.data(), + recv_rdma_rank_prefix_sum.data(), + gbl_channel_prefix_matrix.data(), + recv_gbl_rank_prefix_sum.data(), + is_token_in_rank.data_or((bool*)nullptr), + num_tokens, + num_worst_tokens, + hidden_int4, + num_scales, + num_topk, + num_experts, + scale_token_stride, + scale_hidden_stride, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + cached_mode, + turbomind::core::Context::stream().handle(), + num_channels, + low_latency_mode, + gin_base_ptr, + dev_comm, + nccl_win, + signals_base); + + // Return values + return {recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + moe_recv_expert_counter_ten, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + recv_src_meta, + send_rdma_head, + send_nvl_head}; +} + +std::tuple> // +Buffer::internode_combine(const Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const Tensor& src_meta, + const Tensor& is_combined_token_in_rank, + const Tensor& rdma_channel_prefix_matrix, + const Tensor& rdma_rank_prefix_sum, + const Tensor& gbl_channel_prefix_matrix, + Tensor& combined_rdma_head, + Tensor& combined_nvl_head, + const Config& config) +{ + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_meta.ndim() == 2 and src_meta.is_contiguous() and src_meta.dtype() == turbomind::kUint8); + EP_HOST_ASSERT(is_combined_token_in_rank.ndim() == 2 and is_combined_token_in_rank.is_contiguous() + and is_combined_token_in_rank.dtype() == turbomind::kBool); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.ndim() == 2 and rdma_channel_prefix_matrix.is_contiguous() + and rdma_channel_prefix_matrix.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(rdma_rank_prefix_sum.ndim() == 1 and rdma_rank_prefix_sum.is_contiguous() + and rdma_rank_prefix_sum.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.ndim() == 2 and gbl_channel_prefix_matrix.is_contiguous() + and gbl_channel_prefix_matrix.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(combined_rdma_head.ndim() == 2 and combined_rdma_head.is_contiguous() + and combined_rdma_head.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(combined_nvl_head.ndim() == 2 and combined_nvl_head.is_contiguous() + and combined_nvl_head.dtype() == turbomind::kInt32); + + auto num_tokens = static_cast(x.shape(0)); + auto hidden = static_cast(x.shape(1)); + auto hidden_int4 = static_cast(x.shape(1) * byte_size(x.dtype()) / sizeof(int4)); + auto num_combined_tokens = static_cast(is_combined_token_in_rank.shape(0)); + EP_HOST_ASSERT((hidden * byte_size(x.dtype())) % sizeof(int4) == 0); + EP_HOST_ASSERT(src_meta.shape(1) == internode::get_source_meta_bytes()); + EP_HOST_ASSERT(is_combined_token_in_rank.shape(1) == num_ranks); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.shape(0) == num_rdma_ranks + and rdma_channel_prefix_matrix.shape(1) == num_channels); + EP_HOST_ASSERT(rdma_rank_prefix_sum.shape(0) == num_rdma_ranks); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.shape(0) == num_ranks + and gbl_channel_prefix_matrix.shape(1) == num_channels); + EP_HOST_ASSERT(combined_rdma_head.ndim() == 2 and combined_rdma_head.shape(0) == num_combined_tokens + and combined_rdma_head.shape(1) == num_rdma_ranks); + EP_HOST_ASSERT(combined_nvl_head.ndim() == 2 and combined_nvl_head.shape(1) == NUM_MAX_NVL_PEERS); + + // Top-k checks + int num_topk = 0; + auto combined_topk_weights = std::optional(); + const float* topk_weights_ptr = nullptr; + float* combined_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->ndim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->shape(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->dtype() == turbomind::kFloat32); + num_topk = static_cast(topk_weights->shape(1)); + topk_weights_ptr = topk_weights->data_or((float*)nullptr); + combined_topk_weights = Tensor({num_combined_tokens, num_topk}, turbomind::kFloat32, turbomind::kDEVICE); + combined_topk_weights_ptr = combined_topk_weights->data_or((float*)nullptr); + } + + // Extra check for avoid-dead-lock design + EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); + + auto gin_base_ptr = rdma_buffer_ptr; + auto dev_comm = comm->get_device_communicator(false); + auto nccl_win = comm->get_device_nccl_window(rdma_buffer_ptr); + auto signals_base = comm->get_signals_base(0, false); + + // Launch barrier and reset queue head and tail + internode::cached_notify(hidden_int4, + 0, + 0, + num_topk, + num_ranks, + num_channels, + num_combined_tokens, + combined_rdma_head.data_or((int*)nullptr), + rdma_channel_prefix_matrix.data(), + rdma_rank_prefix_sum.data(), + combined_nvl_head.data_or((int*)nullptr), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, + rank, + turbomind::core::Context::stream().handle(), + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, + false, + low_latency_mode, + gin_base_ptr, + dev_comm, + nccl_win, + signals_base); + + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.ndim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.dtype() == x.dtype()); + EP_HOST_ASSERT(bias.shape(0) == num_combined_tokens and bias.shape(1) == hidden); + bias_ptrs[i] = bias.data_or((void*)nullptr); + } + + // Launch data combine + auto combined_x = Tensor({num_combined_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + internode::combine(CUDA_R_16BF, + combined_x.data_or((void*)nullptr), + combined_topk_weights_ptr, + is_combined_token_in_rank.data_or((bool*)nullptr), + x.data_or((void*)nullptr), + topk_weights_ptr, + bias_ptrs[0], + bias_ptrs[1], + combined_rdma_head.data_or((int*)nullptr), + combined_nvl_head.data_or((int*)nullptr), + src_meta.data_or((void*)nullptr), + rdma_channel_prefix_matrix.data(), + rdma_rank_prefix_sum.data(), + gbl_channel_prefix_matrix.data(), + num_tokens, + num_combined_tokens, + hidden, + num_topk, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + turbomind::core::Context::stream().handle(), + num_channels, + low_latency_mode, + gin_base_ptr, + dev_comm, + nccl_win, + signals_base); + + return {combined_x, combined_topk_weights}; +} + +Config Buffer::get_dispatch_config() +{ + static std::unordered_map config_map = { + {2, Config(num_sms, 24, 256, 6, 128)}, + {4, Config(num_sms, 6, 256, 6, 128)}, + {8, Config(num_sms, 6, 256, 6, 128)}, + {16, Config(num_sms, 36, 288, 20, 128)}, + {24, Config(num_sms, 32, 288, 8, 128)}, + {32, Config(num_sms, 32, 288, 8, 128)}, + {48, Config(num_sms, 32, 288, 8, 128)}, + {64, Config(num_sms, 32, 288, 8, 128)}, + {96, Config(num_sms, 20, 480, 12, 128)}, + {128, Config(num_sms, 20, 560, 12, 128)}, + {144, Config(num_sms, 32, 720, 12, 128)}, + {160, Config(num_sms, 28, 720, 12, 128)}, + }; + const auto it = config_map.find(num_ranks); + TM_CHECK(it != config_map.end()); + return it->second; +} + +Config Buffer::get_combine_config() +{ + static std::unordered_map config_map = { + {2, Config(num_sms, 10, 256, 6, 128)}, + {4, Config(num_sms, 9, 256, 6, 128)}, + {8, Config(num_sms, 4, 256, 6, 128)}, + {16, Config(num_sms, 4, 288, 12, 128)}, + {24, Config(num_sms, 1, 288, 8, 128)}, + {32, Config(num_sms, 1, 288, 8, 128)}, + {48, Config(num_sms, 1, 288, 8, 128)}, + {64, Config(num_sms, 1, 288, 8, 128)}, + {96, Config(num_sms, 1, 480, 8, 128)}, + {128, Config(num_sms, 1, 560, 8, 128)}, + {144, Config(num_sms, 2, 720, 8, 128)}, + {160, Config(num_sms, 2, 720, 8, 128)}, + }; + const auto it = config_map.find(num_ranks); + TM_CHECK(it != config_map.end()); + return it->second; +} + +}; // namespace deep_ep diff --git a/3rdparty/deep_ep/deep_ep.hpp b/3rdparty/deep_ep/deep_ep.hpp new file mode 100644 index 0000000000..8f3a221b1b --- /dev/null +++ b/3rdparty/deep_ep/deep_ep.hpp @@ -0,0 +1,272 @@ +#pragma once + +#include "config.hpp" + +#include "gin_backend.h" +#include "kernels/configs.cuh" +#include "kernels/exception.cuh" +#include "src/turbomind/comm/device_comm.h" +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/core/tensor.h" + +#include + +#include +#include +#include + +using turbomind::comm::HostComm; +using turbomind::comm::DeviceComm; +using turbomind::core::Tensor; +using turbomind::core::Tensor_; +using turbomind::core::Buffer_; + +namespace shared_memory { + +union MemHandleInner { + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle { + MemHandleInner inner; + size_t size; +}; + +constexpr size_t HANDLE_SIZE = sizeof(MemHandle); + +class SharedMemoryAllocator { +public: + SharedMemoryAllocator(bool use_fabric); + void malloc(void** ptr, size_t size); + void free(void* ptr); + void get_mem_handle(MemHandle* mem_handle, void* ptr); + void open_mem_handle(void** ptr, MemHandle* mem_handle); + void close_mem_handle(void* ptr); + +private: + bool use_fabric; +}; +} // namespace shared_memory + +namespace deep_ep { + +class Buffer { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); + +public: + HostComm h_comm; + DeviceComm ipc_comm; + int num_sms{24}; + + std::shared_ptr comm; + + // Low-latency mode buffer + int low_latency_buffer_idx = 0; + bool low_latency_mode = false; + + // NVLink Buffer + int64_t num_nvl_bytes; + void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + void** buffer_ptrs_gpu = nullptr; + + // NVSHMEM Buffer + int64_t num_rdma_bytes; + int64_t num_ll_rdma_bytes; + void* rdma_buffer_ptr = nullptr; + void* rdma_ll_buffer_ptr = nullptr; + + // Shrink mode buffer + bool enable_shrink = false; + int* mask_buffer_ptr = nullptr; + int* sync_buffer_ptr = nullptr; + + // Device info and communication + int device_id; + int num_device_sms; + int rank, rdma_rank, nvl_rank; + int num_ranks, num_rdma_ranks, num_nvl_ranks; + int qps_per_rank; + shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; + + // After IPC/NVSHMEM synchronization, this flag will be true + bool available = false; + + // After `destroy()` be called, this flag will be true + bool destroyed = false; + + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; + + // Workspace + void* workspace = nullptr; + + // Host-side MoE info + volatile int* moe_recv_counter = nullptr; + int* moe_recv_counter_mapped = nullptr; + + // Host-side expert-level MoE info + volatile int* moe_recv_expert_counter = nullptr; + int* moe_recv_expert_counter_mapped = nullptr; + + // Host-side RDMA-level MoE info + volatile int* moe_recv_rdma_counter = nullptr; + int* moe_recv_rdma_counter_mapped = nullptr; + + shared_memory::SharedMemoryAllocator shared_memory_allocator; + + Buffer(int rank, // + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + int64_t num_ll_rdma_bytes, + bool low_latency_mode, + bool enable_shrink, + bool use_fabric, + int qps_per_rank, + HostComm h_comm); + + Buffer(): shared_memory_allocator{false} {}; + + ~Buffer() = default; + + void allocate_sync_nvl_buffer(); + + void allocate_rdma_buffer(); + + bool is_available() const; + + bool is_internode_available() const; + + int get_num_rdma_ranks() const; + + int get_rdma_rank() const; + + int get_root_rdma_rank(bool global) const; + + int get_local_device_id() const; + + void destroy(); + + std::tuple, Tensor, Tensor> // + get_dispatch_layout(const Tensor& topk_idx, int num_experts); + + std::tuple, + std::optional, + std::optional, + std::vector, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor> + intranode_dispatch(const Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config); + + std::tuple> + intranode_combine(const Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const Tensor& src_idx, + const Tensor& rank_prefix_matrix, + const Tensor& channel_prefix_matrix, + Tensor& send_head, + const Config& config); + + std::tuple, + Tensor, + Tensor, + Tensor> + low_latency_dispatch(const Tensor& x, + const Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0); + + std::tuple // + low_latency_combine(const Tensor& x, + const Tensor& topk_idx, + const Tensor& topk_weights, + const Tensor& src_info, + const Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + const std::optional& out = std::nullopt); + + std::tuple, + std::optional, + std::optional, + std::vector, + Tensor, + Tensor, + Tensor, + std::optional, + Tensor, + std::optional, + Tensor, + std::optional, + std::optional, + std::optional> + internode_dispatch(const Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, + int num_worst_tokens, + const Config& config); + + std::tuple> // + internode_combine(const Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const Tensor& src_meta, + const Tensor& is_combined_token_in_rank, + const Tensor& rdma_channel_prefix_matrix, + const Tensor& rdma_rank_prefix_sum, + const Tensor& gbl_channel_prefix_matrix, + Tensor& combined_rdma_head, + Tensor& combined_nvl_head, + const Config& config); + + Config get_dispatch_config(); + + Config get_combine_config(); +}; + +}; // namespace deep_ep diff --git a/3rdparty/deep_ep/gin_backend.cu b/3rdparty/deep_ep/gin_backend.cu new file mode 100644 index 0000000000..3ac3439043 --- /dev/null +++ b/3rdparty/deep_ep/gin_backend.cu @@ -0,0 +1,244 @@ +#include "gin_backend.h" + +#include "kernels/configs.cuh" +#include "kernels/exception.cuh" +#include "src/turbomind/core/check.h" +#include "src/turbomind/core/context.h" +#include "src/turbomind/utils/logger.h" + +#include + +namespace deep_ep { +namespace internode { + +NCCLGINBackend::~NCCLGINBackend() +{ + if (initialized_) { + finalize(); + } +} + +int NCCLGINBackend::init( + const std::vector& root_unique_id_val, int rank, int num_ranks, bool low_latency_mode, int qps_per_rank) +{ + if (initialized_) { + return rank_; + } + TM_CHECK_EQ(low_latency_mode, true); // compatible with low latency mode + + // Check if P2P/NVLink is disabled via environment variable + const char* nccl_disable_p2p = std::getenv("NCCL_P2P_DISABLE"); + p2p_disabled_ = (nccl_disable_p2p != nullptr && std::string(nccl_disable_p2p) == "1"); + + // Determine communication topology based on mode + const int gpus_per_server = NUM_MAX_NVL_PEERS; + int comm_rank; // Rank to use for NCCL initialization + int comm_nranks; // Number of ranks in communicator + int color = -1; // Symmetric group ID (only for high throughput mode) + int group_rank = -1; // Rank within symmetric group + + if (low_latency_mode) { + // LOW LATENCY MODE: Connect to all ranks + comm_rank = rank; + comm_nranks = num_ranks; + } + else { + // HIGH THROUGHPUT MODE: Connect only to symmetric RDMA ranks + color = rank % gpus_per_server; + group_rank = rank / gpus_per_server; + comm_nranks = (num_ranks + gpus_per_server - 1) / gpus_per_server; + comm_rank = group_rank; + } + + size_t single_id_size = sizeof(ncclUniqueId); + size_t expected_ids = gpus_per_server; + EP_HOST_ASSERT(root_unique_id_val.size() == expected_ids * single_id_size + && "Number of unique IDs doesn't match NUM_MAX_NVL_PEERS"); + + if (rank == 0) { + // Print NCCL version from the actually loaded library + int nccl_version; + NCCL_CHECK(ncclGetVersion(&nccl_version)); + TM_LOG_DEBUG("[NCCLEP] NCCL version: %d.%d.%d (loaded library)", + nccl_version / 10000, + (nccl_version % 10000) / 100, + nccl_version % 100); + } + + // All gpus form a group for low latency compatible, + // otherwise, gpus with the same index across different nodes form a group. + ncclUniqueId id; + const int id_offset = (low_latency_mode) ? 0 : color * single_id_size; + std::memcpy(&id, root_unique_id_val.data() + id_offset, single_id_size); + NCCL_CHECK(ncclCommInitRank(&nccl_comm_, comm_nranks, id, comm_rank)); + + // The assumption is that kDecoupled is false when initializing SymBuffers in internode.cu + // IMPORTANT: Use global num_ranks, not comm_nranks, because kernels use global topology + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + int rdma_channel_head_signals = num_rdma_ranks * DEEP_EP_NCCL_MAX_NUM_CHANNELS; + int rdma_channel_tail_signals = num_rdma_ranks * DEEP_EP_NCCL_MAX_NUM_CHANNELS; + // + num_ht_signals_ = rdma_channel_head_signals + rdma_channel_tail_signals; + num_ll_signals_ = qps_per_rank * comm_nranks * 2; + + // Initialize Device Communicators + auto CreateDevComm = [&](ncclDevComm_t& comm, int signals) { + ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER; + reqs.barrierCount = MAX_BARRIER_SESSIONS; + reqs.ginSignalCount = signals + MAX_BARRIER_SESSIONS; + reqs.ginConnectionType = NCCL_GIN_CONNECTION_FULL; + reqs.ginContextCount = qps_per_rank; + NCCL_CHECK(ncclDevCommCreate(nccl_comm_, &reqs, &comm)); + }; + CreateDevComm(dev_ll_comm_, num_ll_signals_); // low latency mode + CreateDevComm(dev_ht_comm_, num_ht_signals_); // high throughput mode + + // Allocate barrier dummy variable + CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_barrier_var_), sizeof(int))); + CUDA_CHECK(cudaMemset(d_barrier_var_, 0, sizeof(int))); + + // Store global rank and num_ranks (for external API) + rank_ = rank; + num_ranks_ = num_ranks; + + // Store communicator-specific ranks for internal use + comm_rank_ = comm_rank; + comm_nranks_ = comm_nranks; + + initialized_ = true; + TM_LOG_DEBUG( + "[NCCLEP] Initialized global rank %d/%d (comm rank %d/%d)", rank_, num_ranks_, comm_rank_, comm_nranks_); + + return rank_; +} + +void NCCLGINBackend::finalize() +{ + TM_LOG_DEBUG("[NCCLEP][%d] Finalizing", rank_); + if (!initialized_) { + return; + } + + // Destroy device communicators + auto DestroyDevComm = [&](ncclDevComm_t& comm, std::string_view key) { + ncclResult_t res = ncclDevCommDestroy(nccl_comm_, &comm); + if (res != ncclSuccess) { + TM_LOG_ERROR("[NCCLEP][%d] Failed to destroy device communication %s: %s", + rank_, + key.data(), + ncclGetErrorString(res)); + } + }; + DestroyDevComm(dev_ll_comm_, "low latency mode"); + DestroyDevComm(dev_ht_comm_, "high throughput mode"); + + for (auto& [ptr, win] : wins_) { + TM_LOG_WARNING("[NCCLEP][%d] Memory %p is not deregistered", rank_, ptr); + } + for (auto& [ptr, size] : buffers_) { + TM_LOG_WARNING("[NCCLEP][%d] Allocation (%p, %lu) is not freed", rank_, ptr, size); + } + + // Free barrier dummy variable + if (d_barrier_var_ != nullptr) { + cudaFree(d_barrier_var_); + d_barrier_var_ = nullptr; + } + // Destroy all communicators + ncclCommFinalize(nccl_comm_); + ncclCommDestroy(nccl_comm_); + + TM_LOG_DEBUG("[NCCLEP][%d] Destroyed NCCL communicator", rank_); + initialized_ = false; +} + +void NCCLGINBackend::barrier() +{ + TM_CHECK_EQ(initialized_, true); + TM_CHECK_NE(d_barrier_var_, nullptr); + + cudaStream_t stream = turbomind::core::Context::stream().handle(); + NCCL_CHECK(ncclGroupStart()); + NCCL_CHECK(ncclAllReduce(d_barrier_var_, d_barrier_var_, 1, ncclInt, ncclSum, nccl_comm_, stream)); + NCCL_CHECK(ncclGroupEnd()); +} + +void* NCCLGINBackend::alloc(size_t size, size_t /*alignment*/) +{ + TM_CHECK_EQ(initialized_, true); + + void* ptr = nullptr; + // NCCL memory is already aligned to page size, so alignment parameter is ignored for now. + NCCL_CHECK(ncclMemAlloc(&ptr, size)); + buffers_.emplace(ptr, size); + return ptr; +} + +void NCCLGINBackend::register_memory(void* ptr, size_t size) +{ + TM_CHECK_EQ(initialized_, true); + TM_CHECK_EQ(buffers_.find(ptr) != buffers_.end(), true); + TM_CHECK_EQ(wins_.find(ptr) == wins_.end(), true); + ncclWindow_t win{}; + NCCL_CHECK(ncclCommWindowRegister(nccl_comm_, ptr, size, &win, 0)); + wins_.emplace(ptr, win); +} + +void NCCLGINBackend::free(void* ptr) +{ + TM_CHECK_EQ(initialized_, true); + auto it = wins_.find(ptr); + TM_CHECK_EQ(it != wins_.end(), true); + NCCL_CHECK(ncclCommWindowDeregister(nccl_comm_, it->second)); + NCCL_CHECK(ncclMemFree(ptr)); + wins_.erase(it); + buffers_.erase(ptr); +} + +int NCCLGINBackend::get_rank() const +{ + TM_CHECK_NE(rank_, -1); + return rank_; +} + +int NCCLGINBackend::get_num_ranks() const +{ + TM_CHECK_NE(num_ranks_, -1); + return num_ranks_; +} + +bool NCCLGINBackend::is_p2p_disabled() const +{ + return p2p_disabled_; +} + +unsigned NCCLGINBackend::get_signals_base(int buffer_idx, bool low_latency_mode) const +{ + if (low_latency_mode) { + EP_HOST_ASSERT(buffer_idx == 0 || buffer_idx == 1); + TM_CHECK_NE(num_ll_signals_, 0); + return buffer_idx * num_ll_signals_ / 2; + } + else { + EP_HOST_ASSERT(buffer_idx == 0); + TM_CHECK_NE(num_ht_signals_, 0); + return 0; + } +} + +ncclWindow_t NCCLGINBackend::get_device_nccl_window(void* ptr) +{ + TM_CHECK_EQ(initialized_, true); + auto it = wins_.find(ptr); + TM_CHECK_EQ(it != wins_.end(), true); + return it->second; +} + +ncclDevComm NCCLGINBackend::get_device_communicator(bool low_latency_mode) const +{ + TM_CHECK_EQ(initialized_, true); + return low_latency_mode ? dev_ll_comm_ : dev_ht_comm_; +} + +} // namespace internode +} // namespace deep_ep diff --git a/3rdparty/deep_ep/gin_backend.h b/3rdparty/deep_ep/gin_backend.h new file mode 100644 index 0000000000..a10de528c3 --- /dev/null +++ b/3rdparty/deep_ep/gin_backend.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include + +#include + +#define DEEP_EP_GIN_MAX_CONTEXTS 32 +#define DEEP_EP_NCCL_GIN_CTXS_PER_COMM 4 +#define DEEP_EP_NCCL_MAX_NUM_CHANNELS 32 // Max number of local experts per GPU + +namespace deep_ep { +namespace internode { + +struct NcclGinMemHandle { + void* ptr = nullptr; +}; + +class NCCLGINBackend { +public: + NCCLGINBackend(): initialized_(false), rank_(-1), num_ranks_(-1) {} + + ~NCCLGINBackend(); + + // Required interface methods + int init(const std::vector& root_unique_id_val, + int rank, + int num_ranks, + bool low_latency_mode, + int qps_per_rank); + + void finalize(); + void barrier(); + + // Memory management interface methods + void* alloc(size_t size, size_t alignment); + void register_memory(void* ptr, size_t size); // NCCL-specific: register allocated memory with communicators + void free(void* ptr); + + int get_rank() const; + int get_num_ranks() const; + + // NCCL-specific methods + bool is_p2p_disabled() const; + + // NCCL specific methods + unsigned get_signals_base(int buffer_idx, bool low_latency_mode) const; + + // Device arrays for kernels + ncclWindow_t get_device_nccl_window(void* ptr); + ncclDevComm get_device_communicator(bool low_latency_mode) const; + +private: + bool initialized_ = false; + bool p2p_disabled_ = false; // True if P2P/NVLink is disabled + int rank_ = -1; // Global rank (for external API) + int num_ranks_ = -1; // Global num_ranks (for external API) + int comm_rank_ = -1; // Rank within NCCL communicator + int comm_nranks_ = -1; // Number of ranks in NCCL communicator + + ncclComm_t nccl_comm_; + + ncclDevComm_t dev_ht_comm_{}; + ncclDevComm_t dev_ll_comm_{}; + + std::unordered_map wins_; + std::unordered_map buffers_; + + // GIN signal management + int num_ht_signals_ = 0; + int num_ll_signals_ = 0; + + // GIN barriers -- assume 32 rdma ranks + const int MAX_BARRIER_SESSIONS = 32; + + // Barrier variable + int* d_barrier_var_ = nullptr; +}; + +} // namespace internode +} // namespace deep_ep diff --git a/3rdparty/deep_ep/kernels/api.cuh b/3rdparty/deep_ep/kernels/api.cuh new file mode 100644 index 0000000000..282aa8480d --- /dev/null +++ b/3rdparty/deep_ep/kernels/api.cuh @@ -0,0 +1,397 @@ +// clang-format off +#pragma once + +#include +#include + +#include + +#include "configs.cuh" + +namespace deep_ep { + +// Intranode runtime +namespace intranode { + +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); + +} // namespace intranode + +// Internode runtime +namespace internode { + +class NCCLGINBackend; + +std::vector get_unique_id(); + +int init(const std::vector& root_unique_id_val, + int rank, + int num_ranks, + bool low_latency_mode, + int qps_per_rank, + NCCLGINBackend* comm); + +void* alloc(size_t size, size_t alignment, NCCLGINBackend* comm); + +void register_memory(void* ptr, size_t size, NCCLGINBackend* comm); + +void free(void* ptr, NCCLGINBackend* comm); + +void barrier(NCCLGINBackend* comm); + +void finalize(NCCLGINBackend* comm); + +} // namespace internode + +// Layout kernels +namespace layout { + +void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts, + cudaStream_t stream); + +} // namespace layout + +// Intranode kernels +namespace intranode { + +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int* moe_recv_expert_counter_ten, + int num_experts, + int num_tokens, + const bool* is_token_in_rank, + int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, + int num_memset_int, + int expert_alignment, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int num_sms); + +void cached_notify_dispatch(const int* rank_prefix_matrix, + int num_memset_int, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + int num_ranks, + cudaStream_t stream); + +void dispatch(void* recv_x, + float* recv_x_scales, + int* recv_src_idx, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + int* send_head, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int scale_token_stride, + int scale_hidden_stride, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); + +void cached_notify_combine(void** buffer_ptrs, + int* send_head, + int num_channels, + int num_recv_tokens, + int num_memset_int, + int** barrier_signal_ptrs, + int rank, + int num_ranks, + cudaStream_t stream); + +void combine(cudaDataType_t type, + void* recv_x, + float* recv_topk_weights, + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* src_idx, + const int* rank_prefix_matrix, + const int* channel_prefix_matrix, + int* send_head, + int num_tokens, + int num_recv_tokens, + int hidden, + int num_topk, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); + +} // namespace intranode + +// Internode kernels +namespace internode { + +int get_source_meta_bytes(); + +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int* moe_recv_expert_counter_ten, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base); + +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base); + +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base); + +void combine(cudaDataType_t type, + void* combined_x, + float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base); + +} // namespace internode + +// Internode low-latency kernels +namespace internode_ll { + +void clean_low_latency_buffer(int* clean_0, + int num_clean_int_0, + int* clean_1, + int num_clean_int_1, + int rank, + int num_ranks, + int* mask_buffer, + int* sync_buffer, + cudaStream_t stream); + +void dispatch(void* packed_recv_x, + void* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* mask_buffer, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_recv_count, + void* rdma_x, + size_t rdma_recv_x_offset, + size_t rdma_recv_count_offset, + size_t rdma_x_offset, + const void* x, + const topk_idx_t* topk_idx, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + ncclWindow_t nccl_win, + ncclDevComm dev_comm, + unsigned signals_base, + cudaStream_t stream, + int phases); + +void combine(void* combined_x, + void* rdma_recv_x, + int* rdma_recv_flag, + void* rdma_send_x, + size_t rdma_recv_x_offset, + size_t rdma_recv_flag_offset, + size_t rdma_send_x_offset, + const void* x, + const topk_idx_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + int* mask_buffer, + int64_t* combine_wait_recv_cost_stats, + int* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_logfmt, + void* workspace, + int num_device_sms, + ncclWindow_t nccl_win, + ncclDevComm dev_comm, + unsigned signals_base, + cudaStream_t stream, + int phases, + bool zero_copy); + +void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream); + +void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, cudaStream_t stream); + +void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream); + +void set_p2p_disabled_flag(bool disabled); + +} // namespace internode_ll + +} // namespace deep_ep + +// clang-format on diff --git a/3rdparty/deep_ep/kernels/buffer.cuh b/3rdparty/deep_ep/kernels/buffer.cuh new file mode 100644 index 0000000000..673fc86ae4 --- /dev/null +++ b/3rdparty/deep_ep/kernels/buffer.cuh @@ -0,0 +1,134 @@ +// clang-format off +#pragma once + +#include "configs.cuh" +#include "exception.cuh" + +namespace deep_ep { + +template +struct Buffer { +private: + uint8_t* ptr; + +public: + int64_t total_bytes; + + __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} + + __device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) { + total_bytes = num_elems * sizeof(dtype_t); + ptr = static_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) { + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer() { return reinterpret_cast(ptr); } + + __device__ __forceinline__ dtype_t& operator[](int idx) { return buffer()[idx]; } +}; + +template +struct AsymBuffer { +private: + uint8_t* ptrs[kNumRanks]; + int64_t num_bytes; + +public: + int64_t total_bytes; + + __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int64_t per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int64_t per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + for (int i = 0; i < kNumRanks; ++i) { + ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; + } + } + + __device__ __forceinline__ void advance(int shift) { + #pragma unroll + for (int i = 0; i < kNumRanks; ++i) + ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); + } + + __device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) { + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + return *this; + } + + template + __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { + for (int i = 0; i < kNumAlsoRanks; ++i) + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[0] + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); + } +}; + +template +struct SymBuffer { +private: + // NOTES: for non-decoupled case, `recv_ptr` is not used + uint8_t* send_ptr; + uint8_t* recv_ptr; + int64_t num_bytes; + +public: + int64_t total_bytes; + + __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { + num_bytes = num_elems * sizeof(dtype_t); + + int64_t per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); + send_ptr = static_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = static_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); + return reinterpret_cast(recv_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } +}; + +} // namespace deep_ep + +// clang-format on diff --git a/3rdparty/deep_ep/kernels/configs.cuh b/3rdparty/deep_ep/kernels/configs.cuh new file mode 100644 index 0000000000..9669120dcf --- /dev/null +++ b/3rdparty/deep_ep/kernels/configs.cuh @@ -0,0 +1,81 @@ +#pragma once + +#define NUM_MAX_NVL_PEERS 8 +#define NUM_MAX_RDMA_PEERS 20 +#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) +#define NUM_MAX_LOCAL_EXPERTS 1024 +#define NUM_BUFFER_ALIGNMENT_BYTES 128 + +#define FINISHED_SUM_TAG 1024 +#define NUM_WAIT_NANOSECONDS 500 + +#ifndef ENABLE_FAST_DEBUG +#define NUM_CPU_TIMEOUT_SECS 100 +#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s +#else +#define NUM_CPU_TIMEOUT_SECS 10 +#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s +#endif + +#define LOW_LATENCY_SEND_PHASE 1 +#define LOW_LATENCY_RECV_PHASE 2 + +// Make CLion CUDA indexing work +#ifdef __CLION_IDE__ +#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) +#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) +#endif + +// Define __CUDACC_RDC__ to ensure proper extern declarations for NVSHMEM device symbols +#ifndef DISABLE_NVSHMEM +#ifndef __CUDACC_RDC__ +#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) +#endif +#endif + +// Remove Torch restrictions +#ifdef __CUDA_NO_HALF_CONVERSIONS__ +#undef __CUDA_NO_HALF_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_HALF_OPERATORS__ +#undef __CUDA_NO_HALF_OPERATORS__ +#endif +#ifdef __CUDA_NO_HALF2_OPERATORS__ +#undef __CUDA_NO_HALF2_OPERATORS__ +#endif +#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#undef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_BFLOAT162_OPERATORS__ +#undef __CUDA_NO_BFLOAT162_OPERATORS__ +#endif + +#include +#include + +#include + +#ifndef DISABLE_SM90_FEATURES +#include +#else +// Ampere does not support FP8 features +#define __NV_E4M3 0 +#define __NV_E5M2 1 +typedef int __nv_fp8_interpretation_t; +typedef int __nv_fp8x4_e4m3; +typedef uint8_t __nv_fp8_storage_t; +#endif + +namespace deep_ep { + +#ifndef TOPK_IDX_BITS +#define TOPK_IDX_BITS 64 +#endif + +#define INT_BITS_T2(bits) int##bits##_t +#define INT_BITS_T(bits) INT_BITS_T2(bits) +typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t +#undef INT_BITS_T +#undef INT_BITS_T2 + +} // namespace deep_ep diff --git a/3rdparty/deep_ep/kernels/exception.cuh b/3rdparty/deep_ep/kernels/exception.cuh new file mode 100644 index 0000000000..d6086f4343 --- /dev/null +++ b/3rdparty/deep_ep/kernels/exception.cuh @@ -0,0 +1,76 @@ +// clang-format off +#pragma once + +#include +#include + +#include "configs.cuh" + +#ifndef EP_STATIC_ASSERT +#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +class EPException : public std::exception { +private: + std::string message = {}; + +public: + explicit EPException(const char* name, const char* file, const int line, const std::string& error) { + message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + } + + const char* what() const noexcept override { return message.c_str(); } +}; + +#ifndef CUDA_CHECK +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = (cmd); \ + if (e != cudaSuccess) { \ + throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ + } \ + } while (0) +#endif + +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ + do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char* error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ + } while (0) +#endif + +#ifndef EP_HOST_ASSERT +#define EP_HOST_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + throw EPException("Assertion", __FILE__, __LINE__, #cond); \ + } \ + } while (0) +#endif + +#ifndef EP_DEVICE_ASSERT +#define EP_DEVICE_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while (0) +#endif + +#ifndef NCCL_CHECK +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t e = (cmd); \ + if (e != ncclSuccess) { \ + throw EPException("NCCL", __FILE__, __LINE__, ncclGetErrorString(e)); \ + } \ + } while (0) +#endif + +// clang-format on diff --git a/3rdparty/deep_ep/kernels/internode.cu b/3rdparty/deep_ep/kernels/internode.cu new file mode 100644 index 0000000000..e0f7f0b2f1 --- /dev/null +++ b/3rdparty/deep_ep/kernels/internode.cu @@ -0,0 +1,2611 @@ +// clang-format off +#include +#include + +#include "buffer.cuh" +#include "configs.cuh" + +#include +#include +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" + +#include +#include "../gin_backend.h" + + +using namespace cooperative_groups; + +namespace deep_ep { + +namespace internode { + +struct SourceMeta { + int src_rdma_rank, is_token_in_nvl_rank_bits; + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); + + __forceinline__ SourceMeta() = default; + + // TODO: faster encoding + __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { + src_rdma_rank = rdma_rank; + is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; + #pragma unroll + for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i) + is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; + } + + __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; } +}; + +EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + +int get_source_meta_bytes() { + return sizeof(SourceMeta); +} + +__host__ __device__ __forceinline__ int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { + return static_cast(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), + sizeof(int4))); +} + +__host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_rdma_ranks, + int num_rdma_recv_buffer_tokens, + int num_channels) { + // Return `int32_t` offset and count to clean + return {(get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * + num_rdma_ranks * 2 * num_channels) / + sizeof(int), + (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels}; +} + +__host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_rdma_ranks, + int num_nvl_ranks, + int num_nvl_recv_buffer_tokens, + int num_channels, + bool is_dispatch) { + // Return `int32_t` offset and to clean + EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + + return { + (num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * + num_channels) / + sizeof(int), + num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, + }; +} + +template +__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) { + return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank; +} + +template +__forceinline__ __device__ void sync_with_same_gpu_idx(ncclDevComm& dev_comm) { + // Barrier before cleaning (in case of unfinished chunked EP) + ncclGin net(dev_comm, 0); + + if (kLowLatencyMode) { + // Use rank as session ID for symmetric synchronization + // This ensures each rank has a unique session ID for symmetric barriers + int session_id = dev_comm.lsaRank; + + // Use GIN barrier session directly with symmetric team + ncclGinBarrierSession barrier(ncclCoopThread(), net, ncclTeamTagRail(), session_id); + barrier.sync(ncclCoopThread(), cuda::memory_order_relaxed, ncclGinFenceLevel::Relaxed); + } else { + // World barrier - synchronizes all ranks + ncclBarrierSession barrier(ncclCoopThread(), ncclTeamTagWorld(), net, 0); + barrier.sync(ncclCoopThread(), cuda::memory_order_relaxed, ncclGinFenceLevel::Relaxed); + } +} + +template +__global__ void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int* moe_recv_expert_counter_ten, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int num_channels, + int expert_alignment, + const int rdma_clean_offset, + const int rdma_num_int_clean, + const int nvl_clean_offset, + const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + void* gin_base_ptr, + ncclDevComm dev_comm, + const ncclWindow_t nccl_win, + unsigned signals_base +) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp does intra-node sync, the second warp does internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); + + // waiting for all previous inflight wrs to complete, + // in case of rewriting cleared rdma_buffer + for (int i = thread_id; i < dev_comm.ginConnectionCount; i += num_threads) { + ncclGin net(dev_comm, i); + net.flush(ncclCoopThread(), cuda::std::memory_order_acquire); + } + __syncthreads(); + if (thread_id == 32) + sync_with_same_gpu_idx(dev_comm); + + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); + auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); + #pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + + // Clean signals here + // For each channel we have kNumRDMARanks head and tail signals + int num_signals = kNumRDMARanks * num_channels * 2; + EP_DEVICE_ASSERT(num_signals <= num_threads); + EP_DEVICE_ASSERT(num_channels * 2 <= dev_comm.ginContextCount); + + // NCCL backend use signal to update head/tail + // Each thread handles one specific signal across all contexts + if (thread_id < num_signals) { + auto signal_id = signals_base + thread_id; + + // Derive channel_id from signal_id + // Signal layout: [all head signals][all tail signals] + int signal_offset = thread_id; + int head_signal_count = kNumRDMARanks * num_channels; + bool is_head_signal = (signal_offset < head_signal_count); + int channel_id = is_head_signal ? signal_offset / kNumRDMARanks : (signal_offset - head_signal_count) / kNumRDMARanks; + + auto context_id = is_head_signal ? channel_id + num_channels : channel_id ; + ncclGin net(dev_comm, context_id); + net.resetSignal(signal_id); + } + __syncthreads(); + + // Copy to send buffer + #pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; + #pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = + num_tokens_per_expert[i]; + if (thread_id < kNumRDMARanks) + rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id]; + __syncthreads(); + + // Issue send + // TODO: more light fence or barrier or signaling + // TODO: overlap EP barrier and NVL cleaning + for (int i = warp_id; i < kNumRDMARanks; i += num_warps) { + if (i != rdma_rank) { + // Distribute work across GIN contexts + int dst_rank = translate_dst_rdma_rank(i, nvl_rank); + size_t src_offset = + reinterpret_cast(rdma_recv_num_tokens_mixed.send_buffer(i)) - reinterpret_cast(gin_base_ptr); + size_t dst_offset = reinterpret_cast(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)) - + reinterpret_cast(gin_base_ptr); + size_t bytes = (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int); + + ncclGin net(dev_comm, 0); + ncclTeam world = ncclTeamWorld(dev_comm); + net.put(world, + dst_rank, + nccl_win, + dst_offset, + nccl_win, + src_offset, + bytes, + ncclGin_None{}, // no signal + ncclGin_None{}, // no counter + ncclCoopWarp()); + } else { + UNROLLED_WARP_COPY(1, + lane_id, + NUM_MAX_NVL_PEERS + num_rdma_experts + 1, + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(i), + ld_volatile_global, + st_na_global); + } + } + __syncthreads(); + + // Wait previous operations to be finished + // Flush all contexts + ncclGin net(dev_comm, 0); + net.flush(ncclCoopCta(), cuda::std::memory_order_acquire); + + // Barrier + if (thread_id == 0) + sync_with_same_gpu_idx(dev_comm); + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + auto nvl_reduced_num_tokens_per_expert = Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); + auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + + nvl_send_num_tokens_per_expert.total_bytes <= + nvl_clean_offset * sizeof(int)); + #pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Reduce number of tokens per expert into the NVL send buffer + // TODO: may use NVSHMEM reduction + EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); + if (thread_id < num_rdma_experts) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; + nvl_reduced_num_tokens_per_expert[thread_id] = sum; + } + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id == 0) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; + recv_rdma_rank_prefix_sum[i] = sum; + } + if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) + ; + *moe_recv_rdma_counter_mapped = sum; + } + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; + #pragma unroll + for (int i = 0; i < num_nvl_experts; ++i) + nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; + } + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Reduce the number of tokens per rank/expert + EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); + if (thread_id == 0) { + int sum = 0; + #pragma unroll + for (int i = 0; i < num_ranks; ++i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; + recv_gbl_rank_prefix_sum[i] = sum; + } + if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_counter_mapped) != -1) + ; + *moe_recv_counter_mapped = sum; + } + } + if (thread_id < num_nvl_experts) { + int sum = 0; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1) + ; + moe_recv_expert_counter_mapped[thread_id] = sum; + moe_recv_expert_counter_ten[thread_id] = sum; + } + } + + // Finally barrier + if (thread_id == 32) + sync_with_same_gpu_idx(dev_comm); + barrier_block(barrier_signal_ptrs, nvl_rank); + } else { + // Calculate meta data + int dst_rdma_rank = sm_id - 1; + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = + *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); + #pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + total_count += (is_token_in_rank_uint64 != 0); + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (elect_one_sync()) { + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; + #pragma unroll + for (int i = 1; i < num_channels; ++i) + prefix_row[i] += prefix_row[i - 1]; + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; + #pragma unroll + for (int i = 1; i < num_channels; ++i) + prefix_row[i] += prefix_row[i - 1]; + } + } +} + +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int* moe_recv_expert_counter_ten, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base) +{ +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_dispatch_func = \ + low_latency_mode ? notify_dispatch : notify_dispatch; \ + LAUNCH_KERNEL(&cfg, \ + notify_dispatch_func, \ + num_tokens_per_rank, \ + moe_recv_counter_mapped, \ + num_ranks, \ + num_tokens_per_rdma_rank, \ + moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, \ + moe_recv_expert_counter_mapped, \ + moe_recv_expert_counter_ten, \ + num_experts, \ + is_token_in_rank, \ + num_tokens, \ + num_worst_tokens, \ + num_channels, \ + expert_alignment, \ + rdma_clean_meta.first, \ + rdma_clean_meta.second, \ + nvl_clean_meta.first, \ + nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, \ + recv_gbl_rank_prefix_sum, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + barrier_signal_ptrs, \ + rank, \ + gin_base_ptr, \ + dev_comm, \ + nccl_win, \ + signals_base); \ + } \ + break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta( + hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + NUM_MAX_NVL_PEERS, + num_max_nvl_chunked_recv_tokens, + num_channels, + true); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +// At most 8 RDMA ranks to be sent +constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { + return num_rdma_ranks < 8 ? num_rdma_ranks : 8; +} + +template +__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) + dispatch(int4* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + SourceMeta* recv_src_meta, + const int4* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base + ) { + enum class WarpRole { kRDMASender, kRDMASenderCoordinator, kRDMAAndNVLForwarder, kForwarderCoordinator, kNVLReceivers }; + + const auto num_sms = static_cast(gridDim.x); + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_channels = num_sms / 2, channel_id = sm_id / 2; + const bool is_forwarder = sm_id % 2 == 0; + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + + const auto role_meta = [=]() -> std::pair { + if (is_forwarder) { + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + } + } else if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASenderCoordinator, -1}; + } else { + return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + }(); + auto warp_role = role_meta.first; + auto target_rank = role_meta.second; // Not applicable for RDMA senders + EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); + + // Data checks + EP_DEVICE_ASSERT(num_topk <= 32); + + // RDMA symmetric layout + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto scale_bytes = num_scales * sizeof(float); + auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk); + auto rdma_channel_data = SymBuffer( + rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + // Using signals for RDMA head and tail instead with gin_signals + auto gin_signals_head = signals_base + kNumRDMARanks * channel_id; // move the signals to the corresponding channel + auto gin_signals_tail = signals_base + kNumRDMARanks * num_channels + + kNumRDMARanks * channel_id; // move the signals to the corresponding channel after passing all head signals + + // Use a diff GIN context and window for each channel/SM + ncclGin net(dev_comm, channel_id); + ncclTeam world = ncclTeamWorld(dev_comm); + // Using different communicator for reading/writing head pointers + ncclGin net_head(dev_comm, channel_id + num_channels); + ncclTeam world_head = ncclTeamWorld(dev_comm); + + + // NVL buffer layouts + // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for + // Receivers" + void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; + int rs_wr_rank = 0, ws_rr_rank = 0; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) + rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, + ws_rr_rank = target_rank; + if (warp_role == WarpRole::kNVLReceivers) + rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, + ws_rr_rank = nvl_rank; + + // Allocate buffers + auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, + num_max_nvl_chunked_recv_tokens * num_bytes_per_token, + NUM_MAX_NVL_PEERS, + channel_id, + num_channels, + rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_start = + AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) + .advance_also(rs_wr_buffer_ptr); + auto nvl_channel_head = + AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); + auto nvl_channel_tail = + AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + + // RDMA sender warp synchronization + // NOTES: `rdma_send_channel_tail` means the latest released tail + // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status + __shared__ int rdma_send_channel_lock[kNumRDMARanks]; + __shared__ int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; + auto sync_rdma_sender_smem = []() { asm volatile("barrier.sync 0, %0;" ::"r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; + + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + num_bytes_per_token); + uint32_t tma_phase = 0; + if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and elect_one_sync()) { + mbarrier_init(tma_mbarrier, 1); + fence_barrier_init(); + EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp); + } + __syncwarp(); + + // Forward warp synchronization + __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; + __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; + auto sync_forwarder_smem = []() { asm volatile("barrier.sync 1, %0;" ::"r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + + if (warp_role == WarpRole::kRDMASender) { + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Send number of tokens in this channel by `-value - 1` + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); + for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = + dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); + if (lane_id < NUM_MAX_NVL_PEERS) { + dst_ptr[lane_id] = + -(channel_id == 0 + ? 0 + : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - + 1; + } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = + -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + + channel_id] - + 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + // Issue RDMA for non-local ranks + if (dst_rdma_rank != rdma_rank) { + // kRDMASender: These are channel-specific routing metadata + int dst_rank = translate_dst_rdma_rank(dst_rdma_rank, nvl_rank); + size_t src_offset = + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)) - reinterpret_cast(gin_base_ptr); + size_t dst_offset = + reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)) - reinterpret_cast(gin_base_ptr); + size_t bytes = sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2); + + net.put(world, + dst_rank, + nccl_win, + dst_offset, + nccl_win, + src_offset, + bytes, + ncclGin_None{}, // no signal + ncclGin_None{}, // no counter + ncclCoopWarp()); + } + } + sync_rdma_sender_smem(); + + // Iterate over tokens and copy into buffer + int64_t token_idx; + int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0; + auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); + for (token_idx = token_start_idx; token_idx < token_end_idx; ++token_idx) { + // Read RDMA rank existence + uint64_t is_token_in_rank_uint64 = 0; + if (lane_id < kNumRDMARanks) { + is_token_in_rank_uint64 = + __ldg(reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS)); + global_rdma_tail_idx += (is_token_in_rank_uint64 != 0); + } + __syncwarp(); + + // Skip the token which does not belong to this warp + if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id) + continue; + auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; + + // Wait the remote buffer to be released + auto start_time = clock64(); + while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { + // kRDMASender: Check available space with head pointers to avoid overflow + auto signal_id = gin_signals_head + lane_id; + uint64_t signal_value = net_head.readSignal(signal_id); + cached_rdma_channel_head = static_cast(signal_value); + + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + cached_rdma_channel_head, + rdma_tail_idx); + trap(); + } + } + __syncwarp(); + + // Store RDMA head for combine + if (lane_id < kNumRDMARanks and not kCachedMode) + send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; + + // Broadcast tails + SourceMeta src_meta; + int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; + void* dst_send_buffers[kNumTopkRDMARanks]; + #pragma unroll + for (int i = 0, slot_idx; i < kNumRDMARanks; ++i) + if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) { + slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; + topk_ranks[num_topk_ranks] = i; + auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); + auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); + if (lane_id == num_topk_ranks) + src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); + dst_send_buffers[num_topk_ranks++] = + reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_token; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); + + // Copy `x` into symmetric send buffer + auto st_broadcast = [=](const int key, const int4& value) { + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); + }; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; + + // Copy `x_scales` into symmetric send buffer + #pragma unroll + for (int i = lane_id; i < num_scales; i += 32) { + auto offset = token_idx * scale_token_stride + i * scale_hidden_stride; + auto value = ld_nc_global(x_scales + offset); + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); + } + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; + + // Copy source metadata into symmetric send buffer + if (lane_id < num_topk_ranks) + st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; + + // Copy `topk_idx` and `topk_weights` into symmetric send buffer + #pragma unroll + for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { + auto rank_idx = i / num_topk, copy_idx = i % num_topk; + auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); + auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); + } + __syncwarp(); + + // Release the transaction in the window + if (is_token_in_rank_uint64 != 0) { + // Acquire lock first + acquire_lock(rdma_send_channel_lock + lane_id); + auto latest_tail = rdma_send_channel_tail[lane_id]; + auto offset = rdma_tail_idx - latest_tail; + while (offset >= 32) { + release_lock(rdma_send_channel_lock + lane_id); + acquire_lock(rdma_send_channel_lock + lane_id); + latest_tail = rdma_send_channel_tail[lane_id]; + offset = rdma_tail_idx - latest_tail; + } + + // Release the transaction slot + // Add the bit and move the ones if possible + auto window = rdma_send_channel_window[lane_id] | (1u << offset); + if (offset == 0) { + auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1; + st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); + window >>= num_empty_slots; + } + rdma_send_channel_window[lane_id] = window; + + // Release lock + release_lock(rdma_send_channel_lock + lane_id); + } + __syncwarp(); + } + } else if (warp_role == WarpRole::kRDMASenderCoordinator) { + // NOTES: in case of splitting, the issued put at the end of the buffer + EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; + + // Synchronize shared memory + sync_rdma_sender_smem(); + + // Get number of tokens to send for each RDMA rank + int num_tokens_to_send = 0; + if (lane_id < kNumRDMARanks) { + num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; + if (channel_id > 0) + num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; + } + + // Iterate all RDMA ranks + int last_issued_tail = 0; + auto start_time = clock64(); + while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + last_issued_tail, + num_tokens_to_send); + trap(); + } + + // TODO: try thread-level `put_nbi`? + for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) { + // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels + int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; + synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank); + if (synced_num_tokens_to_send == 0) + continue; + + // Read the latest progress + // NOTES: `rdma_send_channel_tail` does not need to be protected by lock + auto processed_tail = + __shfl_sync(0xffffffff, ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)), 0); + auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); + auto num_tokens_processed = processed_tail - synced_last_issued_tail; + if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) + continue; + + // Issue RDMA send + auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); + EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send); + if (dst_rdma_rank != rdma_rank) { + auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; + EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); + const size_t num_bytes_per_msg = num_bytes_per_token * num_tokens_to_issue; + const auto dst_ptr = + reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_token); + const auto src_ptr = + reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token); + // kRDMASenderCoordinator: Send tokens to remote RDMA ranks + int dst_rank = translate_dst_rdma_rank(dst_rdma_rank, nvl_rank); + size_t src_offset = + reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token) - + reinterpret_cast(gin_base_ptr); + size_t dst_offset = + reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_token) - + reinterpret_cast(gin_base_ptr); + net.put(world, + dst_rank, + nccl_win, + dst_offset, + nccl_win, + src_offset, + num_bytes_per_msg, + ncclGin_None{}, // no signal + ncclGin_None{}, // no counter + ncclCoopWarp()); + } else { + // Lighter fence for local RDMA rank + memory_fence(); + } + __syncwarp(); + + // Update tails + if (lane_id == dst_rdma_rank) { + last_issued_tail += num_tokens_to_issue; + num_tokens_to_send -= num_tokens_to_issue; + // kRDMASenderCoordinator:Update tails + auto dst_rank = translate_dst_rdma_rank(dst_rdma_rank, nvl_rank); + auto signal_id = gin_signals_tail + rdma_rank; + + net.signal(world, // team + dst_rank, // destination rank + ncclGin_SignalAdd{signal_id, (uint64_t)num_tokens_to_issue}, // signal + value + ncclCoopThread(), // cooperation scope (default) + ncclGin_None{}, // no descriptor (default) + cuda::thread_scope_thread, // alreadyReleased (default) + cuda::thread_scope_thread // expected_scope (default) + ); + } + __syncwarp(); + } + } + } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + // RDMA consumers and NVL producers + const auto dst_nvl_rank = target_rank; + + // Wait counters to arrive + int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + auto start_time = clock64(); + if (lane_id < kNumRDMARanks) { + while (true) { + auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); + auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); + auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); + auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); + if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { + // Notify NVL ranks + int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; + EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); + st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); + + // Save RDMA channel received token count + src_rdma_channel_prefix = -meta_2 - 1; + auto src_rdma_channel_prefix_1 = -meta_3 - 1; + num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; + if (not kCachedMode) + recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; + src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; + EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, " + "meta: %d, %d, %d, %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + dst_nvl_rank, + meta_0, + meta_1, + meta_2, + meta_3); + trap(); + } + } + } + __syncwarp(); + + // Shift cached head + send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; + + // Wait shared memory to be cleaned + sync_forwarder_smem(); + + // Forward tokens from RDMA buffer + // NOTES: always start from the local rank + int src_rdma_rank = sm_id % kNumRDMARanks; + int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; + int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; + while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { + // Check destination queue emptiness, or wait a buffer to be released + start_time = clock64(); + while (true) { + const int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; + if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) + break; + cached_nvl_channel_head = __shfl_sync(0xffffffffu, ld_volatile_global(nvl_channel_head.buffer()), 0); + + // Timeout check + if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", + channel_id, + rdma_rank, + nvl_rank, + dst_nvl_rank, + ld_volatile_global(nvl_channel_head.buffer()), + cached_nvl_channel_tail); + trap(); + } + } + + // Find next source RDMA rank (round-robin) + start_time = clock64(); + while (true) { + src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; + if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { + if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) { + // kRDMAAndNVLForwarder: Read local tail for availability + auto signal_id = gin_signals_tail + src_rdma_rank; + uint64_t signal_value = net.readSignal(signal_id); + cached_rdma_channel_tail = static_cast(signal_value); + } + + if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf( + "DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, " + "head: %d, tail: %d, expected: %d\n", + channel_id, + rdma_rank, + nvl_rank, + dst_nvl_rank, + lane_id, + cached_rdma_channel_head, + cached_rdma_channel_tail, + num_tokens_to_recv_from_rdma); + trap(); + } + } + auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); + auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); + + // Iterate over every token from the RDMA buffer + for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { + auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; + auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_token; + auto src_meta = ld_nc_global(reinterpret_cast(shifted + hidden_bytes + scale_bytes)); + lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; + bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); + if (lane_id == src_rdma_rank) { + auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; + rdma_nvl_token_idx += is_in_dst_nvl_rank; + if (not kCachedMode) + send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; + } + if (not is_in_dst_nvl_rank) + continue; + + // Get an empty slot + int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; + auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token; + + // Copy data + if (elect_one_sync()) { + tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false); + mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); + if (elect_one_sync()) + tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token); + __syncwarp(); + + // In case of insufficient NVL buffers, early stopping + if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) + src_rdma_tail = i + 1; + + // Wait TMA to be finished + tma_store_wait<0>(); + __syncwarp(); + } + + // Sync head index + if (lane_id == src_rdma_rank) + forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); + + // Move tail index + __syncwarp(); + if (elect_one_sync()) + st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); + } + + // Retired + __syncwarp(); + if (elect_one_sync()) + forward_channel_retired[dst_nvl_rank] = true; + } else if (warp_role == WarpRole::kForwarderCoordinator) { + // Extra warps for forwarder coordinator should exit directly + if (target_rank > 0) + return; + + // Forward warp coordinator + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Clean shared memory + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + #pragma unroll + for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) + forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; + if (lane_id < NUM_MAX_NVL_PEERS) + forward_channel_retired[lane_id] = false; + sync_forwarder_smem(); + + int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; + while (true) { + // Find minimum head + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + if (not forward_channel_retired[i]) + min_head = min(min_head, forward_channel_head[i][target_rdma]); + if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) + break; + + // Update remote head + if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and + lane_id < kNumRDMARanks) { + // kForwarderCoordinator: Update remote head + auto dst_rank = translate_dst_rdma_rank(lane_id, nvl_rank); + auto signal_id = gin_signals_head + rdma_rank; + net_head.signal(world_head, // team + dst_rank, // destination rank + ncclGin_SignalAdd{signal_id, (uint64_t)min_head - (uint64_t)last_head}, // signal + value + ncclCoopThread(), // cooperation scope (default) + ncclGin_None{}, // no descriptor (default) + cuda::thread_scope_thread, // alreadyReleased (default) + cuda::thread_scope_thread // expected_scope (default) + ); + last_head = min_head; + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } else { + // NVL consumers + // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) + int src_nvl_rank = target_rank, total_offset = 0; + const int local_expert_begin = rank * (num_experts / num_ranks); + const int local_expert_end = local_expert_begin + (num_experts / num_ranks); + + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) + total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; + + // Receive channel offsets + int start_offset = 0, end_offset = 0, num_tokens_to_recv; + auto start_time = clock64(); + while (lane_id < kNumRDMARanks) { + start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); + end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); + if (start_offset < 0 and end_offset < 0) { + start_offset = -start_offset - 1, end_offset = -end_offset - 1; + total_offset += start_offset; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + src_nvl_rank, + start_offset, + end_offset); + trap(); + } + } + num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); + + // Save for combine usage + if (lane_id < kNumRDMARanks and not kCachedMode) + recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; + __syncwarp(); + + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // Check channel status by lane 0 + start_time = clock64(); + while (true) { + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) + break; + cached_channel_tail_idx = __shfl_sync(0xffffffff, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0); + + // Timeout check + if (elect_one_sync() and clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", + channel_id, + rdma_rank, + nvl_rank, + src_nvl_rank, + cached_channel_head_idx, + cached_channel_tail_idx); + trap(); + } + } + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) { + int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; + auto shifted = nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_token; + auto meta = ld_nc_global(reinterpret_cast(shifted + hidden_bytes + scale_bytes)); + int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); + (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; + + bool scale_aligned = (scale_bytes % 16 == 0); + auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0); + + // Copy data + if (elect_one_sync()) { + tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); + if (elect_one_sync()) { + tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false); + if (scale_aligned) + tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false); + } + __syncwarp(); + shifted += hidden_bytes; + + // Copy scales + // TODO: make it as templated + if (not scale_aligned) { + UNROLLED_WARP_COPY(1, + lane_id, + num_scales, + recv_x_scales + recv_token_idx * num_scales, + reinterpret_cast(shifted), + ld_nc_global, + st_na_global); + } + shifted += scale_bytes; + + // Copy source meta + if (not kCachedMode and elect_one_sync()) + st_na_global(recv_src_meta + recv_token_idx, meta); + shifted += sizeof(SourceMeta); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + // Read + auto idx_value = static_cast(ld_nc_global(reinterpret_cast(shifted) + lane_id)); + auto weight_value = ld_nc_global(reinterpret_cast(shifted + sizeof(int) * num_topk) + lane_id); + auto recv_idx = recv_token_idx * num_topk + lane_id; + + // Transform and write + idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1; + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(recv_topk_idx + recv_idx, idx_value); + st_na_global(recv_topk_weights + recv_idx, weight_value); + } + + // Wait TMA to be finished + tma_store_wait<0>(); + __syncwarp(); + } + + // Move queue + if (elect_one_sync()) + st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); + } + } + + // Clean unused `recv_topk_idx` as -1 + if (num_worst_tokens > 0) { + if (is_forwarder) + return; + // get the actual number of num_recv_tokens on the current rank + int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1]; + // some ForwarderCoordinator threads exit early, so we only use non-forwarder in clean-up + // channel_id * num_threads is the offset of the current non-forwarder sms + const auto clean_start = num_recv_tokens * num_topk + channel_id * num_threads; + const auto clean_end = num_worst_tokens * num_topk; + const auto clean_stride = num_channels * num_threads; + #pragma unroll + for (int i = clean_start + thread_id; i < clean_end; i += clean_stride) + recv_topk_idx[i] = -1; + } +} + +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base) +{ + constexpr int kNumDispatchRDMASenderWarps = 7; + constexpr int kNumTMABytesPerWarp = 16384; + constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; + + // Make sure never OOB + EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < std::numeric_limits::max()); + +#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto dispatch_func = \ + low_latency_mode ? \ + (is_cached_dispatch ? \ + dispatch : \ + dispatch) : \ + (is_cached_dispatch ? \ + dispatch : \ + dispatch); \ + SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \ + LAUNCH_KERNEL(&cfg, \ + dispatch_func, \ + reinterpret_cast(recv_x), \ + recv_x_scales, \ + recv_topk_idx, \ + recv_topk_weights, \ + reinterpret_cast(recv_src_meta), \ + reinterpret_cast(x), \ + x_scales, \ + topk_idx, \ + topk_weights, \ + send_rdma_head, \ + send_nvl_head, \ + recv_rdma_channel_prefix_matrix, \ + recv_gbl_channel_prefix_matrix, \ + rdma_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, \ + recv_gbl_rank_prefix_sum, \ + is_token_in_rank, \ + num_tokens, \ + num_worst_tokens, \ + hidden_int4, \ + num_scales, \ + num_topk, \ + num_experts, \ + scale_token_stride, \ + scale_hidden_stride, \ + rdma_buffer_ptr, \ + num_max_rdma_chunked_send_tokens, \ + num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, \ + num_max_nvl_chunked_send_tokens, \ + num_max_nvl_chunked_recv_tokens, \ + rank, \ + num_ranks, \ + gin_base_ptr, \ + dev_comm, \ + nccl_win, \ + signals_base); \ + } \ + break + + EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); + EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + + SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); + SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); +#undef DISPATCH_LAUNCH_CASE +} + +template +__global__ void cached_notify(const int rdma_clean_offset, + const int rdma_num_int_clean, + const int nvl_clean_offset, + const int nvl_num_int_clean, + int* combined_rdma_head, + int num_combined_tokens, + int num_channels, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + int num_ranks, + bool is_cached_dispatch, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base +) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + auto num_threads = static_cast(blockDim.x); + auto num_warps = num_threads / 32; + auto warp_id = thread_id / 32; + auto lane_id = get_lane_id(); + + auto nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Using two SMs, which clean the RDMA/NVL buffer respectively + if (sm_id == 0) { + for (int i = thread_id; i < dev_comm.ginConnectionCount; i += num_threads) { + ncclGin net(dev_comm, i); + net.flush(ncclCoopThread(), cuda::std::memory_order_acquire); + } + __syncthreads(); + + // Barrier for RDMA + if (thread_id == 32) + sync_with_same_gpu_idx(dev_comm); + // Barrier for NVL + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Clean RDMA buffer + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); + #pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + + // Clean signals here + // For each channel we have num_rdma_ranks head and tail signals + int num_signals = num_rdma_ranks * num_channels * 2; + EP_DEVICE_ASSERT(num_signals <= num_threads); + + // Each thread handles one specific signal across all contexts + if (thread_id < num_signals) { + auto signal_id = signals_base + thread_id; + + // Derive channel_id from signal_id + // Signal layout: [all head signals][all tail signals] + int signal_offset = thread_id; + int head_signal_count = num_rdma_ranks * num_channels; + bool is_head_signal = (signal_offset < head_signal_count); + int channel_id = is_head_signal ? signal_offset / num_rdma_ranks : (signal_offset - head_signal_count) / num_rdma_ranks; + + auto context_id = is_head_signal ? channel_id + num_channels : channel_id ; + ncclGin net(dev_comm, context_id); + net.resetSignal(signal_id); + } + __syncthreads(); + + // Clean NVL buffer + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); + #pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + __syncthreads(); + + // Barrier again + if (thread_id == 32) + sync_with_same_gpu_idx(dev_comm); + barrier_block(barrier_signal_ptrs, nvl_rank); + } else if (sm_id == 1) { + if (is_cached_dispatch) + return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(num_rdma_ranks <= 32); + + // Iterate in reverse order + if (lane_id < num_rdma_ranks and warp_id < num_channels) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } else { + if (is_cached_dispatch) + return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); + + if (warp_id < num_channels) { + constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t); + constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS; + constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token; + EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16"); + + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + tma_batch_size); + uint32_t tma_phase = 0; + if (elect_one_sync()) { + mbarrier_init(tma_mbarrier, 1); + fence_barrier_init(); + } + __syncwarp(); + + for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) { + // Iterate in reverse order + int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + token_start_idx += shift, token_end_idx += shift; + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int batch_end_idx = token_end_idx; batch_end_idx > token_start_idx; batch_end_idx -= num_tokens_per_batch) { + auto batch_start_idx = max(token_start_idx, batch_end_idx - num_tokens_per_batch); + + if (elect_one_sync()) { + tma_load_1d(tma_buffer, + combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, + tma_mbarrier, + (batch_end_idx - batch_start_idx) * num_bytes_per_token); + mbarrier_arrive_and_expect_tx(tma_mbarrier, (batch_end_idx - batch_start_idx) * num_bytes_per_token); + } + mbarrier_wait(tma_mbarrier, tma_phase); + __syncwarp(); + + for (int token_idx = batch_end_idx - 1; token_idx >= batch_start_idx; --token_idx) { + if (lane_id < NUM_MAX_NVL_PEERS) { + auto current_head = + reinterpret_cast(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id]; + if (current_head < 0) { + reinterpret_cast(tma_buffer)[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id] = + -last_head - 1; + } else { + last_head = current_head; + } + } + } + tma_store_fence(); + __syncwarp(); + + if (elect_one_sync()) + tma_store_1d(tma_buffer, + combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, + (batch_end_idx - batch_start_idx) * num_bytes_per_token); + tma_store_wait<0>(); + __syncwarp(); + } + } + } + } +} + +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base) { + const int num_threads = std::max(128, 32 * num_channels); + const int num_warps = num_threads / 32; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int kNumTMABytesPerWarp = 8192; + const int smem_size = kNumTMABytesPerWarp * num_warps; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta( + hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, + num_scales, + num_topk_idx, + num_topk_weights, + num_rdma_ranks, + NUM_MAX_NVL_PEERS, + num_max_nvl_chunked_recv_tokens, + num_channels, + is_cached_dispatch); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_channels * 2 > 3); + + // Launch kernel + auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; + SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + SET_SHARED_MEMORY_FOR_TMA(cached_notify_func); + + LAUNCH_KERNEL(&cfg, + cached_notify_func, + rdma_clean_meta.first, + rdma_clean_meta.second, + nvl_clean_meta.first, + nvl_clean_meta.second, + combined_rdma_head, + num_combined_tokens, + num_channels, + rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, + combined_nvl_head, + rdma_buffer_ptr, + buffer_ptrs, + barrier_signal_ptrs, + rank, + num_ranks, + is_cached_dispatch, + gin_base_ptr, + dev_comm, + nccl_win, + signals_base); +} + +template +__device__ int combine_token(bool is_token_in_rank, + int head_idx, + int lane_id, + int hidden_int4, + int num_topk, + int4* combined_row, + float* combined_topk_weights, + const int4* bias_0_int4, + const int4* bias_1_int4, + int num_max_recv_tokens, + const GetAddrFn& get_addr_fn, + const ReceiveTWFn& recv_tw_fn, + uint8_t* smem_ptr, + uint32_t (&tma_phase)[kNumStages]) { + constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); + + // Broadcast current heads + // Lane `i` holds the head of rank `i` and `is_token_in_rank` + EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks"); + int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; + #pragma unroll + for (int i = 0; i < kNumRanks; ++i) + if (__shfl_sync(0xffffffff, is_token_in_rank, i)) { + slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens; + topk_ranks[num_topk_ranks++] = i; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); + EP_STATIC_ASSERT(not(kUseTMA and kMaybeWithBias), "TMA cannot be used by receiver warps"); + EP_STATIC_ASSERT(kNumStages == 2, "Only support 2 stages now"); + + // Reduce data + if constexpr (kUseTMA) { + constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16; + EP_DEVICE_ASSERT(hidden_int4 % 32 == 0); + + auto tma_load_buffer = [=](const int& i, const int& j) -> int4* { + return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + j * kNumTMALoadBytes); + }; + auto tma_store_buffer = [=](const int& i) -> int4* { + return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + NUM_MAX_NVL_PEERS * kNumTMALoadBytes); + }; + auto tma_mbarrier = [=](const int& i) -> uint64_t* { + return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + (NUM_MAX_NVL_PEERS + 1) * kNumTMALoadBytes); + }; + + // Prefetch + if (lane_id < num_topk_ranks) + tma_load_1d( + tma_load_buffer(0, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], 0), tma_mbarrier(0), kNumTMALoadBytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier(0), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0); + __syncwarp(); + + for (int shifted = 0, iter = 0; shifted < hidden_int4; shifted += 32, iter += 1) { + const int stage_idx = iter % kNumStages; + const int next_stage_idx = (iter + 1) % kNumStages; + + // Prefetch next stage + if (shifted + 32 < hidden_int4) { + if (lane_id < num_topk_ranks) + tma_load_1d(tma_load_buffer(next_stage_idx, lane_id), + get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], shifted + 32), + tma_mbarrier(next_stage_idx), + kNumTMALoadBytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier(next_stage_idx), lane_id < num_topk_ranks ? kNumTMALoadBytes : 0); + __syncwarp(); + } + + mbarrier_wait(tma_mbarrier(stage_idx), tma_phase[stage_idx]); + float values[kDtypePerInt4] = {0}; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) { + auto recv_value_dtypes = reinterpret_cast(tma_load_buffer(stage_idx, j) + lane_id); + #pragma unroll + for (int k = 0; k < kDtypePerInt4; ++k) + values[k] += static_cast(recv_value_dtypes[k]); + } + + // Wait shared memory to be released + tma_store_wait(); + + // Copy into shared and issue TMA + auto out_dtypes = reinterpret_cast(tma_store_buffer(stage_idx) + lane_id); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) + out_dtypes[j] = static_cast(values[j]); + tma_store_fence(); + __syncwarp(); + + if (elect_one_sync()) + tma_store_1d(tma_store_buffer(stage_idx), combined_row + shifted, kNumTMALoadBytes); + __syncwarp(); + } + + // Flush all writes + tma_store_wait<0>(); + } else { + #pragma unroll + for (int i = lane_id; i < hidden_int4; i += 32) { + // Read bias + // TODO: make it as a finer-grained template + int4 bias_0_value_int4, bias_1_value_int4; + if constexpr (kMaybeWithBias) { + bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0); + bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0); + } + + // Read buffers + // TODO: maybe too many registers here + int4 recv_value_int4[kMaxNumRanks]; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) + recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i)); + + // Clean + // Reduce bias + float values[kDtypePerInt4] = {0}; + if constexpr (kMaybeWithBias) { + auto bias_0_values = reinterpret_cast(&bias_0_value_int4); + auto bias_1_values = reinterpret_cast(&bias_1_value_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) + values[j] = static_cast(bias_0_values[j]) + static_cast(bias_1_values[j]); + } + + // Reduce all-to-all results + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++j) { + auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); + #pragma unroll + for (int k = 0; k < kDtypePerInt4; ++k) + values[k] += static_cast(recv_value_dtypes[k]); + } + + // Cast back to `dtype_t` and write + int4 out_int4; + auto out_dtypes = reinterpret_cast(&out_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) + out_dtypes[j] = static_cast(values[j]); + st_na_global(combined_row + i, out_int4); + } + } + + // Reduce `topk_weights` + if (lane_id < num_topk) { + float value = 0; + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); + st_na_global(combined_topk_weights + lane_id, value); + } + + // Return the minimum top-k rank + return topk_ranks[0]; +} + +template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, + int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, + int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS> +__global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* combined_x, + float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const int4* x, + const float* topk_weights, + const int4* bias_0, + const int4* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const SourceMeta* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base +) { + enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator }; + + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; + const bool is_forwarder_sm = sm_id % 2 == 1; + + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); + const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); + const auto hidden_bytes = hidden_int4 * sizeof(int4); + const auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, 0, 0, num_topk); + + EP_DEVICE_ASSERT(num_channels * 2 <= dev_comm.ginContextCount); + // Use a diff GIN context for each channel/SM + ncclGin net(dev_comm, channel_id); + ncclTeam world = ncclTeamWorld(dev_comm); + // Using different communicator for reading/writing head pointers + ncclGin net_head(dev_comm, channel_id + num_channels); + ncclTeam world_head = ncclTeamWorld(dev_comm); + + // NOTES: we decouple a channel into 2 SMs + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto role_meta = [=]() -> std::pair { + auto warp_id = thread_id / 32; + if (not is_forwarder_sm) { + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole::kNVLSender, shuffled_warp_id}; + } else if (warp_id < kNumForwarders) { + return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } else { + if (warp_id < kNumForwarders) { + auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders; + return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } + }(); + auto warp_role = role_meta.first; + auto warp_id = role_meta.second; + + EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1); + auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; + + if (warp_role == WarpRole::kNVLSender) { + // NVL producers + const auto dst_nvl_rank = warp_id; + + // NVL layouts + // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources + auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; + auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, + num_max_nvl_chunked_recv_tokens * num_bytes_per_token, + NUM_MAX_NVL_PEERS, + channel_id, + num_channels, + nvl_rank) + .advance_also(local_buffer_ptr); + auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank) + .advance_also(dst_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) + .advance_also(local_buffer_ptr); + + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerSenderWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + num_bytes_per_token); + uint32_t tma_phase = 0; + if (elect_one_sync()) { + mbarrier_init(tma_mbarrier, 1); + fence_barrier_init(); + EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp); + } + __syncwarp(); + + // Get tasks for each RDMA lane + int token_start_idx = 0, token_end_idx = 0; + if (lane_id < kNumRDMARanks) { + int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; + token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; + token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; + } + __syncwarp(); + + // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Iterate over all tokens and send by chunks + int current_rdma_idx = channel_id % kNumRDMARanks; + while (true) { + // Exit if possible + if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) + break; + + // Decide the next RDMA buffer to send + bool is_lane_ready = false; + auto start_time = clock64(); + while (true) { + int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; + is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and + num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; + if (__any_sync(0xffffffff, is_lane_ready)) + break; + + // Retry + if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) + cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf( + "DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: " + "%d, start: %d, end: %d\n", + channel_id, + rdma_rank, + nvl_rank, + dst_nvl_rank, + lane_id, + ld_volatile_global(nvl_channel_head.buffer() + lane_id), + cached_channel_tail_idx, + token_start_idx, + token_end_idx); + trap(); + } + } + + // Sync token start index and count + for (int i = 0; i < kNumRDMARanks; ++i) { + current_rdma_idx = (current_rdma_idx + 1) % kNumRDMARanks; + if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) + continue; + + // Sync token start index + auto token_idx = static_cast(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx)); + int num_tokens_in_chunk = + __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); + + // Send by chunk + for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) { + // Get an empty slot + int dst_slot_idx = 0; + if (lane_id == current_rdma_idx) { + dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma; + dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; + } + dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); + + // Load data + auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token; + auto shifted_x = x + token_idx * hidden_int4; + tma_store_wait<0>(); + if (elect_one_sync()) { + tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); + + // Load source meta + if (lane_id == num_topk) + *reinterpret_cast(tma_buffer + hidden_bytes) = ld_nc_global(src_meta + token_idx); + + // Load `topk_weights` + if (lane_id < num_topk) + *reinterpret_cast(tma_buffer + hidden_bytes + sizeof(SourceMeta) + lane_id * sizeof(float)) = + ld_nc_global(topk_weights + token_idx * num_topk + lane_id); + + // Issue TMA store + tma_store_fence(); + __syncwarp(); + if (elect_one_sync()) + tma_store_1d(tma_buffer, shifted_x_buffers, num_bytes_per_token, false); + } + lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; + } + + // Move queue tail + tma_store_wait<0>(); + __syncwarp(); + if (lane_id < kNumRDMARanks and is_lane_ready) + st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); + } + } else { + // Combiners and coordinators + // RDMA symmetric layout + auto rdma_channel_data = SymBuffer( + rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + // Using signals for RDMA head and tail instead with gin_signals + auto gin_signals_head = signals_base + kNumRDMARanks * channel_id; // move the signals to the corresponding channel + auto gin_signals_tail = signals_base + kNumRDMARanks * num_channels + + kNumRDMARanks * channel_id; // move the signals to the corresponding channel after passing all head signals + + // NVL layouts + void* local_nvl_buffer = buffer_ptrs[nvl_rank]; + void* nvl_buffers[NUM_MAX_NVL_PEERS]; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + nvl_buffers[i] = buffer_ptrs[i]; + auto nvl_channel_x = + AsymBuffer( + local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_bytes_per_token, NUM_MAX_NVL_PEERS, channel_id, num_channels) + .advance_also(nvl_buffers); + auto nvl_channel_head = + AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) + .advance_also(local_nvl_buffer); + auto nvl_channel_tail = AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels) + .advance_also(nvl_buffers); + + // Combiner warp synchronization + __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; + __shared__ volatile bool forwarder_retired[kNumForwarders]; + __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; + __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; + auto sync_forwarder_smem = [=]() { asm volatile("barrier.sync 0, %0;" ::"r"((kNumForwarders + 1) * 32)); }; + auto sync_rdma_receiver_smem = [=]() { asm volatile("barrier.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * 32)); }; + + if (warp_role == WarpRole::kNVLAndRDMAForwarder) { + // Receive from NVL ranks and forward to RDMA ranks + // NOTES: this part is using "large warps" for each RDMA ranks + const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; + const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; + auto send_buffer = + dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); + auto sync_large_warp = [=]() { + if (kNumWarpsPerForwarder == 1) { + __syncwarp(); + } else { + asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32)); + } + }; + EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); + + // TMA stuffs + constexpr int kNumStages = 2; + constexpr int kNumTMALoadBytes = sizeof(int4) * 32; + constexpr int kNumTMABufferBytesPerStage = kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1) + 16; + EP_STATIC_ASSERT(kNumTMABufferBytesPerStage * kNumStages <= kNumTMABytesPerForwarderWarp, "TMA buffer is not larger enough"); + + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_ptr = smem_buffer + warp_id * kNumStages * kNumTMABufferBytesPerStage; + auto tma_mbarrier = [=](const int& i) { + return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1)); + }; + uint32_t tma_phase[kNumStages] = {0}; + if (lane_id < kNumStages) { + mbarrier_init(tma_mbarrier(lane_id), 32); + fence_barrier_init(); + } + __syncwarp(); + + // Advance to the corresponding NVL buffer + nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_bytes_per_token); + nvl_channel_head.advance(dst_rdma_rank); + nvl_channel_tail.advance(dst_rdma_rank); + + // Clean shared memory and sync + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; + sync_forwarder_smem(); + + // Get count and cached head + int cached_nvl_channel_tail_idx = 0; + int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; + int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + num_tokens_to_combine -= num_tokens_prefix; + num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; + + // Iterate over all tokens and combine by chunks + for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { + // Check destination queue emptiness, or wait a buffer to be released + auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); + auto num_chunked_tokens = token_end_idx - token_start_idx; + auto start_time = clock64(); + while (sub_warp_id == 0 and lane_id == 0) { + // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` + // Here, `token_start_idx` is the actual tail + // kNVLAndRDMAForwarder: Check if RDMA receive buffer has space before sending data + auto signal_id = gin_signals_head + dst_rdma_rank; + uint64_t signal_value = net_head.readSignal(signal_id); + int num_used_slots = token_start_idx - signal_value; + if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) + break; + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + // kNVLAndRDMAForwarder: debugging + auto signal_id = gin_signals_head + dst_rdma_rank; + uint64_t signal_value = net_head.readSignal(signal_id); + printf( + "DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: " + "%d, chunked: %d\n", + channel_id, + rdma_rank, + nvl_rank, + dst_rdma_rank, + signal_value, + token_start_idx, + num_chunked_tokens); + trap(); + } + } + sync_large_warp(); + + // Combine and write to the RDMA buffer + for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < NUM_MAX_NVL_PEERS) { + expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) + : (forwarder_nvl_head[warp_id][lane_id] = expected_head); + } + + // Wait lanes to be ready + start_time = clock64(); + while (cached_nvl_channel_tail_idx <= expected_head) { + cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { + printf( + "DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, " + "tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + dst_rdma_rank, + cached_nvl_channel_tail_idx, + token_idx, + num_tokens_to_combine, + sub_warp_id, + kNumWarpsPerForwarder, + expected_head); + trap(); + } + } + + // Combine current token + auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; + void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token; + auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { + return reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + + hidden_int4_idx; + }; + auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { + return ld_nc_global(reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token + + hidden_bytes + sizeof(SourceMeta)) + + topk_idx); + }; + combine_token( + expected_head >= 0, + expected_head, + lane_id, + hidden_int4, + num_topk, + static_cast(shifted), + reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + nullptr, + nullptr, + num_max_nvl_chunked_recv_tokens_per_rdma, + get_addr_fn, + recv_tw_fn, + smem_ptr, + tma_phase); + + // Update head + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) + : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); + } + sync_large_warp(); + + // Issue RDMA send + if (sub_warp_id == kNumWarpsPerForwarder - 1) { + if (dst_rdma_rank != rdma_rank) { + auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token; + const auto dst_ptr = + reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token); + const auto src_ptr = + reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token); + // kNVLAndRDMAForwarder: Transfer combined token data to remote RDMA rank + int dst_rank = translate_dst_rdma_rank(dst_rdma_rank, nvl_rank); + size_t src_offset = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + + rdma_slot_idx * num_bytes_per_token) - + reinterpret_cast(gin_base_ptr); + size_t dst_offset = + reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_token) - + reinterpret_cast(gin_base_ptr); + net.put(world, + dst_rank, + nccl_win, + dst_offset, + nccl_win, + src_offset, + num_bytes_per_msg, + ncclGin_None{}, // no signal + ncclGin_None{}, // no counter + ncclCoopWarp()); + + } else { + memory_fence(); + } + + // Write new RDMA tail + __syncwarp(); + if (elect_one_sync()) { + auto dst_rank = translate_dst_rdma_rank(dst_rdma_rank, nvl_rank); + auto signal_id = gin_signals_tail + rdma_rank; + net.signal(world, // team + dst_rank, // destination rank + ncclGin_SignalAdd{signal_id, (uint64_t)num_chunked_tokens}, // signal + value + ncclCoopThread(), // cooperation scope (default) + ncclGin_None{}, // no descriptor (default) + cuda::thread_scope_thread, // alreadyReleased (default) + cuda::thread_scope_thread // expected_scope (default) + ); + } + } + } + + // Retired + __syncwarp(); + if (elect_one_sync()) + forwarder_retired[warp_id] = true; + } else if (warp_role == WarpRole::kRDMAReceiver) { + // Receive from RDMA ranks and write to the output tensor + // Clean shared memory and sync + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; + sync_rdma_receiver_smem(); + + // The same tokens as the dispatch process + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over all tokens and combine + int cached_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < kNumRDMARanks) { + expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); + (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) + : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); + } + + // Wait lanes to be ready + auto start_time = clock64(); + while (cached_channel_tail_idx <= expected_head) { + // kRDMAReceiver: Check if data is available from remote RDMA rank (check tail pointer) + auto signal_id = gin_signals_tail + lane_id; + uint64_t signal_value = net.readSignal(signal_id); + cached_channel_tail_idx = static_cast(signal_value); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, " + "expect: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + cached_channel_tail_idx, + token_idx, + expected_head); + trap(); + } + } + __syncwarp(); + + // Combine current token + auto get_addr_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4* { + return reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_token) + + hidden_int4_idx; + }; + auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { + return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + + topk_idx); + }; + uint32_t dummy_tma_phases[2]; + combine_token( + expected_head >= 0, + expected_head, + lane_id, + hidden_int4, + num_topk, + combined_x + token_idx * hidden_int4, + combined_topk_weights + token_idx * num_topk, + bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4, + bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4, + num_max_rdma_chunked_recv_tokens, + get_addr_fn, + recv_tw_fn, + nullptr, + dummy_tma_phases); + } + + // Retired + __syncwarp(); + if (elect_one_sync()) + rdma_receiver_retired[warp_id] = true; + } else { + // Coordinator + // Sync shared memory status + is_forwarder_sm ? sync_forwarder_smem() : sync_rdma_receiver_smem(); + const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; + + int last_rdma_head = 0; + int last_nvl_head[kNumRDMARanks] = {0}; + int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; + int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; + EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); + while (true) { + // Retired + if (not is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) + break; + if (is_forwarder_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) + break; + + // Find minimum head for RDMA ranks + if (not is_forwarder_sm) { + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < kNumRDMAReceivers; ++i) + if (not rdma_receiver_retired[i]) + min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); + if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and + lane_id < kNumRDMARanks) { + // Coordinator: Notify remote rank that buffer space has been freed (update head pointer) + auto dst_rank = translate_dst_rdma_rank(dst_rdma_rank, nvl_rank); + auto signal_id = gin_signals_head + rdma_rank; + net_head.signal(world_head, // team + dst_rank, // destination rank + ncclGin_SignalAdd{signal_id, (uint64_t)min_head - (uint64_t)last_rdma_head}, // signal + value + ncclCoopThread(), // cooperation scope (default) + ncclGin_None{}, // no descriptor (default) + cuda::thread_scope_thread, // alreadyReleased (default) + cuda::thread_scope_thread // expected_scope (default) + ); + last_rdma_head = min_head; + } + } else { + // Find minimum head for NVL ranks + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int j = 0; j < num_warps_per_rdma_rank; ++j) + if (not forwarder_retired[i * num_warps_per_rdma_rank + j]) + min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); + if (min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) + st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); + } + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } + } +} + +void combine(cudaDataType_t type, + void* combined_x, + float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode, + void* gin_base_ptr, + ncclDevComm dev_comm, + ncclWindow_t nccl_win, + unsigned signals_base) { + constexpr int kNumCombineForwarderWarps = 24; + constexpr int kNumTMABytesPerSenderWarp = 16384; + constexpr int kNumTMABytesPerForwarderWarp = 9248; + constexpr int smem_size = + std::max(kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, kNumTMABytesPerForwarderWarp * kNumCombineForwarderWarps); + +#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto combine_func = low_latency_mode ? combine \ + : combine; \ + SET_SHARED_MEMORY_FOR_TMA(combine_func); \ + LAUNCH_KERNEL(&cfg, \ + combine_func, \ + reinterpret_cast(combined_x), \ + combined_topk_weights, \ + is_combined_token_in_rank, \ + reinterpret_cast(x), \ + topk_weights, \ + reinterpret_cast(bias_0), \ + reinterpret_cast(bias_1), \ + combined_rdma_head, \ + combined_nvl_head, \ + reinterpret_cast(src_meta), \ + rdma_channel_prefix_matrix, \ + rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, \ + num_tokens, \ + num_combined_tokens, \ + hidden, \ + num_topk, \ + rdma_buffer_ptr, \ + num_max_rdma_chunked_send_tokens, \ + num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, \ + num_max_nvl_chunked_send_tokens, \ + num_max_nvl_chunked_recv_tokens, \ + rank, \ + num_ranks, \ + gin_base_ptr, \ + dev_comm, \ + nccl_win, \ + signals_base); \ + } \ + break + + int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); + int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; + EP_HOST_ASSERT(num_rdma_ranks <= kNumCombineForwarderWarps); + EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > + std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >= num_max_nvl_chunked_send_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder); + EP_HOST_ASSERT(type == CUDA_R_16BF); + + SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream); + SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); +#undef COMBINE_LAUNCH_CASE +} + +} // namespace internode + +} // namespace deep_ep + +// clang-format on diff --git a/3rdparty/deep_ep/kernels/internode_ll.cu b/3rdparty/deep_ep/kernels/internode_ll.cu new file mode 100644 index 0000000000..bd1209f61b --- /dev/null +++ b/3rdparty/deep_ep/kernels/internode_ll.cu @@ -0,0 +1,1347 @@ +// clang-format off +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" + +#include +#include +#include +#include + +using namespace cooperative_groups; +namespace cg = cooperative_groups; + +namespace deep_ep { + +namespace internode_ll { + +template +__forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) { + if (mask_buffer_ptr == nullptr) { + return false; + } + if constexpr (use_warp_sync) { + return __shfl_sync(0xffffffff, ld_acquire_global(mask_buffer_ptr + rank), 0) != 0; + } else { + return ld_acquire_global(mask_buffer_ptr + rank) != 0; + } +} + +// Device constant for P2P/NVLink disabled flag +// Set to true to force RDMA path, false to allow P2P when available +// Default is false (P2P enabled), updated from host via CLI option +__device__ __constant__ bool d_p2p_disabled = false; + +// Get peer-to-peer pointer for NCCL +// Returns dst_ptr if NVLink is available, 0 otherwise +// offset parameter allows callers to pass a pre-calculated offset for the destination +__device__ __forceinline__ uint64_t nccl_get_p2p_ptr(const uint64_t& dst_ptr, + const size_t& offset, + const int& rank, + const int& dst_rank, + const ncclWindow_t dev_win, + ncclDevComm dev_comm) +{ + // Local rank, no need for peer mapping + if (rank == dst_rank) + return dst_ptr; + + // If P2P is globally disabled, always use RDMA path + if (d_p2p_disabled) + return 0; + + // P2P/NVLink only works between ranks on the same node (LSA team) + // Use NCCL team APIs to check if dst_rank is in the same LSA team + ncclTeam lsa = ncclTeamLsa(dev_comm); + ncclTeam world = ncclTeamWorld(dev_comm); + if (!ncclTeamRankIsMember(lsa, world, dst_rank)) { + return 0; // Different nodes (not in same LSA team), must use RDMA + + } + + auto const p2p_ptr = reinterpret_cast(ncclGetPeerPointer(dev_win, offset, dst_rank)); + return p2p_ptr ? p2p_ptr : 0; +} + + +template +__global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, + void* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_recv_count, + void* rdma_x, + size_t rdma_recv_x_offset, /* nccl backend*/ + size_t rdma_recv_count_offset, + size_t rdma_x_offset, + const void* x, + const topk_idx_t* topk_idx, + int* atomic_counter_per_expert, + int* atomic_finish_counter_per_expert, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + int num_warp_groups, + int num_warps_per_group, + bool round_scale, + int phases, + ncclDevComm dev_comm, + const ncclWindow_t nccl_win, + unsigned signals_base +) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_sms = static_cast(gridDim.x); + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / num_warps_per_group; + const auto sub_warp_id = warp_id % num_warps_per_group; + const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + + // May extract UE8M0 from the scales + using scale_t = std::conditional_t; + using packed_t = std::conditional_t; + EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + + // FP8 staffs + constexpr int kNumPerChannels = 128; + const int num_scales = kHidden / kNumPerChannels; + const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); + + // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales + // NOTES: currently we have 3 reserved int fields for future use + using vec_t = std::conditional_t; + const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + + // Expert counts + constexpr int kNumMaxWarpGroups = 32; + __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + + // Sending phase + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) + goto LOW_LATENCY_DISPATCH_RECV; + + // There are 2 kinds of warps in this part: + // 1. The first-kind warps for FP8 cast and sending top-k tokens + // 2. The last warp for reading `topk_idx` and count for per-expert information + if (warp_id < num_warps - 1) { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden"); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); + const auto num_threads = (num_warps - 1) * 32; + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; + const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + + // Overlap top-k index read and source token index writes + auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; + thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + + // FP8 cast + EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); + #pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + // Read + auto int4_value = __ldg(x_int4 + i); + + if constexpr (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = warp_reduce_max<16>(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } + } + asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); + + // Issue IBGDA sends + if (dst_expert_idx >= 0) { + int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); + const auto dst_rank = dst_expert_idx / num_local_experts; + const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; + + size_t expected_dst_offset = rdma_recv_x_offset + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; + const auto dst_p2p_ptr = + nccl_get_p2p_ptr(dst_ptr, expected_dst_offset, rank, dst_rank, nccl_win, dev_comm); + + if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { + if (dst_p2p_ptr == 0) { + size_t expected_src_offset = rdma_x_offset + token_idx * num_bytes_per_msg; + ncclGin net(dev_comm, dst_expert_local_idx); + ncclTeam world = ncclTeamWorld(dev_comm); + net.put(world, + dst_rank, + nccl_win, + expected_dst_offset, + nccl_win, + expected_src_offset, + num_bytes_per_msg, + ncclGin_None{}, // no signal + ncclGin_None{}, // no counter + ncclCoopWarp()); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + } + + // Increase counter after finishing + __syncwarp(); + lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + } + } + } else if (warp_id == num_warps - 1) { + EP_DEVICE_ASSERT(num_sms > 1); + if (sm_id == 0) { + // The first SM is also responsible for cleaning the next buffer + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + // Notify before executing `int_p` + __syncwarp(); + #pragma unroll + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + + // This SM should be responsible for some destination experts, read `topk_idx` for them + int expert_count[kNumMaxWarpGroups] = {0}; + const auto expert_begin_idx = sm_id * num_warp_groups; + const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); + + // Per lane count + #pragma unroll 8 + for (int i = lane_id; i < num_tokens * num_topk; i += 32) { + auto idx = static_cast(__ldg(topk_idx + i)); + if (idx >= expert_begin_idx and idx < expert_end_idx) + expert_count[idx - expert_begin_idx]++; + } + + // Warp reduce + #pragma unroll + for (int i = expert_begin_idx; i < expert_end_idx; ++i) { + auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } + } + } + __syncthreads(); + + // Issue count sends + if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; + const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; + + // Wait local sends issued and send expert counts + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2) + ; + auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); + + size_t dst_offset = rdma_recv_count_offset + (dst_expert_local_idx * num_ranks + rank) * sizeof(int); + const auto dst_p2p_ptr = nccl_get_p2p_ptr(dst_ptr, dst_offset, rank, dst_rank, nccl_win, dev_comm); + + if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { + if (dst_p2p_ptr == 0) { // if (rank != dst_rank) { + auto signal_id = signals_base + dst_expert_local_idx * num_ranks + rank; + ncclGin net(dev_comm, dst_expert_local_idx); + ncclTeam world = ncclTeamWorld(dev_comm); + // NOTE: net.signal() is semantically cleaner but adds latency to Dispatch-Send + // and Combine-Send compared to net.put() with 0 bytes + // net.signal(world, + // dst_rank, + // ncclGin_SignalAdd{signal_id, (uint64_t)num_tokens_sent + 1}, + // ncclCoopThread(), + // ncclGin_None(), + // cuda::thread_scope_system); + net.put(world, + dst_rank, + nccl_win, + dst_offset, + nccl_win, + 0, + 0, // 0 bytes transfer + ncclGin_SignalAdd{signal_id, (uint64_t)num_tokens_sent + 1}, + ncclGin_None{}, // no counter + ncclCoopThread()); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); + } + } + + // Clean workspace for next use + atomic_counter_per_expert[responsible_expert_idx] = 0; + atomic_finish_counter_per_expert[responsible_expert_idx] = 0; + + // Clean `packed_recv_count` + if (dst_rank == 0) + packed_recv_count[dst_expert_local_idx] = 0; + } + __syncwarp(); + +// Receiving phase +LOW_LATENCY_DISPATCH_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + return; + + // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible + if (phases & LOW_LATENCY_SEND_PHASE) + cg::this_grid().sync(); + + // Receiving and packing + if (responsible_expert_idx < num_experts) { + const auto src_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + const auto recv_x_int4 = + static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; + const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; + const auto num_aligned_scales = align_up(num_scales, sizeof(float) / sizeof(scale_t)); + const auto recv_x_scales = static_cast(packed_recv_x_scales) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + + // Shared between sub-warps in warp groups + __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; + + // Wait tokens to arrive + // NOTES: using sub-warp 1 to overlap with sub-warp 0 + int num_recv_tokens = 0, recv_token_begin_idx; + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); + if (sub_warp_id == 1 and lane_id == 0) { + auto start_time = clock64(); + uint64_t wait_recv_cost = 0; + if (not is_rank_masked(mask_buffer_ptr, src_rank)) { + size_t src_offset = rdma_recv_count_offset + (local_expert_idx * num_ranks + src_rank) * sizeof(int); + auto src_p2p_ptr = nccl_get_p2p_ptr(0x01, src_offset, rank, src_rank, nccl_win, dev_comm); + if (src_p2p_ptr == 0) { + ncclGin net(dev_comm, local_expert_idx); + uint64_t cur_value; + do { + cur_value = net.readSignal(signals_base + local_expert_idx * num_ranks + src_rank); + } while (cur_value < 1 // data not arrived + && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout + ); + net.resetSignal(signals_base + local_expert_idx * num_ranks + src_rank); + num_recv_tokens = -(int)cur_value; + } else { + while ((num_recv_tokens = ld_acquire_sys_global((rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == + 0 // data not arrived + && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout + ); + } + } + // Do not receive tokens if rank timeout or masked + if (num_recv_tokens == 0) + num_recv_tokens = -1; + // Mask rank if timeout + if (wait_recv_cost > NUM_TIMEOUT_CYCLES) { + printf("Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d\n", + rank, + local_expert_idx, + src_rank); + if (mask_buffer_ptr == nullptr) + trap(); + atomicExch(mask_buffer_ptr + src_rank, 1); + } + + num_recv_tokens = -num_recv_tokens - 1; + recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + shared_num_recv_tokens[warp_group_id] = num_recv_tokens; + shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + + // Add stats for diagnosis + if (cumulative_local_expert_recv_stats != nullptr) + atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); + if (dispatch_wait_recv_cost_stats != nullptr) + atomicAdd(reinterpret_cast(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost); + } + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); + num_recv_tokens = shared_num_recv_tokens[warp_group_id]; + recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + // Copy tokens + EP_DEVICE_ASSERT(num_scales <= 64); + for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { + // Copy source info + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + __syncwarp(); + + // Copy data + // NOTES: only 2 load iterations for 7K hidden with 7 unrolls + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + // Copy scales + if constexpr (kUseFP8) { + // Equivalent CuTe layout: + // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + const auto token_idx = recv_token_begin_idx + i; + const auto token_stride = num_elems_per_pack; + const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; + if (lane_id < num_scales) { + const auto pack_idx = lane_id / num_elems_per_pack; + const auto elem_idx = lane_id % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } + if (lane_id + 32 < num_scales) { + const auto pack_idx = (lane_id + 32) / num_elems_per_pack; + const auto elem_idx = (lane_id + 32) % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } + } + } + } +} + +void dispatch(void* packed_recv_x, + void* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_recv_count, + void* rdma_x, + size_t rdma_recv_x_offset, + size_t rdma_recv_count_offset, + size_t rdma_x_offset, + const void* x, + const topk_idx_t* topk_idx, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + ncclWindow_t nccl_win, + ncclDevComm dev_comm, + unsigned signals_base, + cudaStream_t stream, + int phases) +{ + constexpr int kNumMaxTopK = 11; + const int num_warp_groups = ceil_div(num_experts, num_device_sms); + const int num_warps_per_group = 32 / num_warp_groups; + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); + EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); + + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_sms = ceil_div(num_experts, num_warp_groups); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + + // Workspace checks + auto atomic_counter_per_expert = static_cast(workspace); + auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; + EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + + // FP8 checks + if (use_ue8m0) + EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); + +#define DISPATCH_LAUNCH_CASE(hidden) \ + { \ + auto dispatch_func = dispatch