From 11569663c4c59bf014ee65efd72b63da57aecc5d Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Sun, 17 May 2026 15:09:01 -0700 Subject: [PATCH 1/6] fix(core): group HF cache shard symlinks --- modelaudit/core.py | 7 +++++-- tests/test_core.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index df49ffea4..f5e499428 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -617,6 +617,7 @@ def scan_model_directory_or_file( ) if resolved_file is None: continue + scan_source = Path(file_path).absolute() if is_hf_cache_symlink else resolved_file # Skip non-model files early if filtering is enabled # Note: skip_file_types parameter already contains the correct value @@ -647,7 +648,7 @@ def scan_model_directory_or_file( continue # Handle DVC files and get target paths - target_paths = [resolved_file] + target_paths = [scan_source] if file.endswith(".dvc"): dvc_targets = resolve_dvc_file(file_path) if dvc_targets: @@ -690,7 +691,9 @@ def scan_model_directory_or_file( resolved_shard_path, ) ) - if shard_in_base_dir or shard_in_hf_blobs: + if shard_in_hf_blobs: + family_paths.add(str(Path(shard_path).absolute())) + elif shard_in_base_dir: family_paths.add(resolved_shard_path) else: _add_issue_to_model( diff --git a/tests/test_core.py b/tests/test_core.py index 1672549f3..0e4d5f3d1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -208,6 +208,52 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul ) +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_groups_hf_cache_sharded_symlinks( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + snapshots_dir = cache_dir / "snapshots" / "abc123" + blobs_dir = cache_dir / "blobs" + snapshots_dir.mkdir(parents=True) + blobs_dir.mkdir() + + blob_paths: list[Path] = [] + shard_links: list[Path] = [] + for shard_index in range(1, 3): + blob_path = blobs_dir / f"blob-{shard_index}" + blob_path.write_bytes(f"hf-shard-{shard_index}".encode()) + shard_link = snapshots_dir / f"model-{shard_index:05d}-of-00002.safetensors" + shard_link.symlink_to(Path("../../blobs") / blob_path.name) + blob_paths.append(blob_path.resolve()) + shard_links.append(shard_link) + + captured_configs: list[dict[str, Any]] = [] + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + captured_configs.append(dict(config or {})) + return _mock_sharded_scan_result(sum(blob_path.stat().st_size for blob_path in blob_paths)) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(snapshots_dir), cache_scan_results=False) + + material_config = normalize_material_scan_config(captured_configs[0]) + fingerprint = material_config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] + assert len(calls) == 1 + assert Path(calls[0]).name in {shard_link.name for shard_link in shard_links} + assert result.files_scanned == len(shard_links) + assert set(result.file_metadata) == {str(shard_link) for shard_link in shard_links} + assert {asset.path for asset in result.assets} == {str(shard_link) for shard_link in shard_links} + assert {member["path"] for member in fingerprint["members"]} == {str(blob_path) for blob_path in blob_paths} + assert not any("path traversal" in issue.message.lower() for issue in result.issues) + + def test_scan_file_passes_shard_allowlist_to_advanced_handler( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From 54b1fc1acc4d6a27e1bc676e94c1be9e16db22b0 Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Fri, 22 May 2026 03:11:34 -0400 Subject: [PATCH 2/6] fix: preserve dedupe for nonsharded HF aliases --- CHANGELOG.md | 1 + modelaudit/core.py | 7 ++++++- tests/test_core.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2adee9988..871385b15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - avoid repeatedly scanning sharded model families during directory scans - keep shard sibling discovery within the requested scan root - preserve per-shard metadata when aggregating sharded model families +- preserve HuggingFace snapshot shard paths while grouping cache-backed families - reject plain PyTorch source text from Torch7 content routing ## [0.2.45](https://github.com/promptfoo/modelaudit/compare/v0.2.44...v0.2.45) (2026-05-03) diff --git a/modelaudit/core.py b/modelaudit/core.py index f5e499428..a045763d0 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -617,7 +617,12 @@ def scan_model_directory_or_file( ) if resolved_file is None: continue - scan_source = Path(file_path).absolute() if is_hf_cache_symlink else resolved_file + snapshot_path = Path(file_path).absolute() + scan_source = ( + snapshot_path + if is_hf_cache_symlink and _shard_family_key_for_path(str(snapshot_path)) is not None + else resolved_file + ) # Skip non-model files early if filtering is enabled # Note: skip_file_types parameter already contains the correct value diff --git a/tests/test_core.py b/tests/test_core.py index 0e4d5f3d1..a6f171627 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -254,6 +254,38 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul assert not any("path traversal" in issue.message.lower() for issue in result.issues) +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_keeps_nonsharded_hf_snapshot_aliases_deduplicated( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + blobs_dir = cache_dir / "blobs" + blobs_dir.mkdir(parents=True) + blob_path = blobs_dir / "shared-blob" + blob_path.write_bytes(b"shared-model") + + for revision in ("abc123", "def456"): + snapshots_dir = cache_dir / "snapshots" / revision + snapshots_dir.mkdir(parents=True) + (snapshots_dir / "model.safetensors").symlink_to(Path("../../blobs") / blob_path.name) + + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + return _mock_sharded_scan_result(blob_path.stat().st_size) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False) + + assert calls == [str(blob_path.resolve())] + assert result.files_scanned == 1 + + def test_scan_file_passes_shard_allowlist_to_advanced_handler( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From 0698a84e3ae64fe2a6856506adb471b228628158 Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Fri, 22 May 2026 14:42:35 -0400 Subject: [PATCH 3/6] fix: dedupe HF shard families across snapshots --- modelaudit/core.py | 5 +++-- tests/test_core.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index a045763d0..788e8cc3d 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -661,9 +661,10 @@ def scan_model_directory_or_file( for target_path in target_paths: target_str = str(target_path) - if target_str in scanned_paths: + dedupe_target_str = str(target_path.resolve()) if is_hf_cache_symlink else target_str + if dedupe_target_str in scanned_paths: continue - scanned_paths.add(target_str) + scanned_paths.add(dedupe_target_str) if not is_hf_cache_symlink and not is_within_directory(str(base_dir), str(target_path)): _add_issue_to_model( diff --git a/tests/test_core.py b/tests/test_core.py index a6f171627..e82f25883 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -286,6 +286,48 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul assert result.files_scanned == 1 +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_deduplicates_identical_hf_shard_families_across_snapshots( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + blobs_dir = cache_dir / "blobs" + blobs_dir.mkdir(parents=True) + blob_paths: list[Path] = [] + for shard_index in range(1, 3): + blob_path = blobs_dir / f"blob-{shard_index}" + blob_path.write_bytes(f"shared-hf-shard-{shard_index}".encode()) + blob_paths.append(blob_path.resolve()) + for revision in ("abc123", "def456"): + snapshots_dir = cache_dir / "snapshots" / revision + snapshots_dir.mkdir(parents=True, exist_ok=True) + (snapshots_dir / f"model-{shard_index:05d}-of-00002.safetensors").symlink_to( + Path("../../blobs") / blob_path.name + ) + + captured_configs: list[dict[str, Any]] = [] + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + captured_configs.append(dict(config or {})) + return _mock_sharded_scan_result(sum(blob_path.stat().st_size for blob_path in blob_paths)) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False) + + material_config = normalize_material_scan_config(captured_configs[0]) + fingerprint = material_config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] + assert len(calls) == 1 + assert result.files_scanned == len(blob_paths) + assert result.bytes_scanned == sum(blob_path.stat().st_size for blob_path in blob_paths) + assert {member["path"] for member in fingerprint["members"]} == {str(blob_path) for blob_path in blob_paths} + + def test_scan_file_passes_shard_allowlist_to_advanced_handler( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From 465bf521375940f60df235b2e3c1534f6dd8ef57 Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Fri, 22 May 2026 23:19:03 -0400 Subject: [PATCH 4/6] fix: preserve hf shard coverage and scanner config --- modelaudit/core.py | 30 ++++++++++- modelaudit/utils/file/handlers.py | 22 ++++++-- tests/test_core.py | 52 +++++++++++++++++++ .../utils/file/test_advanced_file_handler.py | 21 ++++++++ 4 files changed, 120 insertions(+), 5 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index 788e8cc3d..18a30408d 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -597,6 +597,7 @@ def scan_model_directory_or_file( files_to_scan: list[str] = [] shard_family_representatives: dict[_ShardFamilyKey, str] = {} shard_family_paths: dict[_ShardFamilyKey, set[str]] = {} + complete_hf_shard_families: set[_ShardFamilyKey] = set() directory_discovery_started_at = _start_phase_timing(phase_timings) for root, _, files in os.walk(path, followlinks=False): for file in files: @@ -661,7 +662,12 @@ def scan_model_directory_or_file( for target_path in target_paths: target_str = str(target_path) - dedupe_target_str = str(target_path.resolve()) if is_hf_cache_symlink else target_str + shard_family_key = _shard_family_key_for_path(target_str) + dedupe_target_str = ( + str(target_path.resolve()) + if is_hf_cache_symlink and shard_family_key is None + else target_str + ) if dedupe_target_str in scanned_paths: continue scanned_paths.add(dedupe_target_str) @@ -677,7 +683,6 @@ def scan_model_directory_or_file( continue # Add to files to scan list instead of scanning immediately - shard_family_key = _shard_family_key_for_path(target_str) if shard_family_key is not None: family_paths = shard_family_paths.setdefault(shard_family_key, set()) family_paths.add(target_str) @@ -685,6 +690,15 @@ def scan_model_directory_or_file( shard_family_representatives[shard_family_key] = target_str shard_info = ShardedModelDetector.detect_shards(target_str) if shard_info is not None: + expected_total_shards = shard_info.get("expected_total_shards") + if ( + is_hf_cache_symlink + and isinstance(expected_total_shards, int) + and shard_info.get("total_shards") == expected_total_shards + and "missing_shard_count" not in shard_info + and "inconsistent_expected_total_shards" not in shard_info + ): + complete_hf_shard_families.add(shard_family_key) for shard_path in shard_info.get("shards", []): if isinstance(shard_path, str): resolved_shard_path = str(Path(shard_path).resolve()) @@ -715,8 +729,20 @@ def scan_model_directory_or_file( _finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at) scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan] + seen_complete_hf_shard_families: set[tuple[str, ...]] = set() for shard_family_key, representative_file in shard_family_representatives.items(): ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file})) + expected_total_shards = shard_family_key[2] + if ( + is_hf_cache + and expected_total_shards is not None + and shard_family_key in complete_hf_shard_families + and len(ordered_family_paths) == expected_total_shards + ): + resolved_family_paths = tuple(sorted(str(Path(path).resolve()) for path in ordered_family_paths)) + if resolved_family_paths in seen_complete_hf_shard_families: + continue + seen_complete_hf_shard_families.add(resolved_family_paths) scan_entries.append((representative_file, ordered_family_paths, shard_family_key)) # Second pass: scan every non-shard path independently and every shard diff --git a/modelaudit/utils/file/handlers.py b/modelaudit/utils/file/handlers.py index 8c7fc6999..612dab583 100644 --- a/modelaudit/utils/file/handlers.py +++ b/modelaudit/utils/file/handlers.py @@ -340,16 +340,23 @@ def _analyze_window(self, data: bytes, offset: int) -> "ScanResult": class ParallelShardHandler: """Scan multiple model shards in parallel.""" - def __init__(self, shard_info: dict[str, Any], scanner_class: type): + def __init__( + self, + shard_info: dict[str, Any], + scanner_class: type, + scanner_config: dict[str, Any] | None = None, + ): """ Initialize parallel shard scanner. Args: shard_info: Information about model shards scanner_class: Scanner class to use + scanner_config: Configuration to preserve for each shard scanner """ self.shard_info = shard_info self.scanner_class = scanner_class + self.scanner_config = scanner_config def scan_shards(self, progress_callback: Callable[[str, float], None] | None = None) -> "ScanResult": from ...scanner_results import IssueSeverity, ScanResult @@ -425,7 +432,11 @@ def _scan_single_shard(self, shard_path: str) -> "ScanResult": from ...scanner_results import ScanResult """Scan a single shard file.""" - scanner = self.scanner_class() + scanner = ( + self.scanner_class(config=dict(self.scanner_config)) + if self.scanner_config is not None + else self.scanner_class() + ) result: ScanResult = scanner.scan(shard_path) return result @@ -560,7 +571,12 @@ def _scan_sharded_model(self) -> "ScanResult": # Scan shards in parallel shard_scan_success = True if self.shard_info: - parallel_scanner = ParallelShardHandler(self.shard_info, self.scanner.__class__) + scanner_config = getattr(self.scanner, "config", None) + parallel_scanner = ParallelShardHandler( + self.shard_info, + self.scanner.__class__, + scanner_config=dict(scanner_config) if isinstance(scanner_config, dict) else None, + ) shard_results = parallel_scanner.scan_shards(self.progress_callback) shard_scan_success = bool(shard_results.success) result.merge(shard_results) diff --git a/tests/test_core.py b/tests/test_core.py index e82f25883..962cd3d7f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -328,6 +328,58 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul assert {member["path"] for member in fingerprint["members"]} == {str(blob_path) for blob_path in blob_paths} +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_reports_incomplete_hf_snapshot_after_shared_blob_dedupe( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + blobs_dir = cache_dir / "blobs" + blobs_dir.mkdir(parents=True) + blob_paths: list[Path] = [] + for shard_index in range(1, 3): + blob_path = blobs_dir / f"blob-{shard_index}" + blob_path.write_bytes(f"shared-hf-shard-{shard_index}".encode()) + blob_paths.append(blob_path.resolve()) + full_snapshot = cache_dir / "snapshots" / "abc123" + full_snapshot.mkdir(parents=True, exist_ok=True) + (full_snapshot / f"model-{shard_index:05d}-of-00002.safetensors").symlink_to( + Path("../../blobs") / blob_path.name + ) + + partial_snapshot = cache_dir / "snapshots" / "def456" + partial_snapshot.mkdir(parents=True) + (partial_snapshot / "model-00001-of-00002.safetensors").symlink_to(Path("../../blobs") / blob_paths[0].name) + + captured_configs: list[dict[str, Any]] = [] + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + captured_configs.append(dict(config or {})) + material_config = normalize_material_scan_config(captured_configs[-1]) + fingerprint = material_config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] + member_paths = [Path(member["path"]) for member in fingerprint["members"]] + return _mock_sharded_scan_result( + sum(member_path.stat().st_size for member_path in member_paths), + missing_shards=1 if len(member_paths) == 1 else 0, + ) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False) + + material_configs = [normalize_material_scan_config(config) for config in captured_configs] + fingerprints = [config[core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] for config in material_configs] + coverage_checks = [check for check in result.checks if check.name == "Sharded Model Coverage Check"] + assert len(calls) == 2 + assert sorted(len(fingerprint["members"]) for fingerprint in fingerprints) == [1, 2] + assert len(coverage_checks) == 1 + assert coverage_checks[0].details["missing_shard_count"] == 1 + + def test_scan_file_passes_shard_allowlist_to_advanced_handler( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/utils/file/test_advanced_file_handler.py b/tests/utils/file/test_advanced_file_handler.py index d7f04bff7..d638a2d6c 100644 --- a/tests/utils/file/test_advanced_file_handler.py +++ b/tests/utils/file/test_advanced_file_handler.py @@ -353,6 +353,27 @@ def test_sharded_model_honors_allowed_shard_paths(self, tmp_path: Path) -> None: assert shard_detection.details["shards"] == [str(shard_one)] assert result.bytes_scanned == shard_one.stat().st_size + def test_sharded_model_preserves_scanner_config_for_each_shard(self, tmp_path: Path) -> None: + """Shard fanout should retain caller configuration for each scanner instance.""" + shard_one = tmp_path / "model-00001-of-00002.safetensors" + shard_two = tmp_path / "model-00002-of-00002.safetensors" + shard_one.write_bytes(b"one") + shard_two.write_bytes(b"two") + captured_configs: list[dict[str, Any]] = [] + + class ConfiguredShardScanner(CompletingShardScanner): + def __init__(self, config: dict[str, Any] | None = None) -> None: + self.config = dict(config or {}) + captured_configs.append(self.config) + + scanner = ConfiguredShardScanner({"max_tensor_bytes": 7}) + captured_configs.clear() + + result = AdvancedFileHandler(str(shard_one), scanner).scan() + + assert result.success is True + assert captured_configs == [{"max_tensor_bytes": 7}, {"max_tensor_bytes": 7}] + def test_cached_advanced_scan_keys_allowed_shard_paths(self, tmp_path: Path) -> None: """Different validated shard allowlists must not share advanced-scan cache entries.""" shard_one = tmp_path / "checkpoint_1.pt" From 4f037a7f70015c4dff392aa304a80c61bbce83c1 Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Fri, 22 May 2026 23:30:30 -0400 Subject: [PATCH 5/6] fix: handle hf shard alias edge cases --- modelaudit/core.py | 25 ++-- modelaudit/utils/file/handlers.py | 11 +- tests/test_core.py | 117 ++++++++++++++++++ .../utils/file/test_advanced_file_handler.py | 20 +++ 4 files changed, 164 insertions(+), 9 deletions(-) diff --git a/modelaudit/core.py b/modelaudit/core.py index 18a30408d..18a5f57a4 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -592,6 +592,7 @@ def scan_model_directory_or_file( hf_cache_root = _find_hf_cache_root(base_dir) is_hf_cache = hf_cache_root is not None scanned_paths: set[str] = set() + hf_shard_blob_paths: set[str] = set() # First pass: collect all file paths that need scanning files_to_scan: list[str] = [] @@ -619,11 +620,11 @@ def scan_model_directory_or_file( if resolved_file is None: continue snapshot_path = Path(file_path).absolute() - scan_source = ( - snapshot_path - if is_hf_cache_symlink and _shard_family_key_for_path(str(snapshot_path)) is not None - else resolved_file + snapshot_shard_family_key = _shard_family_key_for_path(str(snapshot_path)) + route_hf_shard_alias = ( + is_hf_cache_symlink and resolved_file.exists() and snapshot_shard_family_key is not None ) + scan_source = snapshot_path if route_hf_shard_alias else resolved_file # Skip non-model files early if filtering is enabled # Note: skip_file_types parameter already contains the correct value @@ -663,6 +664,9 @@ def scan_model_directory_or_file( for target_path in target_paths: target_str = str(target_path) shard_family_key = _shard_family_key_for_path(target_str) + is_hf_shard_alias = route_hf_shard_alias and target_path == scan_source + if is_hf_shard_alias: + hf_shard_blob_paths.add(str(target_path.resolve())) dedupe_target_str = ( str(target_path.resolve()) if is_hf_cache_symlink and shard_family_key is None @@ -728,8 +732,14 @@ def scan_model_directory_or_file( files_to_scan.append(target_str) _finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at) + if hf_shard_blob_paths: + files_to_scan = [ + file_path + for file_path in files_to_scan + if str(Path(file_path).resolve()) not in hf_shard_blob_paths + ] scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan] - seen_complete_hf_shard_families: set[tuple[str, ...]] = set() + seen_complete_hf_shard_families: set[tuple[str, tuple[str, ...]]] = set() for shard_family_key, representative_file in shard_family_representatives.items(): ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file})) expected_total_shards = shard_family_key[2] @@ -740,9 +750,10 @@ def scan_model_directory_or_file( and len(ordered_family_paths) == expected_total_shards ): resolved_family_paths = tuple(sorted(str(Path(path).resolve()) for path in ordered_family_paths)) - if resolved_family_paths in seen_complete_hf_shard_families: + family_dedupe_key = (shard_family_key[1], resolved_family_paths) + if family_dedupe_key in seen_complete_hf_shard_families: continue - seen_complete_hf_shard_families.add(resolved_family_paths) + seen_complete_hf_shard_families.add(family_dedupe_key) scan_entries.append((representative_file, ordered_family_paths, shard_family_key)) # Second pass: scan every non-shard path independently and every shard diff --git a/modelaudit/utils/file/handlers.py b/modelaudit/utils/file/handlers.py index 612dab583..01e8d0e22 100644 --- a/modelaudit/utils/file/handlers.py +++ b/modelaudit/utils/file/handlers.py @@ -141,6 +141,7 @@ def detect_shards( shard_info: dict[str, Any] = {"pattern": pattern, "current_file": file_path, "shards": []} expected_totals: set[int] = set() present_indices: set[int] = set() + total_size = 0 # Find all related shards for file in dir_path.glob("*"): @@ -148,7 +149,12 @@ def detect_shards( if file_match: if allowed_path_set is not None and str(file.resolve()) not in allowed_path_set: continue + try: + shard_size = os.path.getsize(file) + except OSError: + continue shard_info["shards"].append(str(file)) + total_size += shard_size if file_match.lastindex: with suppress(IndexError, ValueError): present_indices.add(int(file_match.group(1))) @@ -156,6 +162,9 @@ def detect_shards( with suppress(IndexError, ValueError): expected_totals.add(int(file_match.group(2))) + if not shard_info["shards"]: + return None + shard_info["shards"].sort() shard_info["total_shards"] = len(shard_info["shards"]) if expected_totals: @@ -173,8 +182,6 @@ def detect_shards( shard_info["missing_shard_indices"] = missing_indices shard_info["missing_shard_indices_truncated"] = missing_indices_truncated - # Calculate total size - total_size = sum(os.path.getsize(s) for s in shard_info["shards"]) shard_info["total_size"] = total_size return shard_info diff --git a/tests/test_core.py b/tests/test_core.py index 962cd3d7f..07661cec0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -380,6 +380,123 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul assert coverage_checks[0].details["missing_shard_count"] == 1 +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_keeps_distinct_hf_shard_filename_patterns( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + blobs_dir = cache_dir / "blobs" + blobs_dir.mkdir(parents=True) + blob_paths: list[Path] = [] + for shard_index in range(1, 3): + blob_path = blobs_dir / f"blob-{shard_index}" + blob_path.write_bytes(f"shared-hf-shard-{shard_index}".encode()) + blob_paths.append(blob_path.resolve()) + for revision, filename in ( + ("abc123", f"model-{shard_index:05d}-of-00002.safetensors"), + ("def456", f"pytorch_model-{shard_index:05d}-of-00002.bin"), + ): + snapshot = cache_dir / "snapshots" / revision + snapshot.mkdir(parents=True, exist_ok=True) + (snapshot / filename).symlink_to(Path("../../blobs") / blob_path.name) + + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + return _mock_sharded_scan_result(sum(blob_path.stat().st_size for blob_path in blob_paths)) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(cache_dir / "snapshots"), cache_scan_results=False) + + assert len(calls) == 2 + assert {Path(call).suffix for call in calls} == {".bin", ".safetensors"} + assert result.files_scanned == 4 + + +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_deduplicates_hf_shard_aliases_against_raw_blobs( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + snapshot = cache_dir / "snapshots" / "abc123" + blobs_dir = cache_dir / "blobs" + snapshot.mkdir(parents=True) + blobs_dir.mkdir() + blob_paths: list[Path] = [] + for shard_index in range(1, 3): + blob_path = blobs_dir / f"blob-{shard_index}.safetensors" + blob_path.write_bytes(f"hf-shard-{shard_index}".encode()) + blob_paths.append(blob_path.resolve()) + (snapshot / f"model-{shard_index:05d}-of-00002.safetensors").symlink_to(Path("../../blobs") / blob_path.name) + + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + return _mock_sharded_scan_result(sum(blob_path.stat().st_size for blob_path in blob_paths)) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(cache_dir), cache_scan_results=False) + + assert len(calls) == 1 + assert Path(calls[0]).parent == snapshot + assert result.files_scanned == 2 + + +@pytest.mark.usefixtures("requires_symlinks") +def test_directory_scan_handles_broken_hf_shard_alias_per_file( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + hf_home = tmp_path / "hf-home" + monkeypatch.setenv("HF_HOME", str(hf_home)) + cache_dir = hf_home / "hub" / "models--org--model" + snapshot = cache_dir / "snapshots" / "abc123" + blobs_dir = cache_dir / "blobs" + snapshot.mkdir(parents=True) + blobs_dir.mkdir() + blob_path = blobs_dir / "blob-1" + blob_path.write_bytes(b"hf-shard-1") + (snapshot / "model-00001-of-00002.safetensors").symlink_to(Path("../../blobs") / blob_path.name) + missing_blob = blobs_dir / "missing-blob" + (snapshot / "model-00002-of-00002.safetensors").symlink_to(Path("../../blobs") / missing_blob.name) + + calls: list[str] = [] + + def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResult: + calls.append(path) + if Path(path) == missing_blob: + result = ScanResult(scanner_name="error") + result.add_check( + name="File Size Check", + passed=False, + message="Error checking file size: missing blob", + severity=IssueSeverity.INFO, + ) + result.finish(success=False) + return result + return _mock_sharded_scan_result(blob_path.stat().st_size, missing_shards=1) + + monkeypatch.setattr(core_module, "scan_file", fake_scan_file) + + result = core_module.scan_model_directory_or_file(str(snapshot), cache_scan_results=False) + + coverage_checks = [check for check in result.checks if check.name == "Sharded Model Coverage Check"] + assert len(calls) == 2 + assert str(missing_blob) in calls + assert len(coverage_checks) == 1 + assert any(check.name == "File Size Check" for check in result.checks) + + def test_scan_file_passes_shard_allowlist_to_advanced_handler( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/utils/file/test_advanced_file_handler.py b/tests/utils/file/test_advanced_file_handler.py index d638a2d6c..1cffad29f 100644 --- a/tests/utils/file/test_advanced_file_handler.py +++ b/tests/utils/file/test_advanced_file_handler.py @@ -335,6 +335,26 @@ def test_sharded_model_missing_shards_marks_scan_inconclusive(self, tmp_path: Pa assert coverage_checks[0].details["missing_shard_indices"] == [2] assert coverage_checks[0].details["missing_shard_indices_truncated"] is False + def test_sharded_model_broken_shard_marks_scan_inconclusive( + self, + tmp_path: Path, + requires_symlinks: None, + ) -> None: + """Unreadable shard links should be reported as missing instead of aborting expansion.""" + shard_one = tmp_path / "model-00001-of-00002.safetensors" + shard_two = tmp_path / "model-00002-of-00002.safetensors" + shard_one.write_bytes(b"safe") + shard_two.symlink_to(tmp_path / "missing-shard") + + handler = AdvancedFileHandler(str(shard_one), CompletingShardScanner()) + result = handler.scan() + + coverage_checks = [check for check in result.checks if check.name == "Sharded Model Coverage Check"] + assert result.success is False + assert result.bytes_scanned == shard_one.stat().st_size + assert len(coverage_checks) == 1 + assert coverage_checks[0].details["missing_shard_count"] == 1 + def test_sharded_model_honors_allowed_shard_paths(self, tmp_path: Path) -> None: """Restricted shard scans must not expand beyond the validated allowlist.""" shard_one = tmp_path / "model-00001-of-00002.safetensors" From d3b478e6a6677e4f773b3569f7d6aeb6ba8282bd Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Fri, 22 May 2026 23:39:40 -0400 Subject: [PATCH 6/6] fix: report unreadable shards without total counts --- modelaudit/utils/file/handlers.py | 24 +++++++++++++++++++ .../utils/file/test_advanced_file_handler.py | 24 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/modelaudit/utils/file/handlers.py b/modelaudit/utils/file/handlers.py index 01e8d0e22..568a9f3f9 100644 --- a/modelaudit/utils/file/handlers.py +++ b/modelaudit/utils/file/handlers.py @@ -141,6 +141,7 @@ def detect_shards( shard_info: dict[str, Any] = {"pattern": pattern, "current_file": file_path, "shards": []} expected_totals: set[int] = set() present_indices: set[int] = set() + unreadable_shards: list[str] = [] total_size = 0 # Find all related shards @@ -152,6 +153,7 @@ def detect_shards( try: shard_size = os.path.getsize(file) except OSError: + unreadable_shards.append(str(file)) continue shard_info["shards"].append(str(file)) total_size += shard_size @@ -167,6 +169,9 @@ def detect_shards( shard_info["shards"].sort() shard_info["total_shards"] = len(shard_info["shards"]) + if unreadable_shards: + shard_info["unreadable_shards"] = sorted(unreadable_shards) + shard_info["unreadable_shard_count"] = len(unreadable_shards) if expected_totals: expected_total = max(expected_totals) shard_info["expected_total_shards"] = expected_total @@ -588,6 +593,7 @@ def _scan_sharded_model(self) -> "ScanResult": shard_scan_success = bool(shard_results.success) result.merge(shard_results) missing_count = self.shard_info.get("missing_shard_count") + unreadable_count = self.shard_info.get("unreadable_shard_count") if isinstance(missing_count, int) and missing_count > 0: _mark_inconclusive_scan_outcome(result, "missing_model_shards") result.add_check( @@ -603,11 +609,29 @@ def _scan_sharded_model(self) -> "ScanResult": "missing_shard_indices_truncated": self.shard_info.get( "missing_shard_indices_truncated", False ), + "unreadable_shard_count": self.shard_info.get("unreadable_shard_count", 0), + "unreadable_shards": self.shard_info.get("unreadable_shards", []), "analysis_incomplete": True, "scan_outcome": "inconclusive", "scan_outcome_reason": "missing_model_shards", }, ) + elif isinstance(unreadable_count, int) and unreadable_count > 0: + _mark_inconclusive_scan_outcome(result, "unreadable_model_shards") + result.add_check( + name="Sharded Model Coverage Check", + passed=False, + message=f"Unable to read {unreadable_count} model shard(s); scan coverage is incomplete.", + severity=IssueSeverity.INFO, + details={ + "present_total_shards": self.shard_info.get("total_shards"), + "unreadable_shard_count": unreadable_count, + "unreadable_shards": self.shard_info.get("unreadable_shards", []), + "analysis_incomplete": True, + "scan_outcome": "inconclusive", + "scan_outcome_reason": "unreadable_model_shards", + }, + ) result.finish(success=shard_scan_success and not result.has_errors and "scan_outcome" not in result.metadata) return result diff --git a/tests/utils/file/test_advanced_file_handler.py b/tests/utils/file/test_advanced_file_handler.py index 1cffad29f..daedebe57 100644 --- a/tests/utils/file/test_advanced_file_handler.py +++ b/tests/utils/file/test_advanced_file_handler.py @@ -354,6 +354,30 @@ def test_sharded_model_broken_shard_marks_scan_inconclusive( assert result.bytes_scanned == shard_one.stat().st_size assert len(coverage_checks) == 1 assert coverage_checks[0].details["missing_shard_count"] == 1 + assert coverage_checks[0].details["unreadable_shard_count"] == 1 + assert coverage_checks[0].details["unreadable_shards"] == [str(shard_two)] + + def test_sharded_model_broken_shard_without_declared_total_marks_scan_inconclusive( + self, + tmp_path: Path, + requires_symlinks: None, + ) -> None: + """Unreadable members cannot be silently dropped when a family has no declared total.""" + shard_one = tmp_path / "checkpoint_1.pt" + shard_two = tmp_path / "checkpoint_2.pt" + shard_one.write_bytes(b"safe") + shard_two.symlink_to(tmp_path / "missing-shard") + + handler = AdvancedFileHandler(str(shard_one), CompletingShardScanner()) + result = handler.scan() + + coverage_checks = [check for check in result.checks if check.name == "Sharded Model Coverage Check"] + assert result.success is False + assert result.bytes_scanned == shard_one.stat().st_size + assert "unreadable_model_shards" in result.metadata["scan_outcome_reasons"] + assert len(coverage_checks) == 1 + assert coverage_checks[0].details["unreadable_shard_count"] == 1 + assert coverage_checks[0].details["unreadable_shards"] == [str(shard_two)] def test_sharded_model_honors_allowed_shard_paths(self, tmp_path: Path) -> None: """Restricted shard scans must not expand beyond the validated allowlist."""