Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/instructions/style-guide.instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,26 @@ def process_items(self, *, items: list[str]) -> list[str]:
return []
```

## Deprecations

When deprecating a public class, function, method, parameter, or module path, use `pyrit.common.deprecation.print_deprecation_message` — never `warnings.warn` directly. It wraps `warnings.warn(..., DeprecationWarning, stacklevel=3)` with a consistent format so filtering still works.

Set `removed_in` to **current version + 2 minor versions** (e.g. `0.14.x` → `removed_in="0.16.0"`). This gives one full release cycle of warning before removal.

```python
from pyrit.common.deprecation import print_deprecation_message

def old_method(self, *, foo: str) -> None:
print_deprecation_message(
old_item="MyClass.old_method",
new_item="MyClass.new_method",
removed_in="0.16.0",
)
...
```

`old_item` / `new_item` accept a class/callable (qualified name is generated) or a string.

## Pythonic Patterns

### List Comprehensions
Expand Down
10 changes: 7 additions & 3 deletions pyrit/backend/services/scenario_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,13 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari
error = scenario_result.error_message
error_type = scenario_result.error_type

# Fallback: look up error from persisted error AttackResults
if not error and scenario_result.error_attack_result_ids:
error_ars = self._memory.get_attack_results(attack_result_ids=scenario_result.error_attack_result_ids)
# Fallback: look up error from any persisted error AttackResults linked
# to this scenario via the new FK.
if not error:
error_ars = self._memory.get_attack_results(
scenario_result_id=scenario_result_id,
outcome=AttackOutcome.ERROR,
)
if error_ars:
error = error_ars[0].error_message
error_type = error_ars[0].error_type
Expand Down
33 changes: 29 additions & 4 deletions pyrit/executor/attack/core/attack_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import asyncio
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Expand All @@ -24,6 +24,7 @@
AttackStrategyContextT,
AttackStrategyResultT,
)
from pyrit.executor.attack.core.scenario_execution_attribution import ScenarioExecutionAttribution
from pyrit.models import SeedAttackGroup

if TYPE_CHECKING:
Expand Down Expand Up @@ -142,6 +143,7 @@ async def execute_attack_from_seed_groups_async(
objective_scorer: Optional["TrueFalseScorer"] = None,
field_overrides: Optional[Sequence[dict[str, Any]]] = None,
return_partial_on_failure: bool = False,
attribution_factory: Optional[Callable[[int], ScenarioExecutionAttribution]] = None,
**broadcast_fields: Any,
) -> AttackExecutorResult[AttackStrategyResultT]:
"""
Expand All @@ -163,6 +165,12 @@ async def execute_attack_from_seed_groups_async(
from_seed_group() as overrides.
return_partial_on_failure: If True, returns partial results when some
objectives fail. If False (default), raises the first exception.
attribution_factory: Optional callable that maps an input index (the
seed group's original index, parallel-safe and deterministic) to
a ``ScenarioExecutionAttribution``. When provided, each per-task
``AttackContext`` is stamped with the attribution so the
resulting ``AttackResultEntry`` row carries the scenario FK +
scenario_data. When ``None``, no attribution is applied.
**broadcast_fields: Fields applied to all seed groups (e.g., memory_labels).
Per-seed-group field_overrides take precedence.

Expand Down Expand Up @@ -205,6 +213,7 @@ async def build_params(i: int, sg: SeedAttackGroup) -> AttackParameters:
attack=attack,
params_list=params_list,
return_partial_on_failure=return_partial_on_failure,
attribution_factory=attribution_factory,
)

async def execute_attack_async(
Expand All @@ -214,6 +223,7 @@ async def execute_attack_async(
objectives: Sequence[str],
field_overrides: Optional[Sequence[dict[str, Any]]] = None,
return_partial_on_failure: bool = False,
attribution_factory: Optional[Callable[[int], ScenarioExecutionAttribution]] = None,
**broadcast_fields: Any,
) -> AttackExecutorResult[AttackStrategyResultT]:
"""
Expand All @@ -228,6 +238,10 @@ async def execute_attack_async(
must match the length of objectives.
return_partial_on_failure: If True, returns partial results when some
objectives fail. If False (default), raises the first exception.
attribution_factory: Optional callable mapping each input index to
a ScenarioExecutionAttribution. When provided, the per-task context is
stamped with the attribution so the persistence path can record
scenario linkage.
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated
**broadcast_fields: Fields applied to all objectives (e.g., memory_labels).
Per-objective field_overrides take precedence.

Expand Down Expand Up @@ -268,6 +282,7 @@ async def execute_attack_async(
attack=attack,
params_list=params_list,
return_partial_on_failure=return_partial_on_failure,
attribution_factory=attribution_factory,
)

async def _execute_with_params_list_async(
Expand All @@ -276,6 +291,7 @@ async def _execute_with_params_list_async(
attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT],
params_list: Sequence[AttackParameters],
return_partial_on_failure: bool = False,
attribution_factory: Optional[Callable[[int], ScenarioExecutionAttribution]] = None,
) -> AttackExecutorResult[AttackStrategyResultT]:
"""
Execute attacks in parallel with a list of pre-built parameters.
Expand All @@ -287,19 +303,28 @@ async def _execute_with_params_list_async(
attack: The attack strategy to execute.
params_list: List of AttackParameters, one per execution.
return_partial_on_failure: If True, returns partial results on failure.
attribution_factory: Optional callable mapping each input index to
a ScenarioExecutionAttribution. When provided, the per-task context is
stamped with the attribution so the persistence path can record
scenario linkage.

Returns:
AttackExecutorResult with completed results and any incomplete objectives.
"""
semaphore = asyncio.Semaphore(self._max_concurrency)

async def run_one(params: AttackParameters) -> AttackStrategyResultT:
async def run_one(index: int, params: AttackParameters) -> AttackStrategyResultT:
async with semaphore:
# Create context with params
# Create context with params and stamp attribution (if any). The
# input index is the seed group's original position and is
# deterministic and parallel-safe — assigned BEFORE the task
# runs, not from completion order.
context = attack._context_type(params=params)
if attribution_factory is not None:
context._attribution = attribution_factory(index)
return await attack.execute_with_context_async(context=context)

tasks = [run_one(p) for p in params_list]
tasks = [run_one(i, p) for i, p in enumerate(params_list)]
results_or_exceptions = await asyncio.gather(*tasks, return_exceptions=True)

return self._process_execution_results(
Expand Down
51 changes: 43 additions & 8 deletions pyrit/executor/attack/core/attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if TYPE_CHECKING:
from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.executor.attack.core.scenario_execution_attribution import ScenarioExecutionAttribution
from pyrit.prompt_target import PromptTarget

AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext[Any]")
Expand Down Expand Up @@ -70,8 +71,12 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]):
_prepended_conversation_override: Optional[list[Message]] = None
_memory_labels_override: Optional[dict[str, str]] = None

# Set by the ON_ERROR handler to link error AttackResults to ScenarioResults
_error_attack_result_id: str | None = None
# Optional attribution from an upstream orchestrator (e.g. Scenario). When
# set, the persistence path stamps scenario_result_id + scenario_data onto
# the resulting AttackResult so it can be located later for hydration and
# resume. Set by AttackExecutor per-task before scheduling. Stays None for
# ad-hoc/direct attack execution outside any scenario.
_attribution: Optional[ScenarioExecutionAttribution] = None
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated

# Convenience properties that delegate to params or overrides
@property
Expand Down Expand Up @@ -223,11 +228,43 @@ async def _on_post_execute(
event_data.result.retry_events = collector.events
event_data.result.total_retries = len(collector.events)

# Stamp scenario attribution onto the result before persistence so the
# AttackResultEntry row carries the FK + scenario_data. Outside scenarios
# _attribution is None and both fields stay None.
self._stamp_attribution(context=event_data.context, result=event_data.result)

self._logger.debug(f"Attack execution completed in {execution_time_ms}ms")

self._log_attack_outcome(event_data.result)
self._memory.add_attack_results_to_memory(attack_results=[event_data.result])

@staticmethod
def _stamp_attribution(
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated
*,
context: AttackStrategyContextT,
result: AttackResult,
) -> None:
"""
Copy scenario attribution from the AttackContext onto the AttackResult.

Reads ``context._attribution`` (a ``ScenarioExecutionAttribution`` set by the
AttackExecutor when running inside a Scenario). When present, writes
``scenario_result_id`` and a fixed-schema ``scenario_data`` dict onto
the result so they round-trip into ``AttackResultEntry``.

Args:
context: The per-task AttackContext.
result: The AttackResult that is about to be persisted.
"""
attribution = context._attribution
if attribution is None:
return
result.scenario_result_id = attribution.scenario_result_id
result.scenario_data = {
"atomic_attack_name": attribution.atomic_attack_name,
"objective_index": attribution.objective_index,
}

def _log_attack_outcome(self, result: AttackResult) -> None:
"""
Log the outcome of the attack.
Expand Down Expand Up @@ -267,9 +304,6 @@ async def _on_error_async(
if not error or not context:
return

# Clear any stale ID from a previous execution
context._error_attack_result_id = None

# Collect retry events (visible via inherited ContextVar copy)
collector = get_retry_collector()
retry_events = collector.events if collector else []
Expand All @@ -295,10 +329,11 @@ async def _on_error_async(
if context.start_time:
error_result.execution_time_ms = int((end_time - context.start_time) * 1000)

# Persist first, then set the ID on the context so scenario-level code
# only sees the reference if the write succeeded.
# Stamp scenario attribution onto the error result so it is locatable
# via the scenario FK on resume.
self._stamp_attribution(context=context, result=error_result)

self._memory.add_attack_results_to_memory(attack_results=[error_result])
context._error_attack_result_id = error_result.attack_result_id

self._logger.error(f"Attack failed with {type(error).__name__}: {error}")

Expand Down
43 changes: 43 additions & 0 deletions pyrit/executor/attack/core/scenario_execution_attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Typed attribution metadata used to link a persisted ``AttackResult`` to the
``Scenario`` that produced it.

Lives in the ``executor`` layer (rather than ``scenario``) so the attack
persistence path — the consumer — does not introduce a dependency on
``pyrit.scenario``.
"""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class ScenarioExecutionAttribution:
"""
Scenario-linkage metadata stamped onto an ``AttackContext`` by the
``AttackExecutor`` and copied onto the resulting ``AttackResult`` by the
attack persistence path so the row carries the scenario FK + scenario_data.

Attributes:
scenario_result_id (str): The ID of the scenario result that produced
this attack execution. Persisted to
``AttackResultEntry.scenario_result_id`` so per-scenario hydration
and resume can locate the row directly without relying on a JSON
manifest written at the end of an atomic attack.
atomic_attack_name (str): The unique key of the atomic attack within
the scenario (matches ``AtomicAttack.atomic_attack_name``).
Persisted into ``AttackResultEntry.scenario_data``.
objective_index (int): The 0-based original seed-group index (the
``input_indices`` value from ``AttackExecutorResult``). Assigned
**before** task execution so it is deterministic and parallel-safe.
Persisted into ``AttackResultEntry.scenario_data`` and used as the
stable resume key (instead of the easily-duplicated objective text).
"""

scenario_result_id: str
atomic_attack_name: str
objective_index: int
10 changes: 2 additions & 8 deletions pyrit/executor/core/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@


class _StrategyRuntimeError(RuntimeError):
"""RuntimeError subclass that carries an optional error_attack_result_id."""

def __init__(self, message: str, *, error_attack_result_id: str | None = None) -> None:
super().__init__(message)
self.error_attack_result_id = error_attack_result_id
"""RuntimeError subclass for strategy execution failures."""


@dataclass
Expand Down Expand Up @@ -386,9 +382,7 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra
else:
error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}"

# Attach the error attack result ID if the ON_ERROR handler created one
error_attack_result_id = getattr(context, "_error_attack_result_id", None)
runtime_error = _StrategyRuntimeError(error_message, error_attack_result_id=error_attack_result_id)
runtime_error = _StrategyRuntimeError(error_message)
raise runtime_error from e

async def execute_async(self, **kwargs: Any) -> StrategyResultT:
Expand Down
Loading
Loading