Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions examples/devops/devops/agents/sre/tasks/review.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from switchplane import Task
from switchplane.agent_runtime import AgentContext
from switchplane.llm import build_llm
from switchplane.usage import estimate_text_tokens, llm_usage_from_response

# -- Mock data generation ------------------------------------------------------
#
Expand Down Expand Up @@ -337,30 +338,78 @@ def _strip_json_fences(text: str) -> str:
class ReviewState(TypedDict):
current: Any # pd.DataFrame (not serializable, ephemeral task)
previous: Any # pd.DataFrame
rows_processed: int
estimated_raw_prompt_tokens: int
analysis: dict | None
formatted: str
summary: dict | None
report: str


def _build_graph(llm) -> StateGraph:
def _build_graph(llm, ctx: AgentContext, model: str) -> StateGraph:
"""Wire the review graph. 3 deterministic nodes, 1 LLM node."""

def fetch_metrics(state: ReviewState) -> dict:
current, previous = generate_metrics()
return {"current": current, "previous": previous}
raw_prompt = "\n\n".join(
[
"CURRENT WEEK RAW METRICS CSV:",
current.to_csv(index=False),
"PREVIOUS WEEK RAW METRICS CSV:",
previous.to_csv(index=False),
]
)
return {
"current": current,
"previous": previous,
"rows_processed": len(current) + len(previous),
"estimated_raw_prompt_tokens": estimate_text_tokens(raw_prompt),
}

def analyze_metrics(state: ReviewState) -> dict:
result = analyze(state["current"], state["previous"])
return {"analysis": result, "formatted": format_analysis(result)}

