Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- avoid repeatedly scanning sharded model families during directory scans
- keep shard sibling discovery within the requested scan root
- preserve per-shard metadata when aggregating sharded model families
- preserve HuggingFace snapshot shard paths while grouping cache-backed families
- stop flagging a false-positive ONNX Python operator when tensor weight bytes coincidentally spell `PyOp`
- distinguish ASCII-serialized Torch7 artifacts from plain PyTorch source text

Expand Down
56 changes: 51 additions & 5 deletions modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,13 @@ def scan_model_directory_or_file(
hf_cache_root = _find_hf_cache_root(base_dir)
is_hf_cache = hf_cache_root is not None
scanned_paths: set[str] = set()
hf_shard_blob_paths: set[str] = set()

# First pass: collect all file paths that need scanning
files_to_scan: list[str] = []
shard_family_representatives: dict[_ShardFamilyKey, str] = {}
shard_family_paths: dict[_ShardFamilyKey, set[str]] = {}
complete_hf_shard_families: set[_ShardFamilyKey] = set()
directory_discovery_started_at = _start_phase_timing(phase_timings)
for root, _, files in os.walk(path, followlinks=False):
for file in files:
Expand All @@ -617,6 +619,12 @@ def scan_model_directory_or_file(
)
if resolved_file is None:
continue
snapshot_path = Path(file_path).absolute()
snapshot_shard_family_key = _shard_family_key_for_path(str(snapshot_path))
route_hf_shard_alias = (
is_hf_cache_symlink and resolved_file.exists() and snapshot_shard_family_key is not None
)
scan_source = snapshot_path if route_hf_shard_alias else resolved_file

# Skip non-model files early if filtering is enabled
# Note: skip_file_types parameter already contains the correct value
Expand Down Expand Up @@ -647,17 +655,26 @@ def scan_model_directory_or_file(
continue

# Handle DVC files and get target paths
target_paths = [resolved_file]
target_paths = [scan_source]
if file.endswith(".dvc"):
dvc_targets = resolve_dvc_file(file_path)
if dvc_targets:
target_paths = [Path(t).resolve() for t in dvc_targets]

for target_path in target_paths:
target_str = str(target_path)
if target_str in scanned_paths:
shard_family_key = _shard_family_key_for_path(target_str)
is_hf_shard_alias = route_hf_shard_alias and target_path == scan_source
if is_hf_shard_alias:
hf_shard_blob_paths.add(str(target_path.resolve()))
dedupe_target_str = (
str(target_path.resolve())
if is_hf_cache_symlink and shard_family_key is None
else target_str
Comment thread
mldangelo-oai marked this conversation as resolved.
)
if dedupe_target_str in scanned_paths:
continue
scanned_paths.add(target_str)
scanned_paths.add(dedupe_target_str)

if not is_hf_cache_symlink and not is_within_directory(str(base_dir), str(target_path)):
_add_issue_to_model(
Expand All @@ -670,14 +687,22 @@ def scan_model_directory_or_file(
continue

# Add to files to scan list instead of scanning immediately
shard_family_key = _shard_family_key_for_path(target_str)
if shard_family_key is not None:
family_paths = shard_family_paths.setdefault(shard_family_key, set())
family_paths.add(target_str)
if shard_family_key not in shard_family_representatives:
shard_family_representatives[shard_family_key] = target_str
shard_info = ShardedModelDetector.detect_shards(target_str)
if shard_info is not None:
expected_total_shards = shard_info.get("expected_total_shards")
if (
is_hf_cache_symlink
and isinstance(expected_total_shards, int)
and shard_info.get("total_shards") == expected_total_shards
and "missing_shard_count" not in shard_info
and "inconsistent_expected_total_shards" not in shard_info
):
complete_hf_shard_families.add(shard_family_key)
for shard_path in shard_info.get("shards", []):
if isinstance(shard_path, str):
resolved_shard_path = str(Path(shard_path).resolve())
Expand All @@ -690,7 +715,9 @@ def scan_model_directory_or_file(
resolved_shard_path,
)
)
if shard_in_base_dir or shard_in_hf_blobs:
if shard_in_hf_blobs:
family_paths.add(str(Path(shard_path).absolute()))
elif shard_in_base_dir:
family_paths.add(resolved_shard_path)
else:
_add_issue_to_model(
Expand All @@ -705,9 +732,28 @@ def scan_model_directory_or_file(
files_to_scan.append(target_str)
_finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at)

if hf_shard_blob_paths:
files_to_scan = [
file_path
for file_path in files_to_scan
if str(Path(file_path).resolve()) not in hf_shard_blob_paths
]
scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan]
seen_complete_hf_shard_families: set[tuple[str, tuple[str, ...]]] = set()
for shard_family_key, representative_file in shard_family_representatives.items():
ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file}))
expected_total_shards = shard_family_key[2]
if (
is_hf_cache
and expected_total_shards is not None
and shard_family_key in complete_hf_shard_families
and len(ordered_family_paths) == expected_total_shards
):
resolved_family_paths = tuple(sorted(str(Path(path).resolve()) for path in ordered_family_paths))
family_dedupe_key = (shard_family_key[1], resolved_family_paths)
if family_dedupe_key in seen_complete_hf_shard_families:
continue
Comment thread
mldangelo-oai marked this conversation as resolved.
seen_complete_hf_shard_families.add(family_dedupe_key)
scan_entries.append((representative_file, ordered_family_paths, shard_family_key))

