-
Notifications
You must be signed in to change notification settings - Fork 1
chore: add rate limiter and redis #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
a5cb70c
eab0464
4f2e069
c2df63e
755599e
bccfab0
ae779f7
188c3ee
713ee78
a49097b
45c0da6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import uuid | ||
| from datetime import datetime | ||
| import email_validator | ||
| from fastapi_limiter.depends import RateLimiter | ||
|
|
||
| from typing import Optional, List, Union | ||
|
|
||
|
|
@@ -15,6 +16,10 @@ | |
| from urllib.parse import urlparse | ||
| from jose import JWTError, jwt | ||
|
|
||
| from app.api.teams import register_team | ||
| from app.api.users import _create_user_in_db, get_user_by_email | ||
| from app.api.private_ai_keys import create_private_ai_key | ||
|
|
||
| from app.core.config import settings | ||
| from app.core.dependencies import get_limit_service | ||
| from app.core.roles import UserRole | ||
|
|
@@ -59,20 +64,14 @@ | |
| BudgetType, | ||
| ) | ||
|
|
||
| from app.api.teams import register_team | ||
| from app.api.users import _create_user_in_db, get_user_by_email | ||
| from app.api.private_ai_keys import create_private_ai_key | ||
|
|
||
| auth_logger = logging.getLogger(__name__) | ||
|
|
||
| router = APIRouter(tags=["auth"]) | ||
|
|
||
|
|
||
| def get_cookie_domain(): | ||
| """Extract domain from COOKIE_DOMAIN or LAGOON_ROUTES for cookie settings.""" | ||
| # First check for explicit cookie domain setting | ||
| cookie_domain = os.getenv("COOKIE_DOMAIN") | ||
| if cookie_domain: | ||
| if (cookie_domain := os.getenv("COOKIE_DOMAIN")): | ||
| return cookie_domain | ||
|
|
||
| # Fall back to extracting from LAGOON_ROUTES | ||
|
|
@@ -141,7 +140,6 @@ def create_and_set_access_token( | |
| access_token = create_access_token(data={"sub": user_email.lower()}) | ||
|
|
||
| # Get cookie domain from LAGOON_ROUTES | ||
| cookie_domain = get_cookie_domain() | ||
|
|
||
| # Set cookie expiration based on user role | ||
| # System administrators get 8 hours (28800 seconds), regular users get 30 minutes (1800 seconds) | ||
|
|
@@ -160,7 +158,7 @@ def create_and_set_access_token( | |
| } | ||
|
|
||
| # Only set domain if we got one from LAGOON_ROUTES | ||
| if cookie_domain: | ||
| if (cookie_domain := get_cookie_domain()): | ||
| cookie_settings["domain"] = cookie_domain | ||
|
|
||
| # Set cookie with appropriate settings | ||
|
|
@@ -467,6 +465,9 @@ async def validate_email( | |
| email_data: Optional[EmailValidation] = None, | ||
| email: Optional[str] = Form(None), | ||
| db: Session = Depends(get_db), | ||
| _: None = Depends( | ||
| RateLimiter(times=settings.RATE_LIMIT_VALIDATE_EMAIL, seconds=60) | ||
| ), | ||
|
dan2k3k4 marked this conversation as resolved.
Comment on lines
+468
to
+470
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @copilot apply changes based on this feedback
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test
Comment on lines
+468
to
+470
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The |
||
| ): | ||
| """ | ||
| Validate an email address and generate a validation code. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import logging | ||
| import os | ||
| from contextlib import asynccontextmanager | ||
| from datetime import UTC | ||
|
|
||
| from app.__version__ import __version__ | ||
| from app.api import ( | ||
|
|
@@ -18,17 +19,24 @@ | |
| teams, | ||
| users, | ||
| ) | ||
| from app.core.locking import release_lock, try_acquire_lock | ||
| from app.core.config import settings | ||
| from app.core.worker import hard_delete_expired_teams, monitor_teams | ||
| from app.db.database import get_db | ||
| from app.middleware.audit import AuditLogMiddleware | ||
| from app.middleware.auth import AuthMiddleware | ||
| from app.middleware.caching import CacheControlMiddleware | ||
| from app.middleware.prometheus import PrometheusMiddleware | ||
| from apscheduler.schedulers.asyncio import AsyncIOScheduler | ||
| from apscheduler.triggers.cron import CronTrigger | ||
| from fastapi import FastAPI | ||
| from fastapi_limiter import FastAPILimiter | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.middleware.trustedhost import TrustedHostMiddleware | ||
| from fastapi.openapi.docs import get_swagger_ui_html | ||
| from fastapi.openapi.utils import get_openapi | ||
| from prometheus_fastapi_instrumentator import Instrumentator, metrics | ||
| from redis.asyncio import Redis as AsyncRedis | ||
| from starlette.middleware.base import BaseHTTPMiddleware | ||
|
|
||
| # Set timezone environment variable to prevent tzlocal warning | ||
|
|
@@ -52,8 +60,128 @@ async def dispatch(self, request, call_next): | |
|
|
||
| @asynccontextmanager | ||
| async def lifespan(app: FastAPI): | ||
| # Initialize rate limiter | ||
| await FastAPILimiter.init( | ||
| redis=AsyncRedis.from_url(settings.REDIS_URL, decode_responses=True) | ||
| ) | ||
|
Comment on lines
62
to
+66
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| # Create scheduler | ||
| scheduler = AsyncIOScheduler() | ||
|
|
||
| async def monitor_teams_job(): | ||
| lock_name = "monitor_teams" | ||
| lock_acquired = False | ||
|
|
||
| try: | ||
| # Try to acquire the lock using a dedicated short-lived session | ||
| lock_db = next(get_db()) | ||
| try: | ||
| lock_acquired = try_acquire_lock(lock_name, lock_db, lock_timeout=10) | ||
| finally: | ||
| lock_db.close() | ||
|
|
||
| if lock_acquired: | ||
| logger.info("Acquired monitor_teams lock, executing job") | ||
| job_db = next(get_db()) | ||
| try: | ||
| await monitor_teams(job_db) | ||
| except Exception as e: | ||
| job_db.rollback() | ||
| logger.error(f"Error in monitor_teams background task: {str(e)}") | ||
| finally: | ||
| job_db.close() | ||
| else: | ||
| logger.info("Another process has the monitor_teams lock, skipping execution") | ||
| except Exception as e: | ||
| logger.error(f"Error in monitor_teams job: {str(e)}") | ||
| finally: | ||
| # Always release the lock when it was acquired, using a separate session | ||
| if lock_acquired: | ||
| release_db = next(get_db()) | ||
| try: | ||
| release_lock(lock_name, release_db) | ||
| except Exception as release_error: | ||
| logger.error(f"Error releasing lock: {str(release_error)}") | ||
| finally: | ||
| release_db.close() | ||
|
|
||
| # Set schedule based on environment | ||
| if settings.ENV_SUFFIX == "local": | ||
| cron_trigger = CronTrigger(minute='*/10', timezone=UTC, jitter=180) | ||
| else: | ||
| # Run every hour in other environments with jitter | ||
| cron_trigger = CronTrigger(hour='*', minute=0, timezone=UTC, jitter=60) | ||
|
|
||
| scheduler.add_job( | ||
| monitor_teams_job, | ||
| trigger=cron_trigger, | ||
| id='monitor_teams', | ||
| replace_existing=True | ||
| ) | ||
|
|
||
| # Hard delete job for teams that have been soft-deleted for 60+ days | ||
| async def hard_delete_teams_job(): | ||
| lock_name = "hard_delete_teams" | ||
| lock_acquired = False | ||
|
|
||
| try: | ||
| # Try to acquire the lock using a dedicated short-lived session | ||
| lock_db = next(get_db()) | ||
| try: | ||
| lock_acquired = try_acquire_lock(lock_name, lock_db, lock_timeout=10) | ||
| finally: | ||
| lock_db.close() | ||
|
|
||
| if lock_acquired: | ||
| logger.info("Acquired hard_delete_teams lock, executing job") | ||
| job_db = next(get_db()) | ||
| try: | ||
| await hard_delete_expired_teams(job_db) | ||
| except Exception as e: | ||
| job_db.rollback() | ||
| logger.error(f"Error in hard_delete_expired_teams background task: {str(e)}") | ||
| finally: | ||
| job_db.close() | ||
| else: | ||
| logger.info("Another process has the hard_delete_teams lock, skipping execution") | ||
| except Exception as e: | ||
| logger.error(f"Error in hard_delete_teams job: {str(e)}") | ||
| finally: | ||
| # Always release the lock when it was acquired, using a separate session | ||
| if lock_acquired: | ||
| release_db = next(get_db()) | ||
| try: | ||
| release_lock(lock_name, release_db) | ||
| except Exception as release_error: | ||
| logger.error(f"Error releasing lock: {str(release_error)}") | ||
| finally: | ||
| release_db.close() | ||
|
|
||
| # Set schedule based on environment for hard delete job | ||
| if settings.ENV_SUFFIX == "local": | ||
| # In local env, run every hour at :30 for testing | ||
| hard_delete_trigger = CronTrigger(hour='*', minute=30, timezone=UTC) | ||
| else: | ||
| # In production, run daily at 3 AM | ||
| hard_delete_trigger = CronTrigger(hour=3, minute=0, timezone=UTC) | ||
|
|
||
| scheduler.add_job( | ||
| hard_delete_teams_job, | ||
| trigger=hard_delete_trigger, | ||
| id='hard_delete_teams', | ||
| replace_existing=True | ||
| ) | ||
|
|
||
| # Start the scheduler | ||
| scheduler.start() | ||
|
|
||
| yield | ||
|
|
||
| # Shutdown the rate limiter | ||
| await FastAPILimiter.close() | ||
|
|
||
| # Shutdown the scheduler | ||
| scheduler.shutdown() | ||
|
|
||
| app = FastAPI( | ||
| title="Private AI Keys as a Service", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Get cookie domain from LAGOON_ROUTESis now orphaned — theget_cookie_domain()call it referenced was refactored into the walrus-operator expression below. The blank line and stale comment are misleading.