diff --git a/examples/distributed/example_gemm_rs_overlapped.py b/examples/distributed/example_gemm_rs_overlapped.py index 27519fdd51..4fb1c6d434 100644 --- a/examples/distributed/example_gemm_rs_overlapped.py +++ b/examples/distributed/example_gemm_rs_overlapped.py @@ -30,6 +30,15 @@ def gemm_kernel(M, M_per_rank = T.ceildiv(M, num_local_rank) GROUP_SIZE_M = 8 + def swizzle_2d(tile_id, num_pid_m, num_pid_n): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = T.min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + @T.prim_func def main( A: T.Tensor((M, K // num_local_rank), dtype), @@ -47,19 +56,10 @@ def main( num_pid_m = T.ceildiv(M, block_M) num_pid_n = T.ceildiv(N, block_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = bid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = T.min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m_ = first_pid_m + ((bid % num_pid_in_group) % group_size_m) - pid_n_ = (bid % num_pid_in_group) // group_size_m - - # threadblock swizzle - # no stream-k support. only split by m x n - m_offset = M_per_rank * local_rank - pid_m_offset = T.ceildiv(m_offset, block_M) + + pid_m_, pid_n = swizzle_2d(bid, num_pid_m, num_pid_n) + pid_m_offset = (local_rank + 1) * M_per_rank // block_M pid_m = (pid_m_ + pid_m_offset) % num_pid_m - pid_n = pid_n_ tid = T.get_thread_binding(0) T.clear(C_local) diff --git a/examples/distributed/example_sp_ag_attention_intra_node.py b/examples/distributed/example_sp_ag_attention_intra_node.py new file mode 100644 index 0000000000..c4d120fea4 --- /dev/null +++ b/examples/distributed/example_sp_ag_attention_intra_node.py @@ -0,0 +1,449 @@ +from __future__ import annotations + +import tilelang +import argparse +from itertools import accumulate +import torch +import torch.distributed as dist +import torch.multiprocessing +from tilelang.distributed import init_dist +from tilelang.distributed import perf_fn +from sp_ag_attention_intra_node import ( + create_sp_ag_attention_context_intra_node, + fused_sp_ag_attn_intra_node, +) + +tilelang.disable_cache() + + +class FusedSequenceParallelAttn(torch.nn.Module): + + def __init__( + self, + pg: torch.distributed.ProcessGroup, + batch_size: int, + q_head: int, + kv_head: int, + max_seqlen_q: int, + max_seqlen_k: int, + head_dim: int, + input_dtype=torch.float16, + output_dtype=torch.float16, + device="cuda", + is_causal=True, + enable_zig_zag=True, + enable_specialized=False, + allocator=None, + ): + super(FusedSequenceParallelAttn, self).__init__() + self.pg = pg + self.rank = pg.rank() + self.world_size = pg.size() + + self.batch_size = batch_size + self.q_head = q_head + self.kv_head = kv_head + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_k = max_seqlen_k + self.head_dim = head_dim + + assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size + == 0), f"sequence length should be multiple of world_size({self.world_size})" + self.max_q_shard_len = self.max_seqlen_q // self.world_size + + self.input_dtype = input_dtype + self.output_dtype = output_dtype + assert self.input_dtype == self.output_dtype + self.device = device + self.is_causal = is_causal + self.enable_zig_zag = enable_zig_zag + self.enable_specialized = enable_specialized + self.allocator = allocator + + self.ctx = create_sp_ag_attention_context_intra_node( + self.batch_size, + self.q_head, + self.kv_head, + self.max_seqlen_k, + self.max_q_shard_len, + self.head_dim, + self.input_dtype, + self.output_dtype, + self.rank, + self.world_size, + self.device, + self.allocator, + ) + + def forward(self, q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k, print_source=False): + total_q_shard = cu_seqlens_q[-1] + output_buffer = self.ctx.attn_output_buffer[:total_q_shard] + + fused_sp_ag_attn_intra_node( + self.ctx, + q_shard, + k_shards, + v_shards, + output_buffer, + cu_seqlens_q, + cu_seqlens_k, + self.max_q_shard_len, + self.max_seqlen_k, + self.rank, + self.world_size, + self.is_causal, + self.enable_zig_zag, + self.enable_specialized, + print_source, + ) + + return output_buffer + + +class TorchSequenceParallelAttn(torch.nn.Module): + + def __init__( + self, + pg: torch.distributed.ProcessGroup, + batch_size: int, + q_head: int, + kv_head: int, + max_seqlen_q: int, + max_seqlen_k: int, + head_dim: int, + input_dtype=torch.float16, + output_dtype=torch.float16, + device="cuda", + is_causal=True, + enable_zig_zag=True, + ): + super(TorchSequenceParallelAttn, self).__init__() + self.pg = pg + self.rank = pg.rank() + self.world_size = pg.size() + + self.batch_size = batch_size + self.q_head = q_head + self.kv_head = kv_head + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_k = max_seqlen_k + self.head_dim = head_dim + + self.input_dtype = input_dtype + self.output_dtype = output_dtype + self.device = device + self.is_causal = is_causal + self.enable_zig_zag = enable_zig_zag + assert self.input_dtype == self.output_dtype + + self.max_q_shard_len = max_seqlen_q // self.world_size + self.max_kv_shard_ken = max_seqlen_q // self.world_size + assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size + == 0), f"sequence length should be multiple of world_size({self.world_size})" + + self.ag_k_buffer: torch.Tensor = torch.empty( + self.batch_size * self.max_seqlen_k, + self.kv_head, + self.head_dim, + dtype=self.input_dtype, + device=self.device, + ) + self.ag_v_buffer: torch.Tensor = torch.empty( + self.batch_size * self.max_seqlen_k, + self.kv_head, + self.head_dim, + dtype=self.input_dtype, + device=self.device, + ) + + def forward(self, q_shard, k_shard, v_shard, cu_seqlens_q, cu_seqlens_k): + # construct casual mask with offset + def _gen_mask(offset, q_shard_len, kv_len): + if self.is_causal: + mask = torch.zeros((q_shard_len, kv_len), dtype=torch.bool, device=self.device) + mask[:, :offset + q_shard_len] = True + if offset < kv_len: + mask[:, offset:offset + q_shard_len].tril_() + return mask + return None + + batch_size = cu_seqlens_q.shape[0] - 1 + total_q_shard, q_head, head_dim = q_shard.shape + total_kv_shard, kv_head, head_dim = k_shard.shape + out_list = [] + for i in range(batch_size): + cu_seqlens_q_start = cu_seqlens_q[i].item() + cu_seqlens_q_end = cu_seqlens_q[i + 1].item() + q_shard_len = cu_seqlens_q_end - cu_seqlens_q_start + q_len = q_shard_len * self.world_size + + cu_seqlens_k_start = cu_seqlens_k[i].item() // self.world_size + cu_seqlens_k_end = cu_seqlens_k[i + 1].item() // self.world_size + kv_shard_len = cu_seqlens_k_end - cu_seqlens_k_start + kv_len = kv_shard_len * self.world_size + + if self.enable_zig_zag: + half_q_shard_len = q_shard_len // 2 + half_kv_shard_len = kv_shard_len // 2 + + q0_shard = q_shard[cu_seqlens_q_start:cu_seqlens_q_start + + half_q_shard_len, :, :].clone() + q1_shard = q_shard[cu_seqlens_q_start + + half_q_shard_len:cu_seqlens_q_end, :, :].clone() + + q0_shard_permute = torch.permute( + q0_shard.reshape(1, half_q_shard_len, q_head, head_dim), + (0, 2, 1, 3)).contiguous() + q1_shard_permute = torch.permute( + q1_shard.reshape(1, half_q_shard_len, q_head, head_dim), + (0, 2, 1, 3)).contiguous() + + k0_shard = k_shard[cu_seqlens_k_start:cu_seqlens_k_start + + half_kv_shard_len, :, :].clone() + k1_shard = k_shard[cu_seqlens_k_start + + half_kv_shard_len:cu_seqlens_k_end, :, :].clone() + v0_shard = v_shard[cu_seqlens_k_start:cu_seqlens_k_start + + half_kv_shard_len, :, :].clone() + v1_shard = v_shard[cu_seqlens_k_start + + half_kv_shard_len:cu_seqlens_k_end, :, :].clone() + + buffer_size = (half_kv_shard_len * kv_head * head_dim * self.world_size) + + ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape( + half_kv_shard_len * self.world_size, kv_head, head_dim) + ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape( + half_kv_shard_len * self.world_size, kv_head, head_dim) + ag_v0 = self.ag_v_buffer.reshape(-1)[:buffer_size].reshape( + half_kv_shard_len * self.world_size, kv_head, head_dim) + ag_v1 = self.ag_v_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape( + half_kv_shard_len * self.world_size, kv_head, head_dim) + torch.distributed.all_gather_into_tensor( + ag_k0, + k0_shard, + group=self.pg, + ) + torch.distributed.all_gather_into_tensor( + ag_k1, + k1_shard, + group=self.pg, + ) + torch.distributed.all_gather_into_tensor( + ag_v0, + v0_shard, + group=self.pg, + ) + torch.distributed.all_gather_into_tensor( + ag_v1, + v1_shard, + group=self.pg, + ) + ag_k1 = ag_k1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim) + ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, + head_dim) + ag_k = torch.cat((ag_k0, ag_k1), dim=0) + ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), + (0, 2, 1, 3)).contiguous() + ag_k = ag_k.repeat_interleave(q_head // kv_head, -3) + + ag_v1 = ag_v1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim) + ag_v1 = torch.flip(ag_v1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, + head_dim) + ag_v = torch.cat((ag_v0, ag_v1), dim=0) + ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), + (0, 2, 1, 3)).contiguous() + ag_v = ag_v.repeat_interleave(q_head // kv_head, -3) + + offset_q0 = half_q_shard_len * self.rank + offset_q1 = q_len - half_q_shard_len * (self.rank + 1) + prefix = kv_len - q_len + mask0 = _gen_mask(prefix + offset_q0, half_q_shard_len, kv_len) + mask1 = _gen_mask(prefix + offset_q1, half_q_shard_len, kv_len) + out0 = torch.nn.functional.scaled_dot_product_attention( + q0_shard_permute, ag_k, ag_v, attn_mask=mask0) + out1 = torch.nn.functional.scaled_dot_product_attention( + q1_shard_permute, ag_k, ag_v, attn_mask=mask1) + out = torch.cat((out0, out1), dim=2) # [1, q_head, q_shard_len, head_dim] + else: + cu_q_shard = q_shard[cu_seqlens_q_start:cu_seqlens_q_end, :, :].clone() + cu_q_shard_permute = torch.permute( + cu_q_shard.reshape(1, q_shard_len, q_head, head_dim), + (0, 2, 1, 3)).contiguous() + + total_size = kv_len * kv_head * head_dim + ag_k = self.ag_k_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim) + cu_k_shard = k_shard[cu_seqlens_k_start:cu_seqlens_k_end, :, :].clone() + torch.distributed.all_gather_into_tensor( + ag_k, + cu_k_shard, + group=self.pg, + ) + ag_v = self.ag_v_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim) + cu_v_shard = v_shard[cu_seqlens_k_start:cu_seqlens_k_end, :, :].clone() + torch.distributed.all_gather_into_tensor( + ag_v, + cu_v_shard, + group=self.pg, + ) + ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), + (0, 2, 1, 3)).contiguous() + ag_k = ag_k.repeat_interleave(q_head // kv_head, -3) + ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), + (0, 2, 1, 3)).contiguous() + ag_v = ag_v.repeat_interleave(q_head // kv_head, -3) + + offset = self.rank * q_shard_len + prefix = kv_len - q_len + mask = _gen_mask(prefix + offset, q_shard_len, kv_len) + out = torch.nn.functional.scaled_dot_product_attention( + cu_q_shard_permute, ag_k, ag_v, + attn_mask=mask) # [1, q_head, q_shard_len, head_dim] + + out = torch.permute(out.reshape(q_head, q_shard_len, head_dim), (1, 0, 2)).contiguous() + out_list.append(out) + + output = torch.cat(out_list) + + return output + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = torch.float16 + device = "cuda" + + batch_size = args.batch_size + q_head = args.q_head + kv_head = args.kv_head + max_seqlen_q = args.max_seqlen_q + max_seqlen_k = args.max_seqlen_k + head_dim = args.head_dim + is_causal = args.is_causal + enable_zig_zag = args.zig_zag + enable_specialized = args.enable_specialized + seqlens_q = args.seqlens_q + cu_seqlens_q_list = [0] + list(accumulate(seqlens_q)) + seqlens_k = args.seqlens_k + cu_seqlens_k_list = [0] + list(accumulate(seqlens_k)) + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" + allocator = tilelang.get_allocator( + size=2**30, + device=device, + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group) + + cu_seqlens_q = torch.tensor(cu_seqlens_q_list, dtype=torch.int32, device=device) + cu_seqlens_q = cu_seqlens_q // num_local_ranks + cu_seqlens_k = torch.tensor(cu_seqlens_k_list, dtype=torch.int32, device=device) + + q_shard = tilelang.tensor((cu_seqlens_q[-1], q_head, head_dim), + dtype=dtype, + allocator=allocator).normal_( + mean=0.0, std=0.5) + k_shards = tilelang.tensor((cu_seqlens_k[-1] // num_local_ranks, kv_head, head_dim), + dtype=dtype, + allocator=allocator, + return_peers=True) + v_shards = tilelang.tensor((cu_seqlens_k[-1] // num_local_ranks, kv_head, head_dim), + dtype=dtype, + allocator=allocator, + return_peers=True) + k_shards[local_rank].normal_(mean=0.0, std=0.5) + v_shards[local_rank].normal_(mean=0.0, std=0.5) + + dist.barrier() + + tilescale_module = FusedSequenceParallelAttn( + group, + batch_size, + q_head, + kv_head, + max_seqlen_q, + max_seqlen_k, + head_dim, + dtype, + dtype, + device, + is_causal, + enable_zig_zag, + enable_specialized, + allocator=allocator, + ) + torch_module = TorchSequenceParallelAttn( + group, + batch_size, + q_head, + kv_head, + max_seqlen_q, + max_seqlen_k, + head_dim, + dtype, + dtype, + device, + is_causal, + enable_zig_zag, + ) + + tilescale_out = tilescale_module( + q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k, print_source=True) + print(f"tilescale_out: {tilescale_out.shape}") + + torch_out = torch_module(q_shard, k_shards[local_rank], v_shards[local_rank], cu_seqlens_q, + cu_seqlens_k) + print(f"torch_out: {torch_out.shape}") + + atol = 1e-2 + rtol = 1e-2 + if torch.allclose(torch_out, tilescale_out, atol=atol, rtol=rtol): + print(f"rank {local_rank} check passed.✅") + else: + print(f"rank {local_rank} check failed.❌") + print(f"torch_out: {torch_out}, tilelang_out: {tilescale_out}") + + _, tl_t = perf_fn( + lambda: tilescale_module(q_shard, k_shards, v_shards, cu_seqlens_q, cu_seqlens_k), + warmup=5, + rep=5) + + print(f"rank {local_rank} tilescale time: {tl_t:.2f} ms") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--num-processes', type=int, default=1, help='Number of processes to spawn (default: 2)') + parser.add_argument("--batch_size", type=int, default=2, help="batch size") + parser.add_argument("--q_head", type=int, default=32, help="num q heads") + parser.add_argument("--kv_head", type=int, default=8, help="num kv heads") + parser.add_argument("--max_seqlen_q", type=int, default=8192, help="max sequence length of q") + parser.add_argument( + "--max_seqlen_k", type=int, default=12288, help="max sequence length of k/v") + parser.add_argument("--head_dim", type=int, default=128, help="head dim") + parser.add_argument( + "--seqlens_q", type=int, nargs='+', default=[4096, 8192], help="sequence lengths of q") + parser.add_argument( + "--seqlens_k", type=int, nargs='+', default=[6144, 12288], help="sequence lengths of k/v") + parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument( + "--zig-zag", + "--no-zig-zag", + action=argparse.BooleanOptionalAction, + default=True, + help="enable zig zag opt", + ) + parser.add_argument( + "--enable-specialized", + "--disable-specialized", + action=argparse.BooleanOptionalAction, + default=False, + help="enable specialized optimized version", + ) + + args = parser.parse_args() + num_processes = args.num_processes + + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) diff --git a/examples/distributed/sp_ag_attention_intra_node.py b/examples/distributed/sp_ag_attention_intra_node.py new file mode 100644 index 0000000000..42e5493e42 --- /dev/null +++ b/examples/distributed/sp_ag_attention_intra_node.py @@ -0,0 +1,783 @@ +import torch +import tilelang +import tilelang.language as T +from typing import List +from dataclasses import dataclass +from cuda import cudart +from tilelang.distributed.utils import CUDA_CHECK + +tilelang.disable_cache() + + +@tilelang.jit +def barrier_all_blocks_sys_kernel(num_local_rank,): + + @T.prim_func + def main(barrier: T.Tensor((num_local_rank), "int32"),): + with T.Kernel(1, threads=32): + T.barrier_all_blocks_sys(barrier) + + return main + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + compile_flags=[ + "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + ], +) +def flashattn(batch_size, + groups, + UQ, + UKV, + heads, + dim, + is_causal, + enable_zig_zag, + enable_specialized, + rank, + num_ranks, + block_M=64, + block_N=64, + num_stages=1, + threads=128): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.macro + def inner( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + Output_unpad: T.Tensor(o_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + O_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + q_start_idx: T.int32, + k_start_idx: T.int32, + v_start_idx: T.int32, + q_current_seqlen: T.int32, + k_current_seqlen: T.int32, + bx: T.int32, + head_idx: T.int32, + kv_head_idx: T.int32, + global_offset_q: T.int32, + kv_len_per_sp_block: T.int32, + ): + T.copy(Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], + Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + prefix_len = k_current_seqlen - q_current_seqlen * num_ranks + loop_range = ( + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) + if is_causal else T.ceildiv(k_current_seqlen, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + sp_block_idx = (k * block_N) // kv_len_per_sp_block + wait_rank = ( + sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - 1) + kv_load_offset = ((k * block_N) % kv_len_per_sp_block + + sp_block_idx // num_ranks * kv_len_per_sp_block + wait_rank * + (k_current_seqlen // num_ranks)) + T.copy( + K_unpad[k_start_idx + kv_load_offset:k_start_idx + kv_load_offset + block_N, + kv_head_idx, :], K_shared) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) or + (bx * block_M + i >= q_current_seqlen or + k * block_N + j >= k_current_seqlen), -1e9, 0) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or + k * block_N + j >= k_current_seqlen), -1e9, 0) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy( + V_unpad[v_start_idx + kv_load_offset:v_start_idx + kv_load_offset + block_N, + kv_head_idx, :], V_shared) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel( + T.ceildiv(max_seqlen_q, block_M), heads, batch_size, + threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + v_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + global_offset_q = q_current_seqlen * rank + kv_len_per_sp_block = k_current_seqlen // num_ranks + + inner(Q_unpad, K_unpad, V_unpad, Output_unpad, Q_shared, K_shared, V_shared, O_shared, + acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum, q_start_idx, k_start_idx, v_start_idx, q_current_seqlen, k_current_seqlen, + bx, head_idx, kv_head_idx, global_offset_q, kv_len_per_sp_block) + + @T.prim_func + def main_zigzag( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel( + T.ceildiv(max_seqlen_q, block_M), heads, batch_size, + threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + v_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + half_q_shard_len = q_current_seqlen // 2 + global_offset_q = rank * half_q_shard_len if bx * block_M < half_q_shard_len else \ + q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + kv_len_per_sp_block = k_current_seqlen // (2 * num_ranks) + + inner(Q_unpad, K_unpad, V_unpad, Output_unpad, Q_shared, K_shared, V_shared, O_shared, + acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum, q_start_idx, k_start_idx, v_start_idx, q_current_seqlen, k_current_seqlen, + bx, head_idx, kv_head_idx, global_offset_q, kv_len_per_sp_block) + + @T.prim_func + def main_specialized( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel( + T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=384) as (bx_, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + bar_q_ready = T.alloc_barrier(arrive_count=128) + bar_k_ready = T.alloc_barrier(arrive_count=128) + bar_v_ready = T.alloc_barrier(arrive_count=128) + bar_k_release = T.alloc_barrier(arrive_count=256) + bar_v_release = T.alloc_barrier(arrive_count=256) + + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + }) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + v_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + global_offset_q = q_current_seqlen * rank + tid = T.get_thread_binding(0) + + bx = T.ceildiv(max_seqlen_q, block_M) - bx_ - 1 + + if tid < 256: + T.set_max_nreg(240, 1) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + prefix_len = k_current_seqlen - q_current_seqlen * num_ranks + loop_range = ( + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) + if is_causal else T.ceildiv(k_current_seqlen, block_N)) + + T.barrier_wait(bar_q_ready, 0) + for k in T.serial(loop_range): + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or + k * block_N + j >= k_current_seqlen), -1e9, 0) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or + k * block_N + j >= k_current_seqlen), + -1e9, 0) + + T.barrier_wait(bar_k_ready, k % 2) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + T.barrier_arrive(bar_k_release) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.barrier_wait(bar_v_ready, k % 2) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.barrier_arrive(bar_v_release) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + elif tid < 384: + T.set_max_nreg(24, 0) + prefix_len = k_current_seqlen - q_current_seqlen * num_ranks + loop_range = ( + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) + if is_causal else T.ceildiv(k_current_seqlen, block_N)) + T.copy( + Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, + head_idx, :], Q_shared) + T.barrier_arrive(bar_q_ready) + for k in T.serial(loop_range): + T.barrier_wait(bar_k_release, (k + 1) % 2) + T.copy( + K_unpad[k_start_idx + (k * block_N):k_start_idx + (k * block_N) + block_N, + kv_head_idx, :], K_shared) + T.barrier_arrive(bar_k_ready) + T.barrier_wait(bar_v_release, (k + 1) % 2) + T.copy( + V_unpad[v_start_idx + (k * block_N):v_start_idx + (k * block_N) + block_N, + kv_head_idx, :], V_shared) + T.barrier_arrive(bar_v_ready) + + @T.prim_func + def main_specialized_zigzag( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel( + T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=384) as (bx_, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + kv_load_offset = T.alloc_var("int32") + + bar_q_ready = T.alloc_barrier(arrive_count=128) + bar_k_ready = T.alloc_barrier(arrive_count=128) + bar_v_ready = T.alloc_barrier(arrive_count=128) + bar_k_release = T.alloc_barrier(arrive_count=256) + bar_v_release = T.alloc_barrier(arrive_count=256) + + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + }) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + v_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + bx = T.ceildiv(max_seqlen_q, block_M) - bx_ - 1 + + half_q_shard_len = q_current_seqlen // 2 + global_offset_q = rank * half_q_shard_len if bx * block_M < half_q_shard_len else \ + q_current_seqlen * num_ranks - (rank + 2) * half_q_shard_len + kv_len_per_sp_block = k_current_seqlen // (2 * num_ranks) + tid = T.get_thread_binding(0) + + if tid < 256: + T.set_max_nreg(240, 1) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + prefix_len = k_current_seqlen - q_current_seqlen * num_ranks + loop_range = ( + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) + if is_causal else T.ceildiv(k_current_seqlen, block_N)) + + T.barrier_wait(bar_q_ready, 0) + for k in T.serial(loop_range): + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (prefix_len + global_offset_q + bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or + k * block_N + j >= k_current_seqlen), -1e9, 0) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or + k * block_N + j >= k_current_seqlen), + -1e9, 0) + + T.barrier_wait(bar_k_ready, k % 2) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + T.barrier_arrive(bar_k_release) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.barrier_wait(bar_v_ready, k % 2) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.barrier_arrive(bar_v_release) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + elif tid < 384: + T.set_max_nreg(24, 0) + prefix_len = k_current_seqlen - q_current_seqlen * num_ranks + loop_range = ( + T.ceildiv(prefix_len + global_offset_q + (bx + 1) * block_M, block_N) + if is_causal else T.ceildiv(k_current_seqlen, block_N)) + T.copy( + Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, + head_idx, :], Q_shared) + T.barrier_arrive(bar_q_ready) + for k in T.serial(loop_range): + sp_block_idx = (k * block_N) // kv_len_per_sp_block + wait_rank = ( + sp_block_idx if sp_block_idx < num_ranks else 2 * num_ranks - sp_block_idx - + 1) + kv_load_offset = ((k * block_N) % kv_len_per_sp_block + + sp_block_idx // num_ranks * kv_len_per_sp_block + wait_rank * + (k_current_seqlen // num_ranks)) + T.barrier_wait(bar_k_release, (k + 1) % 2) + T.copy( + K_unpad[k_start_idx + kv_load_offset:k_start_idx + kv_load_offset + block_N, + kv_head_idx, :], K_shared) + T.barrier_arrive(bar_k_ready) + T.barrier_wait(bar_v_release, (k + 1) % 2) + T.copy( + V_unpad[v_start_idx + kv_load_offset:v_start_idx + kv_load_offset + block_N, + kv_head_idx, :], V_shared) + T.barrier_arrive(bar_v_ready) + + if enable_specialized: + return main_specialized if not enable_zig_zag else main_specialized_zigzag + return main if not enable_zig_zag else main_zigzag + + +@dataclass +class SPAllGatherAttentionContextIntraNode: + ag_k_buffers: List[torch.Tensor] + ag_k_buffer: torch.Tensor + # ag_k_buffers_ptr: torch.Tensor + ag_v_buffers: List[torch.Tensor] + ag_v_buffer: torch.Tensor + # ag_v_buffers_ptr: torch.Tensor + attn_output_buffer: torch.Tensor + ag_stream: torch.cuda.Stream + barrier: torch.Tensor + + +def create_sp_ag_attention_context_intra_node( + batch_size, + q_head, + kv_head, + max_seqlen_k, + max_q_shard_len, + head_dim, + input_dtype, + output_dtype, + rank, + world_size, + device, + allocator, +): + ag_k_buffers = tilelang.tensor((batch_size * max_seqlen_k, kv_head, head_dim), + dtype=input_dtype, + allocator=allocator, + return_peers=True) + ag_k_buffer = ag_k_buffers[rank] + + ag_v_buffers = tilelang.tensor((batch_size * max_seqlen_k, kv_head, head_dim), + dtype=input_dtype, + allocator=allocator, + return_peers=True) + ag_v_buffer = ag_v_buffers[rank] + + attn_output_buffer = torch.empty( + batch_size * max_q_shard_len, + q_head, + head_dim, + dtype=output_dtype, + device=device, + ) + + barrier = tilelang.tensor((world_size), dtype=torch.int32, allocator=allocator) + + # stream for copy + ag_stream = torch.cuda.Stream() + + ctx = SPAllGatherAttentionContextIntraNode( + ag_k_buffers=ag_k_buffers, + ag_k_buffer=ag_k_buffer, + ag_v_buffers=ag_v_buffers, + ag_v_buffer=ag_v_buffer, + attn_output_buffer=attn_output_buffer, + ag_stream=ag_stream, + barrier=barrier) + + return ctx + + +def barrier_all_on_stream(barrier: torch.Tensor, stream: torch.cuda.Stream, world_size: int): + barrier_all_blocks_sys_func = barrier_all_blocks_sys_kernel(world_size) + barrier_all_blocks_sys_func(barrier, stream=stream.cuda_stream) + + +def cp_engine_producer_kv_all_gather( + k_shards: list[torch.Tensor], # [total_kv_shard, kv_head, head_dim] + v_shards: list[torch.Tensor], # [total_kv_shard, kv_head, head_dim] + k_buffer: torch.Tensor, # [total_kv, kv_head, head_dim] + v_buffer: torch.Tensor, # [total_kv, kv_head, head_dim] + k_buffers: list[torch.Tensor], + v_buffers: list[torch.Tensor], + cu_seqlens_k: torch.Tensor, # kv_full_lens + rank: int, + world_size: int, + ag_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + barrier: torch.Tensor, +): + assert k_buffer.is_contiguous() + assert v_buffer.is_contiguous() + assert k_shards[rank].is_contiguous() + assert v_shards[rank].is_contiguous() + + total_kv_shard, kv_head, head_dim = k_shards[rank].shape + batch_size = cu_seqlens_k.shape[0] - 1 + byte_per_token = kv_head * head_dim * k_shards[rank].dtype.itemsize + + def _cp_engine_copy_data(dst_ptr, src_ptr, cp_size, stream): + (err,) = cudart.cudaMemcpyAsync( + dst_ptr, + src_ptr, + cp_size, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + stream.cuda_stream, + ) + + CUDA_CHECK(err) + + # local copy in compute stream + with torch.cuda.stream(compute_stream): + for i in range(batch_size): + cu_seqlens_k_start = cu_seqlens_k[i].item() + cu_seqlens_k_end = cu_seqlens_k[i + 1].item() + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + k_shard_len = seqlen_k // world_size + byte_start = cu_seqlens_k_start * byte_per_token + byte_per_rank = k_shard_len * byte_per_token + cp_size = byte_per_rank + + k_dst_ptr = k_buffers[rank].data_ptr() + byte_start + rank * byte_per_rank + k_src_ptr = k_shards[rank].data_ptr() + byte_start // world_size + _cp_engine_copy_data(k_dst_ptr, k_src_ptr, cp_size, compute_stream) + + v_dst_ptr = v_buffers[rank].data_ptr() + byte_start + rank * byte_per_rank + v_src_ptr = v_shards[rank].data_ptr() + byte_start // world_size + _cp_engine_copy_data(v_dst_ptr, v_src_ptr, cp_size, compute_stream) + + # barrier_all_on_stream(barrier, compute_stream, world_size) + # ag_stream.wait_stream(compute_stream) + + with torch.cuda.stream(ag_stream): + for i in range(batch_size): + cu_seqlens_k_start = cu_seqlens_k[i].item() + cu_seqlens_k_end = cu_seqlens_k[i + 1].item() + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + k_shard_len = seqlen_k // world_size + byte_start = cu_seqlens_k_start * byte_per_token + byte_per_rank = k_shard_len * byte_per_token + cp_size = byte_per_rank + for offset in range(1, world_size): + src_rank = (rank + offset) % world_size + + k_src_ptr = (k_shards[src_rank].data_ptr() + byte_start // world_size) + k_dst_ptr = (k_buffers[rank].data_ptr() + byte_start + src_rank * byte_per_rank) + _cp_engine_copy_data(k_dst_ptr, k_src_ptr, cp_size, ag_stream) + + v_src_ptr = (v_shards[src_rank].data_ptr() + byte_start // world_size) + v_dst_ptr = (v_buffers[rank].data_ptr() + byte_start + src_rank * byte_per_rank) + _cp_engine_copy_data(v_dst_ptr, v_src_ptr, cp_size, ag_stream) + + barrier_all_on_stream(barrier, ag_stream, world_size) + compute_stream.wait_stream(ag_stream) + + +def fused_sp_ag_attn_intra_node( + ctx: SPAllGatherAttentionContextIntraNode, + q_shard: torch.Tensor, # [total_q_shard, q_head, head_dim] + k_shards: list[torch.Tensor], # [total_kv_shard, kv_head, head_dim] + v_shards: list[torch.Tensor], # [total_kv_shard, kv_head, head_dim] + output: torch.Tensor, # [total_q_shard, q_head, head_dim] + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + rank: int, + world_size: int, + is_causal: bool = True, + enable_zig_zag: bool = True, + enable_specialized: bool = False, + print_source: bool = False, +): + + BLOCK_M = 128 + BLOCK_N = 128 + num_stages = 2 + threads = 256 + q_tokens = q_shard.shape[0] + assert ctx.ag_k_buffers[rank].shape[0] == ctx.ag_v_buffers[rank].shape[0] + kv_tokens = ctx.ag_k_buffers[rank].shape[0] + q_head = q_shard.shape[1] + kv_head = k_shards[rank].shape[1] + batch = cu_seqlens_q.shape[0] - 1 + + compute_stream = torch.cuda.current_stream() + ag_k = ctx.ag_k_buffers[rank] + ag_v = ctx.ag_v_buffers[rank] + + ctx.ag_stream.wait_stream(compute_stream) + # kv all gather + cp_engine_producer_kv_all_gather( + k_shards, + v_shards, + ag_k, + ag_v, + ctx.ag_k_buffers, + ctx.ag_v_buffers, + cu_seqlens_k, + rank, + world_size, + ctx.ag_stream, + compute_stream, + ctx.barrier, + ) + + HEAD_DIM_Q, HEAD_DIM_K = q_shard.shape[-1], k_shards[rank].shape[-1] + HEAD_DIM_V = v_shards[rank].shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + with torch.cuda.stream(compute_stream): + kernel = flashattn( + batch, + q_head // kv_head, + q_tokens, + kv_tokens, + q_head, + HEAD_DIM_Q, + is_causal, + enable_zig_zag, + enable_specialized, + rank, + world_size, + block_M=BLOCK_M, + block_N=BLOCK_N, + num_stages=num_stages, + threads=threads) + + if rank == 0 and print_source: + print(kernel.get_kernel_source()) + + kernel( + q_shard, + ag_k, + ag_v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + output, + stream=compute_stream.cuda_stream) + + compute_stream.wait_stream(ctx.ag_stream) + barrier_all_on_stream(ctx.barrier, compute_stream, world_size)