Skip to content

[Example] Add CP example#37

Merged
chengyupku merged 10 commits into
mainfrom
yu/dev
Nov 5, 2025
Merged

[Example] Add CP example#37
chengyupku merged 10 commits into
mainfrom
yu/dev

Conversation

@chengyupku

@chengyupku chengyupku commented Nov 5, 2025

Copy link
Copy Markdown

This pull request refactors the threadblock swizzling logic in the example_gemm_rs_overlapped.py GEMM kernel example to improve code clarity and maintainability. The main change is the extraction of the 2D swizzling logic into a dedicated helper function, which simplifies the main kernel code and makes the swizzling process more modular.

Refactoring and code clarity:

  • Extracted the 2D swizzling logic into a new swizzle_2d helper function, replacing the inlined calculation in the main kernel. This makes the code easier to read and maintain.
  • Updated the main kernel logic to use the new swizzle_2d function for calculating pid_m_ and pid_n, and simplified the calculation of pid_m_offset for better clarity.

Summary by CodeRabbit

  • New Features

    • Added distributed sequence parallel attention module with intra-node support
    • Support for multiple attention execution paths including zigzag and specialized variants
    • New example demonstrating fused and reference attention implementations
  • Improvements

    • Optimized thread block mapping calculation for enhanced kernel performance

@github-actions

github-actions Bot commented Nov 5, 2025

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Nov 5, 2025

Copy link
Copy Markdown

Walkthrough

This PR adds intra-node distributed sequence-parallel attention capabilities by introducing a fused flash-attention kernel module with barrier synchronization and all-gather coordination, along with example implementations comparing fused and reference torch-based attention. It also refactors an existing GEMM example to use 2D swizzle mapping for improved threadblock scheduling.

Changes

Cohort / File(s) Summary
Fused Attention Kernel Module
examples/distributed/sp_ag_attention_intra_node.py
New file introducing flashattn kernel with zigzag and specialized variants, barrier utilities (barrier_all_blocks_sys_kernel, barrier_all_on_stream), all-gather producer (cp_engine_producer_kv_all_gather), context dataclass (SPAllGatherAttentionContextIntraNode), and main fused routine (fused_sp_ag_attn_intra_node) for distributed sequence-parallel attention with CUDA stream/barrier coordination.
Distributed Attention Example
examples/distributed/example_sp_ag_attention_intra_node.py
New file providing example usage with two parallel attention implementations: FusedSequenceParallelAttn (uses fused kernel module) and TorchSequenceParallelAttn (reference torch implementation with manual KV gathering). Includes distributed harness with rank initialization, input shard generation, output validation, and performance profiling.
GEMM Overlapped Refactor
examples/distributed/example_gemm_rs_overlapped.py
Modified to introduce swizzle_2d nested function for 2D threadblock mapping. Replaces manual pid_m_ / pid_n_ calculation with swizzle output and updates pid_m_offset to use formula (local_rank + 1) \* M_per_rank // block_M.

Sequence Diagram(s)

sequenceDiagram
    participant Rank0 as Rank 0
    participant Rank1 as Rank 1
    participant Kernel as Fused Kernel
    participant Sync as Barrier/Stream

    Rank0->>Rank0: Prepare Q shard
    Rank1->>Rank1: Prepare Q shard
    Rank0->>Kernel: cp_engine_producer_kv_all_gather (K/V shard)
    Rank1->>Kernel: cp_engine_producer_kv_all_gather (K/V shard)
    Rank0->>Sync: Barrier (await all K/V data)
    Rank1->>Sync: Barrier (await all K/V data)
    Rank0->>Kernel: Launch flashattn (Q shard + all K/V)
    Rank1->>Kernel: Launch flashattn (Q shard + all K/V)
    Kernel->>Kernel: Compute fused attention
    Kernel->>Rank0: Output buffer
    Kernel->>Rank1: Output buffer
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Key areas requiring attention:
    • sp_ag_attention_intra_node.py: Dense kernel implementations with multiple flash-attention variants, barrier synchronization logic, and CUDA stream/buffer coordination. Requires careful verification of memory safety, tensor shape broadcasting, and correctness of zigzag/specialized tiling strategies.
    • cp_engine_producer_kv_all_gather: All-gather semantics with asynchronous CUDA copy operations and barrier synchronization patterns—potential for synchronization bugs.
    • example_sp_ag_attention_intra_node.py: Distributed process group initialization, data shard alignment, and numerical tolerance validation between fused and reference implementations.
    • Heterogeneous changes across three files with varying levels of abstraction (low-level kernel primitives, high-level distributed wrappers, existing refactor).

Poem

