Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- avoid repeatedly scanning sharded model families during directory scans
- keep shard sibling discovery within the requested scan root
- preserve per-shard metadata when aggregating sharded model families
- preserve HuggingFace snapshot shard paths while grouping cache-backed families
- distinguish ASCII-serialized Torch7 artifacts from plain PyTorch source text

## [0.2.45](https://github.com/promptfoo/modelaudit/compare/v0.2.44...v0.2.45) (2026-05-03)
Expand Down
45 changes: 40 additions & 5 deletions modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def scan_model_directory_or_file(
files_to_scan: list[str] = []
shard_family_representatives: dict[_ShardFamilyKey, str] = {}
shard_family_paths: dict[_ShardFamilyKey, set[str]] = {}
complete_hf_shard_families: set[_ShardFamilyKey] = set()
directory_discovery_started_at = _start_phase_timing(phase_timings)
for root, _, files in os.walk(path, followlinks=False):
for file in files:
Expand All @@ -617,6 +618,12 @@ def scan_model_directory_or_file(
)
if resolved_file is None:
continue
snapshot_path = Path(file_path).absolute()
scan_source = (
snapshot_path
if is_hf_cache_symlink and _shard_family_key_for_path(str(snapshot_path)) is not None
else resolved_file
Comment thread
mldangelo-oai marked this conversation as resolved.
Outdated
Comment thread
mldangelo-oai marked this conversation as resolved.
Outdated
)

# Skip non-model files early if filtering is enabled
# Note: skip_file_types parameter already contains the correct value
Expand Down Expand Up @@ -647,17 +654,23 @@ def scan_model_directory_or_file(
continue

# Handle DVC files and get target paths
target_paths = [resolved_file]
target_paths = [scan_source]
if file.endswith(".dvc"):
dvc_targets = resolve_dvc_file(file_path)
if dvc_targets:
target_paths = [Path(t).resolve() for t in dvc_targets]

for target_path in target_paths:
target_str = str(target_path)
if target_str in scanned_paths:
shard_family_key = _shard_family_key_for_path(target_str)
dedupe_target_str = (
str(target_path.resolve())
if is_hf_cache_symlink and shard_family_key is None
else target_str
Comment thread
mldangelo-oai marked this conversation as resolved.
)
if dedupe_target_str in scanned_paths:
continue
scanned_paths.add(target_str)
scanned_paths.add(dedupe_target_str)

if not is_hf_cache_symlink and not is_within_directory(str(base_dir), str(target_path)):
_add_issue_to_model(
Expand All @@ -670,14 +683,22 @@ def scan_model_directory_or_file(
continue

# Add to files to scan list instead of scanning immediately
shard_family_key = _shard_family_key_for_path(target_str)
if shard_family_key is not None:
family_paths = shard_family_paths.setdefault(shard_family_key, set())
family_paths.add(target_str)
if shard_family_key not in shard_family_representatives:
shard_family_representatives[shard_family_key] = target_str
shard_info = ShardedModelDetector.detect_shards(target_str)
if shard_info is not None:
expected_total_shards = shard_info.get("expected_total_shards")
if (
is_hf_cache_symlink
and isinstance(expected_total_shards, int)
and shard_info.get("total_shards") == expected_total_shards
and "missing_shard_count" not in shard_info
and "inconsistent_expected_total_shards" not in shard_info
):
complete_hf_shard_families.add(shard_family_key)
for shard_path in shard_info.get("shards", []):
if isinstance(shard_path, str):
resolved_shard_path = str(Path(shard_path).resolve())
Expand All @@ -690,7 +711,9 @@ def scan_model_directory_or_file(
resolved_shard_path,
)
)
if shard_in_base_dir or shard_in_hf_blobs:
if shard_in_hf_blobs:
family_paths.add(str(Path(shard_path).absolute()))
elif shard_in_base_dir:
family_paths.add(resolved_shard_path)
else:
_add_issue_to_model(
Expand All @@ -706,8 +729,20 @@ def scan_model_directory_or_file(
_finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at)

scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan]
seen_complete_hf_shard_families: set[tuple[str, ...]] = set()
for shard_family_key, representative_file in shard_family_representatives.items():
ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file}))
expected_total_shards = shard_family_key[2]
if (
is_hf_cache
and expected_total_shards is not None
and shard_family_key in complete_hf_shard_families
and len(ordered_family_paths) == expected_total_shards
):
resolved_family_paths = tuple(sorted(str(Path(path).resolve()) for path in ordered_family_paths))
if resolved_family_paths in seen_complete_hf_shard_families:
continue
Comment thread
mldangelo-oai marked this conversation as resolved.
seen_complete_hf_shard_families.add(resolved_family_paths)
scan_entries.append((representative_file, ordered_family_paths, shard_family_key))

