diff --git a/app/core/config.py b/app/core/config.py index 4c2319b4..52fe693e 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -40,6 +40,7 @@ class Settings(BaseSettings): WEBHOOK_SIG: str = os.getenv("WEBHOOK_SIG", "whsec_test_1234567890") ENABLE_METRICS: bool = os.getenv("ENABLE_METRICS", "false") == "true" PROMETHEUS_API_KEY: str = os.getenv("PROMETHEUS_API_KEY", "") + RATE_LIMIT_ENABLED: bool = os.getenv("RATE_LIMIT_ENABLED", "true") == "true" model_config = ConfigDict(env_file=".env", extra="ignore") main_route: str = os.getenv("LAGOON_ROUTE", "http://localhost:8800") diff --git a/app/main.py b/app/main.py index b3679997..490776a9 100644 --- a/app/main.py +++ b/app/main.py @@ -12,6 +12,7 @@ from app.middleware.caching import CacheControlMiddleware from app.middleware.prometheus import PrometheusMiddleware from app.middleware.auth import AuthMiddleware +from app.middleware.rate_limit import RateLimitMiddleware from app.core.worker import monitor_teams, hard_delete_expired_teams from app.core.locking import try_acquire_lock, release_lock from app.__version__ import __version__ @@ -189,6 +190,10 @@ async def hard_delete_teams_job(): # Add HTTPS redirect middleware first app.add_middleware(HTTPSRedirectMiddleware) +# Add rate limiting middleware for public endpoints +if settings.RATE_LIMIT_ENABLED: + app.add_middleware(RateLimitMiddleware) + # Add Auth middleware (must be before Prometheus and Audit middleware) app.add_middleware(AuthMiddleware) diff --git a/app/middleware/rate_limit.py b/app/middleware/rate_limit.py new file mode 100644 index 00000000..4df3089d --- /dev/null +++ b/app/middleware/rate_limit.py @@ -0,0 +1,114 @@ +import time +import threading +from collections import defaultdict + +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +import logging + +logger = logging.getLogger(__name__) + +# Rate limit rules for public endpoints: path -> (max_requests, window_seconds) +# These protect against brute force attacks, spam, and resource abuse. +DEFAULT_RATE_LIMITS: dict[str, tuple[int, int]] = { + "/auth/login": (10, 60), + "/auth/register": (5, 60), + "/auth/validate-email": (5, 60), + "/auth/sign-in": (10, 60), + "/auth/generate-trial-access": (5, 60), +} + +# How many entries to keep in the timestamp map before triggering cleanup. +# When the map grows beyond this threshold old, exhausted keys are pruned. +_CLEANUP_THRESHOLD = 10_000 + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Rate limiting middleware to protect public endpoints from abuse. + + Uses a per-IP sliding window counter. Only requests to the configured + public auth endpoints are subject to rate limiting. + + NOTE: The client IP is determined from the ``X-Forwarded-For`` header when + present, falling back to the direct connection address. This assumes the + application is deployed behind a trusted reverse proxy that strips any + client-supplied ``X-Forwarded-For`` headers before forwarding requests. + """ + + def __init__(self, app, rate_limits: dict[str, tuple[int, int]] | None = None): + super().__init__(app) + self._rate_limits = rate_limits if rate_limits is not None else DEFAULT_RATE_LIMITS + self._lock = threading.Lock() + # {(ip, path): [timestamp, ...]} + self._request_timestamps: dict[tuple[str, str], list[float]] = defaultdict(list) + + def _get_client_ip(self, request: Request) -> str: + """Return the client IP, honouring X-Forwarded-For when present. + + The first IP in the header is used, as that is the original client IP + appended by a trusted upstream proxy. The application must be deployed + behind a reverse proxy that is configured to strip any client-supplied + X-Forwarded-For headers. + """ + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + def _check_rate_limit(self, ip: str, path: str, limit: int, window: int) -> tuple[bool, int]: + """ + Record the current request and check whether the rate limit is exceeded. + + Returns: + (is_limited, retry_after_seconds) + """ + now = time.time() + key = (ip, path) + + with self._lock: + # Evict timestamps outside the current window + self._request_timestamps[key] = [ + ts for ts in self._request_timestamps[key] if now - ts < window + ] + + count = len(self._request_timestamps[key]) + if count >= limit: + oldest = self._request_timestamps[key][0] + retry_after = int(window - (now - oldest)) + 1 + return True, retry_after + + # Record this request + self._request_timestamps[key].append(now) + + # Periodically remove keys whose timestamp lists are now empty to + # prevent the dictionary from growing without bound. + if len(self._request_timestamps) > _CLEANUP_THRESHOLD: + empty_keys = [k for k, v in self._request_timestamps.items() if not v] + for k in empty_keys: + del self._request_timestamps[k] + + return False, 0 + + async def dispatch(self, request: Request, call_next): + path = request.url.path + + if path not in self._rate_limits: + return await call_next(request) + + limit, window = self._rate_limits[path] + ip = self._get_client_ip(request) + + is_limited, retry_after = self._check_rate_limit(ip, path, limit, window) + + if is_limited: + logger.warning(f"Rate limit exceeded for IP {ip} on {path}") + return JSONResponse( + status_code=429, + content={"detail": "Too many requests. Please try again later."}, + headers={"Retry-After": str(retry_after)}, + ) + + return await call_next(request) diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 00000000..7665c081 --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,200 @@ +""" +Tests for RateLimitMiddleware. + +These tests exercise the middleware directly using the test client, verifying +that public auth endpoints return HTTP 429 after the configured number of +requests and that the Retry-After header is present. +""" +import pytest +from unittest.mock import MagicMock +from fastapi.testclient import TestClient + +from app.middleware.rate_limit import RateLimitMiddleware, DEFAULT_RATE_LIMITS + + +class TestRateLimitMiddlewareDirect: + """Unit tests for the RateLimitMiddleware helper methods.""" + + def _make_middleware(self, limits=None): + """Return a RateLimitMiddleware instance without a real ASGI app.""" + mock_app = MagicMock() + return RateLimitMiddleware(mock_app, rate_limits=limits or DEFAULT_RATE_LIMITS) + + def test_get_client_ip_from_forwarded_header(self): + """X-Forwarded-For header should take priority over request.client.host.""" + mw = self._make_middleware() + req = MagicMock() + req.headers = {"X-Forwarded-For": "1.2.3.4, 10.0.0.1"} + req.client = MagicMock(host="127.0.0.1") + + assert mw._get_client_ip(req) == "1.2.3.4" + + def test_get_client_ip_fallback_to_client_host(self): + """Without X-Forwarded-For, fall back to request.client.host.""" + mw = self._make_middleware() + req = MagicMock() + req.headers = {} + req.client = MagicMock(host="5.6.7.8") + + assert mw._get_client_ip(req) == "5.6.7.8" + + def test_get_client_ip_no_client(self): + """Return 'unknown' when there is no client information at all.""" + mw = self._make_middleware() + req = MagicMock() + req.headers = {} + req.client = None + + assert mw._get_client_ip(req) == "unknown" + + def test_check_rate_limit_allows_requests_within_limit(self): + """ + Given a limit of 3 requests per 60 s + When 3 requests arrive + Then none of them should be rate-limited. + """ + mw = self._make_middleware() + for _ in range(3): + is_limited, _ = mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60) + assert not is_limited + + def test_check_rate_limit_blocks_on_exceeding_limit(self): + """ + Given a limit of 3 requests per 60 s + When a 4th request arrives + Then it should be rate-limited. + """ + mw = self._make_middleware() + for _ in range(3): + mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60) + + is_limited, retry_after = mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60) + assert is_limited + assert retry_after > 0 + + def test_check_rate_limit_different_ips_are_independent(self): + """Different IPs must have independent counters.""" + mw = self._make_middleware() + for _ in range(3): + mw._check_rate_limit("1.1.1.1", "/auth/login", 3, 60) + + # Exhausted for 1.1.1.1 + is_limited_a, _ = mw._check_rate_limit("1.1.1.1", "/auth/login", 3, 60) + # Not exhausted for 2.2.2.2 + is_limited_b, _ = mw._check_rate_limit("2.2.2.2", "/auth/login", 3, 60) + + assert is_limited_a + assert not is_limited_b + + def test_check_rate_limit_different_paths_are_independent(self): + """Different paths must have independent counters for the same IP.""" + mw = self._make_middleware() + for _ in range(3): + mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60) + + # Exhausted for /auth/login + is_limited_login, _ = mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60) + # Not exhausted for /auth/register + is_limited_register, _ = mw._check_rate_limit("1.2.3.4", "/auth/register", 3, 60) + + assert is_limited_login + assert not is_limited_register + + +class TestRateLimitMiddlewareIntegration: + """Integration tests using the real FastAPI TestClient.""" + + def test_login_endpoint_rate_limited_after_threshold(self): + """ + Given the /auth/login endpoint with a custom low limit + When requests exceed the limit + Then HTTP 429 is returned with a Retry-After header. + """ + low_limit = {"/auth/login": (3, 60)} + + # Drive the private helper directly with a fresh middleware instance + # to validate the core logic without needing an active HTTP server. + mw = RateLimitMiddleware(MagicMock(), rate_limits=low_limit) + for _ in range(3): + is_limited, _ = mw._check_rate_limit("127.0.0.1", "/auth/login", 3, 60) + assert not is_limited + + is_limited, retry_after = mw._check_rate_limit("127.0.0.1", "/auth/login", 3, 60) + assert is_limited + assert retry_after >= 1 + + def test_rate_limit_returns_429_and_retry_after(self, client: TestClient): + """ + Given a rate limit of 2 per 60 s on /auth/login + When more than 2 requests come from the same IP + Then the server responds with 429 and a Retry-After header. + """ + from app.main import app as fastapi_app + + # Find the RateLimitMiddleware in the Starlette middleware stack + mw_instance = None + stack = fastapi_app.middleware_stack + # Walk the chain + current = stack + while current is not None: + if isinstance(current, RateLimitMiddleware): + mw_instance = current + break + current = getattr(current, "app", None) + + if mw_instance is None: + pytest.skip("RateLimitMiddleware not found in middleware stack (may be disabled)") + + # Override with a very low limit for this test + original_limits = mw_instance._rate_limits + mw_instance._rate_limits = {"/auth/login": (2, 60)} + # Clear any state from previous tests + mw_instance._request_timestamps.clear() + + try: + # The TestClient uses a fixed IP (127.0.0.1 / testclient). + # First 2 requests should pass (status is not 429). + for _ in range(2): + resp = client.post( + "/auth/login", + json={"username": "x@x.com", "password": "wrong"}, + ) + assert resp.status_code != 429 + + # Third request must be rate-limited. + resp = client.post( + "/auth/login", + json={"username": "x@x.com", "password": "wrong"}, + ) + assert resp.status_code == 429 + assert "Retry-After" in resp.headers + assert int(resp.headers["Retry-After"]) >= 1 + assert "Too many requests" in resp.json()["detail"] + finally: + mw_instance._rate_limits = original_limits + mw_instance._request_timestamps.clear() + + def test_unprotected_endpoint_not_rate_limited(self, client: TestClient): + """ + /health is not in the rate-limited path list and must never return 429. + """ + for _ in range(20): + resp = client.get("/health") + assert resp.status_code == 200 + + def test_default_rate_limits_cover_expected_endpoints(self): + """DEFAULT_RATE_LIMITS must include all sensitive public auth endpoints.""" + expected = { + "/auth/login", + "/auth/register", + "/auth/validate-email", + "/auth/sign-in", + "/auth/generate-trial-access", + } + assert expected.issubset(set(DEFAULT_RATE_LIMITS.keys())) + + def test_rate_limit_config_has_positive_values(self): + """Every entry in DEFAULT_RATE_LIMITS must have positive limit and window.""" + for path, (limit, window) in DEFAULT_RATE_LIMITS.items(): + assert limit > 0, f"limit for {path} must be positive" + assert window > 0, f"window for {path} must be positive"