Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 114 additions & 4 deletions nerve/agent/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,36 @@ def _parse_mcp_tool_name(tool_name: str) -> tuple[str, str] | None:
return None


def _model_family(model: str) -> str:
"""Normalize a model identifier to a comparable family name.

The same model family shows up under many identifiers depending on
provider routing and release: a bare alias, a dated release id, a
Bedrock inference-profile id, or a context-window-suffixed alias.
Serving-model change detection compares *families*, so none of these
synonyms may register as a change:

example-model-2 -> example-model-2
example-model-2-20260101 -> example-model-2
us.provider.example-model-2-20260101-v1:0 -> example-model-2
example-model-2[1m] -> example-model-2
example-model-2-latest -> example-model-2
"""
m = model.strip().lower()
# Provider routing prefix: "us.anthropic.", "global.anthropic.", ...
if "anthropic." in m:
m = m.rsplit("anthropic.", 1)[1]
# Context-window suffix: "...[1m]"
m = m.split("[", 1)[0]
# Bedrock version suffix: "-v1:0" / "-v2"
m = re.sub(r"-v\d+(?::\d+)?$", "", m)
# "-latest" alias suffix
m = re.sub(r"-latest$", "", m)
# Trailing release date: "-20260601"
m = re.sub(r"-20\d{6}$", "", m)
return m


@dataclass
class _TurnState:
"""Accumulates one agent turn's worth of streamed content.
Expand Down Expand Up @@ -331,6 +361,14 @@ def __init__(self, config: NerveConfig, db: Database):
# detect mid-session model switches (the CLI fixes its model at
# connect time, so a change requires recreating the client).
self._session_models: dict[str, str] = {}
# Last model *observed* serving each session (from
# AssistantMessage.model). The API may silently serve a different
# model than requested — e.g. a capacity fallback from a frontier
# model to the previous tier — and switch back later. Transitions
# are surfaced as model_change blocks/events (_track_serving_model).
# Seeded from session metadata on client creation so detection
# survives restarts without re-firing on every resume.
self._observed_models: dict[str, str] = {}
self._router = None # ChannelRouter — lazy-initialized via .router property
self._mcp_servers_cache = list(config.mcp_servers) # hot-reloadable
self._claude_code_plugins: list[dict[str, str]] = [] # plugin dirs
Expand Down Expand Up @@ -1550,13 +1588,28 @@ async def _get_or_create_client(
unregister_handler(session_id)
await self._safe_disconnect(client)
client = None
# Deliberate switch — drop the observed-model baseline so
# the first message on the new model doesn't fire a
# model_change event for a change the user asked for.
self._observed_models.pop(session_id, None)
else:
return client

# Check for stored SDK session ID for resume
session = await self.db.get_session(session_id)
sdk_resume_id = session.get("sdk_session_id") if session else None

# Seed the serving-model baseline from the last persisted
# observation so downgrade detection survives restarts without
# re-firing an event on every resumed session.
if session and session_id not in self._observed_models:
try:
_meta = json.loads(session.get("metadata") or "{}")
except (TypeError, ValueError):
_meta = {}
if _meta.get("observed_model"):
self._observed_models[session_id] = _meta["observed_model"]

# For forks, use the source session's SDK ID
if fork_from and not sdk_resume_id:
sdk_resume_id = fork_from
Expand Down Expand Up @@ -1863,6 +1916,54 @@ def _merge_tool_results(
# Shared per-message processing (user runs + autonomous turns) #
# ------------------------------------------------------------------ #

async def _track_serving_model(
self, session_id: str, model: str, st: _TurnState,
) -> None:
"""Detect serving-model transitions and surface them.

The API can serve a session with a different model than the one
configured — e.g. a capacity fallback from a frontier model to
the previous tier — and later switch back, all without any
explicit signal beyond ``AssistantMessage.model``. Compare each
main-agent message's model against the last observed one (or the
configured model for the first observation) and, when the model
*family* changes:

- append a ``model_change`` block to the turn (persisted with the
message, so the transition stays visible in history),
- broadcast a ``model_changed`` event for the live UI,
- log it (warning when it moves away from the configured model,
info when it returns).

