Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- avoid repeatedly scanning sharded model families during directory scans
- keep shard sibling discovery within the requested scan root
- preserve per-shard metadata when aggregating sharded model families
- preserve HuggingFace snapshot shard paths while grouping cache-backed families
- stop flagging a false-positive ONNX Python operator when tensor weight bytes coincidentally spell `PyOp`
- distinguish ASCII-serialized Torch7 artifacts from plain PyTorch source text

Expand Down
56 changes: 51 additions & 5 deletions modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,13 @@ def scan_model_directory_or_file(
hf_cache_root = _find_hf_cache_root(base_dir)
is_hf_cache = hf_cache_root is not None
scanned_paths: set[str] = set()
hf_shard_blob_paths: set[str] = set()

# First pass: collect all file paths that need scanning
files_to_scan: list[str] = []
shard_family_representatives: dict[_ShardFamilyKey, str] = {}
shard_family_paths: dict[_ShardFamilyKey, set[str]] = {}
complete_hf_shard_families: set[_ShardFamilyKey] = set()
directory_discovery_started_at = _start_phase_timing(phase_timings)
for root, _, files in os.walk(path, followlinks=False):
for file in files:
Expand All @@ -617,6 +619,12 @@ def scan_model_directory_or_file(
)
if resolved_file is None:
continue
snapshot_path = Path(file_path).absolute()
snapshot_shard_family_key = _shard_family_key_for_path(str(snapshot_path))
route_hf_shard_alias = (
is_hf_cache_symlink and resolved_file.exists() and snapshot_shard_family_key is not None
)
scan_source = snapshot_path if route_hf_shard_alias else resolved_file

# Skip non-model files early if filtering is enabled
# Note: skip_file_types parameter already contains the correct value
Expand Down Expand Up @@ -647,17 +655,26 @@ def scan_model_directory_or_file(
continue

# Handle DVC files and get target paths
target_paths = [resolved_file]
target_paths = [scan_source]
if file.endswith(".dvc"):
dvc_targets = resolve_dvc_file(file_path)
if dvc_targets:
target_paths = [Path(t).resolve() for t in dvc_targets]

for target_path in target_paths:
target_str = str(target_path)
if target_str in scanned_paths:
shard_family_key = _shard_family_key_for_path(target_str)
is_hf_shard_alias = route_hf_shard_alias and target_path == scan_source
if is_hf_shard_alias:
hf_shard_blob_paths.add(str(target_path.resolve()))
dedupe_target_str = (
str(target_path.resolve())
if is_hf_cache_symlink and shard_family_key is None
else target_str
Comment thread
mldangelo-oai marked this conversation as resolved.
)
if dedupe_target_str in scanned_paths:
continue
scanned_paths.add(target_str)
scanned_paths.add(dedupe_target_str)

if not is_hf_cache_symlink and not is_within_directory(str(base_dir), str(target_path)):
_add_issue_to_model(
Expand All @@ -670,14 +687,22 @@ def scan_model_directory_or_file(
continue

# Add to files to scan list instead of scanning immediately
shard_family_key = _shard_family_key_for_path(target_str)
if shard_family_key is not None:
family_paths = shard_family_paths.setdefault(shard_family_key, set())
family_paths.add(target_str)
if shard_family_key not in shard_family_representatives:
shard_family_representatives[shard_family_key] = target_str
shard_info = ShardedModelDetector.detect_shards(target_str)
if shard_info is not None:
expected_total_shards = shard_info.get("expected_total_shards")
if (
is_hf_cache_symlink
and isinstance(expected_total_shards, int)
and shard_info.get("total_shards") == expected_total_shards
and "missing_shard_count" not in shard_info
and "inconsistent_expected_total_shards" not in shard_info
):
complete_hf_shard_families.add(shard_family_key)
for shard_path in shard_info.get("shards", []):
if isinstance(shard_path, str):
resolved_shard_path = str(Path(shard_path).resolve())
Expand All @@ -690,7 +715,9 @@ def scan_model_directory_or_file(
resolved_shard_path,
)
)
if shard_in_base_dir or shard_in_hf_blobs:
if shard_in_hf_blobs:
family_paths.add(str(Path(shard_path).absolute()))
elif shard_in_base_dir:
family_paths.add(resolved_shard_path)
else:
_add_issue_to_model(
Expand All @@ -705,9 +732,28 @@ def scan_model_directory_or_file(
files_to_scan.append(target_str)
_finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at)

if hf_shard_blob_paths:
files_to_scan = [
file_path
for file_path in files_to_scan
if str(Path(file_path).resolve()) not in hf_shard_blob_paths
]
scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan]
seen_complete_hf_shard_families: set[tuple[str, tuple[str, ...]]] = set()
for shard_family_key, representative_file in shard_family_representatives.items():
ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file}))
expected_total_shards = shard_family_key[2]
if (
is_hf_cache
and expected_total_shards is not None
and shard_family_key in complete_hf_shard_families
and len(ordered_family_paths) == expected_total_shards
):
resolved_family_paths = tuple(sorted(str(Path(path).resolve()) for path in ordered_family_paths))
family_dedupe_key = (shard_family_key[1], resolved_family_paths)
if family_dedupe_key in seen_complete_hf_shard_families:
continue
Comment thread
mldangelo-oai marked this conversation as resolved.
seen_complete_hf_shard_families.add(family_dedupe_key)
scan_entries.append((representative_file, ordered_family_paths, shard_family_key))

# Second pass: scan every non-shard path independently and every shard
Expand Down
33 changes: 28 additions & 5 deletions modelaudit/utils/file/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,30 @@ def detect_shards(
shard_info: dict[str, Any] = {"pattern": pattern, "current_file": file_path, "shards": []}
expected_totals: set[int] = set()
present_indices: set[int] = set()
total_size = 0

# Find all related shards
for file in dir_path.glob("*"):
file_match = re.fullmatch(pattern, file.name)
if file_match:
if allowed_path_set is not None and str(file.resolve()) not in allowed_path_set:
continue
try:
shard_size = os.path.getsize(file)
except OSError:
continue
Comment thread
mldangelo-oai marked this conversation as resolved.
shard_info["shards"].append(str(file))
total_size += shard_size
if file_match.lastindex:
with suppress(IndexError, ValueError):
present_indices.add(int(file_match.group(1)))
if (file_match.lastindex or 0) >= 2:
with suppress(IndexError, ValueError):
expected_totals.add(int(file_match.group(2)))

if not shard_info["shards"]:
return None

shard_info["shards"].sort()
shard_info["total_shards"] = len(shard_info["shards"])
if expected_totals:
Expand All @@ -173,8 +182,6 @@ def detect_shards(
shard_info["missing_shard_indices"] = missing_indices
shard_info["missing_shard_indices_truncated"] = missing_indices_truncated

# Calculate total size
total_size = sum(os.path.getsize(s) for s in shard_info["shards"])
shard_info["total_size"] = total_size

return shard_info
Expand Down Expand Up @@ -340,16 +347,23 @@ def _analyze_window(self, data: bytes, offset: int) -> "ScanResult":
class ParallelShardHandler:
"""Scan multiple model shards in parallel."""

def __init__(self, shard_info: dict[str, Any], scanner_class: type):
def __init__(
self,
shard_info: dict[str, Any],
scanner_class: type,
scanner_config: dict[str, Any] | None = None,
):
"""
Initialize parallel shard scanner.

Args:
shard_info: Information about model shards
scanner_class: Scanner class to use
scanner_config: Configuration to preserve for each shard scanner
"""
self.shard_info = shard_info
self.scanner_class = scanner_class
self.scanner_config = scanner_config

def scan_shards(self, progress_callback: Callable[[str, float], None] | None = None) -> "ScanResult":
from ...scanner_results import IssueSeverity, ScanResult
Expand Down Expand Up @@ -425,7 +439,11 @@ def _scan_single_shard(self, shard_path: str) -> "ScanResult":
from ...scanner_results import ScanResult

"""Scan a single shard file."""
scanner = self.scanner_class()
scanner = (
self.scanner_class(config=dict(self.scanner_config))
if self.scanner_config is not None
else self.scanner_class()
)
result: ScanResult = scanner.scan(shard_path)
return result

Expand Down Expand Up @@ -560,7 +578,12 @@ def _scan_sharded_model(self) -> "ScanResult":
# Scan shards in parallel
shard_scan_success = True
if self.shard_info:
parallel_scanner = ParallelShardHandler(self.shard_info, self.scanner.__class__)
scanner_config = getattr(self.scanner, "config", None)
parallel_scanner = ParallelShardHandler(
self.shard_info,
self.scanner.__class__,
scanner_config=dict(scanner_config) if isinstance(scanner_config, dict) else None,
)
shard_results = parallel_scanner.scan_shards(self.progress_callback)
shard_scan_success = bool(shard_results.success)
result.merge(shard_results)
Expand Down
Loading
Loading