diff --git a/examples/devops/devops/agents/sre/tasks/review.py b/examples/devops/devops/agents/sre/tasks/review.py index 5456e23..f8349f6 100644 --- a/examples/devops/devops/agents/sre/tasks/review.py +++ b/examples/devops/devops/agents/sre/tasks/review.py @@ -25,6 +25,7 @@ from switchplane import Task from switchplane.agent_runtime import AgentContext from switchplane.llm import build_llm +from switchplane.usage import estimate_text_tokens, llm_usage_from_response # -- Mock data generation ------------------------------------------------------ # @@ -337,18 +338,33 @@ def _strip_json_fences(text: str) -> str: class ReviewState(TypedDict): current: Any # pd.DataFrame (not serializable, ephemeral task) previous: Any # pd.DataFrame + rows_processed: int + estimated_raw_prompt_tokens: int analysis: dict | None formatted: str summary: dict | None report: str -def _build_graph(llm) -> StateGraph: +def _build_graph(llm, ctx: AgentContext, model: str) -> StateGraph: """Wire the review graph. 3 deterministic nodes, 1 LLM node.""" def fetch_metrics(state: ReviewState) -> dict: current, previous = generate_metrics() - return {"current": current, "previous": previous} + raw_prompt = "\n\n".join( + [ + "CURRENT WEEK RAW METRICS CSV:", + current.to_csv(index=False), + "PREVIOUS WEEK RAW METRICS CSV:", + previous.to_csv(index=False), + ] + ) + return { + "current": current, + "previous": previous, + "rows_processed": len(current) + len(previous), + "estimated_raw_prompt_tokens": estimate_text_tokens(raw_prompt), + } def analyze_metrics(state: ReviewState) -> dict: result = analyze(state["current"], state["previous"]) @@ -356,11 +372,44 @@ def analyze_metrics(state: ReviewState) -> dict: async def summarize(state: ReviewState) -> dict: prompt = _ANALYSIS_PROMPT.format(formatted_data=state["formatted"]) - response = await llm.ainvoke( - [ - {"role": "system", "content": _SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ] + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ] + response = await llm.ainvoke(messages) + usage = llm_usage_from_response( + response, + task_id=ctx.task_id, + model=model, + node_name="summarize", + fallback_prompt_text=f"{_SYSTEM_PROMPT}\n\n{prompt}", + fallback_completion_text=str(response.content), + estimated_raw_prompt_tokens=state["estimated_raw_prompt_tokens"], + metadata={ + "deterministic_nodes": 3, + "llm_nodes": 1, + "rows_processed": state["rows_processed"], + "formatted_prompt_tokens_estimate": estimate_text_tokens(prompt), + }, + ) + ctx.record_llm_usage( + model=usage.model, + node_name=usage.node_name, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + estimated_cost_usd=usage.estimated_cost_usd, + estimated_raw_prompt_tokens=usage.estimated_raw_prompt_tokens, + estimated_tokens_saved=usage.estimated_tokens_saved, + metadata=usage.metadata, + ) + ctx.progress( + "LLM usage recorded", + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + estimated_cost_usd=usage.estimated_cost_usd, + estimated_tokens_saved=usage.estimated_tokens_saved, ) try: summary = json.loads(_strip_json_fences(response.content)) @@ -431,16 +480,17 @@ async def run(self, ctx: AgentContext) -> None: ctx.progress(f"Starting ops review (model: {model})") ctx.progress("Generating mock NewRelic metrics (2 weeks, 5-min granularity)...") - graph = _build_graph(llm).compile() - initial: ReviewState = { "current": None, "previous": None, + "rows_processed": 0, + "estimated_raw_prompt_tokens": 0, "analysis": None, "formatted": "", "summary": None, "report": "", } + graph = _build_graph(llm, ctx, model).compile() result = await graph.ainvoke(initial) ctx.complete(result["report"]) diff --git a/src/switchplane/__init__.py b/src/switchplane/__init__.py index 79605a8..d6f2dbf 100644 --- a/src/switchplane/__init__.py +++ b/src/switchplane/__init__.py @@ -8,5 +8,6 @@ from switchplane.app import Application from switchplane.shell import Shell from switchplane.task import Task, command +from switchplane.usage import LLMUsageRecord -__all__ = ["Application", "Field", "Shell", "Task", "command", "fmt"] +__all__ = ["Application", "Field", "LLMUsageRecord", "Shell", "Task", "command", "fmt"] diff --git a/src/switchplane/agent_runtime.py b/src/switchplane/agent_runtime.py index 7916538..aaeb8ad 100644 --- a/src/switchplane/agent_runtime.py +++ b/src/switchplane/agent_runtime.py @@ -21,6 +21,7 @@ from switchplane._util import MAX_MESSAGE_SIZE from switchplane.protocol import AgentCommand, AgentEvent +from switchplane.usage import LLMUsageRecord, estimate_cost_usd _logger = structlog.get_logger() @@ -243,6 +244,41 @@ def command_result(self, action: str, result: dict[str, Any]) -> None: """ self.emit("task.command_result", {"action": action, "result": result}) + def record_llm_usage( + self, + *, + model: str, + node_name: str, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int | None = None, + estimated_cost_usd: float | None = None, + estimated_raw_prompt_tokens: int | None = None, + estimated_tokens_saved: int | None = None, + metadata: dict[str, Any] | None = None, + ) -> LLMUsageRecord: + """Emit a structured ``llm.usage`` event for one model call.""" + + total = total_tokens if total_tokens is not None else prompt_tokens + completion_tokens + cost = estimated_cost_usd + if cost is None: + cost = estimate_cost_usd(model, prompt_tokens, completion_tokens) + + record = LLMUsageRecord( + task_id=self.task_id, + model=model, + node_name=node_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total, + estimated_cost_usd=cost, + estimated_raw_prompt_tokens=estimated_raw_prompt_tokens, + estimated_tokens_saved=estimated_tokens_saved, + metadata=metadata or {}, + ) + self.emit("llm.usage", record.model_dump(mode="json")) + return record + async def wait_for_input(self, prompt: str | None = None) -> str: """Block until the user sends freeform text input. diff --git a/src/switchplane/protocol.py b/src/switchplane/protocol.py index a8c3490..699b079 100644 --- a/src/switchplane/protocol.py +++ b/src/switchplane/protocol.py @@ -57,6 +57,7 @@ class AgentEvent(BaseModel): "task.interrupted", "task.resumed", "checkpoint.save", + "llm.usage", "log", "task.command_result", ] diff --git a/src/switchplane/subprocess_manager.py b/src/switchplane/subprocess_manager.py index e9ae2d5..da3ef9d 100644 --- a/src/switchplane/subprocess_manager.py +++ b/src/switchplane/subprocess_manager.py @@ -315,6 +315,8 @@ async def _handle_event(self, event: AgentEvent) -> int: pass case "checkpoint.save": pass # TODO: Forward to checkpoint saver + case "llm.usage": + pass # Stored as regular event for task-level cost accounting case "task.command_result": pass # Stored as regular event case "log": diff --git a/src/switchplane/usage.py b/src/switchplane/usage.py new file mode 100644 index 0000000..f698364 --- /dev/null +++ b/src/switchplane/usage.py @@ -0,0 +1,159 @@ +"""LLM usage accounting helpers.""" + +from __future__ import annotations + +import math +from datetime import UTC, datetime +from typing import Any, NamedTuple + +from pydantic import BaseModel, ConfigDict, Field + + +class ModelPricing(NamedTuple): + """USD prices per one million tokens.""" + + input_per_million: float + output_per_million: float + + +# Public list-price approximations. Keep this deliberately small; unknown +# models still produce token records, just without an estimated dollar cost. +MODEL_PRICING: dict[str, ModelPricing] = { + "claude-sonnet-4-20250514": ModelPricing(3.0, 15.0), + "claude-sonnet-4-5-20250929": ModelPricing(3.0, 15.0), + "claude-sonnet-4-6": ModelPricing(3.0, 15.0), + "claude-opus-4-20250514": ModelPricing(15.0, 75.0), + "claude-opus-4-6-v1": ModelPricing(15.0, 75.0), + "claude-haiku-4-5-20251001": ModelPricing(1.0, 5.0), + "gpt-4o": ModelPricing(2.5, 10.0), + "gpt-4o-mini": ModelPricing(0.15, 0.60), + "gemini-2.0-flash": ModelPricing(0.10, 0.40), + "gemini-2.5-flash": ModelPricing(0.30, 2.50), + "gemini-2.5-pro": ModelPricing(1.25, 10.0), +} + + +class LLMUsageRecord(BaseModel): + """Structured accounting record for a single LLM call.""" + + model_config = ConfigDict(str_strip_whitespace=True) + + task_id: str + model: str + node_name: str + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + estimated_cost_usd: float | None = None + timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) + estimated_raw_prompt_tokens: int | None = None + estimated_tokens_saved: int | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +def estimate_text_tokens(text: str) -> int: + """Cheap text token estimate for before/after comparisons. + + Providers expose exact usage after a call, but savings estimates often need + a pre-call approximation. Four characters per token is a conservative, + model-agnostic rule of thumb for English/code-like text. + """ + + if not text: + return 0 + return max(1, math.ceil(len(text) / 4)) + + +def estimate_cost_usd(model: str, prompt_tokens: int, completion_tokens: int) -> float | None: + """Estimate USD cost for a model if pricing is known.""" + + pricing = MODEL_PRICING.get(model) + if pricing is None: + return None + cost = (prompt_tokens / 1_000_000 * pricing.input_per_million) + ( + completion_tokens / 1_000_000 * pricing.output_per_million + ) + return round(cost, 6) + + +def _coerce_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + +def extract_token_counts(response: Any) -> tuple[int | None, int | None, int | None]: + """Extract provider-reported token counts from common LangChain responses.""" + + usage_metadata = getattr(response, "usage_metadata", None) + if isinstance(usage_metadata, dict): + prompt = _coerce_int(usage_metadata.get("input_tokens") or usage_metadata.get("prompt_tokens")) + completion = _coerce_int( + usage_metadata.get("output_tokens") + or usage_metadata.get("completion_tokens") + or usage_metadata.get("generated_tokens") + ) + total = _coerce_int(usage_metadata.get("total_tokens")) + if prompt is not None or completion is not None or total is not None: + return prompt, completion, total + + response_metadata = getattr(response, "response_metadata", None) + if isinstance(response_metadata, dict): + token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") + if isinstance(token_usage, dict): + prompt = _coerce_int(token_usage.get("prompt_tokens") or token_usage.get("input_tokens")) + completion = _coerce_int(token_usage.get("completion_tokens") or token_usage.get("output_tokens")) + total = _coerce_int(token_usage.get("total_tokens")) + if prompt is not None or completion is not None or total is not None: + return prompt, completion, total + + return None, None, None + + +def llm_usage_from_response( + response: Any, + *, + task_id: str, + model: str, + node_name: str, + fallback_prompt_text: str = "", + fallback_completion_text: str = "", + estimated_raw_prompt_tokens: int | None = None, + metadata: dict[str, Any] | None = None, +) -> LLMUsageRecord: + """Build an ``LLMUsageRecord`` from a LangChain response.""" + + prompt_tokens, completion_tokens, total_tokens = extract_token_counts(response) + token_source = "provider" + + if prompt_tokens is None: + prompt_tokens = estimate_text_tokens(fallback_prompt_text) + token_source = "estimated" + if completion_tokens is None: + completion_tokens = estimate_text_tokens(fallback_completion_text) + token_source = "estimated" + if total_tokens is None: + total_tokens = prompt_tokens + completion_tokens + + estimated_tokens_saved = None + if estimated_raw_prompt_tokens is not None: + estimated_tokens_saved = max(0, estimated_raw_prompt_tokens - prompt_tokens) + + meta = dict(metadata or {}) + meta.setdefault("token_source", token_source) + + return LLMUsageRecord( + task_id=task_id, + model=model, + node_name=node_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + estimated_cost_usd=estimate_cost_usd(model, prompt_tokens, completion_tokens), + estimated_raw_prompt_tokens=estimated_raw_prompt_tokens, + estimated_tokens_saved=estimated_tokens_saved, + metadata=meta, + ) diff --git a/tests/test_agent_runtime.py b/tests/test_agent_runtime.py index a36a651..6d4c779 100644 --- a/tests/test_agent_runtime.py +++ b/tests/test_agent_runtime.py @@ -175,6 +175,33 @@ def test_command_result(self, ctx, socketpair): assert event.payload["result"]["lat"] == 1.0 +class TestAgentContextLLMUsage: + def test_record_llm_usage(self, ctx, socketpair): + cp_sock, _ = socketpair + record = ctx.record_llm_usage( + model="gpt-4o-mini", + node_name="summarize", + prompt_tokens=100, + completion_tokens=20, + estimated_raw_prompt_tokens=500, + estimated_tokens_saved=400, + metadata={"rows_processed": 10}, + ) + + data = _recv_message(cp_sock) + event = AgentEvent.model_validate_json(data) + assert event.type == "llm.usage" + assert event.payload["task_id"] == "task1" + assert event.payload["model"] == "gpt-4o-mini" + assert event.payload["node_name"] == "summarize" + assert event.payload["prompt_tokens"] == 100 + assert event.payload["completion_tokens"] == 20 + assert event.payload["total_tokens"] == 120 + assert event.payload["estimated_tokens_saved"] == 400 + assert event.payload["metadata"]["rows_processed"] == 10 + assert record.total_tokens == 120 + + class TestAgentContextProperties: def test_config(self, ctx): assert ctx.config == {"llm": {"model": "test"}} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8c4046e..8afc9d3 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -83,6 +83,7 @@ def test_valid_types(self): "task.failed", "task.cancelled", "checkpoint.save", + "llm.usage", "log", "task.command_result", ] diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 0000000..454000f --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,66 @@ +from types import SimpleNamespace + +from switchplane.usage import estimate_cost_usd, estimate_text_tokens, extract_token_counts, llm_usage_from_response + + +def test_estimate_text_tokens_uses_four_chars_per_token(): + assert estimate_text_tokens("") == 0 + assert estimate_text_tokens("abcd") == 1 + assert estimate_text_tokens("abcde") == 2 + + +def test_extract_token_counts_from_usage_metadata(): + response = SimpleNamespace(usage_metadata={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}) + + assert extract_token_counts(response) == (10, 5, 15) + + +def test_extract_token_counts_from_response_metadata(): + response = SimpleNamespace(response_metadata={"token_usage": {"prompt_tokens": 20, "completion_tokens": 7}}) + + assert extract_token_counts(response) == (20, 7, None) + + +def test_estimate_cost_for_known_model(): + assert estimate_cost_usd("gpt-4o-mini", 1_000_000, 1_000_000) == 0.75 + + +def test_usage_from_response_uses_provider_counts(): + response = SimpleNamespace( + content="ok", + usage_metadata={"input_tokens": 100, "output_tokens": 25, "total_tokens": 125}, + ) + + usage = llm_usage_from_response( + response, + task_id="task1", + model="claude-sonnet-4-20250514", + node_name="summarize", + estimated_raw_prompt_tokens=1_000, + ) + + assert usage.task_id == "task1" + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 25 + assert usage.total_tokens == 125 + assert usage.estimated_tokens_saved == 900 + assert usage.metadata["token_source"] == "provider" + + +def test_usage_from_response_falls_back_to_estimates(): + response = SimpleNamespace(content="hello world") + + usage = llm_usage_from_response( + response, + task_id="task1", + model="unknown-model", + node_name="summarize", + fallback_prompt_text="abcd", + fallback_completion_text=response.content, + ) + + assert usage.prompt_tokens == 1 + assert usage.completion_tokens == 3 + assert usage.total_tokens == 4 + assert usage.estimated_cost_usd is None + assert usage.metadata["token_source"] == "estimated"