diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py index c32a328bf..d92eaa0a0 100644 --- a/python/flydsl/compiler/backends/rocm.py +++ b/python/flydsl/compiler/backends/rocm.py @@ -3,7 +3,7 @@ from typing import List, Tuple -from ...runtime.device import get_rocm_arch, is_rdna_arch +from ...runtime.device import get_rocm_arch, get_rocm_toolkit_path, is_rdna_arch from ...utils import env from .base import BaseBackend, GPUTarget @@ -90,7 +90,9 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: else [] ), ] - binary_fragment = f'gpu-module-to-binary{{format=fatbin opts="{" ".join(bin_cli_opts)}"}}' + toolkit_path = get_rocm_toolkit_path() or "" + toolkit_opt = f" toolkit={toolkit_path}" if toolkit_path else "" + binary_fragment = f'gpu-module-to-binary{{format=fatbin opts="{" ".join(bin_cli_opts)}"{toolkit_opt}}}' return [*pre_binary_fragments, *binary_prep_fragments], binary_fragment def pipeline_fragments(self, *, compile_hints: dict) -> List[str]: diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 667cb5845..f06d0c8d3 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -483,8 +483,11 @@ def _dump_isa(*, dump_dir: Path, ctx: ir.Context, asm: str, verify: bool, stage_ di_pass = ( "ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}," if env.debug.enable_debug_info else "" ) + from ..runtime.device import get_rocm_toolkit_path + + toolkit_path = get_rocm_toolkit_path() or "" pm = PassManager.parse( - f'builtin.module({di_pass}gpu-module-to-binary{{format=isa opts="{"-g" if env.debug.enable_debug_info else ""}" section= toolkit=}})', + f'builtin.module({di_pass}gpu-module-to-binary{{format=isa opts="{"-g" if env.debug.enable_debug_info else ""}" section= toolkit={toolkit_path}}})', context=ctx, ) pm.enable_verifier(bool(verify)) diff --git a/python/flydsl/runtime/device.py b/python/flydsl/runtime/device.py index ed5fd47b2..1d6080e17 100644 --- a/python/flydsl/runtime/device.py +++ b/python/flydsl/runtime/device.py @@ -4,11 +4,70 @@ import functools import os import subprocess +from pathlib import Path from typing import Optional _ROCM_AGENT_TIMEOUT_S = int(os.environ.get("FLYDSL_ROCM_AGENT_TIMEOUT", "300")) +@functools.lru_cache(maxsize=None) +def get_rocm_toolkit_path() -> Optional[str]: + """Return a directory MLIR's ROCDL backend recognizes as a toolkit. + + MLIR's gpu-module-to-binary expects ``/llvm/bin/ld.lld`` for + linking and ``/amdgcn/bitcode`` for device libraries. The + rocm-sdk Python wheels (``_rocm_sdk_core``) ship both, but at + ``/lib/llvm/bin/ld.lld`` and ``/lib/llvm/amdgcn/bitcode``, so + the layout doesn't directly match. We synthesize a tiny symlink-based + shim under ``~/.flydsl/toolkit`` and return its path. + + Order of preference: + 1. ``FLYDSL_ROCM_TOOLKIT_PATH`` env var (explicit override) + 2. ``ROCM_PATH`` env var + 3. ``/opt/rocm`` if present and well-formed + 4. Synthesized shim pointing at the rocm-sdk Python wheel. + Returns ``None`` if no toolkit can be located. + """ + + def _well_formed(root: Path) -> bool: + return (root / "llvm" / "bin" / "ld.lld").exists() and (root / "amdgcn" / "bitcode").is_dir() + + for env_var in ("FLYDSL_ROCM_TOOLKIT_PATH", "ROCM_PATH"): + val = os.environ.get(env_var, "").strip() + if val and _well_formed(Path(val)): + return val + + opt_rocm = Path("/opt/rocm") + if _well_formed(opt_rocm): + return str(opt_rocm) + + try: + import _rocm_sdk_core # type: ignore[import-not-found] + except ImportError: + return None + + sdk_root = Path(_rocm_sdk_core.__file__).parent + llvm_dir = sdk_root / "lib" / "llvm" + if not (llvm_dir / "bin" / "ld.lld").exists() or not (llvm_dir / "amdgcn" / "bitcode").is_dir(): + return None + + shim_root = Path(os.environ.get("FLYDSL_ROCM_TOOLKIT_SHIM_DIR") or (Path.home() / ".flydsl" / "toolkit")) + shim_root.mkdir(parents=True, exist_ok=True) + (shim_root / "llvm" / "bin").mkdir(parents=True, exist_ok=True) + amdgcn_link = shim_root / "amdgcn" + if not amdgcn_link.exists(): + amdgcn_link.symlink_to(llvm_dir / "amdgcn") + # ``ld.lld`` in the rocm-sdk wheel is a tiny stub that needs to resolve + # its own argv[0] to load companion libraries. Copying it elsewhere + # breaks that lookup, so we drop a thin exec wrapper instead. + wrapper = shim_root / "llvm" / "bin" / "ld.lld" + wrapper_text = f'#!/bin/bash\nexec "{llvm_dir}/bin/ld.lld" "$@"\n' + if not wrapper.exists() or wrapper.read_text() != wrapper_text: + wrapper.write_text(wrapper_text) + wrapper.chmod(0o755) + return str(shim_root) + + def _arch_from_rocm_agent_enumerator() -> Optional[str]: """Query rocm_agent_enumerator (standard ROCm tool) for the first GPU arch.""" try: