Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Settings(BaseSettings):
WEBHOOK_SIG: str = os.getenv("WEBHOOK_SIG", "whsec_test_1234567890")
ENABLE_METRICS: bool = os.getenv("ENABLE_METRICS", "false") == "true"
PROMETHEUS_API_KEY: str = os.getenv("PROMETHEUS_API_KEY", "")
RATE_LIMIT_ENABLED: bool = os.getenv("RATE_LIMIT_ENABLED", "true") == "true"

model_config = ConfigDict(env_file=".env", extra="ignore")
main_route: str = os.getenv("LAGOON_ROUTE", "http://localhost:8800")
Expand Down
5 changes: 5 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from app.middleware.caching import CacheControlMiddleware
from app.middleware.prometheus import PrometheusMiddleware
from app.middleware.auth import AuthMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.core.worker import monitor_teams, hard_delete_expired_teams
from app.core.locking import try_acquire_lock, release_lock
from app.__version__ import __version__
Expand Down Expand Up @@ -189,6 +190,10 @@ async def hard_delete_teams_job():
# Add HTTPS redirect middleware first
app.add_middleware(HTTPSRedirectMiddleware)

# Add rate limiting middleware for public endpoints
if settings.RATE_LIMIT_ENABLED:
app.add_middleware(RateLimitMiddleware)

# Add Auth middleware (must be before Prometheus and Audit middleware)
app.add_middleware(AuthMiddleware)

Expand Down
114 changes: 114 additions & 0 deletions app/middleware/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import time
import threading
from collections import defaultdict

from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

import logging

logger = logging.getLogger(__name__)

# Rate limit rules for public endpoints: path -> (max_requests, window_seconds)
# These protect against brute force attacks, spam, and resource abuse.
DEFAULT_RATE_LIMITS: dict[str, tuple[int, int]] = {
"/auth/login": (10, 60),
"/auth/register": (5, 60),
"/auth/validate-email": (5, 60),
"/auth/sign-in": (10, 60),
"/auth/generate-trial-access": (5, 60),
}

# How many entries to keep in the timestamp map before triggering cleanup.
# When the map grows beyond this threshold old, exhausted keys are pruned.
_CLEANUP_THRESHOLD = 10_000


class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Rate limiting middleware to protect public endpoints from abuse.

Uses a per-IP sliding window counter. Only requests to the configured
public auth endpoints are subject to rate limiting.

NOTE: The client IP is determined from the ``X-Forwarded-For`` header when
present, falling back to the direct connection address. This assumes the
application is deployed behind a trusted reverse proxy that strips any
client-supplied ``X-Forwarded-For`` headers before forwarding requests.
"""

def __init__(self, app, rate_limits: dict[str, tuple[int, int]] | None = None):
super().__init__(app)
self._rate_limits = rate_limits if rate_limits is not None else DEFAULT_RATE_LIMITS
self._lock = threading.Lock()
# {(ip, path): [timestamp, ...]}
self._request_timestamps: dict[tuple[str, str], list[float]] = defaultdict(list)

def _get_client_ip(self, request: Request) -> str:
"""Return the client IP, honouring X-Forwarded-For when present.

The first IP in the header is used, as that is the original client IP
appended by a trusted upstream proxy. The application must be deployed
behind a reverse proxy that is configured to strip any client-supplied
X-Forwarded-For headers.
"""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"

def _check_rate_limit(self, ip: str, path: str, limit: int, window: int) -> tuple[bool, int]:
"""
Record the current request and check whether the rate limit is exceeded.

Returns:
(is_limited, retry_after_seconds)
"""
now = time.time()
key = (ip, path)

with self._lock:
# Evict timestamps outside the current window
self._request_timestamps[key] = [
ts for ts in self._request_timestamps[key] if now - ts < window
]

count = len(self._request_timestamps[key])
if count >= limit:
oldest = self._request_timestamps[key][0]
retry_after = int(window - (now - oldest)) + 1
return True, retry_after

# Record this request
self._request_timestamps[key].append(now)

# Periodically remove keys whose timestamp lists are now empty to
# prevent the dictionary from growing without bound.
if len(self._request_timestamps) > _CLEANUP_THRESHOLD:
empty_keys = [k for k, v in self._request_timestamps.items() if not v]
for k in empty_keys:
del self._request_timestamps[k]

return False, 0

async def dispatch(self, request: Request, call_next):
path = request.url.path

if path not in self._rate_limits:
return await call_next(request)

limit, window = self._rate_limits[path]
ip = self._get_client_ip(request)

is_limited, retry_after = self._check_rate_limit(ip, path, limit, window)

if is_limited:
logger.warning(f"Rate limit exceeded for IP {ip} on {path}")
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."},
headers={"Retry-After": str(retry_after)},
)

return await call_next(request)
200 changes: 200 additions & 0 deletions tests/test_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
Tests for RateLimitMiddleware.

These tests exercise the middleware directly using the test client, verifying
that public auth endpoints return HTTP 429 after the configured number of
requests and that the Retry-After header is present.
"""
import pytest
from unittest.mock import MagicMock
from fastapi.testclient import TestClient

from app.middleware.rate_limit import RateLimitMiddleware, DEFAULT_RATE_LIMITS


class TestRateLimitMiddlewareDirect:
"""Unit tests for the RateLimitMiddleware helper methods."""

def _make_middleware(self, limits=None):
"""Return a RateLimitMiddleware instance without a real ASGI app."""
mock_app = MagicMock()
return RateLimitMiddleware(mock_app, rate_limits=limits or DEFAULT_RATE_LIMITS)

def test_get_client_ip_from_forwarded_header(self):
"""X-Forwarded-For header should take priority over request.client.host."""
mw = self._make_middleware()
req = MagicMock()
req.headers = {"X-Forwarded-For": "1.2.3.4, 10.0.0.1"}
req.client = MagicMock(host="127.0.0.1")

assert mw._get_client_ip(req) == "1.2.3.4"

def test_get_client_ip_fallback_to_client_host(self):
"""Without X-Forwarded-For, fall back to request.client.host."""
mw = self._make_middleware()
req = MagicMock()
req.headers = {}
req.client = MagicMock(host="5.6.7.8")

assert mw._get_client_ip(req) == "5.6.7.8"

def test_get_client_ip_no_client(self):
"""Return 'unknown' when there is no client information at all."""
mw = self._make_middleware()
req = MagicMock()
req.headers = {}
req.client = None

assert mw._get_client_ip(req) == "unknown"

def test_check_rate_limit_allows_requests_within_limit(self):
"""
Given a limit of 3 requests per 60 s
When 3 requests arrive
Then none of them should be rate-limited.
"""
mw = self._make_middleware()
for _ in range(3):
is_limited, _ = mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60)
assert not is_limited

def test_check_rate_limit_blocks_on_exceeding_limit(self):
"""
Given a limit of 3 requests per 60 s
When a 4th request arrives
Then it should be rate-limited.
"""
mw = self._make_middleware()
for _ in range(3):
mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60)

is_limited, retry_after = mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60)
assert is_limited
assert retry_after > 0

def test_check_rate_limit_different_ips_are_independent(self):
"""Different IPs must have independent counters."""
mw = self._make_middleware()
for _ in range(3):
mw._check_rate_limit("1.1.1.1", "/auth/login", 3, 60)

# Exhausted for 1.1.1.1
is_limited_a, _ = mw._check_rate_limit("1.1.1.1", "/auth/login", 3, 60)
# Not exhausted for 2.2.2.2
is_limited_b, _ = mw._check_rate_limit("2.2.2.2", "/auth/login", 3, 60)

assert is_limited_a
assert not is_limited_b

