diff --git a/csrc/elastic/buffer.hpp b/csrc/elastic/buffer.hpp index f3d73a7f..7e7597f7 100644 --- a/csrc/elastic/buffer.hpp +++ b/csrc/elastic/buffer.hpp @@ -67,6 +67,7 @@ class ElasticBuffer { int prev_rank_idx = 0, next_rank_idx = 0; int64_t num_max_pp_tensor_bytes = 0; int num_max_pp_inflight_tensors = 0; + int num_pp_qps = 0; // AGRS session settings int64_t num_max_agrs_session_bytes = 0; @@ -294,20 +295,23 @@ class ElasticBuffer { }; } - void pp_set_config(const int64_t& num_max_tensor_bytes, const int& num_max_inflight_tensors) { + void pp_set_config(const int64_t& num_max_tensor_bytes, const int& num_max_inflight_tensors, const int& num_qps) { // Flush previous operations barrier(false, true); EP_HOST_ASSERT(num_max_tensor_bytes > 0 and num_max_inflight_tensors > 0); EP_HOST_ASSERT(num_max_tensor_bytes * num_max_inflight_tensors * 2 * 2 <= num_buffer_bytes); + const auto actual_num_qps = num_qps == 0 ? nccl_context->num_allocated_qps : num_qps; + EP_HOST_ASSERT(actual_num_qps > 0 and actual_num_qps <= nccl_context->num_allocated_qps); this->prev_rank_idx = (nccl_context->rank_idx + nccl_context->num_ranks - 1) % nccl_context->num_ranks; this->next_rank_idx = (nccl_context->rank_idx + 1) % nccl_context->num_ranks; this->num_max_pp_tensor_bytes = math::align(num_max_tensor_bytes, 32); this->num_max_pp_inflight_tensors = num_max_inflight_tensors; + this->num_pp_qps = actual_num_qps; } void pp_send(const torch::Tensor& x, const int& dst_rank_idx, const int& num_sms) const { - EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0); + EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0 and num_pp_qps > 0); EP_HOST_ASSERT(x.is_cuda() and x.is_contiguous() and x.nbytes() <= num_max_pp_tensor_bytes); EP_HOST_ASSERT(dst_rank_idx == prev_rank_idx or dst_rank_idx == next_rank_idx); @@ -319,6 +323,7 @@ class ElasticBuffer { num_max_pp_tensor_bytes, num_max_pp_inflight_tensors, num_sms == 0 ? jit::device_runtime->get_num_sms() : num_sms, + num_pp_qps, num_gpu_timeout_cycles, jit::device_runtime->get_num_smem_bytes(), at::cuda::getCurrentCUDAStream() @@ -326,7 +331,7 @@ class ElasticBuffer { } void pp_recv(const torch::Tensor& x, const int& src_rank_idx, const int& num_sms) const { - EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0); + EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0 and num_pp_qps > 0); EP_HOST_ASSERT(x.is_cuda() and x.is_contiguous() and x.nbytes() <= num_max_pp_tensor_bytes); EP_HOST_ASSERT(src_rank_idx == prev_rank_idx or src_rank_idx == next_rank_idx); @@ -338,6 +343,7 @@ class ElasticBuffer { num_max_pp_tensor_bytes, num_max_pp_inflight_tensors, num_sms == 0 ? jit::device_runtime->get_num_sms() : num_sms, + num_pp_qps, num_gpu_timeout_cycles, jit::device_runtime->get_num_smem_bytes(), at::cuda::getCurrentCUDAStream() @@ -1314,9 +1320,12 @@ static void register_apis(pybind11::module_& m) { .def("barrier", &ElasticBuffer::barrier) .def("engram_write", &ElasticBuffer::engram_write) .def("engram_fetch", &ElasticBuffer::engram_fetch) - .def("pp_set_config", &ElasticBuffer::pp_set_config) - .def("pp_send", &ElasticBuffer::pp_send) - .def("pp_recv", &ElasticBuffer::pp_recv) + .def("pp_set_config", &ElasticBuffer::pp_set_config, + pybind11::arg("num_max_tensor_bytes"), pybind11::arg("num_max_inflight_tensors"), pybind11::arg("num_qps") = 0) + .def("pp_send", &ElasticBuffer::pp_send, + pybind11::arg("x"), pybind11::arg("dst_rank_idx"), pybind11::arg("num_sms") = 0) + .def("pp_recv", &ElasticBuffer::pp_recv, + pybind11::arg("x"), pybind11::arg("src_rank_idx"), pybind11::arg("num_sms") = 0) .def("create_agrs_session", &ElasticBuffer::create_agrs_session) .def("destroy_agrs_session", &ElasticBuffer::destroy_agrs_session) .def("agrs_set_config", &ElasticBuffer::agrs_set_config) diff --git a/csrc/kernels/elastic/pp_send_recv.hpp b/csrc/kernels/elastic/pp_send_recv.hpp index 06cbea43..c03e150d 100644 --- a/csrc/kernels/elastic/pp_send_recv.hpp +++ b/csrc/kernels/elastic/pp_send_recv.hpp @@ -18,6 +18,7 @@ class PPSendRuntime final : public jit::LaunchRuntime { // Templated arguments int num_ranks; int num_smem_bytes; + int num_qps; int64_t num_timeout_cycles; // Parameters @@ -41,11 +42,12 @@ class PPSendRuntime final : public jit::LaunchRuntime { using namespace deep_ep::elastic; static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&pp_send_impl<{}, {}, {}, {}>); + auto ptr = reinterpret_cast(&pp_send_impl<{}, {}, {}, {}, {}>); }} )", args.launch_args.grid_dim.first, args.num_ranks, args.num_smem_bytes, + args.num_qps, args.num_timeout_cycles); } @@ -69,6 +71,7 @@ static void launch_pp_send(const ncclDevComm_t& nccl_dev_comm, const int64_t& num_max_tensor_bytes, const int num_max_inflight_tensors, const int& num_sms, + const int& num_qps, const int64_t& num_timeout_cycles, const int& num_smem_bytes, const at::cuda::CUDAStream& stream) { @@ -76,6 +79,7 @@ static void launch_pp_send(const ncclDevComm_t& nccl_dev_comm, const PPSendRuntime::Args args = { .num_ranks = num_ranks, .num_smem_bytes = num_smem_bytes, + .num_qps = num_qps, .num_timeout_cycles = num_timeout_cycles, .nccl_dev_comm = nccl_dev_comm, .nccl_window = nccl_window, @@ -100,6 +104,7 @@ class PPRecvRuntime final : public jit::LaunchRuntime { // Templated arguments int num_ranks; int num_smem_bytes; + int num_qps; int64_t num_timeout_cycles; // Parameters @@ -124,11 +129,12 @@ class PPRecvRuntime final : public jit::LaunchRuntime { using namespace deep_ep::elastic; static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&pp_recv_impl<{}, {}, {}, {}>); + auto ptr = reinterpret_cast(&pp_recv_impl<{}, {}, {}, {}, {}>); }} )", args.launch_args.grid_dim.first, args.num_ranks, args.num_smem_bytes, + args.num_qps, args.num_timeout_cycles); } @@ -154,6 +160,7 @@ static void launch_pp_recv(const ncclDevComm_t& nccl_dev_comm, const int64_t& num_max_tensor_bytes, const int& num_max_inflight_tensors, const int& num_sms, + const int& num_qps, const int64_t& num_timeout_cycles, const int& num_smem_bytes, const at::cuda::CUDAStream& stream) { @@ -161,6 +168,7 @@ static void launch_pp_recv(const ncclDevComm_t& nccl_dev_comm, const PPRecvRuntime::Args args = { .num_ranks = num_ranks, .num_smem_bytes = num_smem_bytes, + .num_qps = num_qps, .num_timeout_cycles = num_timeout_cycles, .nccl_dev_comm = nccl_dev_comm, .nccl_window = nccl_window, diff --git a/deep_ep/buffers/elastic.py b/deep_ep/buffers/elastic.py index 19414e9a..a31e011d 100644 --- a/deep_ep/buffers/elastic.py +++ b/deep_ep/buffers/elastic.py @@ -204,6 +204,7 @@ def __init__(self, else: num_allocated_qps = 17 self.num_allocated_qps = num_allocated_qps + self.num_pp_qps = 0 # Create CPU communicator (exchange POSIX FD handles for CPU segments) cpu_comm = [] @@ -456,15 +457,26 @@ def engram_fetch(self, indices: torch.Tensor, num_qps: int = 0) -> Callable: """ return self.runtime.engram_fetch(indices, num_qps) - def pp_set_config(self, num_max_tensor_bytes: int, num_max_inflight_tensors: int): + def _normalize_pp_num_qps(self, num_qps: int) -> int: + assert num_qps >= 0, 'Number of PP QPs must be non-negative' + if num_qps == 0: + return self.num_allocated_qps + assert num_qps <= self.num_allocated_qps, 'Allocated QPs are not enough' + return num_qps + + def pp_set_config(self, num_max_tensor_bytes: int, num_max_inflight_tensors: int, num_qps: int = 0): """ (Experimental) Configure pipeline-parallel (PP) send/recv parameters. Includes a barrier to flush previous operations. Arguments: num_max_tensor_bytes: the maximum tensor size in bytes per send/recv operation. num_max_inflight_tensors: the maximum number of in-flight tensors at once. + num_qps: the number of RDMA QPs to use for PP (0 for all allocated QPs). + This is fixed until the next `pp_set_config` call and must match on paired ranks. """ - self.runtime.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors) + actual_num_qps = self._normalize_pp_num_qps(num_qps) + self.runtime.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors, num_qps) + self.num_pp_qps = actual_num_qps def pp_send(self, t: torch.Tensor, dst_rank_idx: int, num_sms: int = 0) -> None: """ diff --git a/deep_ep/include/deep_ep/impls/pp_send_recv.cuh b/deep_ep/include/deep_ep/impls/pp_send_recv.cuh index 4d42260b..ad25f137 100644 --- a/deep_ep/include/deep_ep/impls/pp_send_recv.cuh +++ b/deep_ep/include/deep_ep/impls/pp_send_recv.cuh @@ -15,6 +15,23 @@ __device__ __forceinline__ std::pair get_buffer_offset( return dst_rank_idx == next_rank_idx ? std::make_pair(0, 1) : std::make_pair(1, 0); } +template +__device__ __forceinline__ std::pair get_pp_send_chunk( + const int& qp_idx, const int64_t& num_bytes) { + EP_STATIC_ASSERT(kNumQPs > 0, "Invalid number of QPs"); + + const auto num_tma_blocks = num_bytes / ptx::kNumTMAAlignBytes; + const auto num_tma_blocks_per_qp = math::ceil_div(num_tma_blocks, static_cast(kNumQPs)); + const auto start_block_idx = static_cast(qp_idx) * num_tma_blocks_per_qp; + if (start_block_idx >= num_tma_blocks) + return {0, 0}; + + const auto end_block_idx = std::min(start_block_idx + num_tma_blocks_per_qp, num_tma_blocks); + const auto offset = start_block_idx * ptx::kNumTMAAlignBytes; + const auto chunk_bytes = static_cast((end_block_idx - start_block_idx) * ptx::kNumTMAAlignBytes); + return {offset, chunk_bytes}; +} + template __device__ __forceinline__ void check_signal( const handle::NCCLGin& gin, @@ -113,6 +130,7 @@ __device__ __forceinline__ void tma_copy( template __global__ void __launch_bounds__(32, 1) pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, @@ -126,7 +144,7 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, const auto [local_idx_in_dst, dst_idx_in_local] = get_buffer_offset(rank_idx, dst_rank_idx); // Gin handle - const auto gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA); + const auto signal_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA); // Buffer offsets const auto send_count_ptr = workspace_layout.get_pp_send_count_ptr(dst_idx_in_local); @@ -140,7 +158,7 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, // Wait buffer slot release and do TMA if (ptx::elect_one_sync()) { check_signal( - gin, + signal_gin, static_cast(kNumRanks + dst_idx_in_local + 2), send_count - num_max_inflight_tensors + 1, // TODO: print more info, and control the SM who prints it @@ -150,15 +168,28 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, } cooperative_groups::this_grid().sync(); - // Issue RDMA put + // Issue RDMA puts across all requested QPs. Each QP signals completion independently, + // because there is no cross-QP ordering guarantee. + if (ptx::elect_one_sync()) { + for (int qp_idx = sm_idx; qp_idx < kNumQPs; qp_idx += kNumSMs) { + const auto [offset, chunk_bytes] = get_pp_send_chunk(qp_idx, num_x_bytes); + const auto data_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, qp_idx, NCCL_GIN_RESOURCE_SHARING_CTA); + const auto recv_signal = static_cast(kNumRanks + local_idx_in_dst); + + if (chunk_bytes == 0) { + data_gin.signal(dst_rank_idx, ncclGin_SignalInc{recv_signal}); + } else { + data_gin.put( + math::advance_ptr(recv_buffer_ptr, offset), + math::advance_ptr(send_buffer_ptr, offset), + chunk_bytes, dst_rank_idx, + 0, + ncclGin_SignalInc{recv_signal}); + } + } + } + if (sm_idx == 0 and ptx::elect_one_sync()) { - gin.put( - recv_buffer_ptr, - send_buffer_ptr, - num_x_bytes, dst_rank_idx, - 0, - // TODO: is this signal highly optimized? - ncclGin_SignalInc(static_cast(local_idx_in_dst + kNumRanks))); *send_count_ptr += 1; } } @@ -166,6 +197,7 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, template __global__ void __launch_bounds__(32, 1) pp_recv_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, @@ -178,8 +210,9 @@ pp_recv_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, const auto workspace_layout = layout::WorkspaceLayout(workspace, 1, kNumRanks, 0); const auto [src_idx_in_local, local_idx_in_src] = get_buffer_offset(src_rank_idx, rank_idx); - // Gin handle - const auto gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA); + // QP 0 is used for release notification. Payload completion signals are + // polled on the same QP/context that issued each RDMA put. + const auto signal_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA); // Buffer offsets const auto recv_count_ptr = workspace_layout.get_pp_recv_count_ptr(src_idx_in_local); @@ -190,21 +223,29 @@ pp_recv_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window, // Copy from the buffer into a new tensor if (ptx::elect_one_sync()) { - check_signal( - gin, - static_cast(src_idx_in_local + kNumRanks), - recv_count + 1, - // TODO: print more info, and control the SM who prints it - []() { printf("DeepEP PP recv timeout, recv buffer is empty\n"); } - ); + for (int qp_idx = sm_idx; qp_idx < kNumQPs; qp_idx += kNumSMs) { + const auto data_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, qp_idx, NCCL_GIN_RESOURCE_SHARING_CTA); + check_signal( + data_gin, + static_cast(src_idx_in_local + kNumRanks), + recv_count + 1, + // TODO: print more info, and control the SM who prints it + []() { printf("DeepEP PP recv timeout, recv buffer is empty\n"); } + ); + } + } + cooperative_groups::this_grid().sync(); + + if (ptx::elect_one_sync()) { tma_copy(recv_buffer_ptr, x, num_x_bytes, sm_idx); } cooperative_groups::this_grid().sync(); // TODO: add a comment if (sm_idx == 0 and ptx::elect_one_sync()) { - gin.signal( - src_rank_idx, ncclGin_SignalInc(static_cast(kNumRanks + local_idx_in_src + 2)) + signal_gin.signal( + src_rank_idx, + ncclGin_SignalInc(static_cast(kNumRanks + 2 + local_idx_in_src)) ); *recv_count_ptr += 1; } diff --git a/tests/elastic/test_pp.py b/tests/elastic/test_pp.py index e572e102..a213a2c4 100644 --- a/tests/elastic/test_pp.py +++ b/tests/elastic/test_pp.py @@ -48,13 +48,16 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): group, explicitly_destroy=True, allow_hybrid_mode=False, num_bytes=deep_ep.ElasticBuffer.get_pp_buffer_size_hint( num_max_tensor_bytes, num_max_inflight_tensors)) - buffer.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors) + requested_num_qps = args.num_qps + buffer.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors, num_qps=requested_num_qps) + num_qps = buffer.num_pp_qps # Print configs assert num_ranks > 1 dist_print(f'Config:\n' f' > Ranks: {num_ranks}\n' f' > Shape: {shape}\n' + f' > QPs: {num_qps}/{buffer.num_allocated_qps}\n' f' > Max inflight tensors: {num_max_inflight_tensors}\n', once_in_node=True) @@ -127,6 +130,7 @@ def loop(_hide_rdma_latency=hide_rdma_latency): parser.add_argument('--num-tokens', type=int, default=4096) parser.add_argument('--hidden', type=int, default=7168) parser.add_argument('--num-max-inflight-tensors', type=int, default=4) + parser.add_argument('--num_qps', type=int, default=0) parser.add_argument('--num-stress-iterations', type=int, default=4) parser.add_argument('--num-sends', type=int, default=128) parser.add_argument('--num-sleep-cycles', type=int, default=10 ** 7)