Family comparison (see ``_model_family``) keeps alias/dated/
Bedrock spellings of the same model from registering as changes.
"""
prev = self._observed_models.get(session_id)
self._observed_models[session_id] = model
configured = self._session_models.get(session_id)
baseline = prev or configured
if not baseline or _model_family(model) == _model_family(baseline):
return
downgrade = bool(
configured and _model_family(model) != _model_family(configured),
)
log = logger.warning if downgrade else logger.info
log(
"Session %s serving model changed: %s → %s%s",
session_id, baseline, model,
f" (away from configured {configured})" if downgrade else "",
)
st.ordered_blocks.append({
"type": "model_change",
"from": baseline,
"to": model,
"downgrade": downgrade,
})
await broadcaster.broadcast_model_changed(
session_id, from_model=baseline, to_model=model,
downgrade=downgrade,
)

async def _process_sdk_message(
self, session_id: str, message: Any, st: _TurnState,
) -> bool:
Expand All @@ -1885,13 +1986,18 @@ async def _process_sdk_message(

if isinstance(message, AssistantMessage):
st.got_content = True
# Capture model from assistant message (more reliable than config)
msg_model = getattr(message, "model", None)
if msg_model:
st.last_model = msg_model
# Extract parent_tool_use_id — set when this message comes from
# a sub-agent (Task/Agent) rather than the main agent
parent_id = getattr(message, "parent_tool_use_id", None)
# Capture model from assistant message (more reliable than
# config). Main-agent messages only: sub-agents legitimately
# run different models (Agent tool `model` opt, built-in agent
# defaults), which must not pollute turn cost attribution or
# fire serving-model change events.
msg_model = getattr(message, "model", None)
if msg_model and parent_id is None:
st.last_model = msg_model
await self._track_serving_model(session_id, msg_model, st)

for block in message.content:
if isinstance(block, TextBlock):
Expand Down Expand Up @@ -2330,6 +2436,10 @@ async def _finalize_turn(
session_record = await self.db.get_session(session_id)
meta = json.loads(session_record.get("metadata") or "{}") if session_record else {}
meta["last_usage"] = usage_data
if st.last_model:
# Baseline for serving-model change detection across
# restarts (see _track_serving_model).
meta["observed_model"] = st.last_model

# Extract server_tool_use counts
server_tool = st.last_usage.get("server_tool_use") or {}
Expand Down
19 changes: 19 additions & 0 deletions nerve/agent/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ async def broadcast_wakeup(self, session_id: str) -> None:
"""Mark the start of a turn fired by a self-scheduled ScheduleWakeup."""
await self.broadcast(session_id, {"type": "wakeup", "session_id": session_id})

async def broadcast_model_changed(
self,
session_id: str,
from_model: str,
to_model: str,
downgrade: bool = False,
) -> None:
"""The model serving this session changed (e.g. a silent API
downgrade to a fallback tier, or the later recovery back to the
configured model). ``downgrade`` is True when the new model moved
*away* from the session's configured model."""
await self.broadcast(session_id, {
"type": "model_changed",
"session_id": session_id,
"from_model": from_model,
"to_model": to_model,
"downgrade": downgrade,
})