# Second pass: scan every non-shard path independently and every shard
Expand Down
22 changes: 19 additions & 3 deletions modelaudit/utils/file/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,23 @@ def _analyze_window(self, data: bytes, offset: int) -> "ScanResult":
class ParallelShardHandler:
"""Scan multiple model shards in parallel."""

def __init__(self, shard_info: dict[str, Any], scanner_class: type):
def __init__(
self,
shard_info: dict[str, Any],
scanner_class: type,
scanner_config: dict[str, Any] | None = None,
):
"""
Initialize parallel shard scanner.

Args:
shard_info: Information about model shards
scanner_class: Scanner class to use
scanner_config: Configuration to preserve for each shard scanner
"""
self.shard_info = shard_info
self.scanner_class = scanner_class
self.scanner_config = scanner_config

def scan_shards(self, progress_callback: Callable[[str, float], None] | None = None) -> "ScanResult":
from ...scanner_results import IssueSeverity, ScanResult
Expand Down Expand Up @@ -425,7 +432,11 @@ def _scan_single_shard(self, shard_path: str) -> "ScanResult":
from ...scanner_results import ScanResult

"""Scan a single shard file."""
scanner = self.scanner_class()
scanner = (
self.scanner_class(config=dict(self.scanner_config))
if self.scanner_config is not None
else self.scanner_class()
)
result: ScanResult = scanner.scan(shard_path)
return result

Expand Down Expand Up @@ -560,7 +571,12 @@ def _scan_sharded_model(self) -> "ScanResult":
# Scan shards in parallel
shard_scan_success = True
if self.shard_info:
parallel_scanner = ParallelShardHandler(self.shard_info, self.scanner.__class__)
scanner_config = getattr(self.scanner, "config", None)
parallel_scanner = ParallelShardHandler(
self.shard_info,
self.scanner.__class__,
scanner_config=dict(scanner_config) if isinstance(scanner_config, dict) else None,
)
shard_results = parallel_scanner.scan_shards(self.progress_callback)
shard_scan_success = bool(shard_results.success)
result.merge(shard_results)
Expand Down
172 changes: 172 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,178 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul
)


