Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
718 changes: 718 additions & 0 deletions doc/code/targets/round_robin_target.ipynb

Large diffs are not rendered by default.

271 changes: 271 additions & 0 deletions doc/code/targets/round_robin_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.19.1
# ---

# %% [markdown]
# # Round Robin Target
#
# The `RoundRobinTarget` distributes requests across multiple inner targets using weighted round-robin
# selection. This is useful for load-balancing across multiple deployments of the same model (e.g.,
# Azure OpenAI endpoints in different regions) to avoid rate limits or spread cost.
#
# **Key considerations:**
# - All inner targets must be the same concrete class (e.g., all `OpenAIChatTarget`).
# - All inner targets must have identical TargetConfigurations (capabilities, policy, and normalization pipeline)
# - All inner targets must support multi-turn conversations and editable history.
# - Inner targets must have the same behavioral parameters (model, temperature, top_p) used for evaluation hashing. This allows
# users to evaluate round-robin targets for scoring and attack evaluation with confidence that results are comparable to using the
# inner targets directly.
# - Requests are distributed per-call, not per-conversation — any target can handle any turn.
# - Memory entries use the round-robin's identifier. The inner target that handled each
# request is recorded in `prompt_metadata["inner_target_identifier"]`.
# - Optional integer weights control the distribution ratio.

# %% [markdown]
# ## Basic Usage
#
# In this example, we create two `OpenAIChatTarget` instances pointing to different endpoints
# (simulating two regional deployments of the same model) and wrap them in a `RoundRobinTarget`.
# We then send multiple prompts and show which inner target handled each one.

# %%
import os

from pyrit.auth import get_azure_openai_auth
from pyrit.models import Message
from pyrit.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target import OpenAIChatTarget, RoundRobinTarget
from pyrit.setup import IN_MEMORY, initialize_pyrit_async

await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore

# Create two targets pointing to different regional deployments of the same model.
endpoint_a = os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"]
endpoint_b = os.environ["AZURE_OPENAI_GPT4O_ENDPOINT2"]

target_a = OpenAIChatTarget(
endpoint=endpoint_a,
api_key=get_azure_openai_auth(endpoint_a),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"],
)
target_b = OpenAIChatTarget(
endpoint=endpoint_b,
api_key=get_azure_openai_auth(endpoint_b),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL2"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL2"],
)

# Wrap them in a RoundRobinTarget
rr_target = RoundRobinTarget(targets=[target_a, target_b])

# Send 4 prompts and observe the round-robin distribution
normalizer = PromptNormalizer()
prompts = [
"What is 2 + 2?",
"What color is the sky?",
"Name a prime number.",
"What is the capital of France?",
]

for i, prompt in enumerate(prompts):
message = Message.from_prompt(prompt=prompt, role="user")
response = await normalizer.send_prompt_async(message=message, target=rr_target) # type: ignore

# Show which inner target handled this request
inner_hash = response.message_pieces[0].prompt_metadata.get("inner_target_identifier", "N/A")
target_label = "Target A" if inner_hash == target_a.get_identifier().hash else "Target B"
Comment thread
jsong468 marked this conversation as resolved.
print(f"Prompt {i + 1}: '{prompt}' → handled by {target_label}")
print(f" Response: {response.message_pieces[0].converted_value[:80]}...")
print()

# %% [markdown]
# ## Weighted Distribution
#
# You can pass `weights` to control the distribution ratio. For example, `weights=[2, 1]`
# sends roughly twice as many requests to the first target. This is useful when one
# deployment has higher rate limits or capacity.

# %%
await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore

target_a = OpenAIChatTarget(
endpoint=endpoint_a,
api_key=get_azure_openai_auth(endpoint_a),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"],
)
target_b = OpenAIChatTarget(
endpoint=endpoint_b,
api_key=get_azure_openai_auth(endpoint_b),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL2"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL2"],
)

# Target A gets 2x the traffic
rr_weighted = RoundRobinTarget(targets=[target_a, target_b], weights=[2, 1])

normalizer = PromptNormalizer()
prompts = ["Prompt 1", "Prompt 2", "Prompt 3", "Prompt 4", "Prompt 5", "Prompt 6"]

target_a_hash = target_a.get_identifier().hash
counts = {"Target A": 0, "Target B": 0}

for prompt in prompts:
message = Message.from_prompt(prompt=prompt, role="user")
response = await normalizer.send_prompt_async(message=message, target=rr_weighted) # type: ignore
inner_hash = response.message_pieces[0].prompt_metadata.get("inner_target_identifier", "N/A")
label = "Target A" if inner_hash == target_a_hash else "Target B"
counts[label] += 1
print(f" '{prompt}' → {label}")

print(f"\nDistribution: Target A = {counts['Target A']}, Target B = {counts['Target B']}")