# Second pass: scan every non-shard path independently and every shard
Expand Down
57 changes: 52 additions & 5 deletions modelaudit/utils/file/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,37 @@ def detect_shards(
shard_info: dict[str, Any] = {"pattern": pattern, "current_file": file_path, "shards": []}
expected_totals: set[int] = set()
present_indices: set[int] = set()
unreadable_shards: list[str] = []
total_size = 0

# Find all related shards
for file in dir_path.glob("*"):
file_match = re.fullmatch(pattern, file.name)
if file_match:
if allowed_path_set is not None and str(file.resolve()) not in allowed_path_set:
continue
try:
shard_size = os.path.getsize(file)
except OSError:
unreadable_shards.append(str(file))
continue
Comment thread
mldangelo-oai marked this conversation as resolved.
shard_info["shards"].append(str(file))
total_size += shard_size
if file_match.lastindex:
with suppress(IndexError, ValueError):
present_indices.add(int(file_match.group(1)))
if (file_match.lastindex or 0) >= 2:
with suppress(IndexError, ValueError):
expected_totals.add(int(file_match.group(2)))

if not shard_info["shards"]:
return None

shard_info["shards"].sort()
shard_info["total_shards"] = len(shard_info["shards"])
if unreadable_shards:
shard_info["unreadable_shards"] = sorted(unreadable_shards)
shard_info["unreadable_shard_count"] = len(unreadable_shards)
if expected_totals:
expected_total = max(expected_totals)
shard_info["expected_total_shards"] = expected_total
Expand All @@ -173,8 +187,6 @@ def detect_shards(
shard_info["missing_shard_indices"] = missing_indices
shard_info["missing_shard_indices_truncated"] = missing_indices_truncated

# Calculate total size
total_size = sum(os.path.getsize(s) for s in shard_info["shards"])
shard_info["total_size"] = total_size

return shard_info
Expand Down Expand Up @@ -340,16 +352,23 @@ def _analyze_window(self, data: bytes, offset: int) -> "ScanResult":
class ParallelShardHandler:
"""Scan multiple model shards in parallel."""

def __init__(self, shard_info: dict[str, Any], scanner_class: type):
def __init__(
self,
shard_info: dict[str, Any],
scanner_class: type,
scanner_config: dict[str, Any] | None = None,
):
"""
Initialize parallel shard scanner.

Args:
shard_info: Information about model shards
scanner_class: Scanner class to use
scanner_config: Configuration to preserve for each shard scanner
"""
self.shard_info = shard_info
self.scanner_class = scanner_class
self.scanner_config = scanner_config

def scan_shards(self, progress_callback: Callable[[str, float], None] | None = None) -> "ScanResult":
from ...scanner_results import IssueSeverity, ScanResult
Expand Down Expand Up @@ -425,7 +444,11 @@ def _scan_single_shard(self, shard_path: str) -> "ScanResult":
from ...scanner_results import ScanResult

"""Scan a single shard file."""
scanner = self.scanner_class()
scanner = (
self.scanner_class(config=dict(self.scanner_config))
if self.scanner_config is not None
else self.scanner_class()
)
result: ScanResult = scanner.scan(shard_path)
return result

Expand Down Expand Up @@ -560,11 +583,17 @@ def _scan_sharded_model(self) -> "ScanResult":
# Scan shards in parallel
shard_scan_success = True
if self.shard_info:
parallel_scanner = ParallelShardHandler(self.shard_info, self.scanner.__class__)
scanner_config = getattr(self.scanner, "config", None)
parallel_scanner = ParallelShardHandler(
self.shard_info,
self.scanner.__class__,
scanner_config=dict(scanner_config) if isinstance(scanner_config, dict) else None,
)
shard_results = parallel_scanner.scan_shards(self.progress_callback)
shard_scan_success = bool(shard_results.success)
result.merge(shard_results)
missing_count = self.shard_info.get("missing_shard_count")
unreadable_count = self.shard_info.get("unreadable_shard_count")
if isinstance(missing_count, int) and missing_count > 0:
_mark_inconclusive_scan_outcome(result, "missing_model_shards")
result.add_check(
Expand All @@ -580,11 +609,29 @@ def _scan_sharded_model(self) -> "ScanResult":
"missing_shard_indices_truncated": self.shard_info.get(
"missing_shard_indices_truncated", False
),
"unreadable_shard_count": self.shard_info.get("unreadable_shard_count", 0),
"unreadable_shards": self.shard_info.get("unreadable_shards", []),
"analysis_incomplete": True,
"scan_outcome": "inconclusive",
"scan_outcome_reason": "missing_model_shards",
},
)
elif isinstance(unreadable_count, int) and unreadable_count > 0:
_mark_inconclusive_scan_outcome(result, "unreadable_model_shards")
result.add_check(
name="Sharded Model Coverage Check",
passed=False,
message=f"Unable to read {unreadable_count} model shard(s); scan coverage is incomplete.",
severity=IssueSeverity.INFO,
details={
"present_total_shards": self.shard_info.get("total_shards"),
"unreadable_shard_count": unreadable_count,
"unreadable_shards": self.shard_info.get("unreadable_shards", []),
"analysis_incomplete": True,
"scan_outcome": "inconclusive",
"scan_outcome_reason": "unreadable_model_shards",
},
)

result.finish(success=shard_scan_success and not result.has_errors and "scan_outcome" not in result.metadata)
return result
Expand Down
Loading
Loading