async def summarize(state: ReviewState) -> dict:
prompt = _ANALYSIS_PROMPT.format(formatted_data=state["formatted"])
response = await llm.ainvoke(
[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
response = await llm.ainvoke(messages)
usage = llm_usage_from_response(
response,
task_id=ctx.task_id,
model=model,
node_name="summarize",
fallback_prompt_text=f"{_SYSTEM_PROMPT}\n\n{prompt}",
fallback_completion_text=str(response.content),
estimated_raw_prompt_tokens=state["estimated_raw_prompt_tokens"],
metadata={
"deterministic_nodes": 3,
"llm_nodes": 1,
"rows_processed": state["rows_processed"],
"formatted_prompt_tokens_estimate": estimate_text_tokens(prompt),
},
)
ctx.record_llm_usage(
model=usage.model,
node_name=usage.node_name,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
estimated_cost_usd=usage.estimated_cost_usd,
estimated_raw_prompt_tokens=usage.estimated_raw_prompt_tokens,
estimated_tokens_saved=usage.estimated_tokens_saved,
metadata=usage.metadata,
)
ctx.progress(
"LLM usage recorded",
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
estimated_cost_usd=usage.estimated_cost_usd,
estimated_tokens_saved=usage.estimated_tokens_saved,
)
Comment on lines +380 to 413
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core problem with the approach — the summarize node went from ~5 lines to 30+ lines of usage-tracking ceremony. Every LLM node in every task would need this same boilerplate.

If usage tracking is a framework concern, it should be transparent. A LangGraph callback on on_llm_end could emit the llm.usage event automatically with zero changes to task code.

try:
summary = json.loads(_strip_json_fences(response.content))
Expand Down Expand Up @@ -431,16 +480,17 @@ async def run(self, ctx: AgentContext) -> None:
ctx.progress(f"Starting ops review (model: {model})")
ctx.progress("Generating mock NewRelic metrics (2 weeks, 5-min granularity)...")

graph = _build_graph(llm).compile()

initial: ReviewState = {
"current": None,
"previous": None,
"rows_processed": 0,
"estimated_raw_prompt_tokens": 0,
"analysis": None,
"formatted": "",
"summary": None,
"report": "",
}

graph = _build_graph(llm, ctx, model).compile()
result = await graph.ainvoke(initial)
ctx.complete(result["report"])
3 changes: 2 additions & 1 deletion src/switchplane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from switchplane.app import Application
from switchplane.shell import Shell
from switchplane.task import Task, command
from switchplane.usage import LLMUsageRecord

__all__ = ["Application", "Field", "Shell", "Task", "command", "fmt"]
__all__ = ["Application", "Field", "LLMUsageRecord", "Shell", "Task", "command", "fmt"]
36 changes: 36 additions & 0 deletions src/switchplane/agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from switchplane._util import MAX_MESSAGE_SIZE
from switchplane.protocol import AgentCommand, AgentEvent
from switchplane.usage import LLMUsageRecord, estimate_cost_usd

_logger = structlog.get_logger()

Expand Down Expand Up @@ -243,6 +244,41 @@ def command_result(self, action: str, result: dict[str, Any]) -> None:
"""
self.emit("task.command_result", {"action": action, "result": result})

def record_llm_usage(
self,
*,
model: str,
node_name: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int | None = None,
estimated_cost_usd: float | None = None,
estimated_raw_prompt_tokens: int | None = None,
estimated_tokens_saved: int | None = None,
metadata: dict[str, Any] | None = None,
) -> LLMUsageRecord:
"""Emit a structured ``llm.usage`` event for one model call."""

total = total_tokens if total_tokens is not None else prompt_tokens + completion_tokens
cost = estimated_cost_usd
if cost is None:
cost = estimate_cost_usd(model, prompt_tokens, completion_tokens)

record = LLMUsageRecord(
task_id=self.task_id,
model=model,
node_name=node_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total,
estimated_cost_usd=cost,
estimated_raw_prompt_tokens=estimated_raw_prompt_tokens,
estimated_tokens_saved=estimated_tokens_saved,
metadata=metadata or {},
)
self.emit("llm.usage", record.model_dump(mode="json"))
return record

async def wait_for_input(self, prompt: str | None = None) -> str:
"""Block until the user sends freeform text input.

Expand Down
1 change: 1 addition & 0 deletions src/switchplane/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class AgentEvent(BaseModel):
"task.interrupted",
"task.resumed",
"checkpoint.save",
"llm.usage",
"log",
"task.command_result",
]
Expand Down
2 changes: 2 additions & 0 deletions src/switchplane/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ async def _handle_event(self, event: AgentEvent) -> int:
pass
case "checkpoint.save":
pass # TODO: Forward to checkpoint saver
case "llm.usage":
pass # Stored as regular event for task-level cost accounting
case "task.command_result":
pass # Stored as regular event
case "log":
Expand Down
159 changes: 159 additions & 0 deletions src/switchplane/usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""LLM usage accounting helpers."""

from __future__ import annotations

import math
from datetime import UTC, datetime
from typing import Any, NamedTuple

from pydantic import BaseModel, ConfigDict, Field


class ModelPricing(NamedTuple):
"""USD prices per one million tokens."""

input_per_million: float
output_per_million: float


# Public list-price approximations. Keep this deliberately small; unknown
# models still produce token records, just without an estimated dollar cost.
MODEL_PRICING: dict[str, ModelPricing] = {
"claude-sonnet-4-20250514": ModelPricing(3.0, 15.0),
"claude-sonnet-4-5-20250929": ModelPricing(3.0, 15.0),
"claude-sonnet-4-6": ModelPricing(3.0, 15.0),
"claude-opus-4-20250514": ModelPricing(15.0, 75.0),
"claude-opus-4-6-v1": ModelPricing(15.0, 75.0),
"claude-haiku-4-5-20251001": ModelPricing(1.0, 5.0),
"gpt-4o": ModelPricing(2.5, 10.0),
"gpt-4o-mini": ModelPricing(0.15, 0.60),
"gemini-2.0-flash": ModelPricing(0.10, 0.40),
"gemini-2.5-flash": ModelPricing(0.30, 2.50),
"gemini-2.5-pro": ModelPricing(1.25, 10.0),
}
Comment on lines +22 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pricing table will go stale immediately. Prices change frequently, model IDs are versioned, and this doesn't account for caching discounts, batch API pricing, prompt caching writebacks, etc.

If cost tracking is needed, it belongs in config or an external source — not a compiled-in dict that requires code changes to update.



class LLMUsageRecord(BaseModel):
"""Structured accounting record for a single LLM call."""

model_config = ConfigDict(str_strip_whitespace=True)

task_id: str
model: str
node_name: str
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
estimated_cost_usd: float | None = None
timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC))
estimated_raw_prompt_tokens: int | None = None
estimated_tokens_saved: int | None = None
metadata: dict[str, Any] = Field(default_factory=dict)


def estimate_text_tokens(text: str) -> int:
"""Cheap text token estimate for before/after comparisons.

Providers expose exact usage after a call, but savings estimates often need
a pre-call approximation. Four characters per token is a conservative,
model-agnostic rule of thumb for English/code-like text.
"""

if not text:
return 0
return max(1, math.ceil(len(text) / 4))


def estimate_cost_usd(model: str, prompt_tokens: int, completion_tokens: int) -> float | None:
"""Estimate USD cost for a model if pricing is known."""

pricing = MODEL_PRICING.get(model)
if pricing is None:
return None
cost = (prompt_tokens / 1_000_000 * pricing.input_per_million) + (
completion_tokens / 1_000_000 * pricing.output_per_million
)
return round(cost, 6)
Comment on lines +68 to +76
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This len(text) / 4 heuristic is very rough, and the results get stored in the same LLMUsageRecord alongside provider-reported actuals. Downstream consumers of these records have no reliable way to distinguish precision levels.

The estimated_tokens_saved metric (raw prompt tokens from this estimate minus actual prompt tokens from the provider) is comparing a pre-call guess against a post-call actual — not a meaningful comparison.



def _coerce_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None


def extract_token_counts(response: Any) -> tuple[int | None, int | None, int | None]:
"""Extract provider-reported token counts from common LangChain responses."""

usage_metadata = getattr(response, "usage_metadata", None)
if isinstance(usage_metadata, dict):
prompt = _coerce_int(usage_metadata.get("input_tokens") or usage_metadata.get("prompt_tokens"))
completion = _coerce_int(
usage_metadata.get("output_tokens")
or usage_metadata.get("completion_tokens")
or usage_metadata.get("generated_tokens")
)
total = _coerce_int(usage_metadata.get("total_tokens"))
if prompt is not None or completion is not None or total is not None:
return prompt, completion, total

response_metadata = getattr(response, "response_metadata", None)
if isinstance(response_metadata, dict):
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage")
if isinstance(token_usage, dict):
prompt = _coerce_int(token_usage.get("prompt_tokens") or token_usage.get("input_tokens"))
completion = _coerce_int(token_usage.get("completion_tokens") or token_usage.get("output_tokens"))
total = _coerce_int(token_usage.get("total_tokens"))
if prompt is not None or completion is not None or total is not None:
return prompt, completion, total

return None, None, None
Comment on lines +89 to +113
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LangChain already surfaces token usage on every AIMessage via usage_metadata (input_tokens, output_tokens, total_tokens). This function is manually re-extracting data that the framework already provides — it doesn't add new information.

A LangGraph callback (on_llm_end) could capture this automatically without any custom extraction logic.



def llm_usage_from_response(
response: Any,
*,
task_id: str,
model: str,
node_name: str,
fallback_prompt_text: str = "",
fallback_completion_text: str = "",
estimated_raw_prompt_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> LLMUsageRecord:
"""Build an ``LLMUsageRecord`` from a LangChain response."""

prompt_tokens, completion_tokens, total_tokens = extract_token_counts(response)
token_source = "provider"

if prompt_tokens is None:
prompt_tokens = estimate_text_tokens(fallback_prompt_text)
token_source = "estimated"
if completion_tokens is None:
completion_tokens = estimate_text_tokens(fallback_completion_text)
token_source = "estimated"
if total_tokens is None:
total_tokens = prompt_tokens + completion_tokens

estimated_tokens_saved = None
if estimated_raw_prompt_tokens is not None:
estimated_tokens_saved = max(0, estimated_raw_prompt_tokens - prompt_tokens)

meta = dict(metadata or {})
meta.setdefault("token_source", token_source)

return LLMUsageRecord(
task_id=task_id,
model=model,
node_name=node_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
estimated_cost_usd=estimate_cost_usd(model, prompt_tokens, completion_tokens),
estimated_raw_prompt_tokens=estimated_raw_prompt_tokens,
estimated_tokens_saved=estimated_tokens_saved,
metadata=meta,
)
27 changes: 27 additions & 0 deletions tests/test_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,33 @@ def test_command_result(self, ctx, socketpair):
assert event.payload["result"]["lat"] == 1.0


class TestAgentContextLLMUsage:
def test_record_llm_usage(self, ctx, socketpair):
cp_sock, _ = socketpair
record = ctx.record_llm_usage(
model="gpt-4o-mini",
node_name="summarize",
prompt_tokens=100,
completion_tokens=20,
estimated_raw_prompt_tokens=500,
estimated_tokens_saved=400,
metadata={"rows_processed": 10},
)

data = _recv_message(cp_sock)
event = AgentEvent.model_validate_json(data)
assert event.type == "llm.usage"
assert event.payload["task_id"] == "task1"
assert event.payload["model"] == "gpt-4o-mini"
assert event.payload["node_name"] == "summarize"
assert event.payload["prompt_tokens"] == 100
assert event.payload["completion_tokens"] == 20
assert event.payload["total_tokens"] == 120
assert event.payload["estimated_tokens_saved"] == 400
assert event.payload["metadata"]["rows_processed"] == 10
assert record.total_tokens == 120


class TestAgentContextProperties:
def test_config(self, ctx):
assert ctx.config == {"llm": {"model": "test"}}
Expand Down
1 change: 1 addition & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_valid_types(self):
"task.failed",
"task.cancelled",
"checkpoint.save",
"llm.usage",
"log",
"task.command_result",
]
Expand Down
Loading