Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
714dac5
Add portable Docker build workflow using public TheRock wheels
ethanwee1 Apr 13, 2026
f914b59
Add push trigger on feature branch for testing
ethanwee1 Apr 13, 2026
9e82f4e
Fix Docker Hub secret names to match repo settings
ethanwee1 Apr 13, 2026
89e451a
Fix Docker Hub image name to rocm/pytorch-private
ethanwee1 Apr 13, 2026
6758921
Revert Docker Hub image name to rocmshared/pytorch-private
ethanwee1 Apr 13, 2026
06afdf0
Skip triton version pinning in auto-discovery mode
ethanwee1 Apr 13, 2026
8c1bf52
Add PyTorch smoke test step before Docker push
ethanwee1 Apr 13, 2026
127ec05
Fix smoke test: install pytest and propagate exit code
ethanwee1 Apr 13, 2026
3161ba4
Fix smoke test: run from /tmp to avoid source tree shadowing installe…
ethanwee1 Apr 13, 2026
bc5302b
Update Docker tag format to include branch, commit, OS, and python ve…
ethanwee1 Apr 13, 2026
e703c6f
Install CI test dependencies (requirements-ci.txt) in smoke test step
ethanwee1 Apr 13, 2026
ae63834
Make python_version a choice dropdown supporting 3.10 through 3.14
ethanwee1 Apr 13, 2026
17b02fa
Support Python 3.10-3.14 in Dockerfile via deadsnakes PPA
ethanwee1 Apr 13, 2026
b67dba5
Remove smoke test step for now
ethanwee1 Apr 13, 2026
0d1061b
Fix Docker Hub image name to rocm/pytorch-private
ethanwee1 Apr 13, 2026
5322264
Add nightly matrix builds and auto-discover ROCm version from index
ethanwee1 Apr 13, 2026
8679ddc
Limit nightly matrix to schedule-only, push triggers single build
ethanwee1 Apr 13, 2026
5428672
Remove push trigger, rely on workflow_dispatch and schedule only
ethanwee1 Apr 13, 2026
fa8a4a6
Default to pytorch/pytorch main for manual dispatch
ethanwee1 Apr 13, 2026
303ffce
Use pytorch/pytorch nightly branch as default to match theRock builds
ethanwee1 Apr 13, 2026
4f489ca
Reorder and simplify workflow_dispatch inputs
ethanwee1 Apr 13, 2026
3d7d015
Remove stale build-arg references to deleted inputs
ethanwee1 Apr 13, 2026
3360f8d
Auto-derive torch version prefix from branch name
ethanwee1 Apr 13, 2026
14ce620
Fix version sorting to include ROCm date, not just base version
ethanwee1 Apr 13, 2026
f11f1ca
Make Dockerfile verification step tolerant of torchaudio/torchvision …
ethanwee1 Apr 13, 2026
c0e052c
Two-pass install: pin torch+rocm first, let pip resolve companions
ethanwee1 Apr 13, 2026
ce8a82f
Remove triton from pass 2, torch dependency handles it
ethanwee1 Apr 13, 2026
4b63c1c
Fix Dockerfile verification syntax error — use heredoc for Python script
ethanwee1 Apr 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions .github/scripts/install_pytorch_wheels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
#!/usr/bin/env python3
"""
install_pytorch_wheels.py

Installs PyTorch wheels from a pip index URL.

Usage (from repo root):
python .github/scripts/install_pytorch_wheels.py --index-url <URL> --amdgpu-family <FAMILY> [OPTIONS]

Examples:
# Install latest versions
python .github/scripts/install_pytorch_wheels.py \
--index-url <BASE_INDEX_URL>/whl \
--amdgpu-family gfx1250

# Install specific versions (matching ROCm builds)
python .github/scripts/install_pytorch_wheels.py \
--index-url <BASE_INDEX_URL>/whl \
--amdgpu-family gfx1250 \
--torch-version "2.10.0+devrocm7.12.0.dev0.849eec43b..." \
--torchaudio-version "2.11.0a0+devrocm7.12.0.dev0.849eec43b..." \
--torchvision-version "0.25.0a0+devrocm7.12.0.dev0.849eec43b..."
"""

import argparse
import re
import subprocess
import sys
import urllib.parse
import urllib.request


# Package configuration: (name, always_install)
PACKAGES = {
"torch": True,
"torchaudio": True,
"torchvision": True,
"triton": False,
"rocm[devel]": True,
}
PYTORCH_PKGS = ["torch", "torchaudio", "torchvision", "triton"]


def print_banner(title: str) -> None:
"""Print a formatted banner."""
print("=" * 50)
print(title)
print("=" * 50)


def build_package_spec(name: str, version: str | None) -> str:
"""Build a pip package spec (e.g., 'torch==2.10.0' or 'torch')."""
return f"{name}=={version}" if version else name

def get_latest_package_version_for_rocm(
index_url: str, package_name: str, rocm_version: str, required: bool = True,
version_prefix: str | None = None,
) -> str | None:
"""Return latest package version containing rocm_version by parsing the index HTML.

If version_prefix is set (e.g. "2.9"), only versions whose base part starts
with that prefix are considered.
"""

# Build the URL for this package's index page (e.g. .../gfx1250/torch/).
rocm_tag = f"rocm{rocm_version}"
url = f"{index_url.rstrip('/')}/{package_name}/"
# Fetch the package index page; on failure (e.g. 404, timeout) fail if always_install, else return None.
try:
with urllib.request.urlopen(url, timeout=30) as resp:
html = resp.read().decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error: failed to fetch index for {package_name}: {e}", file=sys.stderr)
sys.exit(1)
# Parse wheel links: format is package-VERSION-...whl (e.g. torch-0.26.0a0+rocm7.12...-cp312-....whl).
# Version can contain dots and + (URL-encoded as %2B), so we capture everything up to .whl.
pattern = re.compile(
re.escape(package_name) + r"-(.+?)\.whl",
re.IGNORECASE,
)
all_suffixes = [m.group(1).strip() for m in pattern.finditer(html)]
# Keep only wheels whose version string contains the requested ROCm tag (e.g. rocm7.12.0a20260224).
# Version is the first segment before "-" in the suffix; decode %2B to + for comparison.
matching = []
for s in all_suffixes:
ver = s.split("-")[0]
if rocm_tag in ver:
matching.append(urllib.parse.unquote(ver))
# Filter by version prefix (e.g. "2.9" matches "2.9.0+...", "2.9.1+...").
if version_prefix and matching:
matching = [v for v in matching if v.split("+")[0].startswith(version_prefix)]
# No matching wheels: if required (always_install), fail; otherwise return None (package will be skipped).
if not matching:
if required:
msg = f"Error: no wheel found for {package_name} with ROCm {rocm_version}"
if version_prefix:
msg += f" and version prefix {version_prefix}"
print(msg, file=sys.stderr)
sys.exit(1)
return None
# Pick the latest version by comparing all numeric parts including the ROCm date.
def _key(v: str) -> tuple[int, ...]:
try:
return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
except (ValueError, AttributeError):
return (0,)
return max(matching, key=_key)


def run_pip_install(
index_url: str, packages: list[str], break_system_packages: bool = True
) -> None:
"""Run pip install with the given packages."""
cmd = [sys.executable, "-m", "pip", "install", "--index-url", index_url]

if break_system_packages:
cmd.append("--break-system-packages")

cmd.extend(packages)

print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, check=False)

if result.returncode != 0:
print(f"Error: pip install failed with return code {result.returncode}")
sys.exit(result.returncode)


def check_package(name: str) -> tuple[bool, str | None]:
"""Check if a package is installed and return (installed, version)."""
try:
module = __import__(name)
return True, getattr(module, "__version__", "unknown")
except ImportError:
return False, None


def verify_installation() -> bool:
"""Verify PyTorch installation and print version info."""
print_banner("Verifying Installation")

# Check torch separately for ROCm info
try:
import torch as _torch

version = getattr(_torch, "__version__", "unknown")
except ImportError as e:
print(f"Error: torch import failed ({e!r}). If wheels are installed, run rocm-sdk init first.")
return False

print(f"torch: {version}")

hip_version = _torch.version.hip
print(f"ROCm/HIP: {hip_version or 'not available'}")
print(f"Built with ROCm: {hip_version is not None}")

# Check other packages
for name in ["torchaudio", "torchvision", "triton", "rocm"]:
installed, version = check_package(name)
status = version if installed else "not installed"
print(f"{name}: {status}")

return True


def list_installed_packages() -> None:
"""List installed torch-related packages."""
print("\nInstalled PyTorch packages:")
result = subprocess.run(
[sys.executable, "-m", "pip", "list"],
capture_output=True,
text=True,
check=False,
)

if result.returncode == 0:
keywords = ["torch", "triton", "rocm"]
for line in result.stdout.splitlines():
if any(kw in line.lower() for kw in keywords):
print(f" {line}")


def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Install PyTorch wheels from a pip index URL",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)

parser.add_argument(
"--index-url", required=True, help="Base URL for PyTorch wheels index"
)
parser.add_argument(
"--amdgpu-family", required=True, help="AMD GPU family (e.g., gfx1250)"
)
parser.add_argument(
"--rocm-version",
help="Optional. ROCm version (e.g. 7.12.0a20260126). When set without --torch-version: discovers and installs latest torch/torchaudio/torchvision/triton built for this ROCm. ",
)
parser.add_argument(
"--torch-version", help="Specific torch version (default: latest)"
)
parser.add_argument(
"--torch-version-prefix",
help="Torch version prefix for discovery (e.g. '2.9' matches 2.9.x). "
"Only used in auto-discovery mode (--rocm-version without --torch-version).",
)
parser.add_argument(
"--torchaudio-version", help="Specific torchaudio version (default: latest)"
)
parser.add_argument(
"--torchvision-version", help="Specific torchvision version (default: latest)"
)
parser.add_argument(
"--triton-version",
help="Specific triton version (default: from torch dependency)",
)
parser.add_argument(
"--no-break-system-packages",
action="store_true",
help="Don't use --break-system-packages",
)
parser.add_argument(
"--skip-verify", action="store_true", help="Skip verification step"
)

args = parser.parse_args()

# Build the full index URL
index_url = f"{args.index_url.rstrip('/')}/{args.amdgpu_family}/"

rocm = args.rocm_version
rocm_only = bool(rocm and not args.torch_version)
torch_prefix = args.torch_version_prefix if rocm_only else None
break_sys = not args.no_break_system_packages

if rocm_only:
# Two-pass install:
# Pass 1: torch (pinned) + rocm[devel] (pinned)
# Pass 2: torchaudio, torchvision, triton (unpinned — pip resolves compatibility)
torch_version = get_latest_package_version_for_rocm(
index_url, "torch", rocm, required=True, version_prefix=torch_prefix,
)

print_banner("PyTorch Wheels Installation")
print(f"Index URL: {index_url}")
print(f"AMDGPU Family: {args.amdgpu_family}")
print(f"Python: {sys.version_info.major}.{sys.version_info.minor}")
print(f"torch: {torch_version}")
print(f"rocm[devel]: {rocm}")
print(f"torchaudio: (pip resolves)")
print(f"torchvision: (pip resolves)")
print(f"triton: (torch dependency)")
print("=" * 50)

# Pass 1: install torch + rocm[devel] with exact versions.
# torch's declared dependency on triton pulls in the correct build.
primary = [
build_package_spec("torch", torch_version),
build_package_spec("rocm[devel]", rocm),
]
print_banner("Pass 1: torch + rocm[devel]")
print(f"Installing: {', '.join(primary)}")
run_pip_install(index_url, primary, break_sys)

# Pass 2: install torchaudio/torchvision without pinning — pip picks
# versions compatible with the torch that's already installed
companions = ["torchaudio", "torchvision"]
print_banner("Pass 2: torchaudio, torchvision (unpinned)")
print(f"Installing: {', '.join(companions)}")
run_pip_install(index_url, companions, break_sys)
else:
# Explicit versions mode — install everything in one shot
arg_attrs = ["torch_version", "torchaudio_version", "torchvision_version", "triton_version"]
versions = {p: getattr(args, a) for p, a in zip(PYTORCH_PKGS, arg_attrs)}
versions["rocm[devel]"] = rocm if rocm else None

print_banner("PyTorch Wheels Installation")
print(f"Index URL: {index_url}")
print(f"AMDGPU Family: {args.amdgpu_family}")
print(f"Python: {sys.version_info.major}.{sys.version_info.minor}")
for name, version in versions.items():
print(f"{name:14}: {version or 'latest'}")
print("=" * 50)

packages = []
for name, always_install in PACKAGES.items():
version = versions.get(name)
if always_install or version:
packages.append(build_package_spec(name, version))

print(f"Installing: {', '.join(packages)}")
run_pip_install(index_url, packages, break_sys)

# Verify
if not args.skip_verify and not verify_installation():
return 1

list_installed_packages()
print_banner("Installation complete")
return 0


if __name__ == "__main__":
sys.exit(main())
Loading