[Example] Add CP example#37
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis PR adds intra-node distributed sequence-parallel attention capabilities by introducing a fused flash-attention kernel module with barrier synchronization and all-gather coordination, along with example implementations comparing fused and reference torch-based attention. It also refactors an existing GEMM example to use 2D swizzle mapping for improved threadblock scheduling. Changes
Sequence Diagram(s)sequenceDiagram
participant Rank0 as Rank 0
participant Rank1 as Rank 1
participant Kernel as Fused Kernel
participant Sync as Barrier/Stream
Rank0->>Rank0: Prepare Q shard
Rank1->>Rank1: Prepare Q shard
Rank0->>Kernel: cp_engine_producer_kv_all_gather (K/V shard)
Rank1->>Kernel: cp_engine_producer_kv_all_gather (K/V shard)
Rank0->>Sync: Barrier (await all K/V data)
Rank1->>Sync: Barrier (await all K/V data)
Rank0->>Kernel: Launch flashattn (Q shard + all K/V)
Rank1->>Kernel: Launch flashattn (Q shard + all K/V)
Kernel->>Kernel: Compute fused attention
Kernel->>Rank0: Output buffer
Kernel->>Rank1: Output buffer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
examples/distributed/example_gemm_rs_overlapped.py (1)
33-40: Consider adding documentation for the swizzling pattern.The 2D swizzling logic is correct and handles boundary conditions properly (when the last group has fewer rows than GROUP_SIZE_M). However, the pattern is subtle and would benefit from a brief docstring or comment explaining the grouping strategy and how tiles map to (pid_m, pid_n) coordinates.
Example addition:
def swizzle_2d(tile_id, num_pid_m, num_pid_n): + """ + Maps 1D tile_id to 2D (pid_m, pid_n) coordinates using grouped swizzling. + Tiles are grouped by GROUP_SIZE_M rows to improve cache locality. + """ num_pid_in_group = GROUP_SIZE_M * num_pid_n
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/distributed/example_gemm_rs_overlapped.py(2 hunks)examples/distributed/example_sp_ag_attention_intra_node.py(1 hunks)examples/distributed/sp_ag_attention_intra_node.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/distributed/example_sp_ag_attention_intra_node.py (5)
tilelang/distributed/utils.py (2)
init_dist(40-62)perf_fn(231-243)examples/distributed/sp_ag_attention_intra_node.py (2)
create_sp_ag_attention_context_intra_node(560-608)fused_sp_ag_attn_intra_node(696-783)tilelang/env.py (1)
disable_cache(286-287)tilelang/utils/allocator.py (1)
get_allocator(236-248)tilelang/utils/tensor.py (1)
tensor(43-56)
examples/distributed/sp_ag_attention_intra_node.py (7)
tilelang/distributed/utils.py (1)
CUDA_CHECK(246-254)tilelang/env.py (1)
disable_cache(286-287)tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/builtin.py (1)
barrier_all_blocks_sys(625-632)tilelang/language/allocate.py (1)
alloc_barrier(120-129)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/utils/tensor.py (1)
tensor(43-56)
🪛 Ruff (0.14.3)
examples/distributed/example_sp_ag_attention_intra_node.py
171-171: Unpacked variable total_q_shard is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
172-172: Unpacked variable total_kv_shard is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
323-323: Consider [0, *list(accumulate(seqlens_q))] instead of concatenation
Replace with [0, *list(accumulate(seqlens_q))]
(RUF005)
325-325: Consider [0, *list(accumulate(seqlens_k))] instead of concatenation
Replace with [0, *list(accumulate(seqlens_k))]
(RUF005)
examples/distributed/sp_ag_attention_intra_node.py
635-635: Unpacked variable total_kv_shard is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
705-705: Unused function argument: max_seqlen_k
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
🔇 Additional comments (1)
examples/distributed/example_gemm_rs_overlapped.py (1)
60-62: Inconsistent pid_m_offset calculation compared to similar code in repository.The formula at line 61 differs significantly from the pattern used in
example_allgather_gemm_overlapped.py(lines 88-89, 130-131). The reviewed code uses(local_rank + 1) * M_per_rank // block_Mwith floor division, while the related examples useT.ceildiv(local_rank * M_per_rank, block_M)with ceiling division and no+1offset.For rank 0: reviewed produces
1*M_per_rank//block_Mwhile the pattern in related code produces0. Confirm this difference is intentional (i.e., RS overlapped kernels require different offset strategy than allgather variants) or if this indicates a bug in either implementation.
| assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size | ||
| == 0), f"sequence length should be multiple of world_size({self.world_size})" | ||
| self.max_q_shard_len = self.max_seqlen_q // self.world_size |
There was a problem hiding this comment.
Guard both Q and KV maxima
This assertion currently checks max_seqlen_q twice. If max_seqlen_k is not divisible by world_size we keep going, but every downstream divide-by-world-size (e.g. in cp_engine_producer_kv_all_gather) will truncate and corrupt the gathered KV layout. Please assert on max_seqlen_k as well so we fail fast with a clear error.
- assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
- == 0), f"sequence length should be multiple of world_size({self.world_size})"
+ assert (max_seqlen_q % self.world_size == 0 and max_seqlen_k % self.world_size == 0), (
+ f"max_seqlen_q and max_seqlen_k must be multiples of world_size ({self.world_size})"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size | |
| == 0), f"sequence length should be multiple of world_size({self.world_size})" | |
| self.max_q_shard_len = self.max_seqlen_q // self.world_size | |
| assert (max_seqlen_q % self.world_size == 0 and max_seqlen_k % self.world_size == 0), ( | |
| f"max_seqlen_q and max_seqlen_k must be multiples of world_size ({self.world_size})" | |
| ) | |
| self.max_q_shard_len = self.max_seqlen_q // self.world_size |
🤖 Prompt for AI Agents
In examples/distributed/example_sp_ag_attention_intra_node.py around lines
50-52, the assertion mistakenly checks max_seqlen_q twice; add a guard that also
asserts max_seqlen_k % self.world_size == 0 so both Q and KV maxima are
divisible by world_size, and after the assertions compute and store the
per-shard KV length (e.g., set self.max_k_shard_len = self.max_seqlen_k //
self.world_size) to mirror the Q shard calculation so downstream code doesn’t
truncate or corrupt KV layout.
| self.max_q_shard_len = max_seqlen_q // self.world_size | ||
| self.max_kv_shard_ken = max_seqlen_q // self.world_size | ||
| assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size | ||
| == 0), f"sequence length should be multiple of world_size({self.world_size})" |
There was a problem hiding this comment.
Add the missing KV divisibility check
Same copy/paste slip here: without validating max_seqlen_k, the integer divisions that follow silently drop remainder tokens and misalign the gathered buffers. Please assert on max_seqlen_k alongside max_seqlen_q.
- assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
- == 0), f"sequence length should be multiple of world_size({self.world_size})"
+ assert (max_seqlen_q % self.world_size == 0 and max_seqlen_k % self.world_size == 0), (
+ f"max_seqlen_q and max_seqlen_k must be multiples of world_size ({self.world_size})"
+ )🤖 Prompt for AI Agents
In examples/distributed/example_sp_ag_attention_intra_node.py around lines 139
to 142, the code only validates and divides max_seqlen_q but not max_seqlen_k
(and also has a typo max_kv_shard_ken); add the missing check and correct the
variable names: assert that both max_seqlen_q and max_seqlen_k are divisible by
world_size, compute self.max_q_shard_len = max_seqlen_q // self.world_size and
self.max_kv_shard_len = max_seqlen_k // self.world_size, and update the assert
message/context to mention both sequence lengths.
| ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape( | ||
| half_kv_shard_len * self.world_size, kv_head, head_dim) | ||
| ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape( | ||
| half_kv_shard_len * self.world_size, kv_head, head_dim) | ||
| ag_v0 = self.ag_v_buffer.reshape(-1)[:buffer_size].reshape( | ||
| half_kv_shard_len * self.world_size, kv_head, head_dim) | ||
| ag_v1 = self.ag_v_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape( | ||
| half_kv_shard_len * self.world_size, kv_head, head_dim) | ||
| torch.distributed.all_gather_into_tensor( | ||
| ag_k0, | ||
| k0_shard, | ||
| group=self.pg, | ||
| ) | ||
| torch.distributed.all_gather_into_tensor( | ||
| ag_k1, | ||
| k1_shard, | ||
| group=self.pg, | ||
| ) | ||
| torch.distributed.all_gather_into_tensor( | ||
| ag_v0, | ||
| v0_shard, | ||
| group=self.pg, | ||
| ) | ||
| torch.distributed.all_gather_into_tensor( | ||
| ag_v1, | ||
| v1_shard, | ||
| group=self.pg, | ||
| ) | ||
| ag_k1 = ag_k1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim) | ||
| ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, | ||
| head_dim) | ||
| ag_k = torch.cat((ag_k0, ag_k1), dim=0) | ||
| ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), | ||
| (0, 2, 1, 3)).contiguous() | ||
| ag_k = ag_k.repeat_interleave(q_head // kv_head, -3) | ||
|
|
||
| ag_v1 = ag_v1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim) | ||
| ag_v1 = torch.flip(ag_v1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head, | ||
| head_dim) | ||
| ag_v = torch.cat((ag_v0, ag_v1), dim=0) | ||
| ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), | ||
| (0, 2, 1, 3)).contiguous() | ||
| ag_v = ag_v.repeat_interleave(q_head // kv_head, -3) | ||
|
|
||
| offset_q0 = half_q_shard_len * self.rank | ||
| offset_q1 = q_len - half_q_shard_len * (self.rank + 1) | ||
| prefix = kv_len - q_len | ||
| mask0 = _gen_mask(prefix + offset_q0, half_q_shard_len, kv_len) | ||
| mask1 = _gen_mask(prefix + offset_q1, half_q_shard_len, kv_len) | ||
| out0 = torch.nn.functional.scaled_dot_product_attention( | ||
| q0_shard_permute, ag_k, ag_v, attn_mask=mask0) | ||
| out1 = torch.nn.functional.scaled_dot_product_attention( | ||
| q1_shard_permute, ag_k, ag_v, attn_mask=mask1) | ||
| out = torch.cat((out0, out1), dim=2) # [1, q_head, q_shard_len, head_dim] | ||
| else: | ||
| cu_q_shard = q_shard[cu_seqlens_q_start:cu_seqlens_q_end, :, :].clone() | ||
| cu_q_shard_permute = torch.permute( | ||
| cu_q_shard.reshape(1, q_shard_len, q_head, head_dim), | ||
| (0, 2, 1, 3)).contiguous() | ||
|
|
||
| total_size = kv_len * kv_head * head_dim | ||
| ag_k = self.ag_k_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim) | ||
| cu_k_shard = k_shard[cu_seqlens_k_start:cu_seqlens_k_end, :, :].clone() | ||
| torch.distributed.all_gather_into_tensor( | ||
| ag_k, | ||
| cu_k_shard, | ||
| group=self.pg, | ||
| ) | ||
| ag_v = self.ag_v_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim) | ||
| cu_v_shard = v_shard[cu_seqlens_k_start:cu_seqlens_k_end, :, :].clone() | ||
| torch.distributed.all_gather_into_tensor( | ||
| ag_v, | ||
| cu_v_shard, | ||
| group=self.pg, | ||
| ) | ||
| ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim), | ||
| (0, 2, 1, 3)).contiguous() | ||
| ag_k = ag_k.repeat_interleave(q_head // kv_head, -3) | ||
| ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim), | ||
| (0, 2, 1, 3)).contiguous() | ||
| ag_v = ag_v.repeat_interleave(q_head // kv_head, -3) | ||
|
|
||
| offset = self.rank * q_shard_len | ||
| prefix = kv_len - q_len | ||
| mask = _gen_mask(prefix + offset, q_shard_len, kv_len) | ||
| out = torch.nn.functional.scaled_dot_product_attention( | ||
| cu_q_shard_permute, ag_k, ag_v, | ||
| attn_mask=mask) # [1, q_head, q_shard_len, head_dim] |
There was a problem hiding this comment.
Fix all_gather_into_tensor output shape
torch.distributed.all_gather_into_tensor requires its destination tensor’s leading dimension to equal the process-group size. Reshaping the scratch buffers to half_kv_shard_len * world_size (and similarly in the non–zig-zag path) breaks that contract and will raise as soon as world_size > 1. Please keep the [world_size, …] structure for the gather and only collapse it afterwards so the API preconditions hold. Apply the same adjustment to the V-path and the non–zig-zag branch.
- ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape(
- half_kv_shard_len * self.world_size, kv_head, head_dim)
+ ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape(
+ self.world_size, half_kv_shard_len, kv_head, head_dim)
torch.distributed.all_gather_into_tensor(
ag_k0,
k0_shard,
group=self.pg,
)
- ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape(
- half_kv_shard_len * self.world_size, kv_head, head_dim)
+ ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape(
+ self.world_size, half_kv_shard_len, kv_head, head_dim)
torch.distributed.all_gather_into_tensor(
ag_k1,
k1_shard,
group=self.pg,
)
- ag_k1 = ag_k1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim)
- ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head,
- head_dim)
- ag_k = torch.cat((ag_k0, ag_k1), dim=0)
+ ag_k1 = torch.flip(ag_k1, [0])
+ ag_k = torch.cat(
+ (ag_k0.reshape(-1, kv_head, head_dim),
+ ag_k1.reshape(-1, kv_head, head_dim)),
+ dim=0,
+ )Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/distributed/example_sp_ag_attention_intra_node.py around lines
212-299, the destination tensors passed to
torch.distributed.all_gather_into_tensor are reshaped so their leading dimension
is half_kv_shard_len * world_size (or flattened), which violates the API
requirement that the leading dimension equals world_size; this will break when
world_size > 1. Fix by keeping the gather-destination views with leading dim ==
self.world_size (e.g., reshape the buffers to (self.world_size,
half_kv_shard_len, kv_head, head_dim) for ag_k0/ag_k1/ag_v0/ag_v1 before calling
all_gather_into_tensor, call all_gather_into_tensor into those [world_size,...]
views, then collapse/reshape/permute the gathered tensor afterwards into the
existing ag_k/ag_v shapes; apply the same pattern in the non–zig-zag branch (use
a (self.world_size, kv_len_per_rank, kv_head, head_dim) view for ag_k and ag_v
during gather, then collapse and repeat_interleave afterward).
This pull request refactors the threadblock swizzling logic in the
example_gemm_rs_overlapped.pyGEMM kernel example to improve code clarity and maintainability. The main change is the extraction of the 2D swizzling logic into a dedicated helper function, which simplifies the main kernel code and makes the swizzling process more modular.Refactoring and code clarity:
swizzle_2dhelper function, replacing the inlined calculation in the main kernel. This makes the code easier to read and maintain.swizzle_2dfunction for calculatingpid_m_andpid_n, and simplified the calculation ofpid_m_offsetfor better clarity.Summary by CodeRabbit
New Features
Improvements