diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e039c638..0bebeb4f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - avoid repeatedly scanning sharded model families during directory scans - keep shard sibling discovery within the requested scan root - preserve per-shard metadata when aggregating sharded model families +- prevent picklescan call-graph alias cycles from hanging scans - preserve HuggingFace snapshot shard paths while grouping cache-backed families - stop flagging a false-positive ONNX Python operator when tensor weight bytes coincidentally spell `PyOp` - distinguish ASCII-serialized Torch7 artifacts from plain PyTorch source text diff --git a/packages/modelaudit-picklescan/CHANGELOG.md b/packages/modelaudit-picklescan/CHANGELOG.md index dffbebf01..fb67942d3 100644 --- a/packages/modelaudit-picklescan/CHANGELOG.md +++ b/packages/modelaudit-picklescan/CHANGELOG.md @@ -51,6 +51,7 @@ and this package adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Bug Fixes +- prevent call-graph alias cycles from hanging scans - detect nested brace-format lookups that reach tracked `defaultdict` factories - avoid `str.format` false positives when a `ChainMap` shadows a `defaultdict` - block `statistics.quantiles` call-iterator consumption in call-graph analysis diff --git a/packages/modelaudit-picklescan/src/modelaudit_picklescan/api.py b/packages/modelaudit-picklescan/src/modelaudit_picklescan/api.py index 9a8362b46..a619be21e 100644 --- a/packages/modelaudit-picklescan/src/modelaudit_picklescan/api.py +++ b/packages/modelaudit-picklescan/src/modelaudit_picklescan/api.py @@ -14,6 +14,7 @@ CallGraphFinding, StartupHookWriteFinding, UnanalyzedCallGraphReference, + _CallGraphAnalysisLimitError, find_dangerous_call_graphs, find_startup_hook_write_call_graphs, find_unanalyzed_callable_call_graph_references, @@ -1012,6 +1013,9 @@ def _with_call_graph_findings(report: PickleReport) -> PickleReport: with shared_source_sensitive_caches(): try: call_graph_findings = find_dangerous_call_graphs(import_references, callable_invocations) + except _CallGraphAnalysisLimitError as error: + call_graph_findings = error.partial_findings + enrichment_errors.append(("python_call_graph", error)) except Exception as error: call_graph_findings = () enrichment_errors.append(("python_call_graph", error)) @@ -1020,6 +1024,9 @@ def _with_call_graph_findings(report: PickleReport) -> PickleReport: import_references, callable_invocations, ) + except _CallGraphAnalysisLimitError as error: + startup_hook_write_findings = error.partial_startup_hook_write_findings + enrichment_errors.append(("python_call_graph_startup_hook_write", error)) except Exception as error: startup_hook_write_findings = () enrichment_errors.append(("python_call_graph_startup_hook_write", error)) diff --git a/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py b/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py index 629bb203f..aedab5e44 100644 --- a/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py +++ b/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py @@ -30,6 +30,7 @@ _MAX_VISITED_FUNCTIONS = 64 _MAX_CALLS_PER_FUNCTION = 128 _MAX_ASSIGNMENT_ALIASES = 128 +_MAX_ASSIGNMENT_ALIAS_PASSES = 256 _MAX_FUNCTION_INSTANCE_ALIASES = 32 _MAX_CLASS_INSTANCE_ALIASES = 128 _MAX_INHERITED_CLASS_METHODS = 128 @@ -66,6 +67,23 @@ def cache_clear(self) -> None: _SOURCE_SENSITIVE_CACHED_FUNCTIONS: set[_CacheClearable] = set() +class _CallGraphAnalysisLimitError(RuntimeError): + """Raised when bounded call-graph enrichment cannot complete safely.""" + + def __init__( + self, + message: str, + *, + partial_findings: tuple[CallGraphFinding, ...] = (), + partial_startup_hook_write_findings: tuple[StartupHookWriteFinding, ...] = (), + partial_path: tuple[str, ...] | None = None, + ) -> None: + super().__init__(message) + self.partial_findings = partial_findings + self.partial_startup_hook_write_findings = partial_startup_hook_write_findings + self.partial_path = partial_path + + def _register_source_sensitive_cache(function: _CachedFunctionT) -> _CachedFunctionT: _SOURCE_SENSITIVE_CACHED_FUNCTIONS.add(cast(_CacheClearable, function)) return function @@ -320,26 +338,42 @@ def find_dangerous_call_graphs( if str(reference.get("module", "")) and str(reference.get("name", "")) } + analysis_limit_error: _CallGraphAnalysisLimitError | None = None for reference in _iter_call_graph_references(import_references, callable_references, invoked_references): module = str(reference.get("module", "")) name = str(reference.get("name", "")) if not module or not name: continue - entrypoints = _call_graph_entrypoints_for_reference(module, name, reference) + try: + entrypoints = _call_graph_entrypoints_for_reference(module, name, reference) + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + continue if not entrypoints: continue allow_invoked_non_lifecycle_entrypoint = _is_explicit_method_import_reference(name) - sink_path = _first_matching_path(entrypoints, _find_sink_path) + try: + sink_path = _first_matching_path(entrypoints, _find_sink_path) + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + sink_path = error.partial_path if sink_path is None: for positional_arg_count in positional_arg_counts.get((module, name), ()): - sink_path = _first_matching_path( - entrypoints, - _invoked_import_execution_path_callback( - positional_arg_count, - allow_non_lifecycle_entrypoint=allow_invoked_non_lifecycle_entrypoint, - ), - ) + try: + sink_path = _first_matching_path( + entrypoints, + _invoked_import_execution_path_callback( + positional_arg_count, + allow_non_lifecycle_entrypoint=allow_invoked_non_lifecycle_entrypoint, + ), + ) + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + sink_path = error.partial_path if sink_path is not None: break if sink_path is None: @@ -362,6 +396,11 @@ def find_dangerous_call_graphs( ) if len(findings) >= _MAX_IMPORT_REFERENCES: break + if analysis_limit_error is not None: + raise _CallGraphAnalysisLimitError( + str(analysis_limit_error), + partial_findings=tuple(findings), + ) from analysis_limit_error return tuple(findings) @@ -380,6 +419,7 @@ def find_startup_hook_write_call_graphs( for reference in _iter_callable_invocation_references(callable_invocations) } require_invocations = callable_invocations is not None and callable_invocations_complete + analysis_limit_error: _CallGraphAnalysisLimitError | None = None for reference in _iter_import_references(import_references): module = str(reference.get("module", "")) name = str(reference.get("name", "")) @@ -389,12 +429,28 @@ def find_startup_hook_write_call_graphs( continue seen.add((module, name)) - entrypoints = _safe_call_graph_entrypoints(f"{module}.{name}") + try: + entrypoints = _safe_call_graph_entrypoints(f"{module}.{name}") + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + continue if not entrypoints: continue - if _first_matching_path(entrypoints, _find_sink_path) is not None: + try: + has_sink_path = _first_matching_path(entrypoints, _find_sink_path) is not None + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + has_sink_path = error.partial_path is not None + if has_sink_path: continue - open_path = _first_matching_path(entrypoints, _find_file_open_path) + try: + open_path = _first_matching_path(entrypoints, _find_file_open_path) + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + open_path = error.partial_path if open_path is not None: openers.append( _ImportCallPath( @@ -404,7 +460,12 @@ def find_startup_hook_write_call_graphs( call_path=open_path, ) ) - write_path = _first_matching_path(entrypoints, _find_file_write_path) + try: + write_path = _first_matching_path(entrypoints, _find_file_write_path) + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + write_path = error.partial_path if write_path is not None: writers.append( _ImportCallPath( @@ -415,6 +476,19 @@ def find_startup_hook_write_call_graphs( ) ) + findings = _materialize_startup_hook_write_findings(openers, writers) + if analysis_limit_error is not None: + raise _CallGraphAnalysisLimitError( + str(analysis_limit_error), + partial_startup_hook_write_findings=findings, + ) from analysis_limit_error + return findings + + +def _materialize_startup_hook_write_findings( + openers: list[_ImportCallPath], + writers: list[_ImportCallPath], +) -> tuple[StartupHookWriteFinding, ...]: if not openers or not writers: return () @@ -497,6 +571,8 @@ def shared_source_sensitive_caches() -> Iterator[None]: def _safe_call_graph_entrypoints(function_name: str) -> tuple[str, ...]: try: return _call_graph_entrypoints(function_name) + except _CallGraphAnalysisLimitError: + raise except Exception: return () @@ -505,13 +581,24 @@ def _first_matching_path( entrypoints: Iterable[str], path_for: Callable[[str], tuple[str, ...] | None], ) -> tuple[str, ...] | None: + analysis_limit_error: _CallGraphAnalysisLimitError | None = None for entrypoint in entrypoints: try: path = path_for(entrypoint) + except _CallGraphAnalysisLimitError as error: + if analysis_limit_error is None: + analysis_limit_error = error + continue except Exception: continue if path is not None: + if analysis_limit_error is not None: + raise _CallGraphAnalysisLimitError( + str(analysis_limit_error), partial_path=path + ) from analysis_limit_error return path + if analysis_limit_error is not None: + raise analysis_limit_error return None @@ -727,7 +814,10 @@ def _is_skippable_torch_extension_global_reference(module: str, name: str) -> bo source_path = _resolve_module_source(module) if source_path is not None and not _is_library_source_path(str(source_path)): return False - return not _has_static_torch_extension_global_target(module, name) + try: + return not _has_static_torch_extension_global_target(module, name) + except _CallGraphAnalysisLimitError: + return False @_register_source_sensitive_cache @@ -1653,6 +1743,307 @@ def _local_class_node_from_target( return local_class_nodes.get(class_name) +def _contains_current_loop_break(nodes: Iterable[ast.stmt]) -> bool: + """Return whether this loop body can break without entering a nested scope or loop.""" + + def contains_break(node: ast.AST) -> bool: + if isinstance(node, ast.Break): + return True + if isinstance( + node, + ast.For | ast.AsyncFor | ast.While | ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef, + ): + return False + return any(contains_break(child) for child in ast.iter_child_nodes(node)) + + return any(contains_break(node) for node in nodes) + + +def _is_exhaustive_match(node: ast.Match) -> bool: + return any( + isinstance(case.pattern, ast.MatchAs) and case.pattern.pattern is None and case.guard is None + for case in node.cases + ) + + +_TerminalAssignment = ast.Assign | ast.AnnAssign +_TerminalAssignmentGroup = tuple[str, tuple[_TerminalAssignment | None, ...]] + + +def _can_complete_normally(branch_body: Iterable[ast.stmt]) -> bool: + statements = tuple(branch_body) + return not statements or not isinstance(statements[-1], ast.Raise | ast.Return) + + +def _assignment_alias_targets(branch_body: Iterable[ast.stmt]) -> set[str]: + targets: set[str] = set() + for statement in _definition_scope_statements(branch_body): + targets.update(_assignment_alias_target_names(statement)) + return targets + + +def _terminal_assignment_groups( + branch_bodies: tuple[tuple[ast.stmt, ...], ...], +) -> tuple[_TerminalAssignmentGroup, ...]: + terminal_assignments: list[dict[str, _TerminalAssignment]] = [] + for branch_body in branch_bodies: + suffix_assignments: dict[str, _TerminalAssignment] = {} + for statement in reversed(branch_body): + if isinstance(statement, ast.Expr | ast.Pass): + continue + if not isinstance(statement, ast.Assign | ast.AnnAssign) or statement.value is None: + break + for target_name in _assignment_alias_target_names(statement): + suffix_assignments.setdefault(target_name, statement) + terminal_assignments.append(suffix_assignments) + if not terminal_assignments: + return () + terminal_targets = set().union(*(set(assignments) for assignments in terminal_assignments)) + return tuple( + (target_name, tuple(assignments.get(target_name) for assignments in terminal_assignments)) + for target_name in sorted(terminal_targets) + ) + + +def _conditionally_rebound_assignment_nodes( + nodes: Iterable[ast.AST], +) -> tuple[dict[str, set[int]], tuple[tuple[_TerminalAssignmentGroup, ...], ...]]: + """Return alternate-path assignment nodes grouped by ambiguously rebound name.""" + node_list = tuple(nodes) + ambiguous_assignment_nodes: dict[str, set[int]] = {} + terminal_assignment_group_sets: list[tuple[_TerminalAssignmentGroup, ...]] = [] + for node in node_list: + alternate_bodies: tuple[Iterable[ast.stmt], ...] + branch_bodies: tuple[Iterable[ast.stmt], ...] + terminating_bodies: tuple[Iterable[ast.stmt], ...] + deterministic_terminal_bodies: tuple[Iterable[ast.stmt], ...] | None = None + if isinstance(node, ast.If): + alternate_bodies = (node.body, node.orelse) + terminating_bodies = tuple( + branch_body for branch_body in alternate_bodies if not _can_complete_normally(branch_body) + ) + branch_bodies = tuple( + branch_body for branch_body in alternate_bodies if _can_complete_normally(branch_body) + ) + continuing_targets = set().union(*(_assignment_alias_targets(branch_body) for branch_body in branch_bodies)) + terminating_targets = set().union( + *(_assignment_alias_targets(branch_body) for branch_body in terminating_bodies) + ) + if continuing_targets & terminating_targets: + branch_bodies = alternate_bodies + if len(branch_bodies) < 2: + continue + deterministic_terminal_bodies = branch_bodies + elif isinstance(node, ast.Try) and node.handlers: + alternate_bodies = ( + (*node.body, *node.orelse), + *(handler.body for handler in node.handlers), + ) + terminating_bodies = tuple( + branch_body for branch_body in alternate_bodies if not _can_complete_normally(branch_body) + ) + branch_bodies = tuple( + branch_body for branch_body in alternate_bodies if _can_complete_normally(branch_body) + ) + continuing_targets = set().union(*(_assignment_alias_targets(branch_body) for branch_body in branch_bodies)) + terminating_targets = set().union( + *(_assignment_alias_targets(branch_body) for branch_body in terminating_bodies) + ) + if continuing_targets & terminating_targets: + branch_bodies = alternate_bodies + if len(branch_bodies) < 2: + continue + deterministic_terminal_bodies = branch_bodies + elif isinstance(node, ast.For | ast.AsyncFor | ast.While) and _contains_current_loop_break(node.body): + branch_bodies = (node.body, node.orelse) + if node.body and isinstance(node.body[-1], ast.Break) and not _contains_current_loop_break(node.body[:-1]): + deterministic_terminal_bodies = (node.body[:-1], node.orelse) + elif isinstance(node, ast.Match): + branch_bodies = tuple(case.body for case in node.cases) + if _is_exhaustive_match(node): + deterministic_terminal_bodies = branch_bodies + else: + continue + + branch_statement_bodies = tuple(tuple(branch_body) for branch_body in branch_bodies) + for branch_body in branch_statement_bodies: + for statement in _definition_scope_statements(branch_body): + for target_name in _assignment_alias_target_names(statement): + ambiguous_assignment_nodes.setdefault(target_name, set()).add(id(statement)) + + if deterministic_terminal_bodies is not None: + groups = _terminal_assignment_groups( + tuple(tuple(branch_body) for branch_body in deterministic_terminal_bodies) + ) + if groups: + terminal_assignment_group_sets.append(groups) + return ambiguous_assignment_nodes, tuple(terminal_assignment_group_sets) + + +def _resolved_terminal_assignment_nodes( + nodes: tuple[ast.AST, ...], + terminal_assignment_group_sets: tuple[tuple[_TerminalAssignmentGroup, ...], ...], + ambiguous_assignment_nodes: dict[str, set[int]], + module_name: str, + aliases: dict[str, str], + local_defs: set[str], + local_class_targets: set[str], + *, + class_name: str | None, +) -> dict[str, set[int]]: + deterministic_node_ids: dict[str, set[int]] = {} + + def incoming_alias_value( + target_name: str, + statements: tuple[_TerminalAssignment | None, ...], + effective_ambiguous_node_ids: dict[str, set[int]], + ) -> tuple[bool, str | None]: + statement_ids = {id(statement) for statement in statements if statement is not None} + first_index = min(index for index, node in enumerate(nodes) if id(node) in statement_ids) + found_prior_assignment = False + incoming_value: str | None = None + for node in nodes[:first_index]: + if target_name not in _assignment_alias_target_names(node): + continue + if id(node) in effective_ambiguous_node_ids.get(target_name, set()): + continue + found_prior_assignment = True + incoming_value = _assignment_alias_value( + node, + module_name, + aliases, + local_defs, + local_class_targets, + class_name=class_name, + ) + return found_prior_assignment, incoming_value + + while True: + effective_ambiguous_node_ids = { + target_name: node_ids - deterministic_node_ids.get(target_name, set()) + for target_name, node_ids in ambiguous_assignment_nodes.items() + if node_ids - deterministic_node_ids.get(target_name, set()) + } + active_ambiguous_targets: set[str] = set() + ambiguous_before_statement: dict[int, set[str]] = {} + for node in nodes: + target_names = _assignment_alias_target_names(node) + if not target_names: + continue + ambiguous_before_statement[id(node)] = set(active_ambiguous_targets) + for target_name in target_names: + if id(node) in effective_ambiguous_node_ids.get(target_name, set()): + active_ambiguous_targets.add(target_name) + else: + active_ambiguous_targets.discard(target_name) + + changed = False + for terminal_assignment_groups in terminal_assignment_group_sets: + branch_count = len(terminal_assignment_groups[0][1]) + resolved_by_branch: list[dict[str, str]] = [] + for branch_index in range(branch_count): + branch_statements: dict[int, _TerminalAssignment] = {} + for _, statements in terminal_assignment_groups: + statement = statements[branch_index] + if statement is not None: + branch_statements[id(statement)] = statement + branch_aliases = dict(aliases) + branch_local_targets: set[str] = set() + branch_resolved: dict[str, str] = {} + for statement in sorted( + branch_statements.values(), + key=lambda statement: (statement.lineno, statement.col_offset), + ): + blocked_dependencies = ( + _assignment_value_read_names(statement) + & ambiguous_before_statement.get(id(statement), set()) - branch_local_targets + ) + if blocked_dependencies: + continue + resolved = _assignment_alias_value( + statement, + module_name, + branch_aliases, + local_defs, + local_class_targets, + class_name=class_name, + ) + if resolved is None: + continue + for target_name in _assignment_alias_target_names(statement): + branch_aliases[target_name] = resolved + branch_local_targets.add(target_name) + branch_resolved[target_name] = resolved + resolved_by_branch.append(branch_resolved) + + for target_name, statements in terminal_assignment_groups: + present_values = tuple( + resolved_by_branch[index].get(target_name) + for index, statement in enumerate(statements) + if statement is not None + ) + if None in present_values: + continue + resolved_values = list(present_values) + if any(statement is None for statement in statements): + present_statements = tuple(statement for statement in statements if statement is not None) + if any( + target_name in ambiguous_before_statement.get(id(statement), set()) + for statement in present_statements + ): + continue + found_incoming, incoming_value = incoming_alias_value( + target_name, + statements, + effective_ambiguous_node_ids, + ) + if found_incoming: + if incoming_value is None: + continue + resolved_values.append(incoming_value) + if resolved_values and len(set(resolved_values)) == 1: + node_ids = deterministic_node_ids.setdefault(target_name, set()) + prior_count = len(node_ids) + node_ids.update(id(statement) for statement in statements if statement is not None) + changed = changed or len(node_ids) != prior_count + if not changed: + break + return deterministic_node_ids + + +def _propagated_ambiguous_assignment_nodes( + nodes: tuple[ast.AST, ...], + ambiguous_assignment_nodes: dict[str, set[int]], + deterministic_node_ids: dict[str, set[int]], + *, + propagate_reads: bool, +) -> tuple[dict[str, set[int]], dict[str, set[int]]]: + effective_ambiguous_node_ids = { + target_name: node_ids - deterministic_node_ids.get(target_name, set()) + for target_name, node_ids in ambiguous_assignment_nodes.items() + if node_ids - deterministic_node_ids.get(target_name, set()) + } + propagated_assignment_nodes: dict[str, set[int]] = {} + if not propagate_reads: + return effective_ambiguous_node_ids, propagated_assignment_nodes + active_ambiguous_targets: set[str] = set() + for node in nodes: + target_names = _assignment_alias_target_names(node) + if not target_names: + continue + reads_ambiguous_target = bool(_assignment_alias_read_names(node) & active_ambiguous_targets) + for target_name in target_names: + conditional_node_ids = effective_ambiguous_node_ids.get(target_name, set()) + if id(node) in conditional_node_ids or reads_ambiguous_target: + if reads_ambiguous_target: + effective_ambiguous_node_ids.setdefault(target_name, set()).add(id(node)) + propagated_assignment_nodes.setdefault(target_name, set()).add(id(node)) + active_ambiguous_targets.add(target_name) + else: + active_ambiguous_targets.discard(target_name) + return effective_ambiguous_node_ids, propagated_assignment_nodes + + def _collect_assignment_aliases( nodes: Iterable[ast.AST], module_name: str, @@ -1664,10 +2055,22 @@ def _collect_assignment_aliases( ) -> dict[str, str]: node_list = tuple(nodes) assignment_aliases: dict[str, str] = {} + source_path = _resolve_module_source(module_name) + conditionally_rebound_node_ids, terminal_assignment_group_sets = _conditionally_rebound_assignment_nodes(node_list) + seen_states: set[tuple[tuple[str, str], ...]] = {()} + passes = 0 changed = True while changed and len(assignment_aliases) < _MAX_ASSIGNMENT_ALIASES: + if passes >= _MAX_ASSIGNMENT_ALIAS_PASSES: + raise _CallGraphAnalysisLimitError( + f"assignment alias analysis exceeded {_MAX_ASSIGNMENT_ALIAS_PASSES} propagation passes" + ) + passes += 1 + state = tuple(sorted(assignment_aliases.items())) changed = False + last_changed_node_ids: dict[str, int] = {} + last_resolved_node_ids: dict[str, int] = {} scoped_aliases = {**aliases, **assignment_aliases} for node in node_list: resolved = _assignment_alias_value( @@ -1681,12 +2084,50 @@ def _collect_assignment_aliases( if resolved is None: continue for target_name in _assignment_alias_target_names(node): + last_resolved_node_ids[target_name] = id(node) if assignment_aliases.get(target_name) == resolved: continue assignment_aliases[target_name] = resolved changed = True + last_changed_node_ids[target_name] = id(node) if len(assignment_aliases) >= _MAX_ASSIGNMENT_ALIASES: break + next_state = tuple(sorted(assignment_aliases.items())) + if next_state == state: + deterministic_node_ids = _resolved_terminal_assignment_nodes( + node_list, + terminal_assignment_group_sets, + conditionally_rebound_node_ids, + module_name, + {**aliases, **assignment_aliases}, + local_defs, + local_class_targets, + class_name=class_name, + ) + effective_conditionally_rebound_node_ids, propagated_rebound_node_ids = ( + _propagated_ambiguous_assignment_nodes( + node_list, + conditionally_rebound_node_ids, + deterministic_node_ids, + propagate_reads=source_path is None or not _is_stdlib_source_path(str(source_path)), + ) + ) + changed_conditionally = any( + node_id in effective_conditionally_rebound_node_ids.get(target_name, set()) + for target_name, node_id in last_changed_node_ids.items() + ) + resolved_from_conditional_read = any( + node_id in propagated_rebound_node_ids.get(target_name, set()) + for target_name, node_id in last_resolved_node_ids.items() + ) + if changed_conditionally or resolved_from_conditional_read: + raise _CallGraphAnalysisLimitError( + "assignment alias analysis encountered ambiguous conditional rebinding" + ) + break + if next_state in seen_states: + raise _CallGraphAnalysisLimitError("assignment alias analysis entered a propagation cycle") + seen_states.add(next_state) return assignment_aliases @@ -1755,6 +2196,73 @@ def _assignment_alias_target_names(node: ast.AST) -> set[str]: return set() +def _assignment_alias_read_names(node: ast.AST) -> set[str]: + if not isinstance(node, ast.Assign | ast.AnnAssign) or node.value is None: + return set() + value = node.value + if isinstance(value, ast.Call): + value = value.func + while isinstance(value, ast.Attribute): + value = value.value + if isinstance(value, ast.Name): + return {value.id} + return set() + + +def _assignment_value_read_names(node: ast.Assign | ast.AnnAssign) -> set[str]: + if node.value is None: + return set() + + names: set[str] = set() + + def binding_names(target: ast.AST) -> set[str]: + return {child.id for child in ast.walk(target) if isinstance(child, ast.Name)} + + def visit(value: ast.AST, bound_names: set[str]) -> None: + if isinstance(value, ast.Name): + if isinstance(value.ctx, ast.Load) and value.id not in bound_names: + names.add(value.id) + return + if isinstance(value, ast.Lambda): + for default in ( + *value.args.defaults, + *(default for default in value.args.kw_defaults if default is not None), + ): + visit(default, bound_names) + lambda_bound_names = { + argument.arg + for argument in ( + *value.args.posonlyargs, + *value.args.args, + *value.args.kwonlyargs, + ) + } + if value.args.vararg is not None: + lambda_bound_names.add(value.args.vararg.arg) + if value.args.kwarg is not None: + lambda_bound_names.add(value.args.kwarg.arg) + visit(value.body, bound_names | lambda_bound_names) + return + if isinstance(value, ast.ListComp | ast.SetComp | ast.GeneratorExp | ast.DictComp): + comprehension_bound_names = set(bound_names) + for generator in value.generators: + visit(generator.iter, comprehension_bound_names) + comprehension_bound_names.update(binding_names(generator.target)) + for condition in generator.ifs: + visit(condition, comprehension_bound_names) + if isinstance(value, ast.DictComp): + visit(value.key, comprehension_bound_names) + visit(value.value, comprehension_bound_names) + else: + visit(value.elt, comprehension_bound_names) + return + for child in ast.iter_child_nodes(value): + visit(child, bound_names) + + visit(node.value, set()) + return names + + def _is_local_class_member_alias(resolved: str, local_class_targets: set[str]) -> bool: return any( resolved == class_target or resolved.startswith(f"{class_target}.") for class_target in local_class_targets diff --git a/packages/modelaudit-picklescan/tests/test_api.py b/packages/modelaudit-picklescan/tests/test_api.py index 9dac5378b..4324fb260 100644 --- a/packages/modelaudit-picklescan/tests/test_api.py +++ b/packages/modelaudit-picklescan/tests/test_api.py @@ -36,7 +36,12 @@ scan_bytes, scan_file, ) -from modelaudit_picklescan.call_graph import find_startup_hook_write_call_graphs +from modelaudit_picklescan.call_graph import ( + CallGraphFinding, + StartupHookWriteFinding, + _CallGraphAnalysisLimitError, + find_startup_hook_write_call_graphs, +) def _expected_system_global() -> str: @@ -1371,8 +1376,9 @@ def test_scan_file_detects_hidden_pytorch_zip_pickle_member_without_data_pickle( def test_scan_file_leaves_hidden_pickle_like_zip_without_pytorch_metadata_unrecognized(tmp_path: Path) -> None: archive_path = tmp_path / "hidden-only.zip" + entry = zipfile.ZipInfo("archive/payload", (1980, 1, 1, 0, 0, 0)) with zipfile.ZipFile(archive_path, "w") as archive: - archive.writestr("archive/payload", pickle.dumps({"weights": [1, 2, 3]}, protocol=4)) + archive.writestr(entry, pickle.dumps({"weights": [1, 2, 3]}, protocol=4)) report = scan_file(archive_path) @@ -3513,6 +3519,48 @@ def raise_call_graph_error(*_args: object, **_kwargs: object) -> tuple[()]: ) +def test_with_call_graph_findings_preserves_critical_findings_before_limit_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + partial_finding = CallGraphFinding( + module="click", + name="edit", + import_reference="click.edit", + sink="os.system", + call_path=("click.edit", "os.system"), + ) + + def raise_call_graph_limit(*_args: object, **_kwargs: object) -> tuple[CallGraphFinding, ...]: + raise _CallGraphAnalysisLimitError( + "assignment alias analysis entered a propagation cycle", + partial_findings=(partial_finding,), + ) + + monkeypatch.setattr(package_api, "find_dangerous_call_graphs", raise_call_graph_limit) + report = PickleReport( + source="partial-call-graph-findings.pkl", + status=ScanStatus.COMPLETE, + verdict=SafetyVerdict.CLEAN, + metadata={"import_references": ()}, + ) + + updated = package_api._with_call_graph_findings(report) + + assert updated.status == ScanStatus.INCONCLUSIVE + assert updated.verdict == SafetyVerdict.MALICIOUS + assert updated.metadata["analysis_incomplete"] is True + call_graph_findings = [finding for finding in updated.findings if finding.rule_code == "DANGEROUS_CALL_GRAPH"] + assert len(call_graph_findings) == 1 + assert call_graph_findings[0].details["module"] == "click" + assert call_graph_findings[0].details["name"] == "edit" + assert any( + error.category == "call_graph_analysis_error" + and error.exception_type == "_CallGraphAnalysisLimitError" + and error.details["analysis_incomplete"] is True + for error in updated.errors + ) + + def test_with_call_graph_findings_marks_startup_hook_enrichment_failures_incomplete( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -3548,6 +3596,55 @@ def raise_startup_hook_error(*_args: object, **_kwargs: object) -> tuple[()]: ) +def test_with_call_graph_findings_preserves_startup_hook_findings_before_limit_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + partial_finding = StartupHookWriteFinding( + opener_module="click", + opener_name="open_file", + writer_module="click", + writer_name="echo", + opener_import_reference="click.open_file", + writer_import_reference="click.echo", + open_sink="builtins.open", + write_sink="binary_file.write", + opener_call_path=("click.open_file", "builtins.open"), + writer_call_path=("click.echo", "binary_file.write"), + ) + + def raise_startup_hook_limit(*_args: object, **_kwargs: object) -> tuple[StartupHookWriteFinding, ...]: + raise _CallGraphAnalysisLimitError( + "assignment alias analysis entered a propagation cycle", + partial_startup_hook_write_findings=(partial_finding,), + ) + + monkeypatch.setattr(package_api, "find_startup_hook_write_call_graphs", raise_startup_hook_limit) + report = PickleReport( + source="partial-startup-hook-findings.pkl", + status=ScanStatus.COMPLETE, + verdict=SafetyVerdict.CLEAN, + metadata={"import_references": ()}, + ) + + updated = package_api._with_call_graph_findings(report) + + assert updated.status == ScanStatus.INCONCLUSIVE + assert updated.verdict == SafetyVerdict.MALICIOUS + assert updated.metadata["analysis_incomplete"] is True + startup_findings = [ + finding for finding in updated.findings if finding.rule_code == "DANGEROUS_CALL_GRAPH_FILE_WRITE" + ] + assert len(startup_findings) == 1 + assert startup_findings[0].details["opener_name"] == "open_file" + assert startup_findings[0].details["name"] == "echo" + assert any( + error.category == "call_graph_analysis_error" + and error.exception_type == "_CallGraphAnalysisLimitError" + and error.details["analysis"] == "python_call_graph_startup_hook_write" + for error in updated.errors + ) + + def test_with_call_graph_findings_marks_source_unavailable_enrichment_failures_incomplete( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/packages/modelaudit-picklescan/tests/test_call_graph_assignment_alias_cycle.py b/packages/modelaudit-picklescan/tests/test_call_graph_assignment_alias_cycle.py new file mode 100644 index 000000000..40b1a5857 --- /dev/null +++ b/packages/modelaudit-picklescan/tests/test_call_graph_assignment_alias_cycle.py @@ -0,0 +1,1179 @@ +"""Regression: assignment-alias fixpoint must terminate on oscillating binds. + +A module that binds the same name in both branches of an ``if``/``else`` +(e.g. the ``imaplib``/``http.server``/``nntplib`` stdlib ``__main__`` blocks) +makes ``_collect_assignment_aliases`` oscillate the alias between two values +without ever growing the dict. The dict-size guard never trips, so the +fixpoint loop spins forever and the whole scan hangs (GitHub issue #1247). +""" + +from __future__ import annotations + +import ast +import threading +from collections.abc import Iterable +from importlib.util import find_spec + +import pytest + +from modelaudit_picklescan import PickleReport, Severity, scan_bytes +from modelaudit_picklescan.api import _RUST_EXTENSION_MODULE +from modelaudit_picklescan.call_graph import ( + CallGraphFinding, + StartupHookWriteFinding, + _CallGraphAnalysisLimitError, + _collect_assignment_aliases, + _collect_local_defs, + _first_matching_path, + _module_level_statements, + _safe_call_graph_entrypoints, + find_dangerous_call_graphs, + find_startup_hook_write_call_graphs, +) + +_OSCILLATING_MODULE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +if cond: + m = A() +else: + m = B() +""" + +_DEPENDENT_CYCLE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +a = A() +a = b +b = B() +b = a +c = a +""" + +_SEQUENTIAL_REBIND_SOURCE = """\ +class A: + pass + + +class B: + pass + + +m = A() +m = B() +""" + +_TRY_REBIND_SOURCE = """\ +class A: + pass + + +class B: + pass + + +try: + m = A() + operation() +except Exception: + m = B() +""" + +_LOOP_ELSE_REBIND_SOURCE = """\ +class A: + pass + + +class B: + pass + + +for value in values: + m = A() + break +else: + m = B() +""" + +_LOOP_ELSE_COMPLETION_SOURCE = """\ +class A: + pass + + +class B: + pass + + +for value in values: + m = A() +else: + m = B() +""" + +_LOOP_MATCHING_TERMINAL_REBIND_SOURCE = """\ +class A: + pass + + +class Final: + pass + + +m = A() +for value in values: + m = Final() + break +else: + m = Final() +exposed = m +""" + +_LOOP_EARLY_BREAK_BEFORE_TERMINAL_REBIND_SOURCE = """\ +class A: + pass + + +class Final: + pass + + +m = A() +for value in values: + if cond: + break + m = Final() + break +else: + m = Final() +exposed = m +""" + +_UNCONDITIONAL_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() +else: + m = B() +m = Final() +""" + +_READ_BEFORE_UNCONDITIONAL_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() +else: + m = B() +exposed = m +m = Final() +""" + +_CALL_READ_BEFORE_UNCONDITIONAL_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() +else: + m = B() +exposed = m() +m = Final() +""" + +_TERMINATING_ALIAS_BRANCH_SOURCE = """\ +class A: + pass + + +class B: + pass + + +if cond: + m = A() + consumed = m + raise ImportError +else: + m = B() +exposed = m +""" + +_OVERWRITE_BEFORE_READ_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() +else: + m = B() +m = Final() +exposed = m +""" + +_ONE_SIDED_REBIND_SOURCE = """\ +class A: + pass + + +class B: + pass + + +m = A() +if cond: + m = B() +exposed = m +""" + +_NONEXHAUSTIVE_MATCH_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +m = A() +match value: + case 1: + m = B() + case 2: + m = B() +exposed = m +""" + +_CONDITIONALLY_REBOUND_TERMINAL_DEPENDENCY_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + F = A +else: + F = B +if cond: + m = F() +else: + m = F() +F = Final +exposed = m +""" + +_BRANCH_READ_BEFORE_MATCHING_TERMINAL_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() + exposed = m + m = Final() +else: + m = B() + exposed = m + m = Final() +""" + +_MATCHING_BRANCH_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() + m = Final() +else: + m = B() + m = Final() +exposed = m +""" + +_MATCHING_BRANCH_EPILOGUE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + m = A() + m = Final() + log() +else: + m = B() + m = Final() + log() +exposed = m +""" + +_RESOLVED_MATCHING_BRANCH_OVERWRITE_SOURCE = """\ +class Final: + pass + + +F = Final +if cond: + m = F() +else: + m = Final() +exposed = m +""" + +_SAME_LINE_RESOLVED_MATCHING_BRANCH_OVERWRITE_SOURCE = """\ +class Final: + pass + + +if cond: + F = Final; m = F() +else: + F = Final; m = F() +exposed = m +""" + +_OVERWRITTEN_TERMINAL_DEPENDENCY_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +if cond: + F = A +else: + F = B +F = Final +if cond: + m = F() +else: + m = F() +exposed = m +""" + +_MUTUALLY_DEPENDENT_TERMINAL_ALIAS_SOURCE = """\ +class Final: + pass + + +if cond: + m = Final() + exposed = m +else: + exposed = Final() + m = exposed +""" + +_TERMINATING_ALTERNATIVE_SOURCE = """\ +class Final: + pass + + +if cond: + m = Final() +else: + raise ImportError +exposed = m +""" + +_SCOPED_TERMINAL_DEPENDENCY_SOURCE = """\ +class Final: + pass + + +if cond: + F = first_value +else: + F = second_value +if cond: + m = Final([F for F in values]) +else: + m = Final(lambda F: F) +exposed = m +""" + +_MATCHING_ONE_SIDED_REBIND_SOURCE = """\ +class Final: + pass + + +m = Final() +if cond: + m = Final() +exposed = m +""" + +_UNBOUND_ONE_SIDED_REBIND_SOURCE = """\ +class Final: + pass + + +if cond: + m = Final() +exposed = m +""" + +_TERMINATING_UNRELATED_ALIAS_SOURCE = """\ +class A: + pass + + +class Final: + pass + + +if cond: + unused = A() + raise ImportError +else: + m = Final() +exposed = m +""" + +_MATCHING_TRY_EXCEPT_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +try: + m = A() + m = Final() +except Exception: + m = B() + m = Final() +exposed = m +""" + +_EXHAUSTIVE_MATCH_OVERWRITE_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +match value: + case 1: + m = A() + m = Final() + exposed = m + case _: + m = B() + m = Final() + exposed = m +""" + +_TRY_FINALLY_NO_HANDLER_SOURCE = """\ +class Final: + pass + + +try: + m = Final() +finally: + cleanup() +exposed = m +""" + +_TRY_FINALLY_REBIND_SOURCE = """\ +class A: + pass + + +class B: + pass + + +class Final: + pass + + +try: + m = A() +except Exception: + m = B() +finally: + m = Final() +""" + + +def _run_with_timeout(target: object, timeout: float = 10.0) -> None: + thread = threading.Thread(target=target) # type: ignore[arg-type] + thread.daemon = True + thread.start() + thread.join(timeout) + if thread.is_alive(): + pytest.fail(f"call-graph analysis did not terminate within {timeout}s") + + +@pytest.mark.parametrize( + "source", + ( + _OSCILLATING_MODULE_SOURCE, + _TRY_REBIND_SOURCE, + _LOOP_ELSE_REBIND_SOURCE, + _LOOP_EARLY_BREAK_BEFORE_TERMINAL_REBIND_SOURCE, + _READ_BEFORE_UNCONDITIONAL_OVERWRITE_SOURCE, + _CALL_READ_BEFORE_UNCONDITIONAL_OVERWRITE_SOURCE, + _TERMINATING_ALIAS_BRANCH_SOURCE, + _ONE_SIDED_REBIND_SOURCE, + _NONEXHAUSTIVE_MATCH_OVERWRITE_SOURCE, + _CONDITIONALLY_REBOUND_TERMINAL_DEPENDENCY_SOURCE, + _BRANCH_READ_BEFORE_MATCHING_TERMINAL_OVERWRITE_SOURCE, + ), + ids=( + "if-else", + "try-except", + "loop-else", + "loop-early-break", + "read-before-overwrite", + "call-read-before-overwrite", + "terminating-alias-branch", + "one-sided-rebind", + "match-no-default", + "conditional-terminal-dependency", + "branch-read-before-terminal-overwrite", + ), +) +def test_collect_assignment_aliases_fails_closed_on_stable_branch_rebind(source: str) -> None: + tree = ast.parse(source) + statements = _module_level_statements(tree) + local_defs = _collect_local_defs(statements) + local_class_targets = {"testmod.A", "testmod.B", "testmod.Final"} + + result: dict[str, bool] = {} + + def _collect() -> None: + with pytest.raises(_CallGraphAnalysisLimitError, match="ambiguous conditional rebinding"): + _collect_assignment_aliases( + statements, + "testmod", + {}, + local_defs, + local_class_targets, + ) + result["limited"] = True + + _run_with_timeout(_collect) + + assert result == {"limited": True} + + +@pytest.mark.parametrize( + ("source", "expected_target"), + ( + (_SEQUENTIAL_REBIND_SOURCE, "testmod.B"), + (_LOOP_ELSE_COMPLETION_SOURCE, "testmod.B"), + (_UNCONDITIONAL_OVERWRITE_SOURCE, "testmod.Final"), + (_TRY_FINALLY_REBIND_SOURCE, "testmod.Final"), + ), + ids=("sequential", "loop-else-completion", "unconditional-overwrite", "try-finally-overwrite"), +) +def test_collect_assignment_aliases_converges_on_deterministic_final_rebind( + source: str, + expected_target: str, +) -> None: + tree = ast.parse(source) + statements = _module_level_statements(tree) + local_defs = _collect_local_defs(statements) + local_class_targets = {"testmod.A", "testmod.B", "testmod.Final"} + + aliases = _collect_assignment_aliases( + statements, + "testmod", + {}, + local_defs, + local_class_targets, + ) + + assert aliases["m"] == expected_target + + +def test_collect_assignment_aliases_allows_alias_read_after_deterministic_overwrite() -> None: + tree = ast.parse(_OVERWRITE_BEFORE_READ_SOURCE) + statements = _module_level_statements(tree) + local_defs = _collect_local_defs(statements) + + aliases = _collect_assignment_aliases( + statements, + "testmod", + {}, + local_defs, + {"testmod.A", "testmod.B", "testmod.Final"}, + ) + + assert aliases["m"] == "testmod.Final" + assert aliases["exposed"] == "testmod.Final" + + +@pytest.mark.parametrize( + "source", + ( + _MATCHING_BRANCH_OVERWRITE_SOURCE, + _MATCHING_BRANCH_EPILOGUE_SOURCE, + _RESOLVED_MATCHING_BRANCH_OVERWRITE_SOURCE, + _SAME_LINE_RESOLVED_MATCHING_BRANCH_OVERWRITE_SOURCE, + _OVERWRITTEN_TERMINAL_DEPENDENCY_SOURCE, + _MUTUALLY_DEPENDENT_TERMINAL_ALIAS_SOURCE, + _TERMINATING_ALTERNATIVE_SOURCE, + _SCOPED_TERMINAL_DEPENDENCY_SOURCE, + _MATCHING_ONE_SIDED_REBIND_SOURCE, + _UNBOUND_ONE_SIDED_REBIND_SOURCE, + _TERMINATING_UNRELATED_ALIAS_SOURCE, + _MATCHING_TRY_EXCEPT_OVERWRITE_SOURCE, + _EXHAUSTIVE_MATCH_OVERWRITE_SOURCE, + _LOOP_MATCHING_TERMINAL_REBIND_SOURCE, + _TRY_FINALLY_NO_HANDLER_SOURCE, + ), + ids=( + "if-else", + "if-expression-epilogue", + "if-semantic-resolution", + "if-same-line-semantic-resolution", + "overwritten-terminal-dependency", + "mutually-dependent-terminal-alias", + "terminating-alternative", + "scoped-terminal-dependency", + "matching-one-sided-rebind", + "unbound-one-sided-rebind", + "terminating-unrelated-alias", + "try-except", + "exhaustive-match", + "loop-terminal-break", + "try-finally-no-handler", + ), +) +def test_collect_assignment_aliases_allows_matching_terminal_branch_overwrites(source: str) -> None: + tree = ast.parse(source) + statements = _module_level_statements(tree) + local_defs = _collect_local_defs(statements) + + aliases = _collect_assignment_aliases( + statements, + "testmod", + {}, + local_defs, + {"testmod.A", "testmod.B", "testmod.Final"}, + ) + + assert aliases["m"] == "testmod.Final" + assert aliases["exposed"] == "testmod.Final" + + +def test_collect_assignment_aliases_fails_closed_on_cyclic_dependency_propagation() -> None: + tree = ast.parse(_DEPENDENT_CYCLE_SOURCE) + statements = _module_level_statements(tree) + local_defs = _collect_local_defs(statements) + local_class_targets = {"testmod.A", "testmod.B"} + + result: dict[str, bool] = {} + + def _collect() -> None: + with pytest.raises(_CallGraphAnalysisLimitError, match="entered a propagation cycle"): + _collect_assignment_aliases( + statements, + "testmod", + {}, + local_defs, + local_class_targets, + ) + result["limited"] = True + + _run_with_timeout(_collect) + + assert result == {"limited": True} + + +def test_collect_assignment_aliases_fails_closed_on_long_period_cycles() -> None: + periods = (7, 11, 13, 17, 19) + source_lines: list[str] = [] + local_class_targets: set[str] = set() + for ring_index, period in enumerate(periods): + for position in range(period): + class_name = f"Ring{ring_index}Class{position}" + variable_name = f"ring_{ring_index}_{position}" + source_lines.extend((f"class {class_name}:", " pass", "", f"{variable_name} = {class_name}()", "")) + local_class_targets.add(f"testmod.{class_name}") + for position in range(period): + next_position = (position + 1) % period + source_lines.append(f"ring_{ring_index}_{position} = ring_{ring_index}_{next_position}") + + tree = ast.parse("\n".join(source_lines)) + statements = _module_level_statements(tree) + local_defs = _collect_local_defs(statements) + result: dict[str, bool] = {} + + def _collect() -> None: + with pytest.raises(_CallGraphAnalysisLimitError): + _collect_assignment_aliases( + statements, + "testmod", + {}, + local_defs, + local_class_targets, + ) + result["limited"] = True + + _run_with_timeout(_collect) + + assert result == {"limited": True} + + +def test_assignment_alias_limit_is_not_hidden_by_safe_entrypoint_wrapper(monkeypatch: pytest.MonkeyPatch) -> None: + def _raise_limit(_function_name: str) -> tuple[str, ...]: + raise _CallGraphAnalysisLimitError("assignment alias limit") + + monkeypatch.setattr("modelaudit_picklescan.call_graph._call_graph_entrypoints", _raise_limit) + _safe_call_graph_entrypoints.cache_clear() + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit"): + _safe_call_graph_entrypoints("long_period.module") + + +def test_assignment_alias_limit_is_not_hidden_by_path_search() -> None: + def _raise_limit(_entrypoint: str) -> tuple[str, ...] | None: + raise _CallGraphAnalysisLimitError("assignment alias limit") + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit"): + _first_matching_path(("wrapper.entrypoint",), _raise_limit) + + +def test_assignment_alias_limit_retains_later_matching_entrypoint_path() -> None: + def _path(entrypoint: str) -> tuple[str, ...] | None: + if entrypoint == "constructor.__new__": + raise _CallGraphAnalysisLimitError("assignment alias limit") + return ("constructor.__init__", "builtins.exec") + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + _first_matching_path(("constructor.__new__", "constructor.__init__"), _path) + + assert exc_info.value.partial_path == ("constructor.__init__", "builtins.exec") + + +def test_assignment_alias_limit_preserves_prior_call_graph_findings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + references = ( + {"module": "dangerous", "name": "entry"}, + {"module": "limited", "name": "entry"}, + ) + + def _iter_references( + _import_references: object, + _callable_references: tuple[dict[str, object], ...], + _invoked_references: set[tuple[str, str]], + ) -> tuple[dict[str, str], ...]: + return references + + def _entrypoints(module: str, name: str, _reference: dict[str, object]) -> tuple[str, ...]: + return (f"{module}.{name}",) + + def _path(entrypoints: Iterable[str], _path_for: object) -> tuple[str, ...] | None: + if tuple(entrypoints) == ("limited.entry",): + raise _CallGraphAnalysisLimitError("assignment alias limit") + return ("dangerous.entry", "builtins.exec") + + monkeypatch.setattr("modelaudit_picklescan.call_graph._iter_call_graph_references", _iter_references) + monkeypatch.setattr("modelaudit_picklescan.call_graph._call_graph_entrypoints_for_reference", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._first_matching_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_dangerous_call_graphs(()) + + assert exc_info.value.partial_findings == ( + CallGraphFinding( + module="dangerous", + name="entry", + import_reference="dangerous.entry", + sink="builtins.exec", + call_path=("dangerous.entry", "builtins.exec"), + ), + ) + + +def test_assignment_alias_limit_preserves_later_call_graph_findings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + references = ( + {"module": "limited", "name": "entry"}, + {"module": "dangerous", "name": "entry"}, + ) + + def _iter_references( + _import_references: object, + _callable_references: tuple[dict[str, object], ...], + _invoked_references: set[tuple[str, str]], + ) -> tuple[dict[str, str], ...]: + return references + + def _entrypoints(module: str, name: str, _reference: dict[str, object]) -> tuple[str, ...]: + return (f"{module}.{name}",) + + def _path(entrypoints: Iterable[str], _path_for: object) -> tuple[str, ...] | None: + if tuple(entrypoints) == ("limited.entry",): + raise _CallGraphAnalysisLimitError("assignment alias limit") + return ("dangerous.entry", "builtins.exec") + + monkeypatch.setattr("modelaudit_picklescan.call_graph._iter_call_graph_references", _iter_references) + monkeypatch.setattr("modelaudit_picklescan.call_graph._call_graph_entrypoints_for_reference", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._first_matching_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_dangerous_call_graphs(()) + + assert exc_info.value.partial_findings == ( + CallGraphFinding( + module="dangerous", + name="entry", + import_reference="dangerous.entry", + sink="builtins.exec", + call_path=("dangerous.entry", "builtins.exec"), + ), + ) + + +def test_assignment_alias_limit_preserves_invoked_finding_after_sink_limit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reference = {"module": "invoked", "name": "entry"} + path_calls = 0 + + def _iter_references( + _import_references: object, + _callable_references: tuple[dict[str, object], ...], + _invoked_references: set[tuple[str, str]], + ) -> tuple[dict[str, str], ...]: + return (reference,) + + def _entrypoints(module: str, name: str, _reference: dict[str, object]) -> tuple[str, ...]: + return (f"{module}.{name}",) + + def _path(_entrypoints: Iterable[str], _path_for: object) -> tuple[str, ...] | None: + nonlocal path_calls + path_calls += 1 + if path_calls == 1: + raise _CallGraphAnalysisLimitError("assignment alias limit") + return ("invoked.entry", "builtins.exec") + + monkeypatch.setattr("modelaudit_picklescan.call_graph._iter_call_graph_references", _iter_references) + monkeypatch.setattr("modelaudit_picklescan.call_graph._call_graph_entrypoints_for_reference", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._first_matching_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_dangerous_call_graphs( + (), + ({"module": "invoked", "name": "entry", "positional_arg_count": 1},), + ) + + assert exc_info.value.partial_findings == ( + CallGraphFinding( + module="invoked", + name="entry", + import_reference="invoked.entry", + sink="builtins.exec", + call_path=("invoked.entry", "builtins.exec"), + ), + ) + + +def test_assignment_alias_limit_preserves_later_entrypoint_finding( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reference = {"module": "constructor", "name": "Type"} + + def _iter_references( + _import_references: object, + _callable_references: tuple[dict[str, object], ...], + _invoked_references: set[tuple[str, str]], + ) -> tuple[dict[str, str], ...]: + return (reference,) + + def _entrypoints(_module: str, _name: str, _reference: dict[str, object]) -> tuple[str, ...]: + return ("constructor.Type.__new__", "constructor.Type.__init__") + + def _path(entrypoint: str) -> tuple[str, ...] | None: + if entrypoint.endswith(".__new__"): + raise _CallGraphAnalysisLimitError("assignment alias limit") + return ("constructor.Type.__init__", "builtins.exec") + + monkeypatch.setattr("modelaudit_picklescan.call_graph._iter_call_graph_references", _iter_references) + monkeypatch.setattr("modelaudit_picklescan.call_graph._call_graph_entrypoints_for_reference", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._find_sink_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_dangerous_call_graphs(()) + + assert exc_info.value.partial_findings == ( + CallGraphFinding( + module="constructor", + name="Type", + import_reference="constructor.Type", + sink="builtins.exec", + call_path=("constructor.Type.__init__", "builtins.exec"), + ), + ) + + +def test_assignment_alias_limit_preserves_prior_startup_hook_findings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + references = ( + {"module": "opener", "name": "entry"}, + {"module": "writer", "name": "entry"}, + {"module": "limited", "name": "entry"}, + ) + + def _entrypoints(function_name: str) -> tuple[str, ...]: + if function_name == "limited.entry": + raise _CallGraphAnalysisLimitError("assignment alias limit") + return (function_name,) + + def _path(entrypoints: Iterable[str], path_for: object) -> tuple[str, ...] | None: + entrypoint = next(iter(entrypoints)) + path_name = getattr(path_for, "__name__", "") + if path_name == "_find_file_open_path" and entrypoint == "opener.entry": + return ("opener.entry", "builtins.open") + if path_name == "_find_file_write_path" and entrypoint == "writer.entry": + return ("writer.entry", "binary_file.write") + return None + + monkeypatch.setattr("modelaudit_picklescan.call_graph._safe_call_graph_entrypoints", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._first_matching_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_startup_hook_write_call_graphs(references) + + assert exc_info.value.partial_startup_hook_write_findings == ( + StartupHookWriteFinding( + opener_module="opener", + opener_name="entry", + writer_module="writer", + writer_name="entry", + opener_import_reference="opener.entry", + writer_import_reference="writer.entry", + open_sink="builtins.open", + write_sink="binary_file.write", + opener_call_path=("opener.entry", "builtins.open"), + writer_call_path=("writer.entry", "binary_file.write"), + ), + ) + + +def test_assignment_alias_limit_preserves_later_startup_hook_findings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + references = ( + {"module": "limited", "name": "entry"}, + {"module": "opener", "name": "entry"}, + {"module": "writer", "name": "entry"}, + ) + + def _entrypoints(function_name: str) -> tuple[str, ...]: + if function_name == "limited.entry": + raise _CallGraphAnalysisLimitError("assignment alias limit") + return (function_name,) + + def _path(entrypoints: Iterable[str], path_for: object) -> tuple[str, ...] | None: + entrypoint = next(iter(entrypoints)) + path_name = getattr(path_for, "__name__", "") + if path_name == "_find_file_open_path" and entrypoint == "opener.entry": + return ("opener.entry", "builtins.open") + if path_name == "_find_file_write_path" and entrypoint == "writer.entry": + return ("writer.entry", "binary_file.write") + return None + + monkeypatch.setattr("modelaudit_picklescan.call_graph._safe_call_graph_entrypoints", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._first_matching_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_startup_hook_write_call_graphs(references) + + assert exc_info.value.partial_startup_hook_write_findings == ( + StartupHookWriteFinding( + opener_module="opener", + opener_name="entry", + writer_module="writer", + writer_name="entry", + opener_import_reference="opener.entry", + writer_import_reference="writer.entry", + open_sink="builtins.open", + write_sink="binary_file.write", + opener_call_path=("opener.entry", "builtins.open"), + writer_call_path=("writer.entry", "binary_file.write"), + ), + ) + + +def test_assignment_alias_limit_preserves_startup_hook_findings_after_sink_limit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + references = ( + {"module": "opener", "name": "entry"}, + {"module": "writer", "name": "entry"}, + ) + + def _entrypoints(function_name: str) -> tuple[str, ...]: + return (function_name,) + + def _path(entrypoints: Iterable[str], path_for: object) -> tuple[str, ...] | None: + entrypoint = next(iter(entrypoints)) + path_name = getattr(path_for, "__name__", "") + if path_name == "_find_sink_path": + raise _CallGraphAnalysisLimitError("assignment alias limit") + if path_name == "_find_file_open_path" and entrypoint == "opener.entry": + return ("opener.entry", "builtins.open") + if path_name == "_find_file_write_path" and entrypoint == "writer.entry": + return ("writer.entry", "binary_file.write") + return None + + monkeypatch.setattr("modelaudit_picklescan.call_graph._safe_call_graph_entrypoints", _entrypoints) + monkeypatch.setattr("modelaudit_picklescan.call_graph._first_matching_path", _path) + + with pytest.raises(_CallGraphAnalysisLimitError, match="assignment alias limit") as exc_info: + find_startup_hook_write_call_graphs(references) + + assert exc_info.value.partial_startup_hook_write_findings == ( + StartupHookWriteFinding( + opener_module="opener", + opener_name="entry", + writer_module="writer", + writer_name="entry", + opener_import_reference="opener.entry", + writer_import_reference="writer.entry", + open_sink="builtins.open", + write_sink="binary_file.write", + opener_call_path=("opener.entry", "builtins.open"), + writer_call_path=("writer.entry", "binary_file.write"), + ), + ) + + +def _stack_global_payload(module: str, name: str) -> bytes: + def _operand(value: str) -> bytes: + data = value.encode() + return b"\x8c" + bytes([len(data)]) + data + + return b"\x80\x04" + _operand(module) + _operand(name) + b"\x93" + _operand("proof_of_bypass") + b"\x85R." + + +@pytest.mark.skipif( + find_spec(_RUST_EXTENSION_MODULE) is None, + reason="Rust picklescan extension is not built", +) +@pytest.mark.skipif( + find_spec("imaplib") is None, + reason="imaplib stdlib module is unavailable", +) +def test_scan_imaplib_reference_terminates_and_flags() -> None: + """The issue #1247 proof-of-concept must finish and stay flagged.""" + payload = _stack_global_payload("imaplib", "test") + result: dict[str, PickleReport] = {} + + def _scan() -> None: + result["report"] = scan_bytes(payload) + + _run_with_timeout(_scan) + + report = result["report"] + severities = {finding.severity for finding in report.findings} + assert Severity.CRITICAL in severities diff --git a/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py b/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py index ea9f87ace..c950cae46 100644 --- a/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py +++ b/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py @@ -2102,6 +2102,111 @@ def test_scan_bytes_analyzes_shadowed_torch_extension_callable_invocation( assert _has_critical_call_graph_finding(report, "torch", "device", "os.system") +def test_scan_bytes_fails_closed_when_shadowed_torch_extension_analysis_hits_alias_limit( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + module_dir = tmp_path / "modules" + module_dir.mkdir() + (module_dir / "torch.py").write_text( + """\ +import os + + +class Dangerous: + def __call__(self, command): + os.system(command) + + +class Safe: + def __call__(self, command): + return command + + +if cond: + device = Dangerous() +else: + device = Safe() +""", + encoding="utf-8", + ) + monkeypatch.syspath_prepend(str(module_dir)) + importlib.invalidate_caches() + _clear_call_graph_caches() + + try: + report = scan_bytes( + _global_call_payload("torch", "device", _unicode_operand("echo shadowed-torch-device")), + source="shadowed-torch-device-alias-limit.pkl", + ) + finally: + _clear_call_graph_caches() + + assert report.status == ScanStatus.INCONCLUSIVE + assert report.metadata["analysis_incomplete"] is True + assert not any(error.category == "rust_engine_error" for error in report.errors) + assert any( + error.category == "call_graph_analysis_error" and error.exception_type == "_CallGraphAnalysisLimitError" + for error in report.errors + ) + + +def test_scan_bytes_fails_closed_on_installed_package_alias_read_before_overwrite( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + module_dir = tmp_path / "site-packages" + module_dir.mkdir() + module_name = "installed_alias_read_before_overwrite" + (module_dir / f"{module_name}.py").write_text( + """\ +import os + + +class Dangerous: + def __call__(self, command): + os.system(command) + + +class Safe: + def __call__(self, command): + return command + + +class Final: + def __call__(self, command): + return command + + +if cond: + entry = Dangerous() +else: + entry = Safe() +exposed = entry +entry = Final() +""", + encoding="utf-8", + ) + monkeypatch.syspath_prepend(str(module_dir)) + importlib.invalidate_caches() + _clear_call_graph_caches() + + try: + report = scan_bytes( + _global_call_payload(module_name, "exposed", _unicode_operand("echo hidden-branch")), + source="installed-alias-read-before-overwrite.pkl", + ) + finally: + _clear_call_graph_caches() + + assert report.status == ScanStatus.INCONCLUSIVE + assert report.metadata["analysis_incomplete"] is True + assert any( + error.category == "call_graph_analysis_error" and error.exception_type == "_CallGraphAnalysisLimitError" + for error in report.errors + ) + + def test_call_graph_analyzes_shadowed_torch_storage_persistent_id_reference( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/conftest.py b/tests/conftest.py index 5392fa51d..805d9abfa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,6 +192,7 @@ def pytest_runtest_setup(item): "test_call_graph_local_imports.py", # standalone picklescan function-local import RCE regressions "test_call_graph_six.py", # standalone picklescan six.moves alias RCE regressions "test_call_graph_tkinter.py", # standalone picklescan Tcl call-graph RCE regressions + "test_call_graph_assignment_alias_cycle.py", # standalone picklescan alias fixpoint termination regressions "test_dill_joblib_enhanced.py", # Dill/joblib pickle routing regression tests "test_pickle_context_filtering.py", # Pickle context filtering regression tests "test_xdist_status.py", # xdist worker progress reporting tests