Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions app/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,3 +1379,85 @@ def generate_pricing_url(admin_email: str, validity_hours: int = 24) -> str:

# Add the token as a query parameter
return f"{url}?token={token}"

async def expire_trial_user_keys(db: Session, user: DBUser):
"""Expire all LiteLLM keys for a trial user (best-effort, failures are logged)."""
keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.owner_id == user.id).all()
for key in keys:
if key.litellm_token and key.region:
try:
litellm_service = LiteLLMService(
api_url=key.region.litellm_api_url,
api_key=key.region.litellm_api_key
)
await litellm_service.update_key_duration(key.litellm_token, "0d")
logger.info(f"Set duration to 0d for key {key.id}")
except Exception as e:
logger.error(f"Failed to expire key {key.id}: {e}")


async def deactivate_trial_user(db: Session, user: DBUser):
"""Deactivate a trial user in the DB and expire all their LiteLLM keys.

The DB deactivation is committed by the caller before key expiry so that
the user is marked inactive even if the external LiteLLM call fails.
"""
user.is_active = False
user.updated_at = datetime.now(UTC)


async def monitor_trial_users(db: Session):
"""
Monitor trial users and expire them if they have exceeded their budget.
"""
logger.info("Monitoring trial users")
try:
# Get trial team (only consider active teams)
trial_team = db.query(DBTeam).filter(
DBTeam.admin_email == settings.AI_TRIAL_TEAM_EMAIL,
DBTeam.is_active.is_(True)
).first()
if not trial_team:
logger.info("Trial team not found, skipping")
return

# Get all active users in the trial team (excluding admin)
users = db.query(DBUser).filter(
DBUser.team_id == trial_team.id,
DBUser.is_active,
DBUser.role == "user"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought (separate PR): could be worth to have an enum for DBUser.role.

).all()

# Fetch all user budget limits in one query
user_limits = db.query(DBLimitedResource).filter(
DBLimitedResource.owner_type == OwnerType.USER,
DBLimitedResource.owner_id.in_([user.id for user in users]),
DBLimitedResource.resource == ResourceType.BUDGET
).all()

user_limit_map = {limit.owner_id: limit for limit in user_limits}

# Collect over-budget users and mark them inactive in the DB
over_budget_users = []
for user in users:
user_limit = user_limit_map.get(user.id)
if user_limit and user_limit.current_value is not None:
if user_limit.current_value >= user_limit.max_value:
logger.info(f"Trial user {user.email} (ID: {user.id}) has fully used up their budget ({user_limit.current_value} >= {user_limit.max_value}). Setting for removal.")
await deactivate_trial_user(db, user)
over_budget_users.append(user)

# Commit DB changes first so users are deactivated even if key expiry fails
db.commit()

# Expire LiteLLM keys after the DB commit (best-effort)
for user in over_budget_users:
await expire_trial_user_keys(db, user)

logger.info(f"Finished monitoring {len(users)} trial users")

except Exception as e:
logger.error(f"Error in trial user monitoring: {e}")
db.rollback()
raise

85 changes: 85 additions & 0 deletions scripts/trigger_trial_recon_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3

import os
import sys
import asyncio
import logging

# Add the parent directory to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from sqlalchemy.orm import sessionmaker
from app.db.database import engine
from app.core.worker import monitor_trial_users
from app.core.locking import try_acquire_lock, release_lock

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

logger = logging.getLogger(__name__)

async def trigger_trial_recon_job():
"""Manually trigger the trial recon job (monitor_trial_users)"""

# Create database session
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()

lock_name = "monitor_trial_users"

try:
logger.info("Starting manual trial recon job trigger...")

# Try to acquire the lock
if try_acquire_lock(lock_name, db, lock_timeout=10):
logger.info("Acquired monitor_trial_users lock, executing job")
try:
await monitor_trial_users(db)
logger.info("Trial recon job completed successfully")
except Exception as e:
logger.error(f"Error in trial recon job execution: {str(e)}")
raise
finally:
# Always release the lock when done
release_lock(lock_name, db)
logger.info("Released monitor_trial_users lock")
else:
logger.warning("Another process has the monitor_trial_users lock, cannot execute job")
return False

except Exception as e:
logger.error(f"Error in trial recon job trigger: {str(e)}")
# Try to release lock in case of error
try:
release_lock(lock_name, db)
logger.info("Released lock after error")
except Exception as release_error:
logger.error(f"Error releasing lock: {str(release_error)}")
raise
finally:
Comment on lines +42 to +62
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Double lock release on error path

When monitor_trial_users raises, the inner finally block (line 48) already calls release_lock before re-raising. The outer except block (line 55) then calls release_lock a second time. release_lock sets lock.value = "false" and commits, so two calls are idempotent in practice — but the second call triggers a superfluous DB round-trip and makes the intent unclear. The outer except release guard was presumably added for cases where the lock was acquired but the inner finally never ran, which cannot actually happen here given the control flow.

db.close()