🐰 A kernel swizzles fine, barriers sync just right,
All-gather dances swift through the CUDA night,
Attention fused and swift, from shard to shard,
Parallel sequences leap—no task too hard! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The PR title '[Example] Add CP example' is vague and generic, using the non-descriptive term 'CP example' without clarifying what CP means or what functionality is being added. Clarify the title to specify what CP refers to (e.g., 'Add sequence parallel attention example' or 'Add collective primitive example') and use more descriptive language.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yu/dev

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
examples/distributed/example_gemm_rs_overlapped.py (1)

33-40: Consider adding documentation for the swizzling pattern.

The 2D swizzling logic is correct and handles boundary conditions properly (when the last group has fewer rows than GROUP_SIZE_M). However, the pattern is subtle and would benefit from a brief docstring or comment explaining the grouping strategy and how tiles map to (pid_m, pid_n) coordinates.

Example addition:

 def swizzle_2d(tile_id, num_pid_m, num_pid_n):
+    """
+    Maps 1D tile_id to 2D (pid_m, pid_n) coordinates using grouped swizzling.
+    Tiles are grouped by GROUP_SIZE_M rows to improve cache locality.
+    """
     num_pid_in_group = GROUP_SIZE_M * num_pid_n
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fae99e9 and 3b29e36.

📒 Files selected for processing (3)
  • examples/distributed/example_gemm_rs_overlapped.py (2 hunks)
  • examples/distributed/example_sp_ag_attention_intra_node.py (1 hunks)
  • examples/distributed/sp_ag_attention_intra_node.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/distributed/example_sp_ag_attention_intra_node.py (5)
tilelang/distributed/utils.py (2)
  • init_dist (40-62)
  • perf_fn (231-243)
examples/distributed/sp_ag_attention_intra_node.py (2)
  • create_sp_ag_attention_context_intra_node (560-608)
  • fused_sp_ag_attn_intra_node (696-783)
tilelang/env.py (1)
  • disable_cache (286-287)
tilelang/utils/allocator.py (1)
  • get_allocator (236-248)
tilelang/utils/tensor.py (1)
  • tensor (43-56)
examples/distributed/sp_ag_attention_intra_node.py (7)
tilelang/distributed/utils.py (1)
  • CUDA_CHECK (246-254)
tilelang/env.py (1)
  • disable_cache (286-287)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/builtin.py (1)
  • barrier_all_blocks_sys (625-632)
tilelang/language/allocate.py (1)
  • alloc_barrier (120-129)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/utils/tensor.py (1)
  • tensor (43-56)
🪛 Ruff (0.14.3)
examples/distributed/example_sp_ag_attention_intra_node.py

171-171: Unpacked variable total_q_shard is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


172-172: Unpacked variable total_kv_shard is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


323-323: Consider [0, *list(accumulate(seqlens_q))] instead of concatenation

Replace with [0, *list(accumulate(seqlens_q))]

(RUF005)


325-325: Consider [0, *list(accumulate(seqlens_k))] instead of concatenation

Replace with [0, *list(accumulate(seqlens_k))]

(RUF005)

examples/distributed/sp_ag_attention_intra_node.py

635-635: Unpacked variable total_kv_shard is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


705-705: Unused function argument: max_seqlen_k

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia
🔇 Additional comments (1)
examples/distributed/example_gemm_rs_overlapped.py (1)

60-62: Inconsistent pid_m_offset calculation compared to similar code in repository.

The formula at line 61 differs significantly from the pattern used in example_allgather_gemm_overlapped.py (lines 88-89, 130-131). The reviewed code uses (local_rank + 1) * M_per_rank // block_M with floor division, while the related examples use T.ceildiv(local_rank * M_per_rank, block_M) with ceiling division and no +1 offset.

For rank 0: reviewed produces 1*M_per_rank//block_M while the pattern in related code produces 0. Confirm this difference is intentional (i.e., RS overlapped kernels require different offset strategy than allgather variants) or if this indicates a bug in either implementation.

Comment on lines +50 to +52
assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
== 0), f"sequence length should be multiple of world_size({self.world_size})"
self.max_q_shard_len = self.max_seqlen_q // self.world_size

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard both Q and KV maxima

This assertion currently checks max_seqlen_q twice. If max_seqlen_k is not divisible by world_size we keep going, but every downstream divide-by-world-size (e.g. in cp_engine_producer_kv_all_gather) will truncate and corrupt the gathered KV layout. Please assert on max_seqlen_k as well so we fail fast with a clear error.