# %% [markdown]
# ## Multi-Turn Attack (Crescendo)
#
# The `RoundRobinTarget` works seamlessly with multi-turn attacks like Crescendo. Because
# round-robin targets require editable history, any inner target can reconstruct the full
# conversation from shared memory on each turn. This means different turns of the same
# conversation may be handled by different inner targets — true load-balancing even within
# a single multi-turn interaction.
#
# Note that using a `RoundRobinTarget` within a multi-turn attack can lead to greater API costs
# due to loss of prompt caching. For multi-turn attacks like Crescendo with many objectives, this
# can significantly increase API cost compared to pinning each conversation to a single target.
# This is a cost/latency vs. throughput trade-off — round-robin avoids per-endpoint rate limits at
# the expense of server-side caching. Users who need cache-efficient multi-turn conversations should
# assign individual targets at the attack or scenario level rather than using round-robin for those workloads.
#

# %%
from pyrit.executor.attack import AttackAdversarialConfig, CrescendoAttack
from pyrit.output import output_attack_async

await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore

# The objective target is a round-robin across two deployments
target_a = OpenAIChatTarget(
endpoint=endpoint_a,
api_key=get_azure_openai_auth(endpoint_a),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"],
)
target_b = OpenAIChatTarget(
endpoint=endpoint_b,
api_key=get_azure_openai_auth(endpoint_b),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL2"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL2"],
)
rr_target = RoundRobinTarget(targets=[target_a, target_b])

# The adversarial chat (used to generate attack prompts) can also be a round-robin,
# or a single target — it's independent of the objective target.
adversarial_chat = OpenAIChatTarget(
endpoint=endpoint_a,
api_key=get_azure_openai_auth(endpoint_a),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"],
)
adversarial_config = AttackAdversarialConfig(target=adversarial_chat)

attack = CrescendoAttack(
objective_target=rr_target,
attack_adversarial_config=adversarial_config,
max_turns=5,
max_backtracks=3,
)

result = await attack.execute_async( # type: ignore
objective="Explain how to make a paper airplane that flies really far"
)
await output_attack_async(result)

# %% [markdown]
# ## Scoring Multiple Responses
#
# When using a `RoundRobinTarget` as the scorer's chat target, the scoring requests
# themselves are load-balanced across the inner targets. This is especially useful when
# scoring many responses in batch — the round-robin distributes the scoring LLM calls
# across deployments.

# %%
from pyrit.executor.attack import AttackExecutor, PromptSendingAttack
from pyrit.score import ContentClassifierPaths, SelfAskCategoryScorer

await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore

# Step 1: Generate some responses to score using a simple attack
objective_target = OpenAIChatTarget(
endpoint=endpoint_a,
api_key=get_azure_openai_auth(endpoint_a),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"],
)
attack = PromptSendingAttack(objective_target=objective_target)

objectives = [
"What is photosynthesis?",
"Tell me a joke about cats.",
"Explain how computers work.",
"What is the meaning of life?",
]

results = await AttackExecutor().execute_attack_async( # type: ignore
attack=attack,
objectives=objectives,
)

# Step 2: Score all responses using a round-robin scorer target
# The scorer's LLM calls are distributed across both targets
scorer_target_a = OpenAIChatTarget(
endpoint=endpoint_a,
api_key=get_azure_openai_auth(endpoint_a),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"],
)
scorer_target_b = OpenAIChatTarget(
endpoint=endpoint_b,
api_key=get_azure_openai_auth(endpoint_b),
model_name=os.environ["AZURE_OPENAI_GPT4O_MODEL2"],
underlying_model=os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL2"],
)
rr_scorer_target = RoundRobinTarget(targets=[scorer_target_a, scorer_target_b], weights=[2, 1])

scorer = SelfAskCategoryScorer(
chat_target=rr_scorer_target,
content_classifier_path=ContentClassifierPaths.HARMFUL_CONTENT_CLASSIFIER.value,
)

# Collect response messages for scoring
response_messages = [r.last_response.to_message() for r in results if r.last_response is not None]

scorer_target_a_hash = scorer_target_a.get_identifier().hash

# Score each response individually so we can track and print which scorer target handled it
# You may want to use `score_prompts_batch_async` like below in practice for efficiency
# await scorer.score_prompts_batch_async(messages=response_messages) # type: ignore
for i, response_message in enumerate(response_messages):
scores = await scorer.score_async(message=response_message) # type: ignore

# The scorer's internal LLM response has inner_target_identifier in metadata.
# We can check the round-robin counter to determine which target was used.
# Since set_system_prompt and send_prompt_async each call _next_target(),
# the counter advances by 2 per scoring call (1 for system prompt, 1 for send).
# We use the counter to show the alternation pattern.
target_idx = rr_scorer_target._rotation[(rr_scorer_target._counter - 1) % len(rr_scorer_target._rotation)]
scorer_label = "Scorer Target A" if target_idx == 0 else "Scorer Target B"

