Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions python/flydsl/compiler/backends/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,22 @@ def _format_pass_opts(opts: dict) -> str:
"""Format {key: value, ...} as 'key=value key2=value2' for MLIR pass options."""
return " ".join(f"{k}={v}" for k, v in opts.items())

def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]:
chip = self.target.arch
def _bin_cli_opts(self, *, compile_hints: dict) -> List[str]:
waves_per_eu = compile_hints.get("waves_per_eu")
maxnreg = compile_hints.get("maxnreg")

bin_cli_opts = []
if env.debug.enable_debug_info:
bin_cli_opts.append("-g")
if waves_per_eu:
bin_cli_opts.append(f"--amdgpu-waves-per-eu={waves_per_eu}")
if maxnreg:
bin_cli_opts.append(f"--amdgpu-num-vgpr={maxnreg}")
return bin_cli_opts

rocdl_opts = {
"O": 2,
def _rocdl_opts(self, *, compile_hints: dict, opt_level: int = 2) -> dict:
chip = self.target.arch
return {
"O": opt_level,
"abi": 600,
"chip": chip,
"correct-sqrt": "true",
Expand All @@ -61,6 +62,24 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]:
"wave64": "false" if is_rdna_arch(chip) else "true",
}

def llvm_recodegen_fragments(self, *, compile_hints: dict, opt_level: int = 0) -> Tuple[str, str]:
"""Fragments to re-codegen an already-LLVM-dialect ``gpu.module`` that has
NO target attached: attach a ROCDL target at ``opt_level`` then emit the
device binary. Used by the custom-LLVM-pass path, which has already run
its own ``opt`` pipeline, so codegen runs at ``O=0`` to avoid re-optimizing.
"""
rocdl_opts = self._rocdl_opts(compile_hints=compile_hints, opt_level=opt_level)
bin_cli_opts = self._bin_cli_opts(compile_hints=compile_hints)
attach_fragment = f"rocdl-attach-target{{{self._format_pass_opts(rocdl_opts)}}}"
binary_fragment = f'gpu-module-to-binary{{format=fatbin opts="{" ".join(bin_cli_opts)}"}}'
return attach_fragment, binary_fragment

def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]:
chip = self.target.arch

bin_cli_opts = self._bin_cli_opts(compile_hints=compile_hints)
rocdl_opts = self._rocdl_opts(compile_hints=compile_hints, opt_level=2)

pre_binary_fragments = [
"fly-rewrite-func-signature",
"fly-canonicalize",
Expand Down
130 changes: 130 additions & 0 deletions python/flydsl/compiler/external_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,133 @@ def run_mlir_opt(*, pass_pipeline: str, input_path: Path, output_path: Path) ->
finally:
if tmp_dir_obj is not None:
tmp_dir_obj.cleanup()


def llvm_opt_fingerprint(pipeline: str, plugins: Optional[list] = None) -> str:
"""Cache fingerprint for a custom LLVM-opt configuration: the pipeline
string plus each plugin's path and content hash, so editing a plugin .so
(or the pipeline) invalidates cached artifacts."""
parts = [f"llvm-opt:{pipeline}"]
for p in plugins or []:
path = Path(p).expanduser()
try:
parts.append(f"{path}:{_file_hash(path.resolve())}")
except OSError:
parts.append(f"{path}:<missing>")
return ";".join(parts)


def _run_tool(cmd: list, *, prefix: Path, what: str, work_dir: Path) -> None:
try:
subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=600, env=_subprocess_env(prefix))
except subprocess.TimeoutExpired as exc:
raise ExternalLLVMError(
f"{what} timed out after 600s.\ncommand: {' '.join(cmd)}\nwork_dir: {work_dir}"
) from exc
except subprocess.CalledProcessError as exc:
raise ExternalLLVMError(
f"{what} failed.\nllvm_dir: {prefix}\ncommand: {' '.join(cmd)}\n"
f"work_dir: {work_dir}\nstdout:\n{exc.stdout}\nstderr:\n{exc.stderr}"
) from exc


def run_llvm_opt_then_binary(
module: ir.Module,
*,
llvm_ir: str,
attach_fragment: str,
binary_fragment: str,
pipeline: str,
plugins: Optional[list] = None,
llvm_options: Optional[dict] = None,
work_dir: Optional[Path] = None,
stage_prefix: str = "llvm_opt",
) -> None:
"""Run a custom LLVM new-PM pass pipeline on the device kernel's (pre-link)
LLVM IR, then re-codegen the device binary and splice it back into *module*.

Flow: ``opt --passes`` (with optional ``--load-pass-plugin``) on ``llvm_ir``
-> ``mlir-translate --import-llvm`` -> wrap into a ``gpu.module`` -> external
``mlir-opt`` running ``attach_fragment`` (ROCDL target at O=0) then
``binary_fragment`` (``gpu-module-to-binary``) -> replace the in-process
``gpu.module`` with the produced ``gpu.binary``.
"""
prefix = _llvm_dir()
opt = _tool(prefix, "opt")
mlir_translate = _tool(prefix, "mlir-translate")
mlir_opt = _tool(prefix, "mlir-opt")

gpu_module = _single_top_level_op(module, "gpu.module")
name = _symbol_name(gpu_module)
data_layout = None
if "llvm.data_layout" in gpu_module.attributes:
try:
data_layout = ir.StringAttr(gpu_module.attributes["llvm.data_layout"]).value
except Exception:
data_layout = None

llvm_cli_args = _format_llvm_cli_options(llvm_options) if llvm_options else []

tmp_dir_obj = None
if work_dir is None:
tmp_dir_obj = tempfile.TemporaryDirectory(prefix="flydsl_llvm_opt_")
work_dir = Path(tmp_dir_obj.name)
else:
work_dir.mkdir(parents=True, exist_ok=True)

in_ll = work_dir / f"{stage_prefix}_pre_opt.ll"
out_ll = work_dir / f"{stage_prefix}_post_opt.ll"
imported_path = work_dir / f"{stage_prefix}_imported.mlir"
wrapped_path = work_dir / f"{stage_prefix}_wrapped.mlir"
bin_path = work_dir / f"{stage_prefix}_binary.mlir"

try:
in_ll.write_text(llvm_ir, encoding="utf-8")

plugin_args = [f"--load-pass-plugin={Path(p).expanduser()}" for p in (plugins or [])]
_run_tool(
[str(opt), str(in_ll), "-S", f"--passes={pipeline}", *plugin_args, *llvm_cli_args, "-o", str(out_ll)],
prefix=prefix,
what="LLVM opt pass pipeline",
work_dir=work_dir,
)

_run_tool(
[str(mlir_translate), "--import-llvm", str(out_ll), "-o", str(imported_path)],
prefix=prefix,
what="mlir-translate --import-llvm",
work_dir=work_dir,
)

# Wrap the re-imported LLVM-dialect IR back into a gpu.module (no target;
# attach_fragment adds it). The original gpu.module's data layout is
# re-applied; gpu-module-to-binary will produce gpu.binary @<name>.
imported = ir.Module.parse(imported_path.read_text(encoding="utf-8"), context=module.context)
body = "\n".join(op.operation.get_asm() for op in imported.body.operations)
dl_attr = f' attributes {{llvm.data_layout = "{data_layout}"}}' if data_layout else ""
wrapped_path.write_text(
f"module attributes {{gpu.container_module}} {{\n" f" gpu.module @{name}{dl_attr} {{\n{body}\n }}\n}}\n",
encoding="utf-8",
)

_run_tool(
[
str(mlir_opt),
str(wrapped_path),
f"--pass-pipeline=builtin.module({attach_fragment},{binary_fragment})",
*llvm_cli_args,
"-o",
str(bin_path),
],
prefix=prefix,
what="external gpu-module-to-binary codegen",
work_dir=work_dir,
)

if not bin_path.is_file():
raise ExternalLLVMError(f"external codegen did not create output file: {bin_path}")
binary_module = ir.Module.parse(bin_path.read_text(encoding="utf-8"), context=module.context)
_replace_gpu_module_with_binary_op(module, binary_module)
finally:
if tmp_dir_obj is not None:
tmp_dir_obj.cleanup()
124 changes: 120 additions & 4 deletions python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def _create_mlir_context(*, load_dialects=True):
"HSA_OVERRIDE_GFX_VERSION",
"FLYDSL_DEBUG_ENABLE_DEBUG_INFO",
"FLYDSL_EXTRA_SOURCE_DIRS",
"FLYDSL_COMPILE_LLVM_PASS_PIPELINE",
"FLYDSL_COMPILE_LLVM_PASS_PLUGINS",
)


Expand Down Expand Up @@ -745,6 +747,20 @@ class PipelineConfig:
binary_fragment: Optional[str]
llvm_opts: Optional[dict]
external: bool
llvm_pass_pipeline: str = ""
llvm_pass_plugins: Optional[list] = None


def _effective_llvm_pass_config(hints: dict):
"""Resolve the custom LLVM pass pipeline + plugins, preferring the
@flyc.jit compile_hints over the FLYDSL_COMPILE_LLVM_PASS_* env vars."""
pipeline = hints.get("llvm_pass_pipeline")
if pipeline is None:
pipeline = env.compile.llvm_pass_pipeline
plugins = hints.get("llvm_pass_plugins")
if plugins is None:
plugins = env.compile.llvm_pass_plugins
return (pipeline or "").strip(), list(plugins or [])


def _pipeline_fragments_for_mode(backend) -> PipelineConfig:
Expand All @@ -753,6 +769,22 @@ def _pipeline_fragments_for_mode(backend) -> PipelineConfig:

hints = CompilationContext.get_compile_hints()
llvm_opts = hints.get("llvm_options")
llvm_pass_pipeline, llvm_pass_plugins = _effective_llvm_pass_config(hints)

# Custom LLVM pass pipeline: split off the binary fragment so we can extract
# LLVM IR, run `opt`, and re-codegen externally (see MlirCompiler.compile).
if llvm_pass_pipeline:
pre_binary_fragments, binary_fragment = backend.external_binary_pipeline_fragments(compile_hints=hints)
return PipelineConfig(
fragments=[*pre_binary_fragments, binary_fragment],
pre_binary=pre_binary_fragments,
binary_fragment=binary_fragment,
llvm_opts=llvm_opts,
external=False,
llvm_pass_pipeline=llvm_pass_pipeline,
llvm_pass_plugins=llvm_pass_plugins,
)

if _use_external_binary_codegen():
pre_binary_fragments, binary_fragment = backend.external_binary_pipeline_fragments(compile_hints=hints)
return PipelineConfig(
Expand Down Expand Up @@ -809,6 +841,9 @@ def compile(
"use embedded codegen for kernels that require #fly.explicit_module."
)

if cfg.llvm_pass_pipeline and link_libs:
raise RuntimeError("custom llvm_pass_pipeline does not support extern link_libs yet.")

if link_libs:
link_opt = _format_link_lib_options(link_libs)
fragments, found_attach_target = _append_link_lib_options_to_attach_targets(fragments, link_opt)
Expand All @@ -826,6 +861,10 @@ def compile(
dump_dir = Path(env.debug.dump_dir).resolve()

with _llvm_ctx:
if cfg.llvm_pass_pipeline:
return cls._compile_with_llvm_opt(
module, backend, cfg, func_name=func_name, dump_enabled=dump_enabled, dump_dir=dump_dir
)
if dump_enabled:
asm = module.operation.get_asm(enable_debug_info=True)
kernel_names = _infer_kernel_names_from_asm(asm)
Expand Down Expand Up @@ -937,6 +976,53 @@ def compile(

return module

@classmethod
def _compile_with_llvm_opt(
cls, module, backend, cfg, *, func_name: str, dump_enabled: bool, dump_dir: Path
) -> ir.Module:
"""Custom-LLVM-pass path: run the pre-binary fragments in-process, extract
the device LLVM IR, run the user's ``opt`` pipeline (+ plugins) on it, then
re-codegen the binary externally and splice it back."""
from .external_llvm import run_llvm_opt_then_binary
from .kernel_function import CompilationContext

hints = CompilationContext.get_compile_hints()
work_dir = None
if dump_enabled:
asm = module.operation.get_asm(enable_debug_info=True)
kernel_names = _infer_kernel_names_from_asm(asm)
subdir = kernel_names[0] if len(kernel_names) == 1 else (func_name or "module")
work_dir = dump_dir / _sanitize_path_component(subdir)
print(f"[flydsl.compile] FLYDSL_DUMP_IR=1 (llvm_pass_pipeline) dir={work_dir}")

# Run everything up to (but not including) gpu-module-to-binary in-process.
_run_pipeline(
module,
cfg.pre_binary,
verifier=env.debug.enable_verifier,
print_after_all=env.debug.print_after_all,
)

llvm_ir = _extract_llvm_ir(module)
if llvm_ir is None:
raise FlyDSLCompileError(
"llvm_pass_pipeline is set but the device LLVM IR could not be extracted from the gpu.module."
)

attach_fragment, binary_fragment = backend.llvm_recodegen_fragments(compile_hints=hints, opt_level=0)
run_llvm_opt_then_binary(
module,
llvm_ir=llvm_ir,
attach_fragment=attach_fragment,
binary_fragment=binary_fragment,
pipeline=cfg.llvm_pass_pipeline,
plugins=cfg.llvm_pass_plugins,
llvm_options=cfg.llvm_opts,
work_dir=work_dir,
)
module.operation.verify()
return module


class JitCacheManager:
"""Directory-based cache manager with multi-process safety.
Expand Down Expand Up @@ -1344,6 +1430,14 @@ def _resolve_and_make_cache_key(self, bound_args):
key_parts = [("_env_", _cache_invalidating_env_values()), ("_target_", self._backend_target)]
if self.compile_hints:
key_parts.append(("_hints_", tuple(sorted((k, str(v)) for k, v in self.compile_hints.items()))))
# Fold the effective custom LLVM pass pipeline + plugin content hashes
# (from hints or env) into the key so editing a plugin .so or the
# pipeline invalidates cached artifacts.
eff_pipeline, eff_plugins = _effective_llvm_pass_config(self.compile_hints)
if eff_pipeline:
from .external_llvm import llvm_opt_fingerprint

key_parts.append(("_llvm_pass_", llvm_opt_fingerprint(eff_pipeline, eff_plugins)))

for name, arg in bound_args.items():
param = sig.parameters.get(name)
Expand Down Expand Up @@ -1661,11 +1755,33 @@ def _ensure_stream_arg(jit_args: list) -> bool:
return False


def jit(func: Optional[Callable] = None) -> JitFunction:
"""JIT decorator for host launcher functions."""
def jit(
func: Optional[Callable] = None,
*,
llvm_pass_pipeline: Optional[str] = None,
llvm_pass_plugins: Optional[list] = None,
) -> JitFunction:
"""JIT decorator for host launcher functions.

``llvm_pass_pipeline``: optional LLVM new-PM pass pipeline (e.g.
``"default<O3>,my-pass"``) run on the device kernel IR before codegen.
``llvm_pass_plugins``: optional list of LLVM pass plugin ``.so`` paths
loaded (``opt --load-pass-plugin``) before running that pipeline. Both
require ``FLYDSL_COMPILE_LLVM_DIR`` and override the
``FLYDSL_COMPILE_LLVM_PASS_*`` env vars.
"""
hints = {}
if llvm_pass_pipeline is not None:
hints["llvm_pass_pipeline"] = llvm_pass_pipeline
if llvm_pass_plugins is not None:
hints["llvm_pass_plugins"] = list(llvm_pass_plugins)

def _make(f: Callable) -> JitFunction:
return JitFunction(f, compile_hints=hints or None)

if func is None:
return lambda f: JitFunction(f)
return JitFunction(func)
return _make
return _make(func)


class CompiledFunction:
Expand Down
12 changes: 12 additions & 0 deletions python/flydsl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,18 @@ class CompileEnvManager(EnvManager):
arch = OptStr("", env_var="ARCH", description="Override target GPU architecture (e.g. gfx942, gfx950)")
backend = OptStr("rocm", description="GPU compile backend id (e.g. rocm)")
llvm_dir = OptStr("", description="External LLVM/MLIR install prefix for final code generation")
llvm_pass_pipeline = OptStr(
"",
description="Custom LLVM new-PM pass pipeline run on the device kernel IR before codegen "
"(e.g. 'default<O3>,my-pass'); requires FLYDSL_COMPILE_LLVM_DIR. Overridden by "
"@flyc.jit(llvm_pass_pipeline=...).",
)
llvm_pass_plugins = OptList(
[],
separator=":",
description="Colon-separated LLVM pass plugin .so paths loaded (opt --load-pass-plugin) "
"before running llvm_pass_pipeline. Overridden by @flyc.jit(llvm_pass_plugins=...).",
)


class DebugEnvManager(EnvManager):
Expand Down
Loading
Loading