-        assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
-                == 0), f"sequence length should be multiple of world_size({self.world_size})"
+        assert (max_seqlen_q % self.world_size == 0 and max_seqlen_k % self.world_size == 0), (
+            f"max_seqlen_q and max_seqlen_k must be multiples of world_size ({self.world_size})"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
== 0), f"sequence length should be multiple of world_size({self.world_size})"
self.max_q_shard_len = self.max_seqlen_q // self.world_size
assert (max_seqlen_q % self.world_size == 0 and max_seqlen_k % self.world_size == 0), (
f"max_seqlen_q and max_seqlen_k must be multiples of world_size ({self.world_size})"
)
self.max_q_shard_len = self.max_seqlen_q // self.world_size
🤖 Prompt for AI Agents
In examples/distributed/example_sp_ag_attention_intra_node.py around lines
50-52, the assertion mistakenly checks max_seqlen_q twice; add a guard that also
asserts max_seqlen_k % self.world_size == 0 so both Q and KV maxima are
divisible by world_size, and after the assertions compute and store the
per-shard KV length (e.g., set self.max_k_shard_len = self.max_seqlen_k //
self.world_size) to mirror the Q shard calculation so downstream code doesn’t
truncate or corrupt KV layout.

Comment on lines +139 to +142
self.max_q_shard_len = max_seqlen_q // self.world_size
self.max_kv_shard_ken = max_seqlen_q // self.world_size
assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
== 0), f"sequence length should be multiple of world_size({self.world_size})"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add the missing KV divisibility check

Same copy/paste slip here: without validating max_seqlen_k, the integer divisions that follow silently drop remainder tokens and misalign the gathered buffers. Please assert on max_seqlen_k alongside max_seqlen_q.

-        assert (max_seqlen_q % self.world_size == 0 and max_seqlen_q % self.world_size
-                == 0), f"sequence length should be multiple of world_size({self.world_size})"
+        assert (max_seqlen_q % self.world_size == 0 and max_seqlen_k % self.world_size == 0), (
+            f"max_seqlen_q and max_seqlen_k must be multiples of world_size ({self.world_size})"
+        )
🤖 Prompt for AI Agents
In examples/distributed/example_sp_ag_attention_intra_node.py around lines 139
to 142, the code only validates and divides max_seqlen_q but not max_seqlen_k
(and also has a typo max_kv_shard_ken); add the missing check and correct the
variable names: assert that both max_seqlen_q and max_seqlen_k are divisible by
world_size, compute self.max_q_shard_len = max_seqlen_q // self.world_size and
self.max_kv_shard_len = max_seqlen_k // self.world_size, and update the assert
message/context to mention both sequence lengths.

