From 422571e3e4ec59eeef9a348d2c79c4d3cec614dd Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Tue, 28 Oct 2025 20:31:16 +0800 Subject: [PATCH 1/4] [README] Update dependency in installation doc --- docs/get_started/Installation.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index f5135f56d1..9e9b296ac7 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -21,6 +21,8 @@ cd tilescale 2. **Install Project**: ```bash +pip install cuda-python==12.9 # should align with your nvcc version +pip install scikit-build-core CMake torch ninja Cython pip install -e . --no-build-isolation ``` From 570e0b77837fe26cceae324f4af8a54a9b70dd42 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Thu, 30 Oct 2025 15:04:25 +0800 Subject: [PATCH 2/4] [Example] Update intra-node GEMM-RS example --- .../example_allgather_gemm_overlapped.py | 2 + examples/distributed/example_gemm_rs.py | 184 -------- .../distributed/example_gemm_rs_overlapped.py | 219 +++++++++ examples/distributed/reduce_scatter.py | 421 ++++++++++++++++++ src/op/builtin.cc | 9 + src/op/builtin.h | 16 + src/target/codegen_cuda.cc | 11 + src/tl_templates/cuda/atomic.h | 67 +++ src/tl_templates/cuda/sync.h | 28 ++ tilelang/distributed/utils.py | 101 +++++ tilelang/language/builtin.py | 25 ++ 11 files changed, 899 insertions(+), 184 deletions(-) delete mode 100644 examples/distributed/example_gemm_rs.py create mode 100644 examples/distributed/example_gemm_rs_overlapped.py create mode 100644 examples/distributed/reduce_scatter.py diff --git a/examples/distributed/example_allgather_gemm_overlapped.py b/examples/distributed/example_allgather_gemm_overlapped.py index e8a3a94cf8..cebf58ed1a 100644 --- a/examples/distributed/example_allgather_gemm_overlapped.py +++ b/examples/distributed/example_allgather_gemm_overlapped.py @@ -6,6 +6,7 @@ import torch.distributed as dist import torch.multiprocessing from tilelang.distributed import init_dist +from tilelang.distributed.utils import CUDA_CHECK from tilelang.carver.arch import driver import importlib.metadata @@ -171,6 +172,7 @@ def cp_engine_producer_all_gather_full_mesh_pull( signal_target, cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, ) + CUDA_CHECK(err) def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, local_rank, diff --git a/examples/distributed/example_gemm_rs.py b/examples/distributed/example_gemm_rs.py deleted file mode 100644 index 328282903b..0000000000 --- a/examples/distributed/example_gemm_rs.py +++ /dev/null @@ -1,184 +0,0 @@ -import torch -import argparse -import os -import tilelang -from tilelang.distributed import init_distributed, generate_data -from gemm_rs_kernel import gemm_rs -from gemm_rs_utils import create_gemm_rs_context -from typing import Optional - -tilelang.disable_cache() - - -def torch_gemm_rs( - input: torch.Tensor, # [M, local_k] - weight: torch.Tensor, # [N, local_K] - bias: Optional[torch.Tensor], - TP_GROUP, -): - M, local_K = input.shape - N = weight.shape[0] - output = torch.matmul(input, weight.T) - if bias: - output = output + bias - rs_output = torch.empty((M // WORLD_SIZE, N), dtype=output.dtype, device=input.device) - torch.distributed.reduce_scatter_tensor(rs_output, output, group=TP_GROUP) - return rs_output - - -class GemmRS(torch.nn.Module): - - def __init__( - self, - tp_group: torch.distributed.ProcessGroup, - max_M: int, - N: int, - K: int, - input_dtype: torch.dtype, - output_dtype: torch.dtype, - local_world_size: int = -1, - ): - super().__init__() - self.tp_group = tp_group - self.rank: int = tp_group.rank() - self.world_size = tp_group.size() - self.local_world_size = local_world_size if local_world_size != -1 else self.world_size - self.local_rank = self.rank % self.local_world_size - - self.max_M: int = max_M - self.N = N - self.K = K - self.input_dtype = input_dtype - self.output_dtype = output_dtype - - self.rs_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1) - - self.ctx = create_gemm_rs_context(max_M, N, self.rank, self.world_size, - self.local_world_size, output_dtype, self.rs_stream) - - def forward( - self, - input: torch.Tensor, # [M, local_K] - weight: torch.Tensor, # [N, local_K] - bias: Optional[torch.Tensor] = None, - persistent: bool = True, - ): - assert input.shape[0] <= self.max_M and weight.shape[0] == self.N - - return gemm_rs(input, weight, self.ctx, persistent) - - -DTYPE_MAP = { - "bfloat16": torch.bfloat16, - "float16": torch.float16, - "float8_e4m3fn": torch.float8_e4m3fn, - "float8_e5m2": torch.float8_e5m2, - "s8": torch.int8, - "s32": torch.int32, -} - -THRESHOLD_MAP = { - torch.float16: 1e-2, - torch.bfloat16: 6e-2, - torch.float8_e4m3fn: 1e-2, - torch.float8_e5m2: 1e-2, -} - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("M", type=int) - parser.add_argument("N", type=int) - parser.add_argument("K", type=int) - parser.add_argument("--warmup", default=20, type=int, help="warmup iterations") - parser.add_argument("--iters", default=100, type=int, help="perf iterations") - parser.add_argument("--dtype", default="float16", type=str, help="data type") - - parser.add_argument( - "--profile", default=False, action="store_true", help="dump torch.profiler.profile") - parser.add_argument("--check", default=False, action="store_true", help="correctness check") - parser.add_argument("--verify-iters", default=1, type=int) - parser.add_argument( - "--persistent", - action=argparse.BooleanOptionalAction, - default=torch.cuda.get_device_capability() >= (9, 0)) - - parser.add_argument( - "--transpose_weight", - dest="transpose_weight", - action=argparse.BooleanOptionalAction, - help="transpose weight", - default=True, - ) - parser.add_argument("--has_bias", default=False, action="store_true", help="whether have bias") - parser.add_argument("--seed", type=int, default=42) - - return parser.parse_args() - - -if __name__ == "__main__": - # init - args = parse_args() - - # TODO: remove this - args.persistent = False - - RANK = int(os.environ.get("RANK", 0)) - LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) - WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) - LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) - - WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) - - input_dtype = DTYPE_MAP[args.dtype] - output_dtype = input_dtype - atol = THRESHOLD_MAP[output_dtype] - rtol = THRESHOLD_MAP[output_dtype] - - assert args.M % TP_GROUP.size() == 0 - assert args.K % TP_GROUP.size() == 0 - local_K = args.K // TP_GROUP.size() - - scale = TP_GROUP.rank() + 1 - - def _make_data(M): - data_config = [ - ((M, local_K), input_dtype, (0.01 * scale, 0)), # A - ((args.N, local_K), input_dtype, (0.01 * scale, 0)), # B - ( # bias - None if not args.has_bias else ((M, args.N), input_dtype, (1, 0))), - ] - generator = generate_data(data_config) - input, weight, bias = next(generator) - return input, weight, bias - - gemm_rs_op = GemmRS(TP_GROUP, args.M, args.N, args.K, input_dtype, output_dtype, - LOCAL_WORLD_SIZE) - - torch.cuda.empty_cache() - input_list = [ - # _make_data(random.randint(1, args.M // WORLD_SIZE) * WORLD_SIZE) for _ in range(args.verify_iters) - _make_data(args.M) for _ in range(args.verify_iters) - ] - dist_out_list, torch_out_list = [], [] - - # torch impl - for input, weight, bias in input_list: - torch_out = torch_gemm_rs( - input, - weight, - bias, - TP_GROUP, - ) - torch_out_list.append(torch_out) - - # dist triton impl - for input, weight, bias in input_list: - dist_out = gemm_rs_op.forward(input, weight, bias, args.persistent) - dist_out_list.append(dist_out) - # verify - # for idx, (torch_out, dist_out) in enumerate(zip(torch_out_list, dist_out_list)): - # assert_allclose(torch_out, dist_out, atol=atol, rtol=rtol, verbose=False) - - # print(f"RANK[{RANK}]: pass.") - # exit(0) diff --git a/examples/distributed/example_gemm_rs_overlapped.py b/examples/distributed/example_gemm_rs_overlapped.py new file mode 100644 index 0000000000..87abff57b9 --- /dev/null +++ b/examples/distributed/example_gemm_rs_overlapped.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import tilelang +import tilelang.language as T +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +from tilelang.distributed import init_dist +from tilelang.carver.arch import driver +import importlib.metadata + +cuda_python_version = importlib.metadata.version("cuda-python") +from packaging import version +if version.parse(cuda_python_version) >= version.parse("12.8.0"): + from cuda.bindings import driver as cuda +else: + from cuda import cuda +from tilelang.distributed import perf_fn +from reduce_scatter import reduce_scatter_2d_op, create_reduce_scater_2d_ctx + +tilelang.disable_cache() + + +@tilelang.jit +def gemm_kernel(M, + N, + K, + local_rank, + num_local_rank, + block_M, + block_N, + block_K, + threads, + persistent=False, + dtype="float16", + accum_dtype="float"): + + M_per_rank = T.ceildiv(M, num_local_rank) + GROUP_SIZE_M = 8 + + @T.prim_func + def main( + A: T.Tensor((M, K // num_local_rank), dtype), + B: T.Tensor((K // num_local_rank, N), dtype), + scatter_signal_buf: T.Tensor((num_local_rank), "uint32"), + counter_signal_buf: T.Tensor((num_local_rank), "uint32"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel( + T.ceildiv(M, block_M) * T.ceildiv(N, block_N), + threads=threads) as (bid): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + val = T.alloc_local((1,), "uint32") + + num_pid_m = T.ceildiv(M, block_M) + num_pid_n = T.ceildiv(N, block_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = bid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = T.min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m_ = first_pid_m + ((bid % num_pid_in_group) % group_size_m) + pid_n_ = (bid % num_pid_in_group) // group_size_m + + # threadblock swizzle + # no stream-k support. only split by m x n + m_offset = M_per_rank * local_rank + pid_m_offset = T.ceildiv(m_offset, block_M) + pid_m = (pid_m_ + pid_m_offset) % num_pid_m + pid_n = pid_n_ + + tid = T.get_thread_binding(0) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K // num_local_rank, block_K), num_stages=3): + T.copy(A[pid_m * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, pid_n * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[pid_m * block_M, pid_n * block_N]) + + # inc barrier + segment_start = pid_m * block_M // M_per_rank + segment_end = (T.min((pid_m + 1) * block_M, M) - 1) // M_per_rank + segment = segment_start + tid + if segment <= segment_end: + m_start = M_per_rank * segment + m_end = M_per_rank * (segment + 1) - 1 + tiled_m_start = m_start // block_M + tiled_m_end = m_end // block_M + tiled_m_size = tiled_m_end - tiled_m_start + 1 + val[0] = T.atom_add(counter_signal_buf[segment], 1, scope="gpu", sem="release") + if T.Cast("int32", val[0]) == num_pid_n * tiled_m_size - 1: + T.st(scatter_signal_buf[segment], 1, scope="gpu", sem="release") + + return main + + +def gemm_rs_op(A, B, C, output, ctx, gemm_kernel, gemm_stream, rs_stream, local_rank, print_source=False): + + current_stream = torch.cuda.current_stream() + rs_stream.wait_stream(gemm_stream) + + gemm_kernel(A, B, ctx.scatter_signal_bufs[local_rank], ctx.counter_bufs[local_rank], C, stream=gemm_stream.cuda_stream) + + if print_source and local_rank == 1: + print(gemm_kernel.get_kernel_source()) + + with torch.cuda.stream(rs_stream): + # don't allocate memory on other stream: error-prune + output = reduce_scatter_2d_op(C, ctx, output) + gemm_stream.wait_stream(rs_stream) + current_stream.wait_stream(rs_stream) + + return output + + +def torch_gemm_rs( + pg: torch.distributed.ProcessGroup, + input: torch.Tensor, # [M, local_k] + weight: torch.Tensor, # [local_K, N] + bias: torch.Tensor | None, + num_local_ranks: int, +): + M, local_K = input.shape + N = weight.shape[1] + output = torch.matmul(input, weight) + if bias: + output = output + bias + rs_output = torch.empty((M // num_local_ranks, N), dtype=output.dtype, device=input.device) + torch.distributed.reduce_scatter_tensor(rs_output, output, group=pg) + return rs_output + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = torch.float16 + M = args.M if args else 8192 + N = args.N if args else 8192 + K = args.K if args else 8192 + persistent = args.persistent + M_per_rank = M // num_local_ranks + K_per_rank = K // num_local_ranks + + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" + allocator = tilelang.get_allocator( + size=2**30, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group) + gemm_func = gemm_kernel(M, N, K, local_rank, num_local_ranks, BLOCK_M, BLOCK_N, BLOCK_K, + threads, persistent) + gemm_func.initialize(allocator=allocator) + + A = tilelang.tensor((M, K_per_rank), dtype, allocator=allocator).normal_() / 10 + B = tilelang.tensor((K_per_rank, N), dtype, allocator=allocator).normal_() / 10 + C = tilelang.tensor((M, N), dtype, allocator=allocator) + output = tilelang.tensor((M_per_rank, N), dtype, allocator=allocator) + gemm_stream = torch.cuda.Stream() + rs_stream = torch.cuda.Stream(priority=-1) + ctx = create_reduce_scater_2d_ctx( + M, + N, + local_rank, + num_local_ranks, + num_local_ranks, + dtype, + allocator, + overlap_with_gemm=True, + num_reduction_sms=15 + ) + + dist.barrier() + + tilelang_out = gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank, print_source=True) + torch_out = torch_gemm_rs(group, A, B, None, num_local_ranks) + + atol = 1e-2 + rtol = 1e-2 + if torch.allclose(torch_out, tilelang_out, atol=atol, rtol=rtol): + print(f"rank {local_rank} check passed.✅") + else: + print(f"rank {local_rank} check failed.❌") + print(f"torch_out: {torch_out}, tilelang_out: {tilelang_out}") + + _, tl_t = perf_fn( + lambda: + gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank), + warmup=5, + rep=5) + + print( + f"rank {local_rank} tilelang gemm_rs time: {tl_t:.2f} ms, TFLOPS: {2*M*N*K/1e9/(tl_t)/num_local_ranks:.2f}" + ) + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') + parser.add_argument('--M', type=int, default=8192, help='M dimension') + parser.add_argument('--N', type=int, default=8192, help='N dimension') + parser.add_argument('--K', type=int, default=29568, help='K dimension') + parser.add_argument('--persistent', action='store_true', help='Use persistent kernel') + args = parser.parse_args() + num_processes = args.num_processes + + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) diff --git a/examples/distributed/reduce_scatter.py b/examples/distributed/reduce_scatter.py new file mode 100644 index 0000000000..5b7270d178 --- /dev/null +++ b/examples/distributed/reduce_scatter.py @@ -0,0 +1,421 @@ +import dataclasses +from typing import List, Optional + +import torch +import importlib.metadata +from packaging import version +cuda_python_version = importlib.metadata.version("cuda-python") +if version.parse(cuda_python_version) >= version.parse("12.8.0"): + from cuda.bindings import driver as cuda + from cuda.bindings import runtime as cudart +else: + from cuda import cuda, cudart +import tilelang +from tilelang.distributed.utils import CUDA_CHECK, has_fullmesh_nvlink +from tilelang.utils.target import target_is_hopper, determine_target +import torch.distributed as dist +import tilelang +import tilelang.language as T + +tilelang.disable_cache() + +@dataclasses.dataclass +class ReduceScatter2DContext: + max_M: int + N: int + rank: int + world_size: int + local_world_size: int + dtype: torch.dtype + overlap_with_gemm: bool + + # comm buffer + scatter_bufs: List[torch.Tensor] + rs_per_node_bufs: List[torch.Tensor] + p2p_bufs: List[torch.Tensor] + + # barrier bufs + signal_bufs: List[torch.Tensor] # need reset: signal_buf = scatter_signal | rs_per_node_signal + + counter_bufs: List[torch.Tensor] + + # intra-node barrier + barrier: List[torch.Tensor] + + # stream + reduction_stream: torch.cuda.Stream + + # sms + num_sync_sms: int + num_p2p_sms: int + num_reduction_sms: int + + # preprocess to reduce cpu overhead + # comm barriers + scatter_signal_bufs: List[torch.Tensor] = dataclasses.field(init=False) + rs_per_node_signal_bufs: List[torch.Tensor] = dataclasses.field(init=False) + + local_rank: int = dataclasses.field(init=False) + node_id: int = dataclasses.field(init=False) + nnodes: int = dataclasses.field(init=False) + + scatter_signal_buf_list_for_each_node: List[torch.Tensor] = dataclasses.field(init=False) + + def __post_init__(self): + self.local_rank = self.rank % self.local_world_size + self.node_id = self.rank // self.local_world_size + assert self.world_size % self.local_world_size == 0 + assert self.max_M % self.world_size == 0 + assert len(self.signal_bufs) == self.local_world_size + self.nnodes = self.world_size // self.local_world_size + self.scatter_signal_buf_list_for_each_node = [] + for buf in self.signal_bufs: + assert buf.shape[0] >= 2 * self.world_size + + self.scatter_signal_bufs = [buf[:self.world_size] for buf in self.signal_bufs] + self.rs_per_node_signal_bufs = [buf[self.world_size:self.world_size * 2] for buf in self.signal_bufs] + + for node_id in range(self.nnodes): + self.scatter_signal_buf_list_for_each_node.append( + self.scatter_signal_bufs[self.local_rank][node_id * self.local_world_size:(node_id + 1) * + self.local_world_size]) + + def reset_barriers(self): + self.signal_bufs[self.local_rank].fill_(0) + self.counter_bufs[self.local_rank].fill_(0) + + def get_scatter_bufs_and_signal_for_each_node(self, input, node_id): + M = input.shape[0] + M_per_rank = M // self.world_size + M_per_node = M_per_rank * self.local_world_size + M_start = node_id * M_per_node + M_end = M_start + M_per_node + scatter_bufs_intra_node = [self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size)] + return scatter_bufs_intra_node, self.scatter_signal_buf_list_for_each_node[node_id] + + @property + def rs_per_node_buf(self) -> torch.Tensor: + return self.rs_per_node_bufs[self.local_rank] + + @property + def rs_per_node_signal_buf(self) -> torch.Tensor: + return self.rs_per_node_signal_bufs[self.local_rank] + + @property + def p2p_buf(self) -> torch.Tensor: + return self.p2p_bufs[self.local_rank] + + @property + def num_rs_sms(self) -> int: + if self.nnodes > 1: + return self.num_sync_sms + self.num_p2p_sms + self.num_reduction_sms + else: + # for intra node rs, no need sm. + return 0 + + @property + def scatter_signal_buf(self) -> torch.Tensor: + return self.scatter_signal_bufs[self.local_rank] + + +def create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, dtype, allocator, overlap_with_gemm=True, + num_reduction_sms=15) -> ReduceScatter2DContext: + """ + for num_reduction_sms: tunable param, 16 are enough for H800 + For H800, we overlap local reduce and inter-node p2p with intra-node scatter. + The reduction kernel bandwidth is not a bottleneck if it exceeds 450GB, so only a few SMs are needed. + For machines with higher intra_node bandwidth(e.g. H100), we may need to increase the number of SMs or redesign overlapping. + """ + assert world_size % local_world_size == 0 + assert max_M % world_size == 0 + + scatter_bufs = tilelang.tensor((max_M, N), dtype, allocator=allocator, return_peers=True) + rs_per_node_bufs = tilelang.tensor((max_M // local_world_size, N), dtype, allocator=allocator, return_peers=True) + p2p_bufs = tilelang.tensor((max_M // local_world_size, N), dtype, allocator=allocator, return_peers=True) + + # signal_buf: scatter_signal | rs_per_node_signal + num_signal_bufs = 2 + signal_bufs = tilelang.tensor((world_size * num_signal_bufs), dtype=torch.uint32, allocator=allocator, return_peers=True) + symm_barriers = tilelang.tensor((local_world_size, ), torch.int32, allocator=allocator, return_peers=True) + symm_barriers[rank] = 0 + + counter_signal_buf = tilelang.tensor((local_world_size), dtype=torch.uint32, allocator=allocator, return_peers=True) + + dist.barrier() + + reduction_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1) + + num_sync_sms = 0 + num_p2p_sms = 1 + ctx = ReduceScatter2DContext(max_M=max_M, N=N, rank=rank, world_size=world_size, local_world_size=local_world_size, + dtype=dtype, overlap_with_gemm=overlap_with_gemm, scatter_bufs=scatter_bufs, + rs_per_node_bufs=rs_per_node_bufs, p2p_bufs=p2p_bufs, signal_bufs=signal_bufs, + counter_bufs=counter_signal_buf, + barrier=symm_barriers, reduction_stream=reduction_stream, + num_sync_sms=num_sync_sms, num_p2p_sms=num_p2p_sms, + num_reduction_sms=num_reduction_sms) + return ctx + + +@tilelang.jit +def kernel_ring_reduce_tma( + M_per_rank, + N, + block_M, + block_N, + begin_idx, + num_splits, + threads, + persistent=False, + dtype="float16", + accum_dtype="float"): + + @T.prim_func + def _kernel_ring_reduce_tma( + C: T.Tensor((M_per_rank * num_splits, N), dtype), + output: T.Tensor((M_per_rank, N), dtype), + ): + with T.Kernel(T.ceildiv(M_per_rank, block_M), T.ceildiv(N, block_N), + threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_N), dtype) + init_shared = T.alloc_shared((block_M, block_N), dtype) + data_local = T.alloc_fragment((block_M, block_N), dtype) + accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + cur_rank = (begin_idx + 1) % num_splits + T.copy(C[cur_rank * M_per_rank + bx * block_M, by * block_N], init_shared) + T.copy(init_shared, accum) + for i in T.Pipelined(num_splits - 1, num_stages=0): + cur_rank = (i + 1 + begin_idx + 1) % num_splits + T.copy(C[cur_rank * M_per_rank + bx * block_M, by * block_N], data_shared) + T.copy(data_shared, data_local) + for i, j in T.Parallel(block_M, block_N): + accum[i, j] += data_local[i, j] + T.copy(accum, output[bx * block_M, by * block_N]) + + return _kernel_ring_reduce_tma + + +def _wait_eq_cuda(signal_tensor: torch.Tensor, signal: int, stream: Optional[torch.cuda.Stream] = None, + require_i64=False): + stream = stream or torch.cuda.current_stream() + if signal_tensor.dtype in (torch.int32, torch.uint32): + (err, ) = cuda.cuStreamWaitValue32( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, + ) + CUDA_CHECK(err) + elif signal_tensor.dtype in (torch.int64, torch.uint64): + (err, ) = cuda.cuStreamWaitValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, + ) + CUDA_CHECK(err) + else: + raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") + +def intra_node_scatter(input_intra_node, scatter_bufs_intra_node: List[torch.Tensor], + scatter_signal_buf_intra_node: torch.Tensor, local_rank, overlap_with_gemm=True): + M, N = input_intra_node.shape + local_world_size = len(scatter_bufs_intra_node) + M_per_rank = M // local_world_size + + stream = torch.cuda.current_stream() + + for i in range(local_world_size): + remote_local_rank = (local_rank + i + 1) % local_world_size + + # print(f"scatter_signal_buf_intra_node[remote_local_rank]: {scatter_signal_buf_intra_node[remote_local_rank]}") + if overlap_with_gemm: + _wait_eq_cuda(scatter_signal_buf_intra_node[remote_local_rank], 1, stream) + src = input_intra_node[ + remote_local_rank * M_per_rank : (remote_local_rank + 1) * M_per_rank, : + ] + dst = scatter_bufs_intra_node[remote_local_rank][ + local_rank * M_per_rank : (local_rank + 1) * M_per_rank, : + ] + with torch.cuda.stream(stream): + dst.copy_(src) + +def ring_reduce_non_tma( + input: torch.Tensor, # [M_per_node, N] + output: torch.Tensor, # [M_per_rank, N] + begin_idx, + num_splits, + num_sms=16, +): + total_M, N = input.shape + M_per_split = total_M // num_splits + assert output.shape[0] == M_per_split and total_M % num_splits == 0, output.shape + num_warps = 32 + kernel_ring_reduce_non_tma[(num_sms, )]( + input, + output, + M_per_split * N, + begin_idx, + num_splits, + BLOCK_SIZE=32 * num_warps * 8 * 4, + num_warps=num_warps, + ) + return output + + +def ring_reduce_tma( + input: torch.Tensor, # [M_per_node, N] + output: torch.Tensor, # [M_per_rank, N] + begin_idx, + num_splits, + num_sms=-1, +): + total_M, N = input.shape + M_per_split = total_M // num_splits + assert output.shape[0] == M_per_split and total_M % num_splits == 0, f"{output.shape}, {total_M}, {num_splits}" + + def alloc_fn(size, alignment, stream): + return torch.empty(size, device="cuda", dtype=torch.int8) + + if num_sms == -1: + ring_reduce_tma_func = kernel_ring_reduce_tma( + M_per_split, + N, + block_M=64, + block_N=64, + begin_idx=begin_idx, + num_splits=num_splits, + threads=128) + # if begin_idx == 0: + # print(ring_reduce_tma_func.get_kernel_source()) + ring_reduce_tma_func(input, output, stream=torch.cuda.current_stream().cuda_stream) + else: + raise NotImplementedError("Currently only support num_sms = -1 for TMA ring reduce.") + # grid = lambda META: (min( + # triton.cdiv(M_per_split, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), num_sms), ) + # kernel_ring_reduce_tma[grid]( + # input, + # output, + # M_per_split, + # N, + # begin_idx, + # num_splits, + # BLOCK_SIZE_M=256, + # BLOCK_SIZE_N=128, + # num_warps=8, + # ) + + return output + +target = determine_target(return_object=True) + +def ring_reduce( + input, # [M_per_node, N] + output, # [M_per_rank, N] + begin_idx, + num_splits, + num_sms=-1, +): + if target_is_hopper(target): + return ring_reduce_tma(input, output, begin_idx, num_splits, num_sms) + else: + return ring_reduce_non_tma(input, output, begin_idx, num_splits, 16 if num_sms == -1 else num_sms) + +def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContext, + output: Optional[torch.Tensor] = None): + world_size = ctx.world_size + local_world_size = ctx.local_world_size + local_rank = ctx.local_rank + reduction_stream = ctx.reduction_stream + num_reduction_sms = ctx.num_reduction_sms + M, N = input.shape + M_per_rank = M // world_size + M_per_node = M_per_rank * local_world_size + nnodes = ctx.nnodes + node_id = ctx.node_id + rs_per_node_buf = ctx.rs_per_node_buf + p2p_buf = ctx.p2p_buf + + print(f"reduce_scatter_for_each_node: M={M}, N={N}, M_per_rank={M_per_rank}, M_per_node={M_per_node}, nnodes={nnodes}, node_id={node_id}, local_rank={local_rank}") + + stream = torch.cuda.current_stream() + for n in range(0, nnodes): + cur_node_id = (node_id + n + 1) % nnodes + input_intra_node = input[cur_node_id * M_per_node:(cur_node_id + 1) * M_per_node] + scatter_bufs_intra_node, scatter_signal_buf_intra_node = ctx.get_scatter_bufs_and_signal_for_each_node( + input, cur_node_id) + intra_node_scatter(input_intra_node, scatter_bufs_intra_node, scatter_signal_buf_intra_node, local_rank, + overlap_with_gemm=ctx.overlap_with_gemm) + + # ring reduce intra node + rs_buf_cur_node = rs_per_node_buf[M_per_rank * cur_node_id:(cur_node_id + 1) * M_per_rank] + # nvshmem_barrier_all_on_stream(stream) + reduction_stream.wait_stream(stream) + with torch.cuda.stream(reduction_stream): + reduce_out_buf = output if nnodes == 1 else rs_buf_cur_node + ring_reduce(scatter_bufs_intra_node[local_rank], reduce_out_buf, local_rank, local_world_size, + num_sms=-1 if n == nnodes - 1 else num_reduction_sms) + + # inter node p2p + if nnodes > 1: + if n == nnodes - 1: + p2p_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)].copy_( + rs_per_node_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)]) + else: + grid = lambda META: (ctx.num_p2p_sms, ) + kernel_inter_node_p2p_for_same_local_rank[grid]( + n, + local_world_size, + M_per_rank, + N, + rs_per_node_buf, + p2p_buf, + num_warps=16, + ) + + stream.wait_stream(reduction_stream) + if nnodes == 1: + return output + return p2p_buf[:M_per_rank * nnodes] + + +def reduce_scatter_multi_node(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): + """ + A hierarchical reduce-scatter implementation that overlaps the intra-node scatter + with the local reduce and the inter-node p2p(after reduce). It also provides a rank-wise + signal and supports overlap with gemm. + """ + M, N = input.shape + M_per_rank = M // ctx.world_size + + current_stream = torch.cuda.current_stream() + ctx.reduction_stream.wait_stream(current_stream) + + # directly reduce_scatter to output if nnodes == 1 + out_each_node = output if ctx.nnodes == 1 else None + if not has_fullmesh_nvlink(): + raise Exception("Only support fullmesh nvlink topology for now.") + else: + print("Using fullmesh nvlink reduce_scatter.") + rs_result_per_node = reduce_scatter_for_each_node(input, ctx, out_each_node) + + if ctx.nnodes == 1: + return rs_result_per_node + + # nvshmem_barrier_all_on_stream(current_stream) + if output is None: + output = torch.empty((M_per_rank, N), dtype=input.dtype, device=input.device) + ring_reduce(rs_result_per_node, output, ctx.node_id, ctx.nnodes) + return output + +def reduce_scatter_2d_op(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): + M, N = input.shape + assert input.dtype == ctx.dtype + assert ctx.max_M >= M and ctx.N == N + assert M % ctx.world_size == 0 + + # nvshmem_barrier_all_on_stream(torch.cuda.current_stream()) + output = reduce_scatter_multi_node(input, ctx, output) + ctx.reset_barriers() + return output diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 6241613fe5..646b300e5b 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -312,5 +312,14 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(atom_add) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(st) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index e46b6da13e..85a95f11a9 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -537,6 +537,22 @@ TVM_DLL const Op &increase_descriptor_offset(); */ TVM_DLL const Op &atomicadd_elem_op(); +/*! + * \brief tilelang intrinsic for atomic add that returns the original value. + * + * This op is used to represent an atomic add operation that returns the + * original value before addition in tilelang. + */ +TVM_DLL const Op &atom_add(); + +/*! + * \brief tilelang intrinsic for atomic store with semantic. + * + * This op is used to represent an atomic store operation with semantic in + * tilelang. + */ +TVM_DLL const Op &st(); + } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 2aece8f15b..d942e8ab8b 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1516,6 +1516,17 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); this->stream << "tl::wait_eq(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ");\n"; + } else if (op->op.same_as(tl::atom_add())) { + std::string func_name = "tl::ptx_atom_add_" + op->args[2].as()->value + + "_" + op->args[3].as()->value; + os << func_name << "(" << this->PrintExpr(op->args[0]) << ", " + << this->PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::st())) { + this->PrintIndent(); + std::string func_name = "tl::st_" + op->args[2].as()->value + + "_" + op->args[3].as()->value; + this->stream << func_name << "(" << this->PrintExpr(op->args[0]) << ", " + << this->PrintExpr(op->args[1]) << ");\n"; } else if (op->op.same_as(tl::get_clock())) { os << "get_clock()"; } else if (op->op.same_as(tl::loop_break())) { diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 4ee85a1ad9..97f0608f69 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -202,3 +202,70 @@ TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { cuda::atomic_ref aref(ref); aref.store(cuda_cast(value), cuda::memory_order(memory_order)); } + +namespace tl { + +TL_DEVICE uint32_t ptx_atom_add_relaxed_gpu(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.relaxed.gpu.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_acquire_gpu(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.acquire.gpu.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_release_gpu(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.release.gpu.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_acq_rel_gpu(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.acq_rel.gpu.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.relaxed.sys.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_acquire_sys(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.acquire.sys.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_release_sys(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.release.sys.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} + +TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(const uint32_t *ptr, uint32_t value) { + uint32_t ret; + asm volatile("atom.add.acq_rel.sys.global.u32 %0, [%1], %2;\n" + : "=r"(ret) + : "l"(ptr), "r"(value)); + return ret; +} +} // namespace tl \ No newline at end of file diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index 6a632b1bd1..9cf86fb9ad 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -186,4 +186,32 @@ template TL_DEVICE void wait_eq(void *barrier, T val = 1) { } } +TL_DEVICE void st_release_gpu(uint32_t *ptr, uint32_t value) { + asm volatile("st.release.gpu.global.b32 [%0], %1;" + : + : "l"(ptr), "r"(value) + : "memory"); +} + +TL_DEVICE void st_relaxed_gpu(uint32_t *ptr, uint32_t value) { + asm volatile("st.relaxed.gpu.global.b32 [%0], %1;" + : + : "l"(ptr), "r"(value) + : "memory"); +} + +TL_DEVICE void st_release_sys(uint32_t *ptr, uint32_t value) { + asm volatile("st.release.sys.global.b32 [%0], %1;" + : + : "l"(ptr), "r"(value) + : "memory"); +} + +TL_DEVICE void st_relaxed_sys(uint32_t *ptr, uint32_t value) { + asm volatile("st.relaxed.sys.global.b32 [%0], %1;" + : + : "l"(ptr), "r"(value) + : "memory"); +} + } // namespace tl diff --git a/tilelang/distributed/utils.py b/tilelang/distributed/utils.py index 152af7c425..1ac69f722f 100644 --- a/tilelang/distributed/utils.py +++ b/tilelang/distributed/utils.py @@ -20,7 +20,11 @@ import ctypes from tilescale_ext import _create_tensor, _create_ipc_handle, _sync_ipc_handles +import functools from functools import lru_cache +from threading import Lock +import subprocess +import warnings dtype_map = { "bfloat16": torch.bfloat16, @@ -293,3 +297,100 @@ def cuda_stream_max_priority(): ret = cudart.cudaDeviceGetStreamPriorityRange() CUDA_CHECK(ret[0]) return ret[2] # (leastPriority, greatestPriority) -> greatestPriority is max priority + +def CUDA_CHECK(err): + if isinstance(err, cuda.CUresult): + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"Cuda Error: {err}: {cuda.cuGetErrorName(err)}") + elif isinstance(err, cudart.cudaError_t): + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"Cuda Error: {err}: {cudart.cudaGetErrorString(err)}") + else: + raise RuntimeError(f"Unknown error type: {err}") + +_pynvml_initialized = False +_lock = Lock() + +def ensure_nvml_initialized(): + global _pynvml_initialized + if not _pynvml_initialized: + with _lock: + if not _pynvml_initialized: + import pynvml + + pynvml.nvmlInit() + _pynvml_initialized = True + +@functools.lru_cache() +def has_fullmesh_nvlink_pynvml(): + num_devices = torch.cuda.device_count() + + ensure_nvml_initialized() + import pynvml + + try: + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(num_devices)] + for cur_device in range(num_devices): + cur_handle = handles[cur_device] + for remote_device in range(num_devices): + if remote_device == cur_device: + continue + remote_handle = handles[remote_device] + p2p_status = pynvml.nvmlDeviceGetP2PStatus(cur_handle, remote_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + return True + except pynvml.NVMLError_NotSupported: + return False + + +class NvidiaSmiUtil: + + @staticmethod + def get_nvlink_adjacency_matrix(): + output = subprocess.check_output(["nvidia-smi", "topo", "-m"], text=True) + lines = [line.strip() for line in output.split("\n") if line.startswith("GPU")] + + device_count = len(lines) + matrix = [[-1 for _ in range(device_count)] for _ in range(device_count)] + + # 解析每行数据 + for i, line in enumerate(lines): + parts = line.split() + for j in range(1, len(parts)): + if "NV" in parts[j]: + matrix[i][j - 1] = 1 + + return matrix + + @staticmethod + def get_gpu_numa_node(gpu_index=0): + try: + cmd = f"nvidia-smi --query-gpu=pci.bus_id --format=csv,noheader,nounits -i {gpu_index}" + pci_id = subprocess.check_output(cmd, shell=True).decode().strip() + pci_address = pci_id.replace("00000000:", "").lower() + + numa_node_path = f"/sys/bus/pci/devices/0000:{pci_address}/numa_node" + with open(numa_node_path, "r") as f: + numa_node = int(f.read().strip()) + + assert numa_node >= 0 + return numa_node if numa_node >= 0 else 0 + + except Exception as e: + print(f"Error: {e}") + return -1 + +@functools.lru_cache() +def has_fullmesh_nvlink(): + try: + return has_fullmesh_nvlink_pynvml() + except Exception: + nvlink_matrix = NvidiaSmiUtil.get_nvlink_adjacency_matrix() + has_nvlink = any([any(x == 1 for x in row) for row in nvlink_matrix]) + _has_fullmesh_nvlink = all([i == j or v == 1 for i, row in enumerate(nvlink_matrix) for j, v in enumerate(row)]) + if has_nvlink and not _has_fullmesh_nvlink: + warnings.warn( + "⚠️ found NVLink but not fullmesh NVLink, this may cause undefined behavior, please check your GPU topology" + ) + return _has_fullmesh_nvlink \ No newline at end of file diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 10d8580562..11ed7893a5 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -725,3 +725,28 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) + + +def atom_add(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = "relaxed"): + """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. + """ + assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." + assert sem in ["relaxed", "acquire", "release", "acq_rel"], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." + return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem, scope) + + +def st(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = "relaxed"): + """Store a value to a given address with specified scope and semantic. + + Args: + address: The address to store the value to + value: The value to store + scope: The memory scope (default is "gpu") + semantic: The memory semantic (default is "relaxed") + + Returns: + tir.Call: A handle to the store operation + """ + assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." + assert sem in ["relaxed", "release"], "Semantic must be one of 'relaxed', or 'release'." + return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(barrier), value, sem, scope) \ No newline at end of file From 36290b1a7c7610a01a94a164ebec662a0d031d65 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Thu, 30 Oct 2025 15:11:26 +0800 Subject: [PATCH 3/4] lint --- .../distributed/example_gemm_rs_overlapped.py | 63 ++--- examples/distributed/reduce_scatter.py | 250 ++++++++++-------- src/op/builtin.cc | 12 +- src/target/codegen_cuda.cc | 5 +- src/tl_templates/cuda/atomic.h | 26 +- src/tl_templates/cuda/sync.h | 24 +- tilelang/distributed/utils.py | 33 ++- tilelang/language/builtin.py | 8 +- 8 files changed, 234 insertions(+), 187 deletions(-) diff --git a/examples/distributed/example_gemm_rs_overlapped.py b/examples/distributed/example_gemm_rs_overlapped.py index 87abff57b9..27519fdd51 100644 --- a/examples/distributed/example_gemm_rs_overlapped.py +++ b/examples/distributed/example_gemm_rs_overlapped.py @@ -7,15 +7,6 @@ import torch.distributed as dist import torch.multiprocessing from tilelang.distributed import init_dist -from tilelang.carver.arch import driver -import importlib.metadata - -cuda_python_version = importlib.metadata.version("cuda-python") -from packaging import version -if version.parse(cuda_python_version) >= version.parse("12.8.0"): - from cuda.bindings import driver as cuda -else: - from cuda import cuda from tilelang.distributed import perf_fn from reduce_scatter import reduce_scatter_2d_op, create_reduce_scater_2d_ctx @@ -47,9 +38,7 @@ def main( counter_signal_buf: T.Tensor((num_local_rank), "uint32"), C: T.Tensor((M, N), dtype), ): - with T.Kernel( - T.ceildiv(M, block_M) * T.ceildiv(N, block_N), - threads=threads) as (bid): + with T.Kernel(T.ceildiv(M, block_M) * T.ceildiv(N, block_N), threads=threads) as (bid): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) @@ -80,7 +69,7 @@ def main( T.gemm(A_shared, B_shared, C_local) T.copy(C_local, C_shared) T.copy(C_shared, C[pid_m * block_M, pid_n * block_N]) - + # inc barrier segment_start = pid_m * block_M // M_per_rank segment_end = (T.min((pid_m + 1) * block_M, M) - 1) // M_per_rank @@ -98,22 +87,37 @@ def main( return main -def gemm_rs_op(A, B, C, output, ctx, gemm_kernel, gemm_stream, rs_stream, local_rank, print_source=False): - +def gemm_rs_op(A, + B, + C, + output, + ctx, + gemm_kernel, + gemm_stream, + rs_stream, + local_rank, + print_source=False): + current_stream = torch.cuda.current_stream() rs_stream.wait_stream(gemm_stream) - - gemm_kernel(A, B, ctx.scatter_signal_bufs[local_rank], ctx.counter_bufs[local_rank], C, stream=gemm_stream.cuda_stream) - + + gemm_kernel( + A, + B, + ctx.scatter_signal_bufs[local_rank], + ctx.counter_bufs[local_rank], + C, + stream=gemm_stream.cuda_stream) + if print_source and local_rank == 1: print(gemm_kernel.get_kernel_source()) - + with torch.cuda.stream(rs_stream): # don't allocate memory on other stream: error-prune output = reduce_scatter_2d_op(C, ctx, output) gemm_stream.wait_stream(rs_stream) current_stream.wait_stream(rs_stream) - + return output @@ -169,19 +173,19 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rs_stream = torch.cuda.Stream(priority=-1) ctx = create_reduce_scater_2d_ctx( M, - N, - local_rank, - num_local_ranks, - num_local_ranks, - dtype, + N, + local_rank, + num_local_ranks, + num_local_ranks, + dtype, allocator, overlap_with_gemm=True, - num_reduction_sms=15 - ) + num_reduction_sms=15) dist.barrier() - tilelang_out = gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank, print_source=True) + tilelang_out = gemm_rs_op( + A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank, print_source=True) torch_out = torch_gemm_rs(group, A, B, None, num_local_ranks) atol = 1e-2 @@ -193,8 +197,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): print(f"torch_out: {torch_out}, tilelang_out: {tilelang_out}") _, tl_t = perf_fn( - lambda: - gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank), + lambda: gemm_rs_op(A, B, C, output, ctx, gemm_func, gemm_stream, rs_stream, local_rank), warmup=5, rep=5) diff --git a/examples/distributed/reduce_scatter.py b/examples/distributed/reduce_scatter.py index 5b7270d178..7527e5c71c 100644 --- a/examples/distributed/reduce_scatter.py +++ b/examples/distributed/reduce_scatter.py @@ -4,21 +4,21 @@ import torch import importlib.metadata from packaging import version + cuda_python_version = importlib.metadata.version("cuda-python") if version.parse(cuda_python_version) >= version.parse("12.8.0"): from cuda.bindings import driver as cuda - from cuda.bindings import runtime as cudart else: - from cuda import cuda, cudart + from cuda import cuda import tilelang from tilelang.distributed.utils import CUDA_CHECK, has_fullmesh_nvlink from tilelang.utils.target import target_is_hopper, determine_target import torch.distributed as dist -import tilelang import tilelang.language as T tilelang.disable_cache() + @dataclasses.dataclass class ReduceScatter2DContext: max_M: int @@ -36,7 +36,7 @@ class ReduceScatter2DContext: # barrier bufs signal_bufs: List[torch.Tensor] # need reset: signal_buf = scatter_signal | rs_per_node_signal - + counter_bufs: List[torch.Tensor] # intra-node barrier @@ -73,13 +73,16 @@ def __post_init__(self): assert buf.shape[0] >= 2 * self.world_size self.scatter_signal_bufs = [buf[:self.world_size] for buf in self.signal_bufs] - self.rs_per_node_signal_bufs = [buf[self.world_size:self.world_size * 2] for buf in self.signal_bufs] + self.rs_per_node_signal_bufs = [ + buf[self.world_size:self.world_size * 2] for buf in self.signal_bufs + ] for node_id in range(self.nnodes): self.scatter_signal_buf_list_for_each_node.append( - self.scatter_signal_bufs[self.local_rank][node_id * self.local_world_size:(node_id + 1) * + self.scatter_signal_bufs[self.local_rank][node_id * + self.local_world_size:(node_id + 1) * self.local_world_size]) - + def reset_barriers(self): self.signal_bufs[self.local_rank].fill_(0) self.counter_bufs[self.local_rank].fill_(0) @@ -90,7 +93,9 @@ def get_scatter_bufs_and_signal_for_each_node(self, input, node_id): M_per_node = M_per_rank * self.local_world_size M_start = node_id * M_per_node M_end = M_start + M_per_node - scatter_bufs_intra_node = [self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size)] + scatter_bufs_intra_node = [ + self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size) + ] return scatter_bufs_intra_node, self.scatter_signal_buf_list_for_each_node[node_id] @property @@ -118,7 +123,14 @@ def scatter_signal_buf(self) -> torch.Tensor: return self.scatter_signal_bufs[self.local_rank] -def create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, dtype, allocator, overlap_with_gemm=True, +def create_reduce_scater_2d_ctx(max_M, + N, + rank, + world_size, + local_world_size, + dtype, + allocator, + overlap_with_gemm=True, num_reduction_sms=15) -> ReduceScatter2DContext: """ for num_reduction_sms: tunable param, 16 are enough for H800 @@ -129,17 +141,32 @@ def create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, dt assert world_size % local_world_size == 0 assert max_M % world_size == 0 - scatter_bufs = tilelang.tensor((max_M, N), dtype, allocator=allocator, return_peers=True) - rs_per_node_bufs = tilelang.tensor((max_M // local_world_size, N), dtype, allocator=allocator, return_peers=True) - p2p_bufs = tilelang.tensor((max_M // local_world_size, N), dtype, allocator=allocator, return_peers=True) + scatter_bufs = tilelang.tensor((max_M, N), dtype, allocator=allocator, return_peers=True) + rs_per_node_bufs = tilelang.tensor((max_M // local_world_size, N), + dtype, + allocator=allocator, + return_peers=True) + p2p_bufs = tilelang.tensor((max_M // local_world_size, N), + dtype, + allocator=allocator, + return_peers=True) # signal_buf: scatter_signal | rs_per_node_signal num_signal_bufs = 2 - signal_bufs = tilelang.tensor((world_size * num_signal_bufs), dtype=torch.uint32, allocator=allocator, return_peers=True) - symm_barriers = tilelang.tensor((local_world_size, ), torch.int32, allocator=allocator, return_peers=True) + signal_bufs = tilelang.tensor((world_size * num_signal_bufs), + dtype=torch.uint32, + allocator=allocator, + return_peers=True) + symm_barriers = tilelang.tensor((local_world_size,), + torch.int32, + allocator=allocator, + return_peers=True) symm_barriers[rank] = 0 - - counter_signal_buf = tilelang.tensor((local_world_size), dtype=torch.uint32, allocator=allocator, return_peers=True) + + counter_signal_buf = tilelang.tensor((local_world_size), + dtype=torch.uint32, + allocator=allocator, + return_peers=True) dist.barrier() @@ -147,41 +174,51 @@ def create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, dt num_sync_sms = 0 num_p2p_sms = 1 - ctx = ReduceScatter2DContext(max_M=max_M, N=N, rank=rank, world_size=world_size, local_world_size=local_world_size, - dtype=dtype, overlap_with_gemm=overlap_with_gemm, scatter_bufs=scatter_bufs, - rs_per_node_bufs=rs_per_node_bufs, p2p_bufs=p2p_bufs, signal_bufs=signal_bufs, - counter_bufs=counter_signal_buf, - barrier=symm_barriers, reduction_stream=reduction_stream, - num_sync_sms=num_sync_sms, num_p2p_sms=num_p2p_sms, - num_reduction_sms=num_reduction_sms) + ctx = ReduceScatter2DContext( + max_M=max_M, + N=N, + rank=rank, + world_size=world_size, + local_world_size=local_world_size, + dtype=dtype, + overlap_with_gemm=overlap_with_gemm, + scatter_bufs=scatter_bufs, + rs_per_node_bufs=rs_per_node_bufs, + p2p_bufs=p2p_bufs, + signal_bufs=signal_bufs, + counter_bufs=counter_signal_buf, + barrier=symm_barriers, + reduction_stream=reduction_stream, + num_sync_sms=num_sync_sms, + num_p2p_sms=num_p2p_sms, + num_reduction_sms=num_reduction_sms) return ctx @tilelang.jit -def kernel_ring_reduce_tma( - M_per_rank, - N, - block_M, - block_N, - begin_idx, - num_splits, - threads, - persistent=False, - dtype="float16", - accum_dtype="float"): +def kernel_ring_reduce_tma(M_per_rank, + N, + block_M, + block_N, + begin_idx, + num_splits, + threads, + persistent=False, + dtype="float16", + accum_dtype="float"): @T.prim_func def _kernel_ring_reduce_tma( - C: T.Tensor((M_per_rank * num_splits, N), dtype), - output: T.Tensor((M_per_rank, N), dtype), + C: T.Tensor((M_per_rank * num_splits, N), dtype), + output: T.Tensor((M_per_rank, N), dtype), ): - with T.Kernel(T.ceildiv(M_per_rank, block_M), T.ceildiv(N, block_N), - threads=threads) as (bx, by): + with T.Kernel( + T.ceildiv(M_per_rank, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): data_shared = T.alloc_shared((block_M, block_N), dtype) init_shared = T.alloc_shared((block_M, block_N), dtype) data_local = T.alloc_fragment((block_M, block_N), dtype) accum = T.alloc_fragment((block_M, block_N), accum_dtype) - + cur_rank = (begin_idx + 1) % num_splits T.copy(C[cur_rank * M_per_rank + bx * block_M, by * block_N], init_shared) T.copy(init_shared, accum) @@ -192,15 +229,17 @@ def _kernel_ring_reduce_tma( for i, j in T.Parallel(block_M, block_N): accum[i, j] += data_local[i, j] T.copy(accum, output[bx * block_M, by * block_N]) - + return _kernel_ring_reduce_tma - -def _wait_eq_cuda(signal_tensor: torch.Tensor, signal: int, stream: Optional[torch.cuda.Stream] = None, + +def _wait_eq_cuda(signal_tensor: torch.Tensor, + signal: int, + stream: Optional[torch.cuda.Stream] = None, require_i64=False): stream = stream or torch.cuda.current_stream() if signal_tensor.dtype in (torch.int32, torch.uint32): - (err, ) = cuda.cuStreamWaitValue32( + (err,) = cuda.cuStreamWaitValue32( stream.cuda_stream, signal_tensor.data_ptr(), signal, @@ -208,7 +247,7 @@ def _wait_eq_cuda(signal_tensor: torch.Tensor, signal: int, stream: Optional[tor ) CUDA_CHECK(err) elif signal_tensor.dtype in (torch.int64, torch.uint64): - (err, ) = cuda.cuStreamWaitValue64( + (err,) = cuda.cuStreamWaitValue64( stream.cuda_stream, signal_tensor.data_ptr(), signal, @@ -217,9 +256,13 @@ def _wait_eq_cuda(signal_tensor: torch.Tensor, signal: int, stream: Optional[tor CUDA_CHECK(err) else: raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") - -def intra_node_scatter(input_intra_node, scatter_bufs_intra_node: List[torch.Tensor], - scatter_signal_buf_intra_node: torch.Tensor, local_rank, overlap_with_gemm=True): + + +def intra_node_scatter(input_intra_node, + scatter_bufs_intra_node: List[torch.Tensor], + scatter_signal_buf_intra_node: torch.Tensor, + local_rank, + overlap_with_gemm=True): M, N = input_intra_node.shape local_world_size = len(scatter_bufs_intra_node) M_per_rank = M // local_world_size @@ -228,40 +271,16 @@ def intra_node_scatter(input_intra_node, scatter_bufs_intra_node: List[torch.Ten for i in range(local_world_size): remote_local_rank = (local_rank + i + 1) % local_world_size - + # print(f"scatter_signal_buf_intra_node[remote_local_rank]: {scatter_signal_buf_intra_node[remote_local_rank]}") if overlap_with_gemm: _wait_eq_cuda(scatter_signal_buf_intra_node[remote_local_rank], 1, stream) - src = input_intra_node[ - remote_local_rank * M_per_rank : (remote_local_rank + 1) * M_per_rank, : - ] - dst = scatter_bufs_intra_node[remote_local_rank][ - local_rank * M_per_rank : (local_rank + 1) * M_per_rank, : - ] + src = input_intra_node[remote_local_rank * M_per_rank:(remote_local_rank + 1) * + M_per_rank, :] + dst = scatter_bufs_intra_node[remote_local_rank][local_rank * M_per_rank:(local_rank + 1) * + M_per_rank, :] with torch.cuda.stream(stream): dst.copy_(src) - -def ring_reduce_non_tma( - input: torch.Tensor, # [M_per_node, N] - output: torch.Tensor, # [M_per_rank, N] - begin_idx, - num_splits, - num_sms=16, -): - total_M, N = input.shape - M_per_split = total_M // num_splits - assert output.shape[0] == M_per_split and total_M % num_splits == 0, output.shape - num_warps = 32 - kernel_ring_reduce_non_tma[(num_sms, )]( - input, - output, - M_per_split * N, - begin_idx, - num_splits, - BLOCK_SIZE=32 * num_warps * 8 * 4, - num_warps=num_warps, - ) - return output def ring_reduce_tma( @@ -273,13 +292,14 @@ def ring_reduce_tma( ): total_M, N = input.shape M_per_split = total_M // num_splits - assert output.shape[0] == M_per_split and total_M % num_splits == 0, f"{output.shape}, {total_M}, {num_splits}" + assert output.shape[ + 0] == M_per_split and total_M % num_splits == 0, f"{output.shape}, {total_M}, {num_splits}" def alloc_fn(size, alignment, stream): return torch.empty(size, device="cuda", dtype=torch.int8) if num_sms == -1: - ring_reduce_tma_func = kernel_ring_reduce_tma( + ring_reduce_tma_func = kernel_ring_reduce_tma( M_per_split, N, block_M=64, @@ -308,8 +328,10 @@ def alloc_fn(size, alignment, stream): return output + target = determine_target(return_object=True) - + + def ring_reduce( input, # [M_per_node, N] output, # [M_per_rank, N] @@ -320,9 +342,11 @@ def ring_reduce( if target_is_hopper(target): return ring_reduce_tma(input, output, begin_idx, num_splits, num_sms) else: - return ring_reduce_non_tma(input, output, begin_idx, num_splits, 16 if num_sms == -1 else num_sms) - -def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContext, + raise NotImplementedError("Only Hopper ring reduce is implemented now.") + + +def reduce_scatter_for_each_node(input: torch.Tensor, + ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): world_size = ctx.world_size local_world_size = ctx.local_world_size @@ -336,8 +360,10 @@ def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContex node_id = ctx.node_id rs_per_node_buf = ctx.rs_per_node_buf p2p_buf = ctx.p2p_buf - - print(f"reduce_scatter_for_each_node: M={M}, N={N}, M_per_rank={M_per_rank}, M_per_node={M_per_node}, nnodes={nnodes}, node_id={node_id}, local_rank={local_rank}") + + print( + f"reduce_scatter_for_each_node: M={M}, N={N}, M_per_rank={M_per_rank}, M_per_node={M_per_node}, nnodes={nnodes}, node_id={node_id}, local_rank={local_rank}" + ) stream = torch.cuda.current_stream() for n in range(0, nnodes): @@ -345,8 +371,12 @@ def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContex input_intra_node = input[cur_node_id * M_per_node:(cur_node_id + 1) * M_per_node] scatter_bufs_intra_node, scatter_signal_buf_intra_node = ctx.get_scatter_bufs_and_signal_for_each_node( input, cur_node_id) - intra_node_scatter(input_intra_node, scatter_bufs_intra_node, scatter_signal_buf_intra_node, local_rank, - overlap_with_gemm=ctx.overlap_with_gemm) + intra_node_scatter( + input_intra_node, + scatter_bufs_intra_node, + scatter_signal_buf_intra_node, + local_rank, + overlap_with_gemm=ctx.overlap_with_gemm) # ring reduce intra node rs_buf_cur_node = rs_per_node_buf[M_per_rank * cur_node_id:(cur_node_id + 1) * M_per_rank] @@ -354,25 +384,30 @@ def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContex reduction_stream.wait_stream(stream) with torch.cuda.stream(reduction_stream): reduce_out_buf = output if nnodes == 1 else rs_buf_cur_node - ring_reduce(scatter_bufs_intra_node[local_rank], reduce_out_buf, local_rank, local_world_size, - num_sms=-1 if n == nnodes - 1 else num_reduction_sms) - + ring_reduce( + scatter_bufs_intra_node[local_rank], + reduce_out_buf, + local_rank, + local_world_size, + num_sms=-1 if n == nnodes - 1 else num_reduction_sms) + # inter node p2p if nnodes > 1: - if n == nnodes - 1: - p2p_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)].copy_( - rs_per_node_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)]) - else: - grid = lambda META: (ctx.num_p2p_sms, ) - kernel_inter_node_p2p_for_same_local_rank[grid]( - n, - local_world_size, - M_per_rank, - N, - rs_per_node_buf, - p2p_buf, - num_warps=16, - ) + raise NotImplementedError("Inter-node p2p is not implemented yet.") + # if n == nnodes - 1: + # p2p_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)].copy_( + # rs_per_node_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)]) + # else: + # grid = lambda META: (ctx.num_p2p_sms,) + # kernel_inter_node_p2p_for_same_local_rank[grid]( + # n, + # local_world_size, + # M_per_rank, + # N, + # rs_per_node_buf, + # p2p_buf, + # num_warps=16, + # ) stream.wait_stream(reduction_stream) if nnodes == 1: @@ -380,7 +415,9 @@ def reduce_scatter_for_each_node(input: torch.Tensor, ctx: ReduceScatter2DContex return p2p_buf[:M_per_rank * nnodes] -def reduce_scatter_multi_node(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): +def reduce_scatter_multi_node(input: torch.Tensor, + ctx: ReduceScatter2DContext, + output: Optional[torch.Tensor] = None): """ A hierarchical reduce-scatter implementation that overlaps the intra-node scatter with the local reduce and the inter-node p2p(after reduce). It also provides a rank-wise @@ -409,7 +446,10 @@ def reduce_scatter_multi_node(input: torch.Tensor, ctx: ReduceScatter2DContext, ring_reduce(rs_result_per_node, output, ctx.node_id, ctx.nnodes) return output -def reduce_scatter_2d_op(input: torch.Tensor, ctx: ReduceScatter2DContext, output: Optional[torch.Tensor] = None): + +def reduce_scatter_2d_op(input: torch.Tensor, + ctx: ReduceScatter2DContext, + output: Optional[torch.Tensor] = None): M, N = input.shape assert input.dtype == ctx.dtype assert ctx.max_M >= M and ctx.N == N diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 646b300e5b..a3c8a024e2 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -312,14 +312,10 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(atom_add) - .set_num_inputs(4) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(atom_add).set_num_inputs(4).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(st) - .set_num_inputs(4) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(st).set_num_inputs(4).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index d942e8ab8b..8c38f3c719 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1517,8 +1517,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << "tl::wait_eq(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ");\n"; } else if (op->op.same_as(tl::atom_add())) { - std::string func_name = "tl::ptx_atom_add_" + op->args[2].as()->value + - "_" + op->args[3].as()->value; + std::string func_name = "tl::ptx_atom_add_" + + op->args[2].as()->value + "_" + + op->args[3].as()->value; os << func_name << "(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; } else if (op->op.same_as(tl::st())) { diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 97f0608f69..fe46070201 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -205,7 +205,8 @@ TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { namespace tl { -TL_DEVICE uint32_t ptx_atom_add_relaxed_gpu(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_relaxed_gpu(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.relaxed.gpu.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -213,7 +214,8 @@ TL_DEVICE uint32_t ptx_atom_add_relaxed_gpu(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_acquire_gpu(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_acquire_gpu(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.acquire.gpu.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -221,7 +223,8 @@ TL_DEVICE uint32_t ptx_atom_add_acquire_gpu(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_release_gpu(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_release_gpu(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.release.gpu.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -229,7 +232,8 @@ TL_DEVICE uint32_t ptx_atom_add_release_gpu(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_acq_rel_gpu(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_acq_rel_gpu(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.acq_rel.gpu.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -237,7 +241,8 @@ TL_DEVICE uint32_t ptx_atom_add_acq_rel_gpu(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.relaxed.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -245,7 +250,8 @@ TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_acquire_sys(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_acquire_sys(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.acquire.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -253,7 +259,8 @@ TL_DEVICE uint32_t ptx_atom_add_acquire_sys(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_release_sys(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_release_sys(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.release.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) @@ -261,11 +268,12 @@ TL_DEVICE uint32_t ptx_atom_add_release_sys(const uint32_t *ptr, uint32_t value) return ret; } -TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(const uint32_t *ptr, uint32_t value) { +TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(const uint32_t *ptr, + uint32_t value) { uint32_t ret; asm volatile("atom.add.acq_rel.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -} // namespace tl \ No newline at end of file +} // namespace tl \ No newline at end of file diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index 9cf86fb9ad..b7b4b8cb9b 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -188,30 +188,30 @@ template TL_DEVICE void wait_eq(void *barrier, T val = 1) { TL_DEVICE void st_release_gpu(uint32_t *ptr, uint32_t value) { asm volatile("st.release.gpu.global.b32 [%0], %1;" - : - : "l"(ptr), "r"(value) - : "memory"); + : + : "l"(ptr), "r"(value) + : "memory"); } TL_DEVICE void st_relaxed_gpu(uint32_t *ptr, uint32_t value) { asm volatile("st.relaxed.gpu.global.b32 [%0], %1;" - : - : "l"(ptr), "r"(value) - : "memory"); + : + : "l"(ptr), "r"(value) + : "memory"); } TL_DEVICE void st_release_sys(uint32_t *ptr, uint32_t value) { asm volatile("st.release.sys.global.b32 [%0], %1;" - : - : "l"(ptr), "r"(value) - : "memory"); + : + : "l"(ptr), "r"(value) + : "memory"); } TL_DEVICE void st_relaxed_sys(uint32_t *ptr, uint32_t value) { asm volatile("st.relaxed.sys.global.b32 [%0], %1;" - : - : "l"(ptr), "r"(value) - : "memory"); + : + : "l"(ptr), "r"(value) + : "memory"); } } // namespace tl diff --git a/tilelang/distributed/utils.py b/tilelang/distributed/utils.py index 1ac69f722f..0836e289fe 100644 --- a/tilelang/distributed/utils.py +++ b/tilelang/distributed/utils.py @@ -298,19 +298,11 @@ def cuda_stream_max_priority(): CUDA_CHECK(ret[0]) return ret[2] # (leastPriority, greatestPriority) -> greatestPriority is max priority -def CUDA_CHECK(err): - if isinstance(err, cuda.CUresult): - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError(f"Cuda Error: {err}: {cuda.cuGetErrorName(err)}") - elif isinstance(err, cudart.cudaError_t): - if err != cudart.cudaError_t.cudaSuccess: - raise RuntimeError(f"Cuda Error: {err}: {cudart.cudaGetErrorString(err)}") - else: - raise RuntimeError(f"Unknown error type: {err}") _pynvml_initialized = False _lock = Lock() + def ensure_nvml_initialized(): global _pynvml_initialized if not _pynvml_initialized: @@ -320,8 +312,9 @@ def ensure_nvml_initialized(): pynvml.nvmlInit() _pynvml_initialized = True - -@functools.lru_cache() + + +@functools.lru_cache def has_fullmesh_nvlink_pynvml(): num_devices = torch.cuda.device_count() @@ -336,7 +329,8 @@ def has_fullmesh_nvlink_pynvml(): if remote_device == cur_device: continue remote_handle = handles[remote_device] - p2p_status = pynvml.nvmlDeviceGetP2PStatus(cur_handle, remote_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + p2p_status = pynvml.nvmlDeviceGetP2PStatus(cur_handle, remote_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK) if p2p_status != pynvml.NVML_P2P_STATUS_OK: return False return True @@ -371,7 +365,7 @@ def get_gpu_numa_node(gpu_index=0): pci_address = pci_id.replace("00000000:", "").lower() numa_node_path = f"/sys/bus/pci/devices/0000:{pci_address}/numa_node" - with open(numa_node_path, "r") as f: + with open(numa_node_path) as f: numa_node = int(f.read().strip()) assert numa_node >= 0 @@ -380,17 +374,20 @@ def get_gpu_numa_node(gpu_index=0): except Exception as e: print(f"Error: {e}") return -1 - -@functools.lru_cache() + + +@functools.lru_cache def has_fullmesh_nvlink(): try: return has_fullmesh_nvlink_pynvml() except Exception: nvlink_matrix = NvidiaSmiUtil.get_nvlink_adjacency_matrix() has_nvlink = any([any(x == 1 for x in row) for row in nvlink_matrix]) - _has_fullmesh_nvlink = all([i == j or v == 1 for i, row in enumerate(nvlink_matrix) for j, v in enumerate(row)]) + _has_fullmesh_nvlink = all( + [i == j or v == 1 for i, row in enumerate(nvlink_matrix) for j, v in enumerate(row)]) if has_nvlink and not _has_fullmesh_nvlink: warnings.warn( - "⚠️ found NVLink but not fullmesh NVLink, this may cause undefined behavior, please check your GPU topology" + "⚠️ found NVLink but not fullmesh NVLink, this may cause undefined behavior, please check your GPU topology", + stacklevel=2, ) - return _has_fullmesh_nvlink \ No newline at end of file + return _has_fullmesh_nvlink diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 11ed7893a5..1b088beff0 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -731,8 +731,10 @@ def atom_add(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." - assert sem in ["relaxed", "acquire", "release", "acq_rel"], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." - return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem, scope) + assert sem in ["relaxed", "acquire", "release", "acq_rel" + ], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." + return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem, + scope) def st(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = "relaxed"): @@ -749,4 +751,4 @@ def st(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = "relax """ assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." assert sem in ["relaxed", "release"], "Semantic must be one of 'relaxed', or 'release'." - return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(barrier), value, sem, scope) \ No newline at end of file + return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(barrier), value, sem, scope) From 2ce2af8fa084eaff11cd02d0cc3d3df9aee1e126 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Thu, 30 Oct 2025 15:23:09 +0800 Subject: [PATCH 4/4] lint --- examples/distributed/reduce_scatter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/distributed/reduce_scatter.py b/examples/distributed/reduce_scatter.py index 7527e5c71c..fcb8e997f4 100644 --- a/examples/distributed/reduce_scatter.py +++ b/examples/distributed/reduce_scatter.py @@ -361,10 +361,6 @@ def reduce_scatter_for_each_node(input: torch.Tensor, rs_per_node_buf = ctx.rs_per_node_buf p2p_buf = ctx.p2p_buf - print( - f"reduce_scatter_for_each_node: M={M}, N={N}, M_per_rank={M_per_rank}, M_per_node={M_per_node}, nnodes={nnodes}, node_id={node_id}, local_rank={local_rank}" - ) - stream = torch.cuda.current_stream() for n in range(0, nnodes): cur_node_id = (node_id + n + 1) % nnodes