diff --git a/Makefile b/Makefile index b5c47bfc..730076d7 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: backend-test backend-test-build test-clean test-network test-postgres frontend-test frontend-test-build migration-create migration-upgrade migration-downgrade migration-stamp +.PHONY: backend-test backend-test-build test-clean test-network test-postgres test-redis frontend-test frontend-test-build migration-create migration-upgrade migration-downgrade migration-stamp # Default target all: backend-test @@ -23,9 +23,17 @@ test-postgres: test-clean test-network pgvector/pgvector:pg16 && \ sleep 5 +# Start Redis container for testing +test-redis: test-network + docker run -d \ + --name amazee-test-redis \ + --network amazeeai_default \ + redis:alpine && \ + sleep 2 + # Run backend tests for a specific regex # Usage: make backend-test-regex regex="test_pattern" -backend-test-regex: test-clean backend-test-build test-postgres +backend-test-regex: test-clean backend-test-build test-postgres test-redis @if [ -z "$(regex)" ]; then \ echo "Error: regex parameter is required. Usage: make backend-test-regex regex=\"test_pattern\""; \ exit 1; \ @@ -33,6 +41,7 @@ backend-test-regex: test-clean backend-test-build test-postgres docker run --rm \ --network amazeeai_default \ -e DATABASE_URL="postgresql://postgres:postgres@amazee-test-postgres/postgres_service" \ + -e REDIS_URL="redis://amazee-test-redis:6379" \ -e SECRET_KEY="test-secret-key" \ -e POSTGRES_HOST="amazee-test-postgres" \ -e POSTGRES_USER="postgres" \ @@ -47,10 +56,11 @@ backend-test-regex: test-clean backend-test-build test-postgres amazee-backend-test pytest -vv -k "$(regex)" # Run backend tests in a new container -backend-test: test-clean backend-test-build test-postgres +backend-test: test-clean backend-test-build test-postgres test-redis docker run --rm \ --network amazeeai_default \ -e DATABASE_URL="postgresql://postgres:postgres@amazee-test-postgres/postgres_service" \ + -e REDIS_URL="redis://amazee-test-redis:6379" \ -e SECRET_KEY="test-secret-key" \ -e POSTGRES_HOST="amazee-test-postgres" \ -e POSTGRES_USER="postgres" \ @@ -65,10 +75,11 @@ backend-test: test-clean backend-test-build test-postgres amazee-backend-test # Run backend tests with coverage report -backend-test-cov: test-clean backend-test-build test-postgres +backend-test-cov: test-clean backend-test-build test-postgres test-redis docker run --rm \ --network amazeeai_default \ -e DATABASE_URL="postgresql://postgres:postgres@amazee-test-postgres/postgres_service" \ + -e REDIS_URL="redis://amazee-test-redis:6379" \ -e SECRET_KEY="test-secret-key" \ -e POSTGRES_HOST="amazee-test-postgres" \ -e POSTGRES_USER="postgres" \ @@ -99,6 +110,8 @@ test-all: backend-test frontend-test test-clean: docker stop amazee-test-postgres 2>/dev/null || true docker rm amazee-test-postgres 2>/dev/null || true + docker stop amazee-test-redis 2>/dev/null || true + docker rm amazee-test-redis 2>/dev/null || true docker network rm amazeeai_default 2>/dev/null || true docker rmi amazee-backend-test 2>/dev/null || true docker rmi amazeeai-frontend-test 2>/dev/null || true diff --git a/app/api/auth.py b/app/api/auth.py index cdd1bd5d..3b7354a5 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -6,6 +6,7 @@ import uuid from datetime import datetime import email_validator +from fastapi_limiter.depends import RateLimiter from typing import Optional, List, Union @@ -15,6 +16,10 @@ from urllib.parse import urlparse from jose import JWTError, jwt +from app.api.teams import register_team +from app.api.users import _create_user_in_db, get_user_by_email +from app.api.private_ai_keys import create_private_ai_key + from app.core.config import settings from app.core.dependencies import get_limit_service from app.core.roles import UserRole @@ -59,20 +64,14 @@ BudgetType, ) -from app.api.teams import register_team -from app.api.users import _create_user_in_db, get_user_by_email -from app.api.private_ai_keys import create_private_ai_key - auth_logger = logging.getLogger(__name__) router = APIRouter(tags=["auth"]) - def get_cookie_domain(): """Extract domain from COOKIE_DOMAIN or LAGOON_ROUTES for cookie settings.""" # First check for explicit cookie domain setting - cookie_domain = os.getenv("COOKIE_DOMAIN") - if cookie_domain: + if (cookie_domain := os.getenv("COOKIE_DOMAIN")): return cookie_domain # Fall back to extracting from LAGOON_ROUTES @@ -141,7 +140,6 @@ def create_and_set_access_token( access_token = create_access_token(data={"sub": user_email.lower()}) # Get cookie domain from LAGOON_ROUTES - cookie_domain = get_cookie_domain() # Set cookie expiration based on user role # System administrators get 8 hours (28800 seconds), regular users get 30 minutes (1800 seconds) @@ -160,7 +158,7 @@ def create_and_set_access_token( } # Only set domain if we got one from LAGOON_ROUTES - if cookie_domain: + if (cookie_domain := get_cookie_domain()): cookie_settings["domain"] = cookie_domain # Set cookie with appropriate settings @@ -467,6 +465,9 @@ async def validate_email( email_data: Optional[EmailValidation] = None, email: Optional[str] = Form(None), db: Session = Depends(get_db), + _: None = Depends( + RateLimiter(times=settings.RATE_LIMIT_VALIDATE_EMAIL, seconds=60) + ), ): """ Validate an email address and generate a validation code. diff --git a/app/core/config.py b/app/core/config.py index 2b255eec..b3a93ebe 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -56,6 +56,11 @@ class Settings(BaseSettings): DEDICATED_DEFAULT_SERVICE_KEYS: float | None = None DEDICATED_DEFAULT_VECTOR_DB_COUNT: float | None = None DEDICATED_DEFAULT_RPM_PER_KEY: float | None = None + REDIS_URL: str = os.getenv( + "REDIS_URL", + f"redis://{os.getenv('REDIS_HOST', 'localhost')}:{os.getenv('REDIS_PORT', '6379')}/0", + ) + RATE_LIMIT_VALIDATE_EMAIL: int = int(os.getenv("RATE_LIMIT_VALIDATE_EMAIL", "5")) model_config = ConfigDict(env_file=".env", extra="ignore") main_route: str = os.getenv("LAGOON_ROUTE", "http://localhost:8800") diff --git a/app/main.py b/app/main.py index fdf0c5f2..82b020c1 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,7 @@ import logging import os from contextlib import asynccontextmanager +from datetime import UTC from app.__version__ import __version__ from app.api import ( @@ -18,17 +19,24 @@ teams, users, ) +from app.core.locking import release_lock, try_acquire_lock from app.core.config import settings +from app.core.worker import hard_delete_expired_teams, monitor_teams +from app.db.database import get_db from app.middleware.audit import AuditLogMiddleware from app.middleware.auth import AuthMiddleware from app.middleware.caching import CacheControlMiddleware from app.middleware.prometheus import PrometheusMiddleware +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger from fastapi import FastAPI +from fastapi_limiter import FastAPILimiter from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.utils import get_openapi from prometheus_fastapi_instrumentator import Instrumentator, metrics +from redis.asyncio import Redis as AsyncRedis from starlette.middleware.base import BaseHTTPMiddleware # Set timezone environment variable to prevent tzlocal warning @@ -52,8 +60,128 @@ async def dispatch(self, request, call_next): @asynccontextmanager async def lifespan(app: FastAPI): + # Initialize rate limiter + await FastAPILimiter.init( + redis=AsyncRedis.from_url(settings.REDIS_URL, decode_responses=True) + ) + + # Create scheduler + scheduler = AsyncIOScheduler() + + async def monitor_teams_job(): + lock_name = "monitor_teams" + lock_acquired = False + + try: + # Try to acquire the lock using a dedicated short-lived session + lock_db = next(get_db()) + try: + lock_acquired = try_acquire_lock(lock_name, lock_db, lock_timeout=10) + finally: + lock_db.close() + + if lock_acquired: + logger.info("Acquired monitor_teams lock, executing job") + job_db = next(get_db()) + try: + await monitor_teams(job_db) + except Exception as e: + job_db.rollback() + logger.error(f"Error in monitor_teams background task: {str(e)}") + finally: + job_db.close() + else: + logger.info("Another process has the monitor_teams lock, skipping execution") + except Exception as e: + logger.error(f"Error in monitor_teams job: {str(e)}") + finally: + # Always release the lock when it was acquired, using a separate session + if lock_acquired: + release_db = next(get_db()) + try: + release_lock(lock_name, release_db) + except Exception as release_error: + logger.error(f"Error releasing lock: {str(release_error)}") + finally: + release_db.close() + + # Set schedule based on environment + if settings.ENV_SUFFIX == "local": + cron_trigger = CronTrigger(minute='*/10', timezone=UTC, jitter=180) + else: + # Run every hour in other environments with jitter + cron_trigger = CronTrigger(hour='*', minute=0, timezone=UTC, jitter=60) + + scheduler.add_job( + monitor_teams_job, + trigger=cron_trigger, + id='monitor_teams', + replace_existing=True + ) + + # Hard delete job for teams that have been soft-deleted for 60+ days + async def hard_delete_teams_job(): + lock_name = "hard_delete_teams" + lock_acquired = False + + try: + # Try to acquire the lock using a dedicated short-lived session + lock_db = next(get_db()) + try: + lock_acquired = try_acquire_lock(lock_name, lock_db, lock_timeout=10) + finally: + lock_db.close() + + if lock_acquired: + logger.info("Acquired hard_delete_teams lock, executing job") + job_db = next(get_db()) + try: + await hard_delete_expired_teams(job_db) + except Exception as e: + job_db.rollback() + logger.error(f"Error in hard_delete_expired_teams background task: {str(e)}") + finally: + job_db.close() + else: + logger.info("Another process has the hard_delete_teams lock, skipping execution") + except Exception as e: + logger.error(f"Error in hard_delete_teams job: {str(e)}") + finally: + # Always release the lock when it was acquired, using a separate session + if lock_acquired: + release_db = next(get_db()) + try: + release_lock(lock_name, release_db) + except Exception as release_error: + logger.error(f"Error releasing lock: {str(release_error)}") + finally: + release_db.close() + + # Set schedule based on environment for hard delete job + if settings.ENV_SUFFIX == "local": + # In local env, run every hour at :30 for testing + hard_delete_trigger = CronTrigger(hour='*', minute=30, timezone=UTC) + else: + # In production, run daily at 3 AM + hard_delete_trigger = CronTrigger(hour=3, minute=0, timezone=UTC) + + scheduler.add_job( + hard_delete_teams_job, + trigger=hard_delete_trigger, + id='hard_delete_teams', + replace_existing=True + ) + + # Start the scheduler + scheduler.start() + yield + # Shutdown the rate limiter + await FastAPILimiter.close() + + # Shutdown the scheduler + scheduler.shutdown() app = FastAPI( title="Private AI Keys as a Service", diff --git a/docker-compose.yml b/docker-compose.yml index 91550efe..e12a3c55 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,6 +17,16 @@ services: labels: lagoon.type: postgres + redis: + image: uselagoon/redis-7 + labels: + lagoon.type: redis + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] + interval: 5s + timeout: 5s + retries: 5 + backend: build: context: . @@ -26,6 +36,7 @@ services: AMAZEEAI_JWT_SECRET: ${AMAZEEAI_JWT_SECRET} ENABLE_METRICS: "true" # Enable Prometheus metrics AI_TRIAL_REGION: ${AI_TRIAL_REGION:-eu-west} + REDIS_URL: "redis://redis:6379/0" ports: - "8800:8800" volumes: @@ -33,6 +44,8 @@ services: depends_on: postgres: condition: service_healthy + redis: + condition: service_healthy labels: lagoon.type: python diff --git a/requirements.txt b/requirements.txt index d375bf5f..ed2bcb4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,6 @@ prometheus-fastapi-instrumentator==7.1.0 stripe==14.4.0 six==1.17.0 httpx==0.28.1 +fastapi-limiter==0.1.6 +redis==4.6.0 +apscheduler==3.11.2 diff --git a/tests/conftest.py b/tests/conftest.py index e9dca824..833e701d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,11 @@ os.environ["AMAZEEAI_JWT_SECRET"] = "test-secret-key-for-tests" import pytest +from fastapi import HTTPException from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from starlette.status import HTTP_429_TOO_MANY_REQUESTS from app.main import app from app.db.database import get_db from app.db.models import Base, DBRegion, DBUser, DBTeam, DBProduct @@ -14,6 +16,64 @@ from datetime import datetime, UTC, timedelta from unittest.mock import patch, MagicMock, Mock, AsyncMock +# Track rate limit counts per key for testing +_rate_limit_counts = {} + +# Mock FastAPILimiter and RateLimiter to bypass Redis rate limiting in tests +@pytest.fixture(autouse=True) +def mock_rate_limiting(request): + """Fixture to mock rate limiting for all tests - tracks calls but doesn't block""" + # Check if this is a rate limit test - if so, don't mock the blocking behavior + # Match both 'rate_limit' and 'ratelimit' patterns + nodeid = request.node.nodeid.lower() + is_rate_limit_test = 'rate_limit' in nodeid or 'ratelimit' in nodeid + + _rate_limit_counts.clear() + + # Create a mock Redis that supports async operations and tracks calls + async def mock_evalsha(sha, numkeys, key, times, milliseconds): + # Track how many times each key has been called + if key not in _rate_limit_counts: + _rate_limit_counts[key] = 0 + _rate_limit_counts[key] += 1 + + # For rate limit tests, actually enforce the limit based on the configured 'times' value + # Note: fastapi_limiter passes times as a string, so convert to int for comparison + if is_rate_limit_test and _rate_limit_counts[key] > int(times): + # Return a value that indicates rate limiting (simulating Redis response) + return 1000 # Return positive value indicating wait time in ms + + # Return 0 to indicate not rate limited (allow request) + return 0 + + async def mock_http_callback(request, response, pexpire): + raise HTTPException(status_code=HTTP_429_TOO_MANY_REQUESTS, detail="Too Many Requests") + + mock_redis = AsyncMock() + mock_redis.evalsha = mock_evalsha + + # Patch FastAPILimiter in main module (for lifespan) + with patch('app.main.FastAPILimiter') as mock_main_fl: + mock_main_fl.redis = mock_redis + mock_main_fl.init = AsyncMock() + mock_main_fl.close = AsyncMock() + mock_main_fl.lua_sha = "mock_sha" + mock_main_fl.prefix = "test" + mock_main_fl.identifier = AsyncMock(return_value="test-key") + mock_main_fl.http_callback = mock_http_callback + + # Patch FastAPILimiter in depends module (for RateLimiter class) + with patch('fastapi_limiter.depends.FastAPILimiter') as mock_depends_fl: + mock_depends_fl.redis = mock_redis + mock_depends_fl.init = AsyncMock() + mock_depends_fl.close = AsyncMock() + mock_depends_fl.lua_sha = "mock_sha" + mock_depends_fl.prefix = "test" + mock_depends_fl.identifier = AsyncMock(return_value="test-key") + mock_depends_fl.http_callback = mock_http_callback + + yield _rate_limit_counts + # Get database URL from environment DATABASE_URL = os.getenv( "DATABASE_URL", @@ -363,6 +423,13 @@ def mock_httpx_combined_client(): return mock_client +@pytest.fixture +def mock_client_class(): + """Mock httpx.AsyncClient class for patching""" + with patch('httpx.AsyncClient') as mock_class: + yield mock_class + + # Helper function for soft-deleting teams in tests def soft_delete_team_for_test(db, team: DBTeam, deleted_at: datetime = None): """ diff --git a/tests/test_auth.py b/tests/test_auth.py index 8b37dd4a..d8107384 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -266,6 +266,25 @@ def test_validate_email_invalid_format(client, mock_dynamodb): mock_dynamodb.write_validation_code.assert_not_called() +def test_validate_email_rate_limit(client, mock_dynamodb, mock_ses): + """Test that /auth/validate-email returns HTTP 429 after exceeding the rate limit.""" + from app.core.config import settings + + email = "ratelimit@example.com" + limit = settings.RATE_LIMIT_VALIDATE_EMAIL + + # Requests up to the limit should succeed + for attempt in range(limit): + response = client.post("/auth/validate-email", json={"email": email}) + assert response.status_code == 200, ( + f"Request {attempt + 1} of {limit} should succeed, got {response.status_code}" + ) + + # The next request should be rate limited + response = client.post("/auth/validate-email", json={"email": email}) + assert response.status_code == 429 + + def test_sign_in_success(client, test_user, mock_dynamodb): # First, generate a validation code email = test_user.email @@ -696,7 +715,6 @@ def test_login_cookie_expiration_regular_user(client, test_user): set_cookie_header = response.headers.get("set-cookie", "") assert "Max-Age=1800" in set_cookie_header or "max-age=1800" in set_cookie_header - def test_login_cookie_expiration_system_admin(client, test_admin): """ Given a system administrator diff --git a/tests/test_team_keys_policy.py b/tests/test_team_keys_policy.py index 97cb5e8d..fcb33084 100644 --- a/tests/test_team_keys_policy.py +++ b/tests/test_team_keys_policy.py @@ -1,6 +1,5 @@ from unittest.mock import patch - @patch("httpx.AsyncClient") def test_create_key_force_user_keys_enabled( mock_client_class, @@ -78,7 +77,6 @@ def test_create_key_force_user_keys_enabled( assert found_generate_call - @patch("httpx.AsyncClient") def test_create_token_force_user_keys_enabled( mock_client_class, @@ -122,7 +120,6 @@ def test_create_token_force_user_keys_enabled( assert data["owner_id"] == user_id assert data["team_id"] is None - @patch("httpx.AsyncClient") def test_create_vector_db_force_user_keys_enabled( mock_client_class,