Comment on lines +212 to +299
ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape(
half_kv_shard_len * self.world_size, kv_head, head_dim)
ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape(
half_kv_shard_len * self.world_size, kv_head, head_dim)
ag_v0 = self.ag_v_buffer.reshape(-1)[:buffer_size].reshape(
half_kv_shard_len * self.world_size, kv_head, head_dim)
ag_v1 = self.ag_v_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape(
half_kv_shard_len * self.world_size, kv_head, head_dim)
torch.distributed.all_gather_into_tensor(
ag_k0,
k0_shard,
group=self.pg,
)
torch.distributed.all_gather_into_tensor(
ag_k1,
k1_shard,
group=self.pg,
)
torch.distributed.all_gather_into_tensor(
ag_v0,
v0_shard,
group=self.pg,
)
torch.distributed.all_gather_into_tensor(
ag_v1,
v1_shard,
group=self.pg,
)
ag_k1 = ag_k1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim)
ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head,
head_dim)
ag_k = torch.cat((ag_k0, ag_k1), dim=0)
ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim),
(0, 2, 1, 3)).contiguous()
ag_k = ag_k.repeat_interleave(q_head // kv_head, -3)

ag_v1 = ag_v1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim)
ag_v1 = torch.flip(ag_v1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head,
head_dim)
ag_v = torch.cat((ag_v0, ag_v1), dim=0)
ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim),
(0, 2, 1, 3)).contiguous()
ag_v = ag_v.repeat_interleave(q_head // kv_head, -3)

offset_q0 = half_q_shard_len * self.rank
offset_q1 = q_len - half_q_shard_len * (self.rank + 1)
prefix = kv_len - q_len
mask0 = _gen_mask(prefix + offset_q0, half_q_shard_len, kv_len)
mask1 = _gen_mask(prefix + offset_q1, half_q_shard_len, kv_len)
out0 = torch.nn.functional.scaled_dot_product_attention(
q0_shard_permute, ag_k, ag_v, attn_mask=mask0)
out1 = torch.nn.functional.scaled_dot_product_attention(
q1_shard_permute, ag_k, ag_v, attn_mask=mask1)
out = torch.cat((out0, out1), dim=2) # [1, q_head, q_shard_len, head_dim]
else:
cu_q_shard = q_shard[cu_seqlens_q_start:cu_seqlens_q_end, :, :].clone()
cu_q_shard_permute = torch.permute(
cu_q_shard.reshape(1, q_shard_len, q_head, head_dim),
(0, 2, 1, 3)).contiguous()

total_size = kv_len * kv_head * head_dim
ag_k = self.ag_k_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim)
cu_k_shard = k_shard[cu_seqlens_k_start:cu_seqlens_k_end, :, :].clone()
torch.distributed.all_gather_into_tensor(
ag_k,
cu_k_shard,
group=self.pg,
)
ag_v = self.ag_v_buffer.reshape(-1)[:total_size].reshape(kv_len, kv_head, head_dim)
cu_v_shard = v_shard[cu_seqlens_k_start:cu_seqlens_k_end, :, :].clone()
torch.distributed.all_gather_into_tensor(
ag_v,
cu_v_shard,
group=self.pg,
)
ag_k = torch.permute(ag_k.reshape(1, kv_len, kv_head, head_dim),
(0, 2, 1, 3)).contiguous()
ag_k = ag_k.repeat_interleave(q_head // kv_head, -3)
ag_v = torch.permute(ag_v.reshape(1, kv_len, kv_head, head_dim),
(0, 2, 1, 3)).contiguous()
ag_v = ag_v.repeat_interleave(q_head // kv_head, -3)

offset = self.rank * q_shard_len
prefix = kv_len - q_len
mask = _gen_mask(prefix + offset, q_shard_len, kv_len)
out = torch.nn.functional.scaled_dot_product_attention(
cu_q_shard_permute, ag_k, ag_v,
attn_mask=mask) # [1, q_head, q_shard_len, head_dim]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix all_gather_into_tensor output shape

torch.distributed.all_gather_into_tensor requires its destination tensor’s leading dimension to equal the process-group size. Reshaping the scratch buffers to half_kv_shard_len * world_size (and similarly in the non–zig-zag path) breaks that contract and will raise as soon as world_size > 1. Please keep the [world_size, …] structure for the gather and only collapse it afterwards so the API preconditions hold. Apply the same adjustment to the V-path and the non–zig-zag branch.

-                ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape(
-                    half_kv_shard_len * self.world_size, kv_head, head_dim)
+                ag_k0 = self.ag_k_buffer.reshape(-1)[:buffer_size].reshape(
+                    self.world_size, half_kv_shard_len, kv_head, head_dim)
                 torch.distributed.all_gather_into_tensor(
                     ag_k0,
                     k0_shard,
                     group=self.pg,
                 )
-                ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape(
-                    half_kv_shard_len * self.world_size, kv_head, head_dim)
+                ag_k1 = self.ag_k_buffer.reshape(-1)[buffer_size:2 * buffer_size].reshape(
+                    self.world_size, half_kv_shard_len, kv_head, head_dim)
                 torch.distributed.all_gather_into_tensor(
                     ag_k1,
                     k1_shard,
                     group=self.pg,
                 )
-                ag_k1 = ag_k1.reshape(self.world_size, half_kv_shard_len, kv_head, head_dim)
-                ag_k1 = torch.flip(ag_k1, [0]).reshape(self.world_size * half_kv_shard_len, kv_head,
-                                                       head_dim)
-                ag_k = torch.cat((ag_k0, ag_k1), dim=0)
+                ag_k1 = torch.flip(ag_k1, [0])
+                ag_k = torch.cat(
+                    (ag_k0.reshape(-1, kv_head, head_dim),
+                     ag_k1.reshape(-1, kv_head, head_dim)),
+                    dim=0,
+                )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/distributed/example_sp_ag_attention_intra_node.py around lines
212-299, the destination tensors passed to
torch.distributed.all_gather_into_tensor are reshaped so their leading dimension
is half_kv_shard_len * world_size (or flattened), which violates the API
requirement that the leading dimension equals world_size; this will break when
world_size > 1. Fix by keeping the gather-destination views with leading dim ==
self.world_size (e.g., reshape the buffers to (self.world_size,
half_kv_shard_len, kv_head, head_dim) for ag_k0/ag_k1/ag_v0/ag_v1 before calling
all_gather_into_tensor, call all_gather_into_tensor into those [world_size,...]
views, then collapse/reshape/permute the gathered tensor afterwards into the
existing ag_k/ag_v shapes; apply the same pattern in the non–zig-zag branch (use
a (self.world_size, kv_len_per_rank, kv_head, head_dim) view for ag_k and ag_v
during gather, then collapse and repeat_interleave afterward).

@chengyupku chengyupku merged commit dafaa6f into main Nov 5, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant