Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions csrc/elastic/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class ElasticBuffer {
int prev_rank_idx = 0, next_rank_idx = 0;
int64_t num_max_pp_tensor_bytes = 0;
int num_max_pp_inflight_tensors = 0;
int num_pp_qps = 0;

// AGRS session settings
int64_t num_max_agrs_session_bytes = 0;
Expand Down Expand Up @@ -294,20 +295,23 @@ class ElasticBuffer {
};
}

void pp_set_config(const int64_t& num_max_tensor_bytes, const int& num_max_inflight_tensors) {
void pp_set_config(const int64_t& num_max_tensor_bytes, const int& num_max_inflight_tensors, const int& num_qps) {
// Flush previous operations
barrier(false, true);

EP_HOST_ASSERT(num_max_tensor_bytes > 0 and num_max_inflight_tensors > 0);
EP_HOST_ASSERT(num_max_tensor_bytes * num_max_inflight_tensors * 2 * 2 <= num_buffer_bytes);
const auto actual_num_qps = num_qps == 0 ? nccl_context->num_allocated_qps : num_qps;
EP_HOST_ASSERT(actual_num_qps > 0 and actual_num_qps <= nccl_context->num_allocated_qps);
this->prev_rank_idx = (nccl_context->rank_idx + nccl_context->num_ranks - 1) % nccl_context->num_ranks;
this->next_rank_idx = (nccl_context->rank_idx + 1) % nccl_context->num_ranks;
this->num_max_pp_tensor_bytes = math::align<int64_t>(num_max_tensor_bytes, 32);
this->num_max_pp_inflight_tensors = num_max_inflight_tensors;
this->num_pp_qps = actual_num_qps;
}

void pp_send(const torch::Tensor& x, const int& dst_rank_idx, const int& num_sms) const {
EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0);
EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0 and num_pp_qps > 0);
EP_HOST_ASSERT(x.is_cuda() and x.is_contiguous() and x.nbytes() <= num_max_pp_tensor_bytes);
EP_HOST_ASSERT(dst_rank_idx == prev_rank_idx or dst_rank_idx == next_rank_idx);

Expand All @@ -319,14 +323,15 @@ class ElasticBuffer {
num_max_pp_tensor_bytes,
num_max_pp_inflight_tensors,
num_sms == 0 ? jit::device_runtime->get_num_sms() : num_sms,
num_pp_qps,
num_gpu_timeout_cycles,
jit::device_runtime->get_num_smem_bytes(),
at::cuda::getCurrentCUDAStream()
);
}

void pp_recv(const torch::Tensor& x, const int& src_rank_idx, const int& num_sms) const {
EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0);
EP_HOST_ASSERT(num_max_pp_tensor_bytes > 0 and num_max_pp_inflight_tensors > 0 and num_pp_qps > 0);
EP_HOST_ASSERT(x.is_cuda() and x.is_contiguous() and x.nbytes() <= num_max_pp_tensor_bytes);
EP_HOST_ASSERT(src_rank_idx == prev_rank_idx or src_rank_idx == next_rank_idx);

Expand All @@ -338,6 +343,7 @@ class ElasticBuffer {
num_max_pp_tensor_bytes,
num_max_pp_inflight_tensors,
num_sms == 0 ? jit::device_runtime->get_num_sms() : num_sms,
num_pp_qps,
num_gpu_timeout_cycles,
jit::device_runtime->get_num_smem_bytes(),
at::cuda::getCurrentCUDAStream()
Expand Down Expand Up @@ -1314,9 +1320,12 @@ static void register_apis(pybind11::module_& m) {
.def("barrier", &ElasticBuffer::barrier)
.def("engram_write", &ElasticBuffer::engram_write)
.def("engram_fetch", &ElasticBuffer::engram_fetch)
.def("pp_set_config", &ElasticBuffer::pp_set_config)
.def("pp_send", &ElasticBuffer::pp_send)
.def("pp_recv", &ElasticBuffer::pp_recv)
.def("pp_set_config", &ElasticBuffer::pp_set_config,
pybind11::arg("num_max_tensor_bytes"), pybind11::arg("num_max_inflight_tensors"), pybind11::arg("num_qps") = 0)
.def("pp_send", &ElasticBuffer::pp_send,
pybind11::arg("x"), pybind11::arg("dst_rank_idx"), pybind11::arg("num_sms") = 0)
.def("pp_recv", &ElasticBuffer::pp_recv,
pybind11::arg("x"), pybind11::arg("src_rank_idx"), pybind11::arg("num_sms") = 0)
.def("create_agrs_session", &ElasticBuffer::create_agrs_session)
.def("destroy_agrs_session", &ElasticBuffer::destroy_agrs_session)
.def("agrs_set_config", &ElasticBuffer::agrs_set_config)
Expand Down
12 changes: 10 additions & 2 deletions csrc/kernels/elastic/pp_send_recv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PPSendRuntime final : public jit::LaunchRuntime<PPSendRuntime> {
// Templated arguments
int num_ranks;
int num_smem_bytes;
int num_qps;
int64_t num_timeout_cycles;

// Parameters
Expand All @@ -41,11 +42,12 @@ class PPSendRuntime final : public jit::LaunchRuntime<PPSendRuntime> {
using namespace deep_ep::elastic;

static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&pp_send_impl<{}, {}, {}, {}>);
auto ptr = reinterpret_cast<void*>(&pp_send_impl<{}, {}, {}, {}, {}>);
}}
)", args.launch_args.grid_dim.first,
args.num_ranks,
args.num_smem_bytes,
args.num_qps,
args.num_timeout_cycles);
}

Expand All @@ -69,13 +71,15 @@ static void launch_pp_send(const ncclDevComm_t& nccl_dev_comm,
const int64_t& num_max_tensor_bytes,
const int num_max_inflight_tensors,
const int& num_sms,
const int& num_qps,
const int64_t& num_timeout_cycles,
const int& num_smem_bytes,
const at::cuda::CUDAStream& stream) {
// Generate, build and launch
const PPSendRuntime::Args args = {
.num_ranks = num_ranks,
.num_smem_bytes = num_smem_bytes,
.num_qps = num_qps,
.num_timeout_cycles = num_timeout_cycles,
.nccl_dev_comm = nccl_dev_comm,
.nccl_window = nccl_window,
Expand All @@ -100,6 +104,7 @@ class PPRecvRuntime final : public jit::LaunchRuntime<PPRecvRuntime> {
// Templated arguments
int num_ranks;
int num_smem_bytes;
int num_qps;
int64_t num_timeout_cycles;

// Parameters
Expand All @@ -124,11 +129,12 @@ class PPRecvRuntime final : public jit::LaunchRuntime<PPRecvRuntime> {
using namespace deep_ep::elastic;

static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&pp_recv_impl<{}, {}, {}, {}>);
auto ptr = reinterpret_cast<void*>(&pp_recv_impl<{}, {}, {}, {}, {}>);
}}
)", args.launch_args.grid_dim.first,
args.num_ranks,
args.num_smem_bytes,
args.num_qps,
args.num_timeout_cycles);
}

