From ce339f931539a1dc0b0cb761c5e8be591170db66 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Mon, 6 Feb 2023 15:06:07 +0100 Subject: [PATCH 1/7] feat: add new autotune to replace jit function --- src/kernl/autotune.py | 241 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 src/kernl/autotune.py diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py new file mode 100644 index 00000000..6bbf66ab --- /dev/null +++ b/src/kernl/autotune.py @@ -0,0 +1,241 @@ +import builtins +import copy +import logging +import re +import threading +from typing import List + +import torch +import triton +from triton import cdiv, Config +from triton.runtime.jit import get_cuda_stream, KernelInterface +from triton.testing import do_bench + +log = logging.getLogger(__name__) + +class KernlAutotuner(KernelInterface): + + """ + Simplified version of Triton autotuner. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__(self, fn, meta, configs, mutated_arg_names): + super().__init__() + self.fn = fn + self.meta = meta + self.mutated_arg_names = mutated_arg_names + self.configs = configs + self.launchers = [] + self.lock = threading.Lock() + + + def precompile(self, warm_cache_only_with_cc=None): + with self.lock: + if self.launchers: + return + self.launchers = [ + self._precompile_config(c, warm_cache_only_with_cc) + for c in self.configs + ] + self.configs = None + + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.meta) + for k, v in cfg.kwargs.items(): + compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + if warm_cache_only_with_cc: + triton.compile( + self.fn, + warm_cache_only=True, + cc=warm_cache_only_with_cc, + **compile_meta, + ) + return + + # load binary to the correct device + with torch.cuda.device(compile_meta["device"]): + # need to initialize context + torch.cuda.synchronize(torch.cuda.current_device()) + binary = triton.compile( + self.fn, + **compile_meta, + ) + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = list(self.fn.arg_names) + while def_args and def_args[-1] in cfg.kwargs: + def_args.pop() + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "torch": torch, + "set_device": torch.cuda.set_device, + "current_device": torch.cuda.current_device, + } + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, + stream, bin.cu_function, None, None, None, + {', '.join(call_args)}) + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + return launcher + + def bench(self, launcher, *args, grid): + """Measure the performance of a given launcher""" + stream = get_cuda_stream(torch.cuda.current_device()) + + def kernel_call(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**zip(self.arg_names, args), **launcher.config.kwargs} + ) + launcher( + *args, + grid=grid, + stream=stream, + ) + + return do_bench(kernel_call, rep=40, fast_flush=True) + + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + + # clone inplace buffers to avoid autotune contaminating them if + # the kernel does in-place stores. avoid cloning other buffers because + # it leads to increase memory use + cloned_args = [] + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_args.append(clone_preserve_strides(arg)) + else: + cloned_args.append(arg) + + timings = { + launcher: self.bench(launcher, *cloned_args, **kwargs) + for launcher in self.launchers + } + self.launchers = [builtins.min(timings, key=timings.get)] + + def run(self, *args, grid, stream): + if len(self.launchers) != 1: + if len(self.launchers) == 0: + self.precompile() + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, grid=grid) + + (launcher,) = self.launchers + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**zip(self.arg_names, args), **launcher.config.kwargs} + ) + try: + result = launcher( + *args, + grid=grid, + stream=stream, + ) + except TypeError as e: + if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)): + raise RuntimeError( + """Consider updating Triton with +`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`""" + ) from e + else: + raise e + + return result + +def kernl_autotune( + configs: List[Config], + meta, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 + + mutated_arg_names = meta.pop("mutated_arg_names", ()) + def decorator(fn): + return KernlAutotuner( + fn, + meta=meta, + configs=configs, + mutated_arg_names=mutated_arg_names, + ) + + return decorator + +def unique_configs(configs: List[Config]): + """Remove duplicate configurations""" + seen = set() + pruned_configs = [] + for cfg in configs: + key = tuple(cfg.kwargs.items()) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + +def template(num_stages, num_warps, meta): + """ + Compile a triton template + """ + return kernl_autotune( + [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta + ) + + +def clone_preserve_strides(x): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) + + +def grid(xnumel, ynumel=None, znumel=None): + """Helper function to compute triton grids""" + + def get_grid_dim(numel, block_name, block): + if numel is None: + return 1 + label = block_name[0] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1" + f"({label.lower()}numel=={numel}, {block_name}={block})." + ) + return cdiv(numel, block) + + def grid_fn(meta): + return ( + get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)), + get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)), + get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)), + ) + + return grid_fn From 0b79f4cf470be0138aeab8d67e5089b5a74b8655 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Tue, 7 Feb 2023 01:41:37 +0100 Subject: [PATCH 2/7] feat: update autotuner + add heuristics + update linear layer --- src/kernl/autotune.py | 162 ++++++++++++++-------- src/kernl/implementations/linear_layer.py | 6 +- 2 files changed, 106 insertions(+), 62 deletions(-) diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py index 6bbf66ab..90458264 100644 --- a/src/kernl/autotune.py +++ b/src/kernl/autotune.py @@ -1,18 +1,24 @@ +import ast import builtins import copy +import inspect import logging import re +import textwrap import threading -from typing import List +from collections import namedtuple +from typing import Dict, List, Optional import torch import triton -from triton import cdiv, Config -from triton.runtime.jit import get_cuda_stream, KernelInterface +from triton import Config, cdiv +from triton.runtime.jit import KernelInterface, get_cuda_stream from triton.testing import do_bench + log = logging.getLogger(__name__) + class KernlAutotuner(KernelInterface): """ @@ -21,41 +27,68 @@ class KernlAutotuner(KernelInterface): configs, and does not rely on the Triton JIT. """ - def __init__(self, fn, meta, configs, mutated_arg_names): - super().__init__() + def __init__(self, fn, configs, key, reset_to_zero, prune_configs_by: Dict = None): + if not configs: + self.configs = [Config(dict(), num_warps=4, num_stages=2)] + else: + self.configs = configs self.fn = fn - self.meta = meta - self.mutated_arg_names = mutated_arg_names - self.configs = configs - self.launchers = [] + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def") :] + self.signature = inspect.signature(fn) + self.arg_names = [v.name for v in self.signature.parameters.values()] + self.key_idx = [self.arg_names.index(k) for k in key] + self.launchers = list() + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [self.arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + # prune configs + early_config_prune = None + if prune_configs_by: + perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k = None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.constexprs = [self.arg_names.index(ann) for ann in self.fn.__annotations__.keys()] + self.fn.parse = self.parse self.lock = threading.Lock() + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree - def precompile(self, warm_cache_only_with_cc=None): + def precompile(self): with self.lock: if self.launchers: return - self.launchers = [ - self._precompile_config(c, warm_cache_only_with_cc) - for c in self.configs - ] + self.launchers = [self._precompile_config(c) for c in self.configs] self.configs = None - def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): + def _precompile_config(self, cfg: Config): """Ahead of time compile a given autotuner config.""" - compile_meta = copy.deepcopy(self.meta) + compile_meta = dict() + compile_meta["constants"] = dict() + config = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(range(4)), ()) + compile_meta["configs"] = [config] for k, v in cfg.kwargs.items(): - compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["constants"][self.arg_names.index(k)] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages - if warm_cache_only_with_cc: - triton.compile( - self.fn, - warm_cache_only=True, - cc=warm_cache_only_with_cc, - **compile_meta, - ) - return + compile_meta["device"] = torch.cuda.current_device() + compile_meta["signature"] = self.signature # load binary to the correct device with torch.cuda.device(compile_meta["device"]): @@ -66,11 +99,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): **compile_meta, ) - call_args = [ - arg - for i, arg in enumerate(self.fn.arg_names) - if i not in self.fn.constexprs - ] + call_args = [arg for i, arg in enumerate(self.fn.arg_names) if i not in self.fn.constexprs] def_args = list(self.fn.arg_names) while def_args and def_args[-1] in cfg.kwargs: def_args.pop() @@ -106,9 +135,7 @@ def bench(self, launcher, *args, grid): def kernel_call(): if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**zip(self.arg_names, args), **launcher.config.kwargs} - ) + launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) launcher( *args, grid=grid, @@ -117,7 +144,6 @@ def kernel_call(): return do_bench(kernel_call, rep=40, fast_flush=True) - def autotune_to_one_config(self, *args, **kwargs): """Do the actual autotuning""" @@ -132,10 +158,7 @@ def autotune_to_one_config(self, *args, **kwargs): else: cloned_args.append(arg) - timings = { - launcher: self.bench(launcher, *cloned_args, **kwargs) - for launcher in self.launchers - } + timings = {launcher: self.bench(launcher, *cloned_args, **kwargs) for launcher in self.launchers} self.launchers = [builtins.min(timings, key=timings.get)] def run(self, *args, grid, stream): @@ -147,9 +170,7 @@ def run(self, *args, grid, stream): (launcher,) = self.launchers if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**zip(self.arg_names, args), **launcher.config.kwargs} - ) + launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) try: result = launcher( *args, @@ -167,28 +188,27 @@ def run(self, *args, grid, stream): return result + def kernl_autotune( configs: List[Config], - meta, + key: List[str], + reset_to_zero: Optional[List[str]] = None, + prune_configs_by: Optional[Dict] = None, ): """ A copy of triton.autotune that calls our subclass. Our subclass has additional debugging, error handling, and on-disk caching. """ configs = unique_configs(configs) - assert len(configs) == 1 - mutated_arg_names = meta.pop("mutated_arg_names", ()) def decorator(fn): return KernlAutotuner( - fn, - meta=meta, - configs=configs, - mutated_arg_names=mutated_arg_names, + fn, configs=configs, key=key, reset_to_zero=reset_to_zero, prune_configs_by=prune_configs_by ) return decorator + def unique_configs(configs: List[Config]): """Remove duplicate configurations""" seen = set() @@ -200,19 +220,9 @@ def unique_configs(configs: List[Config]): pruned_configs.append(cfg) return pruned_configs -def template(num_stages, num_warps, meta): - """ - Compile a triton template - """ - return kernl_autotune( - [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta - ) - def clone_preserve_strides(x): - needed_size = ( - sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 - ) + needed_size = sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 buffer = torch.as_strided(x, (needed_size,), (1,)).clone() return torch.as_strided(buffer, x.size(), x.stride()) @@ -239,3 +249,37 @@ def grid_fn(meta): ) return grid_fn + + +class KernlHeuristics(KernelInterface): + def __init__(self, fn, values) -> None: + self.fn = fn + self.values = values + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def kernl_heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + .. highlight:: python + .. code-block:: python + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + .type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return KernlHeuristics(fn, values) + + return decorator diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 32349912..6ed1e638 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -25,6 +25,7 @@ from torch.cuda.amp import custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time +from kernl.autotune import kernl_autotune, kernl_heuristics from kernl.implementations import activation_func @@ -57,7 +58,7 @@ def get_configs_io_bound(): return configs -@triton.autotune( +@kernl_autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), @@ -83,12 +84,11 @@ def get_configs_io_bound(): key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, ) -@triton.heuristics( +@kernl_heuristics( { "K_LOAD_MASK_NEEDED": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, } ) -@triton.jit def kernel_fma( C, # Pointers to matrices ACT_INPUTS, From 04a52769984ebf633711c8fef3066527830dbcf7 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Fri, 17 Feb 2023 01:08:46 +0100 Subject: [PATCH 3/7] feat: fix autotuner --- src/kernl/autotune.py | 275 ++++++++++++---------- src/kernl/implementations/linear_layer.py | 9 +- src/kernl/utils/autotuner_helper.py | 50 ++++ test/test_linear_layer.py | 2 +- 4 files changed, 199 insertions(+), 137 deletions(-) create mode 100644 src/kernl/utils/autotuner_helper.py diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py index 90458264..6fc7eabb 100644 --- a/src/kernl/autotune.py +++ b/src/kernl/autotune.py @@ -1,26 +1,24 @@ -import ast import builtins import copy -import inspect +import hashlib +import json import logging +import os.path import re -import textwrap import threading -from collections import namedtuple -from typing import Dict, List, Optional +from typing import List, Dict, Optional import torch + import triton -from triton import Config, cdiv -from triton.runtime.jit import KernelInterface, get_cuda_stream -from triton.testing import do_bench +from triton import cdiv, Config +from triton.runtime.jit import get_cuda_stream, KernelInterface +from utils.autotuner_helper import type_of, key_of log = logging.getLogger(__name__) - -class KernlAutotuner(KernelInterface): - +class Autotuner(KernelInterface): """ Simplified version of Triton autotuner. Unlike the main triton Autotuner, this version can precompile all @@ -28,17 +26,15 @@ class KernlAutotuner(KernelInterface): """ def __init__(self, fn, configs, key, reset_to_zero, prune_configs_by: Dict = None): + super().__init__() + self.launchers = [] if not configs: self.configs = [Config(dict(), num_warps=4, num_stages=2)] else: self.configs = configs - self.fn = fn - self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[self.src.find("def") :] - self.signature = inspect.signature(fn) - self.arg_names = [v.name for v in self.signature.parameters.values()] + self.arg_names = fn.arg_names self.key_idx = [self.arg_names.index(k) for k in key] - self.launchers = list() + self.cache = dict() # hook to reset all required tensor to zeros before relaunching a kernel self.hook = lambda args: 0 if reset_to_zero is not None: @@ -47,59 +43,65 @@ def __init__(self, fn, configs, key, reset_to_zero, prune_configs_by: Dict = Non def _hook(args): for i in self.reset_idx: args[i].zero_() - self.hook = _hook - # prune configs - early_config_prune = None if prune_configs_by: - perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] - if "early_config_prune" in prune_configs_by: - early_config_prune = prune_configs_by["early_config_prune"] + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] else: - perf_model, top_k = None, None + perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune - self.constexprs = [self.arg_names.index(ann) for ann in self.fn.__annotations__.keys()] - self.fn.parse = self.parse + self.fn = fn + self.__annotations__ = fn.__annotations__ + # index of constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] self.lock = threading.Lock() - def parse(self): - tree = ast.parse(self.src) - assert isinstance(tree, ast.Module) - assert len(tree.body) == 1 - assert isinstance(tree.body[0], ast.FunctionDef) - return tree - - def precompile(self): + def precompile(self, warm_cache_only_with_cc=None): + breakpoint() with self.lock: if self.launchers: return - self.launchers = [self._precompile_config(c) for c in self.configs] + self.launchers = [ + self._precompile_config(c, warm_cache_only_with_cc) + for c in self.configs + ] self.configs = None - def _precompile_config(self, cfg: Config): + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): """Ahead of time compile a given autotuner config.""" - compile_meta = dict() - compile_meta["constants"] = dict() - config = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(range(4)), ()) - compile_meta["configs"] = [config] + all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + signature = {{i: type_of(key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}} + compile_meta = { + "constants": dict(), "signature": signature + } for k, v in cfg.kwargs.items(): - compile_meta["constants"][self.arg_names.index(k)] = v + compile_meta["constants"][self.fn.arg_names.index(k)] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages - compile_meta["device"] = torch.cuda.current_device() - compile_meta["signature"] = self.signature - - # load binary to the correct device - with torch.cuda.device(compile_meta["device"]): - # need to initialize context - torch.cuda.synchronize(torch.cuda.current_device()) - binary = triton.compile( + + if warm_cache_only_with_cc: + triton.compile( self.fn, + warm_cache_only=True, + cc=warm_cache_only_with_cc, **compile_meta, ) + return + + torch.cuda.set_device(torch.cuda.current_device()) - call_args = [arg for i, arg in enumerate(self.fn.arg_names) if i not in self.fn.constexprs] + binary = triton.compile( + self.fn, + **compile_meta, + ) + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] def_args = list(self.fn.arg_names) while def_args and def_args[-1] in cfg.kwargs: def_args.pop() @@ -114,10 +116,8 @@ def _precompile_config(self, cfg: Config): exec( f""" def launcher({', '.join(def_args)}, grid, stream): - if callable(grid): - grid_0, grid_1, grid_2 = grid(grid_meta) - else: - grid_0, grid_1, grid_2 = grid + # set_device(current_device()) # TODO(jansel): is this needed? + grid_0, grid_1, grid_2 = grid(grid_meta) bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, None, None, None, {', '.join(call_args)}) @@ -135,30 +135,40 @@ def bench(self, launcher, *args, grid): def kernel_call(): if launcher.config.pre_hook is not None: - launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) + launcher.config.pre_hook( + {**zip(self.arg_names, args), **launcher.config.kwargs} + ) launcher( *args, grid=grid, stream=stream, ) - return do_bench(kernel_call, rep=40, fast_flush=True) + from triton.testing import do_bench + + return do_bench(kernel_call) + + @staticmethod + def clone_preserve_strides(x): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) def autotune_to_one_config(self, *args, **kwargs): """Do the actual autotuning""" - # clone inplace buffers to avoid autotune contaminating them if - # the kernel does in-place stores. avoid cloning other buffers because - # it leads to increase memory use - cloned_args = [] - for i, arg in enumerate(args): - if self.fn.arg_names[i] in self.mutated_arg_names: - assert isinstance(arg, torch.Tensor) - cloned_args.append(clone_preserve_strides(arg)) - else: - cloned_args.append(arg) - - timings = {launcher: self.bench(launcher, *cloned_args, **kwargs) for launcher in self.launchers} + # clone the input args to avoid autotune contaminating them if + # the kernel does in-place stores + cloned_args = [ + self.clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg + for arg in args + ] + timings = { + launcher: self.bench(launcher, *cloned_args, **kwargs) + for launcher in self.launchers + } self.launchers = [builtins.min(timings, key=timings.get)] def run(self, *args, grid, stream): @@ -170,7 +180,9 @@ def run(self, *args, grid, stream): (launcher,) = self.launchers if launcher.config.pre_hook is not None: - launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) + launcher.config.pre_hook( + {**zip(self.arg_names, args), **launcher.config.kwargs} + ) try: result = launcher( *args, @@ -182,28 +194,65 @@ def run(self, *args, grid, stream): raise RuntimeError( """Consider updating Triton with `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`""" - ) from e + ) else: raise e return result -def kernl_autotune( +def hash_configs(configs: List[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode( + "utf-8" + ) + ) + return hasher.hexdigest() + + +def load_cached_autotuning( + cache_filename: str, configs_hash: str, configs: List[Config] +): + """ + Read a cached autotuning result from disk + """ + if not os.path.exists(cache_filename): + return None + + best_config = json.loads(open(cache_filename).read()) + if best_config.get("configs_hash") != configs_hash: + return None + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +def autotune( configs: List[Config], key: List[str], reset_to_zero: Optional[List[str]] = None, prune_configs_by: Optional[Dict] = None, ): """ - A copy of triton.autotune that calls our subclass. Our subclass - has additional debugging, error handling, and on-disk caching. + A copy of triton.autotune that calls our subclass. """ configs = unique_configs(configs) def decorator(fn): - return KernlAutotuner( - fn, configs=configs, key=key, reset_to_zero=reset_to_zero, prune_configs_by=prune_configs_by + return Autotuner( + fn, configs=configs, key=key, reset_to_zero=reset_to_zero, prune_configs_by=prune_configs_by, ) return decorator @@ -220,66 +269,34 @@ def unique_configs(configs: List[Config]): pruned_configs.append(cfg) return pruned_configs - -def clone_preserve_strides(x): - needed_size = sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 - buffer = torch.as_strided(x, (needed_size,), (1,)).clone() - return torch.as_strided(buffer, x.size(), x.stride()) - - def grid(xnumel, ynumel=None, znumel=None): """Helper function to compute triton grids""" - def get_grid_dim(numel, block_name, block): - if numel is None: - return 1 - label = block_name[0] - if numel == 1: - assert block == 1, ( - f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1" - f"({label.lower()}numel=={numel}, {block_name}={block})." - ) - return cdiv(numel, block) - - def grid_fn(meta): - return ( - get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)), - get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)), - get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)), - ) + if ynumel and znumel: - return grid_fn + def grid_fn(meta): + return ( + cdiv(xnumel, meta["XBLOCK"]), + cdiv(ynumel, meta["YBLOCK"]), + cdiv(znumel, meta["ZBLOCK"]), + ) + elif ynumel: -class KernlHeuristics(KernelInterface): - def __init__(self, fn, values) -> None: - self.fn = fn - self.values = values - signature = inspect.signature(fn) - self.arg_names = [v.name for v in signature.parameters.values()] - - def run(self, *args, **kwargs): - for v, heur in self.values.items(): - kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) - return self.fn.run(*args, **kwargs) + def grid_fn(meta): + return ( + cdiv(xnumel, meta["XBLOCK"]), + cdiv(ynumel, meta["YBLOCK"]), + 1, + ) + else: -def kernl_heuristics(values): - """ - Decorator for specifying how the values of certain meta-parameters may be computed. - This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. - .. highlight:: python - .. code-block:: python - @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. - each such function takes a list of positional arguments as input. - .type values: dict[str, Callable[[list[Any]], Any]] - """ - - def decorator(fn): - return KernlHeuristics(fn, values) + def grid_fn(meta): + return ( + cdiv(xnumel, meta["XBLOCK"]), + 1, + 1, + ) - return decorator + return grid_fn diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 6ed1e638..2e6725fe 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -25,7 +25,7 @@ from torch.cuda.amp import custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time -from kernl.autotune import kernl_autotune, kernl_heuristics +from kernl.autotune import autotune from kernl.implementations import activation_func @@ -58,7 +58,7 @@ def get_configs_io_bound(): return configs -@kernl_autotune( +@autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), @@ -84,11 +84,6 @@ def get_configs_io_bound(): key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, ) -@kernl_heuristics( - { - "K_LOAD_MASK_NEEDED": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) def kernel_fma( C, # Pointers to matrices ACT_INPUTS, diff --git a/src/kernl/utils/autotuner_helper.py b/src/kernl/utils/autotuner_helper.py new file mode 100644 index 00000000..619a452c --- /dev/null +++ b/src/kernl/utils/autotuner_helper.py @@ -0,0 +1,50 @@ +import torch +import triton + + +def type_of(key): + if isinstance(key, (torch.dtype, triton.language.dtype)): + ty = { + torch.bool: "i1", + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float32: "fp32", + torch.float64: "fp64", + torch.uint8: "u8", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + triton.language.uint8: "u8", + triton.language.uint16: "u16", + triton.language.uint32: "u32", + triton.language.uint64: "u64", + triton.language.float8: "fp8", + }[key] + return f"*{ty}" + if key is None: + return "*i8" + assert isinstance(key, str) + return key + + +def key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2 ** 31) <= arg and arg <= 2 ** 31 - 1: + return "i32" + elif 2 ** 31 <= arg and arg <= 2 ** 32 - 1: + return "u32" + elif 2 ** 63 <= arg and arg <= 2 ** 64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index ba777d2d..eee61cd7 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -51,7 +51,7 @@ def get_pytorch_activation(activation: str) -> Callable: @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", - [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], + [(1, 8, 8, 8)], ids=lambda s: "x".join(map(str, s)), ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) From b1198f898ffa62f20e83d9aafe4b1d6803cb0893 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Tue, 21 Feb 2023 14:31:18 +0100 Subject: [PATCH 4/7] feat: fix autotune and add signature in autotune call --- src/kernl/autotune.py | 212 ++++++++++------------ src/kernl/implementations/linear_layer.py | 127 +++++++++++-- src/kernl/utils/autotuner_helper.py | 50 ----- test/test_linear_layer.py | 4 +- 4 files changed, 212 insertions(+), 181 deletions(-) delete mode 100644 src/kernl/utils/autotuner_helper.py diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py index 6fc7eabb..78d5b4e6 100644 --- a/src/kernl/autotune.py +++ b/src/kernl/autotune.py @@ -1,23 +1,36 @@ +import ast import builtins -import copy -import hashlib -import json +import inspect import logging -import os.path import re +import textwrap import threading -from typing import List, Dict, Optional +from typing import Dict, List, Optional import torch - import triton -from triton import cdiv, Config -from triton.runtime.jit import get_cuda_stream, KernelInterface +from triton import Config, cdiv +from triton.runtime.jit import get_cuda_stream -from utils.autotuner_helper import type_of, key_of log = logging.getLogger(__name__) + +class KernelInterface: + def __getitem__(self, grid): + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + stream = get_cuda_stream(torch.cuda.current_device()) + + def launcher(*args, **kwargs): + return self.run(*args, grid=grid, stream=stream, **kwargs) + + return launcher + + class Autotuner(KernelInterface): """ Simplified version of Triton autotuner. @@ -25,16 +38,20 @@ class Autotuner(KernelInterface): configs, and does not rely on the Triton JIT. """ - def __init__(self, fn, configs, key, reset_to_zero, prune_configs_by: Dict = None): + def __init__(self, fn, configs, signature, key, reset_to_zero, prune_configs_by: Dict = None): super().__init__() self.launchers = [] if not configs: self.configs = [Config(dict(), num_warps=4, num_stages=2)] else: self.configs = configs - self.arg_names = fn.arg_names + self.signature = signature + fn_signature = inspect.signature(fn) + self.arg_names = [v.name for v in fn_signature.parameters.values()] self.key_idx = [self.arg_names.index(k) for k in key] self.cache = dict() + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def") :] # hook to reset all required tensor to zeros before relaunching a kernel self.hook = lambda args: 0 if reset_to_zero is not None: @@ -43,43 +60,62 @@ def __init__(self, fn, configs, key, reset_to_zero, prune_configs_by: Dict = Non def _hook(args): for i in self.reset_idx: args[i].zero_() + self.hook = _hook if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] + perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] else: perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune self.fn = fn + self.fn.cache_key = "test" self.__annotations__ = fn.__annotations__ # index of constexprs self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] + self.fn.parse = self.parse + self.fn.src = self.src self.lock = threading.Lock() + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + def precompile(self, warm_cache_only_with_cc=None): - breakpoint() with self.lock: if self.launchers: return - self.launchers = [ - self._precompile_config(c, warm_cache_only_with_cc) - for c in self.configs - ] + self.launchers = [self._precompile_config(c, warm_cache_only_with_cc) for c in self.configs] self.configs = None + @staticmethod + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % 16 == 0 + elif isinstance(x, int): + return x % 16 == 0 + if x is None: + return True + return False + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): """Ahead of time compile a given autotuner config.""" - all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, - signature = {{i: type_of(key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs}} - compile_meta = { - "constants": dict(), "signature": signature - } - for k, v in cfg.kwargs.items(): - compile_meta["constants"][self.fn.arg_names.index(k)] = v + # make constants: + constexpr_args = [f"{arg}" for i, arg in enumerate(self.arg_names) if i in self.constexprs] + constants = {i: k for i, k in zip(self.constexprs, constexpr_args)} + for k, v in constants.items(): + constants[k] = cfg.kwargs[v] if v in cfg.kwargs.keys() else 1 + compile_meta = {"constants": constants} + compile_meta["signature"] = self.signature compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages + cfg.divisible_by_16 = [i for i, arg in enumerate(self.arg_names) if self.is_divisible_by_16(arg)] + cfg.equal_to_1 = [i for i, arg in enumerate(self.arg_names) if isinstance(arg, int) and arg == 1] if warm_cache_only_with_cc: triton.compile( @@ -91,18 +127,16 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): return torch.cuda.set_device(torch.cuda.current_device()) + compile_meta["device"] = 0 binary = triton.compile( self.fn, + configs=[cfg], **compile_meta, ) - call_args = [ - arg - for i, arg in enumerate(self.fn.arg_names) - if i not in self.fn.constexprs - ] - def_args = list(self.fn.arg_names) + call_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs and arg != "stream"] + def_args = list(self.arg_names) while def_args and def_args[-1] in cfg.kwargs: def_args.pop() @@ -117,10 +151,13 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): f""" def launcher({', '.join(def_args)}, grid, stream): # set_device(current_device()) # TODO(jansel): is this needed? - grid_0, grid_1, grid_2 = grid(grid_meta) - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, - stream, bin.cu_function, None, None, None, - {', '.join(call_args)}) + if callable(grid): + grid = grid(grid_meta) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, None, None, None, {', '.join(call_args)}) """.lstrip(), scope, ) @@ -129,20 +166,15 @@ def launcher({', '.join(def_args)}, grid, stream): launcher.config = cfg return launcher - def bench(self, launcher, *args, grid): + def bench(self, launcher, *args, grid, **kwargs): """Measure the performance of a given launcher""" - stream = get_cuda_stream(torch.cuda.current_device()) + + current = dict(**kwargs, **launcher.config.kwargs) def kernel_call(): if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**zip(self.arg_names, args), **launcher.config.kwargs} - ) - launcher( - *args, - grid=grid, - stream=stream, - ) + launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) + launcher(*args, grid=grid, **current) from triton.testing import do_bench @@ -150,9 +182,7 @@ def kernel_call(): @staticmethod def clone_preserve_strides(x): - needed_size = ( - sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 - ) + needed_size = sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 buffer = torch.as_strided(x, (needed_size,), (1,)).clone() return torch.as_strided(buffer, x.size(), x.stride()) @@ -161,34 +191,23 @@ def autotune_to_one_config(self, *args, **kwargs): # clone the input args to avoid autotune contaminating them if # the kernel does in-place stores - cloned_args = [ - self.clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg - for arg in args - ] - timings = { - launcher: self.bench(launcher, *cloned_args, **kwargs) - for launcher in self.launchers - } + cloned_args = [self.clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg for arg in args] + timings = {launcher: self.bench(launcher, *cloned_args, **kwargs) for launcher in self.launchers} self.launchers = [builtins.min(timings, key=timings.get)] - def run(self, *args, grid, stream): + def run(self, *args, grid, **kwargs): + stream = get_cuda_stream(torch.cuda.current_device()) if len(self.launchers) != 1: if len(self.launchers) == 0: self.precompile() if len(self.launchers) > 1: - self.autotune_to_one_config(*args, grid=grid) + self.autotune_to_one_config(*args, grid=grid, **kwargs) (launcher,) = self.launchers if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**zip(self.arg_names, args), **launcher.config.kwargs} - ) + launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) try: - result = launcher( - *args, - grid=grid, - stream=stream, - ) + result = launcher(*args, grid=grid, stream=stream, **kwargs) except TypeError as e: if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)): raise RuntimeError( @@ -201,47 +220,10 @@ def run(self, *args, grid, stream): return result -def hash_configs(configs: List[Config]): - """ - Hash used to check for changes in configurations - """ - hasher = hashlib.sha256() - for cfg in configs: - hasher.update( - f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode( - "utf-8" - ) - ) - return hasher.hexdigest() - - -def load_cached_autotuning( - cache_filename: str, configs_hash: str, configs: List[Config] -): - """ - Read a cached autotuning result from disk - """ - if not os.path.exists(cache_filename): - return None - - best_config = json.loads(open(cache_filename).read()) - if best_config.get("configs_hash") != configs_hash: - return None - - matching_configs = [ - cfg - for cfg in configs - if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) - ] - if len(matching_configs) != 1: - return None - - return matching_configs[0] - - def autotune( configs: List[Config], key: List[str], + signature: Dict[int, str], reset_to_zero: Optional[List[str]] = None, prune_configs_by: Optional[Dict] = None, ): @@ -252,7 +234,12 @@ def autotune( def decorator(fn): return Autotuner( - fn, configs=configs, key=key, reset_to_zero=reset_to_zero, prune_configs_by=prune_configs_by, + fn, + configs=configs, + signature=signature, + key=key, + reset_to_zero=reset_to_zero, + prune_configs_by=prune_configs_by, ) return decorator @@ -269,6 +256,7 @@ def unique_configs(configs: List[Config]): pruned_configs.append(cfg) return pruned_configs + def grid(xnumel, ynumel=None, znumel=None): """Helper function to compute triton grids""" @@ -276,17 +264,17 @@ def grid(xnumel, ynumel=None, znumel=None): def grid_fn(meta): return ( - cdiv(xnumel, meta["XBLOCK"]), - cdiv(ynumel, meta["YBLOCK"]), - cdiv(znumel, meta["ZBLOCK"]), + cdiv(xnumel, meta["BLOCK_M"]), + cdiv(ynumel, meta["BLOCK_N"]), + cdiv(znumel, meta["BLOCK_K"]), ) elif ynumel: def grid_fn(meta): return ( - cdiv(xnumel, meta["XBLOCK"]), - cdiv(ynumel, meta["YBLOCK"]), + cdiv(xnumel, meta["BLOCK_M"]), + cdiv(ynumel, meta["BLOCK_N"]), 1, ) @@ -294,7 +282,7 @@ def grid_fn(meta): def grid_fn(meta): return ( - cdiv(xnumel, meta["XBLOCK"]), + cdiv(xnumel, meta["BLOCK_M"]), 1, 1, ) diff --git a/src/kernl/implementations/linear_layer.py b/src/kernl/implementations/linear_layer.py index 2e6725fe..f24837bb 100644 --- a/src/kernl/implementations/linear_layer.py +++ b/src/kernl/implementations/linear_layer.py @@ -60,29 +60,120 @@ def get_configs_io_bound(): @autotune( configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=5, + num_warps=2, + ), # good for int8 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1, "K_LOAD_MASK_NEEDED": 1}, + num_stages=5, + num_warps=2, + ), ] + get_configs_io_bound(), key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + signature={ + 0: "*fp32", + 1: "*i8", + 2: "*fp32", + 3: "*fp32", + 4: "*fp32", + 5: "i32", + 6: "i32", + 7: "i32", + 8: "i32", + 9: "i32", + 10: "i32", + 11: "i32", + 12: "i32", + 13: "i32", + 14: "i32", + 15: "i32", + 16: "i32", + }, ) def kernel_fma( C, # Pointers to matrices diff --git a/src/kernl/utils/autotuner_helper.py b/src/kernl/utils/autotuner_helper.py deleted file mode 100644 index 619a452c..00000000 --- a/src/kernl/utils/autotuner_helper.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import triton - - -def type_of(key): - if isinstance(key, (torch.dtype, triton.language.dtype)): - ty = { - torch.bool: "i1", - torch.float16: "fp16", - torch.bfloat16: "bf16", - torch.float32: "fp32", - torch.float64: "fp64", - torch.uint8: "u8", - torch.int8: "i8", - torch.int16: "i16", - torch.int32: "i32", - torch.int64: "i64", - triton.language.uint8: "u8", - triton.language.uint16: "u16", - triton.language.uint32: "u32", - triton.language.uint64: "u64", - triton.language.float8: "fp8", - }[key] - return f"*{ty}" - if key is None: - return "*i8" - assert isinstance(key, str) - return key - - -def key_of(arg): - if hasattr(arg, "dtype"): - return arg.dtype - elif isinstance(arg, bool): - return "i1" - elif isinstance(arg, int): - if -(2 ** 31) <= arg and arg <= 2 ** 31 - 1: - return "i32" - elif 2 ** 31 <= arg and arg <= 2 ** 32 - 1: - return "u32" - elif 2 ** 63 <= arg and arg <= 2 ** 64 - 1: - return "u64" - else: - return "i64" - elif isinstance(arg, float): - return "fp32" - elif arg is None: - return None - else: - raise TypeError(f"Unsupported type {type(arg)} for {arg}") diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index eee61cd7..3b8ee683 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -21,7 +21,9 @@ from conftest import assert_all_close, set_seed from kernl.implementations.linear_layer import linear_layer -from kernl.optimizer.cuda_graph import cuda_graphs_wrapper + + +# from kernl.optimizer.cuda_graph import cuda_graphs_wrapper def get_pytorch_activation(activation: str) -> Callable: From f8ecb2bb04f2c7de441936b5a2dee31a4285adfb Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 22 Feb 2023 00:43:34 +0100 Subject: [PATCH 5/7] feat: use batched matmul kernl with new autotune --- src/kernl/autotune.py | 5 +---- src/kernl/implementations/batched_matmul.py | 22 +++++++++++++++++++-- test/test_linear_layer.py | 6 ++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py index 78d5b4e6..550183d9 100644 --- a/src/kernl/autotune.py +++ b/src/kernl/autotune.py @@ -110,10 +110,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): constants = {i: k for i, k in zip(self.constexprs, constexpr_args)} for k, v in constants.items(): constants[k] = cfg.kwargs[v] if v in cfg.kwargs.keys() else 1 - compile_meta = {"constants": constants} - compile_meta["signature"] = self.signature - compile_meta["num_warps"] = cfg.num_warps - compile_meta["num_stages"] = cfg.num_stages + compile_meta = {"constants": constants, "signature": self.signature, "num_warps": cfg.num_warps, "num_stages": cfg.num_stages} cfg.divisible_by_16 = [i for i, arg in enumerate(self.arg_names) if self.is_divisible_by_16(arg)] cfg.equal_to_1 = [i for i, arg in enumerate(self.arg_names) if isinstance(arg, int) and arg == 1] diff --git a/src/kernl/implementations/batched_matmul.py b/src/kernl/implementations/batched_matmul.py index 294f0cdd..6dcf9844 100644 --- a/src/kernl/implementations/batched_matmul.py +++ b/src/kernl/implementations/batched_matmul.py @@ -17,11 +17,13 @@ import triton import triton.language as tl +from kernl.autotune import autotune + # CREDITS: Initially inspired by the Triton tutorial -@triton.autotune( +@autotune( configs=[ triton.Config( {"BLOCK_M_SIZE": 128, "BLOCK_N_SIZE": 256, "BLOCK_K_SIZE": 32, "GROUP_M_SIZE": 8}, num_stages=2, num_warps=8 @@ -55,8 +57,24 @@ ), ], key=["m_size", "n_size", "k_size"], + signature={ + 0: "*fp32", + 1: "*fp32", + 2: "*fp32", + 3: "i32", + 4: "i32", + 5: "i32", + 6: "i32", + 7: "i32", + 8: "i32", + 9: "i32", + 10: "i32", + 11: "i32", + 12: "i32", + 13: "i32", + 14: "i32", + } ) -@triton.jit def matmul_kernel( # Pointers to matrices a_ptr, diff --git a/test/test_linear_layer.py b/test/test_linear_layer.py index 3b8ee683..ba777d2d 100644 --- a/test/test_linear_layer.py +++ b/test/test_linear_layer.py @@ -21,9 +21,7 @@ from conftest import assert_all_close, set_seed from kernl.implementations.linear_layer import linear_layer - - -# from kernl.optimizer.cuda_graph import cuda_graphs_wrapper +from kernl.optimizer.cuda_graph import cuda_graphs_wrapper def get_pytorch_activation(activation: str) -> Callable: @@ -53,7 +51,7 @@ def get_pytorch_activation(activation: str) -> Callable: @pytest.mark.parametrize("activation", ["", "tanh", "gelu", "relu"], ids=["no_activation", "tanh", "gelu", "relu"]) @pytest.mark.parametrize( "shape", - [(1, 8, 8, 8)], + [(1, 8, 8, 8)] + [(bs, M, 768, 768) for bs in [1, 16] for M in [8, 16, 128, 256, 512]], ids=lambda s: "x".join(map(str, s)), ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) From 0e77f9f7bfbc1fe1b8dfdd50af8d1faaa934a835 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 22 Feb 2023 17:11:19 +0100 Subject: [PATCH 6/7] feat: temporary fix for cache key --- src/kernl/autotune.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py index 550183d9..ae8e7771 100644 --- a/src/kernl/autotune.py +++ b/src/kernl/autotune.py @@ -3,6 +3,8 @@ import inspect import logging import re +import string +import random import textwrap import threading from typing import Dict, List, Optional @@ -71,7 +73,7 @@ def _hook(args): self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune self.fn = fn - self.fn.cache_key = "test" + self.fn.cache_key = ''.join(random.choice(string.printable) for i in range(20)) # TODO: fix the cache key self.__annotations__ = fn.__annotations__ # index of constexprs self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] @@ -124,7 +126,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): return torch.cuda.set_device(torch.cuda.current_device()) - compile_meta["device"] = 0 + compile_meta["device"] = torch.cuda.current_device() binary = triton.compile( self.fn, From d52fe816b8b78ff4d74684262fbfd28a3e0087b8 Mon Sep 17 00:00:00 2001 From: ayoub-louati Date: Wed, 22 Feb 2023 17:30:09 +0100 Subject: [PATCH 7/7] feat: clean autotune --- src/kernl/autotune.py | 65 ++++++++++++------------------------------- 1 file changed, 18 insertions(+), 47 deletions(-) diff --git a/src/kernl/autotune.py b/src/kernl/autotune.py index ae8e7771..44d1c3bd 100644 --- a/src/kernl/autotune.py +++ b/src/kernl/autotune.py @@ -2,19 +2,18 @@ import builtins import inspect import logging +import random import re import string -import random import textwrap import threading from typing import Dict, List, Optional import torch import triton -from triton import Config, cdiv +from triton import Config from triton.runtime.jit import get_cuda_stream - log = logging.getLogger(__name__) @@ -53,7 +52,7 @@ def __init__(self, fn, configs, signature, key, reset_to_zero, prune_configs_by: self.key_idx = [self.arg_names.index(k) for k in key] self.cache = dict() self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[self.src.find("def") :] + self.src = self.src[self.src.find("def"):] # hook to reset all required tensor to zeros before relaunching a kernel self.hook = lambda args: 0 if reset_to_zero is not None: @@ -73,7 +72,7 @@ def _hook(args): self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune self.fn = fn - self.fn.cache_key = ''.join(random.choice(string.printable) for i in range(20)) # TODO: fix the cache key + self.fn.cache_key = ''.join(random.choice(string.printable) for i in range(20)) self.__annotations__ = fn.__annotations__ # index of constexprs self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] @@ -108,11 +107,13 @@ def is_divisible_by_16(x): def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): """Ahead of time compile a given autotuner config.""" # make constants: - constexpr_args = [f"{arg}" for i, arg in enumerate(self.arg_names) if i in self.constexprs] + constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs] constants = {i: k for i, k in zip(self.constexprs, constexpr_args)} for k, v in constants.items(): - constants[k] = cfg.kwargs[v] if v in cfg.kwargs.keys() else 1 - compile_meta = {"constants": constants, "signature": self.signature, "num_warps": cfg.num_warps, "num_stages": cfg.num_stages} + if v in cfg.kwargs.keys(): + constants[k] = cfg.kwargs[v] + compile_meta = {"constants": constants, "signature": self.signature, "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages} cfg.divisible_by_16 = [i for i, arg in enumerate(self.arg_names) if self.is_divisible_by_16(arg)] cfg.equal_to_1 = [i for i, arg in enumerate(self.arg_names) if isinstance(arg, int) and arg == 1] @@ -125,8 +126,9 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): ) return - torch.cuda.set_device(torch.cuda.current_device()) - compile_meta["device"] = torch.cuda.current_device() + current_device = torch.cuda.current_device() + torch.cuda.set_device(current_device) + compile_meta["device"] = current_device binary = triton.compile( self.fn, @@ -168,12 +170,14 @@ def launcher({', '.join(def_args)}, grid, stream): def bench(self, launcher, *args, grid, **kwargs): """Measure the performance of a given launcher""" - current = dict(**kwargs, **launcher.config.kwargs) - def kernel_call(): if launcher.config.pre_hook is not None: launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) - launcher(*args, grid=grid, **current) + launcher( + *args, + grid=grid, + **kwargs + ) from triton.testing import do_bench @@ -206,7 +210,7 @@ def run(self, *args, grid, **kwargs): if launcher.config.pre_hook is not None: launcher.config.pre_hook({**zip(self.arg_names, args), **launcher.config.kwargs}) try: - result = launcher(*args, grid=grid, stream=stream, **kwargs) + result = launcher(*args, grid=grid, **kwargs) except TypeError as e: if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)): raise RuntimeError( @@ -254,36 +258,3 @@ def unique_configs(configs: List[Config]): seen.add(key) pruned_configs.append(cfg) return pruned_configs - - -def grid(xnumel, ynumel=None, znumel=None): - """Helper function to compute triton grids""" - - if ynumel and znumel: - - def grid_fn(meta): - return ( - cdiv(xnumel, meta["BLOCK_M"]), - cdiv(ynumel, meta["BLOCK_N"]), - cdiv(znumel, meta["BLOCK_K"]), - ) - - elif ynumel: - - def grid_fn(meta): - return ( - cdiv(xnumel, meta["BLOCK_M"]), - cdiv(ynumel, meta["BLOCK_N"]), - 1, - ) - - else: - - def grid_fn(meta): - return ( - cdiv(xnumel, meta["BLOCK_M"]), - 1, - 1, - ) - - return grid_fn