Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions aiter/ops/flydsl/gemm_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ..shuffle import shuffle_weight
from .kernels.splitk_hgemm import compile_hgemm_kernel
from .kernels.tensor_shim import _run_compiled
from .utils import get_shared_memory_per_block, is_flydsl_available

__all__ = [
Expand Down Expand Up @@ -571,10 +572,8 @@ def launcher(
_check_split_k_counter_capacity(runtime_m, n, tile_m, tile_n, split_k)
launch_stream = _normalize_launch_stream(a.device, stream)
semaphore = _get_split_k_global_semaphore(launch_stream)
exe_compiled = kernel.compile(
out, a, b, runtime_m, semaphore, signal_state, stream
)
return exe_compiled(
return _run_compiled(
kernel,
out,
a,
b,
Expand Down
35 changes: 1 addition & 34 deletions aiter/ops/flydsl/kernels/gdr_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,37 +420,4 @@ def launch_gdr_decode_kernel(
batch_size,
).launch(grid=(gx, 1, 1), block=(BLOCK_THREADS, 1, 1), stream=stream)

_compile_hints = {}

def _launch(*args, **kwargs):
with CompilationContext.compile_hints(_compile_hints):
return launch_gdr_decode_kernel(*args, **kwargs)

_compile_cache = {}

def _compile(
query, key, value, a, b, dt_bias, A_log, indices, state, out, batch_size, stream
):
with CompilationContext.compile_hints(_compile_hints):
lookup_key = (query.dtype, batch_size)
if _compile_cache.get(lookup_key, None) is None:
_compile_cache[lookup_key] = flyc.compile(
launch_gdr_decode_kernel,
query,
key,
value,
a,
b,
dt_bias,
A_log,
indices,
state.clone(),
out,
batch_size,
stream,
)
return _compile_cache[lookup_key]

_launch.compile = _compile

return _launch
return launch_gdr_decode_kernel
33 changes: 1 addition & 32 deletions aiter/ops/flydsl/kernels/splitk_hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,35 +917,4 @@ def launch_hgemm_kernel(
stream=stream,
)

_compile_hints = {
"llvm_options": {
"enable-post-misched": False,
"lsr-drop-solution": True,
},
}

def _launch(*args, **kwargs):
with CompilationContext.compile_hints(_compile_hints):
return launch_hgemm_kernel(*args, **kwargs)

_compile_cache = {}

def _compile(C, A, B, m, COUNTER, signal_state, stream):
with CompilationContext.compile_hints(_compile_hints):
if _compile_cache.get(m, None) is None:
try:
_compile_cache[m] = flyc.compile(
launch_hgemm_kernel, C, A, B, m, COUNTER, signal_state, stream
)
except Exception as e:
raise RuntimeError(
f"{KERNEL_NAME} failed "
f"(arch={GPU_ARCH}, n={n}, k={k}, TILE_M={TILE_M}, TILE_N={TILE_N}, "
f"TILE_K={TILE_K}, SPLIT_K={SPLIT_K}, B_TO_LDS={B_TO_LDS}, "
f"SMEM_USE={SMEM_USE}, SMEM_LIMIT={smem_limit}): {e}",
) from e
return _compile_cache[m]

_launch.compile = _compile

return _launch
return launch_hgemm_kernel
13 changes: 13 additions & 0 deletions aiter/ops/flydsl/kernels/tensor_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import numpy as np
import flydsl.compiler as flyc
from itertools import product
from abc import ABC, abstractmethod

Expand All @@ -12,6 +13,18 @@
from flydsl.expr import buffer_ops, range_constexpr, vector


def _run_compiled(exe, *args):
"""First call: ``flyc.compile(exe, *args)`` compiles **and** executes the kernel.
Subsequent calls: fast dispatch via the cached ``CompiledFunction``.
"""
cf = getattr(exe, "_cf", None)
if cf is None:
cf = flyc.compile(exe, *args)
exe._cf = cf
else:
cf(*args)


def _to_raw(v):
"""Convert ArithValue / Numeric (Int32, Boolean, …) to raw ir.Value."""
if isinstance(v, ir.Value):
Expand Down
19 changes: 3 additions & 16 deletions aiter/ops/flydsl/linear_attention_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


from .kernels.gdr_decode import create_shuffle_gdr_decode_kernel
from .kernels.tensor_shim import get_dtype_str
from .kernels.tensor_shim import get_dtype_str, _run_compiled

__all__ = [
"flydsl_gdr_decode",
Expand Down Expand Up @@ -72,21 +72,8 @@ def flydsl_gdr_decode(
use_qk_l2norm,
**kwargs,
)
exe_compiled = exe.compile(
query,
key,
value,
a,
b,
dt_bias,
A_log,
indices,
state_,
out,
batch_size,
stream,
)
exe_compiled(
_run_compiled(
exe,
query,
key,
value,
Expand Down
Loading