diff --git a/Dockerfile b/Dockerfile index 96b4b4f5..228b3e1d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,6 @@ FROM uselagoon/python-3.12:latest@sha256:5ab457220705f7b4c072ee746b5920779a385a70175e0471b9a263c840ff1070 -RUN apk add bash --no-cache -RUN apk add curl --no-cache +RUN apk add bash curl postgresql-client --no-cache WORKDIR /app diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index e69870df..94bde38d 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -23,9 +23,9 @@ DBPoolPurchase, DBPrivateAIKey, DBRegion, + DBSpendCap, DBUser, DBTeam, - DBSpendCap, ) from app.services.litellm import LiteLLMService from app.core.security import ( @@ -917,7 +917,17 @@ async def delete_private_ai_key( return {"message": "Private AI Key deleted successfully"} -@router.get("/{key_id}/spend", response_model=PrivateAIKeySpendBasic) +@router.get( + "/{key_id}/spend", + response_model=PrivateAIKeySpendBasic, + deprecated=True, + summary="Legacy: use GET /spend/{region_id}/key/{key_id} instead", + description=( + "Returns spend and budget metadata for a specific key. " + "This endpoint is legacy. Prefer /spend/{{region_id}}/key/{{key_id}} " + "which provides richer budget metadata and team-level context." + ), +) async def get_private_ai_key_spend( key_id: int, current_user=Depends(get_current_user_from_auth), @@ -942,6 +952,24 @@ async def get_private_ai_key_spend( data = await litellm_service.get_key_info(private_ai_key.litellm_token) info = data.get("info", {}) + # Override max_budget with the value from spend_caps DB table if present. + # This ensures the configured cap (which may differ from LiteLLM's value + # for purchase-gated teams) is returned to the caller. + # The unique index for key-scope caps is on (region_id, key_id); team_id + # and user_id are metadata only and must NOT be used as lookup filters. + configured_cap = ( + db.query(DBSpendCap.max_budget) + .filter( + DBSpendCap.scope == "key", + DBSpendCap.region_id == private_ai_key.region_id, + DBSpendCap.key_id == private_ai_key.id, + ) + .first() + ) + if configured_cap is not None and configured_cap[0] is not None: + info = dict(info) + info["max_budget"] = round(float(configured_cap[0]), 4) + # Only set default for spend field spend_info = {"spend": info.get("spend", 0.0), **info} diff --git a/app/api/spend.py b/app/api/spend.py index 2dbbf1f4..2e4212cd 100644 --- a/app/api/spend.py +++ b/app/api/spend.py @@ -254,17 +254,16 @@ def _get_spend_cap_max_budget( user_id: int | None = None, key_id: int | None = None, ) -> float | None: - cap = ( - db.query(DBSpendCap.max_budget) - .filter( - DBSpendCap.scope == scope, - DBSpendCap.region_id == region_id, - DBSpendCap.team_id == team_id, - DBSpendCap.user_id == user_id, - DBSpendCap.key_id == key_id, - ) - .first() - ) + # Match on unique-index columns only (see _upsert_spend_cap for rationale). + filters = [DBSpendCap.scope == scope, DBSpendCap.region_id == region_id] + if scope == "team": + filters.append(DBSpendCap.team_id == team_id) + elif scope == "team_member": + filters.extend([DBSpendCap.team_id == team_id, DBSpendCap.user_id == user_id]) + elif scope == "key": + filters.append(DBSpendCap.key_id == key_id) + + cap = db.query(DBSpendCap.max_budget).filter(*filters).first() if cap is None or cap[0] is None: return None return float(cap[0]) @@ -384,17 +383,20 @@ def _upsert_spend_cap( month_anchor: date | None = None, month_start_spend: float | None = None, ) -> None: - cap = ( - db.query(DBSpendCap) - .filter( - DBSpendCap.scope == scope, - DBSpendCap.region_id == region_id, - DBSpendCap.team_id == team_id, - DBSpendCap.user_id == user_id, - DBSpendCap.key_id == key_id, - ) - .first() - ) + # Look up the existing row using the same columns as the partial unique + # index for this scope. A previous implementation filtered on ALL four + # columns (team_id, user_id, key_id) which could miss a row whose + # team_id had been repaired from NULL to a real value, causing a + # UniqueViolation on INSERT. + filters = [DBSpendCap.scope == scope, DBSpendCap.region_id == region_id] + if scope == "team": + filters.append(DBSpendCap.team_id == team_id) + elif scope == "team_member": + filters.extend([DBSpendCap.team_id == team_id, DBSpendCap.user_id == user_id]) + elif scope == "key": + filters.append(DBSpendCap.key_id == key_id) + + cap = db.query(DBSpendCap).filter(*filters).first() if cap is None: cap = DBSpendCap( scope=scope, @@ -403,6 +405,17 @@ def _upsert_spend_cap( user_id=user_id, key_id=key_id, ) + else: + # Repair stale columns so the row stays consistent with the + # current key/team/user relationship. Normalize all relationship + # columns to the requested values, including clearing stale values + # to None when they are not part of the current scope. + if cap.team_id != team_id: + cap.team_id = team_id + if cap.user_id != user_id: + cap.user_id = user_id + if cap.key_id != key_id: + cap.key_id = key_id cap.max_budget = max_budget cap.budget_duration = budget_duration cap.month_anchor = month_anchor @@ -421,17 +434,16 @@ def _delete_spend_cap( user_id: int | None = None, key_id: int | None = None, ) -> None: - ( - db.query(DBSpendCap) - .filter( - DBSpendCap.scope == scope, - DBSpendCap.region_id == region_id, - DBSpendCap.team_id == team_id, - DBSpendCap.user_id == user_id, - DBSpendCap.key_id == key_id, - ) - .delete() - ) + # Match on unique-index columns only (see _upsert_spend_cap for rationale). + filters = [DBSpendCap.scope == scope, DBSpendCap.region_id == region_id] + if scope == "team": + filters.append(DBSpendCap.team_id == team_id) + elif scope == "team_member": + filters.extend([DBSpendCap.team_id == team_id, DBSpendCap.user_id == user_id]) + elif scope == "key": + filters.append(DBSpendCap.key_id == key_id) + + db.query(DBSpendCap).filter(*filters).delete() # Defer commit to the endpoint so DB changes and remote sync share one boundary. db.flush() diff --git a/scripts/detect_orphaned_keys.py b/scripts/detect_orphaned_keys.py new file mode 100644 index 00000000..508e7917 --- /dev/null +++ b/scripts/detect_orphaned_keys.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +""" +One-time script to detect and handle orphaned AI tokens. + +An orphaned key is one where the litellm_token stored in the amazee.ai DB +no longer exists in the corresponding LiteLLM instance. + +What this script does: +1. Iterates over all ai_tokens with a litellm_token +2. Checks each token against its region's LiteLLM via /key/info +3. For tokens that return 401 or 404 (not found): nullifies litellm_token and litellm_api_url +4. Also cleans up any spend_caps rows for orphaned keys + +Default mode is dry-run. Use --apply to execute changes. + +Safety: + - Only touches keys where LiteLLM explicitly returns 401 or 404 "key does not exist" + - Transient errors (502, timeouts, 403) are treated as failures, NOT orphans + - The ai_tokens row is preserved — only litellm_token and litellm_api_url are nulled + - spend_caps for orphaned keys are deleted (no budget to enforce for a dead key) + +Usage: + python scripts/detect_orphaned_keys.py + python scripts/detect_orphaned_keys.py --apply + python scripts/detect_orphaned_keys.py --region-id 2 + python scripts/detect_orphaned_keys.py --limit 50 --apply +""" + +import argparse +import asyncio +import json +import os +import sys +from datetime import datetime, timezone + + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from app.db.database import SessionLocal +from app.db.models import DBPrivateAIKey, DBRegion, DBSpendCap +from app.services.litellm import LiteLLMService + + +def parse_status_from_exc(exc: Exception) -> int | None: + detail = getattr(exc, "detail", "") or str(exc) + if "Status 401" in detail: + return 401 + if "Status 404" in detail: + return 404 + if "Status 403" in detail: + return 403 + if "Status 502" in detail: + return 502 + return None + + +async def check_token_exists(service: LiteLLMService, token: str) -> str: + """ + Check if a token exists in LiteLLM. + + Returns: + "exists" — token is valid + "orphaned" — token does not exist (401/404) + "error" — transient error (should not mark as orphan) + """ + try: + await service.get_key_info(token) + return "exists" + except Exception as exc: + status = parse_status_from_exc(exc) + if status in (401, 404): + return "orphaned" + return "error" + + +async def main() -> int: + parser = argparse.ArgumentParser( + description="Detect and handle orphaned AI tokens (litellm_token not found in LiteLLM)" + ) + parser.add_argument( + "--apply", + action="store_true", + help="Apply changes. Without this flag, script runs in dry-run mode.", + ) + parser.add_argument( + "--region-id", + type=int, + default=None, + help="Scope to a single region ID", + ) + parser.add_argument( + "--key-id", + type=int, + default=None, + help="Scope to a single key ID", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Max keys to process", + ) + parser.add_argument( + "--output-json", + default="/tmp/orphaned-keys-report.json", + help="Path to write the full report JSON", + ) + args = parser.parse_args() + + dry_run = not args.apply + mode = "DRY-RUN" if dry_run else "APPLY" + print(f"Starting orphaned key detection in {mode} mode") + print(f"Report will be written to: {args.output_json}") + print() + + session = SessionLocal() + + query = session.query(DBPrivateAIKey).filter( + DBPrivateAIKey.litellm_token.isnot(None) + ) + if args.region_id is not None: + query = query.filter(DBPrivateAIKey.region_id == args.region_id) + if args.key_id is not None: + query = query.filter(DBPrivateAIKey.id == args.key_id) + query = query.order_by(DBPrivateAIKey.id.asc()) + if args.limit is not None: + query = query.limit(args.limit) + keys = query.all() + + print(f"Found {len(keys)} keys with litellm_token to check") + + counters = { + "processed": 0, + "exists": 0, + "orphaned": 0, + "errors": 0, + "caps_deleted": 0, + } + report = { + "run_at": datetime.now(timezone.utc).isoformat(), + "mode": mode, + "orphaned_keys": [], + "error_keys": [], + } + + try: + # Preload all active regions once and build a LiteLLMService per region + # to avoid an N+1 DB query and redundant service construction per key. + region_query = session.query(DBRegion).filter(DBRegion.is_active.is_(True)) + if args.region_id is not None: + region_query = region_query.filter(DBRegion.id == args.region_id) + regions_by_id: dict[int, DBRegion] = {r.id: r for r in region_query.all()} + services_by_region_id: dict[int, LiteLLMService] = { + r.id: LiteLLMService(r.litellm_api_url, r.litellm_api_key) + for r in regions_by_id.values() + } + + for key in keys: + counters["processed"] += 1 + + region = regions_by_id.get(key.region_id) + if not region: + print( + f" key={key.id:5d} | SKIP | region={key.region_id} inactive/missing" + ) + continue + + service = services_by_region_id[region.id] + status = await check_token_exists(service, key.litellm_token) + + if status == "exists": + counters["exists"] += 1 + # Only log every 100th to reduce noise + if counters["processed"] % 100 == 0: + print( + f" key={key.id:5d} | EXISTS | region={region.name} | processed={counters['processed']}" + ) + continue + + if status == "orphaned": + counters["orphaned"] += 1 + token = key.litellm_token or "" + redacted_token = f"...{token[-4:]}" if len(token) >= 4 else "****" + entry = { + "key_id": key.id, + "key_name": key.name, + "region_id": region.id, + "region_name": region.name, + "litellm_token_hint": redacted_token, + "owner_id": key.owner_id, + "team_id": key.team_id, + "has_spend_cap": False, + "action": "would_nullify" if dry_run else "nullified", + } + + # Check for spend_caps + caps = ( + session.query(DBSpendCap) + .filter( + DBSpendCap.scope == "key", + DBSpendCap.key_id == key.id, + ) + .all() + ) + if caps: + entry["has_spend_cap"] = True + entry["spend_cap_ids"] = [c.id for c in caps] + + action_word = "WOULD ORPHAN" if dry_run else "ORPHANING" + print( + f" key={key.id:5d} | {action_word} | region={region.name} | " + f"name={key.name} | owner={key.owner_id} | team={key.team_id}" + f"{' | has_spend_cap' if caps else ''}" + ) + + if not dry_run: + # Nullify the litellm fields — row stays but token is gone + key.litellm_token = None + key.litellm_api_url = None + session.add(key) + + # Delete any spend_caps for this key + if caps: + for cap in caps: + session.delete(cap) + counters["caps_deleted"] += len(caps) + + report["orphaned_keys"].append(entry) + + elif status == "error": + counters["errors"] += 1 + entry = { + "key_id": key.id, + "key_name": key.name, + "region_id": region.id, + "region_name": region.name, + } + print( + f" key={key.id:5d} | ERROR | region={region.name} | could not check token" + ) + report["error_keys"].append(entry) + + if not dry_run: + session.commit() + + finally: + session.close() + + # Write report + report["summary"] = counters + with open(args.output_json, "w", encoding="utf-8") as fp: + json.dump(report, fp, indent=2, default=str) + + print() + print("Summary:") + print(f" Processed: {counters['processed']}") + print(f" Exists: {counters['exists']}") + print(f" Orphaned: {counters['orphaned']}") + print(f" Errors: {counters['errors']}") + print(f" Caps del: {counters['caps_deleted']}") + print(f" Report: {args.output_json}") + + return 0 if not report["error_keys"] else 1 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/tests/test_spend.py b/tests/test_spend.py index b9ae1269..cffc2c21 100644 --- a/tests/test_spend.py +++ b/tests/test_spend.py @@ -5,6 +5,7 @@ from datetime import UTC, datetime, timedelta from sqlalchemy.exc import IntegrityError +from app.api.spend import _upsert_spend_cap from app.core.security import get_password_hash from app.db.models import ( BudgetType, @@ -1814,6 +1815,72 @@ def test_spend_caps_unique_key_scope_enforced(db, test_region, test_team_user): db.rollback() +def test_upsert_spend_cap_repairs_stale_team_and_user_columns( + db, test_region, test_team, test_team_user +): + """ + Regression: _upsert_spend_cap must find an existing key-scope row by + (region_id, key_id) even when team_id/user_id were NULL (stale), update + those columns in-place, and NOT insert a second row (which would raise a + UniqueViolation on the uq_spend_caps_key_scope index). + """ + key = DBPrivateAIKey( + name="stale-repair-key", + litellm_token="stale-repair-token", + region_id=test_region.id, + owner_id=test_team_user.id, + team_id=test_team.id, + ) + db.add(key) + db.commit() + db.refresh(key) + + # Insert an initial key-scope cap with NULL team_id/user_id (simulating + # the stale state that caused UniqueViolation before the fix). + stale_cap = DBSpendCap( + scope="key", + region_id=test_region.id, + key_id=key.id, + team_id=None, + user_id=None, + max_budget=10.0, + ) + db.add(stale_cap) + db.commit() + db.refresh(stale_cap) + stale_cap_id = stale_cap.id + + # Call _upsert_spend_cap with the correct team/user values. + # Before the fix this would miss the stale row and attempt an INSERT, + # causing a UniqueViolation. After the fix it should repair in-place. + _upsert_spend_cap( + db, + scope="key", + region_id=test_region.id, + key_id=key.id, + team_id=test_team.id, + user_id=test_team_user.id, + max_budget=20.0, + ) + db.commit() + + # Only one row for this key should exist + caps = ( + db.query(DBSpendCap) + .filter(DBSpendCap.scope == "key", DBSpendCap.key_id == key.id) + .all() + ) + assert len(caps) == 1 + repaired = caps[0] + # Same row, not a new insert + assert repaired.id == stale_cap_id + # Stale columns repaired + assert repaired.team_id == test_team.id + assert repaired.user_id == test_team_user.id + # Budget updated + assert repaired.max_budget == 20.0 + + @patch("app.api.spend.LiteLLMService.get_team_info", new_callable=AsyncMock) def test_get_team_spend_uses_db_key_cap_regardless_of_team_type( mock_get_team_info, client, admin_token, test_team, test_region, db