@pytest.mark.usefixtures("requires_symlinks")
def test_directory_scan_groups_hf_cache_sharded_symlinks(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
hf_home = tmp_path / "hf-home"
monkeypatch.setenv("HF_HOME", str(hf_home))
cache_dir = hf_home / "hub" / "models--org--model"
snapshots_dir = cache_dir / "snapshots" / "abc123"
blobs_dir = cache_dir / "blobs"
snapshots_dir.mkdir(parents=True)
blobs_dir.mkdir()

blob_paths: list[Path] = []
shard_links: list[Path] = []
for shard_index in range(1, 3):
blob_path = blobs_dir / f"blob-{shard_index}"
blob_path.write_bytes(f"hf-shard-{shard_index}".encode())
shard_link = snapshots_dir / f"model-{shard_index:05d}-of-00002.safetensors"
shard_link.symlink_to(Path("../../blobs") / blob_path.name)
blob_paths.append(blob_path.resolve())
shard_links.append(shard_link)

captured_configs: list[dict[str, Any]] = []
calls: list[str] = []

def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult:
calls.append(path)
captured_configs.append(dict(config or {}))
return _mock_sharded_scan_result(sum(blob_path.stat().st_size for blob_path in blob_paths))

monkeypatch.setattr(core_module, "scan_file", fake_scan_file)

result = core_module.scan_model_directory_or_file(str(snapshots_dir), cache_scan_results=False)

material_config = normalize_material_scan_config(captured_configs[0])
fingerprint = material_config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY]
assert len(calls) == 1
assert Path(calls[0]).name in {shard_link.name for shard_link in shard_links}
assert result.files_scanned == len(shard_links)
assert set(result.file_metadata) == {str(shard_link) for shard_link in shard_links}
assert {asset.path for asset in result.assets} == {str(shard_link) for shard_link in shard_links}
assert {member["path"] for member in fingerprint["members"]} == {str(blob_path) for blob_path in blob_paths}
assert not any("path traversal" in issue.message.lower() for issue in result.issues)


@pytest.mark.usefixtures("requires_symlinks")
def test_directory_scan_keeps_nonsharded_hf_snapshot_aliases_deduplicated(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
hf_home = tmp_path / "hf-home"
monkeypatch.setenv("HF_HOME", str(hf_home))
cache_dir = hf_home / "hub" / "models--org--model"
blobs_dir = cache_dir / "blobs"
blobs_dir.mkdir(parents=True)
blob_path = blobs_dir / "shared-blob"
blob_path.write_bytes(b"shared-model")

for revision in ("abc123", "def456"):
snapshots_dir = cache_dir / "snapshots" / revision
snapshots_dir.mkdir(parents=True)
(snapshots_dir / "model.safetensors").symlink_to(Path("../../blobs") / blob_path.name)

calls: list[str] = []

def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult:
calls.append(path)
return _mock_sharded_scan_result(blob_path.stat().st_size)

monkeypatch.setattr(core_module, "scan_file", fake_scan_file)

result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False)

assert calls == [str(blob_path.resolve())]
assert result.files_scanned == 1


@pytest.mark.usefixtures("requires_symlinks")
def test_directory_scan_deduplicates_identical_hf_shard_families_across_snapshots(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
hf_home = tmp_path / "hf-home"
monkeypatch.setenv("HF_HOME", str(hf_home))
cache_dir = hf_home / "hub" / "models--org--model"
blobs_dir = cache_dir / "blobs"
blobs_dir.mkdir(parents=True)
blob_paths: list[Path] = []
for shard_index in range(1, 3):
blob_path = blobs_dir / f"blob-{shard_index}"
blob_path.write_bytes(f"shared-hf-shard-{shard_index}".encode())
blob_paths.append(blob_path.resolve())
for revision in ("abc123", "def456"):
snapshots_dir = cache_dir / "snapshots" / revision
snapshots_dir.mkdir(parents=True, exist_ok=True)
(snapshots_dir / f"model-{shard_index:05d}-of-00002.safetensors").symlink_to(
Path("../../blobs") / blob_path.name
)

captured_configs: list[dict[str, Any]] = []
calls: list[str] = []

def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult:
calls.append(path)
captured_configs.append(dict(config or {}))
return _mock_sharded_scan_result(sum(blob_path.stat().st_size for blob_path in blob_paths))

monkeypatch.setattr(core_module, "scan_file", fake_scan_file)

result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False)

material_config = normalize_material_scan_config(captured_configs[0])
fingerprint = material_config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY]
assert len(calls) == 1
assert result.files_scanned == len(blob_paths)
assert result.bytes_scanned == sum(blob_path.stat().st_size for blob_path in blob_paths)
assert {member["path"] for member in fingerprint["members"]} == {str(blob_path) for blob_path in blob_paths}