def test_check_rate_limit_different_paths_are_independent(self):
"""Different paths must have independent counters for the same IP."""
mw = self._make_middleware()
for _ in range(3):
mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60)

# Exhausted for /auth/login
is_limited_login, _ = mw._check_rate_limit("1.2.3.4", "/auth/login", 3, 60)
# Not exhausted for /auth/register
is_limited_register, _ = mw._check_rate_limit("1.2.3.4", "/auth/register", 3, 60)

assert is_limited_login
assert not is_limited_register


class TestRateLimitMiddlewareIntegration:
"""Integration tests using the real FastAPI TestClient."""

def test_login_endpoint_rate_limited_after_threshold(self):
"""
Given the /auth/login endpoint with a custom low limit
When requests exceed the limit
Then HTTP 429 is returned with a Retry-After header.
"""
low_limit = {"/auth/login": (3, 60)}

# Drive the private helper directly with a fresh middleware instance
# to validate the core logic without needing an active HTTP server.
mw = RateLimitMiddleware(MagicMock(), rate_limits=low_limit)
for _ in range(3):
is_limited, _ = mw._check_rate_limit("127.0.0.1", "/auth/login", 3, 60)
assert not is_limited

is_limited, retry_after = mw._check_rate_limit("127.0.0.1", "/auth/login", 3, 60)
assert is_limited
assert retry_after >= 1

def test_rate_limit_returns_429_and_retry_after(self, client: TestClient):
"""
Given a rate limit of 2 per 60 s on /auth/login
When more than 2 requests come from the same IP
Then the server responds with 429 and a Retry-After header.
"""
from app.main import app as fastapi_app

# Find the RateLimitMiddleware in the Starlette middleware stack
mw_instance = None
stack = fastapi_app.middleware_stack
# Walk the chain
current = stack
while current is not None:
if isinstance(current, RateLimitMiddleware):
mw_instance = current
break
current = getattr(current, "app", None)

if mw_instance is None:
pytest.skip("RateLimitMiddleware not found in middleware stack (may be disabled)")

# Override with a very low limit for this test
original_limits = mw_instance._rate_limits
mw_instance._rate_limits = {"/auth/login": (2, 60)}
# Clear any state from previous tests
mw_instance._request_timestamps.clear()

try:
# The TestClient uses a fixed IP (127.0.0.1 / testclient).
# First 2 requests should pass (status is not 429).
for _ in range(2):
resp = client.post(
"/auth/login",
json={"username": "x@x.com", "password": "wrong"},
)
assert resp.status_code != 429

# Third request must be rate-limited.
resp = client.post(
"/auth/login",
json={"username": "x@x.com", "password": "wrong"},
)
assert resp.status_code == 429
assert "Retry-After" in resp.headers
assert int(resp.headers["Retry-After"]) >= 1
assert "Too many requests" in resp.json()["detail"]
finally:
mw_instance._rate_limits = original_limits
mw_instance._request_timestamps.clear()

def test_unprotected_endpoint_not_rate_limited(self, client: TestClient):
"""
/health is not in the rate-limited path list and must never return 429.
"""
for _ in range(20):
resp = client.get("/health")
assert resp.status_code == 200

def test_default_rate_limits_cover_expected_endpoints(self):
"""DEFAULT_RATE_LIMITS must include all sensitive public auth endpoints."""
expected = {
"/auth/login",
"/auth/register",
"/auth/validate-email",
"/auth/sign-in",
"/auth/generate-trial-access",
}
assert expected.issubset(set(DEFAULT_RATE_LIMITS.keys()))

def test_rate_limit_config_has_positive_values(self):
"""Every entry in DEFAULT_RATE_LIMITS must have positive limit and window."""
for path, (limit, window) in DEFAULT_RATE_LIMITS.items():
assert limit > 0, f"limit for {path} must be positive"
assert window > 0, f"window for {path} must be positive"
Loading