async def broadcast_interaction(self, session_id: str, interaction_type: str, interaction_id: str, tool_name: str, tool_input: dict) -> None:
await self.broadcast(session_id, {
"type": "interaction",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_autonomous_turns.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _make_engine() -> AgentEngine:
engine._workflows = {}
engine._idle_watchers = {}
engine._session_locks = {}
engine._session_models = {}
engine._observed_models = {}
return engine


Expand Down
172 changes: 170 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import os
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import pytest
from claude_agent_sdk import AssistantMessage

from nerve.agent.engine import AgentEngine
from nerve.agent.engine import AgentEngine, _TurnState, _model_family


@pytest.mark.parametrize(
Expand Down Expand Up @@ -349,3 +350,170 @@ async def test_no_grant_hook_when_disabled(self):
matchers = {m.matcher for m in hooks["PreToolUse"]}
assert "Edit|Write|NotebookEdit" in matchers
assert "Read" in matchers


# ---------------------------------------------------------------------------
# _model_family — serving-model identifier normalization
# ---------------------------------------------------------------------------


@pytest.mark.parametrize(
"model, expected",
[
# Bare alias passes through
("claude-fable-5", "claude-fable-5"),
("claude-opus-4-8", "claude-opus-4-8"),
# Dated release ids collapse onto the alias
("claude-fable-5-20260601", "claude-fable-5"),
("claude-opus-4-8-20260115", "claude-opus-4-8"),
# Bedrock inference-profile spellings
("us.anthropic.claude-fable-5-20260601-v1:0", "claude-fable-5"),
("global.anthropic.claude-fable-5", "claude-fable-5"),
# Context-window suffix
("claude-sonnet-4-5[1m]", "claude-sonnet-4-5"),
# "-latest" alias
("claude-haiku-4-5-latest", "claude-haiku-4-5"),
# Case / whitespace robustness
(" Claude-Fable-5-20260601 ", "claude-fable-5"),
# Version-looking tails that are NOT dates stay intact
("claude-opus-4", "claude-opus-4"),
("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"),
],
)
def test_model_family(model, expected):
assert _model_family(model) == expected


def test_model_family_distinguishes_real_changes():
# The pair that matters: a frontier model downgrading to the prior tier
# must NOT normalize to the same family.
assert _model_family("claude-fable-5-20260601") != _model_family(
"claude-opus-4-8-20260115",
)


# ---------------------------------------------------------------------------
# _track_serving_model — downgrade / recovery detection via
# _process_sdk_message (the shared user-run + autonomous-turn path)
# ---------------------------------------------------------------------------


def _make_model_engine(configured: str | None = "claude-fable-5") -> AgentEngine:
"""Minimal AgentEngine stub for serving-model tracking tests."""
engine = AgentEngine.__new__(AgentEngine)
engine._session_models = {"s1": configured} if configured else {}
engine._observed_models = {}
engine._workflows = {}
return engine


def _assistant(model: str, parent_tool_use_id: str | None = None) -> AssistantMessage:
return AssistantMessage(
content=[], model=model, parent_tool_use_id=parent_tool_use_id,
)


class TestServingModelTracking:
@pytest.mark.asyncio
async def test_first_message_downgrade_fires_event(self):
engine = _make_model_engine()
st = _TurnState()
with patch("nerve.agent.engine.broadcaster") as bc:
bc.broadcast_model_changed = AsyncMock()
await engine._process_sdk_message(
"s1", _assistant("claude-opus-4-8-20260115"), st,
)
assert st.last_model == "claude-opus-4-8-20260115"
assert st.ordered_blocks == [{
"type": "model_change",
"from": "claude-fable-5",
"to": "claude-opus-4-8-20260115",
"downgrade": True,
}]
bc.broadcast_model_changed.assert_awaited_once_with(
"s1",
from_model="claude-fable-5",
to_model="claude-opus-4-8-20260115",
downgrade=True,
)

@pytest.mark.asyncio
async def test_same_family_dated_id_is_not_a_change(self):
engine = _make_model_engine()
st = _TurnState()
with patch("nerve.agent.engine.broadcaster") as bc:
bc.broadcast_model_changed = AsyncMock()
await engine._process_sdk_message(
"s1", _assistant("claude-fable-5-20260601"), st,
)
assert st.ordered_blocks == []
bc.broadcast_model_changed.assert_not_awaited()
# Baseline still updated for subsequent comparisons
assert engine._observed_models["s1"] == "claude-fable-5-20260601"

@pytest.mark.asyncio
async def test_recovery_back_to_configured_is_not_downgrade(self):
engine = _make_model_engine()
engine._observed_models["s1"] = "claude-opus-4-8-20260115"
st = _TurnState()
with patch("nerve.agent.engine.broadcaster") as bc:
bc.broadcast_model_changed = AsyncMock()
await engine._process_sdk_message(
"s1", _assistant("claude-fable-5-20260601"), st,
)
assert st.ordered_blocks == [{
"type": "model_change",
"from": "claude-opus-4-8-20260115",
"to": "claude-fable-5-20260601",
"downgrade": False,
}]

@pytest.mark.asyncio
async def test_mid_session_transition_uses_observed_baseline(self):
engine = _make_model_engine()
st = _TurnState()
with patch("nerve.agent.engine.broadcaster") as bc:
bc.broadcast_model_changed = AsyncMock()
await engine._process_sdk_message(
"s1", _assistant("claude-fable-5-20260601"), st,
)
await engine._process_sdk_message(
"s1", _assistant("claude-opus-4-8-20260115"), st,
)
# Only the second message fires; "from" is the observed dated id,
# not the configured alias.
assert len(st.ordered_blocks) == 1
assert st.ordered_blocks[0]["from"] == "claude-fable-5-20260601"
assert st.ordered_blocks[0]["downgrade"] is True

@pytest.mark.asyncio
async def test_subagent_messages_are_ignored(self):
engine = _make_model_engine()
st = _TurnState()
with patch("nerve.agent.engine.broadcaster") as bc:
bc.broadcast_model_changed = AsyncMock()
await engine._process_sdk_message(
"s1",
_assistant("claude-haiku-4-5", parent_tool_use_id="tu_1"),
st,
)
# Sub-agents legitimately run other models — no event, no baseline
# pollution, and no cost-attribution model override.
assert st.last_model is None
assert st.ordered_blocks == []
assert "s1" not in engine._observed_models
bc.broadcast_model_changed.assert_not_awaited()

@pytest.mark.asyncio
async def test_no_configured_model_first_message_is_quiet(self):
engine = _make_model_engine(configured=None)
st = _TurnState()
with patch("nerve.agent.engine.broadcaster") as bc:
bc.broadcast_model_changed = AsyncMock()
await engine._process_sdk_message(
"s1", _assistant("claude-opus-4-8-20260115"), st,
)
# Nothing to compare against — record the baseline silently.
assert st.ordered_blocks == []
bc.broadcast_model_changed.assert_not_awaited()
assert engine._observed_models["s1"] == "claude-opus-4-8-20260115"
1 change: 1 addition & 0 deletions web/src/api/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export type WSMessage =
| { type: 'workflow_progress'; session_id: string; tool_use_id: string; workflow: WorkflowSnapshot }
| { type: 'wakeup'; session_id: string }
| { type: 'auto_turn'; session_id: string }
| { type: 'model_changed'; session_id: string; from_model: string; to_model: string; downgrade: boolean }
| { type: 'pong' };

type MessageHandler = (msg: WSMessage) => void;
Expand Down
Loading
Loading