diff --git a/.claude/skills/ncu-report/SKILL.md b/.claude/skills/ncu-report/SKILL.md index b84c26cc..9fbccf1f 100644 --- a/.claude/skills/ncu-report/SKILL.md +++ b/.claude/skills/ncu-report/SKILL.md @@ -14,7 +14,7 @@ user-invocable: true This skill handles two modes: -1. **Analyze an existing report**: The user provides a path to an `.ncu-rep` file (or one exists under `examples/`). Use the `ncu` CLI to extract and present metrics. All `ncu` commands MUST use `TMPDIR=/tmp/ncu_tmp` prefix to avoid temp file errors. +1. **Analyze an existing report**: The user provides a path to an `.ncu-rep` file (or one exists under `examples/`). Use the `ncu` CLI to extract and present metrics. 2. **Generate a new report**: The user specifies a script or kernel to profile but does NOT provide a `.ncu-rep` file. In this case, set up profiling using `tilus.utils.ncu_utils.ncu_run()`, run it, then analyze the resulting report. @@ -84,10 +84,10 @@ Run these in parallel: ```bash # List all kernels with timing -TMPDIR=/tmp/ncu_tmp ncu -i --page raw --csv --metrics gpu__time_duration.sum 2>&1 +ncu -i --page raw --csv --metrics gpu__time_duration.sum 2>&1 # Session/device info -TMPDIR=/tmp/ncu_tmp ncu -i --page session --csv 2>&1 +ncu -i --page session --csv 2>&1 ``` Present a summary table: @@ -97,7 +97,7 @@ Present a summary table: ### Step 2: Speed of Light — Top-level throughput ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page details --csv --section SpeedOfLight 2>&1 +ncu -i --page details --csv --section SpeedOfLight 2>&1 ``` Key metrics to highlight per kernel: @@ -113,10 +113,10 @@ Key metrics to highlight per kernel: ```bash # Compute workload -TMPDIR=/tmp/ncu_tmp ncu -i --page details --csv --section ComputeWorkloadAnalysis 2>&1 +ncu -i --page details --csv --section ComputeWorkloadAnalysis 2>&1 # Memory workload -TMPDIR=/tmp/ncu_tmp ncu -i --page details --csv --section MemoryWorkloadAnalysis 2>&1 +ncu -i --page details --csv --section MemoryWorkloadAnalysis 2>&1 ``` Key compute metrics: Executed IPC Active, SM Busy %, Issue Slots Busy % @@ -125,7 +125,7 @@ Key memory metrics: Mem Busy %, Max Bandwidth %, L1/L2 hit rates ### Step 4: Occupancy ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page details --csv --section Occupancy 2>&1 +ncu -i --page details --csv --section Occupancy 2>&1 ``` Report: Theoretical Occupancy, Achieved Occupancy, and limiters (registers, shared memory, block size). @@ -134,24 +134,24 @@ Report: Theoretical Occupancy, Achieved Occupancy, and limiters (registers, shar To extract specific raw metrics: ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page raw --csv --metrics ,,... 2>&1 +ncu -i --page raw --csv --metrics ,,... 2>&1 ``` To filter by kernel: ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page raw --csv --metrics --kernel-name regex: 2>&1 +ncu -i --page raw --csv --metrics --kernel-name regex: 2>&1 ``` ### Step 6: Source-level analysis (on demand) SASS-only (default, always available): ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page source --csv --kernel-name regex: 2>&1 +ncu -i --page source --csv --kernel-name regex: 2>&1 ``` CUDA source correlated with SASS (requires `--import-source yes` during profiling): ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page source --csv --print-source cuda,sass --kernel-name regex: 2>&1 +ncu -i --page source --csv --print-source cuda,sass --kernel-name regex: 2>&1 ``` Source output columns include per-instruction: Warp Stall Sampling, Instructions Executed, Thread Instructions Executed, stall reasons (stall_barrier, stall_math, stall_wait, etc.), shared memory conflicts, and more. @@ -160,12 +160,12 @@ Source output columns include per-instruction: Warp Stall Sampling, Instructions Rules are included in the details page output. Look for non-empty "Rule Name" column entries. ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page details --csv --print-rule-details 2>&1 | grep -v '^"[0-9]' | head -5 # header +ncu -i --page details --csv --print-rule-details 2>&1 | grep -v '^"[0-9]' | head -5 # header ``` To see all rule results with descriptions: ```bash -TMPDIR=/tmp/ncu_tmp ncu -i --page details --csv --print-rule-details 2>&1 +ncu -i --page details --csv --print-rule-details 2>&1 ``` Filter for rows where column 17 (Rule Name) is non-empty. diff --git a/CLAUDE.md b/CLAUDE.md index d1a2ecd4..4331facc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -70,7 +70,7 @@ Call `tilus.option.debug.dump_ir()` before running the kernel. The IR after each ```python ... self.store_shared(s_c, ...) - self.fence.async_view(space="shared") # fence.proxy.async.shared::cta + self.fence.proxy_async(space="shared") # fence.proxy.async.shared::cta self.sync() with self.single_thread(): self.tma.shared_to_global(s_c, g_c, ...) diff --git a/examples/blackwell_matmul/matmul_v5.py b/examples/blackwell_matmul/matmul_v5.py index 67772d54..2023b324 100644 --- a/examples/blackwell_matmul/matmul_v5.py +++ b/examples/blackwell_matmul/matmul_v5.py @@ -233,7 +233,7 @@ def __call__( r_acc = self.tcgen05.load(t_acc) self.tcgen05.wait_load() self.store_shared(s_c, r_acc.to(float16)) - self.fence.async_view() + self.fence.proxy_async() self.sync() with self.single_thread(): self.tma.shared_to_global( diff --git a/examples/blackwell_matmul/matmul_v6.py b/examples/blackwell_matmul/matmul_v6.py index 9cd499da..90a36025 100644 --- a/examples/blackwell_matmul/matmul_v6.py +++ b/examples/blackwell_matmul/matmul_v6.py @@ -278,7 +278,7 @@ def __call__( r_acc = self.tcgen05.load(t_acc) self.tcgen05.wait_load() self.store_shared(s_c, r_acc.to(float16)) - self.fence.async_view() + self.fence.proxy_async() self.sync() with self.single_thread(): self.tma.shared_to_global( diff --git a/examples/blackwell_matmul/matmul_v7.py b/examples/blackwell_matmul/matmul_v7.py index fa815fb6..eee7ae56 100644 --- a/examples/blackwell_matmul/matmul_v7.py +++ b/examples/blackwell_matmul/matmul_v7.py @@ -95,7 +95,7 @@ def query_clc_response(self, s_clc_response: SharedTensor, pipe: Pipeline): pipe.consumer_acquire(scope="cluster") response = s_clc_response[pipe.consumer_stage] is_valid, new_blockIdx = self.clc.query_response(response) - self.fence.async_view(space="shared") + self.fence.proxy_async(space="shared") self.mbarrier.arrive_and_expect_tx_remote( pipe.consumer_barrier(), transaction_bytes=0, target_rank=0 ) @@ -278,7 +278,7 @@ def __call__( r_acc = self.tcgen05.load(t_acc_slice) self.tcgen05.wait_load() self.store_shared(s_c, r_acc.to(float16)) - self.fence.async_view(space="shared") + self.fence.proxy_async(space="shared") self.sync() with self.single_thread(): self.tma.shared_to_global( diff --git a/examples/blackwell_matmul/matmul_v8.py b/examples/blackwell_matmul/matmul_v8.py index c07766e7..dc9a762a 100644 --- a/examples/blackwell_matmul/matmul_v8.py +++ b/examples/blackwell_matmul/matmul_v8.py @@ -31,11 +31,12 @@ def __init__( self.producer_phase: uint32 = self.mbarrier.producer_initial_phase self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase - def producer_acquire(self, scope: str = "cta"): + def producer_acquire(self): self.mbarrier.wait( barrier=self.empty_barriers[self.producer_stage], phase=self.producer_phase, - scope=scope, + sem="relaxed", + scope="cta", ) def producer_barrier(self) -> RegisterTensor: @@ -45,11 +46,12 @@ def producer_advance(self): self.producer_stage = (self.producer_stage + 1) % self.num_stages self.producer_phase = self.producer_phase ^ (self.producer_stage == 0) - def consumer_acquire(self, scope: str = "cta"): + def consumer_acquire(self): self.mbarrier.wait( barrier=self.full_barriers[self.consumer_stage], phase=self.consumer_phase, - scope=scope, + sem="relaxed", + scope="cta", ) def consumer_barrier(self) -> RegisterTensor: @@ -67,15 +69,15 @@ def consumer_advance(self): @tilus.autotune("mma_stages", [2]) @tilus.autotune("swizzle_size", [4, 8, 16]) class BlackwellMatmulV8(tilus.Script): - debug_schedule = dict( - block_m=256, - block_n=256, - block_k=64, - tma_stages=6, - mma_stages=2, - e_block_n=32, - swizzle_size=8, - ) + # debug_schedule = dict( + # block_m=256, + # block_n=256, + # block_k=64, + # tma_stages=6, + # mma_stages=2, + # e_block_n=32, + # swizzle_size=8, + # ) def __init__( self, @@ -113,10 +115,10 @@ def compute_block_coord( return m_block, n_block def query_clc_response(self, s_clc_response: SharedTensor, pipe: Pipeline): - pipe.consumer_acquire(scope="cluster") + pipe.consumer_acquire() response = s_clc_response[pipe.consumer_stage] is_valid, new_blockIdx = self.clc.query_response(response) - self.fence.async_view(space="shared") + self.fence.proxy_async(space="shared") self.mbarrier.arrive_and_expect_tx_remote( pipe.consumer_barrier(), transaction_bytes=0, target_rank=0 ) @@ -270,9 +272,7 @@ def __call__( with self.single_warp(2): # scheduler while True: if cta_rank == 0: - clc_pipe.producer_acquire( - scope="cluster" - ) # peer cta will arrive this barrier, need 'cluster'scoped acquire + clc_pipe.producer_acquire() self.mbarrier.arrive_and_expect_tx_multicast( clc_pipe.producer_barrier(), transaction_bytes=16, @@ -310,7 +310,7 @@ def __call__( r_acc = self.tcgen05.load(t_acc_slice) self.tcgen05.wait_load() self.store_shared(s_c, r_acc.to(float16)) - self.fence.async_view(space="shared") + self.fence.proxy_async(space="shared") self.sync() with self.single_thread(): self.tma.shared_to_global( @@ -320,7 +320,7 @@ def __call__( dims=[0, 1], ) self.tma.commit_group() - self.tma.wait_group(n=0) + self.tma.wait_group(n=0, read=True) self.sync() self.mbarrier.arrive(mma_pipe.consumer_barrier()) @@ -372,7 +372,7 @@ def main(bench=True): ("torch", lambda: torch.matmul(a, b.T, out=c_expected)), ("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c_actual)), ]: - latency = benchmark_func(func, warmup=5, repeat=20) + latency = benchmark_func(func, warmup=5, repeat=200) tflops = 2 * m_size * n_size * k_size / latency * 1e-9 rows.append([m_size, n_size, k_size, name, latency, tflops]) @@ -384,4 +384,4 @@ def main(bench=True): if __name__ == "__main__": main(bench=True) # main(bench=False) - # ncu_run(main, bench=False, kernel_regex="hidet|nvjet") + # ncu_run(main, bench=False, kernel_regex="tilus|nvjet") diff --git a/examples/blackwell_matmul/matmul_v9.py b/examples/blackwell_matmul/matmul_v9.py deleted file mode 100644 index a06f2d9e..00000000 --- a/examples/blackwell_matmul/matmul_v9.py +++ /dev/null @@ -1,463 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -import os - -import pandas -import tilus -import torch -from tilus import RegisterTensor, SharedTensor, float16, float32, int32, uint32 -from tilus.utils import benchmark_func, cdiv -from tilus.utils.ncu_utils import ncu_run - -tilus.option.cache_dir(os.path.join(os.path.dirname(__file__), "cache")) -tilus.option.debug.dump_ir() - - -class Pipeline(tilus.Class): - def __init__( - self, - num_stages: int, - producer_arrive_count: int = 1, - consumer_arrive_count: int = 1, - ): - self.num_stages: int = num_stages - self.empty_barriers = self.mbarrier.alloc( - [consumer_arrive_count for _ in range(num_stages)] - ) - self.full_barriers = self.mbarrier.alloc( - [producer_arrive_count for _ in range(num_stages)] - ) - self.producer_stage: int32 = 0 - self.consumer_stage: int32 = 0 - self.producer_phase: uint32 = self.mbarrier.producer_initial_phase - self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase - - def producer_acquire(self, scope: str = "cta", sem: str = "acquire"): - self.mbarrier.wait( - barrier=self.empty_barriers[self.producer_stage], - phase=self.producer_phase, - sem=sem, - scope=scope, - ) - - def producer_barrier(self) -> RegisterTensor: - return self.full_barriers[self.producer_stage] - - def producer_advance(self): - self.producer_stage = (self.producer_stage + 1) % self.num_stages - self.producer_phase = self.producer_phase ^ (self.producer_stage == 0) - - def consumer_acquire(self, scope: str = "cta", sem: str = "acquire"): - self.mbarrier.wait( - barrier=self.full_barriers[self.consumer_stage], - phase=self.consumer_phase, - sem=sem, - scope=scope, - ) - - def consumer_barrier(self) -> RegisterTensor: - return self.empty_barriers[self.consumer_stage] - - def consumer_advance(self): - self.consumer_stage = (self.consumer_stage + 1) % self.num_stages - self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0) - - -@tilus.autotune("block_m", [256]) -@tilus.autotune("block_n, e_block_n", [[256, 16], [256, 32]]) -@tilus.autotune("block_k", [64]) -@tilus.autotune("tma_stages", [3, 4]) -@tilus.autotune("swizzle_size", [4, 8, 16]) -class BlackwellMatmulV9(tilus.Script): - """ - nvjet-style schedule: each cluster (2 CTAs) computes [block_m, block_n x 2] - by reusing A tiles across two adjacent N-tiles. This reduces TMA loads from - 8 to 6 per k-step for a [block_m, block_n x 2] output region. - - Per k-step, each CTA issues 3 TMA loads (vs 2 in v8): - - a: (block_m//2, block_k) -- shared across both N-tiles - - b_n0: (block_n//2, block_k) -- B for first N-tile - - b_n1: (block_n//2, block_k) -- B for second N-tile - - The tmem budget formerly used for mma_stages=2 double-buffering is now used - for two accumulators (one per N-tile), keeping the total tmem the same as v8. - tma_stages is reduced from 6 to 4 to keep smem usage constant. - """ - - debug_schedule = dict( - block_m=256, - block_n=256, - block_k=64, - tma_stages=4, - e_block_n=32, - swizzle_size=4, - ) - - def __init__( - self, - block_m: int, - block_n: int, - block_k: int, - tma_stages: int, - e_block_n: int, - swizzle_size: int, - ): - super().__init__() - self.block_m = block_m - self.block_n = block_n - self.block_k = block_k - self.e_block_n = e_block_n - self.tma_stages = tma_stages - self.swizzle_size = swizzle_size - self.clc_stages = 1 - - def compute_block_coord( - self, linear_idx: int32, num_m_blocks: int32, num_n_block_pairs: int - ): - swizzle_size = self.swizzle_size - tiles_per_group = num_m_blocks * swizzle_size - group_idx = linear_idx // tiles_per_group - in_group_idx = linear_idx % tiles_per_group - first_n = group_idx * swizzle_size - group_width = num_n_block_pairs - first_n - if group_width > swizzle_size: - group_width = swizzle_size - m_block = in_group_idx // group_width - n_block_pair = first_n + in_group_idx % group_width - return m_block, n_block_pair - - def query_clc_response(self, s_clc_response: SharedTensor, pipe: Pipeline): - pipe.consumer_acquire(scope="cluster", sem="relaxed") - response = s_clc_response[pipe.consumer_stage] - is_valid, new_blockIdx = self.clc.query_response(response) - self.fence.async_view(space="shared") - self.mbarrier.arrive_and_expect_tx_remote( - pipe.consumer_barrier(), transaction_bytes=0, target_rank=0 - ) - pipe.consumer_advance() - return is_valid, new_blockIdx - - def __call__( - self, - m_size: int32, - n_size: int, - k_size: int, - a_ptr: ~float16, - b_ptr: ~float16, - c_ptr: ~float16, - ): - """ - Each cluster (2 CTAs) computes [block_m, block_n x 2] output by reusing - A tiles across two adjacent N-tiles. - - CTA0 = CTA with last bit of cluster rank = 0 - CTA1 = CTA with last bit of cluster rank = 1 - - Input B (K, N) - ┌──────┬──────┬──────┬──────┐ - │b0_n0 │b1_n0 │b0_n1 │b1_n1 │ - │[CTA0]│[CTA1]│[CTA0]│[CTA1]│ - └──────┴──────┴──────┴──────┘ - N-tile 0 N-tile 1 - ┌──────────────┐ ┌──────────────┬──────────────┐ - │ a0 (M/2, K) │ │ d0 (M/2, N) │ d0 (M/2, N) │ - │ [CTA0] │ │ [CTA0] │ [CTA0] │ - ├──────────────┤ ├──────────────┼──────────────┤ - │ a1 (M/2, K) │ │ d1 (M/2, N) │ d1 (M/2, N) │ - │ [CTA1] │ │ [CTA1] │ [CTA1] │ - └──────────────┘ └──────────────┴──────────────┘ - Input A (M, K) Output D (M, N x 2) - - Per k-step TMA loads (6 total, vs 4 in v8 for same block_n): - CTA0: a0, b0_n0, b0_n1 (3 loads) - CTA1: a1, b1_n0, b1_n1 (3 loads) - """ - num_m_blocks = cdiv(m_size, self.block_m) - num_n_blocks = cdiv(n_size, self.block_n) - num_n_block_pairs = cdiv(num_n_blocks, 2) - - # Each cluster handles 2 adjacent N-tiles - self.attrs.blocks = num_m_blocks * num_n_block_pairs * 2, 1 - self.attrs.cluster_blocks = 2 - self.attrs.warps = 8 - - block_m = self.block_m - block_n = self.block_n - block_k = self.block_k - e_block_n = self.e_block_n - tma_stages = self.tma_stages - clc_stages = self.clc_stages - - g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) - g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) - g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - - s_a = self.shared_tensor(dtype=float16, shape=[tma_stages, block_m // 2, block_k]) - s_b_n0 = self.shared_tensor( - dtype=float16, shape=[tma_stages, block_n // 2, block_k] - ) - s_b_n1 = self.shared_tensor( - dtype=float16, shape=[tma_stages, block_n // 2, block_k] - ) - - # Two accumulators for two N-tiles (replaces mma_stages=2 double-buffering) - t_acc_n0 = self.tcgen05.alloc( - dtype=float32, shape=[block_m // 2, block_n], cta_group=2 - ) - t_acc_n1 = self.tcgen05.alloc( - dtype=float32, shape=[block_m // 2, block_n], cta_group=2 - ) - - s_clc_response = self.shared_tensor(dtype=int32, shape=[clc_stages, 4]) - - tma_pipe = Pipeline(tma_stages) - mma_pipe = Pipeline( - 1, consumer_arrive_count=128 - ) # mma_stages=1, 4 warps (epilogue warps) - clc_pipe = Pipeline( - clc_stages, consumer_arrive_count=224 * 2 - ) # 7 warps * 2 blocks - - cta_rank = self.cluster.blockRank - - self.cluster_sync() - - with self.single_warp(0): # tma worker (gmem -> smem) - m_block_0, n_pair_0 = self.compute_block_coord( - self.blockIdx.x // 2, num_m_blocks, num_n_block_pairs - ) - offset_m_a = (m_block_0 * 2 + cta_rank) * (block_m // 2) - offset_n_b0 = (n_pair_0 * 2) * block_n + cta_rank * (block_n // 2) - offset_n_b1 = (n_pair_0 * 2 + 1) * block_n + cta_rank * (block_n // 2) - while True: - for offset_k in range(0, k_size, block_k): - tma_pipe.producer_acquire() - mbarrier = tma_pipe.producer_barrier() - if cta_rank == 0: - with self.single_thread(): - # track all 6 TMA loads (3 per CTA x 2 CTAs) on CTA0's mbarrier - transaction_bytes = ( - s_a[0].nbytes + s_b_n0[0].nbytes + s_b_n1[0].nbytes - ) * 2 - self.mbarrier.arrive_and_expect_tx( - mbarrier, transaction_bytes - ) - else: - mbarrier = self.cluster.map_shared_addr(mbarrier, target_rank=0) - with self.single_thread(): - self.tma.global_to_shared( - src=g_a, - dst=s_a[tma_pipe.producer_stage], - offsets=[offset_m_a, offset_k], - mbarrier=mbarrier, - cta_group=2, - ) - self.tma.global_to_shared( - src=g_b, - dst=s_b_n0[tma_pipe.producer_stage], - offsets=[offset_n_b0, offset_k], - mbarrier=mbarrier, - cta_group=2, - ) - self.tma.global_to_shared( - src=g_b, - dst=s_b_n1[tma_pipe.producer_stage], - offsets=[offset_n_b1, offset_k], - mbarrier=mbarrier, - cta_group=2, - ) - tma_pipe.producer_advance() - - is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe) - if not is_valid: - break - m_block_0, n_pair_0 = self.compute_block_coord( - new_blockIdx.x // 2, num_m_blocks, num_n_block_pairs - ) - offset_m_a = (m_block_0 * 2 + cta_rank) * (block_m // 2) - offset_n_b0 = (n_pair_0 * 2) * block_n + cta_rank * (block_n // 2) - offset_n_b1 = (n_pair_0 * 2 + 1) * block_n + cta_rank * (block_n // 2) - - with self.single_warp(1): # mma worker (smem -> tmem) - while True: - with self.single_thread(): - if cta_rank == 0: - mma_pipe.producer_acquire() - for offset_k in range(0, k_size, block_k): - tma_pipe.consumer_acquire() - # MMA for first N-tile - self.tcgen05.mma( - s_a[tma_pipe.consumer_stage], - s_b_n0[tma_pipe.consumer_stage].transpose(), - t_acc_n0, - enable_input_d=offset_k != 0, - cta_group=2, - ) - # MMA for second N-tile (reuses same A) - self.tcgen05.mma( - s_a[tma_pipe.consumer_stage], - s_b_n1[tma_pipe.consumer_stage].transpose(), - t_acc_n1, - enable_input_d=offset_k != 0, - cta_group=2, - ) - self.tcgen05.commit( - mbarrier=tma_pipe.consumer_barrier(), - cta_group=2, - multicast_mask=0b11, - ) - tma_pipe.consumer_advance() - self.tcgen05.commit( - mbarrier=mma_pipe.producer_barrier(), - cta_group=2, - multicast_mask=0b11, - ) - mma_pipe.producer_advance() - - is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe) - if not is_valid: - break - - with self.single_warp(2): # scheduler - while True: - if cta_rank == 0: - clc_pipe.producer_acquire( - scope="cluster", sem="relaxed" - ) # peer cta will arrive this barrier, relaxed is sufficient (shared memory only) - self.mbarrier.arrive_and_expect_tx_multicast( - clc_pipe.producer_barrier(), - transaction_bytes=16, - multicast_mask=0b11, - ) - with self.single_thread(): - self.clc.try_cancel( - s_clc_response[clc_pipe.producer_stage], - mbarrier=clc_pipe.producer_barrier(), - multicast=True, - ) - clc_pipe.producer_advance() - - is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe) - if not is_valid: - break - - with self.warp_group(warp_begin=4, num_warps=4): # epilogue (tmem -> gmem) - s_c = self.shared_tensor(dtype=float16, shape=[block_m // 2, self.e_block_n]) - m_block_e, n_pair_e = self.compute_block_coord( - self.blockIdx.x // 2, num_m_blocks, num_n_block_pairs - ) - offset_m_c = (m_block_e * 2 + cta_rank) * (block_m // 2) - offset_n_c_n0 = (n_pair_e * 2) * block_n - offset_n_c_n1 = (n_pair_e * 2 + 1) * block_n - while True: - mma_pipe.consumer_acquire() - - # Store first N-tile - for e_offset_n in range(0, block_n, e_block_n): - t_acc_slice = self.tcgen05.slice( - t_acc_n0, - offsets=[0, e_offset_n], - shape=[block_m // 2, e_block_n], - dims=[0, 1], - ) - r_acc = self.tcgen05.load(t_acc_slice) - self.tcgen05.wait_load() - self.store_shared(s_c, r_acc.to(float16)) - self.fence.async_view(space="shared") - self.sync() - with self.single_thread(): - self.tma.shared_to_global( - s_c, - g_c, - offsets=[offset_m_c, offset_n_c_n0 + e_offset_n], - dims=[0, 1], - ) - self.tma.commit_group() - self.tma.wait_group(n=0, read=True) - self.sync() - - # Store second N-tile - for e_offset_n in range(0, block_n, e_block_n): - t_acc_slice = self.tcgen05.slice( - t_acc_n1, - offsets=[0, e_offset_n], - shape=[block_m // 2, e_block_n], - dims=[0, 1], - ) - r_acc = self.tcgen05.load(t_acc_slice) - self.tcgen05.wait_load() - self.store_shared(s_c, r_acc.to(float16)) - self.fence.async_view(space="shared") - self.sync() - with self.single_thread(): - self.tma.shared_to_global( - s_c, - g_c, - offsets=[offset_m_c, offset_n_c_n1 + e_offset_n], - dims=[0, 1], - ) - self.tma.commit_group() - self.tma.wait_group(n=0, read=True) - self.sync() - - self.mbarrier.arrive(mma_pipe.consumer_barrier()) - mma_pipe.consumer_advance() - - is_valid, new_blockIdx = self.query_clc_response(s_clc_response, clc_pipe) - if not is_valid: - break - m_block_e, n_pair_e = self.compute_block_coord( - new_blockIdx.x // 2, num_m_blocks, num_n_block_pairs - ) - offset_m_c = (m_block_e * 2 + cta_rank) * (block_m // 2) - offset_n_c_n0 = (n_pair_e * 2) * block_n - offset_n_c_n1 = (n_pair_e * 2 + 1) * block_n - - # all allocated tensor memory must be deallocated - self.sync() - self.tcgen05.dealloc(t_acc_n0) - self.tcgen05.dealloc(t_acc_n1) - - -def main(bench=True): - matmul = BlackwellMatmulV9() - - headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] - rows: list = [] - - for m_size, n_size, k_size in [ - [10240, 10240, 10240], - ]: - print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}") - a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda") - b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda") - c_actual = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") - c_expected = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda") - - torch.matmul(a, b.T, out=c_expected) - torch.cuda.synchronize() - - matmul(m_size, n_size, k_size, a, b, c_actual) - torch.cuda.synchronize() - - torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) - - # benchmark - if bench: - for name, func in [ - ("torch", lambda: torch.matmul(a, b.T, out=c_expected)), - ("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c_actual)), - ]: - latency = benchmark_func(func, warmup=5, repeat=200) - tflops = 2 * m_size * n_size * k_size / latency * 1e-9 - rows.append([m_size, n_size, k_size, name, latency, tflops]) - - if bench: - df = pandas.DataFrame(rows, columns=headers) - print(df) - - -if __name__ == "__main__": - # main(bench=True) - # main(bench=False) - ncu_run(main, bench=False, kernel_regex="tilus|nvjet") diff --git a/python/tilus/backends/emitters/cuda/mbarrier.py b/python/tilus/backends/emitters/cuda/mbarrier.py index 3d2c7cca..02bd95a9 100644 --- a/python/tilus/backends/emitters/cuda/mbarrier.py +++ b/python/tilus/backends/emitters/cuda/mbarrier.py @@ -16,14 +16,14 @@ from tilus.backends.emitter import BaseInstEmitter, register_emitter from tilus.hidet.ir.dtypes import uint32, uint64 from tilus.hidet.ir.expr import Var -from tilus.hidet.ir.primitives.cuda.fence import fence_view_async +from tilus.hidet.ir.primitives.cuda.fence import fence_proxy_async, fence_proxy_async_generic_release_shared from tilus.hidet.ir.primitives.cuda.mapa import mapa_shared from tilus.hidet.ir.primitives.cuda.mbarrier import ( mbarrier_arrive, mbarrier_arrive_expect_tx, mbarrier_wait, ) -from tilus.ir.instructions.cuda.fence import FenceViewAsync +from tilus.ir.instructions.cuda.fence import FenceProxyAsync, FenceProxyAsyncRelease from tilus.ir.instructions.cuda.mbarrier import ( AllocBarrierInst, ArriveBarrierInst, @@ -111,7 +111,13 @@ def emit(self, inst: WaitBarrierInst) -> None: self.append(mbarrier_wait(inst.barrier, inst.phase, sem=inst.sem, scope=inst.scope)) -@register_emitter(FenceViewAsync, target=nvgpu_sm80) -class FenceViewAsyncEmitter(BaseInstEmitter): - def emit(self, inst: FenceViewAsync) -> None: - self.append(fence_view_async(scope=inst.space)) +@register_emitter(FenceProxyAsync, target=nvgpu_sm80) +class FenceProxyAsyncEmitter(BaseInstEmitter): + def emit(self, inst: FenceProxyAsync) -> None: + self.append(fence_proxy_async(space=inst.space)) + + +@register_emitter(FenceProxyAsyncRelease, target=nvgpu_sm90) +class FenceProxyAsyncReleaseEmitter(BaseInstEmitter): + def emit(self, inst: FenceProxyAsyncRelease) -> None: + self.append(fence_proxy_async_generic_release_shared()) diff --git a/python/tilus/hidet/ir/primitives/cuda/fence.py b/python/tilus/hidet/ir/primitives/cuda/fence.py index 62cdb837..93c8c367 100644 --- a/python/tilus/hidet/ir/primitives/cuda/fence.py +++ b/python/tilus/hidet/ir/primitives/cuda/fence.py @@ -43,7 +43,17 @@ def cuda_fence_proxy_async(): register_primitive_function(name=func_name, func_or_type=cuda_fence_proxy_async) - for func in [cuda_fence_mbarrier_init_cluster]: + @no_type_check + @script + def cuda_fence_proxy_async_generic_release_shared(): + attrs.func_kind = "cuda_internal" + asm( + "fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster;", + is_volatile=True, + memory_fence=True, + ) + + for func in [cuda_fence_mbarrier_init_cluster, cuda_fence_proxy_async_generic_release_shared]: register_primitive_function(func.name, func) @@ -58,14 +68,16 @@ def fence_mbarrier_init_cluster() -> Expr: return call_primitive_func("cuda_fence_mbarrier_init_cluster", args=[]) -def fence_view_async(scope: str) -> Expr: +def fence_proxy_async(space: str) -> Expr: """ - Emit a proxy fence for async memory operations. + Emit a bidirectional proxy fence for async memory operations. + + PTX: fence.proxy.async.{space} Parameters ---------- - scope : str - The scope of the fence: 'shared' for fence.proxy.async.shared::cta, + space : str + The space of the fence: 'shared' for fence.proxy.async.shared::cta, 'global' for fence.proxy.async.global. Returns @@ -73,5 +85,22 @@ def fence_view_async(scope: str) -> Expr: ret : Expr A call expression that performs the fence operation. """ - func_name = "cuda_fence_proxy_async_{}".format(scope) + func_name = "cuda_fence_proxy_async_{}".format(space) return call_primitive_func(func_name, args=[]) + + +def fence_proxy_async_generic_release_shared() -> Expr: + """ + Emit a unidirectional generic-to-async release proxy fence for shared memory. + + PTX: fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster + + This is a lighter-weight fence that only ensures prior generic proxy writes + to shared::cta memory are visible to subsequent async proxy reads. + + Returns + ------- + ret : Expr + A call expression that performs the fence operation. + """ + return call_primitive_func("cuda_fence_proxy_async_generic_release_shared", args=[]) diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 98dc3859..efe29df6 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -45,7 +45,7 @@ CopyAsyncTensorSharedToGlobalInst, CopyAsyncTensorWaitGroupInst, ) -from tilus.ir.instructions.cuda.fence import FenceViewAsync +from tilus.ir.instructions.cuda.fence import FenceProxyAsync, FenceProxyAsyncRelease from tilus.ir.instructions.cuda.mapa import MapSharedAddrInst from tilus.ir.instructions.cuda.mbarrier import ( AllocBarrierInst, @@ -1272,8 +1272,12 @@ def arrive_expect_tx_remote_barrier( ) self.append(inst) - def fence_view_async(self, space: str) -> None: - inst = FenceViewAsync.create(scope=space) + def fence_proxy_async(self, space: str) -> None: + inst = FenceProxyAsync.create(space=space) + self.append(inst) + + def fence_proxy_async_release(self) -> None: + inst = FenceProxyAsyncRelease.create() self.append(inst) def cluster_launch_control_try_cancel( diff --git a/python/tilus/ir/instructions/cuda/fence.py b/python/tilus/ir/instructions/cuda/fence.py index 6afb79c1..dcde7881 100644 --- a/python/tilus/ir/instructions/cuda/fence.py +++ b/python/tilus/ir/instructions/cuda/fence.py @@ -20,12 +20,30 @@ @dataclass(frozen=True, eq=False) -class FenceViewAsync(Instruction): +class FenceProxyAsync(Instruction): + """Bidirectional async proxy fence: fence.proxy.async.{space}.""" + space: str @staticmethod - def create(scope: str) -> FenceViewAsync: - assert scope in ("shared", "global"), ( - f"Invalid scope for async fence view: {scope}. Supported candidates are 'shared' and 'global'." + def create(space: str) -> FenceProxyAsync: + assert space in ("shared", "global"), ( + f"Invalid space for async proxy fence: {space}. Supported candidates are 'shared' and 'global'." ) - return FenceViewAsync(output=None, inputs=(), space=scope) + return FenceProxyAsync(output=None, inputs=(), space=space) + + +@dataclass(frozen=True, eq=False) +class FenceProxyAsyncRelease(Instruction): + """Unidirectional generic-to-async release proxy fence. + + PTX: fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster + + This is a lighter-weight alternative to the bidirectional fence.proxy.async.shared::cta. + It only ensures that prior generic proxy writes to shared::cta memory are visible to + subsequent async proxy reads, with release semantics at cluster scope. + """ + + @staticmethod + def create() -> FenceProxyAsyncRelease: + return FenceProxyAsyncRelease(output=None, inputs=()) diff --git a/python/tilus/lang/instantiated_script.py b/python/tilus/lang/instantiated_script.py index 869aec7e..cc74dcd4 100644 --- a/python/tilus/lang/instantiated_script.py +++ b/python/tilus/lang/instantiated_script.py @@ -659,7 +659,13 @@ def _pick_best_program(self, args: Sequence[Any]) -> CompiledProgram: ): compiled_func = compiled_program.get_launch_func() try: - latency.append(benchmark_func(lambda: compiled_func(*kernel_args), warmup=1, repeat=10)) # type: ignore + latency.append( + benchmark_func( + lambda: compiled_func(*kernel_args), + warmup=tilus.option.get_option("bench_warmup"), + repeat=tilus.option.get_option("bench_repeat"), + ) + ) # type: ignore except RuntimeError as e: raise RuntimeError( f"Failed to benchmark the kernel {self.instance_name} with schedule: \n" diff --git a/python/tilus/lang/instructions/fence.py b/python/tilus/lang/instructions/fence.py index c2c8cb77..533b19cf 100644 --- a/python/tilus/lang/instructions/fence.py +++ b/python/tilus/lang/instructions/fence.py @@ -18,9 +18,63 @@ class FenceInstructionGroup(InstructionGroup): - def async_view(self, space: str = "shared") -> None: + """Fence instructions for memory ordering. + + Naming convention follows the PTX fence instruction taxonomy + (PTX ISA 9.7.13.4: membar / fence). Each tilus method corresponds + to one PTX instruction template with parameters for meaningful variation. + + Implemented + ----------- + proxy_async(space) + Bidirectional async proxy fence: ``fence.proxy.async.{space}`` + proxy_async_release() + Unidirectional generic-to-async release: + ``fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster`` + + TODO + ---- + thread(sem, scope, sync_restrict=None) + Thread fence: ``fence.{sem}.{scope}`` + When sync_restrict is set: + - ``fence.acquire.sync_restrict::shared::cluster.cluster`` + - ``fence.release.sync_restrict::shared::cta.cluster`` + proxy_alias() + Bidirectional alias proxy fence: ``fence.proxy.alias`` + proxy_async_acquire() + Unidirectional generic-to-async acquire: + ``fence.proxy.async::generic.acquire.sync_restrict::shared::cluster.cluster`` + proxy_tensormap_release(scope) + Unidirectional tensormap release: + ``fence.proxy.tensormap::generic.release.{scope}`` + proxy_tensormap_acquire(addr, scope) + Unidirectional tensormap acquire: + ``fence.proxy.tensormap::generic.acquire.{scope} [addr], 128`` + """ + + def proxy_async(self, space: str = "shared") -> None: + """Bidirectional async proxy fence. + + PTX: ``fence.proxy.async.{space}`` + + Establishes ordering between the async proxy and the generic proxy + for memory operations in the specified state space. + """ if space not in ("shared", "global"): raise ValueError( - f"Invalid scope for async fence view: {space}. Supported candidates are 'shared' and 'global'." + f"Invalid scope for async proxy fence: {space}. Supported candidates are 'shared' and 'global'." ) - self._builder.fence_view_async(space=space) + self._builder.fence_proxy_async(space=space) + + def proxy_async_release(self) -> None: + """Unidirectional generic-to-async release proxy fence for shared memory. + + PTX: ``fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster`` + + A lighter-weight alternative to proxy_async(). Only ensures that prior + generic proxy writes to shared::cta memory are visible to subsequent + async proxy reads (e.g., stmatrix followed by tma.shared_to_global). + + Requires sm_90 or higher. + """ + self._builder.fence_proxy_async_release() diff --git a/python/tilus/option.py b/python/tilus/option.py index 10d43e33..067ae766 100644 --- a/python/tilus/option.py +++ b/python/tilus/option.py @@ -74,6 +74,18 @@ def _register_options(): default_value=False, description="Whether to block the launch of the kernel until the kernel is finished.", ) + _register_hidet_option( + "tilus.bench_warmup", + type_hint="int", + default_value=5, + description="The number of warmup iterations before benchmarking during autotuning.", + ) + _register_hidet_option( + "tilus.bench_repeat", + type_hint="int", + default_value=50, + description="The number of repeat iterations for benchmarking during autotuning.", + ) _register_options() @@ -120,6 +132,30 @@ def parallel_workers(n: int) -> None: return _set_hidet_option("tilus.parallel_workers", n) +def bench_warmup(n: int) -> None: + """ + Set the number of warmup iterations before benchmarking during autotuning. + + Parameters + ---------- + n: int + The number of warmup iterations. + """ + _set_hidet_option("tilus.bench_warmup", n) + + +def bench_repeat(n: int) -> None: + """ + Set the number of repeat iterations for benchmarking during autotuning. + + Parameters + ---------- + n: int + The number of repeat iterations. + """ + _set_hidet_option("tilus.bench_repeat", n) + + class debug: @staticmethod def dump_ir(enable: bool = True) -> None: diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index ef61ae82..84e1ab07 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -59,7 +59,6 @@ ("blackwell_matmul", "matmul_v6.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v7.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v8.py", nvgpu_sm100a), - ("blackwell_matmul", "matmul_v9.py", nvgpu_sm100a), # hopper matmul example (SM 9.0) ("hopper_matmul", "matmul_v0.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v1.py", nvgpu_sm90a), diff --git a/tests/instructions/test_cluster_launch_control.py b/tests/instructions/test_cluster_launch_control.py index 1e4fa05f..ec9b418f 100644 --- a/tests/instructions/test_cluster_launch_control.py +++ b/tests/instructions/test_cluster_launch_control.py @@ -90,7 +90,7 @@ def __call__(self, n: int32, p_out: ~int32) -> None: # consumer of clc pipeline self.mbarrier.wait(consumer_mbarriers[consumer_stage], phase=consumer_phase) is_valid, blockIdx = self.clc.query_response(cancel_response[consumer_stage]) - self.fence.async_view() # ensure the visibility of the stored result before issuing the next clc operation + self.fence.proxy_async() # ensure the visibility of the stored result before issuing the next clc operation self.mbarrier.arrive(producer_mbarriers[consumer_stage]) consumer_stage = (1 + consumer_stage) % self.num_stages consumer_phase = consumer_phase ^ (consumer_stage == 0) diff --git a/tests/instructions/test_copy_async_tensor.py b/tests/instructions/test_copy_async_tensor.py index 4066aa90..1aeaca72 100644 --- a/tests/instructions/test_copy_async_tensor.py +++ b/tests/instructions/test_copy_async_tensor.py @@ -57,7 +57,7 @@ def __call__(self, m_size: int32, n_size: int, x_ptr: ~float16, y_ptr: ~float16) x += 1 self.store_shared(s_y, x) - self.fence.async_view() + self.fence.proxy_async() self.sync() with self.single_thread():