From 694f6398f642c9c26b88d0cbdaa3dd8e1c4a4469 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 26 May 2026 13:30:12 +0530 Subject: [PATCH 1/2] feat(model-config): migrate completion_type from scalar to array and update related endpoints - Added Alembic migration to change completion_type from scalar to array in model_config table. - Updated API documentation for creating, updating, and deleting model configurations to reflect changes in completion_type. - Implemented bulk create and update functionality for model configurations. - Enhanced model configuration routes to support new array-based completion_type and added validation for input modalities. - Added tests for creating, updating, and deleting model configurations with new completion_type structure. --- .../versions/064_completion_type_array.py | 49 ++ .../api/docs/model_config/create_models.md | 53 +++ .../app/api/docs/model_config/delete_model.md | 7 + .../app/api/docs/model_config/get_model.md | 68 +-- .../app/api/docs/model_config/list_models.md | 75 +-- .../docs/model_config/list_models_grouped.md | 53 +-- .../api/docs/model_config/list_providers.md | 17 +- .../app/api/docs/model_config/update_model.md | 21 + .../api/docs/model_config/update_models.md | 19 + backend/app/api/routes/model_config.py | 68 ++- backend/app/crud/model_config.py | 89 +++- backend/app/models/__init__.py | 3 + backend/app/models/llm/constants.py | 23 + backend/app/models/llm/request.py | 8 +- backend/app/models/model_config.py | 80 +++- .../app/tests/api/routes/test_model_config.py | 437 ++++++++++++++++++ 16 files changed, 839 insertions(+), 231 deletions(-) create mode 100644 backend/app/alembic/versions/064_completion_type_array.py create mode 100644 backend/app/api/docs/model_config/create_models.md create mode 100644 backend/app/api/docs/model_config/delete_model.md create mode 100644 backend/app/api/docs/model_config/update_model.md create mode 100644 backend/app/api/docs/model_config/update_models.md diff --git a/backend/app/alembic/versions/064_completion_type_array.py b/backend/app/alembic/versions/064_completion_type_array.py new file mode 100644 index 000000000..72af12ade --- /dev/null +++ b/backend/app/alembic/versions/064_completion_type_array.py @@ -0,0 +1,49 @@ +"""completion_type scalar to array + +Revision ID: 064 +Revises: 063 +Create Date: 2026-05-26 00:00:00.000000 + +""" + +from alembic import op + +revision = "064" +down_revision = "063" +branch_labels = None +depends_on = None + + +def upgrade(): + # Drop scalar index before altering column + op.execute("DROP INDEX IF EXISTS global.ix_model_config_provider_type_active") + + # Convert scalar enum to array, wrapping existing values + op.execute( + """ + ALTER TABLE global.model_config + ALTER COLUMN completion_type TYPE global.completion_type_enum[] + USING ARRAY[completion_type] + """ + ) + + # GIN index for efficient array containment queries + op.execute( + "CREATE INDEX ix_model_config_completion_type ON global.model_config USING gin (completion_type)" + ) + + +def downgrade(): + op.execute("DROP INDEX IF EXISTS global.ix_model_config_completion_type") + + op.execute( + """ + ALTER TABLE global.model_config + ALTER COLUMN completion_type TYPE global.completion_type_enum + USING completion_type[1] + """ + ) + + op.execute( + "CREATE INDEX ix_model_config_provider_type_active ON global.model_config (provider, completion_type, is_active)" + ) diff --git a/backend/app/api/docs/model_config/create_models.md b/backend/app/api/docs/model_config/create_models.md new file mode 100644 index 000000000..2b5f4aca5 --- /dev/null +++ b/backend/app/api/docs/model_config/create_models.md @@ -0,0 +1,53 @@ +Create one or more model configurations. + +Accepts a single object or an array. Response is always an array. + +**Required:** `provider`, `model_name`, `completion_type`, `config` +**Optional:** `input_modalities`, `output_modalities`, `pricing`, `is_active` + +`(provider, model_name)` must be unique. + +### Example (single) + +```json +{ + "provider": "google", + "model_name": "gemini-2.5-flash", + "completion_type": ["text", "stt"], + "config": { "temperature": { "type": "float", "default": 1.0, "min": 0.0, "max": 2.0 } }, + "input_modalities": ["TEXT", "AUDIO"], + "output_modalities": ["TEXT"], + "pricing": { + "response": { "input_token_cost": 0.3, "output_token_cost": 2.5 } + }, + "is_active": true +} +``` + +### Example (multiple) + +```json +[ + { + "provider": "sarvamai", + "model_name": "saaras:v3", + "completion_type": ["stt"], + "config": {}, + "input_modalities": ["AUDIO"], + "output_modalities": ["TEXT"] + }, + { + "provider": "elevenlabs", + "model_name": "scribe_v2", + "completion_type": ["stt"], + "config": {}, + "input_modalities": ["AUDIO"], + "output_modalities": ["TEXT"] + } +] +``` + +### Errors + +- `422` — Validation error. +- DB integrity error on duplicate `(provider, model_name)`. diff --git a/backend/app/api/docs/model_config/delete_model.md b/backend/app/api/docs/model_config/delete_model.md new file mode 100644 index 000000000..a3bd8bae1 --- /dev/null +++ b/backend/app/api/docs/model_config/delete_model.md @@ -0,0 +1,7 @@ +Permanently delete a model configuration. Hard delete — cannot be undone. + +To hide a model instead, use `PATCH` with `is_active: false`. + +### Errors + +- `404` — Model not found. diff --git a/backend/app/api/docs/model_config/get_model.md b/backend/app/api/docs/model_config/get_model.md index b2f3700cb..718035633 100644 --- a/backend/app/api/docs/model_config/get_model.md +++ b/backend/app/api/docs/model_config/get_model.md @@ -1,67 +1,5 @@ -## Endpoint +Get a single model configuration by `provider` and `model_name`. -**GET** `/api/v1/models/{provider}/{model_name}` +### Errors -Retrieve a specific model configuration by provider and model name. - -Returns model details including supported config parameters, input/output modalities, pricing, and active status. - -### Path Parameters - -- **`provider`** (required) — Provider name (e.g. `openai`, `google`) -- **`model_name`** (required) — Model name (e.g. `gpt-4o`, `gpt-4o-mini`) - -### Example Response - -```json -{ - "success": true, - "data": { - "id": 2, - "provider": "openai", - "model_name": "gpt-4o", - "config": { - "temperature": { - "type": "float", - "default": 1.0, - "min": 0.0, - "max": 2.0, - "description": "Controls randomness. Lower = more deterministic." - }, - "top_p": { - "type": "float", - "default": 1.0, - "min": 0.0, - "max": 1.0, - "description": "Nucleus sampling. Use either this or temperature, not both." - }, - "max_output_tokens": { - "type": "int", - "default": 2048, - "min": 1, - "max": 32768, - "description": "Max tokens in the response." - } - }, - "input_modalities": ["TEXT", "IMAGE"], - "output_modalities": ["TEXT"], - "pricing": { - "response": { - "input_token_cost": 2.5, - "output_token_cost": 10 - }, - "batch": { - "input_token_cost": 1.25, - "output_token_cost": 5 - } - }, - "is_active": true, - "inserted_at": "2026-03-12T00:00:00", - "updated_at": "2026-03-12T00:00:00" - } -} -``` - -### Error Response - -- `404 Not Found` — Model not found for the given `provider` and `model_name`. +- `404` — Model not found. diff --git a/backend/app/api/docs/model_config/list_models.md b/backend/app/api/docs/model_config/list_models.md index 321a3d673..522330275 100644 --- a/backend/app/api/docs/model_config/list_models.md +++ b/backend/app/api/docs/model_config/list_models.md @@ -1,74 +1,3 @@ -## Endpoint +List active model configurations. -**GET** `/api/v1/models` - -Retrieve a list of all active model configurations. - -Returns model details including provider, model name, supported config parameters, input/output modalities, pricing, and active status. - -Optionally filter by provider (e.g. openai, google). - -### Query Parameters - -- **`provider`** (optional) — Filter by provider name (e.g. `openai`, `google`) -- **`skip`** (optional, default 0) — Number of records to skip for pagination -- **`limit`** (optional, default 100, max 100) — Maximum number of records to return - -### Example Response - -```json -{ - "success": true, - "metadata": { - "has_more": true - }, - "data": { - "data": [ - { - "id": 1, - "provider": "openai", - "model_name": "gpt-4o-mini", - "config": { - "temperature": { - "type": "float", - "default": 1.0, - "min": 0.0, - "max": 2.0, - "description": "Controls randomness. Lower = more deterministic." - }, - "top_p": { - "type": "float", - "default": 1.0, - "min": 0.0, - "max": 1.0, - "description": "Nucleus sampling. Use either this or temperature, not both." - }, - "max_output_tokens": { - "type": "int", - "default": 2048, - "min": 1, - "max": 32768, - "description": "Max tokens in the response." - } - }, - "input_modalities": ["TEXT", "IMAGE"], - "output_modalities": ["TEXT"], - "pricing": { - "response": { - "input_token_cost": 0.15, - "output_token_cost": 0.6 - }, - "batch": { - "input_token_cost": 0.075, - "output_token_cost": 0.3 - } - }, - "is_active": true, - "inserted_at": "2026-03-12T00:00:00", - "updated_at": "2026-03-12T00:00:00" - } - ], - "count": 1 - } -} -``` +Filter by `provider` (`openai` | `google` | `sarvamai` | `elevenlabs`). Paginate with `skip` / `limit` (max 100). Only active models returned, sorted by `provider`, `model_name`. `metadata.has_more` flags more records. diff --git a/backend/app/api/docs/model_config/list_models_grouped.md b/backend/app/api/docs/model_config/list_models_grouped.md index 706beba37..9f478f962 100644 --- a/backend/app/api/docs/model_config/list_models_grouped.md +++ b/backend/app/api/docs/model_config/list_models_grouped.md @@ -1,52 +1,5 @@ -## Endpoint +List active models grouped by provider. -**GET** `/api/v1/models/grouped` +Returns a dict keyed by provider; each value is the list of that provider's active models. -Retrieve active models grouped by provider. - -Supports pagination of model rows before grouping: -- `skip` (default `0`) -- `limit` (default `100`, max `100`) - -Returns a dictionary where each key is a provider present in the paginated slice, and each value is a list of active model configurations for that provider. -Includes `metadata.has_more` when additional model rows exist. - -### Example Response - -```json -{ - "success": true, - "metadata": { - "has_more": true - }, - "data": { - "openai": [ - { - "id": 2, - "provider": "openai", - "model_name": "gpt-4o", - "config": { - "temperature": { - "type": "float", - "default": 1.0, - "min": 0.0, - "max": 2.0, - "description": "Controls randomness. Lower = more deterministic." - } - }, - "input_modalities": ["TEXT", "IMAGE"], - "output_modalities": ["TEXT"], - "pricing": { - "response": { - "input_token_cost": 2.5, - "output_token_cost": 10 - } - }, - "is_active": true, - "inserted_at": "2026-03-12T00:00:00", - "updated_at": "2026-03-12T00:00:00" - } - ] - } -} -``` +Pagination (`skip` / `limit`) is applied **before** grouping — adjust `limit` if expecting many models per provider. diff --git a/backend/app/api/docs/model_config/list_providers.md b/backend/app/api/docs/model_config/list_providers.md index aa498bf68..77e69ef39 100644 --- a/backend/app/api/docs/model_config/list_providers.md +++ b/backend/app/api/docs/model_config/list_providers.md @@ -1,16 +1 @@ -## Endpoint - -**GET** `/api/v1/models/providers` - -Retrieve the list of providers that currently have active models. - -Returns provider names sorted in ascending order. - -### Example Response - -```json -{ - "success": true, - "data": ["google", "openai"] -} -``` +List provider names that have at least one active model. Sorted ascending. diff --git a/backend/app/api/docs/model_config/update_model.md b/backend/app/api/docs/model_config/update_model.md new file mode 100644 index 000000000..b7ce7427e --- /dev/null +++ b/backend/app/api/docs/model_config/update_model.md @@ -0,0 +1,21 @@ +Partially update a model configuration. Only fields sent are updated; omitted fields stay unchanged. + +**Updatable fields:** `completion_type`, `config`, `input_modalities`, `output_modalities`, `pricing`, `is_active` + +Arrays and objects are **replaced** (no deep merge). `provider` and `model_name` cannot be changed here. + +### Example + +```json +{ + "completion_type": ["text", "stt"], + "pricing": { + "response": { "input_token_cost": 0.5, "output_token_cost": 3.0 } + } +} +``` + +### Errors + +- `404` — Model not found. +- `422` — Validation error. diff --git a/backend/app/api/docs/model_config/update_models.md b/backend/app/api/docs/model_config/update_models.md new file mode 100644 index 000000000..a61ce33ce --- /dev/null +++ b/backend/app/api/docs/model_config/update_models.md @@ -0,0 +1,19 @@ +Update multiple model configurations in one request. + +Each item must include `provider` + `model_name` to identify the target. Other fields are optional and follow the same rules as the single PATCH endpoint (replace semantics, no deep merge). + +Atomic — if any target is missing, no updates are applied. + +### Example + +```json +[ + { "provider": "google", "model_name": "gemini-2.5-flash", "completion_type": ["text", "stt"] }, + { "provider": "sarvamai", "model_name": "saaras:v3", "is_active": false } +] +``` + +### Errors + +- `404` — One or more targets not found. +- `422` — Validation error. diff --git a/backend/app/api/routes/model_config.py b/backend/app/api/routes/model_config.py index 751055839..3cdb34b79 100644 --- a/backend/app/api/routes/model_config.py +++ b/backend/app/api/routes/model_config.py @@ -4,12 +4,23 @@ from fastapi import APIRouter, HTTPException, Query from app.api.deps import SessionDep +from app.models.llm.constants import Provider from app.crud.model_config import ( + bulk_create_model_configs, + bulk_update_model_configs, + delete_model_config, get_model_config, list_active_model_configs, list_all_active_model_configs, + update_model_config, +) +from app.models import ( + ModelConfigBulkUpdateItem, + ModelConfigCreate, + ModelConfigListPublic, + ModelConfigPublic, + ModelConfigUpdate, ) -from app.models import ModelConfigListPublic, ModelConfigPublic from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -77,7 +88,7 @@ def list_providers( description=load_description("model_config/get_model.md"), ) def get_model( - session: SessionDep, provider: str, model_name: str + session: SessionDep, provider: Provider, model_name: str ) -> APIResponse[ModelConfigPublic]: model = get_model_config(session=session, provider=provider, model_name=model_name) @@ -85,3 +96,56 @@ def get_model( raise HTTPException(status_code=404, detail="Model not found") return APIResponse.success_response(model) + + +@router.post( + "", + response_model=APIResponse[list[ModelConfigPublic]], + description=load_description("model_config/create_models.md"), + status_code=201, +) +def create_models( + session: SessionDep, data: ModelConfigCreate | list[ModelConfigCreate] +) -> APIResponse[list[ModelConfigPublic]]: + items = data if isinstance(data, list) else [data] + models = bulk_create_model_configs(session=session, items=items) + return APIResponse.success_response(models) + + +@router.patch( + "/{provider}/{model_name}", + response_model=APIResponse[ModelConfigPublic], + description=load_description("model_config/update_model.md"), +) +def update_model( + session: SessionDep, provider: Provider, model_name: str, data: ModelConfigUpdate +) -> APIResponse[ModelConfigPublic]: + model = update_model_config( + session=session, provider=provider, model_name=model_name, data=data + ) + return APIResponse.success_response(model) + + +@router.patch( + "", + response_model=APIResponse[list[ModelConfigPublic]], + description=load_description("model_config/update_models.md"), + summary="Update Models", +) +def bulk_update_models( + session: SessionDep, items: list[ModelConfigBulkUpdateItem] +) -> APIResponse[list[ModelConfigPublic]]: + models = bulk_update_model_configs(session=session, items=items) + return APIResponse.success_response(models) + + +@router.delete( + "/{provider}/{model_name}", + response_model=APIResponse[None], + description=load_description("model_config/delete_model.md"), +) +def delete_model( + session: SessionDep, provider: Provider, model_name: str +) -> APIResponse[None]: + delete_model_config(session=session, provider=provider, model_name=model_name) + return APIResponse.success_response(None) diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index 9c627f7f4..ae2a84a48 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -1,13 +1,17 @@ +from datetime import datetime from typing import Any, Literal from fastapi import HTTPException from sqlmodel import Session, select from app.models import ModelConfig +from app.models.llm.constants import CompletionType, Provider from app.models.llm.request import ConfigBlob -from app.models.model_config import CompletionType - -Provider = Literal["openai", "google", "sarvamai", "elevenlabs"] +from app.models.model_config import ( + ModelConfigBulkUpdateItem, + ModelConfigCreate, + ModelConfigUpdate, +) def _normalize_provider(raw: str) -> str: @@ -68,7 +72,7 @@ def list_supported_models( """Return active model names for a provider + completion type.""" stmt = select(ModelConfig.model_name).where( ModelConfig.provider == provider, - ModelConfig.completion_type == completion_type, + ModelConfig.completion_type.contains([completion_type]), # type: ignore[union-attr] ModelConfig.is_active, ) return list(session.exec(stmt).all()) @@ -80,11 +84,11 @@ def is_model_supported( completion_type: CompletionType, model_name: str, ) -> bool: - """Check whether (provider, model_name) is active and matches the completion type.""" + """Check whether (provider, model_name) is active and supports the completion type.""" stmt = select(ModelConfig.id).where( ModelConfig.provider == provider, ModelConfig.model_name == model_name, - ModelConfig.completion_type == completion_type, + ModelConfig.completion_type.contains([completion_type]), # type: ignore[union-attr] ModelConfig.is_active, ) return session.exec(stmt).first() is not None @@ -164,6 +168,79 @@ def validate_blob_model_or_raise(session: Session, blob: ConfigBlob) -> None: ) +def create_model_config(session: Session, data: ModelConfigCreate) -> ModelConfig: + model = ModelConfig.model_validate(data) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def bulk_create_model_configs( + session: Session, items: list[ModelConfigCreate] +) -> list[ModelConfig]: + models = [ModelConfig.model_validate(item) for item in items] + session.add_all(models) + session.commit() + for m in models: + session.refresh(m) + return models + + +def update_model_config( + session: Session, provider: str, model_name: str, data: ModelConfigUpdate +) -> ModelConfig: + model = get_model_config(session=session, provider=provider, model_name=model_name) # type: ignore[arg-type] + if model is None: + raise HTTPException(status_code=404, detail="Model not found") + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(model, field, value) + model.updated_at = datetime.utcnow() + session.add(model) + session.commit() + session.refresh(model) + return model + + +def bulk_update_model_configs( + session: Session, items: list[ModelConfigBulkUpdateItem] +) -> list[ModelConfig]: + keys = [(item.provider, item.model_name) for item in items] + existing: dict[tuple, ModelConfig] = {} + for provider, model_name in keys: + m = get_model_config(session=session, provider=provider, model_name=model_name) + if m is None: + raise HTTPException( + status_code=404, + detail=f"Model '{model_name}' not found for provider='{provider}'", + ) + existing[(provider, model_name)] = m + updated = [] + now = datetime.utcnow() + for item in items: + model = existing[(item.provider, item.model_name)] + for field, value in item.model_dump( + exclude_unset=True, exclude={"provider", "model_name"} + ).items(): + setattr(model, field, value) + model.updated_at = now + session.add(model) + updated.append(model) + session.commit() + for m in updated: + session.refresh(m) + return updated + + +def delete_model_config(session: Session, provider: str, model_name: str) -> None: + model = get_model_config(session=session, provider=provider, model_name=model_name) # type: ignore[arg-type] + if model is None: + raise HTTPException(status_code=404, detail="Model not found") + session.delete(model) + session.commit() + + def is_reasoning_model(session: Session, provider: Provider, model_name: str) -> bool: """Return True if the model is configured with a reasoning `effort` control. diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 3d9a2c4c6..24c7f6897 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -131,8 +131,11 @@ from .model_config import ( ModelConfig, ModelConfigBase, + ModelConfigBulkUpdateItem, + ModelConfigCreate, ModelConfigListPublic, ModelConfigPublic, + ModelConfigUpdate, ) from .model_evaluation import ( ModelEvaluation, diff --git a/backend/app/models/llm/constants.py b/backend/app/models/llm/constants.py index 1838da79d..97b7a3c89 100644 --- a/backend/app/models/llm/constants.py +++ b/backend/app/models/llm/constants.py @@ -1,3 +1,26 @@ +from enum import StrEnum + + +class Provider(StrEnum): + OPENAI = "openai" + GOOGLE = "google" + SARVAMAI = "sarvamai" + ELEVENLABS = "elevenlabs" + + +class CompletionType(StrEnum): + TEXT = "text" + STT = "stt" + TTS = "tts" + + +class Modality(StrEnum): + TEXT = "TEXT" + AUDIO = "AUDIO" + IMAGE = "IMAGE" + FILES = "FILES" + + DEFAULT_STT_MODEL = "gemini-2.5-pro" DEFAULT_TTS_MODEL = "gemini-2.5-flash-preview-tts" DEFAULT_TTS_VOICE = "Kore" diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index da0c18120..6a8eea13d 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -13,6 +13,8 @@ DEFAULT_STT_MODEL, DEFAULT_TTS_MODEL, DEFAULT_TTS_VOICE, + CompletionType, + Provider, ) @@ -236,7 +238,7 @@ class NativeCompletionConfig(SQLModel): ..., description="Provider-specific parameters (schema varies by provider), should exactly match the provider's endpoint params structure", ) - type: Literal["text", "stt", "tts"] = Field( + type: CompletionType = Field( ..., description="Completion config type. Params schema varies by type" ) @@ -248,11 +250,11 @@ class KaapiCompletionConfig(SQLModel): Supports multiple providers: OpenAI, Claude, Gemini, etc. """ - provider: Literal["openai", "google", "sarvamai", "elevenlabs"] | None = Field( + provider: Provider | None = Field( None, description="LLM provider (openai, google, sarvamai, elevenlabs)" ) - type: Literal["text", "stt", "tts"] = Field( + type: CompletionType = Field( ..., description="Completion config type. Params schema varies by type" ) params: dict[str, Any] = Field( diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py index ef284dc5f..01f7ddd08 100644 --- a/backend/app/models/model_config.py +++ b/backend/app/models/model_config.py @@ -1,26 +1,23 @@ from datetime import datetime -from typing import Any, Literal +from typing import Any import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY, JSONB from sqlmodel import Field, SQLModel from app.core.util import now - -CompletionType = Literal["text", "stt", "tts"] +from app.models.llm.constants import CompletionType, Modality, Provider class ModelConfigBase(SQLModel): - provider: Literal["openai", "google", "sarvamai", "elevenlabs"] = Field( + provider: Provider = Field( default="openai", sa_column=sa.Column( sa.Enum( - "openai", - "google", - "sarvamai", - "elevenlabs", + *[p.value for p in Provider], name="provider_enum", schema="global", + create_type=False, ), nullable=False, comment="provider name (e.g. openai, google, sarvamai, elevenlabs)", @@ -36,12 +33,21 @@ class ModelConfigBase(SQLModel): ), ) - completion_type: CompletionType = Field( + completion_type: list[CompletionType] = Field( ..., sa_column=sa.Column( - sa.Enum("text", "stt", "tts", name="completion_type_enum", schema="global"), + ARRAY( + sa.Enum( + "text", + "stt", + "tts", + name="completion_type_enum", + schema="global", + create_type=False, + ) + ), nullable=False, - comment="text | stt | tts — drives routing and validation", + comment="supported completion types: text, stt, tts", ), ) @@ -50,7 +56,7 @@ class ModelConfigBase(SQLModel): sa_column=sa.Column(JSONB, nullable=False, comment="model adhoc configuration"), ) - input_modalities: list[str] = Field( + input_modalities: list[Modality] = Field( default_factory=list, sa_column=sa.Column( ARRAY(sa.String), @@ -60,7 +66,7 @@ class ModelConfigBase(SQLModel): ), ) - output_modalities: list[str] = Field( + output_modalities: list[Modality] = Field( default_factory=list, sa_column=sa.Column( ARRAY(sa.String), @@ -101,10 +107,9 @@ class ModelConfig(ModelConfigBase, table=True): sa.UniqueConstraint("provider", "model_name"), sa.Index("ix_model_config_provider_active", "provider", "is_active"), sa.Index( - "ix_model_config_provider_type_active", - "provider", + "ix_model_config_completion_type", "completion_type", - "is_active", + postgresql_using="gin", ), sa.Index( "ix_model_config_input_modalities", @@ -150,6 +155,49 @@ class ModelConfig(ModelConfigBase, table=True): ) +class ModelConfigCreate(ModelConfigBase): + pass + + +class ModelConfigUpdate(SQLModel): + completion_type: list[CompletionType] | None = None + config: dict[str, Any] | None = None + input_modalities: list[Modality] | None = None + output_modalities: list[Modality] | None = None + pricing: dict[str, Any] | None = None + is_active: bool | None = None + + model_config = { + "json_schema_extra": { + "example": { + "completion_type": ["text", "stt"], + "config": { + "temperature": { + "type": "float", + "default": 1.0, + "min": 0.0, + "max": 2.0, + "description": "Controls randomness.", + } + }, + "input_modalities": ["TEXT", "AUDIO"], + "output_modalities": ["TEXT"], + "pricing": { + "response": {"input_token_cost": 0.5, "output_token_cost": 2.0}, + "batch": {"input_token_cost": 0.25, "output_token_cost": 1.0}, + "audio": {"input_token_cost": 1.0, "output_token_cost": 2.0}, + }, + "is_active": True, + } + } + } + + +class ModelConfigBulkUpdateItem(ModelConfigUpdate): + provider: Provider + model_name: str + + class ModelConfigPublic(ModelConfigBase): id: int inserted_at: datetime diff --git a/backend/app/tests/api/routes/test_model_config.py b/backend/app/tests/api/routes/test_model_config.py index e63b0bbbb..07e5d744f 100644 --- a/backend/app/tests/api/routes/test_model_config.py +++ b/backend/app/tests/api/routes/test_model_config.py @@ -1,8 +1,32 @@ +import uuid + from fastapi.testclient import TestClient from app.core.config import settings +def _payload(model_name: str | None = None, **overrides) -> dict: + base = { + "provider": "google", + "model_name": model_name or f"test-{uuid.uuid4().hex[:8]}", + "completion_type": ["text"], + "config": {}, + "input_modalities": ["TEXT"], + "output_modalities": ["TEXT"], + "pricing": None, + "is_active": True, + } + base.update(overrides) + return base + + +def _delete(client: TestClient, headers: dict, provider: str, model_name: str) -> None: + client.delete( + f"{settings.API_V1_STR}/models/{provider}/{model_name}", + headers=headers, + ) + + def test_list_models( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: @@ -146,3 +170,416 @@ def test_list_providers( assert providers == sorted(providers) assert len(providers) == len(set(providers)) assert "openai" in providers + + +def test_create_model_single( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(completion_type=["text", "stt"]) + response = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert response.status_code == 201 + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["model_name"] == payload["model_name"] + assert data[0]["completion_type"] == ["text", "stt"] + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_create_model_multiple( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payloads = [_payload(), _payload()] + response = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payloads, + ) + assert response.status_code == 201 + data = response.json()["data"] + assert len(data) == 2 + names = {m["model_name"] for m in data} + assert names == {p["model_name"] for p in payloads} + + for p in payloads: + _delete(client, superuser_token_headers, p["provider"], p["model_name"]) + + +def test_create_model_invalid_modality( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(input_modalities=["INVALID"]) + response = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert response.status_code == 422 + + +def test_create_model_invalid_completion_type( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(completion_type=["invalid"]) + response = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert response.status_code == 422 + + +def test_update_model_replaces_completion_type( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(completion_type=["stt"]) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.patch( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + json={"completion_type": ["text", "stt"]}, + ) + assert response.status_code == 200 + assert set(response.json()["data"]["completion_type"]) == {"text", "stt"} + + response = client.patch( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + json={"completion_type": ["text"]}, + ) + assert response.status_code == 200 + assert response.json()["data"]["completion_type"] == ["text"] + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_update_model_only_sent_fields_changed( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload( + completion_type=["text"], + pricing={"response": {"input_token_cost": 1.0, "output_token_cost": 2.0}}, + ) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.patch( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + json={"is_active": False}, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert data["is_active"] is False + assert data["completion_type"] == ["text"] + assert data["pricing"]["response"]["input_token_cost"] == 1.0 + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_update_model_not_found( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + response = client.patch( + f"{settings.API_V1_STR}/models/google/does-not-exist", + headers=superuser_token_headers, + json={"is_active": False}, + ) + assert response.status_code == 404 + + +def test_update_model_invalid_modality( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload() + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.patch( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + json={"input_modalities": ["BOGUS"]}, + ) + assert response.status_code == 422 + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_bulk_update_models( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + p1 = _payload(completion_type=["stt"]) + p2 = _payload(completion_type=["tts"]) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=[p1, p2], + ) + + response = client.patch( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=[ + { + "provider": p1["provider"], + "model_name": p1["model_name"], + "completion_type": ["text", "stt"], + }, + { + "provider": p2["provider"], + "model_name": p2["model_name"], + "is_active": False, + }, + ], + ) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + by_name = {m["model_name"]: m for m in data} + assert set(by_name[p1["model_name"]]["completion_type"]) == {"text", "stt"} + assert by_name[p2["model_name"]]["is_active"] is False + + for p in [p1, p2]: + _delete(client, superuser_token_headers, p["provider"], p["model_name"]) + + +def test_bulk_update_atomic_on_missing_target( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(completion_type=["stt"]) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.patch( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=[ + { + "provider": payload["provider"], + "model_name": payload["model_name"], + "completion_type": ["text"], + }, + { + "provider": "google", + "model_name": "definitely-does-not-exist", + "is_active": False, + }, + ], + ) + assert response.status_code == 404 + + # First item should not have been updated either + check = client.get( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + ) + assert check.json()["data"]["completion_type"] == ["stt"] + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_delete_model( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload() + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.delete( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + ) + assert response.status_code == 200 + + check = client.get( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + ) + assert check.status_code == 404 + + +def test_delete_model_not_found( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + response = client.delete( + f"{settings.API_V1_STR}/models/google/does-not-exist", + headers=superuser_token_headers, + ) + assert response.status_code == 404 + + +def test_create_model_duplicate_fails( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload() + r1 = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert r1.status_code == 201 + + r2 = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert r2.status_code >= 400 + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_create_invalid_provider( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(provider="nonexistent_provider") + response = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert response.status_code == 422 + + +def test_update_empty_body_noop( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(completion_type=["stt"]) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.patch( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + json={}, + ) + assert response.status_code == 200 + assert response.json()["data"]["completion_type"] == ["stt"] + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_bulk_update_empty_array( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + response = client.patch( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=[], + ) + assert response.status_code == 200 + assert response.json()["data"] == [] + + +def test_list_models_invalid_provider_filter( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + response = client.get( + f"{settings.API_V1_STR}/models?provider=bogus", + headers=superuser_token_headers, + ) + assert response.status_code == 200 + assert response.json()["data"]["count"] == 0 + + +def test_inactive_model_excluded_from_list( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload(is_active=False) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.get( + f"{settings.API_V1_STR}/models?provider={payload['provider']}", + headers=superuser_token_headers, + ) + names = [m["model_name"] for m in response.json()["data"]["data"]] + assert payload["model_name"] not in names + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_update_pricing_replaces_full_object( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + payload = _payload( + pricing={ + "response": {"input_token_cost": 1.0, "output_token_cost": 2.0}, + "batch": {"input_token_cost": 0.5, "output_token_cost": 1.0}, + }, + ) + client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + + response = client.patch( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + json={ + "pricing": { + "response": {"input_token_cost": 5.0, "output_token_cost": 10.0} + } + }, + ) + assert response.status_code == 200 + pricing = response.json()["data"]["pricing"] + assert pricing["response"]["input_token_cost"] == 5.0 + assert "batch" not in pricing # full replace, not merge + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) + + +def test_create_with_multiple_completion_types_and_query( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + """Model supporting both text + stt should appear when filtering by either type via validation.""" + payload = _payload( + completion_type=["text", "stt"], + input_modalities=["TEXT", "AUDIO"], + ) + response = client.post( + f"{settings.API_V1_STR}/models", + headers=superuser_token_headers, + json=payload, + ) + assert response.status_code == 201 + created = response.json()["data"][0] + assert set(created["completion_type"]) == {"text", "stt"} + + fetched = client.get( + f"{settings.API_V1_STR}/models/{payload['provider']}/{payload['model_name']}", + headers=superuser_token_headers, + ) + assert set(fetched.json()["data"]["completion_type"]) == {"text", "stt"} + + _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) From 8ca252f5c76d3c8c8958685b1206522fbf65a00f Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 26 May 2026 13:53:35 +0530 Subject: [PATCH 2/2] feat(model-config): enhance model creation with integrity error handling and update provider type in list models --- backend/app/api/routes/model_config.py | 2 +- backend/app/crud/model_config.py | 10 +++++++++- backend/app/tests/api/routes/test_model_config.py | 5 ++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/backend/app/api/routes/model_config.py b/backend/app/api/routes/model_config.py index 3cdb34b79..7db086d31 100644 --- a/backend/app/api/routes/model_config.py +++ b/backend/app/api/routes/model_config.py @@ -34,7 +34,7 @@ ) def list_models( session: SessionDep, - provider: str | None = None, + provider: Provider | None = None, skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), ) -> APIResponse[ModelConfigListPublic]: diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index ae2a84a48..2f5c517ae 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -2,6 +2,7 @@ from typing import Any, Literal from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select from app.models import ModelConfig @@ -181,7 +182,14 @@ def bulk_create_model_configs( ) -> list[ModelConfig]: models = [ModelConfig.model_validate(item) for item in items] session.add_all(models) - session.commit() + try: + session.commit() + except IntegrityError as e: + session.rollback() + raise HTTPException( + status_code=409, + detail="Duplicate (provider, model_name) — entry already exists", + ) from e for m in models: session.refresh(m) return models diff --git a/backend/app/tests/api/routes/test_model_config.py b/backend/app/tests/api/routes/test_model_config.py index 07e5d744f..9d82bbeae 100644 --- a/backend/app/tests/api/routes/test_model_config.py +++ b/backend/app/tests/api/routes/test_model_config.py @@ -446,7 +446,7 @@ def test_create_model_duplicate_fails( headers=superuser_token_headers, json=payload, ) - assert r2.status_code >= 400 + assert r2.status_code == 409 _delete(client, superuser_token_headers, payload["provider"], payload["model_name"]) @@ -503,8 +503,7 @@ def test_list_models_invalid_provider_filter( f"{settings.API_V1_STR}/models?provider=bogus", headers=superuser_token_headers, ) - assert response.status_code == 200 - assert response.json()["data"]["count"] == 0 + assert response.status_code == 422 def test_inactive_model_excluded_from_list(