diff --git a/examples/distributed/example_pre_attn_all2all_ipc.py b/examples/distributed/example_pre_attn_all2all_ipc.py new file mode 100644 index 0000000000..fe8ae989fe --- /dev/null +++ b/examples/distributed/example_pre_attn_all2all_ipc.py @@ -0,0 +1,648 @@ +import torch +import torch.distributed as dist +import torch.multiprocessing +import tilelang +from tilelang.distributed import init_dist, dtype_map +from cuda import cudart +from tilelang.distributed.utils import CUDA_CHECK +import argparse +import os + +tilelang.disable_cache() + + +def torch_pre_attn_qkv_a2a_reference(group, q_input, k_input, v_input, skip_q_a2a=False, debug=False, tensor_name="", rank=0): + """ + PyTorch golden reference matching pre_attn_qkv_pack_a2a semantics. + + q_input shape: [BATCH_SIZE, SEQ_PER_PE, Q_NUM_HEADS, HEAD_DIM] + k_input shape: [BATCH_SIZE, SEQ_PER_PE, KV_NUM_HEADS, HEAD_DIM] + v_input shape: [BATCH_SIZE, SEQ_PER_PE, KV_NUM_HEADS, HEAD_DIM] + + q_output shape: [BATCH_SIZE, SEQ_LEN, Q_NUM_HEADS_PER_PE, HEAD_DIM] (or None if skip) + k_output shape: [BATCH_SIZE, SEQ_LEN, KV_NUM_HEADS_PER_PE, HEAD_DIM] + v_output shape: [BATCH_SIZE, SEQ_LEN, KV_NUM_HEADS_PER_PE, HEAD_DIM] + + Args: + group: Distributed process group + q_input: Q input on current PE + k_input: K input on current PE + v_input: V input on current PE + skip_q_a2a: Whether to skip q all-to-all + debug: Whether to print tensor shapes and values + tensor_name: Name of tensor for debug output + rank: Current rank for debug output + + Returns: + [q_output, k_output, v_output] + """ + world_size = dist.get_world_size(group) + + def _a2a(data_src, name=""): + # Match target path: [B, S_local, H, D] -> [H, S_local, B, D] -> all_to_all_single -> [B, S_global, H_local, D] + if debug and rank == 0: + print(f"\n=== {tensor_name} {name} ===") + print(f"Input shape: {data_src.shape}") + print(f"Input tensor:\n{data_src}") + + a2a_input = data_src.permute(2, 1, 0, 3).contiguous() + a2a_heads, a2a_seq_per_pe, a2a_batch, a2a_head_dim = a2a_input.shape + + if debug and rank == 0: + print(f"After permute [2,1,0,3]: {a2a_input.shape}") + + if a2a_heads < world_size: + assert world_size % a2a_heads == 0 + repeats = world_size // a2a_heads + a2a_input = torch.repeat_interleave(a2a_input, repeats=repeats, dim=0).contiguous() + a2a_heads, a2a_seq_per_pe, a2a_batch, a2a_head_dim = a2a_input.shape + if debug and rank == 0: + print(f"After repeat_interleave (repeats={repeats}): {a2a_input.shape}") + + assert a2a_heads % world_size == 0 + a2a_output = torch.empty( + (world_size, a2a_heads // world_size, a2a_seq_per_pe, a2a_batch, a2a_head_dim), + dtype=a2a_input.dtype, + device=a2a_input.device, + requires_grad=False, + ) + + if debug and rank == 0: + print(f"Before all_to_all_single - send shape: {a2a_input.shape}, recv shape: {a2a_output.shape}") + + dist.all_to_all_single(a2a_output, a2a_input, group=group) + + if debug and rank == 0: + print(f"After all_to_all_single: {a2a_output.shape}") + + result = ( + a2a_output.permute(3, 0, 2, 1, 4) + .reshape(a2a_batch, a2a_seq_per_pe * world_size, a2a_heads // world_size, a2a_head_dim) + .contiguous() + ) + + if debug and rank == 0: + print(f"After final reshape: {result.shape}") + print(f"Output tensor:\n{result}") + + return result + + q_output = None if skip_q_a2a else _a2a(q_input, "(Q)") + k_output = _a2a(k_input, "(K)") + v_output = _a2a(v_input, "(V)") + return [q_output, k_output, v_output] + + +def _cp_engine_copy_data(dst_ptr, src_ptr, cp_size, stream): + """ + Optimized CUDA async memory copy using cudaMemcpyAsync for reduced overhead. + """ + (err,) = cudart.cudaMemcpyAsync( + dst_ptr, + src_ptr, + cp_size, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + stream.cuda_stream, + ) + CUDA_CHECK(err) + + +def _cp_engine_copy_heads_by_batch( + dst_base_ptr, + src_base_ptr, + dst_batch_stride_bytes, + src_batch_stride_bytes, + dst_seq_offset, + src_head_offset, + seq_per_pe, + heads_per_pe, + total_src_heads, + head_dim, + dtype_itemsize, + batch_size, + stream, +): + """ + Copy [B, S_local, H_slice, D] from source [B, S_local, H_total, D] + to destination [B, S_global, H_slice, D] using 2D async copies per batch. + """ + row_bytes = heads_per_pe * head_dim * dtype_itemsize + src_pitch = total_src_heads * head_dim * dtype_itemsize + dst_pitch = row_bytes + src_head_offset_bytes = src_head_offset * head_dim * dtype_itemsize + dst_seq_offset_bytes = dst_seq_offset * row_bytes + + for b in range(batch_size): + src_ptr = src_base_ptr + b * src_batch_stride_bytes + src_head_offset_bytes + dst_ptr = dst_base_ptr + b * dst_batch_stride_bytes + dst_seq_offset_bytes + (err,) = cudart.cudaMemcpy2DAsync( + dst_ptr, + dst_pitch, + src_ptr, + src_pitch, + row_bytes, + seq_per_pe, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + stream.cuda_stream, + ) + CUDA_CHECK(err) + + +def custom_ipc_all_to_all(data_src_peers, data_dst_peers, local_rank, local_world_size, HEADS_PER_PE, SEQ_PER_PE, stream): + """ + P2P IPC-based all-to-all dimension swap matching the PyTorch golden reference. + Executes asynchronous memory copies pulling from remote buffers to local buffer using cudaMemcpyAsync. + """ + rank_orders = [(local_rank + i) % local_world_size for i in range(local_world_size)] + + # Metadata + batch_size, _, total_heads, head_dim = data_src_peers[local_rank].shape + dtype_itemsize = data_src_peers[local_rank].dtype.itemsize + src_batch_stride_bytes = data_src_peers[local_rank].stride(0) * dtype_itemsize + dst_batch_stride_bytes = data_dst_peers[local_rank].stride(0) * dtype_itemsize + + with torch.cuda.stream(stream): + for src_rank in rank_orders: + src_tensor = data_src_peers[src_rank] + dst_tensor = data_dst_peers[local_rank] + _cp_engine_copy_heads_by_batch( + dst_base_ptr=dst_tensor.data_ptr(), + src_base_ptr=src_tensor.data_ptr(), + dst_batch_stride_bytes=dst_batch_stride_bytes, + src_batch_stride_bytes=src_batch_stride_bytes, + dst_seq_offset=src_rank * SEQ_PER_PE, + src_head_offset=local_rank * HEADS_PER_PE, + seq_per_pe=SEQ_PER_PE, + heads_per_pe=HEADS_PER_PE, + total_src_heads=total_heads, + head_dim=head_dim, + dtype_itemsize=dtype_itemsize, + batch_size=batch_size, + stream=stream, + ) + + +def custom_ipc_pre_attn_qkv_a2a( + q_src_peers, + k_src_peers, + v_src_peers, + q_dst_peers, + k_dst_peers, + v_dst_peers, + local_rank, + local_world_size, + q_heads_per_pe, + kv_heads_per_pe, + seq_per_pe, + stream, + skip_q_a2a=False, +): + """ + Fused P2P IPC-based all-to-all dimension swap matching the PyTorch golden reference. + Uses cudaMemcpyAsync for minimal Python overhead and maximal GPU utilization. + """ + rank_orders = [(local_rank + i) % local_world_size for i in range(local_world_size)] + + # Metadata + batch_size = q_src_peers[local_rank].shape[0] + head_dim = q_src_peers[local_rank].shape[3] + q_total_heads = q_src_peers[local_rank].shape[2] + kv_total_heads = k_src_peers[local_rank].shape[2] + + dtype_itemsize = q_src_peers[local_rank].dtype.itemsize + k_dtype_itemsize = k_src_peers[local_rank].dtype.itemsize + + q_src_batch_stride_bytes = q_src_peers[local_rank].stride(0) * dtype_itemsize + q_dst_batch_stride_bytes = q_dst_peers[local_rank].stride(0) * dtype_itemsize + k_src_batch_stride_bytes = k_src_peers[local_rank].stride(0) * k_dtype_itemsize + k_dst_batch_stride_bytes = k_dst_peers[local_rank].stride(0) * k_dtype_itemsize + v_src_batch_stride_bytes = v_src_peers[local_rank].stride(0) * k_dtype_itemsize + v_dst_batch_stride_bytes = v_dst_peers[local_rank].stride(0) * k_dtype_itemsize + + with torch.cuda.stream(stream): + for src_rank in rank_orders: + if not skip_q_a2a: + q_src = q_src_peers[src_rank] + q_dst = q_dst_peers[local_rank] + _cp_engine_copy_heads_by_batch( + dst_base_ptr=q_dst.data_ptr(), + src_base_ptr=q_src.data_ptr(), + dst_batch_stride_bytes=q_dst_batch_stride_bytes, + src_batch_stride_bytes=q_src_batch_stride_bytes, + dst_seq_offset=src_rank * seq_per_pe, + src_head_offset=local_rank * q_heads_per_pe, + seq_per_pe=seq_per_pe, + heads_per_pe=q_heads_per_pe, + total_src_heads=q_total_heads, + head_dim=head_dim, + dtype_itemsize=dtype_itemsize, + batch_size=batch_size, + stream=stream, + ) + + k_src = k_src_peers[src_rank] + k_dst = k_dst_peers[local_rank] + _cp_engine_copy_heads_by_batch( + dst_base_ptr=k_dst.data_ptr(), + src_base_ptr=k_src.data_ptr(), + dst_batch_stride_bytes=k_dst_batch_stride_bytes, + src_batch_stride_bytes=k_src_batch_stride_bytes, + dst_seq_offset=src_rank * seq_per_pe, + src_head_offset=local_rank * kv_heads_per_pe, + seq_per_pe=seq_per_pe, + heads_per_pe=kv_heads_per_pe, + total_src_heads=kv_total_heads, + head_dim=head_dim, + dtype_itemsize=k_dtype_itemsize, + batch_size=batch_size, + stream=stream, + ) + + v_src = v_src_peers[src_rank] + v_dst = v_dst_peers[local_rank] + _cp_engine_copy_heads_by_batch( + dst_base_ptr=v_dst.data_ptr(), + src_base_ptr=v_src.data_ptr(), + dst_batch_stride_bytes=v_dst_batch_stride_bytes, + src_batch_stride_bytes=v_src_batch_stride_bytes, + dst_seq_offset=src_rank * seq_per_pe, + src_head_offset=local_rank * kv_heads_per_pe, + seq_per_pe=seq_per_pe, + heads_per_pe=kv_heads_per_pe, + total_src_heads=kv_total_heads, + head_dim=head_dim, + dtype_itemsize=k_dtype_itemsize, + batch_size=batch_size, + stream=stream, + ) + + +def pack_local_qkv_for_fused_a2a( + q_input, + k_input, + v_input, + qkv_src_local, + local_world_size, + q_heads_per_pe, + kv_heads_per_pe, + skip_q_a2a=False, +): + """ + Pack local Q/K/V into interleaved head layout: + [rank0(Q,K,V), rank1(Q,K,V), ...] (or [ranki(K,V)] if skip_q_a2a). + This makes each receiver's payload contiguous for one-copy-per-src_rank transfer. + """ + q_block = 0 if skip_q_a2a else q_heads_per_pe + packed_heads_per_rank = q_block + 2 * kv_heads_per_pe + + qkv_src_local.fill_(0) + for dst_rank in range(local_world_size): + base = dst_rank * packed_heads_per_rank + cur = base + if not skip_q_a2a: + q_src = q_input[:, :, dst_rank * q_heads_per_pe : (dst_rank + 1) * q_heads_per_pe, :] + qkv_src_local[:, :, cur : cur + q_heads_per_pe, :].copy_(q_src) + cur += q_heads_per_pe + + k_src = k_input[:, :, dst_rank * kv_heads_per_pe : (dst_rank + 1) * kv_heads_per_pe, :] + qkv_src_local[:, :, cur : cur + kv_heads_per_pe, :].copy_(k_src) + cur += kv_heads_per_pe + + v_src = v_input[:, :, dst_rank * kv_heads_per_pe : (dst_rank + 1) * kv_heads_per_pe, :] + qkv_src_local[:, :, cur : cur + kv_heads_per_pe, :].copy_(v_src) + + +def custom_ipc_pre_attn_qkv_a2a_packed_fused( + qkv_src_peers, + qkv_dst_peers, + local_rank, + local_world_size, + packed_heads_per_rank, + seq_per_pe, + stream, +): + """ + Higher-fusion IPC all-to-all: + each src_rank performs one 2D async copy for packed QKV payload. + """ + rank_orders = [(local_rank + i) % local_world_size for i in range(local_world_size)] + + batch_size, _, total_src_heads, head_dim = qkv_src_peers[local_rank].shape + dtype_itemsize = qkv_src_peers[local_rank].dtype.itemsize + src_batch_stride_bytes = qkv_src_peers[local_rank].stride(0) * dtype_itemsize + dst_batch_stride_bytes = qkv_dst_peers[local_rank].stride(0) * dtype_itemsize + + with torch.cuda.stream(stream): + for src_rank in rank_orders: + src_tensor = qkv_src_peers[src_rank] + dst_tensor = qkv_dst_peers[local_rank] + _cp_engine_copy_heads_by_batch( + dst_base_ptr=dst_tensor.data_ptr(), + src_base_ptr=src_tensor.data_ptr(), + dst_batch_stride_bytes=dst_batch_stride_bytes, + src_batch_stride_bytes=src_batch_stride_bytes, + dst_seq_offset=src_rank * seq_per_pe, + src_head_offset=local_rank * packed_heads_per_rank, + seq_per_pe=seq_per_pe, + heads_per_pe=packed_heads_per_rank, + total_src_heads=total_src_heads, + head_dim=head_dim, + dtype_itemsize=dtype_itemsize, + batch_size=batch_size, + stream=stream, + ) + + +def split_packed_qkv_buffer( + packed_qkv, + q_heads_per_pe, + kv_heads_per_pe, + skip_q_a2a=False, +): + """ + Return views into the packed QKV layout without materializing a reorder. + Layout per source rank is [Q, K, V] or [K, V] if Q is skipped. + """ + q_offset = 0 + kv_offset = q_heads_per_pe if not skip_q_a2a else 0 + q_view = None if skip_q_a2a else packed_qkv[:, :, q_offset : q_offset + q_heads_per_pe, :] + k_view = packed_qkv[:, :, kv_offset : kv_offset + kv_heads_per_pe, :] + v_view = packed_qkv[:, :, kv_offset + kv_heads_per_pe : kv_offset + 2 * kv_heads_per_pe, :] + return q_view, k_view, v_view + + +def verify_results(custom_output, torch_output, rank, tensor_name="", tolerance=1e-3): + """Verify output against PyTorch golden reference. Returns True if passed, False if failed.""" + if not torch.allclose(custom_output, torch_output, atol=tolerance, rtol=tolerance): + print(f"❌ PE {rank} {tensor_name} Verification FAILED!") + + diff = torch.abs(custom_output - torch_output) + max_diff = torch.max(diff) + mean_diff = torch.mean(diff) + + print(f" Max difference: {max_diff:.6f}") + print(f" Mean difference: {mean_diff:.6f}") + print(f" TileLang shape: {custom_output.shape}") + print(f" PyTorch shape: {torch_output.shape}") + + # Find position with maximum difference + # max_pos = torch.unravel_index(torch.argmax(diff), diff.shape) + # print(f" Max diff position: {max_pos}") + # print(f" TileLang value: {custom_output[max_pos]:.6f}") + # print(f" PyTorch value: {torch_output[max_pos]:.6f}") + print(f" TileLang output: {custom_output}") + print(f" PyTorch output: {torch_output}") + + return False + else: + return True + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--num_processes", type=int, default=4, help="Number of GPUs to spawn") + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--gqa", type=int, default=1, help="group size of group query attn") + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--skip_q_a2a", default=False, action="store_true", help="skip q all-to-all") + parser.add_argument("--dtype", default="float16") + parser.add_argument("--debug", default=False, action="store_true", help="print debug info with sequential integer inputs") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations for profiling") + parser.add_argument("--iters", type=int, default=10, help="Number of iterations for profiling") + return parser.parse_args() + + +def run_all_to_all_with_golden_reference(args, WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP): + try: + PE_num = WORLD_SIZE + assert args.seq_len % PE_num == 0 + assert args.num_heads % PE_num == 0 + assert args.gqa > 0 + assert args.num_heads % args.gqa == 0 + + SEQ_PER_PE = args.seq_len // PE_num + Q_HEADS_PER_PE = args.num_heads // PE_num + kv_num_heads = args.num_heads // args.gqa + assert kv_num_heads % PE_num == 0 + KV_HEADS_PER_PE = kv_num_heads // PE_num + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", WORLD_SIZE)) + assert PE_num == local_world_size, "IPC mode in this example only supports single-node runs" + + allocator = tilelang.get_allocator( + size=2**30, + device="cuda", + is_distributed=True, + local_rank=LOCAL_RANK, + num_local_ranks=local_world_size, + group=LC_GROUP, + ) + + if RANK == 0: + print("=== All-to-All with IPC and PyTorch Golden Reference ===") + print(f"Batch size: {args.batch_size}") + print(f"Sequence length: {args.seq_len}") + print(f"Number of heads: {args.num_heads}") + print(f"KV heads: {kv_num_heads}") + print(f"GQA group size: {args.gqa}") + print(f"Head dimension: {args.head_dim}") + print(f"PE count: {PE_num}") + print(f"Sequence per PE: {SEQ_PER_PE}") + print(f"Q heads per PE: {Q_HEADS_PER_PE}") + print(f"KV heads per PE: {KV_HEADS_PER_PE}") + print(f"Skip q all-to-all: {args.skip_q_a2a}") + + dtype_torch = dtype_map[args.dtype] + + # PyTorch reference inputs + if args.debug: + # Use sequential integers for easy visualization + q_size = args.batch_size * SEQ_PER_PE * args.num_heads * args.head_dim + q_input = torch.arange(q_size, dtype=dtype_torch, device="cuda").reshape( + [args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim] + ) + + k_size = args.batch_size * SEQ_PER_PE * kv_num_heads * args.head_dim + k_input = torch.arange(k_size, dtype=dtype_torch, device="cuda").reshape( + [args.batch_size, SEQ_PER_PE, kv_num_heads, args.head_dim] + ) + v_input = torch.arange(k_size, dtype=dtype_torch, device="cuda").reshape( + [args.batch_size, SEQ_PER_PE, kv_num_heads, args.head_dim] + ) + else: + q_input = torch.rand([args.batch_size, SEQ_PER_PE, args.num_heads, args.head_dim], dtype=dtype_torch, device="cuda") + k_input = torch.rand([args.batch_size, SEQ_PER_PE, kv_num_heads, args.head_dim], dtype=dtype_torch, device="cuda") + v_input = torch.rand([args.batch_size, SEQ_PER_PE, kv_num_heads, args.head_dim], dtype=dtype_torch, device="cuda") + + # PyTorch Reference run + dist.barrier(TP_GROUP) + torch_q_out, torch_k_out, torch_v_out = torch_pre_attn_qkv_a2a_reference( + TP_GROUP, + q_input, + k_input, + v_input, + skip_q_a2a=args.skip_q_a2a, + debug=args.debug, + tensor_name=f"Rank {RANK}", + rank=RANK, + ) + + # Packed fused IPC memory allocation + q_block = 0 if args.skip_q_a2a else Q_HEADS_PER_PE + packed_heads_per_rank = q_block + 2 * KV_HEADS_PER_PE + packed_total_src_heads = packed_heads_per_rank * local_world_size + + qkv_src_peers = tilelang.tensor( + (args.batch_size, SEQ_PER_PE, packed_total_src_heads, args.head_dim), + dtype_torch, + allocator=allocator, + return_peers=True, + ) + qkv_dst_peers = tilelang.tensor( + (args.batch_size, args.seq_len, packed_heads_per_rank, args.head_dim), + dtype_torch, + allocator=allocator, + return_peers=True, + ) + + # Local initialization: pack once, then fused all-to-all copies packed payload + pack_local_qkv_for_fused_a2a( + q_input, + k_input, + v_input, + qkv_src_peers[LOCAL_RANK], + local_world_size, + Q_HEADS_PER_PE, + KV_HEADS_PER_PE, + skip_q_a2a=args.skip_q_a2a, + ) + qkv_dst_peers[LOCAL_RANK].fill_(0.0) + + torch.cuda.synchronize() + dist.barrier(LC_GROUP) + + # Run IPC data transfer + stream = torch.cuda.Stream() + custom_ipc_pre_attn_qkv_a2a_packed_fused( + qkv_src_peers, + qkv_dst_peers, + LOCAL_RANK, + local_world_size, + packed_heads_per_rank, + SEQ_PER_PE, + stream, + ) + stream.synchronize() + dist.barrier(LC_GROUP) + + qkv_dst_local = qkv_dst_peers[LOCAL_RANK] + custom_q_out, custom_k_out, custom_v_out = split_packed_qkv_buffer( + qkv_dst_local, + Q_HEADS_PER_PE, + KV_HEADS_PER_PE, + skip_q_a2a=args.skip_q_a2a, + ) + + if RANK == 0: + print("Finished IPC Q/K/V all-to-all.") + + # Collect verification results + results = [] + if not args.skip_q_a2a: + results.append(verify_results(custom_q_out, torch_q_out, RANK, "(Q)")) + results.append(verify_results(custom_k_out, torch_k_out, RANK, "(K)")) + results.append(verify_results(custom_v_out, torch_v_out, RANK, "(V)")) + + # Output unified result + if all(results): + if RANK == 0: + print("✅ All Verification PASSED!") + else: + print(f"❌ PE {RANK} Verification FAILED!") + + # --- Profiling --- + warmup = args.warmup + iters = args.iters + + # Profiling Torch Baseline + torch_times = [] + for i in range(warmup + iters): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + dist.barrier(TP_GROUP) + start_event.record() + + torch_pre_attn_qkv_a2a_reference( + TP_GROUP, + q_input, + k_input, + v_input, + skip_q_a2a=args.skip_q_a2a, + ) + + end_event.record() + end_event.synchronize() + + if i >= warmup: + torch_times.append(start_event.elapsed_time(end_event)) + + # Profiling Custom IPC All-to-All + custom_times = [] + profile_stream = torch.cuda.Stream() + for i in range(warmup + iters): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + dist.barrier(LC_GROUP) + start_event.record(profile_stream) + + custom_ipc_pre_attn_qkv_a2a_packed_fused( + qkv_src_peers, + qkv_dst_peers, + LOCAL_RANK, + local_world_size, + packed_heads_per_rank, + SEQ_PER_PE, + profile_stream, + ) + + profile_stream.synchronize() + # Important: Add barrier to ensure all peers have finished pulling before advancing + dist.barrier(LC_GROUP) + end_event.record() + end_event.synchronize() + + if i >= warmup: + custom_times.append(start_event.elapsed_time(end_event)) + + if RANK == 0: + torch_avg = sum(torch_times) / iters + custom_avg = sum(custom_times) / iters + print(f"\n=== Profiling Results ({iters} iters) ===") + print(f"PyTorch All-to-All : {torch_avg:.3f} ms") + print(f"Custom IPC : {custom_avg:.3f} ms") + print(f"Speedup : {torch_avg / custom_avg:.2f}x\n") + + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def test_all_to_all_with_golden_reference_spawn(local_rank: int, num_local_ranks: int, args): + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + run_all_to_all_with_golden_reference(args, num_ranks, rank, local_rank, group, group) + + +if __name__ == "__main__": + args = parse_args() + if args.num_processes is None: + args.num_processes = int(os.environ.get("LOCAL_WORLD_SIZE", torch.cuda.device_count())) + torch.multiprocessing.spawn( + test_all_to_all_with_golden_reference_spawn, + args=(args.num_processes, args), + nprocs=args.num_processes, + ) diff --git a/examples/distributed/example_sp_all2all_attention_intra_node.py b/examples/distributed/example_sp_all2all_attention_intra_node.py new file mode 100644 index 0000000000..d06ad60ded --- /dev/null +++ b/examples/distributed/example_sp_all2all_attention_intra_node.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +import argparse +from itertools import accumulate + +import tilelang +import torch +import torch.distributed as dist +import torch.multiprocessing + +from tilelang.distributed import init_dist, perf_fn +from sp_all2all_attention_intra_node import ( + create_sp_all2all_attention_context_intra_node, + fused_sp_all2all_attn_intra_node, +) + + +def torch_pre_attn_qkv_a2a_reference(group, q_input, k_input, v_input, skip_q_a2a=False): + world_size = dist.get_world_size(group) + + def _a2a(data_src): + a2a_input = data_src.permute(2, 1, 0, 3).contiguous() + a2a_heads, a2a_seq_per_pe, a2a_batch, a2a_head_dim = a2a_input.shape + assert a2a_heads % world_size == 0 + + a2a_output = torch.empty( + (world_size, a2a_heads // world_size, a2a_seq_per_pe, a2a_batch, a2a_head_dim), + dtype=a2a_input.dtype, + device=a2a_input.device, + requires_grad=False, + ) + dist.all_to_all_single(a2a_output, a2a_input, group=group) + return ( + a2a_output.permute(3, 0, 2, 1, 4) + .reshape(a2a_batch, a2a_seq_per_pe * world_size, a2a_heads // world_size, a2a_head_dim) + .contiguous() + ) + + q_output = None if skip_q_a2a else _a2a(q_input) + k_output = _a2a(k_input) + v_output = _a2a(v_input) + return q_output, k_output, v_output + + +def torch_attention_reference(q_out, k_out, v_out, is_causal, q_start_offsets=None): + # q_out: [B, S, Hq_local, D], k/v_out: [B, S, Hkv_local, D] + batch_size, seq_len, q_heads, head_dim = q_out.shape + kv_heads = k_out.shape[2] + assert q_heads % kv_heads == 0 + groups = q_heads // kv_heads + out_list = [] + for b in range(batch_size): + q_b = q_out[b].permute(1, 0, 2).unsqueeze(0).contiguous() # [1, Hq, S, D] + k_b = k_out[b].permute(1, 0, 2).unsqueeze(0).contiguous() # [1, Hkv, S, D] + v_b = v_out[b].permute(1, 0, 2).unsqueeze(0).contiguous() # [1, Hkv, S, D] + + k_b = k_b.repeat_interleave(groups, dim=1) + v_b = v_b.repeat_interleave(groups, dim=1) + + attn_mask = None + if is_causal: + q_positions = torch.arange(seq_len, device=q_out.device)[:, None] + k_positions = torch.arange(k_out.shape[1], device=q_out.device)[None, :] + attn_mask = k_positions <= q_positions + + out_b = torch.nn.functional.scaled_dot_product_attention(q_b, k_b, v_b, attn_mask=attn_mask) + out_b = out_b.squeeze(0).permute(1, 0, 2).contiguous() # [S, Hq, D] + out_list.append(out_b) + + return torch.cat(out_list, dim=0) # [B*S, Hq, D] + + +def pack_local_qkv_for_all2all( + q_input, + k_input, + v_input, + packed_local, + local_world_size, + q_heads_per_rank, + kv_heads_per_rank, +): + # q_input: [B, S_local, Hq_global, D], k/v_input: [B, S_local, Hkv_global, D] + batch_size, seq_per_pe, _, head_dim = q_input.shape + packed_heads_per_rank = q_heads_per_rank + 2 * kv_heads_per_rank + packed_total_heads = packed_heads_per_rank * local_world_size + + packed_view = packed_local.view(batch_size, seq_per_pe, packed_total_heads, head_dim) + packed_view.fill_(0) + + for dst_rank in range(local_world_size): + base = dst_rank * packed_heads_per_rank + q_slice = q_input[:, :, dst_rank * q_heads_per_rank : (dst_rank + 1) * q_heads_per_rank, :] + k_slice = k_input[:, :, dst_rank * kv_heads_per_rank : (dst_rank + 1) * kv_heads_per_rank, :] + v_slice = v_input[:, :, dst_rank * kv_heads_per_rank : (dst_rank + 1) * kv_heads_per_rank, :] + + packed_view[:, :, base : base + q_heads_per_rank, :].copy_(q_slice) + packed_view[:, :, base + q_heads_per_rank : base + q_heads_per_rank + kv_heads_per_rank, :].copy_(k_slice) + packed_view[:, :, base + q_heads_per_rank + kv_heads_per_rank : base + q_heads_per_rank + 2 * kv_heads_per_rank, :].copy_(v_slice) + + +class FusedSequenceParallelAll2AllAttn(torch.nn.Module): + def __init__( + self, + pg: torch.distributed.ProcessGroup, + batch_size: int, + q_head: int, + kv_head: int, + max_seqlen_q: int, + max_seqlen_k: int, + head_dim: int, + input_dtype=torch.float16, + output_dtype=torch.float16, + device="cuda", + is_causal=True, + enable_zig_zag=True, + allocator=None, + ): + super(FusedSequenceParallelAll2AllAttn, self).__init__() + self.pg = pg + self.rank = pg.rank() + self.world_size = pg.size() + + self.batch_size = batch_size + self.q_head = q_head + self.kv_head = kv_head + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_k = max_seqlen_k + self.head_dim = head_dim + self.input_dtype = input_dtype + self.output_dtype = output_dtype + self.device = device + self.is_causal = is_causal + self.enable_zig_zag = enable_zig_zag + self.allocator = allocator + + assert self.q_head % self.world_size == 0, "q_head should be divisible by world_size" + assert self.kv_head % self.world_size == 0, "kv_head should be divisible by world_size" + self.q_head_per_rank = self.q_head // self.world_size + self.kv_head_per_rank = self.kv_head // self.world_size + self.max_q_shard_len = self.max_seqlen_q // self.world_size + + self.ctx = create_sp_all2all_attention_context_intra_node( + self.batch_size, + self.q_head_per_rank, + self.kv_head_per_rank, + self.max_seqlen_k, + self.max_q_shard_len, + self.head_dim, + self.input_dtype, + self.output_dtype, + self.rank, + self.world_size, + self.device, + self.allocator, + ) + + def forward(self, packed_qkv_shards, cu_seqlens_q, cu_seqlens_k, print_source=False): + output_buffer = self.ctx.attn_output_buffer + + fused_sp_all2all_attn_intra_node( + self.ctx, + packed_qkv_shards, + output_buffer, + cu_seqlens_q, + cu_seqlens_k, + self.max_q_shard_len, + self.rank, + self.world_size, + self.q_head_per_rank, + self.kv_head_per_rank, + self.is_causal, + self.enable_zig_zag, + print_source, + ) + + return output_buffer + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = torch.float16 + device = "cuda" + + batch_size = args.batch_size + q_head = args.q_head + kv_head = args.kv_head + max_seqlen_q = args.max_seqlen_q + max_seqlen_k = args.max_seqlen_k + head_dim = args.head_dim + is_causal = args.is_causal + enable_zig_zag = args.zig_zag + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + try: + assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" + + seqlens_q = args.seqlens_q + seqlens_k = args.seqlens_k + assert len(seqlens_q) == batch_size and len(seqlens_k) == batch_size + assert q_head % num_ranks == 0, "q_head should be divisible by world size" + assert kv_head % num_ranks == 0, "kv_head should be divisible by world size" + for s in seqlens_q + seqlens_k: + assert s % num_ranks == 0, "all2all requires per-batch sequence length divisible by world size" + + cu_seqlens_q_list = [0] + list(accumulate(seqlens_q)) + cu_seqlens_k_list = [0] + list(accumulate(seqlens_k)) + cu_seqlens_q = torch.tensor(cu_seqlens_q_list, dtype=torch.int32, device=device) // num_ranks + cu_seqlens_k = torch.tensor(cu_seqlens_k_list, dtype=torch.int32, device=device) + + allocator = tilelang.get_allocator( + size=2**30, + device=device, + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group, + ) + + seq_per_rank_max = max_seqlen_k // num_ranks + q_head_per_rank = q_head // num_ranks + kv_head_per_rank = kv_head // num_ranks + packed_heads_per_rank = q_head_per_rank + 2 * kv_head_per_rank + packed_total_heads = packed_heads_per_rank * num_ranks + + packed_qkv_shards = tilelang.tensor( + (batch_size * seq_per_rank_max, packed_total_heads, head_dim), + dtype=dtype, + allocator=allocator, + return_peers=True, + ) + + q_input = torch.randn((batch_size, seq_per_rank_max, q_head, head_dim), dtype=dtype, device=device) + k_input = torch.randn((batch_size, seq_per_rank_max, kv_head, head_dim), dtype=dtype, device=device) + v_input = torch.randn((batch_size, seq_per_rank_max, kv_head, head_dim), dtype=dtype, device=device) + local_q_batch_lens = [s // num_ranks for s in seqlens_q] + local_q_start_offsets = [rank * local_len for local_len in local_q_batch_lens] + + pack_local_qkv_for_all2all( + q_input, + k_input, + v_input, + packed_qkv_shards[local_rank], + num_ranks, + q_head_per_rank, + kv_head_per_rank, + ) + + dist.barrier(group) + + tilescale_module = FusedSequenceParallelAll2AllAttn( + group, + batch_size, + q_head, + kv_head, + max_seqlen_q, + max_seqlen_k, + head_dim, + dtype, + dtype, + device, + is_causal, + enable_zig_zag, + allocator=allocator, + ) + + tilescale_out = tilescale_module(packed_qkv_shards, cu_seqlens_q, cu_seqlens_k, print_source=args.print_source) + # valid_q_tokens = int(cu_seqlens_q[-1].item()) + + torch_q_out, torch_k_out, torch_v_out = torch_pre_attn_qkv_a2a_reference(group, q_input, k_input, v_input) + torch_out = torch_attention_reference(torch_q_out, torch_k_out, torch_v_out, is_causal, local_q_start_offsets) + + torch_out_local = [] + tilescale_out_local = [] + torch_cursor = 0 + tilescale_cursor = 0 + for local_len in local_q_batch_lens: + torch_start = torch_cursor + rank * local_len + torch_end = torch_start + local_len + torch_out_local.append(torch_out[torch_start:torch_end]) + + tilescale_start = tilescale_cursor + tilescale_end = tilescale_start + local_len + tilescale_out_local.append(tilescale_out[tilescale_start:tilescale_end]) + + torch_cursor += local_len * num_ranks + tilescale_cursor += local_len + + torch_out = torch.cat(torch_out_local, dim=0) + tilescale_out = torch.cat(tilescale_out_local, dim=0) + + atol = 1e-2 + rtol = 1e-2 + if torch.allclose(torch_out, tilescale_out, atol=atol, rtol=rtol): + print(f"rank {local_rank} check passed.✅") + else: + diff = torch.abs(torch_out - tilescale_out) + print(f"rank {local_rank} check failed.❌ max_diff={torch.max(diff).item():.6f}") + print(f"torch_out: {torch_out.shape}, tilelang_out: {tilescale_out.shape}") + + tl_t = perf_fn(lambda: tilescale_module(packed_qkv_shards, cu_seqlens_q, cu_seqlens_k), warmup=5, rep=5) + # if isinstance(tl_t, (tuple, list)): + # tl_t = tl_t[0] + # if isinstance(tl_t, torch.Tensor): + # tl_t = tl_t.detach().float() + # tl_t = tl_t.mean().item() if tl_t.numel() > 1 else tl_t.item() + # else: + # tl_t = float(tl_t) + print(f"rank {local_rank} tilescale time: {tl_t:.2f} ms") + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--batch_size", type=int, default=2, help="batch size") + parser.add_argument("--q_head", type=int, default=32, help="local num q heads per rank") + parser.add_argument("--kv_head", type=int, default=8, help="local num kv heads per rank") + parser.add_argument("--max_seqlen_q", type=int, default=4096, help="max sequence length of q") + parser.add_argument("--max_seqlen_k", type=int, default=4096, help="max sequence length of k/v") + parser.add_argument("--head_dim", type=int, default=128, help="head dim") + parser.add_argument("--seqlens_q", type=int, nargs="+", default=[4096, 4096], help="sequence lengths of q") + parser.add_argument("--seqlens_k", type=int, nargs="+", default=[4096, 4096], help="sequence lengths of k/v") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument( + "--zig-zag", + "--no-zig-zag", + action=argparse.BooleanOptionalAction, + default=False, + help="enable zig zag opt", + ) + parser.add_argument("--print_source", action="store_true", help="print kernel source") + + args = parser.parse_args() + num_processes = args.num_processes + + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) diff --git a/examples/distributed/sp_all2all_attention_intra_node.py b/examples/distributed/sp_all2all_attention_intra_node.py new file mode 100644 index 0000000000..cb8eed1ddd --- /dev/null +++ b/examples/distributed/sp_all2all_attention_intra_node.py @@ -0,0 +1,569 @@ +import torch +import tilelang +import tilelang.language as T +from typing import List +from dataclasses import dataclass +from cuda import cudart +from tilelang.distributed.utils import CUDA_CHECK + + +@tilelang.jit +def barrier_all_blocks_sys_kernel(num_local_rank): + @T.prim_func + def main(barrier: T.Tensor((num_local_rank), "int32")): + with T.Kernel(1, threads=32): + T.barrier_blocks(barrier) + + return main + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + ], +) +def flashattn_packed( + batch_size, + groups, + UQ, + UKV, + heads, + dim, + is_causal, + enable_zig_zag, + rank, + num_ranks, + block_M=64, + block_N=64, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + packed_heads = heads + 2 * head_kv + packed_shape = [UQ, packed_heads, dim] + o_shape = [UQ, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.macro + def inner_packed( + Packed_unpad: T.Tensor(packed_shape, dtype), + Output_unpad: T.Tensor(o_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + O_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + q_load_start_idx: T.int32, + q_write_start_idx: T.int32, + k_start_idx: T.int32, + q_current_seqlen: T.int32, + k_current_seqlen: T.int32, + bx: T.int32, + head_idx: T.int32, + kv_head_idx: T.int32, + global_offset_q: T.int32, + kv_len_per_sp_block: T.int32, + ): + q_head_offset = 0 + k_head_offset = heads + v_head_offset = heads + head_kv + q_token_offset = rank * q_current_seqlen + + T.copy( + Packed_unpad[ + q_load_start_idx + q_token_offset + bx * block_M : q_load_start_idx + q_token_offset + (bx + 1) * block_M, + q_head_offset + head_idx, + :, + ], + Q_shared, + ) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + prefix_len = k_current_seqlen - q_current_seqlen * num_ranks + loop_range = ( + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) if is_causal else T.ceildiv(k_current_seqlen, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + sp_block_idx = (k * block_N) // kv_len_per_sp_block + wait_rank = sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1 + kv_load_offset = ( + (k * block_N) % kv_len_per_sp_block + + sp_block_idx // num_ranks * kv_len_per_sp_block + + wait_rank * (k_current_seqlen // num_ranks) + ) + T.copy( + Packed_unpad[k_start_idx + kv_load_offset : k_start_idx + kv_load_offset + block_N, k_head_offset + kv_head_idx, :], + K_shared, + ) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -1e9, + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -1e9, + 0, + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy( + Packed_unpad[k_start_idx + kv_load_offset : k_start_idx + kv_load_offset + block_N, v_head_offset + kv_head_idx, :], + V_shared, + ) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_write_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + @T.prim_func + def main_packed( + Packed_unpad: T.Tensor(packed_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_write_start_idx = cu_seqlens_q[batch_idx] + q_load_start_idx = cu_seqlens_k[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_write_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + global_offset_q = q_current_seqlen * rank + kv_len_per_sp_block = k_current_seqlen // num_ranks + + inner_packed( + Packed_unpad, + Output_unpad, + Q_shared, + K_shared, + V_shared, + O_shared, + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + scores_scale, + scores_sum, + logsum, + q_load_start_idx, + q_write_start_idx, + k_start_idx, + q_current_seqlen, + k_current_seqlen, + bx, + head_idx, + kv_head_idx, + global_offset_q, + kv_len_per_sp_block, + ) + + @T.prim_func + def main_packed_zigzag( + Packed_unpad: T.Tensor(packed_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_write_start_idx = cu_seqlens_q[batch_idx] + q_load_start_idx = cu_seqlens_k[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_write_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + half_q_shard_len = q_current_seqlen // 2 + global_offset_q = ( + rank * half_q_shard_len if bx * block_M < half_q_shard_len else q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + ) + kv_len_per_sp_block = k_current_seqlen // (2 * num_ranks) + + inner_packed( + Packed_unpad, + Output_unpad, + Q_shared, + K_shared, + V_shared, + O_shared, + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + scores_scale, + scores_sum, + logsum, + q_load_start_idx, + q_write_start_idx, + k_start_idx, + q_current_seqlen, + k_current_seqlen, + bx, + head_idx, + kv_head_idx, + global_offset_q, + kv_len_per_sp_block, + ) + + return main_packed if not enable_zig_zag else main_packed_zigzag + + +# def packed_sp_all2all_attention( +# packed_qkv: torch.Tensor, +# output: torch.Tensor, +# cu_seqlens_q: torch.Tensor, +# cu_seqlens_k: torch.Tensor, +# max_seqlen_q: int, +# batch_size: int, +# q_heads: int, +# kv_heads: int, +# groups: int, +# rank: int, +# num_ranks: int, +# is_causal: bool = True, +# enable_zig_zag: bool = True, +# block_M: int = 128, +# block_N: int = 128, +# num_stages: int = 2, +# threads: int = 256, +# ): +# kernel = flashattn_packed( +# batch_size, +# groups, +# packed_qkv.shape[0], +# packed_qkv.shape[0], +# q_heads, +# packed_qkv.shape[-1], +# is_causal, +# enable_zig_zag, +# rank, +# num_ranks, +# block_M=block_M, +# block_N=block_N, +# num_stages=num_stages, +# threads=threads, +# ) +# kernel(packed_qkv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, output) +# return output + + +@dataclass +class SPAll2AllAttentionContextIntraNode: + ag_packed_buffers: List[torch.Tensor] + ag_packed_buffer: torch.Tensor + attn_output_buffer: torch.Tensor + ag_stream: torch.cuda.Stream + barrier: torch.Tensor + + +def create_sp_all2all_attention_context_intra_node( + batch_size, + q_head, + kv_head, + max_seqlen_k, + max_q_shard_len, + head_dim, + input_dtype, + output_dtype, + rank, + world_size, + device, + allocator, +): + packed_heads = q_head + 2 * kv_head + ag_packed_buffers = tilelang.tensor( + (batch_size * max_seqlen_k, packed_heads, head_dim), + dtype=input_dtype, + allocator=allocator, + return_peers=True, + ) + ag_packed_buffer = ag_packed_buffers[rank] + + attn_output_buffer = torch.empty( + batch_size * max_seqlen_k, + q_head, + head_dim, + dtype=output_dtype, + device=device, + ) + + barrier = tilelang.tensor((world_size), dtype=torch.int32, allocator=allocator) + ag_stream = torch.cuda.Stream() + + return SPAll2AllAttentionContextIntraNode( + ag_packed_buffers=ag_packed_buffers, + ag_packed_buffer=ag_packed_buffer, + attn_output_buffer=attn_output_buffer, + ag_stream=ag_stream, + barrier=barrier, + ) + + +def barrier_all_on_stream(barrier: torch.Tensor, stream: torch.cuda.Stream, world_size: int): + barrier_all_blocks_sys_func = barrier_all_blocks_sys_kernel(world_size) + with torch.cuda.stream(stream): + barrier_all_blocks_sys_func(barrier) + + +def cp_engine_producer_packed_all2all( + packed_shards: list[torch.Tensor], + packed_buffer: torch.Tensor, + packed_buffers: list[torch.Tensor], + cu_seqlens_k: torch.Tensor, + rank: int, + world_size: int, + ag_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + barrier: torch.Tensor, +): + assert packed_buffer.is_contiguous() + assert packed_shards[rank].is_contiguous() + + _, packed_heads_total, head_dim = packed_shards[rank].shape + assert packed_heads_total % world_size == 0 + packed_heads_per_rank = packed_heads_total // world_size + batch_size = cu_seqlens_k.shape[0] - 1 + dtype_itemsize = packed_shards[rank].dtype.itemsize + src_token_bytes = packed_heads_total * head_dim * dtype_itemsize + dst_token_bytes = packed_heads_per_rank * head_dim * dtype_itemsize + + def _cp_engine_copy_data(dst_ptr, src_ptr, cp_size, stream): + (err,) = cudart.cudaMemcpyAsync( + dst_ptr, + src_ptr, + cp_size, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + stream.cuda_stream, + ) + CUDA_CHECK(err) + + def _cp_engine_copy_2d(dst_ptr, src_ptr, width_bytes, height, src_pitch, dst_pitch, stream): + (err,) = cudart.cudaMemcpy2DAsync( + dst_ptr, + dst_pitch, + src_ptr, + src_pitch, + width_bytes, + height, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + stream.cuda_stream, + ) + CUDA_CHECK(err) + + # self copy: src rank == local rank + with torch.cuda.stream(compute_stream): + for i in range(batch_size): + cu_seqlens_k_start = cu_seqlens_k[i].item() + cu_seqlens_k_end = cu_seqlens_k[i + 1].item() + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + local_seq_len = seqlen_k // world_size + src_token_start = cu_seqlens_k_start // world_size + dst_token_start = cu_seqlens_k_start + rank * local_seq_len + + src_head_offset_bytes = rank * packed_heads_per_rank * head_dim * dtype_itemsize + src_ptr = packed_shards[rank].data_ptr() + src_token_start * src_token_bytes + src_head_offset_bytes + dst_ptr = packed_buffers[rank].data_ptr() + dst_token_start * dst_token_bytes + _cp_engine_copy_2d( + dst_ptr, + src_ptr, + dst_token_bytes, + local_seq_len, + src_token_bytes, + dst_token_bytes, + compute_stream, + ) + + # pull from remote src ranks into local destination sequence slots + with torch.cuda.stream(ag_stream): + for i in range(batch_size): + cu_seqlens_k_start = cu_seqlens_k[i].item() + cu_seqlens_k_end = cu_seqlens_k[i + 1].item() + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + local_seq_len = seqlen_k // world_size + src_token_start = cu_seqlens_k_start // world_size + src_head_offset_bytes = rank * packed_heads_per_rank * head_dim * dtype_itemsize + for offset in range(1, world_size): + src_rank = (rank + offset) % world_size + dst_token_start = cu_seqlens_k_start + src_rank * local_seq_len + src_ptr = packed_shards[src_rank].data_ptr() + src_token_start * src_token_bytes + src_head_offset_bytes + dst_ptr = packed_buffers[rank].data_ptr() + dst_token_start * dst_token_bytes + _cp_engine_copy_2d( + dst_ptr, + src_ptr, + dst_token_bytes, + local_seq_len, + src_token_bytes, + dst_token_bytes, + ag_stream, + ) + + barrier_all_on_stream(barrier, ag_stream, world_size) + compute_stream.wait_stream(ag_stream) + + +def fused_sp_all2all_attn_intra_node( + ctx: SPAll2AllAttentionContextIntraNode, + packed_qkv_shards: list[torch.Tensor], + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + rank: int, + world_size: int, + q_head: int, + kv_head: int, + is_causal: bool = True, + enable_zig_zag: bool = True, + print_source: bool = False, +): + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + + compute_stream = torch.cuda.current_stream() + ag_packed = ctx.ag_packed_buffers[rank] + + ctx.ag_stream.wait_stream(compute_stream) + cp_engine_producer_packed_all2all( + packed_qkv_shards, + ag_packed, + ctx.ag_packed_buffers, + cu_seqlens_k, + rank, + world_size, + ctx.ag_stream, + compute_stream, + ctx.barrier, + ) + + head_dim = packed_qkv_shards[rank].shape[-1] + groups = q_head // kv_head + batch = cu_seqlens_q.shape[0] - 1 + + with torch.cuda.stream(compute_stream): + kernel = flashattn_packed( + batch, + groups, + ag_packed.shape[0], + ag_packed.shape[0], + q_head, + head_dim, + is_causal, + enable_zig_zag, + rank, + world_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + ) + + if rank == 0 and print_source: + print(kernel.get_kernel_source()) + + kernel(ag_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, output) + + compute_stream.wait_stream(ctx.ag_stream) + barrier_all_on_stream(ctx.barrier, compute_stream, world_size)