@pytest.mark.usefixtures("requires_symlinks")
def test_directory_scan_reports_incomplete_hf_snapshot_after_shared_blob_dedupe(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
hf_home = tmp_path / "hf-home"
monkeypatch.setenv("HF_HOME", str(hf_home))
cache_dir = hf_home / "hub" / "models--org--model"
blobs_dir = cache_dir / "blobs"
blobs_dir.mkdir(parents=True)
blob_paths: list[Path] = []
for shard_index in range(1, 3):
blob_path = blobs_dir / f"blob-{shard_index}"
blob_path.write_bytes(f"shared-hf-shard-{shard_index}".encode())
blob_paths.append(blob_path.resolve())
full_snapshot = cache_dir / "snapshots" / "abc123"
full_snapshot.mkdir(parents=True, exist_ok=True)
(full_snapshot / f"model-{shard_index:05d}-of-00002.safetensors").symlink_to(
Path("../../blobs") / blob_path.name
)

partial_snapshot = cache_dir / "snapshots" / "def456"
partial_snapshot.mkdir(parents=True)
(partial_snapshot / "model-00001-of-00002.safetensors").symlink_to(Path("../../blobs") / blob_paths[0].name)

captured_configs: list[dict[str, Any]] = []
calls: list[str] = []

def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult:
calls.append(path)
captured_configs.append(dict(config or {}))
material_config = normalize_material_scan_config(captured_configs[-1])
fingerprint = material_config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY]
member_paths = [Path(member["path"]) for member in fingerprint["members"]]
return _mock_sharded_scan_result(
sum(member_path.stat().st_size for member_path in member_paths),
missing_shards=1 if len(member_paths) == 1 else 0,
)

monkeypatch.setattr(core_module, "scan_file", fake_scan_file)

result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False)

material_configs = [normalize_material_scan_config(config) for config in captured_configs]
fingerprints = [config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] for config in material_configs]
coverage_checks = [check for check in result.checks if check.name == "Sharded Model Coverage Check"]
assert len(calls) == 2
assert sorted(len(fingerprint["members"]) for fingerprint in fingerprints) == [1, 2]
assert len(coverage_checks) == 1
assert coverage_checks[0].details["missing_shard_count"] == 1


def test_scan_file_passes_shard_allowlist_to_advanced_handler(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
Expand Down
21 changes: 21 additions & 0 deletions tests/utils/file/test_advanced_file_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,27 @@ def test_sharded_model_honors_allowed_shard_paths(self, tmp_path: Path) -> None:
assert shard_detection.details["shards"] == [str(shard_one)]
assert result.bytes_scanned == shard_one.stat().st_size

def test_sharded_model_preserves_scanner_config_for_each_shard(self, tmp_path: Path) -> None:
"""Shard fanout should retain caller configuration for each scanner instance."""
shard_one = tmp_path / "model-00001-of-00002.safetensors"
shard_two = tmp_path / "model-00002-of-00002.safetensors"
shard_one.write_bytes(b"one")
shard_two.write_bytes(b"two")
captured_configs: list[dict[str, Any]] = []

class ConfiguredShardScanner(CompletingShardScanner):
def __init__(self, config: dict[str, Any] | None = None) -> None:
self.config = dict(config or {})
captured_configs.append(self.config)

scanner = ConfiguredShardScanner({"max_tensor_bytes": 7})
captured_configs.clear()

result = AdvancedFileHandler(str(shard_one), scanner).scan()

assert result.success is True
assert captured_configs == [{"max_tensor_bytes": 7}, {"max_tensor_bytes": 7}]

def test_cached_advanced_scan_keys_allowed_shard_paths(self, tmp_path: Path) -> None:
"""Different validated shard allowlists must not share advanced-scan cache entries."""
shard_one = tmp_path / "checkpoint_1.pt"
Expand Down
Loading