for score in scores:
print(
f"Prompt {i + 1} scored by {scorer_label} | "
f"Value: {score.get_value()} | "
f"Category: {score.score_category} | "
f"Rationale: {score.score_rationale[:60]}"
)
1 change: 1 addition & 0 deletions doc/myst.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ project:
- file: code/targets/prompt_shield_target.ipynb
- file: code/targets/realtime_target.ipynb
- file: code/targets/use_huggingface_chat_target.ipynb
- file: code/targets/round_robin_target.ipynb
- file: code/converters/0_converters.ipynb
children:
- file: code/converters/1_text_to_text_converters.ipynb
Expand Down
4 changes: 4 additions & 0 deletions pyrit/identifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
)
from pyrit.identifiers.component_identifier import ComponentIdentifier, Identifiable, config_hash
from pyrit.identifiers.evaluation_identifier import (
TARGET_EVAL_PARAM_FALLBACKS,
TARGET_EVAL_PARAMS,
AtomicAttackEvaluationIdentifier,
ChildEvalRule,
EvaluationIdentifier,
Expand All @@ -36,6 +38,8 @@
"REGISTRY_NAME_PATTERN",
"ScorerEvaluationIdentifier",
"snake_case_to_class_name",
"TARGET_EVAL_PARAM_FALLBACKS",
"TARGET_EVAL_PARAMS",
"validate_registry_name",
"config_hash",
"IdentifierFilter",
Expand Down
33 changes: 29 additions & 4 deletions pyrit/identifiers/evaluation_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

from pyrit.identifiers.component_identifier import ComponentIdentifier, config_hash

# Behavioral params that define model output quality for scoring.
TARGET_EVAL_PARAMS: frozenset[str] = frozenset({"underlying_model_name", "temperature", "top_p"})
TARGET_EVAL_PARAM_FALLBACKS: dict[str, str] = {"underlying_model_name": "model_name"}


@dataclass(frozen=True)
class ChildEvalRule:
Expand All @@ -45,12 +49,18 @@ class ChildEvalRule:
missing), the fallback key's value from the component's raw params
is used instead. This keeps fallback logic in the eval layer without
changing full component hashes. ``None`` means no fallbacks.
* ``inner_child_name`` — if set, names the sub-child to "look through"
when the child being processed is a wrapper component (e.g.,
``RoundRobinTarget``). The first item of that sub-child list is
substituted before applying param filtering, so the eval hash
matches the unwrapped inner target. ``None`` means no unwrapping.
"""

exclude: bool = False
included_params: Optional[frozenset[str]] = None
included_item_values: Optional[dict[str, Any]] = field(default=None)
param_fallbacks: Optional[dict[str, str]] = field(default=None)
inner_child_name: Optional[str] = field(default=None)


def _build_eval_dict(
Expand Down Expand Up @@ -115,6 +125,19 @@ def _build_eval_dict(

child_list = identifier.get_child_list(name)

# Inner child lookup: if the rule names a sub-child (e.g., "targets"),
# substitute the first item of that sub-child list. This lets wrapper
# components (e.g., RoundRobinTarget) be "seen through".
if rule and rule.inner_child_name:
unwrapped: list[ComponentIdentifier] = []
for c in child_list:
inner = c.get_child_list(rule.inner_child_name)
if inner:
unwrapped.append(inner[0])
else:
unwrapped.append(c)
child_list = unwrapped

# Filter list items by param-value match (e.g., only is_general_technique=True seeds)
if rule and rule.included_item_values:
required = rule.included_item_values
Expand Down Expand Up @@ -238,8 +261,9 @@ class ScorerEvaluationIdentifier(EvaluationIdentifier):

CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = {
"prompt_target": ChildEvalRule(
included_params=frozenset({"underlying_model_name", "temperature", "top_p"}),
param_fallbacks={"underlying_model_name": "model_name"},
included_params=TARGET_EVAL_PARAMS,
param_fallbacks=TARGET_EVAL_PARAM_FALLBACKS,
inner_child_name="targets",
),
}

Expand All @@ -266,10 +290,11 @@ class AtomicAttackEvaluationIdentifier(EvaluationIdentifier):
CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = {
"objective_target": ChildEvalRule(
included_params=frozenset({"temperature"}),
inner_child_name="targets",
),
"adversarial_chat": ChildEvalRule(
included_params=frozenset({"underlying_model_name", "temperature", "top_p"}),
param_fallbacks={"underlying_model_name": "model_name"},
included_params=TARGET_EVAL_PARAMS,
param_fallbacks=TARGET_EVAL_PARAM_FALLBACKS,
),
"objective_scorer": ChildEvalRule(exclude=True),
"seed_identifiers": ChildEvalRule(exclude=True),
Expand Down
2 changes: 2 additions & 0 deletions pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pyrit.prompt_target.playwright_copilot_target import CopilotType, PlaywrightCopilotTarget
from pyrit.prompt_target.playwright_target import PlaywrightTarget
from pyrit.prompt_target.prompt_shield_target import PromptShieldTarget
from pyrit.prompt_target.round_robin_target import RoundRobinTarget
from pyrit.prompt_target.text_target import TextTarget
from pyrit.prompt_target.websocket_copilot_target import WebSocketCopilotTarget

Expand Down Expand Up @@ -101,6 +102,7 @@ def __getattr__(name: str) -> object:
"PromptShieldTarget",
"PromptTarget",
"RealtimeTarget",
"RoundRobinTarget",
"TargetCapabilities",
"TargetConfiguration",
"TargetRequirements",
Expand Down
Loading
Loading