Skip to content

[Example] Update intra-node GEMM-RS example#35

Merged
chengyupku merged 5 commits into
mainfrom
yu/dev
Oct 30, 2025
Merged

[Example] Update intra-node GEMM-RS example#35
chengyupku merged 5 commits into
mainfrom
yu/dev

Conversation

@chengyupku

@chengyupku chengyupku commented Oct 30, 2025

Copy link
Copy Markdown

This pull request introduces support for new tilelang CUDA intrinsics for atomic operations and stores, and demonstrates their usage in distributed GEMM examples. The changes include adding new built-in operators, corresponding CUDA code generation and PTX implementations, and updating distributed example scripts to leverage these features. Additionally, a new distributed GEMM example with overlapped reduce-scatter is provided, while an older example is removed.

Tilelang CUDA atomic intrinsics support

  • Added new built-in operators atom_add and st to tilelang, with corresponding registration in src/op/builtin.cc and header declarations in src/op/builtin.h. These operators represent atomic add (returning the original value) and atomic store with semantics. [1] [2]
  • Implemented PTX-level CUDA functions for atomic add and store with various memory semantics in src/tl_templates/cuda/atomic.h and src/tl_templates/cuda/sync.h. [1] [2]
  • Updated CUDA code generation in src/target/codegen_cuda.cc to emit calls to the new atomic intrinsics, mapping tilelang operators to the correct PTX functions.

Distributed GEMM examples update

  • Added a new example example_gemm_rs_overlapped.py demonstrating overlapped GEMM and reduce-scatter using the new atomic intrinsics and synchronization primitives.
  • Removed the legacy example example_gemm_rs.py, which is now superseded by the new version.

Miscellaneous

  • Minor utility import added to example_allgather_gemm_overlapped.py for CUDA error checking. [1] [2]
  • Utility module tilelang/distributed/utils.py updated with additional imports for threading and subprocess management.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added overlapped distributed GEMM example with multi-GPU support
    • Introduced 2D reduce-scatter runtime for efficient distributed GPU coordination
    • Added atomic operations and synchronization primitives for GPU memory operations
    • Integrated NVLink topology detection for cluster configuration validation
  • Bug Fixes

    • Added error checking for GPU stream operations
  • Chores

    • Removed legacy GEMM example

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Oct 30, 2025

Copy link
Copy Markdown

Caution

Review failed

The pull request is closed.

Walkthrough

This PR adds distributed GEMM infrastructure with overlapped computation support. It introduces atomic and store synchronization primitives at PTX and IR levels, creates a comprehensive 2D reduce-scatter framework with NVLink topology detection, provides new distributed GEMM examples, and updates an existing example with error checking.

Changes

Cohort / File(s) Summary
Atomic and Store Intrinsics (PTX Layer)
src/tl_templates/cuda/atomic.h, src/tl_templates/cuda/sync.h
Added 8 PTX-based atomic add functions (ptx_atom_add_) with variants for gpu/sys scopes and relaxed/acquire/release/acq_rel semantics; added 4 device-level store functions (st_) with release/relaxed variants for gpu/sys scopes.
TileLang IR and Codegen
src/op/builtin.h, src/op/builtin.cc, src/target/codegen_cuda.cc
Declared and defined two new TL builtins (atom_add, st) marked as opaque; added CUDA codegen branches to emit ptx_atom_add_* and st_* calls from CallNode intrinsics.
Python API Layer
tilelang/language/builtin.py
Exposed atom_add() and st() as public high-level functions with input validation for scope and memory semantic parameters.
Distributed Utilities
tilelang/distributed/utils.py
Added NVML initialization helpers, NvidiaSmiUtil class for topology parsing, and cached has_fullmesh_nvlink() function to detect full NVLink connectivity across ranks.
Distributed Examples
examples/distributed/example_allgather_gemm_overlapped.py
Added CUDA_CHECK error checking immediately after cuStreamWriteValue32 call.
Reduce-Scatter Framework
examples/distributed/reduce_scatter.py
New comprehensive 2D reduce-scatter runtime with ReduceScatter2DContext, intra-node scatter, ring-based reduction kernels (ring_reduce_tma), signal synchronization utilities, and multi-node orchestration.
Distributed GEMM Examples
examples/distributed/example_gemm_rs_overlapped.py
New complete distributed GEMM with overlapped RS example; defines gemm_kernel (JIT-compiled tiled GEMM), gemm_rs_op orchestrator, PyTorch reference path, performance measurement, and CLI entry point.
Removed Example
examples/distributed/example_gemm_rs.py
Entire file removed (standalone GEMM RS example with torch_gemm_rs helper, GemmRS class, CLI, and benchmarking harness).

