Skip to content
47 changes: 44 additions & 3 deletions app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import time
import uuid
from datetime import datetime
from datetime import datetime, timedelta, UTC
import email_validator

from typing import Optional, List, Union
Expand Down Expand Up @@ -36,7 +36,7 @@
from app.core.worker import generate_pricing_url

from app.db.database import get_db
from app.db.models import DBUser, DBAPIToken, DBRegion, DBTeam
from app.db.models import DBUser, DBAPIToken, DBRegion, DBTeam, DBAPITokenExpiryOption

from app.services.litellm import LiteLLMService
from app.services.dynamodb import DynamoDBService
Expand All @@ -50,6 +50,7 @@
APIToken,
APITokenCreate,
APITokenResponse,
APITokenExpiryOption,
UserUpdate,
EmailValidation,
LoginData,
Expand Down Expand Up @@ -519,6 +520,20 @@ def generate_api_token() -> str:
return secrets.token_urlsafe(32)


@router.get("/token/expiry-options", response_model=List[APITokenExpiryOption])
async def list_expiry_options(
Comment thread
dan2k3k4 marked this conversation as resolved.
current_user=Depends(get_current_user_from_auth),
db: Session = Depends(get_db),
):
"""List available API token expiry options"""
return (
db.query(DBAPITokenExpiryOption)
.filter(DBAPITokenExpiryOption.is_active)
.order_by(DBAPITokenExpiryOption.id)
.all()
)


