diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 614473632..5437d2050 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -194,6 +194,7 @@ jobs: ${{ inputs.env-vars }} HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 python benchmarks/run.py \ --op $kernel \ --metrics speedup,accuracy \ + --measure-compile-time \ --latency-measure-mode triton_do_bench \ --cudagraph \ --only $IMPLS \ diff --git a/benchmarks/run.py b/benchmarks/run.py index ef8617517..88a7841db 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -46,6 +46,8 @@ from torch.utils._pytree import tree_map from helion._compat import get_device_name +from helion._compile_time import get_total_time as get_compile_total_time +from helion._compile_time import reset as reset_compile_time from helion._testing import get_nvidia_gpu_model from helion._utils import counters from helion.autotuner.metrics import AutotuneMetrics @@ -1023,6 +1025,7 @@ def run_kernel( results: list[RunResult], kernel_mappings: dict[str, tuple[str, ...]] | None = None, kernel_metric_mappings: dict[str, dict[str, str]] | None = None, + measure_compile_time: bool = False, ) -> None: """Run a kernel benchmark, handling both single and multiple variants.""" # Use provided mappings or default to global mappings @@ -1081,6 +1084,7 @@ def run_kernel( operator_args, results, active_metrics, + measure_compile_time=measure_compile_time, ) @@ -1093,6 +1097,7 @@ def run_kernel_variants( operator_args: dict[str, Any] | None, results: list[RunResult], kernel_metric_mappings: dict[str, dict[str, str]] | None = None, + measure_compile_time: bool = False, ) -> None: """Run kernel variants in the same benchmark run.""" @@ -1183,6 +1188,9 @@ def run_kernel_variants( # pyrefly: ignore [missing-import] from tritonbench.utils.triton_op import register_benchmark + # Compile time tracking per variant + variant_compile_times: dict[str, list[float]] = {} + # Register all variants as separate methods for module_path, func_name in variants: # Import the kernel function @@ -1207,6 +1215,7 @@ def run_kernel_variants( def create_helion_method( mod: Any, # noqa: ANN401 kfunc: Callable[..., Any], + compile_time_list: list[float] | None = None, ) -> Callable[..., Any]: def helion_method( self: object, @@ -1245,6 +1254,29 @@ def helion_method( measured_func_callable = kfunc(self, *args, **kwargs) assert callable(measured_func_callable) + + if compile_time_list is not None: + original = measured_func_callable + first_call = True + ct_list = compile_time_list + + def timed_callable() -> object: + nonlocal first_call + if first_call: + first_call = False + torch.cuda.synchronize() + reset_compile_time() + try: + result = original() + except Exception: + ct_list.append(get_compile_total_time()) + raise + ct_list.append(get_compile_total_time()) + return result + return original() + + return timed_callable + return measured_func_callable return helion_method @@ -1253,6 +1285,12 @@ def helion_method( variant_name = func_name helion_method_name = f"helion_{variant_name}" + # Set up compile time tracking for this variant + compile_times: list[float] | None = None + if measure_compile_time: + compile_times = [] + variant_compile_times[func_name] = compile_times + # Use register_benchmark decorator decorated_method = register_benchmark( operator_name=operator_name, @@ -1261,7 +1299,7 @@ def helion_method( enabled=True, fwd_only=False, label=helion_method_name, - )(create_helion_method(module, kernel_func)) + )(create_helion_method(module, kernel_func, compile_times)) # Set the decorated method on the Operator class setattr(Operator, helion_method_name, decorated_method) @@ -1350,6 +1388,40 @@ def accuracy_fail_hook( except Exception: logger.exception("failed to process results") + # Add compile time metrics (per-shape, same format as speedup) + if measure_compile_time and variant_compile_times: + # Get shapes from the most recent result for this kernel + kernel_results = [r for r in results if r.model == kernel_name] + shapes = kernel_results[-1].shape if kernel_results else [] + device = get_device_name() or "unknown" + for func_name, times in variant_compile_times.items(): + if not times: + continue + # Align compile times with shapes (both are in input order) + if len(times) != len(shapes): + logger.warning( + f"Compile time count ({len(times)}) != shape count " + f"({len(shapes)}) for {kernel_name}/{func_name}, skipping" + ) + continue + metric_name = "helion_compile_time_s" + if len(variants) > 1: + metric_name = f"helion_{func_name}_compile_time_s" + results.append( + RunResult( + model=kernel_name, + device=device, + shape=shapes, + metrics={metric_name: times}, + ) + ) + print( + f"Compile time for {kernel_name}/{func_name}: " + f"{', '.join(f'{t:.3f}s' for t in times)} " + f"({len(times)} shapes)", + file=sys.stderr, + ) + # Force garbage collection multiple times to ensure memory is freed for _ in range(3): gc.collect() @@ -1607,6 +1679,12 @@ def main() -> None: help="Export autotune metrics to a JSON file at the given path. " "Also set via HELION_AUTOTUNE_METRICS_JSON=.", ) + parser.add_argument( + "--measure-compile-time", + action="store_true", + help="Measure and report Helion kernel compile time (seconds) for each input shape. " + "Results are included in JSON output as helion_compile_time_s metric.", + ) # Parse known args to get the kernel name, pass rest to tritonbench args, tritonbench_args = parser.parse_known_args() @@ -1729,6 +1807,9 @@ def main() -> None: results: list[RunResult] = [] + if args.measure_compile_time: + os.environ["HELION_MEASURE_COMPILE_TIME"] = "1" + collected_metrics: list[AutotuneMetrics] = [] if args.autotune_metrics or args.autotune_metrics_json: register_post_autotune_hook(collected_metrics.append) @@ -1759,6 +1840,7 @@ def main() -> None: results, active_kernel_mappings, active_metric_mappings, + measure_compile_time=args.measure_compile_time, ) else: print( @@ -1776,6 +1858,7 @@ def main() -> None: results, active_kernel_mappings, active_metric_mappings, + measure_compile_time=args.measure_compile_time, ) else: # Run all kernels @@ -1793,6 +1876,7 @@ def main() -> None: results, active_kernel_mappings, active_metric_mappings, + measure_compile_time=args.measure_compile_time, ) if args.output: diff --git a/helion/_compile_time.py b/helion/_compile_time.py index 2e4e4cf0f..4b151c781 100644 --- a/helion/_compile_time.py +++ b/helion/_compile_time.py @@ -207,6 +207,10 @@ def _print_line(self, name: str, elapsed: float, total: float, indent: int) -> N file=sys.stderr, ) + def get_total_time(self) -> float: + """Get total top-level compile time in seconds.""" + return sum(self._timings.get(name, 0.0) for name in self._TOP_LEVEL) + def reset(self) -> None: """Reset all timing data.""" with self._timer_lock: @@ -298,3 +302,8 @@ def print_report() -> None: def reset() -> None: """Reset all timing data.""" get_tracker().reset() + + +def get_total_time() -> float: + """Get total top-level compile time from the global tracker.""" + return get_tracker().get_total_time()