return True

def main():
"""Main function to run the script"""
try:
logger.info("Triggering trial recon job manually...")
success = asyncio.run(trigger_trial_recon_job())

if success:
logger.info("✅ Trial recon job completed successfully")
sys.exit(0)
else:
logger.info("⚠️ Trial recon job could not be executed (lock held by another process)")
sys.exit(1)

except Exception as e:
logger.error(f"❌ Script failed: {str(e)}")
sys.exit(1)

if __name__ == "__main__":
main()
150 changes: 150 additions & 0 deletions tests/test_monitor_trial_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest
from unittest.mock import AsyncMock, patch
from sqlalchemy.orm import Session
from app.core.worker import monitor_trial_users
from app.db.models import DBTeam, DBUser, DBLimitedResource, DBPrivateAIKey, DBRegion
from app.schemas.limits import ResourceType, OwnerType, LimitType, UnitType, LimitSource
from app.core.config import settings

@pytest.fixture
def trial_team(db: Session):
team = DBTeam(
name="AI Trial Team",
admin_email=settings.AI_TRIAL_TEAM_EMAIL,
is_active=True
)
db.add(team)
db.commit()
db.refresh(team)
return team
Comment thread
dan2k3k4 marked this conversation as resolved.

@pytest.fixture
def trial_user(db: Session, trial_team: DBTeam):
user = DBUser(
email="trial-user@example.com",
team_id=trial_team.id,
is_active=True,
role="user"
)
db.add(user)
db.commit()
db.refresh(user)
return user

@pytest.fixture
def trial_region(db: Session):
region = DBRegion(
name="test-trial-region",
litellm_api_url="http://mock-litellm",
litellm_api_key="mock-key",
is_active=True
)
db.add(region)
db.commit()
db.refresh(region)
return region

@pytest.fixture
def trial_key(db: Session, trial_user: DBUser, trial_region: DBRegion, trial_team: DBTeam):
key = DBPrivateAIKey(
owner_id=trial_user.id,
team_id=trial_team.id,
region_id=trial_region.id,
litellm_token="mock-token",
name="Trial Key"
)
db.add(key)
db.commit()
db.refresh(key)
return key

@pytest.fixture
def user_budget_limit(db: Session, trial_user: DBUser):
limit = DBLimitedResource(
resource=ResourceType.BUDGET,
limit_type=LimitType.DATA_PLANE,
unit=UnitType.DOLLAR,
owner_type=OwnerType.USER,
owner_id=trial_user.id,
max_value=10.0,
current_value=0.0,
limited_by=LimitSource.MANUAL,
set_by="test"
)
db.add(limit)
db.commit()
db.refresh(limit)
return limit

@pytest.fixture
def mock_litellm():
"""Fixture to mock LiteLLMService."""
with patch('app.core.worker.LiteLLMService', autospec=True) as MockLiteLLM:
mock_instance = MockLiteLLM.return_value
mock_instance.update_key_duration = AsyncMock()
yield mock_instance

@pytest.mark.asyncio
async def test_monitor_trial_users_no_overage(db, trial_team, trial_user, trial_key, user_budget_limit, mock_litellm):
"""Test that users within budget are not affected."""
# usage is 5.0, max is 10.0
user_budget_limit.current_value = 5.0
db.commit()

await monitor_trial_users(db)

# Verify user is still active
db.refresh(trial_user)
assert trial_user.is_active is True

# Verify LiteLLM was not called
assert mock_litellm.update_key_duration.call_count == 0

@pytest.mark.asyncio
async def test_monitor_trial_users_with_overage(db, trial_team, trial_user, trial_key, user_budget_limit, mock_litellm):
"""Test that users over budget are disabled and keys expired."""
# usage is 10.0, max is 10.0 (limit reached)
user_budget_limit.current_value = 10.0
db.commit()

await monitor_trial_users(db)

# Verify user is deactivated
db.refresh(trial_user)
assert trial_user.is_active is False

# Verify LiteLLM called with 0d duration
mock_litellm.update_key_duration.assert_called_once_with("mock-token", "0d")

@pytest.mark.asyncio
async def test_monitor_trial_users_skips_admin(db, trial_team, mock_litellm):
"""Test that admin user is skipped even if over budget."""
admin_user = DBUser(
email="admin@example.com",
team_id=trial_team.id,
is_active=True,
role="admin"
)
db.add(admin_user)
db.commit()

# Even if we add a limit
limit = DBLimitedResource(
resource=ResourceType.BUDGET,
limit_type=LimitType.DATA_PLANE,
unit=UnitType.DOLLAR,
owner_type=OwnerType.USER,
owner_id=admin_user.id,
max_value=10.0,
current_value=15.0,
limited_by=LimitSource.MANUAL,
set_by="test"
)
db.add(limit)
db.commit()

await monitor_trial_users(db)

db.refresh(admin_user)
assert admin_user.is_active is True
assert mock_litellm.update_key_duration.call_count == 0