Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/instructions/style-guide.instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,26 @@ def process_items(self, *, items: list[str]) -> list[str]:
return []
```

## Deprecations

When deprecating a public class, function, method, parameter, or module path, use `pyrit.common.deprecation.print_deprecation_message` — never `warnings.warn` directly. It wraps `warnings.warn(..., DeprecationWarning, stacklevel=3)` with a consistent format so filtering still works.

Set `removed_in` to **current version + 2 minor versions** (e.g. `0.14.x` → `removed_in="0.16.0"`). This gives one full release cycle of warning before removal.

```python
from pyrit.common.deprecation import print_deprecation_message

def old_method(self, *, foo: str) -> None:
print_deprecation_message(
old_item="MyClass.old_method",
new_item="MyClass.new_method",
removed_in="0.16.0",
)
...
```

`old_item` / `new_item` accept a class/callable (qualified name is generated) or a string.

## Pythonic Patterns

### List Comprehensions
Expand Down
10 changes: 7 additions & 3 deletions pyrit/backend/services/scenario_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,13 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari
error = scenario_result.error_message
error_type = scenario_result.error_type

# Fallback: look up error from persisted error AttackResults
if not error and scenario_result.error_attack_result_ids:
error_ars = self._memory.get_attack_results(attack_result_ids=scenario_result.error_attack_result_ids)
# Fallback: look up error from any persisted error AttackResults linked
# to this scenario via the new attribution_parent_id foreign key.
if not error:
error_ars = self._memory.get_attack_results(
scenario_result_id=scenario_result_id,
outcome=AttackOutcome.ERROR,
)
if error_ars:
error = error_ars[0].error_message
error_type = error_ars[0].error_type
Expand Down
34 changes: 30 additions & 4 deletions pyrit/executor/attack/core/attack_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import asyncio
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Expand All @@ -19,6 +19,7 @@
)

from pyrit.executor.attack.core.attack_parameters import AttackParameters
from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution
from pyrit.executor.attack.core.attack_strategy import (
AttackStrategy,
AttackStrategyContextT,
Expand Down Expand Up @@ -142,6 +143,7 @@ async def execute_attack_from_seed_groups_async(
objective_scorer: Optional["TrueFalseScorer"] = None,
field_overrides: Optional[Sequence[dict[str, Any]]] = None,
return_partial_on_failure: bool = False,
attribution_factory: Optional[Callable[[int], AttackResultAttribution]] = None,
**broadcast_fields: Any,
) -> AttackExecutorResult[AttackStrategyResultT]:
"""
Expand All @@ -163,6 +165,13 @@ async def execute_attack_from_seed_groups_async(
from_seed_group() as overrides.
return_partial_on_failure: If True, returns partial results when some
objectives fail. If False (default), raises the first exception.
attribution_factory: Optional callable that maps an input index (the
seed group's original index, parallel-safe and deterministic) to
an ``AttackResultAttribution``. When provided, each per-task
``AttackContext`` is stamped with the attribution so the
resulting ``AttackResultEntry`` row carries
``attribution_parent_id`` + ``attribution_data``. When ``None``,
no attribution is applied.
**broadcast_fields: Fields applied to all seed groups (e.g., memory_labels).
Per-seed-group field_overrides take precedence.

Expand Down Expand Up @@ -205,6 +214,7 @@ async def build_params(i: int, sg: SeedAttackGroup) -> AttackParameters:
attack=attack,
params_list=params_list,
return_partial_on_failure=return_partial_on_failure,
attribution_factory=attribution_factory,
)

async def execute_attack_async(
Expand All @@ -214,6 +224,7 @@ async def execute_attack_async(
objectives: Sequence[str],
field_overrides: Optional[Sequence[dict[str, Any]]] = None,
return_partial_on_failure: bool = False,
attribution_factory: Optional[Callable[[int], AttackResultAttribution]] = None,
**broadcast_fields: Any,
) -> AttackExecutorResult[AttackStrategyResultT]:
"""
Expand All @@ -228,6 +239,10 @@ async def execute_attack_async(
must match the length of objectives.
return_partial_on_failure: If True, returns partial results when some
objectives fail. If False (default), raises the first exception.
attribution_factory: Optional callable mapping each input index to
a AttackResultAttribution. When provided, the per-task context is
stamped with the attribution so the persistence path can record
scenario linkage.
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated
**broadcast_fields: Fields applied to all objectives (e.g., memory_labels).
Per-objective field_overrides take precedence.

Expand Down Expand Up @@ -268,6 +283,7 @@ async def execute_attack_async(
attack=attack,
params_list=params_list,
return_partial_on_failure=return_partial_on_failure,
attribution_factory=attribution_factory,
)

async def _execute_with_params_list_async(
Expand All @@ -276,6 +292,7 @@ async def _execute_with_params_list_async(
attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT],
params_list: Sequence[AttackParameters],
return_partial_on_failure: bool = False,
attribution_factory: Optional[Callable[[int], AttackResultAttribution]] = None,
) -> AttackExecutorResult[AttackStrategyResultT]:
"""
Execute attacks in parallel with a list of pre-built parameters.
Expand All @@ -287,19 +304,28 @@ async def _execute_with_params_list_async(
attack: The attack strategy to execute.
params_list: List of AttackParameters, one per execution.
return_partial_on_failure: If True, returns partial results on failure.
attribution_factory: Optional callable mapping each input index to
a AttackResultAttribution. When provided, the per-task context is
stamped with the attribution so the persistence path can record
scenario linkage.

Returns:
AttackExecutorResult with completed results and any incomplete objectives.
"""
semaphore = asyncio.Semaphore(self._max_concurrency)

async def run_one(params: AttackParameters) -> AttackStrategyResultT:
async def run_one(index: int, params: AttackParameters) -> AttackStrategyResultT:
async with semaphore:
# Create context with params
# Create context with params and stamp attribution (if any). The
# input index is the seed group's original position and is
# deterministic and parallel-safe — assigned BEFORE the task
# runs, not from completion order.
context = attack._context_type(params=params)
if attribution_factory is not None:
context._attribution = attribution_factory(index)
return await attack.execute_with_context_async(context=context)

tasks = [run_one(p) for p in params_list]
tasks = [run_one(i, p) for i, p in enumerate(params_list)]
results_or_exceptions = await asyncio.gather(*tasks, return_exceptions=True)

return self._process_execution_results(
Expand Down
51 changes: 51 additions & 0 deletions pyrit/executor/attack/core/attack_result_attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Generic attribution metadata that an upstream orchestrator can stamp onto an
``AttackContext`` so the persisted ``AttackResult`` carries the linkage back
to whatever produced it.

The attack layer treats this as opaque infrastructure — three string-typed
fields, no scenario semantics. The orchestrator (e.g. ``Scenario``) interprets
them however it likes. Keeping the type in ``executor`` rather than
``scenario`` means the persistence path has no dependency on the
``pyrit.scenario`` package.
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated
"""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class AttackResultAttribution:
"""
Attribution stamped onto an ``AttackContext`` by an upstream caller (the
``AttackExecutor`` populates it via an ``attribution_factory``) and copied
onto the resulting ``AttackResult`` by the attack persistence path so the
DB row records its lineage.

All three fields are opaque to the attack layer. The orchestrator chooses
what they mean and how to query them back later. For example,
``Scenario`` uses ``parent_id`` for the scenario result UUID,
``parent_collection`` for the atomic attack name, and ``position`` for
the original 0-based seed-group index.
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated

Attributes:
parent_id (str): The ID of the parent entity that owns this attack
execution. Persisted to ``AttackResultEntry.attribution_parent_id``
and indexed (with a foreign key to ``ScenarioResultEntries.id``)
so per-parent hydration and resume lookups are direct.
parent_collection (str): A free-form label naming the per-parent
collection this result belongs to. Persisted into
``AttackResultEntry.attribution_data``.
position (int): The 0-based position of this result within its
``parent_collection``. Assigned **before** task execution so it is
deterministic and parallel-safe, and used as the stable resume key.
Persisted into ``AttackResultEntry.attribution_data``.
"""

parent_id: str
parent_collection: str
position: int
52 changes: 44 additions & 8 deletions pyrit/executor/attack/core/attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if TYPE_CHECKING:
from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution
from pyrit.prompt_target import PromptTarget

AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext[Any]")
Expand Down Expand Up @@ -70,8 +71,12 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]):
_prepended_conversation_override: Optional[list[Message]] = None
_memory_labels_override: Optional[dict[str, str]] = None

# Set by the ON_ERROR handler to link error AttackResults to ScenarioResults
_error_attack_result_id: str | None = None
# Optional attribution from an upstream orchestrator (e.g. Scenario). When
# set, the persistence path stamps attribution_parent_id + attribution_data
# onto the resulting AttackResult so it can be located later for hydration
# and resume. Set by AttackExecutor per-task before scheduling. Stays None
# for ad-hoc/direct attack execution outside any orchestrator.
_attribution: Optional[AttackResultAttribution] = None

# Convenience properties that delegate to params or overrides
@property
Expand Down Expand Up @@ -223,11 +228,44 @@ async def _on_post_execute(
event_data.result.retry_events = collector.events
event_data.result.total_retries = len(collector.events)

# Stamp attribution onto the result before persistence so the
# AttackResultEntry row records its lineage. Outside an orchestrator
# _attribution is None and both attribution fields stay None.
self._stamp_attribution(context=event_data.context, result=event_data.result)

self._logger.debug(f"Attack execution completed in {execution_time_ms}ms")

self._log_attack_outcome(event_data.result)
self._memory.add_attack_results_to_memory(attack_results=[event_data.result])

@staticmethod
def _stamp_attribution(
Comment thread
rlundeen2 marked this conversation as resolved.
Outdated
*,
context: AttackStrategyContextT,
result: AttackResult,
) -> None:
"""
Copy attribution from the AttackContext onto the AttackResult.

Reads ``context._attribution`` (an ``AttackResultAttribution`` set by
the AttackExecutor when an upstream orchestrator supplied a factory).
When present, writes ``attribution_parent_id`` and a fixed-schema
``attribution_data`` dict onto the result so they round-trip into
``AttackResultEntry``.

Args:
context: The per-task AttackContext.
result: The AttackResult that is about to be persisted.
"""
attribution = context._attribution
if attribution is None:
return
result.attribution_parent_id = attribution.parent_id
result.attribution_data = {
"parent_collection": attribution.parent_collection,
"position": attribution.position,
}

def _log_attack_outcome(self, result: AttackResult) -> None:
"""
Log the outcome of the attack.
Expand Down Expand Up @@ -267,9 +305,6 @@ async def _on_error_async(
if not error or not context:
return

# Clear any stale ID from a previous execution
context._error_attack_result_id = None

# Collect retry events (visible via inherited ContextVar copy)
collector = get_retry_collector()
retry_events = collector.events if collector else []
Expand All @@ -295,10 +330,11 @@ async def _on_error_async(
if context.start_time:
error_result.execution_time_ms = int((end_time - context.start_time) * 1000)

# Persist first, then set the ID on the context so scenario-level code
# only sees the reference if the write succeeded.
# Stamp attribution onto the error result so it is locatable via the
# attribution_parent_id foreign key on resume.
self._stamp_attribution(context=context, result=error_result)

self._memory.add_attack_results_to_memory(attack_results=[error_result])
context._error_attack_result_id = error_result.attack_result_id

self._logger.error(f"Attack failed with {type(error).__name__}: {error}")

Expand Down
10 changes: 2 additions & 8 deletions pyrit/executor/core/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@


class _StrategyRuntimeError(RuntimeError):
"""RuntimeError subclass that carries an optional error_attack_result_id."""

def __init__(self, message: str, *, error_attack_result_id: str | None = None) -> None:
super().__init__(message)
self.error_attack_result_id = error_attack_result_id
"""RuntimeError subclass for strategy execution failures."""


@dataclass
Expand Down Expand Up @@ -386,9 +382,7 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra
else:
error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}"

# Attach the error attack result ID if the ON_ERROR handler created one
error_attack_result_id = getattr(context, "_error_attack_result_id", None)
runtime_error = _StrategyRuntimeError(error_message, error_attack_result_id=error_attack_result_id)
runtime_error = _StrategyRuntimeError(error_message)
raise runtime_error from e

async def execute_async(self, **kwargs: Any) -> StrategyResultT:
Expand Down
Loading