@router.post("/token", response_model=APIToken)
async def create_token(
token_create: APITokenCreate,
Expand Down Expand Up @@ -552,8 +567,34 @@ async def create_token(
# Create token for the current user
user_id = current_user.id

# Fetch expiry option from DB (only active options are valid)
expiry_slug = token_create.expiry or "forever"
db_expiry_opt = (
db.query(DBAPITokenExpiryOption)
.filter(
DBAPITokenExpiryOption.slug == expiry_slug,
DBAPITokenExpiryOption.is_active,
)
.first()
)

if not db_expiry_opt:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid or inactive expiry option: {expiry_slug}",
)

# Calculate expiration date
expires_at = None
if db_expiry_opt.days is not None:
expires_at = datetime.now(UTC) + timedelta(days=db_expiry_opt.days)

db_token = DBAPIToken(
name=token_create.name, token=generate_api_token(), user_id=user_id
name=token_create.name,
token=generate_api_token(),
user_id=user_id,
expires_at=expires_at,
expiry_option=db_expiry_opt.slug,
)
db.add(db_token)
db.commit()
Expand Down
4 changes: 3 additions & 1 deletion app/api/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def _to_display_name(model_id: str, aliases: list[str] | None = None) -> str:
# Replace hyphenated number sequences with dotted versions before splitting.
# Use word-boundary anchors so that e.g. "3-5" does not corrupt "123-5".
modified_id = model_id
for hyphenated, dotted in sorted(dot_replacements.items(), key=lambda x: -len(x[0])):
for hyphenated, dotted in sorted(
dot_replacements.items(), key=lambda x: -len(x[0])
):
modified_id = re.sub(
r"(?<![0-9])" + re.escape(hyphenated) + r"(?![0-9])",
dotted,
Expand Down
30 changes: 12 additions & 18 deletions app/api/subscription.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from datetime import UTC, datetime, timedelta

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session

from app.core.config import settings
from app.core.security import get_role_min_system_admin
from app.core.team_service import get_team_region_litellm_keys
from app.core.worker import (
_record_periodic_payment_direct,
Expand All @@ -32,17 +31,6 @@

router = APIRouter()
logger = logging.getLogger(__name__)
security = HTTPBearer()


def _verify_moad_api_key(
credentials: HTTPAuthorizationCredentials = Depends(security),
):
if credentials.credentials != settings.MOAD_API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
return credentials.credentials


def _write_audit_log(
Expand Down Expand Up @@ -73,11 +61,14 @@ def _write_audit_log(
logger.warning("Failed to rollback audit log transaction: %s", rollback_exc)


@router.post("/cycle", response_model=SubscriptionCycleResponse)
@router.post(
"/cycle",
response_model=SubscriptionCycleResponse,
dependencies=[Depends(get_role_min_system_admin)],
)
async def subscription_cycle(
request: SubscriptionCycleRequest,
db: Session = Depends(get_db),
_: str = Depends(_verify_moad_api_key),
):
logger.info("subscription.cycle called: %s", request.model_dump())

Expand Down Expand Up @@ -232,11 +223,14 @@ async def subscription_cycle(
raise HTTPException(status_code=500, detail=f"Subscription cycle failed: {exc}")


@router.post("/deactivate", response_model=SubscriptionDeactivateResponse)
@router.post(
"/deactivate",
response_model=SubscriptionDeactivateResponse,
dependencies=[Depends(get_role_min_system_admin)],
)
async def subscription_deactivate(
request: SubscriptionDeactivateRequest,
db: Session = Depends(get_db),
_: str = Depends(_verify_moad_api_key),
):
logger.info("subscription.deactivate called: %s", request.model_dump())

Expand Down
12 changes: 11 additions & 1 deletion app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ async def get_current_user_from_auth(
.first()
)
if db_token:
# Check if token is expired
if db_token.expires_at and db_token.expires_at < datetime.now(UTC):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token has expired",
headers={"WWW-Authenticate": "Bearer"},
)

# Update last used timestamp
db_token.last_used_at = datetime.now(UTC)
_check_user_team_not_suspended(db_token.owner)
Expand All @@ -150,7 +158,9 @@ async def get_current_user_from_auth(
except HTTPException:
raise
except Exception:
pass
logger.exception(
"Unexpected error during API token validation; falling back to JWT validation"
)

# If API token validation fails, try JWT validation
try:
Expand Down
60 changes: 56 additions & 4 deletions app/db/init_db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,63 @@
from app.db.models import Base
from app.db.database import engine
import logging

from app.db.models import Base, DBAPITokenExpiryOption
from app.db.database import engine, SessionLocal

logger = logging.getLogger(__name__)


def init_api_token_expiry_options(db=None):
logger.info("Initializing API token expiry options...")
own_db = False
if db is None:
db = SessionLocal()
own_db = True

try:
options = [
{"name": "1 day", "slug": "1_day", "days": 1},
{"name": "1 week", "slug": "1_week", "days": 7},
{"name": "1 month", "slug": "1_month", "days": 30},
{"name": "2 months", "slug": "2_months", "days": 60},
{"name": "3 months", "slug": "3_months", "days": 90},
{"name": "4 months", "slug": "4_months", "days": 120},
{"name": "5 months", "slug": "5_months", "days": 150},
{"name": "6 months", "slug": "6_months", "days": 180},
{"name": "7 months", "slug": "7_months", "days": 210},
{"name": "8 months", "slug": "8_months", "days": 240},
{"name": "9 months", "slug": "9_months", "days": 270},
{"name": "10 months", "slug": "10_months", "days": 300},
{"name": "11 months", "slug": "11_months", "days": 330},
{"name": "1 year", "slug": "1_year", "days": 365},
{"name": "forever", "slug": "forever", "days": None},
]

for opt_data in options:
existing = (
db.query(DBAPITokenExpiryOption)
.filter(DBAPITokenExpiryOption.slug == opt_data["slug"])
.first()
)
if not existing:
db_opt = DBAPITokenExpiryOption(**opt_data)
db.add(db_opt)

db.commit()
logger.info("API token expiry options initialized successfully!")
except Exception:
logger.exception("Error initializing API token expiry options")
db.rollback()
raise
finally:
if own_db:
db.close()


def init_db():
print("Creating database tables...")
logger.info("Creating database tables...")
Base.metadata.create_all(bind=engine)
print("Database tables created successfully!")
logger.info("Database tables created successfully!")
init_api_token_expiry_options()


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions app/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ class DBRegion(Base):
admin_users = relationship("DBUserAdminRegion", back_populates="region")


class DBAPITokenExpiryOption(Base):
__tablename__ = "api_token_expiry_options"

id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
slug = Column(String, unique=True, index=True, nullable=False)
days = Column(Integer, nullable=True) # None for forever
is_active = Column(Boolean, nullable=False, default=True)


class DBAPIToken(Base):
__tablename__ = "api_tokens"

Expand All @@ -132,6 +142,8 @@ class DBAPIToken(Base):
token = Column(String, unique=True, index=True)
created_at = Column(DateTime(timezone=True), default=func.now())
last_used_at = Column(DateTime(timezone=True), nullable=True)
expires_at = Column(DateTime(timezone=True), nullable=True)
expiry_option = Column(String, default="forever", nullable=False)
user_id = Column(Integer, ForeignKey("users.id"))

owner = relationship("DBUser", back_populates="api_tokens")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""add_api_token_expiry_options

Revision ID: daf5bf0b03c2
Revises: a1b2c3d4e5f6
Create Date: 2026-04-08 12:59:46.600703+00:00

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "daf5bf0b03c2"
Comment thread
dan2k3k4 marked this conversation as resolved.
Dismissed
down_revision: Union[str, None] = "a1b2c3d4e5f6"
Comment thread
dan2k3k4 marked this conversation as resolved.
Dismissed
branch_labels: Union[str, Sequence[str], None] = None
Comment thread
dan2k3k4 marked this conversation as resolved.
Dismissed
depends_on: Union[str, Sequence[str], None] = None
Comment thread
dan2k3k4 marked this conversation as resolved.
Dismissed


def upgrade() -> None:
# Create api_token_expiry_options table
op.create_table(
"api_token_expiry_options",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("slug", sa.String(), nullable=False),
sa.Column("days", sa.Integer(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.true()),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
Comment thread
dan2k3k4 marked this conversation as resolved.
op.f("ix_api_token_expiry_options_id"),
"api_token_expiry_options",
["id"],
unique=False,
)
op.create_index(
op.f("ix_api_token_expiry_options_slug"),
"api_token_expiry_options",
["slug"],
unique=True,
)

# Seed default expiry options so the API is usable immediately after migration
expiry_options_table = sa.table(
"api_token_expiry_options",
sa.column("name", sa.String()),
sa.column("slug", sa.String()),
sa.column("days", sa.Integer()),
sa.column("is_active", sa.Boolean()),
)
op.bulk_insert(
expiry_options_table,
[
{"name": "1 day", "slug": "1_day", "days": 1, "is_active": True},
{"name": "1 week", "slug": "1_week", "days": 7, "is_active": True},
{"name": "1 month", "slug": "1_month", "days": 30, "is_active": True},
{"name": "2 months", "slug": "2_months", "days": 60, "is_active": True},
{"name": "3 months", "slug": "3_months", "days": 90, "is_active": True},
{"name": "4 months", "slug": "4_months", "days": 120, "is_active": True},
{"name": "5 months", "slug": "5_months", "days": 150, "is_active": True},
{"name": "6 months", "slug": "6_months", "days": 180, "is_active": True},
{"name": "7 months", "slug": "7_months", "days": 210, "is_active": True},
{"name": "8 months", "slug": "8_months", "days": 240, "is_active": True},
{"name": "9 months", "slug": "9_months", "days": 270, "is_active": True},
{"name": "10 months", "slug": "10_months", "days": 300, "is_active": True},
{"name": "11 months", "slug": "11_months", "days": 330, "is_active": True},
{"name": "1 year", "slug": "1_year", "days": 365, "is_active": True},
{"name": "forever", "slug": "forever", "days": None, "is_active": True},
],
)

# Update api_tokens table - these might already exist in some environments but let's ensure they are there
# Check if columns exist first to be safe
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [c["name"] for c in inspector.get_columns("api_tokens")]

if "expires_at" not in columns:
op.add_column(
"api_tokens",
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
)
if "expiry_option" not in columns:
op.add_column(
"api_tokens",
sa.Column(
"expiry_option", sa.String(), nullable=False, server_default="forever"
),
)


def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = [c["name"] for c in inspector.get_columns("api_tokens")]

if "expiry_option" in columns:
op.drop_column("api_tokens", "expiry_option")
if "expires_at" in columns:
op.drop_column("api_tokens", "expires_at")
op.drop_index(
op.f("ix_api_token_expiry_options_slug"), table_name="api_token_expiry_options"
)
op.drop_index(
op.f("ix_api_token_expiry_options_id"), table_name="api_token_expiry_options"
)
op.drop_table("api_token_expiry_options")
Loading