diff --git a/app/core/worker.py b/app/core/worker.py index a4efc482..0df208ba 100644 --- a/app/core/worker.py +++ b/app/core/worker.py @@ -1379,3 +1379,85 @@ def generate_pricing_url(admin_email: str, validity_hours: int = 24) -> str: # Add the token as a query parameter return f"{url}?token={token}" + +async def expire_trial_user_keys(db: Session, user: DBUser): + """Expire all LiteLLM keys for a trial user (best-effort, failures are logged).""" + keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.owner_id == user.id).all() + for key in keys: + if key.litellm_token and key.region: + try: + litellm_service = LiteLLMService( + api_url=key.region.litellm_api_url, + api_key=key.region.litellm_api_key + ) + await litellm_service.update_key_duration(key.litellm_token, "0d") + logger.info(f"Set duration to 0d for key {key.id}") + except Exception as e: + logger.error(f"Failed to expire key {key.id}: {e}") + + +async def deactivate_trial_user(db: Session, user: DBUser): + """Deactivate a trial user in the DB and expire all their LiteLLM keys. + + The DB deactivation is committed by the caller before key expiry so that + the user is marked inactive even if the external LiteLLM call fails. + """ + user.is_active = False + user.updated_at = datetime.now(UTC) + + +async def monitor_trial_users(db: Session): + """ + Monitor trial users and expire them if they have exceeded their budget. + """ + logger.info("Monitoring trial users") + try: + # Get trial team (only consider active teams) + trial_team = db.query(DBTeam).filter( + DBTeam.admin_email == settings.AI_TRIAL_TEAM_EMAIL, + DBTeam.is_active.is_(True) + ).first() + if not trial_team: + logger.info("Trial team not found, skipping") + return + + # Get all active users in the trial team (excluding admin) + users = db.query(DBUser).filter( + DBUser.team_id == trial_team.id, + DBUser.is_active, + DBUser.role == "user" + ).all() + + # Fetch all user budget limits in one query + user_limits = db.query(DBLimitedResource).filter( + DBLimitedResource.owner_type == OwnerType.USER, + DBLimitedResource.owner_id.in_([user.id for user in users]), + DBLimitedResource.resource == ResourceType.BUDGET + ).all() + + user_limit_map = {limit.owner_id: limit for limit in user_limits} + + # Collect over-budget users and mark them inactive in the DB + over_budget_users = [] + for user in users: + user_limit = user_limit_map.get(user.id) + if user_limit and user_limit.current_value is not None: + if user_limit.current_value >= user_limit.max_value: + logger.info(f"Trial user {user.email} (ID: {user.id}) has fully used up their budget ({user_limit.current_value} >= {user_limit.max_value}). Setting for removal.") + await deactivate_trial_user(db, user) + over_budget_users.append(user) + + # Commit DB changes first so users are deactivated even if key expiry fails + db.commit() + + # Expire LiteLLM keys after the DB commit (best-effort) + for user in over_budget_users: + await expire_trial_user_keys(db, user) + + logger.info(f"Finished monitoring {len(users)} trial users") + + except Exception as e: + logger.error(f"Error in trial user monitoring: {e}") + db.rollback() + raise + diff --git a/scripts/trigger_trial_recon_job.py b/scripts/trigger_trial_recon_job.py new file mode 100644 index 00000000..cbe56494 --- /dev/null +++ b/scripts/trigger_trial_recon_job.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +import os +import sys +import asyncio +import logging + +# Add the parent directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from sqlalchemy.orm import sessionmaker +from app.db.database import engine +from app.core.worker import monitor_trial_users +from app.core.locking import try_acquire_lock, release_lock + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +logger = logging.getLogger(__name__) + +async def trigger_trial_recon_job(): + """Manually trigger the trial recon job (monitor_trial_users)""" + + # Create database session + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = SessionLocal() + + lock_name = "monitor_trial_users" + + try: + logger.info("Starting manual trial recon job trigger...") + + # Try to acquire the lock + if try_acquire_lock(lock_name, db, lock_timeout=10): + logger.info("Acquired monitor_trial_users lock, executing job") + try: + await monitor_trial_users(db) + logger.info("Trial recon job completed successfully") + except Exception as e: + logger.error(f"Error in trial recon job execution: {str(e)}") + raise + finally: + # Always release the lock when done + release_lock(lock_name, db) + logger.info("Released monitor_trial_users lock") + else: + logger.warning("Another process has the monitor_trial_users lock, cannot execute job") + return False + + except Exception as e: + logger.error(f"Error in trial recon job trigger: {str(e)}") + # Try to release lock in case of error + try: + release_lock(lock_name, db) + logger.info("Released lock after error") + except Exception as release_error: + logger.error(f"Error releasing lock: {str(release_error)}") + raise + finally: + db.close() + + return True + +def main(): + """Main function to run the script""" + try: + logger.info("Triggering trial recon job manually...") + success = asyncio.run(trigger_trial_recon_job()) + + if success: + logger.info("✅ Trial recon job completed successfully") + sys.exit(0) + else: + logger.info("⚠️ Trial recon job could not be executed (lock held by another process)") + sys.exit(1) + + except Exception as e: + logger.error(f"❌ Script failed: {str(e)}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/tests/test_monitor_trial_users.py b/tests/test_monitor_trial_users.py new file mode 100644 index 00000000..6425bfda --- /dev/null +++ b/tests/test_monitor_trial_users.py @@ -0,0 +1,150 @@ +import pytest +from unittest.mock import AsyncMock, patch +from sqlalchemy.orm import Session +from app.core.worker import monitor_trial_users +from app.db.models import DBTeam, DBUser, DBLimitedResource, DBPrivateAIKey, DBRegion +from app.schemas.limits import ResourceType, OwnerType, LimitType, UnitType, LimitSource +from app.core.config import settings + +@pytest.fixture +def trial_team(db: Session): + team = DBTeam( + name="AI Trial Team", + admin_email=settings.AI_TRIAL_TEAM_EMAIL, + is_active=True + ) + db.add(team) + db.commit() + db.refresh(team) + return team + +@pytest.fixture +def trial_user(db: Session, trial_team: DBTeam): + user = DBUser( + email="trial-user@example.com", + team_id=trial_team.id, + is_active=True, + role="user" + ) + db.add(user) + db.commit() + db.refresh(user) + return user + +@pytest.fixture +def trial_region(db: Session): + region = DBRegion( + name="test-trial-region", + litellm_api_url="http://mock-litellm", + litellm_api_key="mock-key", + is_active=True + ) + db.add(region) + db.commit() + db.refresh(region) + return region + +@pytest.fixture +def trial_key(db: Session, trial_user: DBUser, trial_region: DBRegion, trial_team: DBTeam): + key = DBPrivateAIKey( + owner_id=trial_user.id, + team_id=trial_team.id, + region_id=trial_region.id, + litellm_token="mock-token", + name="Trial Key" + ) + db.add(key) + db.commit() + db.refresh(key) + return key + +@pytest.fixture +def user_budget_limit(db: Session, trial_user: DBUser): + limit = DBLimitedResource( + resource=ResourceType.BUDGET, + limit_type=LimitType.DATA_PLANE, + unit=UnitType.DOLLAR, + owner_type=OwnerType.USER, + owner_id=trial_user.id, + max_value=10.0, + current_value=0.0, + limited_by=LimitSource.MANUAL, + set_by="test" + ) + db.add(limit) + db.commit() + db.refresh(limit) + return limit + +@pytest.fixture +def mock_litellm(): + """Fixture to mock LiteLLMService.""" + with patch('app.core.worker.LiteLLMService', autospec=True) as MockLiteLLM: + mock_instance = MockLiteLLM.return_value + mock_instance.update_key_duration = AsyncMock() + yield mock_instance + +@pytest.mark.asyncio +async def test_monitor_trial_users_no_overage(db, trial_team, trial_user, trial_key, user_budget_limit, mock_litellm): + """Test that users within budget are not affected.""" + # usage is 5.0, max is 10.0 + user_budget_limit.current_value = 5.0 + db.commit() + + await monitor_trial_users(db) + + # Verify user is still active + db.refresh(trial_user) + assert trial_user.is_active is True + + # Verify LiteLLM was not called + assert mock_litellm.update_key_duration.call_count == 0 + +@pytest.mark.asyncio +async def test_monitor_trial_users_with_overage(db, trial_team, trial_user, trial_key, user_budget_limit, mock_litellm): + """Test that users over budget are disabled and keys expired.""" + # usage is 10.0, max is 10.0 (limit reached) + user_budget_limit.current_value = 10.0 + db.commit() + + await monitor_trial_users(db) + + # Verify user is deactivated + db.refresh(trial_user) + assert trial_user.is_active is False + + # Verify LiteLLM called with 0d duration + mock_litellm.update_key_duration.assert_called_once_with("mock-token", "0d") + +@pytest.mark.asyncio +async def test_monitor_trial_users_skips_admin(db, trial_team, mock_litellm): + """Test that admin user is skipped even if over budget.""" + admin_user = DBUser( + email="admin@example.com", + team_id=trial_team.id, + is_active=True, + role="admin" + ) + db.add(admin_user) + db.commit() + + # Even if we add a limit + limit = DBLimitedResource( + resource=ResourceType.BUDGET, + limit_type=LimitType.DATA_PLANE, + unit=UnitType.DOLLAR, + owner_type=OwnerType.USER, + owner_id=admin_user.id, + max_value=10.0, + current_value=15.0, + limited_by=LimitSource.MANUAL, + set_by="test" + ) + db.add(limit) + db.commit() + + await monitor_trial_users(db) + + db.refresh(admin_user) + assert admin_user.is_active is True + assert mock_litellm.update_key_duration.call_count == 0 \ No newline at end of file