Sequence Diagram

sequenceDiagram
    participant Main as Main Process
    participant Distributed as Distributed Init
    participant GEMM as GEMM Kernel Build
    participant Exec as Execution (Multi-rank)
    participant TL as TileLang GEMM RS
    participant RS as Reduce-Scatter
    participant PT as PyTorch Path
    participant Verify as Verification

    Main->>Distributed: Initialize process group & NVLink topology check
    Distributed->>Distributed: has_fullmesh_nvlink() validation
    Distributed-->>Main: Ready
    Main->>GEMM: Compile gemm_kernel (JIT)
    GEMM-->>Main: Kernel ready
    Main->>Exec: Spawn processes per rank
    Exec->>TL: Execute gemm_rs_op(A, B, C)
    activate TL
    TL->>TL: Run gemm_kernel on gemm_stream
    TL->>RS: Synchronize, initiate reduce-scatter
    RS->>RS: intra_node_scatter with signal sync
    RS->>RS: ring_reduce_tma (per-node reduction)
    RS->>RS: ring_reduce (inter-node if multi-node)
    RS-->>TL: Output ready
    deactivate TL
    TL-->>Exec: TileLang result
    Exec->>PT: Execute torch_gemm_rs (baseline)
    PT-->>Exec: PyTorch result
    Exec->>Verify: Compare results (allclose)
    Verify-->>Exec: Pass/Fail status
    Exec-->>Main: Benchmark & metrics
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Reduce-scatter infrastructure (examples/distributed/reduce_scatter.py): High-density logic with signal synchronization, multi-node orchestration, and kernel factory patterns requires careful validation of synchronization semantics and topology assumptions.
  • New distributed GEMM example (examples/distributed/example_gemm_rs_overlapped.py): Substantial kernel orchestration and stream coordination; correctness depends on proper synchronization with reduce-scatter.
  • Atomic/store primitive stack: Multiple coordinated layers (PTX → ops → codegen → Python API) with scope and semantic parameters; scope parameter values must be consistent across layers.
  • NVLink topology detection: New NVML-based fallback path with thread-safety and caching logic; hardcoded assumptions (fullmesh requirement in reduce_scatter_multi_node) should be validated.
  • Interconnected changes: Modifications span C++, CUDA headers, and Python across multiple abstraction levels; changes in one layer directly impact others.

Specific areas requiring extra attention:

  • Correctness of signal synchronization in _wait_eq_cuda and barrier reset logic in reduce_scatter_2d_op
  • Multi-node reduce-scatter path and assumptions about NVLink topology (inter-node p2p currently raises NotImplementedError)
  • Consistency of scope/semantic string literals across atomic.h, sync.h, codegen, and builtin.py
  • NVML initialization thread-safety and fallback paths in utils.py

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

🐰 Hops through atomic fences bright,
Scatter signals through the night,
NVLink meshes, kernels dance,
Overlapped ranks in sync—a chance!
DeviceSync and atoms gleam,
Building faster GPU dreams!

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yu/dev

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aa916f4 and 2ce2af8.

📒 Files selected for processing (11)
  • examples/distributed/example_allgather_gemm_overlapped.py (2 hunks)
  • examples/distributed/example_gemm_rs.py (0 hunks)
  • examples/distributed/example_gemm_rs_overlapped.py (1 hunks)
  • examples/distributed/reduce_scatter.py (1 hunks)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/target/codegen_cuda.cc (1 hunks)
  • src/tl_templates/cuda/atomic.h (1 hunks)
  • src/tl_templates/cuda/sync.h (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
  • tilelang/language/builtin.py (1 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@chengyupku chengyupku merged commit fae99e9 into main Oct 30, 2025
2 of 3 checks passed
@coderabbitai coderabbitai Bot mentioned this pull request Dec 20, 2025
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant