diff --git a/miners/gpu_fingerprint.py b/miners/gpu_fingerprint.py index 133ad81ae..c5a31f7eb 100644 --- a/miners/gpu_fingerprint.py +++ b/miners/gpu_fingerprint.py @@ -24,6 +24,8 @@ Author: Elyan Labs (RIP-0308: Proof of Physical AI) """ +from __future__ import annotations + import argparse import json import hashlib @@ -38,14 +40,22 @@ try: import torch - import torch.cuda except ImportError: - print("ERROR: PyTorch with CUDA support required. Install: pip install torch") - sys.exit(1) + torch = None + HAS_TORCH = False +else: + HAS_TORCH = True + try: + import torch.cuda + except ImportError: + pass -if not torch.cuda.is_available(): - print("ERROR: No CUDA-capable GPU detected.") - sys.exit(1) + +def check_requirements(): + if not HAS_TORCH or torch is None: + raise RuntimeError("PyTorch with CUDA support required. Install: pip install torch") + if not hasattr(torch, "cuda") or not torch.cuda.is_available(): + raise RuntimeError("No CUDA-capable GPU detected.") # --------------------------------------------------------------------------- @@ -789,6 +799,7 @@ def cross_validate_gpu(device: torch.device) -> ChannelResult: def run_gpu_fingerprint(device_index: int = 0, samples: int = 200, epoch_salt: str = "") -> GPUFingerprint: """Run all GPU fingerprint channels and return results.""" + check_requirements() device = torch.device(f"cuda:{device_index}") # GPU info @@ -898,16 +909,20 @@ def run_gpu_fingerprint(device_index: int = 0, samples: int = 200, epoch_salt: s help="Epoch salt for privacy (prevents cross-epoch correlation)") args = parser.parse_args() - if args.json: - # Suppress banner output for clean JSON - import io, contextlib - with contextlib.redirect_stdout(io.StringIO()): + try: + if args.json: + # Suppress banner output for clean JSON + import io, contextlib + with contextlib.redirect_stdout(io.StringIO()): + fp = run_gpu_fingerprint(device_index=args.device, samples=args.samples, epoch_salt=args.epoch_salt) + print(json.dumps(fp.to_dict(), indent=2)) + else: fp = run_gpu_fingerprint(device_index=args.device, samples=args.samples, epoch_salt=args.epoch_salt) - print(json.dumps(fp.to_dict(), indent=2)) - else: - fp = run_gpu_fingerprint(device_index=args.device, samples=args.samples, epoch_salt=args.epoch_salt) - # Print channel summary - print("Channel Details:") - for ch in fp.channels: - status = "PASS" if ch["passed"] else "FAIL" - print(f" [{status}] {ch['name']}: {ch['notes']}") + # Print channel summary + print("Channel Details:") + for ch in fp.channels: + status = "PASS" if ch["passed"] else "FAIL" + print(f" [{status}] {ch['name']}: {ch['notes']}") + except RuntimeError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + sys.exit(1) diff --git a/tests/test_gpu_fingerprint_import.py b/tests/test_gpu_fingerprint_import.py new file mode 100644 index 000000000..48d26da12 --- /dev/null +++ b/tests/test_gpu_fingerprint_import.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: MIT +"""Regression coverage for importing the GPU fingerprint helper on CPU CI.""" + +import builtins +import importlib.util +import sys +from pathlib import Path + +import pytest + + +def test_gpu_fingerprint_import_without_torch_does_not_exit(monkeypatch): + original_import = builtins.__import__ + + def import_without_torch(name, *args, **kwargs): + if name == "torch" or name.startswith("torch."): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + monkeypatch.delitem(sys.modules, "torch", raising=False) + monkeypatch.delitem(sys.modules, "torch.cuda", raising=False) + monkeypatch.setattr(builtins, "__import__", import_without_torch) + + module_path = Path(__file__).resolve().parents[1] / "miners" / "gpu_fingerprint.py" + spec = importlib.util.spec_from_file_location("gpu_fingerprint_without_torch", module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + monkeypatch.setitem(sys.modules, spec.name, module) + + spec.loader.exec_module(module) + + assert module.HAS_TORCH is False + with pytest.raises(RuntimeError, match="PyTorch with CUDA support required"): + module.check_requirements()