Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions csrc/kernels/legacy/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + LEGACY_NUM
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_devs = ibgda_get_state()->num_devices_initialized;
const auto rc_per_pe = ibgda_get_state()->num_rc_per_pe;
const auto num_channels = num_sms / 2, channel_id = sm_id / 2;
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / LEGACY_NUM_MAX_NVL_PEERS, nvl_rank = rank % LEGACY_NUM_MAX_NVL_PEERS;
Expand Down Expand Up @@ -811,6 +813,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + LEGACY_NUM

// Issue RDMA send
auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
const auto qp_id = channel_id + (channel_id % num_devs) * rc_per_pe;
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send);
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
Expand All @@ -824,7 +827,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + LEGACY_NUM
src_ptr,
num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id,
qp_id,
lane_id,
0);
} else {
Expand All @@ -840,7 +843,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + LEGACY_NUM
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank),
num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id,
qp_id,
dst_rdma_rank == rdma_rank);
}
__syncwarp();
Expand Down Expand Up @@ -1045,12 +1048,13 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + LEGACY_NUM
break;

// Update remote head
const auto qp_id = channel_id + (channel_id % num_devs) * rc_per_pe + num_channels;
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and
lane_id < kNumRDMARanks) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank),
min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank),
channel_id + num_channels,
qp_id,
lane_id == rdma_rank);
last_head = min_head;
}
Expand Down Expand Up @@ -1749,6 +1753,8 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const auto num_devs = ibgda_get_state()->num_devices_initialized;
const auto rc_per_pe = ibgda_get_state()->num_rc_per_pe;
const bool is_forwarder_sm = sm_id % 2 == 1;

EP_DEVICE_ASSERT(num_topk <= 32);
Expand Down Expand Up @@ -2113,6 +2119,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co

// Issue RDMA send
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
const auto qp_id = channel_id + (channel_id % num_devs) * rc_per_pe;
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_token;
Expand All @@ -2124,7 +2131,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
src_ptr,
num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id,
qp_id,
lane_id,
0);
} else {
Expand All @@ -2137,7 +2144,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank),
num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id,
qp_id,
dst_rdma_rank == rdma_rank);
}
}
Expand Down Expand Up @@ -2250,12 +2257,13 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
for (int i = 0; i < kNumRDMAReceivers; ++i)
if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
const auto qp_id = channel_id + (channel_id % num_devs) * rc_per_pe + num_channels;
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and
lane_id < kNumRDMARanks) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank),
min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id + num_channels,
qp_id,
dst_rdma_rank == rdma_rank);
last_rdma_head = min_head;
}
Expand Down
17 changes: 12 additions & 5 deletions csrc/kernels/legacy/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x,
const auto num_sms = static_cast<int>(gridDim.x);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_local_experts = num_experts / num_ranks;
const auto num_devs = ibgda_get_state()->num_devices_initialized;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
Expand Down Expand Up @@ -261,9 +262,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x,
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg;
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
const auto qp_id = dst_expert_local_idx + (dst_expert_local_idx % num_devs) * num_local_experts;
if (not is_rank_masked<true>(mask_buffer_ptr, dst_rank)) {
if (dst_p2p_ptr == 0) {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, qp_id, lane_id, slot_idx);
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
Expand Down Expand Up @@ -331,9 +333,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x,
;
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
const auto qp_id = dst_expert_local_idx + (dst_expert_local_idx % num_devs) * num_local_experts;
if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {
if (dst_p2p_ptr == 0) {
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, qp_id);
} else {
st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);
}
Expand Down Expand Up @@ -743,6 +746,7 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x,
const auto num_threads = __shfl_sync(0xffffffff, static_cast<int>(blockDim.x), 0);
const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks;
const auto num_devs = ibgda_get_state()->num_devices_initialized;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
Expand Down Expand Up @@ -907,8 +911,10 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x,

// Issue RDMA
// NOTES: for zero-copy mode, we assume the data is already in the send buffer
if (dst_p2p_ptr == 0)
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset);
if (dst_p2p_ptr == 0) {
const auto qp_id = local_expert_idx + (local_expert_idx % num_devs) * num_local_experts;
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, qp_id, lane_id, token_idx - offset);
}
}
}

Expand All @@ -920,9 +926,10 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x,
;
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
const auto qp_id = local_expert_idx + (local_expert_idx % num_devs) * num_local_experts;
if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {
if (dst_p2p_ptr == 0) {
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, qp_id);
} else {
st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);
}
Expand Down