-
Notifications
You must be signed in to change notification settings - Fork 1
refactor: optimize trial user monitor with batch query and deactivate helper #214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
9f116de
c382ea3
f0cfe74
0b066e8
ad8bbcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| import os | ||
| import sys | ||
| import asyncio | ||
| import logging | ||
|
|
||
| # Add the parent directory to the Python path | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
|
|
||
| from sqlalchemy.orm import sessionmaker | ||
| from app.db.database import engine | ||
| from app.core.worker import monitor_trial_users | ||
| from app.core.locking import try_acquire_lock, release_lock | ||
|
|
||
| # Configure logging | ||
| logging.basicConfig( | ||
| level=logging.INFO, | ||
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| async def trigger_trial_recon_job(): | ||
| """Manually trigger the trial recon job (monitor_trial_users)""" | ||
|
|
||
| # Create database session | ||
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | ||
| db = SessionLocal() | ||
|
|
||
| lock_name = "monitor_trial_users" | ||
|
|
||
| try: | ||
| logger.info("Starting manual trial recon job trigger...") | ||
|
|
||
| # Try to acquire the lock | ||
| if try_acquire_lock(lock_name, db, lock_timeout=10): | ||
| logger.info("Acquired monitor_trial_users lock, executing job") | ||
| try: | ||
| await monitor_trial_users(db) | ||
| logger.info("Trial recon job completed successfully") | ||
| except Exception as e: | ||
| logger.error(f"Error in trial recon job execution: {str(e)}") | ||
| raise | ||
| finally: | ||
| # Always release the lock when done | ||
| release_lock(lock_name, db) | ||
| logger.info("Released monitor_trial_users lock") | ||
| else: | ||
| logger.warning("Another process has the monitor_trial_users lock, cannot execute job") | ||
| return False | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error in trial recon job trigger: {str(e)}") | ||
| # Try to release lock in case of error | ||
| try: | ||
| release_lock(lock_name, db) | ||
| logger.info("Released lock after error") | ||
| except Exception as release_error: | ||
| logger.error(f"Error releasing lock: {str(release_error)}") | ||
| raise | ||
| finally: | ||
|
Comment on lines
+42
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When |
||
| db.close() | ||
|
|
||
| return True | ||
|
|
||
| def main(): | ||
| """Main function to run the script""" | ||
| try: | ||
| logger.info("Triggering trial recon job manually...") | ||
| success = asyncio.run(trigger_trial_recon_job()) | ||
|
|
||
| if success: | ||
| logger.info("✅ Trial recon job completed successfully") | ||
| sys.exit(0) | ||
| else: | ||
| logger.info("⚠️ Trial recon job could not be executed (lock held by another process)") | ||
| sys.exit(1) | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"❌ Script failed: {str(e)}") | ||
| sys.exit(1) | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| import pytest | ||
| from unittest.mock import AsyncMock, patch | ||
| from sqlalchemy.orm import Session | ||
| from app.core.worker import monitor_trial_users | ||
| from app.db.models import DBTeam, DBUser, DBLimitedResource, DBPrivateAIKey, DBRegion | ||
| from app.schemas.limits import ResourceType, OwnerType, LimitType, UnitType, LimitSource | ||
| from app.core.config import settings | ||
|
|
||
| @pytest.fixture | ||
| def trial_team(db: Session): | ||
| team = DBTeam( | ||
| name="AI Trial Team", | ||
| admin_email=settings.AI_TRIAL_TEAM_EMAIL, | ||
| is_active=True | ||
| ) | ||
| db.add(team) | ||
| db.commit() | ||
| db.refresh(team) | ||
| return team | ||
|
dan2k3k4 marked this conversation as resolved.
|
||
|
|
||
| @pytest.fixture | ||
| def trial_user(db: Session, trial_team: DBTeam): | ||
| user = DBUser( | ||
| email="trial-user@example.com", | ||
| team_id=trial_team.id, | ||
| is_active=True, | ||
| role="user" | ||
| ) | ||
| db.add(user) | ||
| db.commit() | ||
| db.refresh(user) | ||
| return user | ||
|
|
||
| @pytest.fixture | ||
| def trial_region(db: Session): | ||
| region = DBRegion( | ||
| name="test-trial-region", | ||
| litellm_api_url="http://mock-litellm", | ||
| litellm_api_key="mock-key", | ||
| is_active=True | ||
| ) | ||
| db.add(region) | ||
| db.commit() | ||
| db.refresh(region) | ||
| return region | ||
|
|
||
| @pytest.fixture | ||
| def trial_key(db: Session, trial_user: DBUser, trial_region: DBRegion, trial_team: DBTeam): | ||
| key = DBPrivateAIKey( | ||
| owner_id=trial_user.id, | ||
| team_id=trial_team.id, | ||
| region_id=trial_region.id, | ||
| litellm_token="mock-token", | ||
| name="Trial Key" | ||
| ) | ||
| db.add(key) | ||
| db.commit() | ||
| db.refresh(key) | ||
| return key | ||
|
|
||
| @pytest.fixture | ||
| def user_budget_limit(db: Session, trial_user: DBUser): | ||
| limit = DBLimitedResource( | ||
| resource=ResourceType.BUDGET, | ||
| limit_type=LimitType.DATA_PLANE, | ||
| unit=UnitType.DOLLAR, | ||
| owner_type=OwnerType.USER, | ||
| owner_id=trial_user.id, | ||
| max_value=10.0, | ||
| current_value=0.0, | ||
| limited_by=LimitSource.MANUAL, | ||
| set_by="test" | ||
| ) | ||
| db.add(limit) | ||
| db.commit() | ||
| db.refresh(limit) | ||
| return limit | ||
|
|
||
| @pytest.fixture | ||
| def mock_litellm(): | ||
| """Fixture to mock LiteLLMService.""" | ||
| with patch('app.core.worker.LiteLLMService', autospec=True) as MockLiteLLM: | ||
| mock_instance = MockLiteLLM.return_value | ||
| mock_instance.update_key_duration = AsyncMock() | ||
| yield mock_instance | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_monitor_trial_users_no_overage(db, trial_team, trial_user, trial_key, user_budget_limit, mock_litellm): | ||
| """Test that users within budget are not affected.""" | ||
| # usage is 5.0, max is 10.0 | ||
| user_budget_limit.current_value = 5.0 | ||
| db.commit() | ||
|
|
||
| await monitor_trial_users(db) | ||
|
|
||
| # Verify user is still active | ||
| db.refresh(trial_user) | ||
| assert trial_user.is_active is True | ||
|
|
||
| # Verify LiteLLM was not called | ||
| assert mock_litellm.update_key_duration.call_count == 0 | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_monitor_trial_users_with_overage(db, trial_team, trial_user, trial_key, user_budget_limit, mock_litellm): | ||
| """Test that users over budget are disabled and keys expired.""" | ||
| # usage is 10.0, max is 10.0 (limit reached) | ||
| user_budget_limit.current_value = 10.0 | ||
| db.commit() | ||
|
|
||
| await monitor_trial_users(db) | ||
|
|
||
| # Verify user is deactivated | ||
| db.refresh(trial_user) | ||
| assert trial_user.is_active is False | ||
|
|
||
| # Verify LiteLLM called with 0d duration | ||
| mock_litellm.update_key_duration.assert_called_once_with("mock-token", "0d") | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_monitor_trial_users_skips_admin(db, trial_team, mock_litellm): | ||
| """Test that admin user is skipped even if over budget.""" | ||
| admin_user = DBUser( | ||
| email="admin@example.com", | ||
| team_id=trial_team.id, | ||
| is_active=True, | ||
| role="admin" | ||
| ) | ||
| db.add(admin_user) | ||
| db.commit() | ||
|
|
||
| # Even if we add a limit | ||
| limit = DBLimitedResource( | ||
| resource=ResourceType.BUDGET, | ||
| limit_type=LimitType.DATA_PLANE, | ||
| unit=UnitType.DOLLAR, | ||
| owner_type=OwnerType.USER, | ||
| owner_id=admin_user.id, | ||
| max_value=10.0, | ||
| current_value=15.0, | ||
| limited_by=LimitSource.MANUAL, | ||
| set_by="test" | ||
| ) | ||
| db.add(limit) | ||
| db.commit() | ||
|
|
||
| await monitor_trial_users(db) | ||
|
|
||
| db.refresh(admin_user) | ||
| assert admin_user.is_active is True | ||
| assert mock_litellm.update_key_duration.call_count == 0 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thought (separate PR): could be worth to have an enum for
DBUser.role.