Expand All @@ -154,13 +160,15 @@ static void launch_pp_recv(const ncclDevComm_t& nccl_dev_comm,
const int64_t& num_max_tensor_bytes,
const int& num_max_inflight_tensors,
const int& num_sms,
const int& num_qps,
const int64_t& num_timeout_cycles,
const int& num_smem_bytes,
const at::cuda::CUDAStream& stream) {
// Generate, build and launch
const PPRecvRuntime::Args args = {
.num_ranks = num_ranks,
.num_smem_bytes = num_smem_bytes,
.num_qps = num_qps,
.num_timeout_cycles = num_timeout_cycles,
.nccl_dev_comm = nccl_dev_comm,
.nccl_window = nccl_window,
Expand Down
16 changes: 14 additions & 2 deletions deep_ep/buffers/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(self,
else:
num_allocated_qps = 17
self.num_allocated_qps = num_allocated_qps
self.num_pp_qps = 0

# Create CPU communicator (exchange POSIX FD handles for CPU segments)
cpu_comm = []
Expand Down Expand Up @@ -456,15 +457,26 @@ def engram_fetch(self, indices: torch.Tensor, num_qps: int = 0) -> Callable:
"""
return self.runtime.engram_fetch(indices, num_qps)

def pp_set_config(self, num_max_tensor_bytes: int, num_max_inflight_tensors: int):
def _normalize_pp_num_qps(self, num_qps: int) -> int:
assert num_qps >= 0, 'Number of PP QPs must be non-negative'
if num_qps == 0:
return self.num_allocated_qps
assert num_qps <= self.num_allocated_qps, 'Allocated QPs are not enough'
return num_qps

def pp_set_config(self, num_max_tensor_bytes: int, num_max_inflight_tensors: int, num_qps: int = 0):
"""
(Experimental) Configure pipeline-parallel (PP) send/recv parameters. Includes a barrier to flush previous operations.

Arguments:
num_max_tensor_bytes: the maximum tensor size in bytes per send/recv operation.
num_max_inflight_tensors: the maximum number of in-flight tensors at once.
num_qps: the number of RDMA QPs to use for PP (0 for all allocated QPs).
This is fixed until the next `pp_set_config` call and must match on paired ranks.
"""
self.runtime.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors)
actual_num_qps = self._normalize_pp_num_qps(num_qps)
self.runtime.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors, num_qps)
self.num_pp_qps = actual_num_qps

def pp_send(self, t: torch.Tensor, dst_rank_idx: int, num_sms: int = 0) -> None:
"""
Expand Down
83 changes: 62 additions & 21 deletions deep_ep/include/deep_ep/impls/pp_send_recv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ __device__ __forceinline__ std::pair<int, int> get_buffer_offset(
return dst_rank_idx == next_rank_idx ? std::make_pair(0, 1) : std::make_pair(1, 0);
}

template <int kNumQPs>
__device__ __forceinline__ std::pair<int64_t, int> get_pp_send_chunk(
const int& qp_idx, const int64_t& num_bytes) {
EP_STATIC_ASSERT(kNumQPs > 0, "Invalid number of QPs");

const auto num_tma_blocks = num_bytes / ptx::kNumTMAAlignBytes;
const auto num_tma_blocks_per_qp = math::ceil_div<int64_t>(num_tma_blocks, static_cast<int64_t>(kNumQPs));
const auto start_block_idx = static_cast<int64_t>(qp_idx) * num_tma_blocks_per_qp;
if (start_block_idx >= num_tma_blocks)
return {0, 0};

const auto end_block_idx = std::min(start_block_idx + num_tma_blocks_per_qp, num_tma_blocks);
const auto offset = start_block_idx * ptx::kNumTMAAlignBytes;
const auto chunk_bytes = static_cast<int>((end_block_idx - start_block_idx) * ptx::kNumTMAAlignBytes);
return {offset, chunk_bytes};
}

template <int64_t kNumTimeoutCycles, typename timeout_print_t>
__device__ __forceinline__ void check_signal(
const handle::NCCLGin& gin,
Expand Down Expand Up @@ -113,6 +130,7 @@ __device__ __forceinline__ void tma_copy(
template <int kNumSMs,
int kNumRanks,
int kNumSmemBytes,
int kNumQPs,
int64_t kNumTimeoutCycles>
__global__ void __launch_bounds__(32, 1)
pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,
Expand All @@ -126,7 +144,7 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,
const auto [local_idx_in_dst, dst_idx_in_local] = get_buffer_offset<kNumRanks>(rank_idx, dst_rank_idx);

// Gin handle
const auto gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA);
const auto signal_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA);

// Buffer offsets
const auto send_count_ptr = workspace_layout.get_pp_send_count_ptr(dst_idx_in_local);
Expand All @@ -140,7 +158,7 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,
// Wait buffer slot release and do TMA
if (ptx::elect_one_sync()) {
check_signal<kNumTimeoutCycles>(
gin,
signal_gin,
static_cast<ncclGinSignal_t>(kNumRanks + dst_idx_in_local + 2),
send_count - num_max_inflight_tensors + 1,
// TODO: print more info, and control the SM who prints it
Expand All @@ -150,22 +168,36 @@ pp_send_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,
}
cooperative_groups::this_grid().sync();

// Issue RDMA put
// Issue RDMA puts across all requested QPs. Each QP signals completion independently,
// because there is no cross-QP ordering guarantee.
if (ptx::elect_one_sync()) {
for (int qp_idx = sm_idx; qp_idx < kNumQPs; qp_idx += kNumSMs) {
const auto [offset, chunk_bytes] = get_pp_send_chunk<kNumQPs>(qp_idx, num_x_bytes);
const auto data_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, qp_idx, NCCL_GIN_RESOURCE_SHARING_CTA);
const auto recv_signal = static_cast<ncclGinSignal_t>(kNumRanks + local_idx_in_dst);

if (chunk_bytes == 0) {
data_gin.signal<ncclTeamTagWorld>(dst_rank_idx, ncclGin_SignalInc{recv_signal});
} else {
data_gin.put<ncclTeamTagWorld>(
math::advance_ptr(recv_buffer_ptr, offset),
math::advance_ptr(send_buffer_ptr, offset),
chunk_bytes, dst_rank_idx,
0,
ncclGin_SignalInc{recv_signal});
}
}
}

if (sm_idx == 0 and ptx::elect_one_sync()) {
gin.put<ncclTeamTagWorld>(
recv_buffer_ptr,
send_buffer_ptr,
num_x_bytes, dst_rank_idx,
0,
// TODO: is this signal highly optimized?
ncclGin_SignalInc(static_cast<ncclGinSignal_t>(local_idx_in_dst + kNumRanks)));
*send_count_ptr += 1;
}
}

template <int kNumSMs,
int kNumRanks,
int kNumSmemBytes,
int kNumQPs,
int64_t kNumTimeoutCycles>
__global__ void __launch_bounds__(32, 1)
pp_recv_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,
Expand All @@ -178,8 +210,9 @@ pp_recv_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,
const auto workspace_layout = layout::WorkspaceLayout(workspace, 1, kNumRanks, 0);
const auto [src_idx_in_local, local_idx_in_src] = get_buffer_offset<kNumRanks>(src_rank_idx, rank_idx);

// Gin handle
const auto gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA);
// QP 0 is used for release notification. Payload completion signals are
// polled on the same QP/context that issued each RDMA put.
const auto signal_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, 0, NCCL_GIN_RESOURCE_SHARING_CTA);

