diff --git a/.github/instructions/style-guide.instructions.md b/.github/instructions/style-guide.instructions.md index f5f245f9f..42fdab331 100644 --- a/.github/instructions/style-guide.instructions.md +++ b/.github/instructions/style-guide.instructions.md @@ -293,6 +293,26 @@ def process_items(self, *, items: list[str]) -> list[str]: return [] ``` +## Deprecations + +When deprecating a public class, function, method, parameter, or module path, use `pyrit.common.deprecation.print_deprecation_message` — never `warnings.warn` directly. It wraps `warnings.warn(..., DeprecationWarning, stacklevel=3)` with a consistent format so filtering still works. + +Set `removed_in` to **current version + 2 minor versions** (e.g. `0.14.x` → `removed_in="0.16.0"`). This gives one full release cycle of warning before removal. + +```python +from pyrit.common.deprecation import print_deprecation_message + +def old_method(self, *, foo: str) -> None: + print_deprecation_message( + old_item="MyClass.old_method", + new_item="MyClass.new_method", + removed_in="0.16.0", + ) + ... +``` + +`old_item` / `new_item` accept a class/callable (qualified name is generated) or a string. + ## Pythonic Patterns ### List Comprehensions diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 37f0ff1b7..6b04f2a71 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -409,9 +409,13 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari error = scenario_result.error_message error_type = scenario_result.error_type - # Fallback: look up error from persisted error AttackResults - if not error and scenario_result.error_attack_result_ids: - error_ars = self._memory.get_attack_results(attack_result_ids=scenario_result.error_attack_result_ids) + # Fallback: look up error from any persisted error AttackResults linked + # to this scenario via the new attribution_parent_id foreign key. + if not error: + error_ars = self._memory.get_attack_results( + scenario_result_id=scenario_result_id, + outcome=AttackOutcome.ERROR, + ) if error_ars: error = error_ars[0].error_message error_type = error_ars[0].error_type diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index fea9315e0..6405b6816 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -19,6 +19,7 @@ ) from pyrit.executor.attack.core.attack_parameters import AttackParameters +from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution from pyrit.executor.attack.core.attack_strategy import ( AttackStrategy, AttackStrategyContextT, @@ -142,6 +143,7 @@ async def execute_attack_from_seed_groups_async( objective_scorer: Optional["TrueFalseScorer"] = None, field_overrides: Optional[Sequence[dict[str, Any]]] = None, return_partial_on_failure: bool = False, + attribution: Optional[AttackResultAttribution] = None, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -163,6 +165,12 @@ async def execute_attack_from_seed_groups_async( from_seed_group() as overrides. return_partial_on_failure: If True, returns partial results when some objectives fail. If False (default), raises the first exception. + attribution: Optional ``AttackResultAttribution`` stamped onto every + per-task ``AttackContext`` so the persisted ``AttackResultEntry`` + row carries ``attribution_parent_id`` + ``attribution_data``. + When ``None`` (default), no attribution is applied. The same + attribution is shared across all tasks; per-task identity is + reconstructed from the row's own ``objective_sha256``. **broadcast_fields: Fields applied to all seed groups (e.g., memory_labels). Per-seed-group field_overrides take precedence. @@ -205,6 +213,7 @@ async def build_params(i: int, sg: SeedAttackGroup) -> AttackParameters: attack=attack, params_list=params_list, return_partial_on_failure=return_partial_on_failure, + attribution=attribution, ) async def execute_attack_async( @@ -214,6 +223,7 @@ async def execute_attack_async( objectives: Sequence[str], field_overrides: Optional[Sequence[dict[str, Any]]] = None, return_partial_on_failure: bool = False, + attribution: Optional[AttackResultAttribution] = None, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -228,6 +238,9 @@ async def execute_attack_async( must match the length of objectives. return_partial_on_failure: If True, returns partial results when some objectives fail. If False (default), raises the first exception. + attribution: Optional ``AttackResultAttribution`` stamped onto every + per-task ``AttackContext`` so the persistence path can record + orchestrator linkage. When ``None``, no attribution is applied. **broadcast_fields: Fields applied to all objectives (e.g., memory_labels). Per-objective field_overrides take precedence. @@ -268,6 +281,7 @@ async def execute_attack_async( attack=attack, params_list=params_list, return_partial_on_failure=return_partial_on_failure, + attribution=attribution, ) async def _execute_with_params_list_async( @@ -276,6 +290,7 @@ async def _execute_with_params_list_async( attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], params_list: Sequence[AttackParameters], return_partial_on_failure: bool = False, + attribution: Optional[AttackResultAttribution] = None, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute attacks in parallel with a list of pre-built parameters. @@ -287,19 +302,23 @@ async def _execute_with_params_list_async( attack: The attack strategy to execute. params_list: List of AttackParameters, one per execution. return_partial_on_failure: If True, returns partial results on failure. + attribution: Optional ``AttackResultAttribution`` stamped onto every + per-task ``AttackContext`` so the persistence path can record + orchestrator linkage. Returns: AttackExecutorResult with completed results and any incomplete objectives. """ semaphore = asyncio.Semaphore(self._max_concurrency) - async def run_one(params: AttackParameters) -> AttackStrategyResultT: + async def run_one(index: int, params: AttackParameters) -> AttackStrategyResultT: async with semaphore: - # Create context with params context = attack._context_type(params=params) + if attribution is not None: + context._attribution = attribution return await attack.execute_with_context_async(context=context) - tasks = [run_one(p) for p in params_list] + tasks = [run_one(i, p) for i, p in enumerate(params_list)] results_or_exceptions = await asyncio.gather(*tasks, return_exceptions=True) return self._process_execution_results( diff --git a/pyrit/executor/attack/core/attack_result_attribution.py b/pyrit/executor/attack/core/attack_result_attribution.py new file mode 100644 index 000000000..2953f7160 --- /dev/null +++ b/pyrit/executor/attack/core/attack_result_attribution.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Generic attribution metadata an orchestrator stamps onto an +``AttackContext`` so the persisted ``AttackResult`` carries linkage back to +whatever produced it. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class AttackResultAttribution: + """ + Attribution copied onto an ``AttackResult`` by the persistence path so the + DB row records its lineage. + + All fields are opaque to the attack layer; the orchestrator chooses what + they mean. Together, ``(parent_collection, parent_eval_hash)`` form the + "what was running" key the orchestrator can use to scope per-parent + work on resume. Two ``AtomicAttack`` instances sharing a name but using + different techniques (e.g. base64 vs hex encoders) get distinct + ``parent_eval_hash`` values, so resume bookkeeping never cross-pollinates. + + Attributes: + parent_id (str): The ID of the parent entity. Persisted to + ``AttackResultEntry.attribution_parent_id`` (foreign key to + ``ScenarioResultEntries.id``) so per-parent loading is indexed. + ``Scenario`` sets this to the scenario result UUID, e.g. + ``self._scenario_result_id`` + (``"8a8d17b9-b671-4a3d-8170-e65ea9b44053"``). + parent_collection (str): Free-form label naming the per-parent + collection this result belongs to. Persisted into + ``AttackResultEntry.attribution_data``. ``Scenario`` sets this + to the atomic attack name, e.g. ``self.atomic_attack_name`` + (``"encoded_jailbreaks"``). + parent_eval_hash (str | None): Optional content-addressed hash that + disambiguates configurations sharing the same + ``parent_collection``. Persisted into + ``AttackResultEntry.attribution_data``. ``Scenario`` sets this + to the atomic attack's technique evaluation hash, e.g. + ``self.technique_eval_hash`` (computed via + ``AtomicAttackEvaluationIdentifier``). + """ + + parent_id: str + parent_collection: str + parent_eval_hash: str | None = None diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index a054ace81..ee8ae379e 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution from pyrit.prompt_target import PromptTarget AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext[Any]") @@ -70,8 +71,12 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): _prepended_conversation_override: Optional[list[Message]] = None _memory_labels_override: Optional[dict[str, str]] = None - # Set by the ON_ERROR handler to link error AttackResults to ScenarioResults - _error_attack_result_id: str | None = None + # Optional attribution from an upstream orchestrator (e.g. Scenario). When + # set, the persistence path stamps attribution_parent_id + attribution_data + # onto the resulting AttackResult so it can be located later for hydration + # and resume. Set by AttackExecutor per-task before scheduling. Stays None + # for ad-hoc/direct attack execution outside any orchestrator. + _attribution: Optional[AttackResultAttribution] = None # Convenience properties that delegate to params or overrides @property @@ -223,11 +228,46 @@ async def _on_post_execute( event_data.result.retry_events = collector.events event_data.result.total_retries = len(collector.events) + # Stamp attribution onto the result before persistence so the + # AttackResultEntry row records its lineage. Outside an orchestrator + # _attribution is None and both attribution fields stay None. + self._apply_attribution(context=event_data.context, result=event_data.result) + self._logger.debug(f"Attack execution completed in {execution_time_ms}ms") self._log_attack_outcome(event_data.result) self._memory.add_attack_results_to_memory(attack_results=[event_data.result]) + @staticmethod + def _apply_attribution( + *, + context: AttackStrategyContextT, + result: AttackResult, + ) -> None: + """ + Copy attribution from the AttackContext onto the AttackResult. + + Reads ``context._attribution`` (an ``AttackResultAttribution`` set by + the AttackExecutor when an upstream orchestrator supplied a factory). + When present, writes ``attribution_parent_id`` and a fixed-schema + ``attribution_data`` dict onto the result so they round-trip into + ``AttackResultEntry``. + + Args: + context: The per-task AttackContext. + result: The AttackResult that is about to be persisted. + """ + attribution = context._attribution + if attribution is None: + return + result.attribution_parent_id = attribution.parent_id + attribution_data: dict[str, Any] = { + "parent_collection": attribution.parent_collection, + } + if attribution.parent_eval_hash is not None: + attribution_data["parent_eval_hash"] = attribution.parent_eval_hash + result.attribution_data = attribution_data + def _log_attack_outcome(self, result: AttackResult) -> None: """ Log the outcome of the attack. @@ -267,9 +307,6 @@ async def _on_error_async( if not error or not context: return - # Clear any stale ID from a previous execution - context._error_attack_result_id = None - # Collect retry events (visible via inherited ContextVar copy) collector = get_retry_collector() retry_events = collector.events if collector else [] @@ -295,10 +332,11 @@ async def _on_error_async( if context.start_time: error_result.execution_time_ms = int((end_time - context.start_time) * 1000) - # Persist first, then set the ID on the context so scenario-level code - # only sees the reference if the write succeeded. + # Stamp attribution onto the error result so it is locatable via the + # attribution_parent_id foreign key on resume. + self._apply_attribution(context=context, result=error_result) + self._memory.add_attack_results_to_memory(attack_results=[error_result]) - context._error_attack_result_id = error_result.attack_result_id self._logger.error(f"Attack failed with {type(error).__name__}: {error}") diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 3aa9ede03..fe299e640 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -31,11 +31,7 @@ class _StrategyRuntimeError(RuntimeError): - """RuntimeError subclass that carries an optional error_attack_result_id.""" - - def __init__(self, message: str, *, error_attack_result_id: str | None = None) -> None: - super().__init__(message) - self.error_attack_result_id = error_attack_result_id + """RuntimeError subclass for strategy execution failures.""" @dataclass @@ -386,9 +382,7 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra else: error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}" - # Attach the error attack result ID if the ON_ERROR handler created one - error_attack_result_id = getattr(context, "_error_attack_result_id", None) - runtime_error = _StrategyRuntimeError(error_message, error_attack_result_id=error_attack_result_id) + runtime_error = _StrategyRuntimeError(error_message) raise runtime_error from e async def execute_async(self, **kwargs: Any) -> StrategyResultT: diff --git a/pyrit/memory/alembic/versions/9c8b7a6d5e4f_add_attribution_to_attack_results.py b/pyrit/memory/alembic/versions/9c8b7a6d5e4f_add_attribution_to_attack_results.py new file mode 100644 index 000000000..06559825c --- /dev/null +++ b/pyrit/memory/alembic/versions/9c8b7a6d5e4f_add_attribution_to_attack_results.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Add attribution_parent_id (foreign key) + attribution_data (JSON) to +AttackResultEntries; drop ScenarioResultEntries.error_attack_result_ids_json; +backfill the linkage from the existing attack_results_json manifest. + +Revision ID: 9c8b7a6d5e4f +Revises: 7a1b2c3d4e5f +Create Date: 2026-05-18 15:00:00.000000 +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Sequence # noqa: TC003 +from typing import Any + +import sqlalchemy as sa +from alembic import op + +from pyrit.memory.memory_models import CustomUUID + +# revision identifiers, used by Alembic. +revision: str = "9c8b7a6d5e4f" +down_revision: str | None = "7a1b2c3d4e5f" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +logger = logging.getLogger(__name__) + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # AttackResultEntries: attribution / parent linkage columns. + op.add_column( + "AttackResultEntries", + sa.Column("attribution_parent_id", CustomUUID(), nullable=True), + ) + op.add_column( + "AttackResultEntries", + sa.Column("attribution_data", sa.JSON(), nullable=True), + ) + op.create_index( + "ix_AttackResultEntries_attribution_parent_id", + "AttackResultEntries", + ["attribution_parent_id"], + ) + + # Foreign key with ON DELETE SET NULL: deleting a scenario nulls the + # attribution_parent_id on its AttackResults; attribution_data is retained + # as historical provenance. Use a batch operation for SQLite portability + # (no plain ALTER TABLE ADD CONSTRAINT for foreign keys on SQLite). + with op.batch_alter_table("AttackResultEntries") as batch_op: + batch_op.create_foreign_key( + "fk_attack_results_attribution_parent", + "ScenarioResultEntries", + ["attribution_parent_id"], + ["id"], + ondelete="SET NULL", + ) + + # ScenarioResultEntries: drop the not-yet-released error_attack_result_ids_json, + # and add scenario_metadata for free-form scenario-level JSON state (e.g. + # the persisted objective_hashes used for resume). + # Error AttackResults are now linkable via the new attribution_parent_id + # foreign key; the per-scenario manifest column is no longer used. + # Wrapped in a batch op for SQLite. + with op.batch_alter_table("ScenarioResultEntries") as batch_op: + batch_op.drop_column("error_attack_result_ids_json") + batch_op.add_column(sa.Column("scenario_metadata", sa.JSON(), nullable=True)) + + # Backfill attribution linkage from the existing attack_results_json manifest. + _backfill_attribution_linkage() + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # Re-add error_attack_result_ids_json on ScenarioResultEntries and drop scenario_metadata. + with op.batch_alter_table("ScenarioResultEntries") as batch_op: + batch_op.add_column(sa.Column("error_attack_result_ids_json", sa.Unicode(), nullable=True)) + batch_op.drop_column("scenario_metadata") + + # Drop foreign key + columns from AttackResultEntries. + with op.batch_alter_table("AttackResultEntries") as batch_op: + batch_op.drop_constraint("fk_attack_results_attribution_parent", type_="foreignkey") + + op.drop_index("ix_AttackResultEntries_attribution_parent_id", table_name="AttackResultEntries") + op.drop_column("AttackResultEntries", "attribution_data") + op.drop_column("AttackResultEntries", "attribution_parent_id") + + +def _backfill_attribution_linkage() -> None: + """ + Walk every ScenarioResultEntry and copy its attack_results_json manifest + into the new attribution_parent_id + attribution_data columns on + AttackResultEntries. + + Idempotent: the ``WHERE attribution_parent_id IS NULL`` guard prevents + clobbering rows that were already linked (e.g. by a re-run of the + migration, or by code that ran after the schema change but before this + backfill). ``conversation_id`` is logically unique per AttackResult but is + not DB-enforced, so the guard is purely defensive and a WARNING is logged + if any duplicate match is observed in the wild. + """ + bind = op.get_bind() + + scenarios = bind.execute(sa.text('SELECT id, attack_results_json FROM "ScenarioResultEntries"')).fetchall() + + update_stmt = sa.text( + 'UPDATE "AttackResultEntries" ' + "SET attribution_parent_id = :sid, attribution_data = :sdata " + "WHERE conversation_id = :cid AND attribution_parent_id IS NULL" + ) + + total_updates = 0 + duplicate_warnings = 0 + + for row in scenarios: + scenario_id = row[0] + manifest_json = row[1] + if not manifest_json: + continue + try: + manifest: dict[str, Any] = json.loads(manifest_json) + except (TypeError, ValueError): + logger.warning(f"Skipping scenario {scenario_id}: attack_results_json is not valid JSON") + continue + + for atomic_attack_name, conversation_ids in manifest.items(): + if not isinstance(conversation_ids, list): + continue + for conversation_id in conversation_ids: + if not isinstance(conversation_id, str): + continue + # Check for duplicate conversation_id matches (data anomaly). + match_count = bind.execute( + sa.text( + 'SELECT COUNT(*) FROM "AttackResultEntries" ' + "WHERE conversation_id = :cid AND attribution_parent_id IS NULL" + ), + {"cid": conversation_id}, + ).scalar() + if isinstance(match_count, int) and match_count > 1: + duplicate_warnings += 1 + logger.warning( + f"Backfill: conversation_id {conversation_id!r} matches {match_count} " + f"unlinked AttackResultEntries rows; conversation_id should be unique. " + f"All matching rows will be linked to scenario {scenario_id}." + ) + + attribution_data = json.dumps({"parent_collection": atomic_attack_name}) + result = bind.execute( + update_stmt, + { + "sid": str(scenario_id), + "sdata": attribution_data, + "cid": conversation_id, + }, + ) + total_updates += result.rowcount or 0 + + if total_updates or duplicate_warnings: + logger.info( + f"Attribution linkage backfill: linked {total_updates} AttackResultEntries row(s); " + f"{duplicate_warnings} duplicate-conversation_id warning(s)." + ) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 65d4480aa..cc68f8250 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1701,6 +1701,7 @@ def get_attack_results( targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str | Sequence[str]]] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + scenario_result_id: Optional[str] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1750,6 +1751,10 @@ def get_attack_results( identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. + scenario_result_id (Optional[str], optional): Filter to attack results linked to a + specific scenario via the ``AttackResultEntry.attribution_parent_id`` foreign key. + Combined with ``outcome=AttackOutcome.ERROR`` this is the replacement for the + removed per-scenario error_attack_result_ids manifest. Defaults to None. Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. @@ -1783,6 +1788,8 @@ def get_attack_results( conditions.append(AttackResultEntry.objective.contains(objective)) if outcome: conditions.append(AttackResultEntry.outcome == outcome) + if scenario_result_id: + conditions.append(AttackResultEntry.attribution_parent_id == uuid.UUID(scenario_result_id)) if attack_classes: # Case-insensitive to mirror converter_classes; forgives casing drift in @@ -1958,64 +1965,6 @@ def add_scenario_results_to_memory(self, *, scenario_results: Sequence[ScenarioR entries=[ScenarioResultEntry(entry=scenario_result) for scenario_result in scenario_results] ) - def add_attack_results_to_scenario( - self, - *, - scenario_result_id: str, - atomic_attack_name: str, - attack_results: Sequence[AttackResult], - ) -> bool: - """ - Add attack results to an existing scenario result in memory. - - This method efficiently updates a scenario result by appending new attack results - to a specific atomic attack name without requiring a full retrieve-modify-save cycle. - - Args: - scenario_result_id (str): The ID of the scenario result to update. - atomic_attack_name (str): The name of the atomic attack to add results for. - attack_results (Sequence[AttackResult]): The attack results to add. - - Returns: - bool: True if the update was successful, False otherwise. - - Example: - >>> memory.add_attack_results_to_scenario( - ... scenario_result_id="123e4567-e89b-12d3-a456-426614174000", - ... atomic_attack_name="base64_attack", - ... attack_results=[result1, result2] - ... ) - """ - try: - # Retrieve current scenario result - scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) - - if not scenario_results: - logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") - return False - - scenario_result = scenario_results[0] - - # Update attack results for this atomic attack name - if atomic_attack_name not in scenario_result.attack_results: - scenario_result.attack_results[atomic_attack_name] = [] - - scenario_result.attack_results[atomic_attack_name].extend(list(attack_results)) - - # Save updated result back to memory using update - entry = ScenarioResultEntry(entry=scenario_result) - self._update_entry(entry) - - logger.info( - f"Added {len(attack_results)} attack results to scenario {scenario_result_id} " - f"for atomic attack '{atomic_attack_name}'" - ) - return True - - except Exception as e: - logger.exception(f"Failed to add attack results to scenario {scenario_result_id}: {str(e)}") - raise - def update_scenario_run_state( self, *, @@ -2027,6 +1976,13 @@ def update_scenario_run_state( """ Update the run state of an existing scenario result. + Performs a targeted UPDATE of only the state/error columns instead of + rebuilding the entire ``ScenarioResultEntry`` row. The full-row rebuild + used to read the stored row, mutate the ScenarioResult, and re-serialize + every column — including ``attack_results_json`` which is being phased + out and could be stale during the deprecation window. A targeted UPDATE + avoids clobbering manifest data and is also cheaper. + Args: scenario_result_id (str): The ID of the scenario result to update. scenario_run_state (str): The new state for the scenario @@ -2037,62 +1993,50 @@ def update_scenario_run_state( Raises: ValueError: If the scenario result is not found. """ - scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) - - if not scenario_results: - raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") - - scenario_result = scenario_results[0] + with closing(self.get_session()) as session: + entry = session.query(ScenarioResultEntry).filter_by(id=scenario_result_id).first() - # Update the scenario run state - scenario_result.scenario_run_state = scenario_run_state # type: ignore[ty:invalid-assignment] + if not entry: + raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") - if error_message is not None: - scenario_result.error_message = error_message - if error_type is not None: - scenario_result.error_type = error_type + entry.scenario_run_state = scenario_run_state + if error_message is not None: + entry.error_message = error_message + if error_type is not None: + entry.error_type = error_type - # Save updated result back to memory using update - entry = ScenarioResultEntry(entry=scenario_result) - self._update_entry(entry) + session.commit() logger.info(f"Updated scenario {scenario_result_id} state to '{scenario_run_state}'") - def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack_result_ids: list[str]) -> None: + def update_scenario_metadata( + self, + *, + scenario_result_id: str, + metadata: dict[str, Any], + ) -> None: """ - Update the error attack result IDs on an existing scenario result. - - This links failed AttackResults to the ScenarioResult so the REST API - can quickly find error details without scanning all attacks. + Replace the ``scenario_metadata`` JSON blob on an existing scenario result. - Performs the read-modify-write within a single DB session to avoid - inter-session consistency issues. + Used by the scenario layer to persist first-run state (e.g. + ``objective_hashes``) that resume needs to replay. Performs a + targeted UPDATE so it doesn't clobber other columns. Args: - scenario_result_id: The ID of the scenario result to update. - error_attack_result_ids: IDs of AttackResults that contain error information. + scenario_result_id (str): The ID of the scenario result to update. + metadata (dict[str, Any]): The full metadata dict to store. Pass the + merged dict, not just the new keys — this writes the whole value. Raises: ValueError: If the scenario result is not found. """ - import json - with closing(self.get_session()) as session: entry = session.query(ScenarioResultEntry).filter_by(id=scenario_result_id).first() - if not entry: raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") - - existing: list[str] = ( - json.loads(entry.error_attack_result_ids_json) if entry.error_attack_result_ids_json else [] - ) - merged = list(dict.fromkeys(existing + error_attack_result_ids)) - entry.error_attack_result_ids_json = json.dumps(merged) - + entry.scenario_metadata = metadata if metadata else None session.commit() - logger.info(f"Updated scenario {scenario_result_id} with {len(error_attack_result_ids)} error attack result(s)") - def get_scenario_results( self, *, @@ -2144,6 +2088,57 @@ def get_scenario_results( if scenario_result_ids is not None and len(scenario_result_ids) == 0: return [] + conditions = self._build_scenario_result_query_conditions( + scenario_name=scenario_name, + scenario_version=scenario_version, + pyrit_version=pyrit_version, + added_after=added_after, + added_before=added_before, + labels=labels, + objective_target_endpoint=objective_target_endpoint, + objective_target_model_name=objective_target_model_name, + identifier_filters=identifier_filters, + ) + + try: + entries = self._query_scenario_result_entries( + scenario_result_ids=scenario_result_ids, + conditions=conditions, + limit=limit, + ) + + attack_results_by_scenario = self._get_attack_results_by_scenario(entries=entries) + + scenario_results: list[ScenarioResult] = [] + for entry in entries: + scenario_result = entry.get_scenario_result() + scenario_result.attack_results = attack_results_by_scenario.get(entry.id, {}) + scenario_results.append(scenario_result) + + return scenario_results + except Exception as e: + logger.exception(f"Failed to retrieve scenario results with error {e}") + raise + + def _build_scenario_result_query_conditions( + self, + *, + scenario_name: str | None, + scenario_version: int | None, + pyrit_version: str | None, + added_after: datetime | None, + added_before: datetime | None, + labels: dict[str, str] | None, + objective_target_endpoint: str | None, + objective_target_model_name: str | None, + identifier_filters: Sequence[IdentifierFilter] | None, + ) -> "list[ColumnElement[bool]]": + """ + Build the WHERE conditions for ``get_scenario_results``. + + Returns: + list[ColumnElement[bool]]: SQLAlchemy WHERE clauses derived from the supplied filters. + """ conditions: list[ColumnElement[bool]] = [] if scenario_name: @@ -2197,64 +2192,97 @@ def get_scenario_results( ) ) - try: - order_by_clause = ScenarioResultEntry.completion_time.desc() - - # Handle scenario_result_ids with batched queries if needed - if scenario_result_ids: - entries = self._execute_batched_query( - ScenarioResultEntry, - batch_column=ScenarioResultEntry.id, - batch_values=list(scenario_result_ids), - other_conditions=conditions, - order_by=order_by_clause, - limit=limit, - ) - else: - entries = self._query_entries( - ScenarioResultEntry, - conditions=and_(*conditions) if conditions else None, - order_by=order_by_clause, - limit=limit, - ) + return conditions - # Convert entries to ScenarioResults and populate attack_results efficiently - scenario_results = [] - for entry in entries: - scenario_result = entry.get_scenario_result() + def _query_scenario_result_entries( + self, + *, + scenario_result_ids: Sequence[str] | None, + conditions: "list[ColumnElement[bool]]", + limit: int | None, + ) -> Sequence[ScenarioResultEntry]: + """ + Run the (possibly batched) ScenarioResultEntry query. - # Get conversation IDs grouped by attack name - conversation_ids_by_attack = entry.get_conversation_ids_by_attack_name() + Returns: + Sequence[ScenarioResultEntry]: The matching rows ordered by completion_time descending. + """ + order_by_clause = ScenarioResultEntry.completion_time.desc() - # Collect all conversation IDs to query in a single batch - all_conversation_ids = [] - for conv_ids in conversation_ids_by_attack.values(): - all_conversation_ids.extend(conv_ids) + if scenario_result_ids: + return self._execute_batched_query( + ScenarioResultEntry, + batch_column=ScenarioResultEntry.id, + batch_values=list(scenario_result_ids), + other_conditions=conditions, + order_by=order_by_clause, + limit=limit, + ) - # Query all AttackResults using batched queries if needed - if all_conversation_ids: - attack_entries = self._execute_batched_query( - AttackResultEntry, - batch_column=AttackResultEntry.conversation_id, - batch_values=all_conversation_ids, - ) + return self._query_entries( + ScenarioResultEntry, + conditions=and_(*conditions) if conditions else None, + order_by=order_by_clause, + limit=limit, + ) - # Build a dict for quick lookup - attack_results_dict = {entry.conversation_id: entry.get_attack_result() for entry in attack_entries} + def _get_attack_results_by_scenario( + self, + *, + entries: Sequence[ScenarioResultEntry], + ) -> dict[uuid.UUID, dict[str, list[AttackResult]]]: + """ + Fetch every ``AttackResult`` linked to the given scenarios via the + ``AttackResultEntry.attribution_parent_id`` foreign key in a single + batched query, then group by scenario + ``parent_collection`` (which + the scenario layer uses for the atomic attack name) and sort each + group by ``AttackResultEntry.timestamp``. - # Populate attack_results by attack name, preserving order - scenario_result.attack_results = {} - for attack_name, conv_ids in conversation_ids_by_attack.items(): - scenario_result.attack_results[attack_name] = [ - attack_results_dict[conv_id] for conv_id in conv_ids if conv_id in attack_results_dict - ] + Foreign-key linkage is the sole source of truth — set at write-time by + the attack persistence path when an ``AttackResultAttribution`` is on + the context. Rows without a valid ``attribution_data`` payload are + skipped (and logged) rather than guessed at. - scenario_results.append(scenario_result) + Returns: + dict[uuid.UUID, dict[str, list[AttackResult]]]: Mapping of + ``scenario_result_id`` → ``atomic_attack_name`` → ordered list of + ``AttackResult`` objects. Scenarios with no linked rows map to ``{}``. + """ + if not entries: + return {} - return scenario_results - except Exception as e: - logger.exception(f"Failed to retrieve scenario results with error {e}") - raise + scenario_ids = [entry.id for entry in entries] + attack_rows = self._execute_batched_query( + AttackResultEntry, + batch_column=AttackResultEntry.attribution_parent_id, + batch_values=scenario_ids, + ) + + grouped: dict[uuid.UUID, dict[str, list[tuple[datetime, AttackResult]]]] = {entry.id: {} for entry in entries} + + for row in attack_rows: + scenario_id = row.attribution_parent_id + if scenario_id is None or scenario_id not in grouped: + continue + + data = row.attribution_data or {} + name = data.get("parent_collection") + if not name: + logger.debug( + f"Skipping AttackResultEntry {row.id} during scenario load: " + "attribution_data missing parent_collection" + ) + continue + + sort_key = row.timestamp or datetime.min.replace(tzinfo=timezone.utc) + grouped[scenario_id].setdefault(name, []).append((sort_key, row.get_attack_result())) + + return { + scenario_id: { + name: [ar for _, ar in sorted(bucket, key=lambda kv: kv[0])] for name, bucket in name_buckets.items() + } + for scenario_id, name_buckets in grouped.items() + } def print_schema(self) -> None: """ diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index efb9ec72a..2afeb4533 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -741,6 +741,19 @@ class AttackResultEntry(Base): retry_events_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) total_retries = mapped_column(INTEGER, nullable=True, default=0) + # Attribution / parent linkage (set when the AttackResult is produced + # inside an orchestrator that supplies an AttackResultAttribution, e.g. a + # Scenario). attribution_parent_id is an indexed foreign key so per-parent + # hydration and resume queries are direct lookups (no JSON manifest + # required, no orphaning if the orchestrator is interrupted mid-run). + # attribution_data is a documented-fixed-schema JSON blob keyed by + # parent_collection (str). When the AttackResult is created outside an + # orchestrator both fields remain NULL. + attribution_parent_id: Mapped[uuid.UUID | None] = mapped_column( + CustomUUID, ForeignKey("ScenarioResultEntries.id", ondelete="SET NULL"), nullable=True, index=True + ) + attribution_data: Mapped[dict[str, Any | None]] = mapped_column(JSON, nullable=True) + last_response: Mapped["PromptMemoryEntry | None"] = relationship( "PromptMemoryEntry", foreign_keys=[last_response_id], @@ -815,6 +828,11 @@ def __init__(self, *, entry: AttackResult) -> None: ) self.total_retries = entry.total_retries + # Attribution / parent linkage (set by the attack persistence path when + # an AttackResultAttribution is present on the AttackContext; otherwise None) + self.attribution_parent_id = uuid.UUID(entry.attribution_parent_id) if entry.attribution_parent_id else None + self.attribution_data = entry.attribution_data + @staticmethod def _get_id_as_uuid(obj: Any) -> uuid.UUID | None: """ @@ -927,6 +945,8 @@ def get_attack_result(self) -> AttackResult: error_traceback=self.error_traceback, retry_events=retry_events, total_retries=self.total_retries or 0, + attribution_parent_id=str(self.attribution_parent_id) if self.attribution_parent_id else None, + attribution_data=self.attribution_data, ) @@ -988,13 +1008,18 @@ class ScenarioResultEntry(Base): completion_time = mapped_column(DateTime, nullable=False) timestamp = mapped_column(DateTime, nullable=False) - # Pointer to failed attack result(s) — avoids scanning all attacks for error info - error_attack_result_ids_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) - # Scenario-level error info (persisted so it survives process restarts) error_message: Mapped[str | None] = mapped_column(Unicode, nullable=True) error_type: Mapped[str | None] = mapped_column(String, nullable=True) + # Free-form JSON metadata stamped by the scenario. Currently used to record + # ``objective_hashes`` — the objective sha256 set chosen on the + # first run, replayed on resume so a fresh ``random.sample`` can't + # silently change which objectives the scenario operates on. Column is + # named ``scenario_metadata`` because SQLAlchemy's ``DeclarativeBase`` + # reserves ``metadata`` as a class attribute on the model. + scenario_metadata: Mapped[dict[str, Any | None]] = mapped_column(JSON, nullable=True) + def __init__(self, *, entry: ScenarioResult) -> None: """ Initialize a ScenarioResultEntry from a ScenarioResult object. @@ -1044,13 +1069,9 @@ def __init__(self, *, entry: ScenarioResult) -> None: # Serialize display_group_map if present self.display_group_map_json = json.dumps(entry._display_group_map) if entry._display_group_map else None - # Serialize error_attack_result_ids if present - self.error_attack_result_ids_json = ( - json.dumps(entry.error_attack_result_ids) if entry.error_attack_result_ids else None - ) - self.error_message = entry.error_message self.error_type = entry.error_type + self.scenario_metadata = entry.metadata if entry.metadata else None self.timestamp = datetime.now(tz=timezone.utc) @@ -1093,11 +1114,6 @@ def get_scenario_result(self) -> ScenarioResult: if self.display_group_map_json: display_group_map = json.loads(self.display_group_map_json) - # Deserialize error_attack_result_ids if stored - error_attack_result_ids: list[str] | None = None - if self.error_attack_result_ids_json: - error_attack_result_ids = json.loads(self.error_attack_result_ids_json) - return ScenarioResult( id=self.id, scenario_identifier=scenario_identifier, @@ -1110,9 +1126,9 @@ def get_scenario_result(self) -> ScenarioResult: number_tries=self.number_tries, completion_time=self.completion_time, display_group_map=display_group_map, - error_attack_result_ids=error_attack_result_ids, error_message=self.error_message, error_type=self.error_type, + metadata=dict(self.scenario_metadata) if self.scenario_metadata else None, ) def get_conversation_ids_by_attack_name(self) -> dict[str, list[str]]: diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 703a2b90a..02fad197d 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -105,6 +105,14 @@ class AttackResult(StrategyResult): retry_events: list[RetryEvent] = field(default_factory=list) total_retries: int = 0 + # Attribution / parent linkage (infrastructure-managed). Set by the attack + # persistence path when an AttackResultAttribution is present on the + # AttackContext. User code should not set these directly; ad-hoc + # AttackResults created outside an orchestrator leave both fields as None + # and the corresponding DB columns remain NULL. + attribution_parent_id: str | None = None + attribution_data: dict[str, Any] | None = None + @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index ef971e04b..731003498 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -105,9 +105,10 @@ def __init__( number_tries: int = 0, id: uuid.UUID | None = None, # noqa: A002 display_group_map: dict[str, str] | None = None, - error_attack_result_ids: list[str] | None = None, error_message: str | None = None, error_type: str | None = None, + error_attack_result_ids: list[str] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: """ Initialize a scenario result. @@ -127,11 +128,17 @@ def __init__( display_group_map (Optional[dict[str, str]]): Optional mapping of atomic_attack_name → display group label. Used by the console printer to aggregate results for user-facing output. - error_attack_result_ids (Optional[list[str]]): IDs of AttackResults that - contain error information. Used for quick error lookup without scanning - all attack results. error_message (Optional[str]): Scenario-level error message when the run fails. error_type (Optional[str]): Exception class name when the run fails. + error_attack_result_ids (Optional[list[str]]): IDs of attack results that + errored during the scenario run. Defaults to an empty list. + metadata (Optional[dict[str, Any]]): Free-form JSON metadata persisted + with the scenario result. Currently used to record + ``objective_hashes`` — the objective ``sha256`` set chosen + on the first run, replayed on resume so a fresh + ``random.sample`` can't silently change which objectives the + scenario operates on. Keys are not part of any public contract + and may evolve. """ self.id = id if id is not None else uuid.uuid4() @@ -148,9 +155,10 @@ def __init__( self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries self._display_group_map = display_group_map or {} - self.error_attack_result_ids = error_attack_result_ids or [] self.error_message = error_message self.error_type = error_type + self.error_attack_result_ids: list[str] = list(error_attack_result_ids) if error_attack_result_ids else [] + self.metadata: dict[str, Any] = metadata if metadata is not None else {} @property def display_group_map(self) -> dict[str, str]: diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index ef61b8b0b..1138671f3 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -17,8 +17,10 @@ from typing import TYPE_CHECKING, Any, Optional from pyrit.common.deprecation import print_deprecation_message +from pyrit.common.utils import to_sha256 from pyrit.executor.attack import AttackExecutor, AttackStrategy from pyrit.executor.attack.core.attack_executor import AttackExecutorResult +from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution from pyrit.identifiers import build_atomic_attack_identifier from pyrit.identifiers.evaluation_identifier import AtomicAttackEvaluationIdentifier from pyrit.memory import CentralMemory @@ -117,21 +119,93 @@ def __init__( sg.validate() self._seed_groups = seed_groups + self._validate_unique_objective_hashes() self._adversarial_chat = adversarial_chat self._objective_scorer = objective_scorer self._memory_labels = memory_labels or {} self._attack_execute_params = attack_execute_params + # Set via set_scenario_result_id() by Scenario._execute_scenario_async + # before run_async. When set, each persisted AttackResult is linked to + # the scenario via the attribution_parent_id foreign key on + # AttackResultEntry. + self._scenario_result_id: str | None = None logger.info( f"Initialized atomic attack with {len(self._seed_groups)} seed groups, " f"attack type: {type(self._attack_technique.attack).__name__}" ) + def set_scenario_result_id(self, scenario_result_id: str | None) -> None: + """ + Bind this atomic attack to a scenario result for attribution. + + Called by ``Scenario._execute_scenario_async`` before each + ``run_async`` so persisted ``AttackResult`` rows carry the + ``attribution_parent_id`` foreign key back to the scenario. Pass + ``None`` to clear the binding (e.g. when running an atomic attack + outside of a scenario). + + Args: + scenario_result_id (str | None): The scenario result UUID this + atomic attack belongs to, or ``None`` to detach. + """ + self._scenario_result_id = scenario_result_id + + def _validate_unique_objective_hashes(self) -> None: + """ + Ensure each seed group in this atomic attack has a unique objective hash. + + Within a single ``AtomicAttack`` (one ``atomic_attack_name``, one + technique), the objective text identifies a unit of work. Duplicates + would mean two indistinguishable rows on the write side, which makes + resume reconciliation ambiguous — the hash-based resume key treats a + set of hashes as already-done, with no way to distinguish which of two + duplicate rows is "the one" that is still outstanding. + + The hash is currently derived from objective text only. A future + iteration may hash the full ``SeedGroup`` (minus technique-specific + fields) so two seed groups that share an objective string but differ + in other inputs can coexist in one atomic attack. + + Raises: + ValueError: If two seed groups share the same ``objective_sha256``. + """ + seen: dict[str, int] = {} + for sg in self._seed_groups: + if sg.objective is None: + continue + sha = to_sha256(sg.objective.value) + if sha in seen: + raise ValueError( + f"AtomicAttack '{self.atomic_attack_name}' has duplicate objective hash " + f"{sha[:12]}... across seed_groups; each (objective, technique) pair must be unique." + ) + seen[sha] = 1 + @property def attack_technique(self) -> AttackTechnique: """Get the attack technique for this atomic attack.""" return self._attack_technique + @property + def technique_eval_hash(self) -> str: + """ + Behavioral evaluation hash for this atomic attack's technique configuration. + + Builds an ``AtomicAttack`` identifier from this attack's technique + (without any seed group) and runs it through + ``AtomicAttackEvaluationIdentifier`` so target/scorer/seed-identifier + noise is stripped per the standard atomic-attack eval rules. The + result is stable across resume runs and across different seed groups, + which is what makes it usable as the resume disambiguator alongside + ``atomic_attack_name``. + """ + composite = build_atomic_attack_identifier( + technique_identifier=self._attack_technique.get_identifier(), + seed_group=None, + ) + return AtomicAttackEvaluationIdentifier(composite).eval_hash + @property def objectives(self) -> list[str]: """ @@ -152,20 +226,78 @@ def seed_groups(self) -> list[SeedAttackGroup]: """ return list(self._seed_groups) + def drop_seed_groups_with_hashes(self, *, hashes: set[str]) -> None: + """ + Drop seed groups whose ``objective_sha256`` is in ``hashes``. + + This is the resume filter: within an atomic attack, ``objective_sha256`` + is the stable identity (enforced unique by ``__init__``). Content-derived + keys are robust to reordering and resampling, so resume produces the + right remaining-work set even when ``get_seed_groups()`` is rebuilt + from scratch on each ``run_async()``. + + Args: + hashes (set[str]): SHA256 hashes of objective text for seed groups + to drop (typically those that have already produced a + non-error ``AttackResult``). + """ + self._seed_groups = [ + sg for sg in self._seed_groups if sg.objective is None or to_sha256(sg.objective.value) not in hashes + ] + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: """ Filter seed groups to only those with objectives in the remaining list. - This is used for scenario resumption to skip already completed objectives. + .. deprecated:: + Use ``drop_seed_groups_with_hashes`` (or ``keep_seed_groups_with_hashes``) + which keys on content-addressed ``objective_sha256`` instead of + objective text. Scheduled for removal in 0.16.0. Args: remaining_objectives (List[str]): List of objectives that still need to be executed. """ + print_deprecation_message( + old_item="AtomicAttack.filter_seed_groups_by_objectives(remaining_objectives=...)", + new_item="AtomicAttack.keep_seed_groups_with_hashes(hashes=...)", + removed_in="0.16.0", + ) remaining_set = set(remaining_objectives) self._seed_groups = [ sg for sg in self._seed_groups if sg.objective is not None and sg.objective.value in remaining_set ] + def keep_seed_groups_with_hashes(self, *, hashes: set[str]) -> set[str]: + """ + Keep only seed groups whose ``objective_sha256`` is in ``hashes``. + + Inverse of ``drop_seed_groups_with_hashes``: used on resume to + replay the originally-sampled subset and ignore any seed groups that + were added since (or that landed in this run's fresh ``random.sample`` + draw and are no longer in the persisted set). + + Args: + hashes (set[str]): SHA256 hashes of objective text for seed + groups to keep. + + Returns: + set[str]: The hashes that were actually retained (intersection of + ``hashes`` and the current seed_groups' hashes). The caller can + union these across atomic attacks to detect persisted hashes that + no longer exist in the dataset. + """ + retained: set[str] = set() + new_groups: list[SeedAttackGroup] = [] + for sg in self._seed_groups: + if sg.objective is None: + continue + sha = to_sha256(sg.objective.value) + if sha in hashes: + retained.add(sha) + new_groups.append(sg) + self._seed_groups = new_groups + return retained + async def run_async( self, *, @@ -220,6 +352,18 @@ async def run_async( else: execution_seed_groups = self._seed_groups + # Build attribution when this atomic attack is being executed inside + # a Scenario. The same attribution object is stamped on every + # per-task AttackContext; per-task identity is reconstructed from + # the row's own objective_sha256 (no positional state required). + attribution: AttackResultAttribution | None = None + if self._scenario_result_id is not None: + attribution = AttackResultAttribution( + parent_id=self._scenario_result_id, + parent_collection=self.atomic_attack_name, + parent_eval_hash=self.technique_eval_hash, + ) + results = await executor.execute_attack_from_seed_groups_async( attack=technique.attack, seed_groups=execution_seed_groups, @@ -227,6 +371,7 @@ async def run_async( objective_scorer=self._objective_scorer, memory_labels=self._memory_labels, return_partial_on_failure=return_partial_on_failure, + attribution=attribution, **self._attack_execute_params, ) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index b82f16af9..ef5a374cc 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -8,7 +8,6 @@ AtomicAttack instances sequentially, enabling comprehensive security testing campaigns. """ -import asyncio import copy import json import logging @@ -25,10 +24,11 @@ from pyrit.common import REQUIRED_VALUE, Parameter, apply_defaults from pyrit.common.deprecation import print_deprecation_message from pyrit.common.parameter import coerce_value, validate_param_type +from pyrit.common.utils import to_sha256 from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry -from pyrit.models import AttackResult, SeedAttackGroup +from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_requirements import TargetRequirements @@ -218,15 +218,10 @@ def __init__( self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: list[AtomicAttack] = [] self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None - self._result_lock = asyncio.Lock() # Store prepared strategies for use in _get_atomic_attacks_async self._scenario_strategies: list[ScenarioStrategy] = [] - # Store original objectives for each atomic attack (before any mutations) - # Key: atomic_attack_name, Value: tuple of original objectives - self._original_objectives_map: dict[str, tuple[str, ...]] = {} - # Maps atomic_attack_name → display_group for user-facing aggregation self._display_group_map: dict[str, str] = {} @@ -699,11 +694,6 @@ async def initialize_async( seed_groups = self._dataset_config.get_all_seed_attack_groups() self._atomic_attacks.insert(0, self._build_baseline_atomic_attack(seed_groups=seed_groups)) - # Store original objectives for each atomic attack (before any mutations during execution) - self._original_objectives_map = { - atomic_attack.atomic_attack_name: tuple(atomic_attack.objectives) for atomic_attack in self._atomic_attacks - } - # Snapshot params onto the identifier before the resume branch so the identifier # is fully populated regardless of which branch we take. Deep-copy avoids sharing # mutable state with self.params. @@ -724,6 +714,7 @@ async def initialize_async( ) self._validate_stored_scenario(stored_result=existing_results[0]) + self._apply_persisted_objectives(stored_result=existing_results[0]) return # Valid resume - skip creating new scenario result # Build display group mapping from atomic attacks @@ -742,12 +733,84 @@ async def initialize_async( attack_results=attack_results, scenario_run_state="CREATED", display_group_map=self._display_group_map, + metadata=self._build_initial_scenario_metadata(), ) self._memory.add_scenario_results_to_memory(scenario_results=[result]) self._scenario_result_id = str(result.id) logger.info(f"Created new scenario result with ID: {self._scenario_result_id}") + def _build_initial_scenario_metadata(self) -> dict[str, Any]: + """ + Build the metadata dict persisted with a freshly-created ``ScenarioResult``. + + When ``max_dataset_size`` is in effect, the dataset config draws an + unseeded ``random.sample`` and the chosen subset would silently change + on the next run (e.g. a resume). To make resume reliable, snapshot the + chosen objective hashes here so the next ``_setup_scenario_async`` can + replay them via ``keep_seed_groups_with_hashes``. + + When ``max_dataset_size`` is not set, the sample equals the dataset and + nothing needs pinning; the dict is empty. + + Returns: + dict[str, Any]: Metadata payload for the new ScenarioResult. + """ + metadata: dict[str, Any] = {} + if getattr(self._dataset_config, "max_dataset_size", None) is None: + return metadata + hashes: list[str] = [] + seen: set[str] = set() + for aa in self._atomic_attacks: + for sg in aa.seed_groups: + if sg.objective is None: + continue + sha = to_sha256(sg.objective.value) + if sha not in seen: + seen.add(sha) + hashes.append(sha) + metadata["objective_hashes"] = hashes + return metadata + + def _apply_persisted_objectives(self, *, stored_result: ScenarioResult) -> None: + """ + On resume, replay the originally-sampled objective subset. + + When the first run used ``max_dataset_size``, the chosen subset was + recorded in ``ScenarioResult.metadata["objective_hashes"]``. + Restrict each atomic attack's freshly-resolved seed_groups to that set + so a fresh ``random.sample`` draw on resume can't silently shift which + objectives the scenario operates on. If any persisted hash is no longer + present in the dataset, refuse to resume — running a smaller subset + than the user committed to would silently produce different results. + + Args: + stored_result (ScenarioResult): The scenario result loaded from memory. + + Raises: + ValueError: If any persisted objective hash is missing from the + currently-resolved dataset. + """ + metadata = stored_result.metadata or {} + persisted = metadata.get("objective_hashes") + if not persisted: + return + + persisted_hashes: set[str] = set(persisted) + retained: set[str] = set() + for aa in self._atomic_attacks: + retained |= aa.keep_seed_groups_with_hashes(hashes=persisted_hashes) + + missing = persisted_hashes - retained + if missing: + sample = sorted(missing)[:3] + raise ValueError( + f"Scenario result id '{self._scenario_result_id}' cannot resume: " + f"{len(missing)} persisted objective hash(es) are no longer present in the dataset " + f"(missing examples: {', '.join(h[:12] + '...' for h in sample)}). " + f"Either restore the missing objectives or drop scenario_result_id to start a new scenario." + ) + def _build_baseline_atomic_attack(self, *, seed_groups: list[SeedAttackGroup]) -> AtomicAttack: """ Build the baseline AtomicAttack from pre-resolved seed groups. @@ -847,45 +910,69 @@ def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> None: f"(ID: {self._scenario_result_id}, state: {stored_result.scenario_run_state})" ) - def _get_completed_objectives_for_attack(self, *, atomic_attack_name: str) -> set[str]: + def _get_completed_objective_hashes_for_attack(self, *, atomic_attack: AtomicAttack) -> set[str]: """ - Get the set of objectives that have already been completed for a specific atomic attack. + Return the set of ``objective_sha256`` values already completed (non-error) + for a specific atomic attack inside this scenario. + + Queries ``AttackResultEntry`` rows directly by ``attribution_parent_id`` — + which is stamped at write-time by the attack persistence path — so + results from an interrupted run are visible even though the + ``ScenarioResult.attack_results`` aggregate may not yet reflect them. + Identity is content-derived (``to_sha256(objective)``), so it stays + stable even if ``get_seed_groups()`` reorders or resamples between runs. + + Rows are matched on ``(parent_collection, parent_eval_hash)`` so that + two ``AtomicAttack`` instances sharing a name but using different + techniques (e.g. base64 vs hex encoders) never cross-pollinate their + completed-hash sets on resume. Rows persisted before + ``parent_eval_hash`` was introduced (or by callers that don't supply + one) match name-only as a backward-compatible fallback. Args: - atomic_attack_name (str): The name of the atomic attack to check. + atomic_attack (AtomicAttack): The live atomic attack whose + ``atomic_attack_name`` and technique identifier scope the query. Returns: - Set[str]: Set of objective strings that have been completed. + set[str]: ``objective_sha256`` hex strings for completed-without-error rows. """ if not self._scenario_result_id: return set() - completed_objectives: set[str] = set() + atomic_attack_name = atomic_attack.atomic_attack_name + expected_eval_hash = atomic_attack.technique_eval_hash + completed_hashes: set[str] = set() try: - # Retrieve the scenario result from memory - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) - - if scenario_results: - scenario_result = scenario_results[0] - # Get completed objectives for this atomic attack name - if atomic_attack_name in scenario_result.attack_results: - completed_objectives = { - result.objective for result in scenario_result.attack_results[atomic_attack_name] - } + rows = self._memory.get_attack_results(scenario_result_id=self._scenario_result_id) + for row in rows: + if row.outcome == AttackOutcome.ERROR: + continue + if row.attribution_data is None: + continue + if row.attribution_data.get("parent_collection") != atomic_attack_name: + continue + row_eval_hash = row.attribution_data.get("parent_eval_hash") + if row_eval_hash is not None and row_eval_hash != expected_eval_hash: + continue + if row.objective: + completed_hashes.add(to_sha256(row.objective)) except Exception as e: logger.warning( - f"Failed to retrieve completed objectives for atomic attack '{atomic_attack_name}': {str(e)}" + f"Failed to retrieve completed objective hashes for atomic attack '{atomic_attack_name}': {str(e)}" ) - return completed_objectives + return completed_hashes async def _get_remaining_atomic_attacks_async(self) -> list[AtomicAttack]: """ Get the list of atomic attacks that still have objectives to complete. - This method filters out atomic attacks where all objectives have been completed, - and updates the objectives list for atomic attacks that are partially complete. + Uses ``objective_sha256`` as the stable identity for resume: each + atomic attack enforces uniqueness of objective hashes at construction + time, and the executor stamps ``attribution_parent_id`` + + ``attribution_data["parent_collection"]`` on the row so a content-hash + join is sufficient. Returns: List[AtomicAttack]: List of atomic attacks with uncompleted objectives. @@ -897,64 +984,27 @@ async def _get_remaining_atomic_attacks_async(self) -> list[AtomicAttack]: remaining_attacks: list[AtomicAttack] = [] for atomic_attack in self._atomic_attacks: - # Get completed objectives for this atomic attack name - completed_objectives = self._get_completed_objectives_for_attack( - atomic_attack_name=atomic_attack.atomic_attack_name - ) - - # Get ORIGINAL objectives (before any mutations) from stored map - original_objectives = self._original_objectives_map.get(atomic_attack.atomic_attack_name, ()) + completed_hashes = self._get_completed_objective_hashes_for_attack(atomic_attack=atomic_attack) - # Calculate remaining objectives - remaining_objectives = [obj for obj in original_objectives if obj not in completed_objectives] - - if remaining_objectives: - # If there are remaining objectives, update the atomic attack - if len(remaining_objectives) < len(original_objectives): + if completed_hashes: + original_count = len(atomic_attack.seed_groups) + atomic_attack.drop_seed_groups_with_hashes(hashes=completed_hashes) + remaining_count = len(atomic_attack.seed_groups) + if remaining_count == 0: + logger.info( + f"Atomic attack '{atomic_attack.atomic_attack_name}' has all objectives completed, skipping" + ) + continue + if remaining_count < original_count: logger.info( f"Atomic attack '{atomic_attack.atomic_attack_name}' has " - f"{len(remaining_objectives)}/{len(original_objectives)} objectives remaining" + f"{remaining_count}/{original_count} objectives remaining" ) - # Update the objectives for this atomic attack to only include remaining ones - atomic_attack.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) - remaining_attacks.append(atomic_attack) - else: - logger.info( - f"Atomic attack '{atomic_attack.atomic_attack_name}' has all objectives completed, skipping" - ) + remaining_attacks.append(atomic_attack) return remaining_attacks - async def _update_scenario_result_async( - self, *, atomic_attack_name: str, attack_results: list[AttackResult] - ) -> None: - """ - Update the scenario result in memory with new attack results (thread-safe). - - This method is thread-safe and can be called from parallel executions. - - Args: - atomic_attack_name (str): The name of the atomic attack. - attack_results (List[AttackResult]): The list of new attack results to add. - """ - if not self._scenario_result_id: - logger.warning("Cannot update scenario result: no scenario result ID available") - return - - async with self._result_lock: - success = self._memory.add_attack_results_to_scenario( - scenario_result_id=self._scenario_result_id, - atomic_attack_name=atomic_attack_name, - attack_results=attack_results, - ) - - if not success: - logger.error( - f"Failed to update scenario result with {len(attack_results)} results " - f"for atomic attack '{atomic_attack_name}'" - ) - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Build atomic attacks from the cross-product of selected techniques and datasets. @@ -1191,6 +1241,12 @@ async def _execute_scenario_async(self) -> ScenarioResult: ), start=completed_count + 1, ): + # Stamp the scenario id onto the atomic attack so each persisted + # AttackResult carries the attribution_parent_id linkage. This + # is what enables mid-run interruption recovery (results are + # visible without the post-atomic-attack bulk manifest write). + atomic_attack.set_scenario_result_id(scenario_result_id) + logger.info( f"Executing atomic attack {i}/{len(self._atomic_attacks)} " f"('{atomic_attack.atomic_attack_name}') in scenario '{self._name}'" @@ -1202,12 +1258,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: return_partial_on_failure=True, ) - # Always save completed results, even if some objectives didn't complete - if atomic_results.completed_results: - await self._update_scenario_result_async( - atomic_attack_name=atomic_attack.atomic_attack_name, - attack_results=atomic_results.completed_results, - ) + # Per-result scenario linkage is now stamped by the attack + # event handler at write time; no post-atomic bulk update. # Check if there were any incomplete objectives if atomic_results.has_incomplete: @@ -1224,19 +1276,12 @@ async def _execute_scenario_async(self) -> ScenarioResult: for obj, exc in atomic_results.incomplete_objectives: logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") - # Collect error attack result IDs from the exceptions - error_ids = [] - for _, exc in atomic_results.incomplete_objectives: - error_id = getattr(exc, "error_attack_result_id", None) - if error_id: - error_ids.append(error_id) - - # Link error attack results to the scenario result - if error_ids: - self._memory.update_scenario_error_attacks( - scenario_result_id=scenario_result_id, - error_attack_result_ids=error_ids, - ) + # Error AttackResults are linked to this scenario via the + # attribution_parent_id foreign key on AttackResultEntry + # (stamped by the attack event handler when an + # AttackResultAttribution is on the context). The + # previous per-scenario error_id manifest is no longer + # needed. # Mark scenario as failed error_msg = ( diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 29d2855cd..65169c90f 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -19,6 +19,7 @@ _DEFAULT_MAX_CONCURRENT_RUNS, ScenarioRunService, ) +from pyrit.models import AttackOutcome _REGISTRY_PATCH_BASE = "pyrit.registry" _MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" @@ -74,7 +75,6 @@ def _make_db_scenario_result( sr.display_group_map = {} sr.error_message = None sr.error_type = None - sr.error_attack_result_ids = [] return sr @@ -83,6 +83,9 @@ def mock_memory(): """Patch CentralMemory.get_memory_instance to return a mock.""" mock = MagicMock() mock.get_scenario_results.return_value = [] + # Default: no error AttackResults linked to any scenario. Tests that exercise + # the error fallback path explicitly set get_attack_results.return_value. + mock.get_attack_results.return_value = [] with patch(_MEMORY_PATCH, return_value=mock): yield mock @@ -298,9 +301,14 @@ def test_get_run_returns_existing_run(self, mock_memory) -> None: assert fetched.status == ScenarioRunStatus.IN_PROGRESS def test_get_run_falls_back_to_persisted_error(self, mock_memory) -> None: - """Test that get_run extracts error from persisted error AttackResult when no active task.""" + """Test that get_run extracts error from persisted error AttackResult when no active task. + + After the foreign-key-based scenario linkage refactor, error + AttackResults are located via + ``get_attack_results(scenario_result_id=..., outcome=ERROR)`` rather + than via a per-scenario error_attack_result_ids manifest. + """ db_result = _make_db_scenario_result(result_id="sr-fail", run_state="FAILED") - db_result.error_attack_result_ids = ["err-ar-1"] # Mock the error AttackResult lookup error_ar = MagicMock() @@ -315,7 +323,10 @@ def test_get_run_falls_back_to_persisted_error(self, mock_memory) -> None: assert fetched is not None assert fetched.error == "Connection refused" assert fetched.error_type == "ConnectionError" - mock_memory.get_attack_results.assert_called_once_with(attack_result_ids=["err-ar-1"]) + mock_memory.get_attack_results.assert_called_once_with( + scenario_result_id="sr-fail", + outcome=AttackOutcome.ERROR, + ) class TestScenarioRunServiceListRuns: diff --git a/tests/unit/executor/attack/core/test_attack_executor.py b/tests/unit/executor/attack/core/test_attack_executor.py index f949f69bf..703473987 100644 --- a/tests/unit/executor/attack/core/test_attack_executor.py +++ b/tests/unit/executor/attack/core/test_attack_executor.py @@ -336,6 +336,99 @@ async def test_validates_explicit_empty_field_overrides_for_seed_groups(self): ) +@pytest.mark.usefixtures("patch_central_database") +class TestAttributionPropagation: + """Tests for AttackResultAttribution propagation through the AttackExecutor. + + The executor stamps the same ``AttackResultAttribution`` on every per-task + context. Per-task identity is reconstructed from each row's own + ``objective_sha256`` at hydration/resume time, so no positional state is + threaded through the executor. + """ + + async def test_attribution_stamps_every_per_task_context(self): + from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution + + attack = create_mock_attack() + seen_parent_ids: list[str] = [] + seen_collections: list[str] = [] + + async def capture(context): + attr = context._attribution + assert attr is not None + seen_parent_ids.append(attr.parent_id) + seen_collections.append(attr.parent_collection) + return create_attack_result(context.params.objective) + + attack.execute_with_context_async = AsyncMock(side_effect=capture) + + seed_groups = [create_seed_group(f"obj-{i}") for i in range(4)] + attribution = AttackResultAttribution(parent_id="sid", parent_collection="atomic") + + executor = AttackExecutor(max_concurrency=1) + result = await executor.execute_attack_from_seed_groups_async( + attack=attack, + seed_groups=seed_groups, + attribution=attribution, + ) + + assert seen_parent_ids == ["sid"] * 4 + assert seen_collections == ["atomic"] * 4 + assert len(result.completed_results) == 4 + + async def test_attribution_parallel_safe_with_high_concurrency(self): + """At max_concurrency > 1, every task still sees the same attribution + regardless of completion order — there is no per-task positional state. + """ + from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution + + attack = create_mock_attack() + seen: dict[str, AttackResultAttribution] = {} + + async def out_of_order(context): + attr = context._attribution + assert attr is not None + # Reverse-delay tasks so completion order is inverse of input order. + i = int(context.params.objective.split("-")[1]) + await asyncio.sleep(0.005 * (10 - i)) + seen[context.params.objective] = attr + return create_attack_result(context.params.objective) + + attack.execute_with_context_async = AsyncMock(side_effect=out_of_order) + + seed_groups = [create_seed_group(f"obj-{i}") for i in range(6)] + attribution = AttackResultAttribution(parent_id="sid", parent_collection="atomic") + + executor = AttackExecutor(max_concurrency=6) + await executor.execute_attack_from_seed_groups_async( + attack=attack, + seed_groups=seed_groups, + attribution=attribution, + ) + + for i in range(6): + attr = seen[f"obj-{i}"] + assert attr.parent_id == "sid" + assert attr.parent_collection == "atomic" + + async def test_no_attribution_leaves_context_attribution_none(self): + attack = create_mock_attack() + + async def capture(context): + attr = context._attribution + assert attr is None + return create_attack_result(context.params.objective) + + attack.execute_with_context_async = AsyncMock(side_effect=capture) + + seed_groups = [create_seed_group("obj-0"), create_seed_group("obj-1")] + executor = AttackExecutor(max_concurrency=2) + await executor.execute_attack_from_seed_groups_async( + attack=attack, + seed_groups=seed_groups, + ) + + @pytest.mark.usefixtures("patch_central_database") class TestPartialFailureHandling: """Tests for partial failure handling.""" diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 33370f0c9..e932da3da 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -610,6 +610,88 @@ async def test_on_event_handles_other_events(self, event_handler, sample_attack_ f"Attack is in '{StrategyEvent.ON_PRE_VALIDATE.value}' stage for {event_handler.__class__.__name__}" ) + async def test_on_post_execute_stamps_scenario_attribution_when_present( + self, sample_attack_context, sample_attack_result, mock_memory + ): + """When the context carries an AttackResultAttribution, the persisted + AttackResult must have attribution_parent_id + attribution_data populated.""" + from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution + + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + sample_attack_context.start_time = 100.0 + sample_attack_context._attribution = AttackResultAttribution( + parent_id="scenario-1", + parent_collection="atomic_a", + ) + + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + await handler.on_event(event_data) + + assert sample_attack_result.attribution_parent_id == "scenario-1" + assert sample_attack_result.attribution_data == { + "parent_collection": "atomic_a", + } + + async def test_on_post_execute_no_attribution_leaves_fields_none( + self, sample_attack_context, sample_attack_result, mock_memory + ): + """Outside a Scenario, _attribution is None and the attribution fields + on the persisted AttackResult must stay None.""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + sample_attack_context.start_time = 100.0 + # _attribution defaults to None — no scenario stamping should happen. + + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + await handler.on_event(event_data) + + assert sample_attack_result.attribution_parent_id is None + assert sample_attack_result.attribution_data is None + + async def test_on_error_stamps_scenario_attribution_when_present(self, sample_attack_context, mock_memory): + """Error AttackResults must also carry the attribution foreign key so + error lookups via get_attack_results(scenario_result_id=..., outcome=ERROR) work.""" + from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution + + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + sample_attack_context.start_time = 100.0 + sample_attack_context._attribution = AttackResultAttribution( + parent_id="scenario-err", + parent_collection="atomic_err", + ) + + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=RuntimeError("boom"), + ) + await handler.on_event(event_data) + + # The error AttackResult was persisted; inspect what was sent to memory. + call = mock_memory.add_attack_results_to_memory.call_args + persisted = call.kwargs["attack_results"][0] + assert persisted.outcome == AttackOutcome.ERROR + assert persisted.attribution_parent_id == "scenario-err" + assert persisted.attribution_data == { + "parent_collection": "atomic_err", + } + @pytest.mark.usefixtures("patch_central_database") class TestAttackStrategyIntegration: diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 97a240e37..6ce891943 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -160,21 +160,24 @@ def test_empty_ids_returns_empty(sqlite_instance: MemoryInterface): def test_attack_results_populated_correctly(sqlite_instance: MemoryInterface): """Test that retrieving scenario results populates attack_results correctly.""" - # Create and add attack results - attack_result1 = create_attack_result("conv_1", "Objective 1", AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", "Objective 2", AttackOutcome.FAILURE) - attack_result3 = create_attack_result("conv_3", "Objective 3", AttackOutcome.SUCCESS) - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) + scenario_result = create_scenario_result(name="Multi-Attack Scenario", attack_results={}) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) - # Create scenario result with multiple attacks - scenario_result = create_scenario_result( - name="Multi-Attack Scenario", - attack_results={ - "PromptInjection": [attack_result1, attack_result2], - "Crescendo": [attack_result3], - }, + sid = scenario_result.id + attack_result1 = _make_attack_result_for_scenario( + scenario_result_id=sid, atomic_attack_name="PromptInjection", objective_index=0, conversation_id="conv_1" ) - sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + attack_result2 = _make_attack_result_for_scenario( + scenario_result_id=sid, + atomic_attack_name="PromptInjection", + objective_index=1, + conversation_id="conv_2", + outcome=AttackOutcome.FAILURE, + ) + attack_result3 = _make_attack_result_for_scenario( + scenario_result_id=sid, atomic_attack_name="Crescendo", objective_index=0, conversation_id="conv_3" + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) # Retrieve and verify attack_results are populated results = sqlite_instance.get_scenario_results() @@ -198,40 +201,42 @@ def test_attack_results_populated_correctly(sqlite_instance: MemoryInterface): def test_attack_order_preserved(sqlite_instance: MemoryInterface): - """Test that attack results maintain their order within each attack name.""" - # Create and add attack results - attack_results = [create_attack_result(f"conv_{i}", f"Objective {i}") for i in range(5)] - sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) - - # Create scenario result with ordered attacks - scenario_result = create_scenario_result( - name="Ordered Scenario", - attack_results={ - "Attack1": attack_results, - }, - ) + """Hydration sorts each atomic attack's results by ``timestamp`` (which + monotonically tracks insertion order under normal sequential execution).""" + scenario_result = create_scenario_result(name="Ordered Scenario", attack_results={}) sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) - # Retrieve and verify order is preserved + sid = scenario_result.id + # Insert in a specific order; hydration must surface them in the same order. + attack_results = [ + _make_attack_result_for_scenario( + scenario_result_id=sid, atomic_attack_name="Attack1", objective_index=i, conversation_id=f"conv_{i}" + ) + for i in range(5) + ] + for ar in attack_results: + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + results = sqlite_instance.get_scenario_results() retrieved_attacks = results[0].attack_results["Attack1"] - # Verify the conversation IDs are in the same order retrieved_conv_ids = [ar.conversation_id for ar in retrieved_attacks] - original_conv_ids = [ar.conversation_id for ar in attack_results] - assert retrieved_conv_ids == original_conv_ids + assert retrieved_conv_ids == [f"conv_{i}" for i in range(5)] -def test_stores_conversation_ids_only(sqlite_instance: MemoryInterface, sample_attack_results): - """Test that scenario results store only conversation IDs, not full AttackResult objects.""" - # Create and add scenario result - scenario_result = create_scenario_result( - name="Test Scenario", - attack_results={"Attack1": [sample_attack_results[0]]}, - ) +def test_stores_conversation_ids_only(sqlite_instance: MemoryInterface): + """Test that scenario results expose AttackResult objects with conversation IDs after hydration.""" + scenario_result = create_scenario_result(name="Test Scenario", attack_results={}) sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) - # Retrieve the scenario result to verify structure + ar = _make_attack_result_for_scenario( + scenario_result_id=scenario_result.id, + atomic_attack_name="Attack1", + objective_index=0, + conversation_id="conv_1", + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) assert len(results) == 1 @@ -299,24 +304,29 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): def test_multiple_scenarios_with_attacks(sqlite_instance: MemoryInterface): """Test retrieving multiple scenarios with their attack results populated.""" - # Create attack results for multiple scenarios - attack_results_scenario1 = [create_attack_result(f"conv_s1_{i}", f"S1 Objective {i}") for i in range(5)] - attack_results_scenario2 = [create_attack_result(f"conv_s2_{i}", f"S2 Objective {i}") for i in range(3)] + scenario1 = create_scenario_result(name="Scenario 1", attack_results={}) + scenario2 = create_scenario_result(name="Scenario 2", attack_results={}) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) - all_attack_results = attack_results_scenario1 + attack_results_scenario2 + all_attack_results = [ + _make_attack_result_for_scenario( + scenario_result_id=scenario1.id, + atomic_attack_name="Attack1", + objective_index=i, + conversation_id=f"conv_s1_{i}", + ) + for i in range(5) + ] + [ + _make_attack_result_for_scenario( + scenario_result_id=scenario2.id, + atomic_attack_name="Attack2", + objective_index=i, + conversation_id=f"conv_s2_{i}", + ) + for i in range(3) + ] sqlite_instance.add_attack_results_to_memory(attack_results=all_attack_results) - # Create multiple scenario results - scenario1 = create_scenario_result( - name="Scenario 1", - attack_results={"Attack1": attack_results_scenario1}, - ) - scenario2 = create_scenario_result( - name="Scenario 2", - attack_results={"Attack2": attack_results_scenario2}, - ) - sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) - # Retrieve all scenarios results = sqlite_instance.get_scenario_results() assert len(results) == 2 @@ -648,58 +658,147 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] -def test_update_scenario_error_attacks_success(sqlite_instance: MemoryInterface, sample_attack_results): - """Test successfully linking error attack result IDs to a scenario result.""" - scenario_result = create_scenario_result( - name="Error Scenario", - attack_results={"Attack1": [sample_attack_results[0]]}, - ) - sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) +# ============================================================================= +# Scenario linkage (attribution_parent_id foreign key + attribution_data on +# AttackResultEntry) hydration tests +# ============================================================================= - error_ids = ["error-ar-1", "error-ar-2"] - sqlite_instance.update_scenario_error_attacks( - scenario_result_id=str(scenario_result.id), - error_attack_result_ids=error_ids, - ) - # Verify the error IDs were persisted - results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) - assert len(results) == 1 - assert results[0].error_attack_result_ids == error_ids +def _make_attack_result_for_scenario( + *, + scenario_result_id, + atomic_attack_name, + objective_index, + conversation_id=None, + outcome=AttackOutcome.SUCCESS, +): + """Build an AttackResult pre-stamped with scenario linkage (mirrors what + the event handler does when an AttackResultAttribution is on the context).""" + return AttackResult( + conversation_id=conversation_id or f"conv-{atomic_attack_name}-{objective_index}", + objective=f"objective-{atomic_attack_name}-{objective_index}", + outcome=outcome, + executed_turns=1, + attribution_parent_id=str(scenario_result_id), + attribution_data={"parent_collection": atomic_attack_name}, + ) -def test_update_scenario_error_attacks_appends_to_existing(sqlite_instance: MemoryInterface, sample_attack_results): - """Test that updating error attacks appends to existing IDs without duplicates.""" +def test_get_scenario_results_loads_attack_results_via_foreign_key(sqlite_instance: MemoryInterface): + """When AttackResultEntry rows carry the attribution_parent_id foreign key, + hydration picks them up directly — without needing the legacy + attack_results_json manifest. This is the path that makes mid-AtomicAttack + interruption-recovery work.""" scenario_result = create_scenario_result( - name="Error Scenario", - attack_results={"Attack1": [sample_attack_results[0]]}, + name="ForeignKey-only Scenario", + attack_results={}, # manifest intentionally empty ) sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) - # First update - sqlite_instance.update_scenario_error_attacks( - scenario_result_id=str(scenario_result.id), - error_attack_result_ids=["error-ar-1"], + sid = scenario_result.id + ar1 = _make_attack_result_for_scenario(scenario_result_id=sid, atomic_attack_name="a", objective_index=0) + ar2 = _make_attack_result_for_scenario(scenario_result_id=sid, atomic_attack_name="a", objective_index=1) + ar3 = _make_attack_result_for_scenario(scenario_result_id=sid, atomic_attack_name="b", objective_index=0) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + [result] = sqlite_instance.get_scenario_results(scenario_result_ids=[str(sid)]) + assert set(result.attack_results.keys()) == {"a", "b"} + assert [r.conversation_id for r in result.attack_results["a"]] == [ + "conv-a-0", + "conv-a-1", + ] + assert [r.conversation_id for r in result.attack_results["b"]] == ["conv-b-0"] + + +def test_get_attack_results_filters_by_scenario_result_id(sqlite_instance: MemoryInterface): + """get_attack_results gains a scenario_result_id filter — replaces the + removed error_attack_result_ids_json lookup path.""" + scenario_result = create_scenario_result(name="Filter Scenario") + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + sid = scenario_result.id + + ok = _make_attack_result_for_scenario(scenario_result_id=sid, atomic_attack_name="a", objective_index=0) + err = _make_attack_result_for_scenario( + scenario_result_id=sid, + atomic_attack_name="a", + objective_index=1, + outcome=AttackOutcome.ERROR, ) + # An unrelated AttackResult NOT linked to this scenario should be excluded. + unrelated = create_attack_result("unrelated-conv", "unrelated-obj") + sqlite_instance.add_attack_results_to_memory(attack_results=[ok, err, unrelated]) + + all_for_scenario = sqlite_instance.get_attack_results(scenario_result_id=str(sid)) + assert {r.conversation_id for r in all_for_scenario} == {ok.conversation_id, err.conversation_id} - # Second update with overlap and new ID - sqlite_instance.update_scenario_error_attacks( - scenario_result_id=str(scenario_result.id), - error_attack_result_ids=["error-ar-1", "error-ar-2"], + only_errors = sqlite_instance.get_attack_results( + scenario_result_id=str(sid), + outcome=AttackOutcome.ERROR.value, ) + assert [r.conversation_id for r in only_errors] == [err.conversation_id] - results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) - # Should be deduplicated: ["error-ar-1", "error-ar-2"] - assert results[0].error_attack_result_ids == ["error-ar-1", "error-ar-2"] +def test_delete_scenario_sets_attack_result_foreign_key_to_null(sqlite_instance: MemoryInterface): + """ON DELETE SET NULL: deleting the parent ScenarioResultEntry nulls the + attribution_parent_id foreign key on its linked AttackResultEntries but + the AttackResultEntries survive (attribution_data is retained as + historical provenance). -def test_update_scenario_error_attacks_not_found(sqlite_instance: MemoryInterface): - """Test that updating a nonexistent scenario result raises ValueError.""" - with pytest.raises(ValueError, match="not found in memory"): - sqlite_instance.update_scenario_error_attacks( - scenario_result_id="nonexistent-id", - error_attack_result_ids=["error-ar-1"], - ) + Note: SQLite does not enforce foreign keys by default; this test enables + them on the session for the duration of the delete to verify the + ON DELETE SET NULL clause works. Production deployments using SQL Server + enforce foreign keys by default. + """ + from contextlib import closing + + from sqlalchemy import text as _sql_text + + from pyrit.memory.memory_models import AttackResultEntry, ScenarioResultEntry + + scenario_result = create_scenario_result(name="To Be Deleted") + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + sid = scenario_result.id + + ar = _make_attack_result_for_scenario(scenario_result_id=sid, atomic_attack_name="a", objective_index=0) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + + # Enable foreign keys for the delete and verify the SET NULL clause fires. + with closing(sqlite_instance.get_session()) as session: + session.execute(_sql_text("PRAGMA foreign_keys = ON")) + session.query(ScenarioResultEntry).filter_by(id=sid).delete() + session.commit() + + # The AttackResult survives, but its foreign key is now NULL. + # attribution_data is retained as historical provenance. + with closing(sqlite_instance.get_session()) as session: + entry = session.query(AttackResultEntry).filter_by(conversation_id=ar.conversation_id).one() + assert entry.attribution_parent_id is None + assert entry.attribution_data == {"parent_collection": "a"} + + +def test_update_scenario_run_state_targeted_update_preserves_manifest(sqlite_instance: MemoryInterface): + """update_scenario_run_state must be a targeted UPDATE — it must not + re-serialize the whole row and clobber the manifest column during the + deprecation window.""" + scenario_result = create_scenario_result( + name="Targeted Update", + attack_results={"a": []}, # baseline manifest + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + sid = str(scenario_result.id) + + sqlite_instance.update_scenario_run_state( + scenario_result_id=sid, + scenario_run_state="FAILED", + error_message="boom", + error_type="RuntimeError", + ) + + # State and error fields updated. + [hydrated] = sqlite_instance.get_scenario_results(scenario_result_ids=[sid]) + assert hydrated.scenario_run_state == "FAILED" + assert hydrated.error_message == "boom" + assert hydrated.error_type == "RuntimeError" def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index b382bf864..f143a69e9 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -474,17 +474,3 @@ def test_init_with_empty_attack_results(self): entry = ScenarioResultEntry(entry=sr) conv_ids = entry.get_conversation_ids_by_attack_name() assert conv_ids == {} - - def test_roundtrip_error_attack_result_ids(self): - sr = self._make_scenario_result(error_attack_result_ids=["err-1", "err-2"]) - entry = ScenarioResultEntry(entry=sr) - assert entry.error_attack_result_ids_json is not None - recovered = entry.get_scenario_result() - assert recovered.error_attack_result_ids == ["err-1", "err-2"] - - def test_roundtrip_error_attack_result_ids_none(self): - sr = self._make_scenario_result() - entry = ScenarioResultEntry(entry=sr) - assert entry.error_attack_result_ids_json is None - recovered = entry.get_scenario_result() - assert recovered.error_attack_result_ids == [] diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index 0140ce5b1..386775d9f 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -205,6 +205,263 @@ def test_migration_downgrade_creates_proper_structure(): engine.dispose() +# ============================================================================= +# Backfill tests for the attribution_parent_id foreign key migration +# ============================================================================= + + +_SCENARIO_LINKAGE_REV = "9c8b7a6d5e4f" +_PREV_REV = "7a1b2c3d4e5f" + + +def _seed_pre_migration_scenario(connection, *, scenario_id, manifest_json): + """Insert a ScenarioResultEntry row at the pre-migration revision.""" + connection.execute( + text( + 'INSERT INTO "ScenarioResultEntries" ' + "(id, scenario_name, scenario_description, scenario_version, pyrit_version, " + "objective_target_identifier, scenario_run_state, attack_results_json, " + "number_tries, completion_time, timestamp) " + "VALUES (:id, :name, '', 1, '0.14.0.dev0', '{}', 'COMPLETED', :manifest, 0, '2026-05-18', '2026-05-18')" + ), + {"id": scenario_id, "name": "Backfill Test", "manifest": manifest_json}, + ) + + +def _seed_pre_migration_attack_result(connection, *, attack_id, conversation_id): + """Insert an AttackResultEntry row at the pre-migration revision.""" + connection.execute( + text( + 'INSERT INTO "AttackResultEntries" ' + "(id, conversation_id, objective, attack_identifier, objective_sha256, executed_turns, " + "execution_time_ms, outcome, timestamp) " + "VALUES (:id, :conv, 'obj', '{}', 'sha', 1, 0, 'success', '2026-05-18')" + ), + {"id": attack_id, "conv": conversation_id}, + ) + + +def _config_for(connection): + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + return config + + +def test_backfill_links_attack_results_via_conversation_id(): + """Upgrading from the pre-foreign-key revision backfills + attribution_parent_id + attribution_data on AttackResultEntries by + matching conversation_id.""" + import json + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "backfill-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + sid = str(uuid.uuid4()) + ar1_id = str(uuid.uuid4()) + ar2_id = str(uuid.uuid4()) + ar3_id = str(uuid.uuid4()) + + with engine.begin() as connection: + config = _config_for(connection) + # Step the schema up to JUST before the linkage migration. + command.upgrade(config, _PREV_REV) + + _seed_pre_migration_attack_result(connection, attack_id=ar1_id, conversation_id="conv-a-0") + _seed_pre_migration_attack_result(connection, attack_id=ar2_id, conversation_id="conv-a-1") + _seed_pre_migration_attack_result(connection, attack_id=ar3_id, conversation_id="conv-b-0") + _seed_pre_migration_scenario( + connection, + scenario_id=sid, + manifest_json=json.dumps({"a": ["conv-a-0", "conv-a-1"], "b": ["conv-b-0"]}), + ) + + command.upgrade(config, _SCENARIO_LINKAGE_REV) + + rows = connection.execute( + text( + "SELECT conversation_id, attribution_parent_id, attribution_data " + 'FROM "AttackResultEntries" ORDER BY conversation_id' + ) + ).fetchall() + + results_by_conv = {r[0]: (r[1], r[2]) for r in rows} + + # All three rows now point at the scenario via the new foreign key. + for conv in ("conv-a-0", "conv-a-1", "conv-b-0"): + assert results_by_conv[conv][0] == sid, f"{conv} should be backfilled" + + # attribution_data carries parent_collection (the atomic attack name). + sd_a0 = json.loads(results_by_conv["conv-a-0"][1]) + sd_a1 = json.loads(results_by_conv["conv-a-1"][1]) + sd_b0 = json.loads(results_by_conv["conv-b-0"][1]) + + assert sd_a0 == {"parent_collection": "a"} + assert sd_a1 == {"parent_collection": "a"} + assert sd_b0 == {"parent_collection": "b"} + finally: + engine.dispose() + + +def test_backfill_is_idempotent_and_does_not_clobber_existing_linkage(): + """The backfill is safe to re-run: rows that already carry an + ``attribution_parent_id`` are not overwritten (the WHERE IS NULL guard). We + verify by upgrading, manually retargeting a row, then downgrading + + re-upgrading and asserting the manual retarget survives.""" + import json + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "idempotent.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + sid_old = str(uuid.uuid4()) + sid_manual = str(uuid.uuid4()) + ar_id = str(uuid.uuid4()) + + with engine.begin() as connection: + config = _config_for(connection) + command.upgrade(config, _PREV_REV) + _seed_pre_migration_attack_result(connection, attack_id=ar_id, conversation_id="conv-shared") + _seed_pre_migration_scenario( + connection, scenario_id=sid_old, manifest_json=json.dumps({"a": ["conv-shared"]}) + ) + command.upgrade(config, _SCENARIO_LINKAGE_REV) + + # Manually retarget the row to a DIFFERENT attribution_parent_id — + # simulate code that already linked it post-upgrade. + connection.execute( + text('UPDATE "AttackResultEntries" SET attribution_parent_id = :sid WHERE conversation_id = :conv'), + {"sid": sid_manual, "conv": "conv-shared"}, + ) + + # Downgrade then re-upgrade to re-run the backfill. + command.downgrade(config, _PREV_REV) + + # After downgrade the foreign key column is gone, but the + # manifest still references conv-shared. On re-upgrade, the + # backfill should NOT clobber sid_manual because the column was + # just re-added as NULL — actually downgrade DROPS the column + # data, so on re-upgrade the row will start at NULL and get + # linked again. The test we want is: re-running the backfill + # while a row already has a non-NULL foreign key does not + # overwrite it. We exercise that with a fresh second + # upgrade-then-no-op-re-upgrade. + command.upgrade(config, _SCENARIO_LINKAGE_REV) + + # First upgrade after downgrade re-links it to sid_old (the + # manifest source). Now manually retarget again. + connection.execute( + text('UPDATE "AttackResultEntries" SET attribution_parent_id = :sid WHERE conversation_id = :conv'), + {"sid": sid_manual, "conv": "conv-shared"}, + ) + + # Stamping should NOT happen on re-invocation since the column + # is already non-NULL. We verify by re-running the backfill + # logic via downgrade+upgrade is NOT what we want here — we + # want the IS NULL guard. Simulate by adding another scenario + # referencing the same conversation_id and re-running the + # backfill function only. + connection.execute( + text( + 'INSERT INTO "ScenarioResultEntries" ' + "(id, scenario_name, scenario_description, scenario_version, pyrit_version, " + "objective_target_identifier, scenario_run_state, attack_results_json, " + "number_tries, completion_time, timestamp) " + "VALUES (:id, 'Other', '', 1, '0.14.0.dev0', '{}', 'COMPLETED', :manifest, 0, " + "'2026-05-18', '2026-05-18')" + ), + {"id": str(uuid.uuid4()), "manifest": json.dumps({"x": ["conv-shared"]})}, + ) + + # Manually call the backfill function (loaded via the alembic + # script directory — modules with leading-digit filenames are + # not importable through normal Python import). + from importlib.util import module_from_spec, spec_from_file_location + + migration_path = ( + Path(__file__).resolve().parent.parent.parent.parent + / "pyrit" + / "memory" + / "alembic" + / "versions" + / "9c8b7a6d5e4f_add_attribution_to_attack_results.py" + ) + spec = spec_from_file_location("scenario_linkage_migration", migration_path) + assert spec is not None and spec.loader is not None + mig = module_from_spec(spec) + spec.loader.exec_module(mig) + + from alembic import op as _op_mod + + _original_get_bind = _op_mod.get_bind + _op_mod.get_bind = lambda: connection + try: + mig._backfill_attribution_linkage() + finally: + _op_mod.get_bind = _original_get_bind + + # The row's manual retarget MUST survive — the IS NULL guard + # prevents the backfill from overwriting it. + row = connection.execute( + text('SELECT attribution_parent_id FROM "AttackResultEntries" WHERE conversation_id = :conv'), + {"conv": "conv-shared"}, + ).scalar_one() + assert row == sid_manual + finally: + engine.dispose() + + +def test_migration_drops_error_attack_result_ids_json_column(): + """The not-yet-released error_attack_result_ids_json column is removed + in this migration (no deprecation window needed).""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "drop-col.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + config = _config_for(connection) + command.upgrade(config, _PREV_REV) + cols_before = {c["name"] for c in inspect(connection).get_columns("ScenarioResultEntries")} + assert "error_attack_result_ids_json" in cols_before + + command.upgrade(config, _SCENARIO_LINKAGE_REV) + cols_after = {c["name"] for c in inspect(connection).get_columns("ScenarioResultEntries")} + assert "error_attack_result_ids_json" not in cols_after + finally: + engine.dispose() + + +def test_migration_downgrade_restores_dropped_column(): + """Downgrading from the linkage revision re-adds error_attack_result_ids_json + and removes the new AttackResultEntries columns.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "downgrade-linkage.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + config = _config_for(connection) + command.upgrade(config, _SCENARIO_LINKAGE_REV) + + attack_cols_up = {c["name"] for c in inspect(connection).get_columns("AttackResultEntries")} + assert "attribution_parent_id" in attack_cols_up + assert "attribution_data" in attack_cols_up + + command.downgrade(config, _PREV_REV) + + attack_cols_down = {c["name"] for c in inspect(connection).get_columns("AttackResultEntries")} + assert "attribution_parent_id" not in attack_cols_down + assert "attribution_data" not in attack_cols_down + + scenario_cols = {c["name"] for c in inspect(connection).get_columns("ScenarioResultEntries")} + assert "error_attack_result_ids_json" in scenario_cols + finally: + engine.dispose() + + def test_check_schema_migrations_calls_alembic_check(): with tempfile.TemporaryDirectory() as temp_dir: db_path = os.path.join(temp_dir, "check-test.db") diff --git a/tests/unit/scenario/test_atomic_attack.py b/tests/unit/scenario/test_atomic_attack.py index 5f8639175..b17db749f 100644 --- a/tests/unit/scenario/test_atomic_attack.py +++ b/tests/unit/scenario/test_atomic_attack.py @@ -980,3 +980,259 @@ async def test_enrichment_skips_db_update_when_no_attack_result_id(self, mock_at await atomic.run_async() mock_memory.update_attack_result_by_id.assert_not_called() + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackFilterSeedGroupsByCompletedHashes: + """Tests for ``drop_seed_groups_with_hashes`` — the hash-based + resume filter.""" + + def test_filters_out_completed_hashes(self, mock_attack, sample_seed_groups): + from pyrit.common.utils import to_sha256 + + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + completed = {to_sha256("objective1"), to_sha256("objective3")} + atomic.drop_seed_groups_with_hashes(hashes=completed) + + assert atomic.seed_groups == [sample_seed_groups[1]] + + def test_empty_completed_hashes_is_noop(self, mock_attack, sample_seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + + atomic.drop_seed_groups_with_hashes(hashes=set()) + + assert atomic.seed_groups == sample_seed_groups + + def test_all_hashes_completed_clears_seed_groups(self, mock_attack, sample_seed_groups): + from pyrit.common.utils import to_sha256 + + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + + atomic.drop_seed_groups_with_hashes(hashes={to_sha256(f"objective{i}") for i in range(1, 4)}) + + assert atomic.seed_groups == [] + + def test_filter_is_stable_across_resampling(self, mock_attack, sample_seed_groups): + """Identity is content-derived, so reordering ``_seed_groups`` between + two calls (e.g. a fresh ``random.sample``) doesn't break the filter.""" + from pyrit.common.utils import to_sha256 + + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + # Simulate a re-sample by reversing the internal list. + atomic._seed_groups = list(reversed(atomic._seed_groups)) + + atomic.drop_seed_groups_with_hashes(hashes={to_sha256("objective1")}) + kept_objectives = [sg.objective.value for sg in atomic.seed_groups] + assert "objective1" not in kept_objectives + assert set(kept_objectives) == {"objective2", "objective3"} + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackRestrictSeedGroupsToHashes: + """Tests for ``keep_seed_groups_with_hashes`` — the keep-set inverse used + on resume to replay the originally-sampled subset.""" + + def test_keeps_only_listed_hashes(self, mock_attack, sample_seed_groups): + from pyrit.common.utils import to_sha256 + + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + keep = {to_sha256("objective1"), to_sha256("objective3")} + retained = atomic.keep_seed_groups_with_hashes(hashes=keep) + + assert {sg.objective.value for sg in atomic.seed_groups} == {"objective1", "objective3"} + assert retained == keep + + def test_retained_set_excludes_missing_hashes(self, mock_attack, sample_seed_groups): + from pyrit.common.utils import to_sha256 + + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + keep = {to_sha256("objective1"), to_sha256("not-in-dataset")} + retained = atomic.keep_seed_groups_with_hashes(hashes=keep) + + assert {sg.objective.value for sg in atomic.seed_groups} == {"objective1"} + assert retained == {to_sha256("objective1")} + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackDuplicateObjectiveValidation: + """``AtomicAttack.__init__`` enforces objective-hash uniqueness within a + single atomic attack so resume can use the hash as a stable identity.""" + + def test_constructing_with_duplicate_objective_raises(self, mock_attack): + duplicate_groups = [ + SeedAttackGroup(seeds=[SeedObjective(value="same-objective")]), + SeedAttackGroup(seeds=[SeedObjective(value="same-objective")]), + ] + with pytest.raises(ValueError, match="duplicate objective hash"): + AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=duplicate_groups, + atomic_attack_name="dup", + ) + + def test_constructing_with_unique_objectives_succeeds(self, mock_attack, sample_seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="ok", + ) + assert len(atomic.seed_groups) == 3 + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackAttributionStamping: + """Tests for how ``run_async`` builds the ``AttackResultAttribution`` it + passes to the executor.""" + + async def test_no_attribution_when_scenario_result_id_unset( + self, mock_attack, sample_seed_groups, sample_attack_results + ): + """Outside a Scenario, ``_scenario_result_id`` is None and the + executor must receive ``attribution=None``.""" + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + assert atomic._scenario_result_id is None + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = wrap_results(sample_attack_results) + await atomic.run_async() + + assert mock_exec.call_args.kwargs["attribution"] is None + + async def test_attribution_built_when_scenario_result_id_set( + self, mock_attack, sample_seed_groups, sample_attack_results + ): + """When the Scenario stamps ``_scenario_result_id`` onto the atomic + attack, ``run_async`` must build and pass a single attribution object.""" + from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution + + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="MyAtomicAttack", + ) + atomic._scenario_result_id = "00000000-0000-0000-0000-000000000abc" + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = wrap_results(sample_attack_results) + await atomic.run_async() + + attribution = mock_exec.call_args.kwargs["attribution"] + assert isinstance(attribution, AttackResultAttribution) + assert attribution.parent_id == "00000000-0000-0000-0000-000000000abc" + assert attribution.parent_collection == "MyAtomicAttack" + + async def test_attribution_includes_technique_eval_hash( + self, mock_attack, sample_seed_groups, sample_attack_results + ): + """The stamped attribution must carry ``parent_eval_hash`` equal to + ``technique_eval_hash`` so resume disambiguates between two atomic + attacks that share a name but use different techniques.""" + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="MyAtomicAttack", + ) + atomic._scenario_result_id = "00000000-0000-0000-0000-000000000abc" + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = wrap_results(sample_attack_results) + await atomic.run_async() + + attribution = mock_exec.call_args.kwargs["attribution"] + assert attribution.parent_eval_hash is not None + assert attribution.parent_eval_hash == atomic.technique_eval_hash + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackTechniqueEvalHash: + """``technique_eval_hash`` must be stable across seed groups and differ + between distinct technique configurations — it's the resume bucket key.""" + + def test_hash_is_independent_of_seed_groups(self, mock_attack, sample_seed_groups): + a1 = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="same", + ) + a2 = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=[SeedAttackGroup(seeds=[SeedObjective(value="different-objective")])], + atomic_attack_name="same", + ) + assert a1.technique_eval_hash == a2.technique_eval_hash + + def test_hash_differs_for_different_attacks(self, sample_seed_groups): + attack_a = MagicMock(spec=AttackStrategy) + attack_a.get_identifier.return_value = ComponentIdentifier(class_name="AttackA", class_module="pyrit.test") + attack_b = MagicMock(spec=AttackStrategy) + attack_b.get_identifier.return_value = ComponentIdentifier(class_name="AttackB", class_module="pyrit.test") + + a1 = AtomicAttack( + attack_technique=AttackTechnique(attack=attack_a), + seed_groups=sample_seed_groups, + atomic_attack_name="same", + ) + a2 = AtomicAttack( + attack_technique=AttackTechnique(attack=attack_b), + seed_groups=sample_seed_groups, + atomic_attack_name="same", + ) + assert a1.technique_eval_hash != a2.technique_eval_hash + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackFilterSeedGroupsByObjectivesDeprecation: + """Tests for the deprecated ``filter_seed_groups_by_objectives`` shim + that ships with v0.13.0 → 0.16.0 deprecation.""" + + def test_emits_deprecation_warning(self, mock_attack, sample_seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + with patch("pyrit.scenario.core.atomic_attack.print_deprecation_message") as mock_dep: + atomic.filter_seed_groups_by_objectives(remaining_objectives=["objective1"]) + assert mock_dep.call_count == 1 + kwargs = mock_dep.call_args.kwargs + assert "filter_seed_groups_by_objectives" in kwargs["old_item"] + assert "keep_seed_groups_with_hashes" in kwargs["new_item"] + assert kwargs["removed_in"] == "0.16.0" + + def test_filters_by_text_match(self, mock_attack, sample_seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=sample_seed_groups, + atomic_attack_name="test", + ) + with patch("pyrit.scenario.core.atomic_attack.print_deprecation_message"): + atomic.filter_seed_groups_by_objectives(remaining_objectives=["objective2"]) + assert [sg.objective.value for sg in atomic.seed_groups] == ["objective2"] diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index e7042183d..9f3eaf9be 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -29,11 +29,36 @@ def save_attack_results_to_memory(attack_results): memory.add_attack_results_to_memory(attack_results=attack_results) -def create_mock_run_async(attack_results): - """Create a mock run_async that saves results to memory before returning.""" +def _stamp_scenario_linkage(*, attack_results, atomic_attack): + """ + Stamp attribution_parent_id + attribution_data on each AttackResult the + same way the real attack persistence path does. Mirrors what + ``_DefaultAttackStrategyEventHandler._apply_attribution`` does at runtime + so test fixtures that mock out the executor still produce DB rows the new + foreign-key-based hydration can find. + """ + sid = getattr(atomic_attack, "_scenario_result_id", None) + name = getattr(atomic_attack, "atomic_attack_name", None) + if not sid or not name: + return + for r in attack_results: + r.attribution_parent_id = sid + r.attribution_data = {"parent_collection": name} + + +def create_mock_run_async(attack_results, *, atomic_attack=None): + """ + Create a mock ``run_async`` that stamps + saves results to memory. + + Pass ``atomic_attack`` (the AtomicAttack MagicMock) so the helper can copy + its ``_scenario_result_id`` (set by ``Scenario._execute_scenario_async``) + and ``atomic_attack_name`` onto each result. Without those the foreign-key- + based hydration in ``get_scenario_results`` won't see the rows. + """ async def mock_run_async(*args, **kwargs): - # Save results to memory (mimics what real attacks do) + if atomic_attack is not None: + _stamp_scenario_linkage(attack_results=attack_results, atomic_attack=atomic_attack) save_attack_results_to_memory(attack_results) return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) @@ -52,18 +77,24 @@ def mock_atomic_attacks(): run1.atomic_attack_name = "attack_run_1" run1.display_group = "attack_run_1" run1._attack = mock_attack + run1._scenario_result_id = None + run1.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(run1, "_scenario_result_id", sid)) type(run1).objectives = PropertyMock(return_value=["objective1"]) run2 = MagicMock(spec=AtomicAttack) run2.atomic_attack_name = "attack_run_2" run2.display_group = "attack_run_2" run2._attack = mock_attack + run2._scenario_result_id = None + run2.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(run2, "_scenario_result_id", sid)) type(run2).objectives = PropertyMock(return_value=["objective2"]) run3 = MagicMock(spec=AtomicAttack) run3.atomic_attack_name = "attack_run_3" run3.display_group = "attack_run_3" run3._attack = mock_attack + run3._scenario_result_id = None + run3.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(run3, "_scenario_result_id", sid)) type(run3).objectives = PropertyMock(return_value=["objective3"]) return [run1, run2, run3] @@ -327,7 +358,7 @@ async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_att """Test that run_async executes all atomic attacks sequentially.""" # Configure each run to return different results for i, run in enumerate(mock_atomic_attacks): - run.run_async = create_mock_run_async([sample_attack_results[i]]) + run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) scenario = ConcreteScenario( name="Test Scenario", @@ -359,7 +390,7 @@ async def test_run_async_with_custom_concurrency( ): """Test that max_concurrency from init is passed to each atomic attack.""" for i, run in enumerate(mock_atomic_attacks): - run.run_async = create_mock_run_async([sample_attack_results[i]]) + run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) scenario = ConcreteScenario( name="Test Scenario", @@ -383,9 +414,15 @@ async def test_run_async_aggregates_multiple_results( ): """Test that results from multiple atomic attacks are properly aggregated.""" # Configure runs to return different numbers of results - mock_atomic_attacks[0].run_async = create_mock_run_async(sample_attack_results[0:2]) - mock_atomic_attacks[1].run_async = create_mock_run_async(sample_attack_results[2:4]) - mock_atomic_attacks[2].run_async = create_mock_run_async(sample_attack_results[4:5]) + mock_atomic_attacks[0].run_async = create_mock_run_async( + sample_attack_results[0:2], atomic_attack=mock_atomic_attacks[0] + ) + mock_atomic_attacks[1].run_async = create_mock_run_async( + sample_attack_results[2:4], atomic_attack=mock_atomic_attacks[1] + ) + mock_atomic_attacks[2].run_async = create_mock_run_async( + sample_attack_results[4:5], atomic_attack=mock_atomic_attacks[2] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -441,7 +478,7 @@ async def test_run_async_returns_scenario_result_with_identifier( ): """Test that run_async returns ScenarioResult with proper identifier.""" for i, run in enumerate(mock_atomic_attacks): - run.run_async = create_mock_run_async([sample_attack_results[i]]) + run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) scenario = ConcreteScenario( name="Test Scenario", @@ -502,6 +539,10 @@ async def test_atomic_attack_count_with_different_sizes(self, mock_objective_tar single_run_mock.atomic_attack_name = "attack_1" single_run_mock.display_group = "attack_1" single_run_mock._attack = mock_attack + single_run_mock._scenario_result_id = None + single_run_mock.set_scenario_result_id = MagicMock( + side_effect=lambda sid: setattr(single_run_mock, "_scenario_result_id", sid) + ) type(single_run_mock).objectives = PropertyMock(return_value=["obj1"]) single_run = [single_run_mock] @@ -519,6 +560,11 @@ async def test_atomic_attack_count_with_different_sizes(self, mock_objective_tar run.atomic_attack_name = f"attack_{i}" run.display_group = f"attack_{i}" run._attack = mock_attack + run._scenario_result_id = None + # Capture run by default arg to avoid late-binding in the closure. + run.set_scenario_result_id = MagicMock( + side_effect=lambda sid, _run=run: setattr(_run, "_scenario_result_id", sid) + ) type(run).objectives = PropertyMock(return_value=[f"obj{i}"]) many_runs.append(run) @@ -779,7 +825,9 @@ async def test_baseline_only_execution_runs_successfully(self, mock_objective_ta ) # Mock the baseline attack's run_async - scenario._atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) + scenario._atomic_attacks[0].run_async = create_mock_run_async( + [sample_attack_results[0]], atomic_attack=scenario._atomic_attacks[0] + ) # Run the scenario result = await scenario.run_async() diff --git a/tests/unit/scenario/test_scenario_partial_results.py b/tests/unit/scenario/test_scenario_partial_results.py index a18625dc8..f14683093 100644 --- a/tests/unit/scenario/test_scenario_partial_results.py +++ b/tests/unit/scenario/test_scenario_partial_results.py @@ -35,8 +35,21 @@ def mock_objective_target(): return target -def save_attack_results_to_memory(attack_results): - """Helper function to save attack results to memory.""" +def save_attack_results_to_memory(attack_results, *, atomic_attack=None): + """ + Helper function to save attack results to memory. When ``atomic_attack`` is + provided, also stamps ``attribution_parent_id`` and ``attribution_data`` on + each result the same way the real attack persistence path does — so + foreign-key-based + hydration in ``get_scenario_results`` finds them. + """ + if atomic_attack is not None: + sid = getattr(atomic_attack, "_scenario_result_id", None) + name = getattr(atomic_attack, "atomic_attack_name", None) + if sid and name: + for r in attack_results: + r.attribution_parent_id = sid + r.attribution_data = {"parent_collection": name} memory = CentralMemory.get_memory_instance() memory.add_attack_results_to_memory(attack_results=attack_results) @@ -44,8 +57,11 @@ def save_attack_results_to_memory(attack_results): def create_mock_atomic_attack(name: str, objectives: list[str]) -> MagicMock: """Create a mock AtomicAttack with required attributes for baseline creation. - The mock tracks its objectives and properly updates when filter_seed_groups_by_objectives is called. + The mock tracks its objectives and properly updates when + drop_seed_groups_with_hashes is called. """ + from pyrit.common.utils import to_sha256 + mock_attack_strategy = MagicMock() mock_attack_strategy.get_objective_target.return_value = MagicMock() mock_attack_strategy.get_attack_scoring_config.return_value = MagicMock() @@ -54,19 +70,24 @@ def create_mock_atomic_attack(name: str, objectives: list[str]) -> MagicMock: attack.atomic_attack_name = name attack.display_group = name attack._attack = mock_attack_strategy + attack._scenario_result_id = None + + def _set_scenario_result_id(scenario_result_id): + attack._scenario_result_id = scenario_result_id + + attack.set_scenario_result_id = MagicMock(side_effect=_set_scenario_result_id) - # Track current objectives in a mutable container so it can be updated + original_objectives = list(objectives) current_objectives = {"value": list(objectives)} - # Configure objectives property to return current objectives type(attack).objectives = PropertyMock(side_effect=lambda: current_objectives["value"]) + type(attack).seed_groups = PropertyMock(side_effect=lambda: current_objectives["value"]) - # Configure filter_seed_groups_by_objectives to update the tracked objectives - def filter_objectives(*, remaining_objectives): - remaining_set = set(remaining_objectives) - current_objectives["value"] = [obj for obj in current_objectives["value"] if obj in remaining_set] + def drop_hashes(*, hashes): + current_objectives["value"] = [o for o in current_objectives["value"] if to_sha256(o) not in hashes] - attack.filter_seed_groups_by_objectives = MagicMock(side_effect=filter_objectives) + attack.drop_seed_groups_with_hashes = MagicMock(side_effect=drop_hashes) + attack._original_objectives = original_objectives return attack @@ -142,7 +163,7 @@ async def mock_run(*args, **kwargs): incomplete = [("obj3", ValueError("Failed to complete obj3"))] # Save completed results to memory - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) # Retry: complete the remaining objective @@ -154,7 +175,7 @@ async def mock_run(*args, **kwargs): executed_turns=1, ) ] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) atomic_attack.run_async = mock_run @@ -200,7 +221,7 @@ async def mock_run(*args, **kwargs): incomplete = [("obj3", RuntimeError("Failed obj3")), ("obj4", RuntimeError("Failed obj4"))] # Save completed results to memory - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) @@ -257,7 +278,7 @@ async def mock_run(*args, **kwargs): ] incomplete = [("obj4", Exception("Failed obj4")), ("obj5", Exception("Failed obj5"))] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) # Retry: complete remaining objectives @@ -271,7 +292,7 @@ async def mock_run(*args, **kwargs): for i in [4, 5] ] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) @@ -313,10 +334,12 @@ async def test_multiple_atomic_attacks_with_partial_results(self, mock_objective attack3 = create_mock_atomic_attack("attack_3", ["a3_obj1"]) call_counts = {"attack_1": 0, "attack_2": 0, "attack_3": 0} + attacks_by_name = {"attack_1": attack1, "attack_2": attack2, "attack_3": attack3} async def make_mock_run(attack_name, objectives): async def mock_run(*args, **kwargs): call_counts[attack_name] += 1 + this_attack = attacks_by_name[attack_name] if attack_name == "attack_2" and call_counts[attack_name] == 1: # Attack 2 fails partially on first attempt @@ -330,7 +353,7 @@ async def mock_run(*args, **kwargs): ] incomplete = [("a2_obj2", Exception("Failed a2_obj2")), ("a2_obj3", Exception("Failed a2_obj3"))] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=this_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=incomplete) # All other attempts succeed fully @@ -341,12 +364,10 @@ async def mock_run(*args, **kwargs): outcome=AttackOutcome.SUCCESS, executed_turns=1, ) - for obj in ( - attack1 if attack_name == "attack_1" else (attack2 if attack_name == "attack_2" else attack3) - ).objectives + for obj in this_attack.objectives ] - save_attack_results_to_memory(completed) + save_attack_results_to_memory(completed, atomic_attack=this_attack) return AttackExecutorResult(completed_results=completed, incomplete_objectives=[]) diff --git a/tests/unit/scenario/test_scenario_retry.py b/tests/unit/scenario/test_scenario_retry.py index 836503ff5..0add8d4e6 100644 --- a/tests/unit/scenario/test_scenario_retry.py +++ b/tests/unit/scenario/test_scenario_retry.py @@ -40,8 +40,20 @@ def mock_objective_scorer(): # Helper functions -def save_attack_results_to_memory(attack_results): - """Helper function to save attack results to memory (mimics what real attacks do).""" +def save_attack_results_to_memory(attack_results, *, atomic_attack=None): + """Helper function to save attack results to memory. + + When ``atomic_attack`` is provided, stamps ``attribution_parent_id`` and + ``attribution_data`` onto each result (mirrors the real attack persistence + path so foreign-key-based hydration sees the rows). + """ + if atomic_attack is not None: + sid = getattr(atomic_attack, "_scenario_result_id", None) + name = getattr(atomic_attack, "atomic_attack_name", None) + if sid and name: + for r in attack_results: + r.attribution_parent_id = sid + r.attribution_data = {"parent_collection": name} memory = CentralMemory.get_memory_instance() memory.add_attack_results_to_memory(attack_results=attack_results) @@ -86,19 +98,21 @@ def create_attack_results_list(count: int, start_index: int = 1) -> list[AttackR return [create_attack_result(i) for i in range(start_index, start_index + count)] -def create_mock_run_async(attack_results): - """Create a mock run_async that saves results to memory before returning. +def create_mock_run_async(attack_results, *, atomic_attack=None): + """Create a mock run_async that stamps + saves results to memory before returning. Args: attack_results: List of AttackResult objects to return + atomic_attack: Optional AtomicAttack mock. When provided, results are + stamped with attribution_parent_id and attribution_data so + foreign-key-based hydration finds them. Returns: AsyncMock configured to return the results """ async def mock_run_async(*args, **kwargs): - # Save results to memory (mimics what real attacks do) - save_attack_results_to_memory(attack_results) + save_attack_results_to_memory(attack_results, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) return AsyncMock(side_effect=mock_run_async) @@ -124,10 +138,25 @@ def create_mock_atomic_attack(name: str, objectives: list[str], run_async_mock: attack.atomic_attack_name = name attack.display_group = name attack._attack = mock_attack_strategy - type(attack).objectives = PropertyMock(return_value=objectives) + attack._scenario_result_id = None + + def _set_scenario_result_id(scenario_result_id): + attack._scenario_result_id = scenario_result_id + + attack.set_scenario_result_id = MagicMock(side_effect=_set_scenario_result_id) + + # Track objectives + objective-hash mapping so the hash-based filter + # behaves correctly in resume tests. + from pyrit.common.utils import to_sha256 + + current_objectives = {"value": list(objectives)} + type(attack).objectives = PropertyMock(side_effect=lambda: current_objectives["value"]) + type(attack).seed_groups = PropertyMock(side_effect=lambda: current_objectives["value"]) + + def drop_hashes(*, hashes): + current_objectives["value"] = [o for o in current_objectives["value"] if to_sha256(o) not in hashes] - # Configure filter_seed_groups_by_objectives - needed for scenario retry filtering - attack.filter_seed_groups_by_objectives = MagicMock() + attack.drop_seed_groups_with_hashes = MagicMock(side_effect=drop_hashes) if run_async_mock: attack.run_async = run_async_mock @@ -214,7 +243,7 @@ async def test_no_retry_on_success(self, mock_atomic_attacks, sample_attack_resu """Test that scenario doesn't retry when execution succeeds.""" # Configure successful execution for i, run in enumerate(mock_atomic_attacks): - run.run_async = create_mock_run_async([sample_attack_results[i]]) + run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) scenario = ConcreteScenario( name="Test Scenario", @@ -247,11 +276,13 @@ async def mock_run_with_retry(*args, **kwargs): raise Exception("Test failure") # Retry succeeds results = [sample_attack_results[0]] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=mock_atomic_attacks[0]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) mock_atomic_attacks[0].run_async = mock_run_with_retry - mock_atomic_attacks[1].run_async = create_mock_run_async([sample_attack_results[1]]) + mock_atomic_attacks[1].run_async = create_mock_run_async( + [sample_attack_results[1]], atomic_attack=mock_atomic_attacks[1] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -326,11 +357,13 @@ async def mock_run_with_multiple_retries(*args, **kwargs): raise Exception("Test failure") # Third attempt succeeds results = [sample_attack_results[0]] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=mock_atomic_attacks[0]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) mock_atomic_attacks[0].run_async = mock_run_with_multiple_retries - mock_atomic_attacks[1].run_async = create_mock_run_async([sample_attack_results[1]]) + mock_atomic_attacks[1].run_async = create_mock_run_async( + [sample_attack_results[1]], atomic_attack=mock_atomic_attacks[1] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -360,11 +393,13 @@ async def mock_run_with_logged_failure(*args, **kwargs): raise ValueError("First failure") # Retry succeeds results = [sample_attack_results[0]] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=mock_atomic_attacks[0]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) mock_atomic_attacks[0].run_async = mock_run_with_logged_failure - mock_atomic_attacks[1].run_async = create_mock_run_async([sample_attack_results[1]]) + mock_atomic_attacks[1].run_async = create_mock_run_async( + [sample_attack_results[1]], atomic_attack=mock_atomic_attacks[1] + ) scenario = ConcreteScenario( name="Test Scenario", @@ -407,12 +442,12 @@ async def mock_run_with_partial_completion(*args, **kwargs): # First attempt: complete 2 objectives, then fail executed_objectives.extend(["obj1", "obj2"]) results = [create_attack_result(i, objective=f"obj{i}") for i in [1, 2]] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=atomic_attack) raise Exception("Failed after 2 objectives") # Retry: should only execute remaining objectives (obj3, obj4) executed_objectives.extend(["obj3", "obj4"]) results = [create_attack_result(i, objective=f"obj{i}") for i in [3, 4]] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=atomic_attack) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) atomic_attack.run_async = mock_run_with_partial_completion @@ -448,7 +483,7 @@ async def test_resumes_skipping_completed_atomic_attacks(self, mock_objective_ta async def mock_run_attack1(*args, **kwargs): call_count["attack_1"] += 1 results = [create_attack_result(1, objective="objective1")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attack1) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) # Attack 2: Succeeds on first attempt, should not be retried @@ -456,7 +491,7 @@ async def mock_run_attack2(*args, **kwargs): call_count["attack_2"] += 1 if call_count["attack_2"] == 1: results = [create_attack_result(2, objective="objective2")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attack2) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) raise AssertionError("Attack 2 should not be retried after completion") @@ -466,7 +501,7 @@ async def mock_run_attack3(*args, **kwargs): if call_count["attack_3"] == 1: raise Exception("Attack 3 failed on first attempt") results = [create_attack_result(3, objective="objective3")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attack3) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) attack1.run_async = mock_run_attack1 @@ -509,7 +544,7 @@ async def test_resumes_with_multiple_failures_across_attacks(self, mock_objectiv async def mock_run_attack1(*args, **kwargs): call_count["attack_1"] += 1 results = [create_attack_result(1, objective="objective1")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attacks[0]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) # Attack 2: Fails on first attempt, succeeds on retry @@ -518,21 +553,21 @@ async def mock_run_attack2(*args, **kwargs): if call_count["attack_2"] == 1: raise Exception("Attack 2 failed") results = [create_attack_result(2, objective="objective2")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attacks[1]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) # Attack 3: Only called on retry (after attack 2 succeeds) async def mock_run_attack3(*args, **kwargs): call_count["attack_3"] += 1 results = [create_attack_result(3, objective="objective3")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attacks[2]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) # Attack 4: Only called on retry async def mock_run_attack4(*args, **kwargs): call_count["attack_4"] += 1 results = [create_attack_result(4, objective="objective4")] - save_attack_results_to_memory(results) + save_attack_results_to_memory(results, atomic_attack=attacks[3]) return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) attacks[0].run_async = mock_run_attack1 @@ -564,3 +599,283 @@ async def mock_run_attack4(*args, **kwargs): assert call_count["attack_4"] == 1 # All four attacks should be in results assert len(result.attack_results) == 4 + + +@pytest.mark.usefixtures("patch_central_database") +class TestScenarioForeignKeyResumeRegression: + """Regression tests for the foreign-key-based scenario linkage resume path. + + The bug being regression-tested: when a Scenario is interrupted mid- + AtomicAttack (Ctrl-C, OOM, crash), AttackResults already persisted to the + DB used to be invisible to the scenario because the scenario→attack-result + link only lived in a JSON manifest written after the whole AtomicAttack + returned. On resume, those objectives were re-executed (wasted compute). + + After the refactor, ``attribution_parent_id`` is stamped on each + ``AttackResultEntry`` at write time, so resume reads them directly and + skips the already-done work even when the manifest was never updated. + """ + + async def test_resume_skips_objectives_persisted_before_interruption(self, mock_objective_target): + """Simulate Ctrl-C after some objectives in an atomic attack persisted + results but before the manifest was bulk-written. On resume, only the + missing objectives are re-executed.""" + atomic_attack = create_mock_atomic_attack("partial", ["o1", "o2", "o3", "o4"]) + + async def first_run(*args, **kwargs): + partials = [ + create_attack_result(0, conversation_id="c1", objective="o1"), + create_attack_result(1, conversation_id="c2", objective="o2"), + ] + save_attack_results_to_memory(partials, atomic_attack=atomic_attack) + raise Exception("simulated crash after partial persistence") + + atomic_attack.run_async = first_run + + scenario = ConcreteScenario( + name="Interrupted Scenario", + version=1, + atomic_attacks_to_return=[atomic_attack], + ) + await scenario.initialize_async(objective_target=mock_objective_target, max_retries=0) + + with pytest.raises(Exception, match="simulated crash"): + await scenario.run_async() + + scenario_result_id = scenario._scenario_result_id + assert scenario_result_id is not None + + # === Resume by scenario_result_id === + atomic_attack_resume = create_mock_atomic_attack("partial", ["o1", "o2", "o3", "o4"]) + executed: list[str] = [] + + async def second_run(*args, **kwargs): + executed.extend(atomic_attack_resume.objectives) + results = [ + create_attack_result(i, conversation_id=f"c{i + 1}", objective=obj) + for i, obj in enumerate(atomic_attack_resume.objectives, start=2) + ] + save_attack_results_to_memory(results, atomic_attack=atomic_attack_resume) + return AttackExecutorResult(completed_results=results, incomplete_objectives=[]) + + atomic_attack_resume.run_async = second_run + + scenario_resumed = ConcreteScenario( + name="Interrupted Scenario", + version=1, + atomic_attacks_to_return=[atomic_attack_resume], + scenario_result_id=scenario_result_id, + ) + await scenario_resumed.initialize_async(objective_target=mock_objective_target, max_retries=0) + await scenario_resumed.run_async() + + # Resume executed only the missing objectives — the core fix. + assert executed == ["o3", "o4"] + + async def test_duplicate_objective_text_in_atomic_attack_is_rejected(self, mock_objective_target): + """Resume identity is the objective sha256 within an AtomicAttack, so + the real ``AtomicAttack.__init__`` refuses to construct with duplicate + objective text. We exercise the production constructor here to lock + that contract in (the resume mocks bypass it intentionally).""" + from pyrit.executor.attack import AttackStrategy + from pyrit.models import SeedAttackGroup, SeedObjective + from pyrit.scenario import AtomicAttack + from pyrit.scenario.core.attack_technique import AttackTechnique + + mock_attack = MagicMock(spec=AttackStrategy) + duplicate_groups = [ + SeedAttackGroup(seeds=[SeedObjective(value="dup-obj")]), + SeedAttackGroup(seeds=[SeedObjective(value="dup-obj")]), + ] + with pytest.raises(ValueError, match="duplicate objective hash"): + AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=duplicate_groups, + atomic_attack_name="dup_attack", + ) + + async def test_duplicate_atomic_attack_name_does_not_warn(self, mock_objective_target, caplog): + """Duplicate ``atomic_attack_name`` is supported: resume disambiguates + rows by ``(parent_collection, parent_eval_hash)``, so two atomic + attacks sharing a name with different techniques don't cross-pollinate + their completed-hash sets. No warning is emitted.""" + dup1 = create_mock_atomic_attack("dup_name", ["objA"]) + dup2 = create_mock_atomic_attack("dup_name", ["objB"]) + + async def noop_run(*args, **kwargs): + return AttackExecutorResult(completed_results=[], incomplete_objectives=[]) + + dup1.run_async = noop_run + dup2.run_async = noop_run + + scenario = ConcreteScenario( + name="Dup Name Scenario", + version=1, + atomic_attacks_to_return=[dup1, dup2], + ) + + with caplog.at_level("WARNING"): + await scenario.initialize_async(objective_target=mock_objective_target) + + assert not any("duplicate atomic_attack_name" in record.message for record in caplog.records), ( + "Duplicate atomic_attack_name should be supported without warning" + ) + + +@pytest.mark.usefixtures("patch_central_database") +class TestGetCompletedObjectiveHashesForAttack: + """Direct tests for ``Scenario._get_completed_objective_hashes_for_attack`` + — the filter that excludes already-completed objectives on resume. + + Covers the row-filtering branches: outcome=ERROR rows, rows without + attribution_data, and the technique-disambiguation branch where two + atomic attacks share a name but differ in technique eval hash. + """ + + def _make_scenario(self, scenario_result_id="scn-1"): + scenario = ConcreteScenario(name="S", version=1, atomic_attacks_to_return=[]) + scenario._scenario_result_id = scenario_result_id + scenario._memory = MagicMock() + return scenario + + def _make_atomic(self, name, eval_hash="hash-A"): + atomic = MagicMock(spec=AtomicAttack) + atomic.atomic_attack_name = name + type(atomic).technique_eval_hash = PropertyMock(return_value=eval_hash) + return atomic + + def _row(self, *, objective, outcome=AttackOutcome.SUCCESS, attribution_data=None): + row = MagicMock() + row.outcome = outcome + row.attribution_data = attribution_data + row.objective = objective + return row + + def test_returns_empty_when_scenario_result_id_unset(self): + scenario = ConcreteScenario(name="S", version=1, atomic_attacks_to_return=[]) + scenario._scenario_result_id = None + result = scenario._get_completed_objective_hashes_for_attack( + atomic_attack=self._make_atomic("a"), + ) + assert result == set() + + def test_skips_error_rows(self): + from pyrit.common.utils import to_sha256 + + scenario = self._make_scenario() + scenario._memory.get_attack_results.return_value = [ + self._row( + objective="ok", + outcome=AttackOutcome.SUCCESS, + attribution_data={"parent_collection": "a", "parent_eval_hash": "hash-A"}, + ), + self._row( + objective="failed", + outcome=AttackOutcome.ERROR, + attribution_data={"parent_collection": "a", "parent_eval_hash": "hash-A"}, + ), + ] + result = scenario._get_completed_objective_hashes_for_attack( + atomic_attack=self._make_atomic("a"), + ) + assert result == {to_sha256("ok")} + + def test_skips_rows_without_attribution_data(self): + from pyrit.common.utils import to_sha256 + + scenario = self._make_scenario() + scenario._memory.get_attack_results.return_value = [ + self._row(objective="legacy", attribution_data=None), + self._row( + objective="new", + attribution_data={"parent_collection": "a", "parent_eval_hash": "hash-A"}, + ), + ] + result = scenario._get_completed_objective_hashes_for_attack( + atomic_attack=self._make_atomic("a"), + ) + assert result == {to_sha256("new")} + + def test_skips_rows_with_mismatched_eval_hash(self): + """Two atomic attacks with the same name but different techniques + must not cross-pollinate completed hashes. This is the core Option-B + guarantee.""" + from pyrit.common.utils import to_sha256 + + scenario = self._make_scenario() + scenario._memory.get_attack_results.return_value = [ + self._row( + objective="mine", + attribution_data={"parent_collection": "encoding", "parent_eval_hash": "hash-base64"}, + ), + self._row( + objective="theirs", + attribution_data={"parent_collection": "encoding", "parent_eval_hash": "hash-hex"}, + ), + ] + result = scenario._get_completed_objective_hashes_for_attack( + atomic_attack=self._make_atomic("encoding", eval_hash="hash-base64"), + ) + assert result == {to_sha256("mine")} + + def test_backward_compat_matches_name_only_when_eval_hash_missing(self): + """Rows persisted before ``parent_eval_hash`` shipped match name-only + so pre-existing resume runs aren't stranded.""" + from pyrit.common.utils import to_sha256 + + scenario = self._make_scenario() + scenario._memory.get_attack_results.return_value = [ + self._row( + objective="old", + attribution_data={"parent_collection": "a"}, # no parent_eval_hash + ), + ] + result = scenario._get_completed_objective_hashes_for_attack( + atomic_attack=self._make_atomic("a", eval_hash="hash-A"), + ) + assert result == {to_sha256("old")} + + +@pytest.mark.usefixtures("patch_central_database") +class TestApplyPersistedObjectives: + """Direct tests for ``Scenario._apply_persisted_objectives`` — the + resume-time replay that locks subsequent runs to the originally-sampled + objective subset.""" + + def _make_scenario_with_atomics(self, atomics): + scenario = ConcreteScenario(name="S", version=1, atomic_attacks_to_return=[]) + scenario._scenario_result_id = "scn-1" + scenario._atomic_attacks = atomics + return scenario + + def test_noop_when_metadata_has_no_persisted_hashes(self): + atomic = MagicMock(spec=AtomicAttack) + scenario = self._make_scenario_with_atomics([atomic]) + stored = MagicMock() + stored.metadata = {} + scenario._apply_persisted_objectives(stored_result=stored) + atomic.keep_seed_groups_with_hashes.assert_not_called() + + def test_replays_persisted_subset_across_atomics(self): + atomic_a = MagicMock(spec=AtomicAttack) + atomic_a.keep_seed_groups_with_hashes.return_value = {"h1", "h2"} + atomic_b = MagicMock(spec=AtomicAttack) + atomic_b.keep_seed_groups_with_hashes.return_value = {"h3"} + scenario = self._make_scenario_with_atomics([atomic_a, atomic_b]) + + stored = MagicMock() + stored.metadata = {"objective_hashes": ["h1", "h2", "h3"]} + scenario._apply_persisted_objectives(stored_result=stored) + + atomic_a.keep_seed_groups_with_hashes.assert_called_once_with(hashes={"h1", "h2", "h3"}) + atomic_b.keep_seed_groups_with_hashes.assert_called_once_with(hashes={"h1", "h2", "h3"}) + + def test_raises_when_persisted_hash_is_missing(self): + atomic = MagicMock(spec=AtomicAttack) + atomic.keep_seed_groups_with_hashes.return_value = {"h1"} # h2 missing + scenario = self._make_scenario_with_atomics([atomic]) + + stored = MagicMock() + stored.metadata = {"objective_hashes": ["h1", "h2"]} + with pytest.raises(ValueError, match="cannot resume"): + scenario._apply_persisted_objectives(stored_result=stored)