// Buffer offsets
const auto recv_count_ptr = workspace_layout.get_pp_recv_count_ptr(src_idx_in_local);
Expand All @@ -190,21 +223,29 @@ pp_recv_impl(const ncclDevComm_t nccl_dev_comm, const ncclWindow_t nccl_window,

// Copy from the buffer into a new tensor
if (ptx::elect_one_sync()) {
check_signal<kNumTimeoutCycles>(
gin,
static_cast<ncclGinSignal_t>(src_idx_in_local + kNumRanks),
recv_count + 1,
// TODO: print more info, and control the SM who prints it
[]() { printf("DeepEP PP recv timeout, recv buffer is empty\n"); }
);
for (int qp_idx = sm_idx; qp_idx < kNumQPs; qp_idx += kNumSMs) {
const auto data_gin = handle::NCCLGin(nccl_dev_comm, nccl_window, qp_idx, NCCL_GIN_RESOURCE_SHARING_CTA);
check_signal<kNumTimeoutCycles>(
data_gin,
static_cast<ncclGinSignal_t>(src_idx_in_local + kNumRanks),
recv_count + 1,
// TODO: print more info, and control the SM who prints it
[]() { printf("DeepEP PP recv timeout, recv buffer is empty\n"); }
);
}
}
cooperative_groups::this_grid().sync();

if (ptx::elect_one_sync()) {
tma_copy<kNumSMs, kNumSmemBytes>(recv_buffer_ptr, x, num_x_bytes, sm_idx);
}
cooperative_groups::this_grid().sync();

// TODO: add a comment
if (sm_idx == 0 and ptx::elect_one_sync()) {
gin.signal<ncclTeamTagWorld>(
src_rank_idx, ncclGin_SignalInc(static_cast<ncclGinSignal_t>(kNumRanks + local_idx_in_src + 2))
signal_gin.signal<ncclTeamTagWorld>(
src_rank_idx,
ncclGin_SignalInc(static_cast<ncclGinSignal_t>(kNumRanks + 2 + local_idx_in_src))
);
*recv_count_ptr += 1;
}
Expand Down
6 changes: 5 additions & 1 deletion tests/elastic/test_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group, explicitly_destroy=True, allow_hybrid_mode=False,
num_bytes=deep_ep.ElasticBuffer.get_pp_buffer_size_hint(
num_max_tensor_bytes, num_max_inflight_tensors))
buffer.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors)
requested_num_qps = args.num_qps
buffer.pp_set_config(num_max_tensor_bytes, num_max_inflight_tensors, num_qps=requested_num_qps)
num_qps = buffer.num_pp_qps

# Print configs
assert num_ranks > 1
dist_print(f'Config:\n'
f' > Ranks: {num_ranks}\n'
f' > Shape: {shape}\n'
f' > QPs: {num_qps}/{buffer.num_allocated_qps}\n'
f' > Max inflight tensors: {num_max_inflight_tensors}\n',
once_in_node=True)

Expand Down Expand Up @@ -127,6 +130,7 @@ def loop(_hide_rdma_latency=hide_rdma_latency):
parser.add_argument('--num-tokens', type=int, default=4096)
parser.add_argument('--hidden', type=int, default=7168)
parser.add_argument('--num-max-inflight-tensors', type=int, default=4)
parser.add_argument('--num_qps', type=int, default=0)
parser.add_argument('--num-stress-iterations', type=int, default=4)
parser.add_argument('--num-sends', type=int, default=128)
parser.add_argument('--num-sleep-cycles', type=int, default=10 ** 7)
Expand Down