diff --git a/.github/workflows/helm-release.yml b/.github/workflows/helm-release.yml index f36b281abc..7ef58f625b 100644 --- a/.github/workflows/helm-release.yml +++ b/.github/workflows/helm-release.yml @@ -105,6 +105,7 @@ jobs: cd helm-charts gh release create "helm-chart-v${{ steps.chart-version.outputs.version }}" \ bifrost-${{ steps.chart-version.outputs.version }}.tgz \ + --target ${{ github.sha }} \ --title "Helm Chart v${{ steps.chart-version.outputs.version }}" \ --notes "Helm chart release for Bifrost v${{ steps.chart-version.outputs.version }}" env: diff --git a/core/bifrost.go b/core/bifrost.go index a67edf9ebc..860a2f256d 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -514,9 +514,20 @@ func (bifrost *Bifrost) ListAllModels(ctx *schemas.BifrostContext, req *schemas. response, bifrostErr := bifrost.ListModelsRequest(providerCtx, providerRequest) if bifrostErr != nil { - // Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured - if !strings.Contains(bifrostErr.Error.Message, "no keys found") && - !strings.Contains(bifrostErr.Error.Message, "not supported") { + // Some per-provider failures are expected when fanning out across all + // configured providers and must not be surfaced as a top-level error + errType := "" + if bifrostErr.Type != nil { + errType = *bifrostErr.Type + } + errMsg := "" + if bifrostErr.Error != nil { + errMsg = bifrostErr.Error.Message + } + isExpected := strings.Contains(errMsg, "no keys found") || + strings.Contains(errMsg, "not supported") || + errType == "provider_blocked" + if !isExpected { providerErr = bifrostErr bifrost.logger.Warn("failed to list models for provider %s: %s", providerKey, bifrostErr.GetErrorString()) } diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 049dbce854..09c00508c9 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -711,7 +711,7 @@ func (ma *MockAccount) AddProviderWithBaseURL(provider schemas.ModelProvider, co ma.configs[provider] = &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: baseURL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, MaxRetries: 3, RetryBackoffInitial: 500 * time.Millisecond, RetryBackoffMax: 5 * time.Second, diff --git a/core/providers/bedrock/transport_test.go b/core/providers/bedrock/transport_test.go index 346565a61e..9484d9d868 100644 --- a/core/providers/bedrock/transport_test.go +++ b/core/providers/bedrock/transport_test.go @@ -547,7 +547,7 @@ func generateTestCACert(t *testing.T) string { func TestBedrockTransportHTTP2Config(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, MaxConnsPerHost: 5000, EnforceHTTP2: true, }, @@ -570,7 +570,7 @@ func TestBedrockTransportHTTP2Config(t *testing.T) { func TestBedrockTransportCustomMaxConns(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, MaxConnsPerHost: 50, }, } @@ -590,7 +590,7 @@ func TestBedrockTransportCustomMaxConns(t *testing.T) { func TestBedrockTransportDefaultMaxConns(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, // MaxConnsPerHost left as 0 — should default to 5000 }, } @@ -612,7 +612,7 @@ func TestBedrockTransportDefaultMaxConns(t *testing.T) { func TestBedrockTransportTLSInsecureSkipVerify(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, InsecureSkipVerify: true, EnforceHTTP2: true, }, @@ -636,7 +636,7 @@ func TestBedrockTransportTLSCACert(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, CACertPEM: schemas.NewEnvVar(testCACert), EnforceHTTP2: true, }, @@ -657,7 +657,7 @@ func TestBedrockTransportTLSCACert(t *testing.T) { func TestBedrockTransportDefaultTLS(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, // No TLS settings — should use system defaults }, } @@ -677,7 +677,7 @@ func TestBedrockTransportDefaultTLS(t *testing.T) { func TestBedrockTransportEnforceHTTP2(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, EnforceHTTP2: true, }, } @@ -696,7 +696,7 @@ func TestBedrockTransportEnforceHTTP2(t *testing.T) { func TestBedrockTransportEnforceHTTP2Disabled(t *testing.T) { config := &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, EnforceHTTP2: false, }, } diff --git a/core/providers/fireworks/fireworks_test.go b/core/providers/fireworks/fireworks_test.go index f59284960c..def08f5550 100644 --- a/core/providers/fireworks/fireworks_test.go +++ b/core/providers/fireworks/fireworks_test.go @@ -431,7 +431,7 @@ func newTestFireworksProvider(t *testing.T, baseURL string) *fireworksprovider.F provider, err := fireworksprovider.NewFireworksProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: baseURL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, bifrost.NewNoOpLogger()) if err != nil { diff --git a/core/providers/mistral/ocr_test.go b/core/providers/mistral/ocr_test.go index ccf7223e4c..92195b6816 100644 --- a/core/providers/mistral/ocr_test.go +++ b/core/providers/mistral/ocr_test.go @@ -605,7 +605,7 @@ func TestOCRWithMockServer(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: server.URL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -639,7 +639,7 @@ func TestOCRNilInput(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: "https://api.mistral.ai", - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -686,7 +686,7 @@ func TestOCRRequestValidation(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: server.URL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) diff --git a/core/providers/mistral/transcription_test.go b/core/providers/mistral/transcription_test.go index 0056f6315f..9129f5903b 100644 --- a/core/providers/mistral/transcription_test.go +++ b/core/providers/mistral/transcription_test.go @@ -593,7 +593,7 @@ func TestTranscriptionWithMockServer(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: server.URL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -637,7 +637,7 @@ func TestTranscriptionNilInput(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: "https://api.mistral.ai", - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -787,7 +787,7 @@ func TestTranscriptionStreamWithMockServer(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: server.URL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -842,7 +842,7 @@ func TestTranscriptionStreamNilInput(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: "https://api.mistral.ai", - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -1272,7 +1272,7 @@ func TestTranscriptionStreamEdgeCases(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: server.URL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) @@ -1342,7 +1342,7 @@ func TestTranscriptionStreamContextCancellation(t *testing.T) { provider := NewMistralProvider(&schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: server.URL, - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, }, }, &testLogger{}) diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 27f2ecb0d0..fde4fed8c5 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -2217,7 +2217,7 @@ func SetupStreamCancellation(ctx *schemas.BifrostContext, bodyStream io.Reader, // before bifrost considers the connection stalled and closes it. This protects // against providers that stop sending data but keep the TCP connection open // (e.g., Azure TPM throttling). -const DefaultStreamIdleTimeout = 60 * time.Second +const DefaultStreamIdleTimeout = 120 * time.Second // SetStreamIdleTimeoutIfEmpty sets the stream idle timeout on the context from // the provider's network config, but only if no valid timeout is already present. diff --git a/core/schemas/provider.go b/core/schemas/provider.go index ddb32278f3..7cd4bb8db4 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -13,12 +13,12 @@ const ( DefaultMaxRetries = 0 DefaultRetryBackoffInitial = 500 * time.Millisecond DefaultRetryBackoffMax = 5 * time.Second - DefaultRequestTimeoutInSeconds = 30 + DefaultRequestTimeoutInSeconds = 300 DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops DefaultBufferSize = 5000 DefaultConcurrency = 1000 DefaultStreamBufferSize = 256 - DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection + DefaultStreamIdleTimeoutInSeconds = 120 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection DefaultMaxConnsPerHost = 5000 MaxConnsPerHostUpperBound = 10000 DefaultMaxIdleConnsPerHost = 40 @@ -26,7 +26,7 @@ const ( // Pre-defined errors for provider operations const ( - ErrProviderRequestTimedOut = "request timed out (default is 30 seconds). You can increase it by setting the default_request_timeout_in_seconds in the network_config or in UI - Providers > Provider Name > Network Config." + ErrProviderRequestTimedOut = "request timed out (default is 300 seconds). You can increase it by setting the default_request_timeout_in_seconds in the network_config or in UI - Providers > Provider Name > Network Config." ErrRequestCancelled = "request cancelled by caller" ErrRequestBodyConversion = "failed to convert bifrost request to the expected provider request body" ErrProviderRequestMarshal = "failed to marshal request body to JSON" diff --git a/core/schemas/serialization_test.go b/core/schemas/serialization_test.go index 378936cf4d..d744392305 100644 --- a/core/schemas/serialization_test.go +++ b/core/schemas/serialization_test.go @@ -1135,7 +1135,7 @@ func TestNetworkConfig_TLSFieldsRoundTrip(t *testing.T) { // round-trips correctly through JSON marshaling. func TestNetworkConfig_StreamIdleTimeoutRoundTrip(t *testing.T) { nc := NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, StreamIdleTimeoutInSeconds: 120, } diff --git a/docs/deployment-guides/config-json.mdx b/docs/deployment-guides/config-json.mdx index a46105fc8f..805efe513b 100644 --- a/docs/deployment-guides/config-json.mdx +++ b/docs/deployment-guides/config-json.mdx @@ -12,23 +12,26 @@ icon: "file-code" --- -## Two Configuration Modes +## Configuration Sources -Bifrost supports **two mutually exclusive modes**. You cannot run both at the same time. +Bifrost stores runtime configuration in a config database by default, so settings can be edited through the Web UI or API. You can also provide a `config.json` file to seed or reconcile that database at startup. To run with only `config.json`, set `config_store.enabled: false`. -| Mode | When | Behaviour | +| Setup | When | Behaviour | |------|------|-----------| -| **Web UI / database** | No `config.json`, or `config.json` with `config_store` enabled | Full UI available, configuration stored in SQLite or PostgreSQL | -| **File-based (`config.json`)** | `config.json` present, `config_store` disabled | UI disabled, all config loaded from file at startup, restart required for changes | +| **Web UI / database** | No `config.json` | Bifrost creates a default SQLite config store and runtime changes are saved through the UI or API | +| **DB-backed `config.json`** | `config.json` exists and `config_store` is omitted or enabled | File-backed sections seed or reconcile the config store at startup; UI/API edits remain available | +| **File-only `config.json`** | `config_store.enabled` is `false` | Config is loaded from file into memory at startup; config-backed UI/API changes are unavailable and file changes require restart | -See [Setting Up](/quickstart/gateway/setting-up#two-configuration-modes) for a full explanation of both modes and how `config_store` bootstrapping works. +By default, DB-backed `config.json` uses `source_of_truth: "split"`: unchanged file-backed rows preserve UI/API edits, while changed file-backed rows are applied on the next startup. Use `source_of_truth: "config.json"` only when explicitly present file sections should replace matching DB state. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth) for the full rules, including missing-vs-empty section behavior. --- ## Minimal Working Example +This example uses file-only configuration for the smallest self-contained setup. + ```json { "$schema": "https://www.getbifrost.ai/schema", diff --git a/docs/deployment-guides/config-json/governance.mdx b/docs/deployment-guides/config-json/governance.mdx index e039e254d8..d6fddb0161 100644 --- a/docs/deployment-guides/config-json/governance.mdx +++ b/docs/deployment-guides/config-json/governance.mdx @@ -6,6 +6,10 @@ icon: "shield-check" The `governance` block lets you seed all governance resources directly in `config.json`. On startup, Bifrost loads these into the configuration store. This is the recommended approach for GitOps workflows where governance state is managed as code. + +In default split mode, file-backed governance resources seed or update the DB by hash while unrelated DB-only resources are preserved. With `source_of_truth: "config.json"`, only governance sub-sections that are explicitly present in the file are authoritative. Omit a sub-section to leave DB-managed rows alone; set it to an empty array only when you intend to remove stored rows for that sub-section. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth). + + **Governance enforcement is always active** in OSS - you do not need a plugin entry to enable it. To require a virtual key on every inference request, set `client.enforce_auth_on_inference: true`. This is the global default, but a more specific inference-auth flag such as `governance.auth_config.disable_auth_on_inference` overrides it; if no specific override is set, `client.enforce_auth_on_inference` applies. diff --git a/docs/deployment-guides/config-json/plugins.mdx b/docs/deployment-guides/config-json/plugins.mdx index 2a2612a893..148758a28f 100644 --- a/docs/deployment-guides/config-json/plugins.mdx +++ b/docs/deployment-guides/config-json/plugins.mdx @@ -10,6 +10,10 @@ icon: "puzzle-piece" **Telemetry, logging, and governance are auto-loaded built-ins** - they are always active and configured via the `client` block and dedicated top-level keys, not the `plugins` array. + +In DB-backed deployments, plugin sync depends on the reconciliation mode. Split mode preserves DB plugins unless a file plugin has a higher `version` or changed placement/order. With `source_of_truth: "config.json"`, a present `plugins` array is authoritative; `plugins: []` removes stored opt-in plugins. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth). + + --- ## Auto-Loaded Built-ins diff --git a/docs/deployment-guides/config-json/providers.mdx b/docs/deployment-guides/config-json/providers.mdx index 33f5023b78..a52c974083 100644 --- a/docs/deployment-guides/config-json/providers.mdx +++ b/docs/deployment-guides/config-json/providers.mdx @@ -6,6 +6,10 @@ icon: "plug" All providers are configured under `providers` in `config.json`. Each provider entry contains a `keys` array where every key has a `name`, `value`, `models`, and `weight`, plus optional provider-specific config objects. + +In DB-backed deployments, provider entries from `config.json` are reconciled into the config store at startup. The default `source_of_truth: "split"` mode preserves UI/API edits while matching file-backed providers are unchanged. With `source_of_truth: "config.json"`, a present `providers` section is authoritative and prunes DB-only providers or keys. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth). + + **Supplying credentials:** Use the `env.` prefix to reference environment variables - never put API keys directly in `config.json`: diff --git a/docs/deployment-guides/config-json/schema-reference.mdx b/docs/deployment-guides/config-json/schema-reference.mdx index aae662e12f..b679705704 100644 --- a/docs/deployment-guides/config-json/schema-reference.mdx +++ b/docs/deployment-guides/config-json/schema-reference.mdx @@ -17,6 +17,8 @@ This page is a concise reference for every top-level key in `config.json`. Click | Key | Type | Description | Guide | |-----|------|-------------|-------| | `$schema` | string | Schema URL for IDE validation. Set to `"https://www.getbifrost.ai/schema"` | - | +| `version` | integer | Compatibility switch for empty allow-list arrays. Omit for current v2 semantics. | [`version`](#version) | +| `source_of_truth` | string | Startup reconciliation mode for DB-backed `config.json`: `"split"` or `"config.json"` | [Source of Truth](/deployment-guides/config-json/source-of-truth) | | `encryption_key` | string | Optional AES-256 key (derived via Argon2id). Accepts `env.VAR` prefix and is also read from `BIFROST_ENCRYPTION_KEY`. If omitted, data is stored in plaintext. | [Client](/deployment-guides/config-json/client#encryption-key) | | `client` | object | Worker pool, logging, CORS, auth enforcement, header filtering, MCP, compat shims | [Client](/deployment-guides/config-json/client) | | `providers` | object | LLM provider API keys, network settings, concurrency | [Providers](/deployment-guides/config-json/providers) | @@ -48,6 +50,28 @@ Omitting `version` uses v2 semantics. Set `"version": 1` only if you are migrati --- +## `source_of_truth` + +Controls how `config.json` is reconciled with the config store at startup. + +| Value | Behaviour | +|-------|-----------| +| `"split"` *(default)* | File-backed rows seed or update the config store by hash, while unchanged file-backed rows preserve UI/API edits | +| `"config.json"` | Explicitly present file sections are authoritative and replace matching DB state on startup | + +Missing and empty sections behave differently when `source_of_truth` is `"config.json"`. A missing section leaves DB rows untouched; a present empty section is authoritative and can prune matching DB rows. + +```json +{ + "source_of_truth": "config.json", + "plugins": [] +} +``` + +The example above makes the `plugins` section present and empty, so stored plugins are removed on startup. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth) for section-by-section behavior. + +--- + ## `client` Controls the worker pool, logging pipeline, security, and SDK shims. All fields are optional. diff --git a/docs/deployment-guides/config-json/source-of-truth.mdx b/docs/deployment-guides/config-json/source-of-truth.mdx new file mode 100644 index 0000000000..b1ce0a331f --- /dev/null +++ b/docs/deployment-guides/config-json/source-of-truth.mdx @@ -0,0 +1,117 @@ +--- +title: "Source of Truth & Reconciliation" +description: "How config.json, the config store, split mode, and authoritative file sync interact at startup" +icon: "shuffle" +--- + +Bifrost can use `config.json` in two different ways: + +- As the only runtime configuration source, with `config_store.enabled: false`. +- As a declarative bootstrap and reconciliation source for a SQLite or PostgreSQL config store. + +That distinction matters at startup. Bifrost uses `source_of_truth` and per-section reconciliation metadata, including `config_hash` for many config rows, to decide when file-backed configuration should preserve DB edits, update stored values, or prune DB-only rows. + + +`config_hash` is auto-managed. Do not set it manually in `config.json` or API payloads. + + +--- + +## Configuration Setups + +| Setup | Config store | Web UI / API edits | Startup behavior | +|-------|--------------|--------------------|------------------| +| No `config.json` | Default SQLite `config.db` in app-dir | Enabled | Bifrost starts with defaults and stores runtime changes in SQLite | +| `config.json` with `config_store` omitted | Default SQLite `config.db` in app-dir | Enabled | File sections are reconciled into SQLite, then DB state is used at runtime | +| `config.json` with `config_store.enabled: true` | Explicit SQLite or PostgreSQL | Enabled | File sections are reconciled into the configured store, then DB state is used at runtime | +| `config.json` with `config_store.enabled: false` | Disabled | Unavailable for config-backed surfaces | File is loaded into memory at startup; changes require restart | + +`source_of_truth` only affects DB-backed reconciliation. In file-only mode there is no config store to reconcile against, so `config.json` is naturally the runtime source. + +--- + +## Default Split Mode + +The default mode is: + +```json +{ + "source_of_truth": "split" +} +``` + +You can also omit `source_of_truth`; `split` is the default. + +In split mode, Bifrost treats `config.json` as a bootstrap and drift-detection source: + +1. On first startup, file-backed sections from `config.json` are written to the config store. +2. Stored rows keep reconciliation metadata for the file-backed definition. +3. UI/API edits update the DB state without changing the file-backed definition. +4. On later startups, unchanged file-backed definitions preserve DB edits. +5. If the matching file-backed definition changes, the new file version is applied for that section or entity. + +Split mode does not prune DB-only entries just because they are missing from `config.json`. Removing a provider, plugin, MCP client, or governance row from the file leaves the stored row in place; use `source_of_truth: "config.json"` when a present file section should prune DB-only rows. + +Use split mode when you want `config.json` to seed or update a deployment while preserving UI/API edits unless the matching file-backed definition changes. + + +Split mode preserves runtime edits only while the matching file-backed section or entity is unchanged. If you edit that entity in `config.json`, the file version wins on the next startup. + + +--- + +## config.json as Source of Truth + +Use this only when the file should actively control the matching DB state: + +```json +{ + "source_of_truth": "config.json" +} +``` + +In this mode, explicitly present sections in `config.json` are authoritative at startup. Bifrost applies the file values even if the stored `config_hash` still matches, because UI/API edits do not update the file hash. + +This is useful for stricter GitOps setups where the DB should converge back to the file after every restart or redeploy. + +--- + +## Missing vs Empty Sections + +In split mode, missing sections are left alone. Authoritative mode is stricter: a missing section is not the same as an empty section. + +This leaves stored plugins untouched: + +```json +{ + "source_of_truth": "config.json" +} +``` + +This makes the `plugins` section authoritative and empty, so DB-only plugins are removed: + +```json +{ + "source_of_truth": "config.json", + "plugins": [] +} +``` + +The same pattern applies to other supported top-level sections such as `providers`, `mcp`, and governance sub-sections. + + +Before using `source_of_truth: "config.json"` in production, check whether your file contains empty arrays or empty objects for sections you do not intend to prune. + + +--- + +## Recommended Use + +| Goal | Recommended setup | +|------|-------------------| +| Interactive single-node gateway | Omit `config.json`, or use `config.json` with a config store and default `split` mode | +| Bootstrap from file, then allow UI/API edits | DB-backed config store, explicit or default, with `source_of_truth: "split"` | +| Strict GitOps over selected sections | DB-backed config store, explicit or default, plus `source_of_truth: "config.json"` and only the sections you intend to own | +| File-only OSS multinode deployment | `config_store.enabled: false` with a shared `config.json` | + +For DB-backed deployments, prefer `split` unless you explicitly want restarts to revert UI/API changes back to `config.json`. diff --git a/docs/deployment-guides/config-json/storage.mdx b/docs/deployment-guides/config-json/storage.mdx index cd1ca5b8f1..23356d9c34 100644 --- a/docs/deployment-guides/config-json/storage.mdx +++ b/docs/deployment-guides/config-json/storage.mdx @@ -21,7 +21,7 @@ If you use PostgreSQL for any store, the target database must be **UTF8 encoded* ## config_store -When `config_store` is disabled (or absent), all configuration is loaded from `config.json` at startup only - the Web UI is disabled and changes require a restart. See [Two Configuration Modes](/deployment-guides/config-json#two-configuration-modes). +When `config_store` is omitted, Bifrost creates a default SQLite config store in the app directory. Set `config_store.enabled` to `false` only when you want file-only configuration with no config-backed Web UI/API edits. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth). @@ -129,7 +129,7 @@ Use `password_command` for short-lived database credentials such as AWS RDS IAM ### Disabled (file-only mode) -Use this when you want Bifrost to read all configuration from `config.json` only - no database, no Web UI. +Use this when you want Bifrost to read all configuration from `config.json` only - no configuration database and no config-backed Web UI/API edits. ```json { diff --git a/docs/deployment-guides/helm/governance.mdx b/docs/deployment-guides/helm/governance.mdx index 930c702a48..21bd301157 100644 --- a/docs/deployment-guides/helm/governance.mdx +++ b/docs/deployment-guides/helm/governance.mdx @@ -406,7 +406,7 @@ bifrost: ``` -In the default split mode, runtime UI and API edits are preserved while the matching Helm-rendered section is unchanged. When Helm changes a section, tier boundaries are replaced from the rendered `config.json`, while keyword lists are merged additively with stored runtime keywords (union with duplicates removed). Use `bifrost.governance.sourceOfTruth: config.json` only when Helm should replace stored governance state. +In the default split mode, runtime UI and API edits are preserved while the matching Helm-rendered section is unchanged. When Helm changes a section, tier boundaries are replaced from the rendered `config.json`, while keyword lists are merged additively with stored runtime keywords (union with duplicates removed). Use `bifrost.sourceOfTruth: config.json` only when Helm should replace stored governance state. See [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth) for the full startup rules. --- diff --git a/docs/docs.json b/docs/docs.json index a9acad502e..20fad2fe94 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -528,6 +528,7 @@ "icon": "file-code", "pages": [ "deployment-guides/config-json", + "deployment-guides/config-json/source-of-truth", "deployment-guides/config-json/schema-reference", "deployment-guides/config-json/client", "deployment-guides/config-json/providers", diff --git a/docs/quickstart/gateway/setting-up.mdx b/docs/quickstart/gateway/setting-up.mdx index f71aa5ba59..fe6e56df37 100644 --- a/docs/quickstart/gateway/setting-up.mdx +++ b/docs/quickstart/gateway/setting-up.mdx @@ -127,11 +127,20 @@ curl -X POST http://localhost:8080/v1/chat/completions \ --- -## Two Configuration Modes +## Configuration Modes -Bifrost supports **two configuration approaches**: +Bifrost has two user-facing configuration modes: -**How config and database state interact:** +| Mode | When to use it | Behavior | +|-------|------|----------| +| Web UI / database | You want to manage configuration from the dashboard or management API | Bifrost stores configuration in SQLite or PostgreSQL. If `config.json` is also present, it seeds or reconciles the database at startup. | +| File-only `config.json` | You want declarative, restart-based configuration with no config database | Bifrost loads configuration from the file at startup. Config-backed UI/API edits are unavailable and file changes require restart. | + + +Omitting `config_store` does **not** mean file-only mode. If `config.json` exists and `config_store` is omitted, Bifrost creates the default SQLite config store and treats the file as DB-backed startup configuration. Set `config_store.enabled: false` only when the file should be the only configuration source. + + +**How DB-backed `config.json` interacts with database state:** - Entities defined in `config.json` are stored with a content-based hash on first load. @@ -141,21 +150,26 @@ Bifrost supports **two configuration approaches**: `config.json` does not overwrite those database changes, because the content hash is unchanged. If you edit an entity directly in `config.json`, its hash changes and the file-based definition overwrites the database copy. +- Set `source_of_truth: "config.json"` only when explicitly present file sections should replace + matching DB state on startup. Missing sections leave DB rows untouched; present empty sections can + prune matching DB rows. -### Mode 1: Web UI Configuration +For the full reconciliation rules, see [Source of Truth & Reconciliation](/deployment-guides/config-json/source-of-truth). + +### Mode 1: Web UI / Database Configuration ![Configuration via UI](../../media/ui-config.png) **When the UI is available:** - No `config.json` file exists (Bifrost auto-creates SQLite database) -- `config.json` exists with `config_store` configured +- `config.json` exists with `config_store` omitted or enabled -### Mode 2: File-based Configuration +### Optional: DB-backed config.json You can view entire config schema [here](https://www.getbifrost.ai/schema) -**When to use:** Advanced setups, GitOps workflows, or when UI is not needed +**When to use:** Bootstrap from file, then keep runtime UI/API configuration available. Create `config.json` in your app directory: @@ -187,21 +201,46 @@ Create `config.json` in your app directory: } ``` -**Without `config_store` in `config.json`:** - -- **UI is disabled** - no real-time configuration possible -- **Read-only mode** - `config.json` is never modified -- **Memory-only** - all configurations loaded into memory at startup -- **Restart required** - changes to `config.json` only apply after restart - -**With `config_store` in `config.json`:** +**With `config_store` omitted or enabled:** - **UI is enabled** - full real-time configuration via web interface - **Database check** - Bifrost reconciles `config.json` against the config store using a content-based hash per entity - **Empty DB**: Bootstraps database with `config.json` settings, then uses DB for runtime reads - - **Existing DB**: Compares each entity's file hash to the stored `ConfigHash`. If the hash is unchanged, the DB copy is kept (UI/API edits are preserved). If you edit an entity in `config.json` so its hash changes, the file definition overwrites the DB copy for that entity. Entries added only via the UI/API (not present in `config.json`) are always preserved. + - **Existing DB**: Compares each entity's file hash to the stored `ConfigHash`. In default split mode, unchanged file-backed entities keep the DB copy so UI/API edits are preserved. If you edit that entity in `config.json`, the file definition overwrites the DB copy for that entity. - **Persistent storage** - all changes saved to database immediately +### Mode 2: File-only config.json + +**When to use:** Advanced setups, GitOps workflows, or when UI is not needed. + +Set `config_store.enabled` to `false` when `config.json` should be the only configuration source: + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini", "gpt-4o"], + "weight": 1.0 + } + ] + } + }, + "config_store": { + "enabled": false + } +} +``` + +- **Config UI/API edits are unavailable** - no config store is initialized +- **Read-only file** - Bifrost never modifies `config.json` +- **Memory-only config** - settings are loaded into memory at startup +- **Restart required** - changes to `config.json` apply after restart + **The Three Stores Explained:** - **Config Store**: Stores provider configs, API keys, MCP settings - Required for UI functionality diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index d2b5f8806f..be2466a698 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -98,9 +98,18 @@ type ClientConfig struct { HideDeletedVirtualKeysInFilters bool `json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys from logs/MCP filter data RoutingChainMaxDepth int `json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) MCPExternalClientURL *schemas.EnvVar `json:"mcp_external_client_url,omitempty"` // Public base URL used as redirect_uri when Bifrost acts as an OAuth client to upstream MCP servers. Supports env var syntax ("env.MY_VAR") + MCPServerAuthMode tables.MCPServerAuthMode `json:"mcp_server_auth_mode,omitempty"` // How /mcp authenticates inbound clients: headers (default), both, or oauth. + OAuth2ServerConfig *tables.OAuth2ServerConfig `json:"oauth2_server_config,omitempty"` // OAuth2 AS-specific settings (IssuerURL, token TTLs). Only relevant when MCPServerAuthMode is both or oauth. ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } +// IsMCPOAuthDiscoveryEnabled reports whether the well-known OAuth discovery +// endpoints and JWKS endpoint should be live. True when MCPServerAuthMode is +// both or oauth. +func (c *ClientConfig) IsMCPOAuthDiscoveryEnabled() bool { + return c.MCPServerAuthMode == tables.MCPServerAuthModeBoth || c.MCPServerAuthMode == tables.MCPServerAuthModeOAuth +} + // UnmarshalJSON defaults all bool fields to true when absent from JSON. func (c *ClientConfig) UnmarshalJSON(data []byte) error { type ClientConfigAlias ClientConfig @@ -366,6 +375,26 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { } } + // Only hash non-default values to avoid legacy config hash churn on upgrade — + // existing configs carry an empty auth mode and a nil OAuth2 server config. + if c.MCPServerAuthMode != "" { + hash.Write([]byte("mcpServerAuthMode:" + string(c.MCPServerAuthMode))) + } + // Hash OAuth2ServerConfig field-by-field (not via Marshal) for a stable, + // deterministic byte stream that does not depend on serializer field order. + if c.OAuth2ServerConfig != nil { + oc := c.OAuth2ServerConfig + if oc.IssuerURL.IsSet() { + if oc.IssuerURL.IsFromEnv() { + hash.Write([]byte("oauth2IssuerURL:env:" + oc.IssuerURL.EnvVar)) + } else { + hash.Write([]byte("oauth2IssuerURL:val:" + oc.IssuerURL.GetValue())) + } + } + hash.Write([]byte("oauth2AuthCodeTTL:" + strconv.Itoa(oc.AuthCodeTTL))) + hash.Write([]byte("oauth2AccessTokenTTL:" + strconv.Itoa(oc.AccessTokenTTL))) + } + return hex.EncodeToString(hash.Sum(nil)), nil } diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 6b2d287d61..db7470eea7 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -416,6 +416,7 @@ var configstoreMigrationSteps = []migrationStep{ {IDs: []string{"add_customer_name_unique_constraint_dedup", "add_customer_name_unique_constraint_index"}, run: migrationAddCustomerNameUniqueConstraint}, {IDs: []string{"null_legacy_customer_budget_id_refs"}, run: migrationNullLegacyCustomerBudgetID}, {IDs: []string{"add_skills_repo_tables"}, run: migrationAddSkillsRepoTables}, + {IDs: []string{"add_oauth2_server_tables"}, run: migrationAddOAuth2ServerTables}, } // quoteSQLiteIdentifier quotes a SQLite identifier, escaping any double quotes. @@ -10896,3 +10897,48 @@ func migrationAddCustomerNameUniqueConstraint(ctx context.Context, db *gorm.DB, }, }) } + +func migrationAddOAuth2ServerTables(ctx context.Context, db *gorm.DB, logger schemas.Logger) error { + migrationName := "add_oauth2_server_tables" + logger.Info("[configstore] starting migration %s", migrationName) + defer logger.Info("[configstore] finished migration %s", migrationName) + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{ + { + ID: migrationName, + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasColumn(&tables.TableClientConfig{}, "mcp_server_auth_mode") { + if err := mg.AddColumn(&tables.TableClientConfig{}, "MCPServerAuthMode"); err != nil { + return fmt.Errorf("add mcp_server_auth_mode column: %w", err) + } + } + if !mg.HasColumn(&tables.TableClientConfig{}, "oauth2_server_config_json") { + if err := mg.AddColumn(&tables.TableClientConfig{}, "OAuth2ServerConfigJSON"); err != nil { + return fmt.Errorf("add oauth2_server_config_json column: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasColumn(&tables.TableClientConfig{}, "oauth2_server_config_json") { + if err := mg.DropColumn(&tables.TableClientConfig{}, "OAuth2ServerConfigJSON"); err != nil { + return fmt.Errorf("drop oauth2_server_config_json column: %w", err) + } + } + if mg.HasColumn(&tables.TableClientConfig{}, "mcp_server_auth_mode") { + if err := mg.DropColumn(&tables.TableClientConfig{}, "MCPServerAuthMode"); err != nil { + return fmt.Errorf("drop mcp_server_auth_mode column: %w", err) + } + } + return nil + }, + }, + }) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running db migration %s: %w", migrationName, err) + } + return nil +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index ce0a3d02d4..03c2e1215b 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -2,7 +2,11 @@ package configstore import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "errors" "fmt" "sort" @@ -266,6 +270,8 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowPerRequestContentStorageOverride: config.AllowPerRequestContentStorageOverride, AllowPerRequestRawOverride: config.AllowPerRequestRawOverride, AllowDirectKeys: config.AllowDirectKeys, + MCPServerAuthMode: config.MCPServerAuthMode, + OAuth2ServerConfig: config.OAuth2ServerConfig, ConfigHash: config.ConfigHash, } // Delete existing client config and create new one in a transaction. @@ -530,6 +536,8 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er AllowPerRequestContentStorageOverride: dbConfig.AllowPerRequestContentStorageOverride, AllowPerRequestRawOverride: dbConfig.AllowPerRequestRawOverride, AllowDirectKeys: dbConfig.AllowDirectKeys, + MCPServerAuthMode: dbConfig.MCPServerAuthMode, + OAuth2ServerConfig: dbConfig.OAuth2ServerConfig, ConfigHash: dbConfig.ConfigHash, }, nil } @@ -6673,3 +6681,101 @@ func (s *RDBConfigStore) ReconcileMCPHeadersAfterMCPChange(ctx context.Context, return nil }) } + +// GetOAuth2SigningKey returns the signing key, creating and persisting a new +// RS2048 keypair if none exists yet. +func (s *RDBConfigStore) GetOAuth2SigningKey(ctx context.Context) (*tables.OAuth2SigningKey, error) { + key, err := s.loadOAuth2SigningKey(ctx) + if err != nil { + if errors.Is(err, ErrNotFound) { + // No key persisted yet — generate and store one atomically. + return s.createOAuth2SigningKey(ctx) + } + return nil, err + } + return key, nil +} + +// loadOAuth2SigningKey reads and decrypts the persisted signing key. It returns +// ErrNotFound when no key has been generated yet. +func (s *RDBConfigStore) loadOAuth2SigningKey(ctx context.Context) (*tables.OAuth2SigningKey, error) { + row, err := s.GetConfig(ctx, tables.GovernanceConfigKeyOAuth2SigningKey) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get oauth2 signing key: %w", err) + } + if row == nil || row.Value == "" { + return nil, ErrNotFound + } + var key tables.OAuth2SigningKey + if err := json.Unmarshal([]byte(row.Value), &key); err != nil { + return nil, fmt.Errorf("unmarshal oauth2 signing key: %w", err) + } + // Decrypt off the stored marker, not the live encrypt.IsEnabled() flag, so a + // key persisted while encryption was disabled is not mangled once encryption + // is later turned on (mirrors the AfterFind hooks on secret-bearing tables). + if err := key.Decrypt(); err != nil { + return nil, err + } + return &key, nil +} + +func (s *RDBConfigStore) createOAuth2SigningKey(ctx context.Context) (*tables.OAuth2SigningKey, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("generate RSA key: %w", err) + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, fmt.Errorf("marshal RSA private key: %w", err) + } + privPEM := string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})) + + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + if err != nil { + return nil, fmt.Errorf("marshal RSA public key: %w", err) + } + pubPEM := string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes})) + + key := &tables.OAuth2SigningKey{ + KID: uuid.New().String(), + PrivateKeyPEM: privPEM, + PublicKeyPEM: pubPEM, + } + + // Encrypt the private key in place and stamp EncryptionStatus before storage + // (mirrors the BeforeSave hooks on secret-bearing tables). + if err := key.Encrypt(); err != nil { + return nil, err + } + + data, err := json.Marshal(key) + if err != nil { + return nil, fmt.Errorf("marshal oauth2 signing key: %w", err) + } + + // Persist atomically with INSERT ... ON CONFLICT DO NOTHING so concurrent + // first-use callers cannot last-writer-wins different keypairs. If the row + // already exists (RowsAffected == 0), another caller won the race — reload + // and return the persisted key so every caller agrees on a single keypair. + res := s.DB().WithContext(ctx).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoNothing: true, + }).Create(&tables.TableGovernanceConfig{ + Key: tables.GovernanceConfigKeyOAuth2SigningKey, + Value: string(data), + }) + if res.Error != nil { + return nil, fmt.Errorf("persist oauth2 signing key: %w", res.Error) + } + if res.RowsAffected == 0 { + return s.loadOAuth2SigningKey(ctx) + } + + // We won the insert — return with plaintext private key for immediate use. + key.PrivateKeyPEM = privPEM + return key, nil +} diff --git a/framework/configstore/store.go b/framework/configstore/store.go index fbb9f80680..dc1c8310da 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -654,6 +654,10 @@ type ConfigStore interface { // pool, whose connections carry no cached plans. SQLite is a no-op. RefreshConnectionPool(ctx context.Context) error + // GetOAuth2SigningKey returns the signing key, creating and persisting one + // on first call. Always returns a usable key — never nil on a nil error. + GetOAuth2SigningKey(ctx context.Context) (*tables.OAuth2SigningKey, error) + // Cleanup Close(ctx context.Context) error } diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index ee1a0d395e..26acbb5272 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -48,6 +48,16 @@ type TableClientConfig struct { CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"` CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"` + // MCPServerAuthMode controls how /mcp authenticates inbound clients. + // Stored as a plain varchar column so it can be read without JSON parsing. + MCPServerAuthMode MCPServerAuthMode `gorm:"column:mcp_server_auth_mode;type:varchar(20);not null;default:'headers'" json:"mcp_server_auth_mode"` + // OAuth2ServerConfigJSON holds the OAuth2 AS-specific settings (IssuerURL, + // AuthCodeTTL, AccessTokenTTL) as a JSON blob. Only relevant when + // MCPServerAuthMode is both or oauth. Deserialized into OAuth2ServerConfig + // by AfterFind. The explicit column name avoids GORM deriving the leading + // acronym as "o_auth2_..." from the field name. + OAuth2ServerConfigJSON string `gorm:"column:oauth2_server_config_json;type:text" json:"-"` + // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` @@ -56,14 +66,15 @@ type TableClientConfig struct { UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Virtual fields for runtime use (not stored in DB) - PrometheusLabels []string `gorm:"-" json:"prometheus_labels"` - AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` - AllowedHeaders []string `gorm:"-" json:"allowed_headers,omitempty"` - RequiredHeaders []string `gorm:"-" json:"required_headers,omitempty"` - LoggingHeaders []string `gorm:"-" json:"logging_headers,omitempty"` - WhitelistedRoutes []string `gorm:"-" json:"whitelisted_routes,omitempty"` - HeaderFilterConfig *GlobalHeaderFilterConfig `gorm:"-" json:"header_filter_config,omitempty"` - Metadata map[string]any `gorm:"-" json:"metadata,omitempty"` + PrometheusLabels []string `gorm:"-" json:"prometheus_labels"` + AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` + AllowedHeaders []string `gorm:"-" json:"allowed_headers,omitempty"` + RequiredHeaders []string `gorm:"-" json:"required_headers,omitempty"` + LoggingHeaders []string `gorm:"-" json:"logging_headers,omitempty"` + WhitelistedRoutes []string `gorm:"-" json:"whitelisted_routes,omitempty"` + HeaderFilterConfig *GlobalHeaderFilterConfig `gorm:"-" json:"header_filter_config,omitempty"` + Metadata map[string]any `gorm:"-" json:"metadata,omitempty"` + OAuth2ServerConfig *OAuth2ServerConfig `gorm:"-" json:"oauth2_server_config,omitempty"` } // TableName sets the table name for each model @@ -152,6 +163,16 @@ func (cc *TableClientConfig) BeforeSave(tx *gorm.DB) error { cc.MetadataJSON = string(data) } + if cc.OAuth2ServerConfig != nil { + data, err := json.Marshal(cc.OAuth2ServerConfig) + if err != nil { + return err + } + cc.OAuth2ServerConfigJSON = string(data) + } else { + cc.OAuth2ServerConfigJSON = "" + } + return nil } @@ -211,5 +232,15 @@ func (cc *TableClientConfig) AfterFind(tx *gorm.DB) error { cc.Metadata = nil } + if cc.OAuth2ServerConfigJSON != "" { + var authCfg OAuth2ServerConfig + if err := json.Unmarshal([]byte(cc.OAuth2ServerConfigJSON), &authCfg); err != nil { + return err + } + cc.OAuth2ServerConfig = &authCfg + } else { + cc.OAuth2ServerConfig = nil + } + return nil } diff --git a/framework/configstore/tables/oauth2_server.go b/framework/configstore/tables/oauth2_server.go new file mode 100644 index 0000000000..5a9d713c41 --- /dev/null +++ b/framework/configstore/tables/oauth2_server.go @@ -0,0 +1,129 @@ +package tables + +import ( + "fmt" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/encrypt" +) + +// MCPServerAuthMode controls how Bifrost's /mcp endpoint authenticates inbound +// MCP clients. It does not affect how Bifrost authenticates to upstream MCP +// servers (governed by MCPClientConfig.AuthType). +type MCPServerAuthMode string + +const ( + DefaultAuthCodeTTL = 600 // 10 minutes + DefaultAccessTokenTTL = 600 // 10 minutes +) + +const ( + // MCPServerAuthModeHeaders accepts header credentials only: x-bf-vk, + // Authorization: Bearer , x-api-key, x-bf-mcp-session-id. + // Discovery endpoints return 404. Default — today's behavior. + MCPServerAuthModeHeaders MCPServerAuthMode = "headers" + + // MCPServerAuthModeBoth accepts both header credentials and Bifrost-issued + // JWTs. Discovery endpoints are live; existing header-credential clients + // that never receive a 401 are unaffected. + MCPServerAuthModeBoth MCPServerAuthMode = "both" + + // MCPServerAuthModeOAuth accepts Bifrost-issued JWTs only. Header + // credentials (VK / api-key / session) are rejected on /mcp. + // WARNING: existing virtual-key MCP integrations will stop working. + MCPServerAuthModeOAuth MCPServerAuthMode = "oauth" +) + +// OAuth2ServerConfig holds the OAuth2 authorization-server settings serialized +// as JSON into config_client.oauth2_server_config_json. Only meaningful when +// MCPServerAuthMode is MCPServerAuthModeBoth or MCPServerAuthModeOAuth. +// Not a table of its own. +type OAuth2ServerConfig struct { + // IssuerURL is Bifrost's OAuth authorization-server identity — it appears + // as the `issuer` in discovery documents and as the `iss` claim in every + // issued JWT. Supports env var syntax ("env.MY_VAR"). When empty, + // BuildBaseURL(request) is used as a per-request fallback, which works for + // single-host / dev deployments. Multi-host or reverse-proxy deployments + // MUST set a stable value; token verification fails when the Host header + // differs across nodes. + IssuerURL *schemas.EnvVar `json:"issuer_url,omitempty"` + + // AuthCodeTTL is the lifetime of the one-time authorization code issued by + // /oauth2/authorize and exchanged at /oauth2/token (seconds, default 600). + // The code is single-use — it is invalidated the moment it is exchanged or + // expires. Short TTL is intentional: if the window lapses the user simply + // re-authenticates. + AuthCodeTTL int `json:"auth_code_ttl"` + + // AccessTokenTTL is the lifetime of the issued JWT Bearer token (seconds, + // default 600 = 10 min). When the token expires the client uses its refresh + // token to silently obtain a new one without any user interaction. + AccessTokenTTL int `json:"access_token_ttl"` + + // Refresh tokens have no hard expiry — they are invalidated only by: + // - rotation on use (each /oauth2/token refresh call issues a new token + // and immediately invalidates the previous one) + // - bf_sub liveness check on refresh (VK deleted / user deactivated → + // invalid_grant, forcing re-authentication) + // - explicit revocation via the Connected Clients UI + // - EnforceAuthOnInference toggled on (revokes all session-mode grants) + // No RefreshTokenTTL field exists by design — there is no timer, only + // explicit invalidation paths. +} + +// DefaultOAuth2ServerConfig returns sensible defaults for the AS-specific settings. +func DefaultOAuth2ServerConfig() *OAuth2ServerConfig { + return &OAuth2ServerConfig{ + AuthCodeTTL: DefaultAuthCodeTTL, + AccessTokenTTL: DefaultAccessTokenTTL, + } +} + +// OAuth2SigningKey holds the single RS256 keypair used to sign Bifrost-issued +// JWTs. Stored as JSON in governance_config under GovernanceConfigKeyOAuth2SigningKey. +// The private key PEM is encrypted via framework/encrypt before storage. +type OAuth2SigningKey struct { + KID string `json:"kid"` // key ID embedded in JWT headers + PrivateKeyPEM string `json:"private_key_pem"` // encrypted at rest via framework/encrypt when EncryptionStatus is "encrypted" + PublicKeyPEM string `json:"public_key_pem"` // plaintext; public key is not sensitive + EncryptionStatus string `json:"encryption_status,omitempty"` // EncryptionStatusPlainText or EncryptionStatusEncrypted — records whether PrivateKeyPEM was encrypted at write time so reads do not depend on the current encrypt.IsEnabled() state +} + +// Encrypt encrypts PrivateKeyPEM in place and stamps EncryptionStatus, mirroring +// the BeforeSave hooks on secret-bearing tables. Because the signing key is +// persisted as a JSON blob inside governance_config (not its own GORM row) it +// cannot rely on GORM hooks, so callers invoke this before marshaling/storing. +func (k *OAuth2SigningKey) Encrypt() error { + if encrypt.IsEnabled() && k.PrivateKeyPEM != "" { + if err := encryptString(&k.PrivateKeyPEM); err != nil { + return fmt.Errorf("failed to encrypt oauth2 signing key: %w", err) + } + k.EncryptionStatus = EncryptionStatusEncrypted + } else { + k.EncryptionStatus = EncryptionStatusPlainText + } + return nil +} + +// Decrypt decrypts PrivateKeyPEM in place based on the stored EncryptionStatus +// marker, mirroring the AfterFind hooks on secret-bearing tables. Keys written +// before encryption was enabled are marked plain_text and returned as-is. +// Records written before this marker existed carry an empty status and fall back +// to the historical encrypt.IsEnabled() behavior. +func (k *OAuth2SigningKey) Decrypt() error { + if k.PrivateKeyPEM == "" { + return nil + } + shouldDecrypt := k.EncryptionStatus == EncryptionStatusEncrypted || + (k.EncryptionStatus == "" && encrypt.IsEnabled()) + if shouldDecrypt { + if err := decryptString(&k.PrivateKeyPEM); err != nil { + return fmt.Errorf("failed to decrypt oauth2 signing key: %w", err) + } + } + return nil +} + +// GovernanceConfigKeyOAuth2SigningKey is the governance_config key under which +// the OAuth2 signing keypair is stored. +const GovernanceConfigKeyOAuth2SigningKey = "oauth2_signing_key" diff --git a/framework/logstore/migrations.go b/framework/logstore/migrations.go index c46e94094c..74e9874331 100644 --- a/framework/logstore/migrations.go +++ b/framework/logstore/migrations.go @@ -2939,12 +2939,9 @@ func migrationAddAliasColumn(ctx context.Context, db *gorm.DB, logger schemas.Lo // alias config when the request's model was resolved via alias mapping and the // alias defines them. func migrationAddCanonicalModelColumns(ctx context.Context, db *gorm.DB, logger schemas.Logger) error { - migrationName := "logs_recreate_filter_customers_matview_multivalue" + migrationName := "logs_add_canonical_model_columns" logger.Info("[logstore] starting migration %s", migrationName) defer logger.Info("[logstore] finished migration %s", migrationName) - if db.Dialector.Name() != "postgres" { - return nil - } opts := *migrator.DefaultOptions opts.UseTransaction = true m := migrator.New(db, &opts, []*migrator.Migration{{ diff --git a/framework/temptoken/scope.go b/framework/temptoken/scope.go index e7cb789891..bb9660f4d7 100644 --- a/framework/temptoken/scope.go +++ b/framework/temptoken/scope.go @@ -23,6 +23,14 @@ const ( // flow endpoints. Bound resource_id is the headers flow ID. Parallel // of MCPAuthScopeName for the per-user-headers surface. MCPHeadersAuthScopeName = "mcp_headers_auth" + + // OAuth2ConsentScopeName names the scope that authorizes the OAuth2 + // downstream consent page to call the consent flow endpoints. Bound + // resource_id is the authorize-request flow ID. The consent page is + // public (outside /workspace) so it cannot rely on dashboard auth; + // this token is the sole credential binding the browser session to the + // pending authorization request. + OAuth2ConsentScopeName = "oauth2_consent" ) // RoutePattern is one (method, path) pair a Scope grants access to. The path diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index e420034431..59cbc0aede 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -272,7 +272,7 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, } } // 2. Check provider filtering - if requestType != schemas.MCPToolExecutionRequest && !r.isProviderAllowed(vk, provider) { + if requestType != schemas.MCPToolExecutionRequest && requestType != schemas.ListModelsRequest && !r.isProviderAllowed(vk, provider) { return &EvaluationResult{ Decision: DecisionProviderBlocked, Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", provider), diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go index 86fb7a4041..9e5fd94cad 100644 --- a/plugins/governance/resolver_test.go +++ b/plugins/governance/resolver_test.go @@ -92,6 +92,34 @@ func TestBudgetResolver_EvaluateRequest_ProviderBlocked(t *testing.T) { assertVirtualKeyFound(t, result) } +// TestBudgetResolver_EvaluateRequest_ListModelsBypassesProviderBlock verifies that a VK +// with a restricted provider allowlist does NOT block a ListModelsRequest for a provider +// outside its allowlist. List-models fans out across all configured providers and is +// filtered per-VK in the PostHook, so provider gating must be skipped here (resolver.go:275). +func TestBudgetResolver_EvaluateRequest_ListModelsBypassesProviderBlock(t *testing.T) { + logger := NewMockLogger() + + // Same Anthropic-only VK as TestBudgetResolver_EvaluateRequest_ProviderBlocked. + providerConfigs := []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("anthropic", []string{"claude-3-sonnet"}), + } + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", providerConfigs) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }, nil) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, nil, logger, nil) + ctx := &schemas.BifrostContext{} + + // OpenAI is not in the allowlist, but ListModelsRequest must not be provider-blocked. + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ListModelsRequest, false) + + assert.NotEqual(t, DecisionProviderBlocked, result.Decision, "ListModelsRequest must bypass provider allowlist gating") + assertDecision(t, DecisionAllow, result) +} + // TestBudgetResolver_EvaluateRequest_ModelBlocked tests model filtering func TestBudgetResolver_EvaluateRequest_ModelBlocked(t *testing.T) { logger := NewMockLogger() diff --git a/scripts/litellm-to-bifrost/bifrost.go b/scripts/litellm-to-bifrost/bifrost.go new file mode 100644 index 0000000000..438ba02fa6 --- /dev/null +++ b/scripts/litellm-to-bifrost/bifrost.go @@ -0,0 +1,529 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +var httpClient = &http.Client{Timeout: 30 * time.Second} + +// Bifrost sink types: the subset of the management-API request bodies the +// migration writes. Kept as local DTOs so the migration depends only on the +// wire (JSON) contract, not on the transports package. + +// BifrostClient writes entities to the Bifrost management API. +type BifrostClient struct { + BaseURL string // e.g. http://localhost:8080 + APIKey string // optional bearer token +} + +// BifrostCreateCustomerRequest is the body for POST /api/governance/customers. +type BifrostCreateCustomerRequest struct { + Name string `json:"name"` + Budgets []BifrostCreateBudgetRequest `json:"budgets,omitempty"` + RateLimit *BifrostCreateRateLimitRequest `json:"rate_limit,omitempty"` +} + +// BifrostCreateTeamRequest is the body for POST /api/governance/teams. +// CustomerID links the team to a migrated customer (LiteLLM organization); it is +// omitted for a standalone team or one whose customer could not be resolved. +type BifrostCreateTeamRequest struct { + Name string `json:"name"` + CustomerID *string `json:"customer_id,omitempty"` + Budgets []BifrostCreateBudgetRequest `json:"budgets,omitempty"` + RateLimit *BifrostCreateRateLimitRequest `json:"rate_limit,omitempty"` +} + +// BifrostCreateUserRequest is the body for POST /api/users. role_id is +// intentionally omitted: LiteLLM's string user_role has no Bifrost numeric +// role_id mapping, and user governance is driven by access profiles. +type BifrostCreateUserRequest struct { + Name string `json:"name"` + Email string `json:"email"` +} + +// BifrostCreateVirtualKeyRequest is the body for +// POST /api/governance/virtual-keys. TeamID and CustomerID are mutually +// exclusive (VK ownership is team XOR customer). The key value is +// server-generated; the LiteLLM token is not carried. +type BifrostCreateVirtualKeyRequest struct { + Name string `json:"name"` + ProviderConfigs []BifrostVKProviderConfigRequest `json:"provider_configs,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budgets []BifrostCreateBudgetRequest `json:"budgets,omitempty"` + RateLimit *BifrostCreateRateLimitRequest `json:"rate_limit,omitempty"` + IsActive *bool `json:"is_active,omitempty"` +} + +// BifrostCreateModelConfigRequest is the body for +// POST /api/governance/model-configs. Scope is global for LiteLLM model-level +// budgets and rate limits. +type BifrostCreateModelConfigRequest struct { + ModelName string `json:"model_name"` + Provider *string `json:"provider,omitempty"` + Scope string `json:"scope,omitempty"` + Budgets []BifrostCreateBudgetRequest `json:"budgets,omitempty"` + RateLimit *BifrostCreateRateLimitRequest `json:"rate_limit,omitempty"` +} + +// bifrostModelConfigScopeGlobal is the Bifrost scope for global model configs. +const bifrostModelConfigScopeGlobal = "global" + +// BifrostVKProviderConfigRequest selects which keys of a provider a VK may use. +// key_ids ["*"] grants all keys; specific UUIDs limit to those keys; empty denies all. +// allowed_models ["*"] allows all models on the matched keys. +type BifrostVKProviderConfigRequest struct { + Provider string `json:"provider"` + KeyIDs []string `json:"key_ids,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` +} + +// BifrostCreateBudgetRequest is a single Bifrost budget. +type BifrostCreateBudgetRequest struct { + MaxLimit float64 `json:"max_limit"` + ResetDuration string `json:"reset_duration"` // e.g. "30s", "5m", "1h", "1d", "1w", "1M", "1Y" +} + +// BifrostCreateRateLimitRequest is a Bifrost rate limit. Token and request +// dimensions are independent; each is omitted when LiteLLM has no positive +// limit for it. +type BifrostCreateRateLimitRequest struct { + TokenMaxLimit *int64 `json:"token_max_limit,omitempty"` + TokenResetDuration *string `json:"token_reset_duration,omitempty"` + RequestMaxLimit *int64 `json:"request_max_limit,omitempty"` + RequestResetDuration *string `json:"request_reset_duration,omitempty"` +} + +// bifrostDefaultConcurrency / bifrostDefaultBufferSize mirror +// schemas.DefaultConcurrency / DefaultBufferSize. The provider PUT/POST require +// both to be > 0. +const ( + bifrostDefaultConcurrency = 1000 + bifrostDefaultBufferSize = 5000 +) + +// BifrostConcurrencyAndBufferSize mirrors schemas.ConcurrencyAndBufferSize. +type BifrostConcurrencyAndBufferSize struct { + Concurrency int `json:"concurrency"` + BufferSize int `json:"buffer_size"` +} + +// BifrostNetworkConfig is the subset of schemas.NetworkConfig the migration +// sets. BaseURL is provider-level in Bifrost (LiteLLM api_base is +// per-deployment). +type BifrostNetworkConfig struct { + BaseURL string `json:"base_url,omitempty"` +} + +// BifrostCustomProviderConfig mirrors schemas.CustomProviderConfig for a +// synthesized provider that wraps a base provider at a distinct base URL. +type BifrostCustomProviderConfig struct { + IsKeyLess bool `json:"is_key_less"` + BaseProviderType string `json:"base_provider_type"` +} + +// BifrostProviderUpdatePayload is the body for PUT /api/providers/{provider} +// for standard providers; keys are managed separately via the /keys endpoint. +type BifrostProviderUpdatePayload struct { + NetworkConfig *BifrostNetworkConfig `json:"network_config,omitempty"` + ConcurrencyAndBufferSize BifrostConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` +} + +// BifrostProviderCreatePayload is the body for POST /api/providers. +type BifrostProviderCreatePayload struct { + Provider string `json:"provider"` + CustomProviderConfig *BifrostCustomProviderConfig `json:"custom_provider_config,omitempty"` + NetworkConfig *BifrostNetworkConfig `json:"network_config,omitempty"` + ConcurrencyAndBufferSize BifrostConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` +} + +// BifrostProviderKey is the body for POST /api/providers/{provider}/keys. It +// decodes into schemas.Key; Value is a plain string ("env.FOO" => from +// environment, else a literal value). The *_key_config fields carry provider- +// specific credentials (Azure endpoint, AWS credentials, GCP credentials) or +// the per-key server URL for the keyless self-hosted providers (vllm/ollama). +type BifrostProviderKey struct { + Name string `json:"name"` + Value string `json:"value"` + Models []string `json:"models"` + Weight float64 `json:"weight"` + VLLMKeyConfig *BifrostVLLMKeyConfig `json:"vllm_key_config,omitempty"` + OllamaKeyConfig *BifrostKeyURLConfig `json:"ollama_key_config,omitempty"` + AzureKeyConfig *BifrostAzureKeyConfig `json:"azure_key_config,omitempty"` + BedrockKeyConfig *BifrostBedrockKeyConfig `json:"bedrock_key_config,omitempty"` + VertexKeyConfig *BifrostVertexKeyConfig `json:"vertex_key_config,omitempty"` +} + +// BifrostVLLMKeyConfig mirrors schemas.VLLMKeyConfig: a per-key vLLM server URL +// plus the exact served model used to select the key. +type BifrostVLLMKeyConfig struct { + URL string `json:"url"` + ModelName string `json:"model_name,omitempty"` +} + +// BifrostKeyURLConfig mirrors schemas.OllamaKeyConfig: a per-key server URL with +// no model selector. +type BifrostKeyURLConfig struct { + URL string `json:"url"` +} + +// BifrostAzureKeyConfig mirrors schemas.AzureKeyConfig. Endpoint is the Azure +// OpenAI service URL (e.g. https://myazure.openai.azure.com/). For Entra ID +// (service principal) auth, set ClientID + ClientSecret + TenantID; for API-key +// auth, leave those nil and set the key Value instead. +type BifrostAzureKeyConfig struct { + Endpoint string `json:"endpoint"` + ClientID *string `json:"client_id,omitempty"` + ClientSecret *string `json:"client_secret,omitempty"` + TenantID *string `json:"tenant_id,omitempty"` +} + +// BifrostBedrockKeyConfig mirrors schemas.BedrockKeyConfig. For static IAM +// credentials set AccessKey + SecretKey (+ optional SessionToken). For role +// assumption set RoleARN (+ optional RoleSessionName). For EC2 instance-profile +// or ECS task-role auth leave AccessKey and SecretKey empty; Bifrost uses the +// SDK default credential chain. Region is optional but recommended. +type BifrostBedrockKeyConfig struct { + AccessKey string `json:"access_key,omitempty"` + SecretKey string `json:"secret_key,omitempty"` + SessionToken *string `json:"session_token,omitempty"` + Region *string `json:"region,omitempty"` + RoleARN *string `json:"role_arn,omitempty"` + RoleSessionName *string `json:"session_name,omitempty"` +} + +// BifrostVertexKeyConfig mirrors schemas.VertexKeyConfig. ProjectID and Region +// are required. AuthCredentials is the service-account JSON (or a path/env ref +// to it); leave empty to use Application Default Credentials (ADC). +type BifrostVertexKeyConfig struct { + ProjectID string `json:"project_id"` + ProjectNumber string `json:"project_number,omitempty"` + Region string `json:"region"` + AuthCredentials string `json:"auth_credentials"` +} + +// CreateCustomer posts a single customer to POST /api/governance/customers. +func (c *BifrostClient) CreateCustomer(ctx context.Context, in *BifrostCreateCustomerRequest) error { + return c.sendJSON(ctx, http.MethodPost, "/api/governance/customers", in, "customer "+in.Name) +} + +// CreateTeam posts a single team to POST /api/governance/teams. +func (c *BifrostClient) CreateTeam(ctx context.Context, in *BifrostCreateTeamRequest) error { + return c.sendJSON(ctx, http.MethodPost, "/api/governance/teams", in, "team "+in.Name) +} + +// FindCustomerByName resolves a Bifrost customer id by exact name via +// GET /api/governance/customers?search=. It returns ok=false (no error) when no +// customer matches, so the caller can create the team unlinked and warn. +func (c *BifrostClient) FindCustomerByName(ctx context.Context, name string) (id string, ok bool, err error) { + body, err := c.getJSON(ctx, "/api/governance/customers?search="+url.QueryEscape(name), fmt.Sprintf("find customer %q", name)) + if err != nil { + return "", false, err + } + + var out struct { + Customers []struct { + ID string `json:"id"` + Name string `json:"name"` + } `json:"customers"` + } + if err := json.Unmarshal(body, &out); err != nil { + return "", false, fmt.Errorf("decode customers: %w", err) + } + // search may match by substring; require an exact name match. + for _, cust := range out.Customers { + if cust.Name == name { + return cust.ID, true, nil + } + } + return "", false, nil +} + +// CreateUser creates a Bifrost user via POST /api/users and returns its id. An +// existing user (409 on duplicate email) is resolved to the existing id via +// FindUserByEmail, so the caller can still link team memberships. +func (c *BifrostClient) CreateUser(ctx context.Context, in *BifrostCreateUserRequest) (string, error) { + body, status, err := c.doRequest(ctx, http.MethodPost, "/api/users", in) + if err != nil { + return "", fmt.Errorf("create user %q: %w", in.Email, err) + } + + if status == http.StatusConflict { + id, ok, ferr := c.FindUserByEmail(ctx, in.Email) + if ferr != nil { + return "", fmt.Errorf("create user %q: already exists but lookup failed: %w", in.Email, ferr) + } + if !ok { + return "", fmt.Errorf("create user %q: conflict but no matching user found", in.Email) + } + return id, nil + } + if status < 200 || status >= 300 { + return "", fmt.Errorf("create user %q: status %d: %s", in.Email, status, strings.TrimSpace(string(body))) + } + + var out struct { + User struct { + ID string `json:"id"` + } `json:"user"` + } + if err := json.Unmarshal(body, &out); err != nil { + return "", fmt.Errorf("decode created user: %w", err) + } + if out.User.ID == "" { + return "", fmt.Errorf("create user %q: empty id in response", in.Email) + } + return out.User.ID, nil +} + +// ListProviders returns the names of providers configured in Bifrost via +// GET /api/providers. The VK migration uses this to drop allow-list entries for +// providers that were not migrated (a VK create rejects unknown providers). +func (c *BifrostClient) ListProviders(ctx context.Context) (map[string]bool, error) { + body, err := c.getJSON(ctx, "/api/providers", "list providers") + if err != nil { + return nil, err + } + var out struct { + Providers []struct { + Name string `json:"name"` + } `json:"providers"` + } + if err := json.Unmarshal(body, &out); err != nil { + return nil, fmt.Errorf("decode providers: %w", err) + } + set := make(map[string]bool, len(out.Providers)) + for _, p := range out.Providers { + set[p.Name] = true + } + return set, nil +} + +// BifrostKeyMeta is the key identity returned by GET /api/keys. The value is +// redacted; only the UUID (KeyID) and name are needed for VK key attachment. +type BifrostKeyMeta struct { + KeyID string `json:"key_id"` + Name string `json:"name"` + Provider string `json:"provider"` +} + +// ListAllKeys returns every provider key registered in Bifrost via GET /api/keys. +// The VK migration uses the KeyID UUIDs to attach specific keys to virtual keys. +func (c *BifrostClient) ListAllKeys(ctx context.Context) ([]BifrostKeyMeta, error) { + body, err := c.getJSON(ctx, "/api/keys", "list all keys") + if err != nil { + return nil, err + } + var keys []BifrostKeyMeta + if err := json.Unmarshal(body, &keys); err != nil { + return nil, fmt.Errorf("decode keys: %w", err) + } + return keys, nil +} + +// CreateVirtualKey creates a Bifrost virtual key via +// POST /api/governance/virtual-keys. The key value is server-generated. +func (c *BifrostClient) CreateVirtualKey(ctx context.Context, in *BifrostCreateVirtualKeyRequest) error { + return c.sendProvider(ctx, http.MethodPost, "/api/governance/virtual-keys", in, "virtual key "+in.Name) +} + +// CreateModelConfig creates a global Bifrost model config via +// POST /api/governance/model-configs. An existing config is treated as success. +func (c *BifrostClient) CreateModelConfig(ctx context.Context, in ModelConfigPlan) error { + req := BifrostCreateModelConfigRequest{ + ModelName: in.ModelName, + Provider: in.Provider, + Scope: bifrostModelConfigScopeGlobal, + Budgets: in.Budgets, + RateLimit: in.RateLimit, + } + return c.sendProvider(ctx, http.MethodPost, "/api/governance/model-configs", req, "model config "+modelConfigSignature(in)) +} + +// FindUserByEmail resolves a Bifrost user id by exact email via +// GET /api/users?search=. Returns ok=false (no error) when none matches. +func (c *BifrostClient) FindUserByEmail(ctx context.Context, email string) (id string, ok bool, err error) { + body, err := c.getJSON(ctx, "/api/users?search="+url.QueryEscape(email), fmt.Sprintf("find user %q", email)) + if err != nil { + return "", false, err + } + var out struct { + Users []struct { + ID string `json:"id"` + Email string `json:"email"` + } `json:"users"` + } + if err := json.Unmarshal(body, &out); err != nil { + return "", false, fmt.Errorf("decode users: %w", err) + } + for _, u := range out.Users { + if strings.EqualFold(u.Email, email) { + return u.ID, true, nil + } + } + return "", false, nil +} + +// FindTeamByName resolves a Bifrost team id by exact name via +// GET /api/governance/teams?search=. Returns ok=false (no error) when none +// matches, so the caller can warn and skip the membership link. +func (c *BifrostClient) FindTeamByName(ctx context.Context, name string) (id string, ok bool, err error) { + body, err := c.getJSON(ctx, "/api/governance/teams?search="+url.QueryEscape(name), fmt.Sprintf("find team %q", name)) + if err != nil { + return "", false, err + } + var out struct { + Teams []struct { + ID string `json:"id"` + Name string `json:"name"` + } `json:"teams"` + } + if err := json.Unmarshal(body, &out); err != nil { + return "", false, fmt.Errorf("decode teams: %w", err) + } + for _, tm := range out.Teams { + if tm.Name == name { + return tm.ID, true, nil + } + } + return "", false, nil +} + +// AddTeamMember links a user to a team via POST /api/teams/{id}/members. An +// existing membership (409) is treated as success. +func (c *BifrostClient) AddTeamMember(ctx context.Context, teamID, userID string) error { + return c.sendProvider(ctx, http.MethodPost, "/api/teams/"+teamID+"/members", map[string]string{"user_id": userID}, "team "+teamID+" member "+userID) +} + +// EnsureProvider creates or upserts a Bifrost provider's config (no keys). A +// custom provider (distinct base URL) is created via POST /api/providers; a +// standard provider is upserted via PUT /api/providers/{provider}. An existing +// provider (409 / already-present) is treated as success. +func (c *BifrostClient) EnsureProvider(ctx context.Context, p ProviderPlan) error { + concurrency := BifrostConcurrencyAndBufferSize{Concurrency: bifrostDefaultConcurrency, BufferSize: bifrostDefaultBufferSize} + + var network *BifrostNetworkConfig + if p.BaseURL != "" { + network = &BifrostNetworkConfig{BaseURL: p.BaseURL} + } + + if p.IsCustom { + payload := BifrostProviderCreatePayload{ + Provider: p.Name, + CustomProviderConfig: &BifrostCustomProviderConfig{BaseProviderType: p.BaseProvider}, + NetworkConfig: network, + ConcurrencyAndBufferSize: concurrency, + } + return c.sendProvider(ctx, http.MethodPost, "/api/providers", payload, p.Name) + } + + payload := BifrostProviderUpdatePayload{ + NetworkConfig: network, + ConcurrencyAndBufferSize: concurrency, + } + return c.sendProvider(ctx, http.MethodPut, "/api/providers/"+p.Name, payload, p.Name) +} + +// CreateProviderKey adds a single key to an existing provider via +// POST /api/providers/{provider}/keys. An already-present key (409) is treated +// as success. +func (c *BifrostClient) CreateProviderKey(ctx context.Context, provider string, k KeyPlan) error { + key := BifrostProviderKey{ + Name: k.Name, + Value: k.Value, + Models: k.Models, + Weight: 1.0, + AzureKeyConfig: k.AzureKeyConfig, + BedrockKeyConfig: k.BedrockKeyConfig, + VertexKeyConfig: k.VertexKeyConfig, + } + // Self-hosted providers carry the server URL on the key itself. + if k.URL != "" { + switch provider { + case "vllm": + key.VLLMKeyConfig = &BifrostVLLMKeyConfig{URL: k.URL, ModelName: k.VLLMModelName} + case "ollama": + key.OllamaKeyConfig = &BifrostKeyURLConfig{URL: k.URL} + } + } + return c.sendProvider(ctx, http.MethodPost, "/api/providers/"+provider+"/keys", key, provider+"/"+k.Name) +} + +func (c *BifrostClient) endpoint(path string) string { + return strings.TrimRight(c.BaseURL, "/") + path +} + +func (c *BifrostClient) doRequest(ctx context.Context, method, path string, body any) ([]byte, int, error) { + var rdr io.Reader + if body != nil { + payload, err := json.Marshal(body) + if err != nil { + return nil, 0, err + } + rdr = bytes.NewReader(payload) + } + + req, err := http.NewRequestWithContext(ctx, method, c.endpoint(path), rdr) + if err != nil { + return nil, 0, err + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if c.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+c.APIKey) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + + out, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err) + } + return out, resp.StatusCode, nil +} + +// getJSON performs a GET with optional bearer and returns the body, mapping +// non-2xx to an error. +func (c *BifrostClient) getJSON(ctx context.Context, path, what string) ([]byte, error) { + body, status, err := c.doRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return nil, fmt.Errorf("%s: %w", what, err) + } + if status < 200 || status >= 300 { + return nil, fmt.Errorf("%s: status %d: %s", what, status, strings.TrimSpace(string(body))) + } + return body, nil +} + +// sendProvider marshals body and performs the request, mapping 2xx and 409 +// (already exists) to success. +func (c *BifrostClient) sendProvider(ctx context.Context, method, path string, body any, what string) error { + return c.sendJSON(ctx, method, path, body, what) +} + +func (c *BifrostClient) sendJSON(ctx context.Context, method, path string, body any, what string) error { + out, status, err := c.doRequest(ctx, method, path, body) + if err != nil { + return fmt.Errorf("%s %s: %w", method, what, err) + } + if status == http.StatusConflict { + return nil // already exists + } + if status < 200 || status >= 300 { + return fmt.Errorf("%s %s: status %d: %s", method, what, status, strings.TrimSpace(string(out))) + } + return nil +} diff --git a/scripts/litellm-to-bifrost/go.mod b/scripts/litellm-to-bifrost/go.mod new file mode 100644 index 0000000000..acd903d589 --- /dev/null +++ b/scripts/litellm-to-bifrost/go.mod @@ -0,0 +1,24 @@ +module github.com/maximhq/bifrost/scripts/litellm-to-bifrost + +go 1.26.4 + +require ( + golang.org/x/crypto v0.53.0 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.10.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.15.0 // indirect + golang.org/x/sync v0.21.0 // indirect + golang.org/x/sys v0.46.0 // indirect + golang.org/x/text v0.38.0 // indirect +) diff --git a/scripts/litellm-to-bifrost/go.sum b/scripts/litellm-to-bifrost/go.sum new file mode 100644 index 0000000000..d91cdb409a --- /dev/null +++ b/scripts/litellm-to-bifrost/go.sum @@ -0,0 +1,47 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.10.0 h1:VhSvgU2jSli8o3AqIEOTJr7rZwAEUVo4E4XhR94Zfr0= +github.com/jackc/pgx/v5 v5.10.0/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.15.0 h1:D0RCU5rMAp+SpgkiNdrjfJ+LX4J1M32V2NeCY7EJ6hc= +github.com/rogpeppe/go-internal v1.15.0/go.mod h1:DrUVZyrJU+txYW5/1kwtXQSMFio52ZOxX7yM1VHvnxs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= +golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= +golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= +golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw= +golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE= +golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/scripts/litellm-to-bifrost/litellm/crypto.go b/scripts/litellm-to-bifrost/litellm/crypto.go new file mode 100644 index 0000000000..ac0cbe4332 --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/crypto.go @@ -0,0 +1,44 @@ +package litellm + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + + "golang.org/x/crypto/nacl/secretbox" +) + +// secretBoxNonceLen is the 24-byte NaCl SecretBox nonce that LiteLLM prepends to +// the ciphertext (PyNaCl's box.encrypt returns nonce||ciphertext). +const secretBoxNonceLen = 24 + +// DecryptValue reverses LiteLLM's encrypt_value_helper. The stored string is +// base64 (url-safe, with a standard-base64 fallback for the old format) over a +// NaCl SecretBox sealed message whose 32-byte key is SHA-256(saltKey); the first +// 24 bytes are the nonce. saltKey is LITELLM_SALT_KEY, falling back to the proxy +// master_key. An empty input decrypts to an empty string. +func DecryptValue(encoded, saltKey string) (string, error) { + if encoded == "" { + return "", nil + } + raw, err := base64.URLEncoding.DecodeString(encoded) + if err != nil { + raw, err = base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", fmt.Errorf("base64 decode: %w", err) + } + } + if len(raw) < secretBoxNonceLen+secretbox.Overhead { + return "", fmt.Errorf("ciphertext too short: %d bytes", len(raw)) + } + + key := sha256.Sum256([]byte(saltKey)) + var nonce [secretBoxNonceLen]byte + copy(nonce[:], raw[:secretBoxNonceLen]) + + out, ok := secretbox.Open(nil, raw[secretBoxNonceLen:], &nonce, &key) + if !ok { + return "", fmt.Errorf("decrypt failed (wrong salt key?)") + } + return string(out), nil +} diff --git a/scripts/litellm-to-bifrost/litellm/db.go b/scripts/litellm-to-bifrost/litellm/db.go new file mode 100644 index 0000000000..cd7dd1b81b --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/db.go @@ -0,0 +1,418 @@ +package litellm + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// The LiteLLM management API masks secrets (/credentials shows "os****KE", +// /model/info omits api_key entirely), so the real key material is read from two +// places: the config.yaml credential_list/model_list (plaintext env refs and +// literals) and the Postgres tables LiteLLM_CredentialsTable and +// LiteLLM_ProxyModelTable (values encrypted with the salt key). CredentialStore +// loads both and hands back fully resolved credentials. + +// OpenDB opens the LiteLLM Postgres database from its DATABASE_URL DSN. +func OpenDB(dsn string) (*gorm.DB, error) { + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + if err != nil { + return nil, fmt.Errorf("open litellm db: %w", err) + } + return db, nil +} + +// credentialRow mirrors LiteLLM_CredentialsTable. credential_values holds each +// value encrypted; credential_info is plaintext JSON. +type credentialRow struct { + CredentialName string + CredentialValues string // jsonb: {"api_key": "", "api_base": ""} + CredentialInfo string // jsonb: {"custom_llm_provider": "openai"} +} + +func (credentialRow) TableName() string { return "LiteLLM_CredentialsTable" } + +// proxyModelRow mirrors LiteLLM_ProxyModelTable. Every string value in +// litellm_params is encrypted; numeric values (rpm/tpm) are stored as-is. +type proxyModelRow struct { + ModelID string + ModelName string + LitellmParams string // jsonb: {"model": "", "api_key": "", "api_base": "", ...} +} + +func (proxyModelRow) TableName() string { return "LiteLLM_ProxyModelTable" } + +// dbCredentialValues is the encrypted credential_values payload. +type dbCredentialValues struct { + APIKey *string `json:"api_key"` + APIBase *string `json:"api_base"` +} + +// dbCredentialInfo is the plaintext credential_info payload. +type dbCredentialInfo struct { + CustomLLMProvider *string `json:"custom_llm_provider"` +} + +// dbLitellmParams is the subset of an encrypted litellm_params blob the +// migration needs: inline credential fields plus per-deployment budget. String +// values are encrypted; numeric values (max_budget) are stored as-is. +type dbLitellmParams struct { + APIKey *string `json:"api_key"` + APIBase *string `json:"api_base"` + CustomLLMProvider *string `json:"custom_llm_provider"` + MaxBudget *float64 `json:"max_budget"` + BudgetDuration *string `json:"budget_duration"` + // Azure + AzureClientID *string `json:"azure_client_id"` + AzureClientSecret *string `json:"azure_client_secret"` + AzureTenantID *string `json:"azure_tenant_id"` + AzureADToken *string `json:"azure_ad_token"` + // Bedrock (AWS) + AWSAccessKeyID *string `json:"aws_access_key_id"` + AWSSecretAccessKey *string `json:"aws_secret_access_key"` + AWSRegionName *string `json:"aws_region_name"` + AWSRoleName *string `json:"aws_role_name"` + AWSSessionName *string `json:"aws_session_name"` + AWSSessionToken *string `json:"aws_session_token"` + // Vertex (GCP) + VertexProject *string `json:"vertex_project"` + VertexLocation *string `json:"vertex_location"` + VertexCredentials *string `json:"vertex_credentials"` +} + +// Deployment is the resolved (decrypted) per-deployment view the migration needs +// from litellm_params: the inline credential (when the deployment carries one +// directly) and its budget. Rate limits (rpm/tpm) are read separately from the +// /model/info response, where they are not encrypted. +type Deployment struct { + APIKey string + APIBase string + CustomLLMProvider string + MaxBudget *float64 + BudgetDuration *string + // Azure + AzureClientID string + AzureClientSecret string + AzureTenantID string + AzureADToken string + // Bedrock (AWS) + AWSAccessKeyID string + AWSSecretAccessKey string + AWSRegionName string + AWSRoleName string + AWSSessionName string + AWSSessionToken string + // Vertex (GCP) + VertexProject string + VertexLocation string + VertexCredentials string +} + +// InlineCredential returns the deployment's inline credential, or nil when it +// carries neither an api_key nor an api_base (i.e. it uses a named credential). +func (d *Deployment) InlineCredential() *LiteLLMModelCredential { + if d == nil || (d.APIKey == "" && d.APIBase == "") { + return nil + } + return &LiteLLMModelCredential{ + CredentialValues: &LiteLLMModelCredentialValues{ApiKey: strPtr(d.APIKey), ApiBase: strPtr(d.APIBase)}, + CredentialInfo: &LiteLLMModelCredentialInfo{CustomLLMProvider: d.CustomLLMProvider}, + } +} + +// CredentialStore resolves LiteLLM credentials from config + database, decrypting +// the database values with the salt key. +type CredentialStore struct { + db *gorm.DB + cfg *LiteLLMConfig + saltKey string + byName map[string]*LiteLLMModelCredential +} + +// NewCredentialStore loads every named credential from the config credential_list +// (plaintext) and the LiteLLM_CredentialsTable (decrypted), with the database +// taking precedence on name collisions. An empty dbURL or nil cfg skips that +// source. saltKey is LITELLM_SALT_KEY (or the proxy master_key). +func NewCredentialStore(ctx context.Context, dbURL string, cfg *LiteLLMConfig, saltKey string) (*CredentialStore, error) { + var db *gorm.DB + if strings.TrimSpace(dbURL) != "" { + var err error + if db, err = OpenDB(dbURL); err != nil { + return nil, err + } + } + s := &CredentialStore{db: db, cfg: cfg, saltKey: saltKey, byName: map[string]*LiteLLMModelCredential{}} + if err := s.loadNamed(ctx); err != nil { + if s.db != nil { + _ = s.Close() + } + return nil, err + } + return s, nil +} + +// Named returns the resolved (decrypted) credentials keyed by credential_name. +func (s *CredentialStore) Named() map[string]*LiteLLMModelCredential { return s.byName } + +func (s *CredentialStore) Close() error { + if s.db == nil { + return nil + } + sqlDB, err := s.db.DB() + if err != nil { + return err + } + return sqlDB.Close() +} + +func (s *CredentialStore) loadNamed(ctx context.Context) error { + if s.cfg != nil { + for _, c := range s.cfg.CredentialList { + name := strings.TrimSpace(c.CredentialName) + if name == "" { + continue + } + s.byName[name] = &LiteLLMModelCredential{ + CredentialName: name, + CredentialValues: &LiteLLMModelCredentialValues{ApiKey: strPtr(c.CredentialValues["api_key"]), ApiBase: strPtr(c.CredentialValues["api_base"])}, + CredentialInfo: &LiteLLMModelCredentialInfo{CustomLLMProvider: c.CredentialInfo["custom_llm_provider"]}, + } + } + } + + if s.db == nil { + return nil + } + var rows []credentialRow + if err := s.db.WithContext(ctx).Find(&rows).Error; err != nil { + return fmt.Errorf("read LiteLLM_CredentialsTable: %w", err) + } + for _, r := range rows { + name := strings.TrimSpace(r.CredentialName) + if name == "" { + continue + } + var vals dbCredentialValues + if err := json.Unmarshal([]byte(r.CredentialValues), &vals); err != nil { + return fmt.Errorf("credential %q values: %w", name, err) + } + apiKey, err := s.decryptPtr(vals.APIKey) + if err != nil { + return fmt.Errorf("credential %q api_key: %w", name, err) + } + apiBase, err := s.decryptPtr(vals.APIBase) + if err != nil { + return fmt.Errorf("credential %q api_base: %w", name, err) + } + var info dbCredentialInfo + if r.CredentialInfo != "" { + if err = json.Unmarshal([]byte(r.CredentialInfo), &info); err != nil { + return fmt.Errorf("credential %q info: %w", name, err) + } + } + provider := "" + if info.CustomLLMProvider != nil { + provider = *info.CustomLLMProvider + } + s.byName[name] = &LiteLLMModelCredential{ + CredentialName: name, + CredentialValues: &LiteLLMModelCredentialValues{ApiKey: apiKey, ApiBase: apiBase}, + CredentialInfo: &LiteLLMModelCredentialInfo{CustomLLMProvider: provider}, + } + } + return nil +} + +// ResolveCredential returns the fully resolved (decrypted) credential for in's +// credential_name, looked up in the config + database. in is returned unchanged +// when its credential is not found (e.g. an inline deployment with no name). +func (s *CredentialStore) ResolveCredential(in *LiteLLMModelCredential) *LiteLLMModelCredential { + if in == nil { + return nil + } + if c, ok := s.byName[strings.TrimSpace(in.CredentialName)]; ok { + return c + } + return in +} + +// Deployments builds the resolved per-deployment map (inline credential + budget) +// from litellm_params. Database deployments are keyed by model_id (decrypted); +// config deployments, which have no stable id, are keyed by model_name +// (plaintext). Callers look up by model_id first, then by model_name. +func (s *CredentialStore) Deployments(ctx context.Context) (map[string]*Deployment, error) { + out := map[string]*Deployment{} + + if s.cfg != nil { + for _, m := range s.cfg.ModelList { + p := m.LiteLLMParams + out[m.ModelName] = &Deployment{ + APIKey: strings.TrimSpace(p.APIKey), + APIBase: strings.TrimSpace(p.APIBase), + CustomLLMProvider: p.CustomLLMProvider, + MaxBudget: p.MaxBudget, + BudgetDuration: p.BudgetDuration, + AzureClientID: strings.TrimSpace(p.AzureClientID), + AzureClientSecret: strings.TrimSpace(p.AzureClientSecret), + AzureTenantID: strings.TrimSpace(p.AzureTenantID), + AzureADToken: strings.TrimSpace(p.AzureADToken), + AWSAccessKeyID: strings.TrimSpace(p.AWSAccessKeyID), + AWSSecretAccessKey: strings.TrimSpace(p.AWSSecretAccessKey), + AWSRegionName: strings.TrimSpace(p.AWSRegionName), + AWSRoleName: strings.TrimSpace(p.AWSRoleName), + AWSSessionName: strings.TrimSpace(p.AWSSessionName), + AWSSessionToken: strings.TrimSpace(p.AWSSessionToken), + VertexProject: strings.TrimSpace(p.VertexProject), + VertexLocation: strings.TrimSpace(p.VertexLocation), + VertexCredentials: strings.TrimSpace(p.VertexCredentials), + } + } + } + + if s.db == nil { + return out, nil + } + var rows []proxyModelRow + if err := s.db.WithContext(ctx).Find(&rows).Error; err != nil { + return nil, fmt.Errorf("read LiteLLM_ProxyModelTable: %w", err) + } + for _, r := range rows { + var p dbLitellmParams + if err := json.Unmarshal([]byte(r.LitellmParams), &p); err != nil { + return nil, fmt.Errorf("model %q params: %w", r.ModelID, err) + } + apiKey, err := s.decryptPtr(p.APIKey) + if err != nil { + return nil, fmt.Errorf("model %q api_key: %w", r.ModelID, err) + } + apiBase, err := s.decryptPtr(p.APIBase) + if err != nil { + return nil, fmt.Errorf("model %q api_base: %w", r.ModelID, err) + } + provider := "" + if p.CustomLLMProvider != nil { + dec, err := s.decrypt(*p.CustomLLMProvider) + if err != nil { + return nil, fmt.Errorf("model %q custom_llm_provider: %w", r.ModelID, err) + } + provider = dec + } + budgetDuration, err := s.decryptPtr(p.BudgetDuration) + if err != nil { + return nil, fmt.Errorf("model %q budget_duration: %w", r.ModelID, err) + } + // Azure + azureClientID, err := s.decryptPtr(p.AzureClientID) + if err != nil { + return nil, fmt.Errorf("model %q azure_client_id: %w", r.ModelID, err) + } + azureClientSecret, err := s.decryptPtr(p.AzureClientSecret) + if err != nil { + return nil, fmt.Errorf("model %q azure_client_secret: %w", r.ModelID, err) + } + azureTenantID, err := s.decryptPtr(p.AzureTenantID) + if err != nil { + return nil, fmt.Errorf("model %q azure_tenant_id: %w", r.ModelID, err) + } + azureADToken, err := s.decryptPtr(p.AzureADToken) + if err != nil { + return nil, fmt.Errorf("model %q azure_ad_token: %w", r.ModelID, err) + } + // Bedrock + awsAccessKeyID, err := s.decryptPtr(p.AWSAccessKeyID) + if err != nil { + return nil, fmt.Errorf("model %q aws_access_key_id: %w", r.ModelID, err) + } + awsSecretAccessKey, err := s.decryptPtr(p.AWSSecretAccessKey) + if err != nil { + return nil, fmt.Errorf("model %q aws_secret_access_key: %w", r.ModelID, err) + } + awsRegionName, err := s.decryptPtr(p.AWSRegionName) + if err != nil { + return nil, fmt.Errorf("model %q aws_region_name: %w", r.ModelID, err) + } + awsRoleName, err := s.decryptPtr(p.AWSRoleName) + if err != nil { + return nil, fmt.Errorf("model %q aws_role_name: %w", r.ModelID, err) + } + awsSessionName, err := s.decryptPtr(p.AWSSessionName) + if err != nil { + return nil, fmt.Errorf("model %q aws_session_name: %w", r.ModelID, err) + } + awsSessionToken, err := s.decryptPtr(p.AWSSessionToken) + if err != nil { + return nil, fmt.Errorf("model %q aws_session_token: %w", r.ModelID, err) + } + // Vertex + vertexProject, err := s.decryptPtr(p.VertexProject) + if err != nil { + return nil, fmt.Errorf("model %q vertex_project: %w", r.ModelID, err) + } + vertexLocation, err := s.decryptPtr(p.VertexLocation) + if err != nil { + return nil, fmt.Errorf("model %q vertex_location: %w", r.ModelID, err) + } + vertexCredentials, err := s.decryptPtr(p.VertexCredentials) + if err != nil { + return nil, fmt.Errorf("model %q vertex_credentials: %w", r.ModelID, err) + } + out[r.ModelID] = &Deployment{ + APIKey: deref(apiKey), + APIBase: deref(apiBase), + CustomLLMProvider: provider, + MaxBudget: p.MaxBudget, + BudgetDuration: budgetDuration, + AzureClientID: deref(azureClientID), + AzureClientSecret: deref(azureClientSecret), + AzureTenantID: deref(azureTenantID), + AzureADToken: deref(azureADToken), + AWSAccessKeyID: deref(awsAccessKeyID), + AWSSecretAccessKey: deref(awsSecretAccessKey), + AWSRegionName: deref(awsRegionName), + AWSRoleName: deref(awsRoleName), + AWSSessionName: deref(awsSessionName), + AWSSessionToken: deref(awsSessionToken), + VertexProject: deref(vertexProject), + VertexLocation: deref(vertexLocation), + VertexCredentials: deref(vertexCredentials), + } + } + return out, nil +} + +// decryptPtr decrypts a non-nil encrypted value, returning a pointer to the +// plaintext (nil stays nil). +func (s *CredentialStore) decryptPtr(v *string) (*string, error) { + if v == nil { + return nil, nil + } + dec, err := s.decrypt(*v) + if err != nil { + return nil, err + } + return &dec, nil +} + +func (s *CredentialStore) decrypt(v string) (string, error) { + return DecryptValue(v, s.saltKey) +} + +func strPtr(s string) *string { + if s == "" { + return nil + } + v := s + return &v +} + +func deref(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/scripts/litellm-to-bifrost/litellm/litellm.go b/scripts/litellm-to-bifrost/litellm/litellm.go new file mode 100644 index 0000000000..1765a92d9f --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/litellm.go @@ -0,0 +1,141 @@ +package litellm + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +var httpClient = &http.Client{Timeout: 30 * time.Second} + +// DefaultModelProvider is the provider assumed for a deployment whose model has +// no provider prefix and no custom_llm_provider. LiteLLM itself defaults such +// bare names to OpenAI. +const DefaultModelProvider = "openai" + +// doGet is a helper that executes an authenticated HTTP GET request. It appends +// the provided path to the client's BaseURL, attaches the Bearer token, and +// unmarshals the resulting JSON response body into the target destination. +func doGet[T any](ctx context.Context, c *LiteLLMClient, path string, target *T) error { + url := strings.TrimRight(c.BaseURL, "/") + path + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error while reading response body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + if err := json.Unmarshal(body, target); err != nil { + return fmt.Errorf("decode error: %w", err) + } + return nil +} + +// LiteLLMClient reads entities from the LiteLLM management API. +type LiteLLMClient struct { + BaseURL string + APIKey string // LiteLLM admin key +} + +// LiteLLMConfig mirrors the parts of LiteLLM's config.yaml the migration consumes. +type LiteLLMConfig struct { + ModelList []LiteLLMModel `yaml:"model_list"` + RouterSettings LiteLLMRouterSettings `yaml:"router_settings"` + CredentialList []LiteLLMConfigCredential `yaml:"credential_list"` +} + +// LiteLLMConfigCredential is one credential_list entry. Config-declared +// credential values are plaintext (env refs like "os.environ/FOO" or literals), +// unlike the encrypted values in LiteLLM_CredentialsTable. +type LiteLLMConfigCredential struct { + CredentialName string `yaml:"credential_name"` + CredentialValues map[string]string `yaml:"credential_values"` + CredentialInfo map[string]string `yaml:"credential_info"` +} + +// LiteLLMModel is one model_list entry. +type LiteLLMModel struct { + ModelName string `yaml:"model_name"` + LiteLLMParams LiteLLMParams `yaml:"litellm_params"` +} + +// LiteLLMParams mirrors litellm_params fields consumed by the migration. +type LiteLLMParams struct { + Model string `yaml:"model"` // litellm model name - "openai/gpt-4o", "bedrock/*", "gpt-4.1" + APIKey string `yaml:"api_key"` // env ref ("os.environ/FOO") or literal + APIBase string `yaml:"api_base"` // per-deployment base URL + CustomLLMProvider string `yaml:"custom_llm_provider"` // explicit provider override + MaxBudget *float64 `yaml:"max_budget"` // per-deployment spend cap + BudgetDuration *string `yaml:"budget_duration"` // budget reset window, e.g. "30d", "1mo" + RPM *int64 `yaml:"rpm"` // per-deployment request per minute override + TPM *int64 `yaml:"tpm"` // per-deployment token per minute override + // Azure-specific + AzureClientID string `yaml:"azure_client_id"` + AzureClientSecret string `yaml:"azure_client_secret"` + AzureTenantID string `yaml:"azure_tenant_id"` + AzureADToken string `yaml:"azure_ad_token"` + // Bedrock-specific (AWS credentials) + AWSAccessKeyID string `yaml:"aws_access_key_id"` + AWSSecretAccessKey string `yaml:"aws_secret_access_key"` + AWSRegionName string `yaml:"aws_region_name"` + AWSRoleName string `yaml:"aws_role_name"` // IAM role ARN for STS AssumeRole + AWSSessionName string `yaml:"aws_session_name"` // STS session name + AWSSessionToken string `yaml:"aws_session_token"` // temporary session token + // Vertex-specific (GCP credentials) + VertexProject string `yaml:"vertex_project"` + VertexLocation string `yaml:"vertex_location"` + VertexCredentials string `yaml:"vertex_credentials"` // service account JSON or ADC path; empty = ADC +} + +// Normalize trims LiteLLM params and validates that a target model is present. +func (p *LiteLLMParams) Normalize() error { + p.Model = strings.TrimSpace(p.Model) + if p.Model == "" { + return fmt.Errorf("model is empty") + } + + p.APIKey = strings.TrimSpace(p.APIKey) + p.APIBase = strings.TrimSpace(p.APIBase) + p.CustomLLMProvider = strings.TrimSpace(p.CustomLLMProvider) + return nil +} + +// LiteLLMRouterSettings mirrors router_settings. +type LiteLLMRouterSettings struct { + ModelGroupAlias map[string]string `yaml:"model_group_alias"` +} + +// ReadLiteLLMConfig parses a LiteLLM proxy config YAML from disk. The model +// migration reads secrets here rather than from the management API, which +// redacts them. +func ReadLiteLLMConfig(path string) (*LiteLLMConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read litellm config %q: %w", path, err) + } + var cfg LiteLLMConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse litellm config %q: %w", path, err) + } + return &cfg, nil +} diff --git a/scripts/litellm-to-bifrost/litellm/models.go b/scripts/litellm-to-bifrost/litellm/models.go new file mode 100644 index 0000000000..7bb6bb5abe --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/models.go @@ -0,0 +1,73 @@ +package litellm + +import ( + "context" + "fmt" +) + +type LiteLLMModelInfo struct { + ModelName string `json:"model_name"` + LiteLLMParams *LiteLLMModelInfoLiteLLMParams `json:"litellm_params"` + ModelInfo *LiteLLMModelInfoParams `json:"model_info"` +} + +type LiteLLMModelInfoLiteLLMParams struct { + UseInPassThrough bool `json:"use_in_pass_through"` + UseLitellmProxy bool `json:"use_litellm_proxy"` + MergeReasoningContentInChoices bool `json:"merge_reasoning_content_in_choices"` + Model string `json:"model"` + CustomLLMProvider string `json:"custom_llm_provider"` + LiteLLMCredentialName *string `json:"litellm_credential_name"` + Tpm *int64 `json:"tpm"` + Rpm *int64 `json:"rpm"` +} + +type LiteLLMModelInfoParams struct { + Id string `json:"id"` + DbModel bool `json:"db_model"` + Key string `json:"key"` + MaxTokens *int `json:"max_tokens"` + MaxInputTokens *int `json:"max_input_tokens"` + MaxOutputTokens *int `json:"max_output_tokens"` + Tpm *int64 `json:"tpm"` + Rpm *int64 `json:"rpm"` +} + +type LiteLLMModelCredential struct { + CredentialName string `json:"credential_name"` + CredentialValues *LiteLLMModelCredentialValues `json:"credential_values"` + CredentialInfo *LiteLLMModelCredentialInfo `json:"credential_info"` +} + +type LiteLLMModelCredentialValues struct { + ApiKey *string `json:"api_key"` + ApiBase *string `json:"api_base"` +} + +type LiteLLMModelCredentialInfo struct { + CustomLLMProvider string `json:"custom_llm_provider"` +} + +// ListModelInfo fetches every deployment via GET /model/info +func (c *LiteLLMClient) ListModelInfo(ctx context.Context) ([]LiteLLMModelInfo, error) { + var out struct { + Data []LiteLLMModelInfo `json:"data"` + } + + if err := doGet(ctx, c, "/model/info", &out); err != nil { + return nil, fmt.Errorf("list model info: %w", err) + } + return out.Data, nil +} + +// ListCredentials fetches every stored credential via GET /credentials +func (c *LiteLLMClient) ListCredentials(ctx context.Context) ([]LiteLLMModelCredential, error) { + var out struct { + Credentials []LiteLLMModelCredential `json:"credentials"` + } + + if err := doGet(ctx, c, "/credentials", &out); err != nil { + return nil, fmt.Errorf("list credentials: %w", err) + } + return out.Credentials, nil +} diff --git a/scripts/litellm-to-bifrost/litellm/organization.go b/scripts/litellm-to-bifrost/litellm/organization.go new file mode 100644 index 0000000000..bea803bcb2 --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/organization.go @@ -0,0 +1,36 @@ +package litellm + +import ( + "context" + "fmt" +) + +// LiteLLMOrganization mirrors LiteLLM_OrganizationTableWithMembers from +// GET /organization/list. Only the fields relevant to a Bifrost customer are +// decoded; everything else (teams, members, keys, models, spend, metadata) is +// intentionally ignored — those are migrated as their own entities. +type LiteLLMOrganization struct { + OrganizationID string `json:"organization_id"` + OrganizationAlias string `json:"organization_alias"` + Models []string `json:"models"` // allowed model names (Bifrost gates these on the VK) + Budget *LiteLLMBudget `json:"litellm_budget_table"` +} + +// LiteLLMBudget mirrors the joined LiteLLM_BudgetTable row. In LiteLLM the +// budget carries both the spend limit (max_budget / budget_duration) and the +// rate limits (tpm_limit / rpm_limit) for the organization. +type LiteLLMBudget struct { + MaxBudget *float64 `json:"max_budget"` // spend cap in USD + BudgetDuration *string `json:"budget_duration"` // reset window, e.g. "30d", "1mo" + TPMLimit *int64 `json:"tpm_limit"` // tokens per minute + RPMLimit *int64 `json:"rpm_limit"` // requests per minute +} + +// ListOrganizations fetches every organization via GET /organization/list. +func (c *LiteLLMClient) ListOrganizations(ctx context.Context) ([]LiteLLMOrganization, error) { + var orgs []LiteLLMOrganization + if err := doGet(ctx, c, "/organization/list", &orgs); err != nil { + return nil, fmt.Errorf("list organizations: %w", err) + } + return orgs, nil +} diff --git a/scripts/litellm-to-bifrost/litellm/team.go b/scripts/litellm-to-bifrost/litellm/team.go new file mode 100644 index 0000000000..b02174194d --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/team.go @@ -0,0 +1,34 @@ +package litellm + +import ( + "context" + "fmt" +) + +// LiteLLMTeam mirrors a row from GET /team/list. Only the fields relevant to a +// Bifrost team are decoded; members, keys, models, metadata and spend are +// intentionally ignored (members/keys/models are their own entities, and +// Bifrost teams carry no membership). +// +// Unlike organizations, a team's spend cap and rate limits are inline on the +// team object (max_budget / budget_duration / tpm_limit / rpm_limit), not in a +// joined litellm_budget_table. +type LiteLLMTeam struct { + TeamID string `json:"team_id"` + TeamAlias string `json:"team_alias"` + OrganizationID *string `json:"organization_id"` // nil/empty => standalone team + Models []string `json:"models"` // allowed model names (Bifrost gates these on the VK) + MaxBudget *float64 `json:"max_budget"` + BudgetDuration *string `json:"budget_duration"` + TPMLimit *int64 `json:"tpm_limit"` + RPMLimit *int64 `json:"rpm_limit"` +} + +// ListTeams fetches every team via GET /team/list. +func (c *LiteLLMClient) ListTeams(ctx context.Context) ([]LiteLLMTeam, error) { + var teams []LiteLLMTeam + if err := doGet(ctx, c, "/team/list", &teams); err != nil { + return nil, fmt.Errorf("list teams: %w", err) + } + return teams, nil +} diff --git a/scripts/litellm-to-bifrost/litellm/user.go b/scripts/litellm-to-bifrost/litellm/user.go new file mode 100644 index 0000000000..8727198363 --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/user.go @@ -0,0 +1,70 @@ +package litellm + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// LiteLLMUser mirrors a row from GET /user/list (LiteLLM_UserTable). Only the +// fields relevant to a Bifrost user are decoded. teams holds the LiteLLM +// team_ids the user belongs to; the migration links the user to the Bifrost +// team of the same name (resolved by the caller). +type LiteLLMUser struct { + UserID string `json:"user_id"` + UserEmail *string `json:"user_email"` + UserAlias *string `json:"user_alias"` + UserRole *string `json:"user_role"` + Teams []string `json:"teams"` // LiteLLM team_ids + MaxBudget *float64 `json:"max_budget"` + BudgetDuration *string `json:"budget_duration"` + TPMLimit *int64 `json:"tpm_limit"` + RPMLimit *int64 `json:"rpm_limit"` +} + +// ListUsers fetches every internal user via GET /user/list, following +// pagination because the endpoint returns a fixed page. +func (c *LiteLLMClient) ListUsers(ctx context.Context) ([]LiteLLMUser, error) { + const pageSize = 100 + base := strings.TrimRight(c.BaseURL, "/") + "/user/list" + + var all []LiteLLMUser + for page := 1; ; page++ { + endpoint := fmt.Sprintf("%s?page=%d&page_size=%d", base, page, pageSize) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("list users: %w", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("list users: read body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("list users: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var pageResp struct { + Users []LiteLLMUser `json:"users"` + TotalPages int `json:"total_pages"` + } + if err := json.Unmarshal(body, &pageResp); err != nil { + return nil, fmt.Errorf("decode users: %w", err) + } + all = append(all, pageResp.Users...) + + if page >= pageResp.TotalPages || len(pageResp.Users) == 0 { + break + } + } + return all, nil +} diff --git a/scripts/litellm-to-bifrost/litellm/virtualkey.go b/scripts/litellm-to-bifrost/litellm/virtualkey.go new file mode 100644 index 0000000000..7b864443fd --- /dev/null +++ b/scripts/litellm-to-bifrost/litellm/virtualkey.go @@ -0,0 +1,49 @@ +package litellm + +import ( + "context" + "fmt" +) + +// LiteLLMVirtualKey mirrors a row from GET /key/list (return_full_object=true). +// Only the fields a Bifrost virtual key can carry are decoded. +type LiteLLMVirtualKey struct { + KeyAlias *string `json:"key_alias"` + KeyName *string `json:"key_name"` // masked token, used as a name fallback + UserID *string `json:"user_id"` // not linkable on VK create (see report) + TeamID *string `json:"team_id"` + OrgID *string `json:"org_id"` + Models []string `json:"models"` // allowed model names ([] => all) + MaxBudget *float64 `json:"max_budget"` + BudgetDuration *string `json:"budget_duration"` + TPMLimit *int64 `json:"tpm_limit"` + RPMLimit *int64 `json:"rpm_limit"` + Blocked *bool `json:"blocked"` + Expires *string `json:"expires"` +} + +// ListVirtualKeys fetches every virtual key via GET /key/list, following +// pagination. return_full_object=true yields the budget/limit/model fields and +// include_team_keys=true includes team-scoped keys. +func (c *LiteLLMClient) ListVirtualKeys(ctx context.Context) ([]LiteLLMVirtualKey, error) { + const pageSize = 100 + + var all []LiteLLMVirtualKey + for page := 1; ; page++ { + path := fmt.Sprintf("/key/list?include_team_keys=true&return_full_object=true&page=%d&page_size=%d", page, pageSize) + + var pageResp struct { + Keys []LiteLLMVirtualKey `json:"keys"` + TotalPages int `json:"total_pages"` + } + if err := doGet(ctx, c, path, &pageResp); err != nil { + return nil, fmt.Errorf("list virtual keys: %w", err) + } + all = append(all, pageResp.Keys...) + + if page >= pageResp.TotalPages || len(pageResp.Keys) == 0 { + break + } + } + return all, nil +} diff --git a/scripts/litellm-to-bifrost/main.go b/scripts/litellm-to-bifrost/main.go new file mode 100644 index 0000000000..bd94a9fb7d --- /dev/null +++ b/scripts/litellm-to-bifrost/main.go @@ -0,0 +1,624 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "sort" + "strings" + "time" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// main parses CLI flags and runs all LiteLLM entity migrations in order. +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + cfg := NewMigrationRunConfig() + var migrateErrors []error + if err := runModels(ctx, cfg); err != nil { + log.Printf("ERROR model migration: %v", err) + migrateErrors = append(migrateErrors, err) + } + if err := runOrganizations(ctx, cfg); err != nil { + log.Printf("ERROR organization migration: %v", err) + migrateErrors = append(migrateErrors, err) + } + if err := runTeams(ctx, cfg); err != nil { + log.Printf("ERROR team migration: %v", err) + migrateErrors = append(migrateErrors, err) + } + if err := runUsers(ctx, cfg); err != nil { + log.Printf("ERROR user migration: %v", err) + migrateErrors = append(migrateErrors, err) + } + if err := runVKs(ctx, cfg); err != nil { + log.Printf("ERROR virtual key migration: %v", err) + migrateErrors = append(migrateErrors, err) + } + if len(migrateErrors) > 0 { + log.Fatalf("migration completed with %d error(s); see above", len(migrateErrors)) + } +} + +// MigrationRunConfig carries all CLI/runtime state shared needed for migration. +type MigrationRunConfig struct { + LiteLLMClient *litellm.LiteLLMClient + BifrostClient *BifrostClient + LiteLLMConfigPath string + LiteLLMDBURL string + LiteLLMSaltKey string + DefaultProvider string + MaxBudgetPeriod string + DryRun bool +} + +func requireEnv(key string) string { + if val, ok := os.LookupEnv(key); ok { + return val + } + log.Fatalf("%s is required.", key) + return "" +} + +func NewMigrationRunConfig() MigrationRunConfig { + litellmMasterKey := requireEnv("LITELLM_MASTER_KEY") + + litellmSaltKey, ok := os.LookupEnv("LITELLM_SALT_KEY") + if !ok { + litellmSaltKey = litellmMasterKey + } + + defaultProvider, ok := os.LookupEnv("DEFAULT_PROVIDER") + if !ok { + defaultProvider = litellm.DefaultModelProvider + } + maxBudgetPeriod, ok := os.LookupEnv("MAX_BUDGET_PERIOD") + if !ok { + maxBudgetPeriod = defaultMaxBudgetPeriod + } + + return MigrationRunConfig{ + LiteLLMClient: &litellm.LiteLLMClient{ + BaseURL: requireEnv("LITELLM_URL"), + APIKey: litellmMasterKey, + }, + BifrostClient: &BifrostClient{ + BaseURL: requireEnv("BIFROST_URL"), + APIKey: requireEnv("BIFROST_API_KEY"), + }, + LiteLLMConfigPath: requireEnv("LITELLM_CONFIG"), + LiteLLMDBURL: os.Getenv("LITELLM_DB_URL"), + LiteLLMSaltKey: litellmSaltKey, + DefaultProvider: defaultProvider, + MaxBudgetPeriod: maxBudgetPeriod, + DryRun: os.Getenv("DRY_RUN") == "1", + } +} + +// runOrganizations migrates every LiteLLM organization into a Bifrost customer. Per-org +// failures are logged and counted but do not abort the whole run, so one bad +// record never blocks the rest of the migration. +func runOrganizations(ctx context.Context, cfg MigrationRunConfig) error { + orgs, err := cfg.LiteLLMClient.ListOrganizations(ctx) + if err != nil { + return err + } + log.Printf("fetched %d organization(s) from LiteLLM", len(orgs)) + + var migrated, skipped, failed int + for _, org := range orgs { + customer, err := LiteLLMOrganizationToBifrostCustomer(org, cfg) + if err != nil { + log.Printf("SKIP org %q: %v", org.OrganizationID, err) + skipped++ + continue + } + + if cfg.DryRun { + log.Printf("DRY-RUN org %q -> customer %q (budgets=%d, rateLimit=%v)", + org.OrganizationID, customer.Name, len(customer.Budgets), customer.RateLimit != nil) + migrated++ + continue + } + + if err := cfg.BifrostClient.CreateCustomer(ctx, customer); err != nil { + log.Printf("FAIL org %q: %v", org.OrganizationID, err) + failed++ + continue + } + log.Printf("OK org %q -> customer %q", org.OrganizationID, customer.Name) + migrated++ + } + + log.Printf("done: %d migrated, %d skipped, %d failed", migrated, skipped, failed) + if failed > 0 { + return fmt.Errorf("%d organization(s) failed to migrate", failed) + } + return nil +} + +// runTeams migrates every LiteLLM team into a Bifrost team. A team inside an +// organization is linked to the migrated customer: team.organization_id -> the +// org's alias -> the Bifrost customer of the same name. When that customer +// cannot be resolved (org has no alias, or was not migrated yet), the team is +// created unlinked and a warning is logged rather than failing the run. +// Per-team failures are logged and counted but do not abort the whole run. +func runTeams(ctx context.Context, cfg MigrationRunConfig) error { + teams, err := cfg.LiteLLMClient.ListTeams(ctx) + if err != nil { + return err + } + orgs, err := cfg.LiteLLMClient.ListOrganizations(ctx) + if err != nil { + return err + } + log.Printf("fetched %d team(s) and %d organization(s) from LiteLLM", len(teams), len(orgs)) + + // org id -> alias, to map a team's organization onto its Bifrost customer. + aliasByOrgID := make(map[string]string, len(orgs)) + for _, o := range orgs { + aliasByOrgID[o.OrganizationID] = strings.TrimSpace(o.OrganizationAlias) + } + + var migrated, skipped, failed, unlinked int + for _, team := range teams { + // Resolve the customer link (read-only; safe in dry-run too). + var customerID *string + if team.OrganizationID != nil && *team.OrganizationID != "" { + alias := aliasByOrgID[*team.OrganizationID] + if alias == "" { + log.Printf("WARN team %q: organization %q not found / has no alias; creating unlinked", team.TeamAlias, *team.OrganizationID) + unlinked++ + } else if id, ok, err := cfg.BifrostClient.FindCustomerByName(ctx, alias); err != nil { + log.Printf("WARN team %q: resolving customer %q: %v; creating unlinked", team.TeamAlias, alias, err) + unlinked++ + } else if !ok { + log.Printf("WARN team %q: customer %q not migrated yet; creating unlinked", team.TeamAlias, alias) + unlinked++ + } else { + customerID = &id + } + } + + teamReq, err := LiteLLMTeamToBifrostTeam(team, customerID, cfg) + if err != nil { + log.Printf("SKIP team %q: %v", team.TeamID, err) + skipped++ + continue + } + + if cfg.DryRun { + link := "none" + if customerID != nil { + link = *customerID + } + log.Printf("DRY-RUN team %q -> team %q (customer=%s, budgets=%d, rateLimit=%v)", + team.TeamID, teamReq.Name, link, len(teamReq.Budgets), teamReq.RateLimit != nil) + migrated++ + continue + } + + if err := cfg.BifrostClient.CreateTeam(ctx, teamReq); err != nil { + log.Printf("FAIL team %q: %v", team.TeamID, err) + failed++ + continue + } + log.Printf("OK team %q -> team %q (customer=%v)", team.TeamID, teamReq.Name, customerID != nil) + migrated++ + } + + log.Printf("done: %d migrated, %d skipped, %d failed, %d unlinked", migrated, skipped, failed, unlinked) + if failed > 0 { + return fmt.Errorf("%d team(s) failed to migrate", failed) + } + return nil +} + +// runUsers migrates every LiteLLM internal user into a Bifrost user, then links +// each user to the Bifrost teams matching its LiteLLM team memberships: +// LiteLLM user.teams (team_ids) -> team alias -> Bifrost team of the same name. +// This assumes teams were migrated first. Users without an email are skipped +// (Bifrost requires one); unresolvable team links are warned and skipped. A +// user that already exists (duplicate email) is reused so memberships still +// link. Per-user failures are logged and counted but do not abort the run. +func runUsers(ctx context.Context, cfg MigrationRunConfig) error { + users, err := cfg.LiteLLMClient.ListUsers(ctx) + if err != nil { + return err + } + teams, err := cfg.LiteLLMClient.ListTeams(ctx) + if err != nil { + return err + } + log.Printf("fetched %d user(s) and %d team(s) from LiteLLM", len(users), len(teams)) + + // LiteLLM team_id -> alias, to map a user's memberships onto Bifrost teams. + aliasByTeamID := make(map[string]string, len(teams)) + for _, tm := range teams { + aliasByTeamID[tm.TeamID] = strings.TrimSpace(tm.TeamAlias) + } + + plans, report := LiteLLMUsersToBifrostUsers(users) + printUserReport(plans, report) + + if cfg.DryRun { + log.Printf("dry-run: %d user(s) planned, no writes performed", len(plans)) + return nil + } + + var migrated, failed, links, linkSkipped int + for _, p := range plans { + userID, err := cfg.BifrostClient.CreateUser(ctx, &BifrostCreateUserRequest{Name: p.Name, Email: p.Email}) + if err != nil { + log.Printf("FAIL user %q: %v", p.Email, err) + failed++ + continue + } + migrated++ + log.Printf("OK user %q -> %q", p.Email, userID) + + for _, srcTeamID := range p.SourceTeamIDs { + alias := aliasByTeamID[srcTeamID] + if alias == "" { + log.Printf("WARN user %q: team %q not found / has no alias; skipping link", p.Email, srcTeamID) + linkSkipped++ + continue + } + teamID, ok, err := cfg.BifrostClient.FindTeamByName(ctx, alias) + if err != nil { + log.Printf("WARN user %q: resolving team %q: %v; skipping link", p.Email, alias, err) + linkSkipped++ + continue + } + if !ok { + log.Printf("WARN user %q: team %q not migrated yet; skipping link", p.Email, alias) + linkSkipped++ + continue + } + if err := cfg.BifrostClient.AddTeamMember(ctx, teamID, userID); err != nil { + log.Printf("FAIL link user %q -> team %q: %v", p.Email, alias, err) + failed++ + continue + } + links++ + log.Printf("OK link user %q -> team %q", p.Email, alias) + } + } + + log.Printf("done: %d user(s) migrated, %d link(s), %d failed, %d link(s) skipped", migrated, links, failed, linkSkipped) + if failed > 0 { + return fmt.Errorf("%d user/link write(s) failed", failed) + } + return nil +} + +// printUserReport logs the planned users (with their team links) and everything +// the transform could not carry. +func printUserReport(plans []UserPlan, r UserMigrationReport) { + for _, p := range plans { + log.Printf("PLAN user %q (%s) -> teams=%v", p.Name, p.Email, p.SourceTeamIDs) + } + logReportSection("skipped (no email; Bifrost requires one)", r.SkippedNoEmail) + logReportSection("dropped roles (no LiteLLM->Bifrost role mapping)", r.DroppedRoles) + logReportSection("dropped user budgets (no field on create-user; use access profiles)", r.DroppedBudgets) + logReportSection("dropped user rate limits (no field on create-user; use access profiles)", r.DroppedRateLimits) +} + +// runVKs migrates every LiteLLM virtual key into a Bifrost virtual key. It +// reads the LiteLLM config to build a model_name -> (provider, model) index, so +// each VK's allow-list (the union of the key's, its team's and its +// organization's allowed models — Bifrost gates models only on the VK) maps to +// Bifrost provider configs. Ownership resolves to a Bifrost team (preferred) or +// customer (mutually exclusive). LiteLLM user_id, expiry and the key value +// itself have no VK-create field and are reported. Per-key failures are logged +// and counted but do not abort the run. +func runVKs(ctx context.Context, cfg MigrationRunConfig) error { + models, err := cfg.LiteLLMClient.ListModelInfo(ctx) + if err != nil { + return err + } + litellmCfg, err := litellm.ReadLiteLLMConfig(cfg.LiteLLMConfigPath) + if err != nil { + return err + } + store, err := litellm.NewCredentialStore(ctx, cfg.LiteLLMDBURL, litellmCfg, cfg.LiteLLMSaltKey) + if err != nil { + return err + } + defer store.Close() + deployments, err := store.Deployments(ctx) + if err != nil { + return err + } + // Build the provider/key plans to derive which keys serve which models. + plans, _, _ := LiteLLMModelsToBifrostProviders(models, store.Named(), deployments, cfg) + + // Fetch key UUIDs from Bifrost so VKs can attach specific keys by ID. + bifrostKeys, err := cfg.BifrostClient.ListAllKeys(ctx) + if err != nil { + return err + } + keyIDByName := make(map[string]string, len(bifrostKeys)) + for _, k := range bifrostKeys { + keyIDByName[k.Name] = k.KeyID + } + keyModelIdx, wildcardKeys := BuildKeyModelIndex(plans, keyIDByName) + + // Only providers that actually exist in Bifrost can be referenced by a VK + // (a create rejects unknown providers). For "all proxy models" VKs, expand + // to ALL providers in Bifrost — not just those found in the model index — + // so new providers added between migration runs are also covered. + bifrostProviders, err := cfg.BifrostClient.ListProviders(ctx) + if err != nil { + return err + } + allProviders := make([]string, 0, len(bifrostProviders)) + for p := range bifrostProviders { + allProviders = append(allProviders, p) + } + sort.Strings(allProviders) + + keys, err := cfg.LiteLLMClient.ListVirtualKeys(ctx) + if err != nil { + return err + } + teams, err := cfg.LiteLLMClient.ListTeams(ctx) + if err != nil { + return err + } + orgs, err := cfg.LiteLLMClient.ListOrganizations(ctx) + if err != nil { + return err + } + log.Printf("fetched %d virtual key(s), %d team(s), %d org(s), %d model deployment(s); key index has %d name(s), %d wildcard key(s), %d provider(s)", + len(keys), len(teams), len(orgs), len(models), len(keyModelIdx), len(wildcardKeys), len(allProviders)) + + // LiteLLM id -> (alias, allowed models, org id) for owner + model folding. + type entity struct { + alias string + models []string + orgID string // for teams: the owning org, so VK inherits org model restrictions + } + teamByID := make(map[string]entity, len(teams)) + for _, t := range teams { + orgID := "" + if t.OrganizationID != nil { + orgID = *t.OrganizationID + } + teamByID[t.TeamID] = entity{alias: strings.TrimSpace(t.TeamAlias), models: t.Models, orgID: orgID} + } + orgByID := make(map[string]entity, len(orgs)) + for _, o := range orgs { + orgByID[o.OrganizationID] = entity{alias: strings.TrimSpace(o.OrganizationAlias), models: o.Models} + } + + var migrated, failed, skipped, unlinked int + var unmappedModels, droppedUserLinks, droppedExpiries, droppedProviders []string + + for _, key := range keys { + var teamModels, orgModels []string + if key.TeamID != nil { + te := teamByID[*key.TeamID] + teamModels = te.models + // Fold in the team's org model restrictions: Bifrost has no org-level + // model gates, so any restriction the org imposes must live on the VK. + if te.orgID != "" { + orgModels = orgByID[te.orgID].models + } + } + if key.OrgID != nil { + orgModels = orgByID[*key.OrgID].models + } + + plan, err := LiteLLMVirtualKeyToBifrost(key, teamModels, orgModels, keyModelIdx, wildcardKeys, allProviders, cfg.MaxBudgetPeriod) + if err != nil { + log.Printf("SKIP virtual key: %v", err) + skipped++ + continue + } + + req := &BifrostCreateVirtualKeyRequest{Name: plan.Name, IsActive: plan.IsActive, RateLimit: plan.RateLimit} + for _, pc := range plan.ProviderConfigs { + // A VK create rejects providers that do not exist in Bifrost (e.g. + // one whose migration failed). Drop and report them. + if !bifrostProviders[pc.Provider] { + droppedProviders = append(droppedProviders, fmt.Sprintf("%s: %s", plan.Name, pc.Provider)) + continue + } + req.ProviderConfigs = append(req.ProviderConfigs, BifrostVKProviderConfigRequest{ + Provider: pc.Provider, + KeyIDs: pc.KeyIDs, + AllowedModels: []string{"*"}, + }) + } + if plan.Budget != nil { + req.Budgets = []BifrostCreateBudgetRequest{*plan.Budget} + } + + // Resolve owner -> Bifrost team (preferred) or customer. + owner := "none" + switch { + case plan.OwnerTeamID != nil: + alias := teamByID[*plan.OwnerTeamID].alias + if id, ok := resolveOwner(ctx, cfg.BifrostClient.FindTeamByName, alias); ok { + req.TeamID = &id + owner = "team:" + alias + } else { + log.Printf("WARN vkey %q: team %q not resolved; creating unlinked", plan.Name, alias) + unlinked++ + } + case plan.OwnerOrgID != nil: + alias := orgByID[*plan.OwnerOrgID].alias + if id, ok := resolveOwner(ctx, cfg.BifrostClient.FindCustomerByName, alias); ok { + req.CustomerID = &id + owner = "customer:" + alias + } else { + log.Printf("WARN vkey %q: customer %q not resolved; creating unlinked", plan.Name, alias) + unlinked++ + } + } + + if plan.SourceUserID != nil && *plan.SourceUserID != "" { + droppedUserLinks = append(droppedUserLinks, fmt.Sprintf("%s (user %s)", plan.Name, *plan.SourceUserID)) + } + if plan.Expiry != nil { + droppedExpiries = append(droppedExpiries, fmt.Sprintf("%s (%s)", plan.Name, *plan.Expiry)) + } + for _, m := range plan.UnmappedModels { + unmappedModels = append(unmappedModels, fmt.Sprintf("%s: %s", plan.Name, m)) + } + + if cfg.DryRun { + log.Printf("DRY-RUN vkey %q owner=%s providers=%d budgets=%d rateLimit=%v active=%v", + plan.Name, owner, len(req.ProviderConfigs), len(req.Budgets), req.RateLimit != nil, req.IsActive) + migrated++ + continue + } + + if err := cfg.BifrostClient.CreateVirtualKey(ctx, req); err != nil { + log.Printf("FAIL vkey %q: %v", plan.Name, err) + failed++ + continue + } + migrated++ + log.Printf("OK vkey %q (owner=%s, providers=%d)", plan.Name, owner, len(req.ProviderConfigs)) + } + + logReportSection("unmapped models (not in config; not granted)", unmappedModels) + logReportSection("dropped providers (not migrated to Bifrost)", droppedProviders) + logReportSection("dropped user links (no VK-create field; user gets VKs via access profiles)", droppedUserLinks) + logReportSection("dropped expiries (no VK-create field)", droppedExpiries) + if len(keys) > 0 { + log.Printf("NOTE virtual key values are server-generated; original LiteLLM tokens are NOT carried over") + } + log.Printf("done: %d migrated, %d skipped, %d failed, %d unlinked", migrated, skipped, failed, unlinked) + if failed > 0 { + return fmt.Errorf("%d virtual key(s) failed to migrate", failed) + } + return nil +} + +// resolveOwner looks up a Bifrost owner id by alias using the given finder, +// returning ok=false when the alias is blank or no match is found. +func resolveOwner(ctx context.Context, find func(context.Context, string) (string, bool, error), alias string) (string, bool) { + if alias == "" { + return "", false + } + id, ok, err := find(ctx, alias) + if err != nil || !ok { + return "", false + } + return id, true +} + +// runModels reads a LiteLLM proxy config, transforms its model deployments into +// Bifrost providers, keys and global model configs +func runModels(ctx context.Context, cfg MigrationRunConfig) error { + models, err := cfg.LiteLLMClient.ListModelInfo(ctx) + if err != nil { + return err + } + + // The management API masks secrets, so resolve real credential values from + // the config (plaintext) and the LiteLLM database (decrypted with the salt + // key): named credentials by name, inline ones per deployment. + litellmCfg, err := litellm.ReadLiteLLMConfig(cfg.LiteLLMConfigPath) + if err != nil { + return err + } + store, err := litellm.NewCredentialStore(ctx, cfg.LiteLLMDBURL, litellmCfg, cfg.LiteLLMSaltKey) + if err != nil { + return err + } + defer store.Close() + deployments, err := store.Deployments(ctx) + if err != nil { + return err + } + log.Printf("fetched %d model deployment(s) from LiteLLM; resolved %d named credential(s) and %d deployment param set(s)", len(models), len(store.Named()), len(deployments)) + + plans, modelConfigs, skippedProviders := LiteLLMModelsToBifrostProviders(models, store.Named(), deployments, cfg) + printModelReport(plans, modelConfigs, skippedProviders) + + if cfg.DryRun { + log.Printf("dry-run: %d provider(s), %d model config(s) planned, no writes performed", len(plans), len(modelConfigs)) + return nil + } + + var providersOK, keysOK, modelConfigsOK, failed int + for _, p := range plans { + if err := cfg.BifrostClient.EnsureProvider(ctx, p); err != nil { + log.Printf("FAIL provider %q: %v", p.Name, err) + failed++ + continue + } + providersOK++ + log.Printf("OK provider %q (custom=%v, base_url=%q)", p.Name, p.IsCustom, p.BaseURL) + for _, k := range p.Keys { + if err := cfg.BifrostClient.CreateProviderKey(ctx, p.Name, k); err != nil { + log.Printf("FAIL key %q/%q: %v", p.Name, k.Name, err) + failed++ + continue + } + keysOK++ + log.Printf("OK key %q/%q (models=%v)", p.Name, k.Name, k.Models) + } + } + for _, mc := range modelConfigs { + if err := cfg.BifrostClient.CreateModelConfig(ctx, mc); err != nil { + log.Printf("FAIL model config %q: %v", modelConfigSignature(mc), err) + failed++ + continue + } + modelConfigsOK++ + log.Printf("OK model config %q (budgets=%d, rateLimit=%v)", modelConfigSignature(mc), len(mc.Budgets), mc.RateLimit != nil) + } + + log.Printf("done: %d provider(s), %d key(s), %d model config(s) written, %d failed", providersOK, keysOK, modelConfigsOK, failed) + if failed > 0 { + return fmt.Errorf("%d model migration write(s) failed", failed) + } + return nil +} + +// printModelReport logs the planned providers/keys and everything the transform +// could not faithfully carry, so the operator can act on it. +func printModelReport(plans []ProviderPlan, modelConfigs []ModelConfigPlan, skippedProviders []SkippedProvider) { + for _, p := range plans { + log.Printf("PLAN provider %q (custom=%v, base_url=%q): %d key(s)", p.Name, p.IsCustom, p.BaseURL, len(p.Keys)) + for _, k := range p.Keys { + log.Printf(" key %q models=%v", k.Name, k.Models) + } + } + for _, mc := range modelConfigs { + provider := "*" + if mc.Provider != nil { + provider = *mc.Provider + } + log.Printf("PLAN model config source=%q provider=%q model=%q budgets=%d rateLimit=%v", mc.SourceName, provider, mc.ModelName, len(mc.Budgets), mc.RateLimit != nil) + } + var skippedProvidersStr []string + for _, sp := range skippedProviders { + if sp.Reason != "" { + skippedProvidersStr = append(skippedProvidersStr, fmt.Sprintf("%s (%s)", sp.Provider, sp.Reason)) + } else { + skippedProvidersStr = append(skippedProvidersStr, sp.Provider) + } + } + logReportSection("skipped providers", skippedProvidersStr) +} + +// logReportSection prints a titled migration report section when it has items. +func logReportSection(title string, items []string) { + if len(items) == 0 { + return + } + log.Printf("REPORT %s (%d):", title, len(items)) + for _, it := range items { + log.Printf(" - %s", it) + } +} diff --git a/scripts/litellm-to-bifrost/model.go b/scripts/litellm-to-bifrost/model.go new file mode 100644 index 0000000000..e7e7bafeff --- /dev/null +++ b/scripts/litellm-to-bifrost/model.go @@ -0,0 +1,1031 @@ +package main + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/url" + "regexp" + "sort" + "strings" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// Secrets are redacted by the LiteLLM management API (/credentials masks the +// api_key, /model/info omits it entirely), so the model migration reads the +// config file directly. It carries the unredacted credential references like +// `os.environ/FOO` env refs and literal keys that Bifrost needs. + +// ProviderPlan is one Bifrost provider to ensure, with the keys to add under +// it. Name is the management-API provider key (a standard provider name, or a +// generated name for a custom provider). +type ProviderPlan struct { + Name string + IsCustom bool // true => POST /api/providers with CustomProviderConfig + BaseProvider string // for custom providers: base_provider_type + BaseURL string // network_config.base_url (LiteLLM api_base) + Keys []KeyPlan +} + +// KeyPlan is one Bifrost key under a provider. Value is the wire string Bifrost +// resolves: "env.FOO" => from environment, anything else => literal value. +type KeyPlan struct { + Name string + Value string + Models []string // allowlist; ["*"] when the deployment serves all models + // URL and VLLMModelName carry per-key routing for the keyless self-hosted + // providers (vllm/ollama): these are standard Bifrost providers whose + // server URL lives on each key, so distinct base URLs become distinct keys. + // URL set => emit the provider-matching *_key_config. VLLMModelName sets + // vllm_key_config.model_name (vllm selects a key by the exact served model). + URL string + VLLMModelName string + // Provider-specific structured credentials (Azure/Bedrock/Vertex). + // Exactly one of these is set when the provider requires it. + AzureKeyConfig *BifrostAzureKeyConfig + BedrockKeyConfig *BifrostBedrockKeyConfig + VertexKeyConfig *BifrostVertexKeyConfig +} + +// ModelConfigPlan is one global Bifrost model config created from LiteLLM +// deployment-level budgets and rate limits. +type ModelConfigPlan struct { + SourceName string + ModelName string + Provider *string + Budgets []BifrostCreateBudgetRequest + RateLimit *BifrostCreateRateLimitRequest +} + +type SkippedProvider struct { + Provider string + Reason string +} + +type modelMigrationInput struct { + Params *litellm.LiteLLMModelInfoLiteLLMParams + Deployment *litellm.Deployment + RawModel string + APIKey string + APIBase string + CredName string + BaseProvider string + // Azure structured credentials + AzureClientID string + AzureClientSecret string + AzureTenantID string + AzureADToken string + // Bedrock (AWS) structured credentials + AWSAccessKeyID string + AWSSecretAccessKey string + AWSRegionName string + AWSRoleName string // IAM role ARN for STS AssumeRole + AWSSessionName string + AWSSessionToken string + // Vertex (GCP) structured credentials + VertexProject string + VertexLocation string + VertexCredentials string // service-account JSON or ADC ref; empty = ADC +} + +// liteLLMProviderAliases maps LiteLLM provider prefix strings that differ from +// Bifrost standard provider names. Applied in deriveProvider after lowercasing. +var liteLLMProviderAliases = map[string]string{ + "vertex_ai": "vertex", + "vertex_ai_beta": "vertex", + "hosted_vllm": "vllm", + "ollama_chat": "ollama", + "cohere_chat": "cohere", + "text-completion-openai": "openai", + "azure_ai": "azure", + "fireworks_ai": "fireworks", +} + +// specialCredProviders are Bifrost standard providers that use structured +// credentials instead of a plain api_key: Azure (per-key endpoint + optional +// Entra ID), Bedrock (AWS IAM credentials), Vertex (GCP credentials). +// These are intercepted before the standard api_key path in +// LiteLLMModelsToBifrostProviders. +var specialCredProviders = map[string]bool{ + "azure": true, + "bedrock": true, + "vertex": true, +} + +// standardProviders is the set of Bifrost standard provider names (from schemas.StandardProviders). +// A LiteLLM provider absent from this is either added as a custom provider or is skipped and reported. +var standardProviders = map[string]bool{ + "anthropic": true, + "azure": true, + "bedrock": true, + "cerebras": true, + "cohere": true, + "elevenlabs": true, + "fireworks": true, + "gemini": true, + "groq": true, + "huggingface": true, + "mistral": true, + "nebius": true, + "ollama": true, + "openai": true, + "openrouter": true, + "parasail": true, + "perplexity": true, + "replicate": true, + "runway": true, + "vertex": true, + "vllm": true, + "xai": true, +} + +var nonAlphanumRe = regexp.MustCompile(`[^a-z0-9]+`) + +// maxProviderNameLen is the Bifrost provider-name column limit (varchar(50)). +const maxProviderNameLen = 50 + +// customProviders is the subset of Bifrost base providers we allow to back a +// custom provider (a deployment whose credential sets a non-standard api_base). +// A deployment whose base provider is outside this set is skipped. +var customProviders = map[string]bool{ + "openai": true, + "anthropic": true, + "gemini": true, + "bedrock": true, +} + +// selfHostedURLProviders are keyless standard Bifrost providers that route per +// key: each carries the server URL in its *_key_config rather than at the +// provider level, so a deployment's api_base becomes a per-key URL instead of a +// custom provider. vllm additionally selects a key by the exact served model. +var selfHostedURLProviders = map[string]bool{ + "vllm": true, + "ollama": true, +} + +// LiteLLMModelsToBifrostProviders transforms LiteLLM model deployments into +// Bifrost providers, keys and global model configs. +// +// credByName holds the resolved (decrypted) named credentials keyed by CredentialValues +// +// deployments holds the resolved per-deployment litellm_params (inline credential + budget) +// keyed by model_id (database) and, as a fallback, model_name (config). +func LiteLLMModelsToBifrostProviders(models []litellm.LiteLLMModelInfo, credByName map[string]*litellm.LiteLLMModelCredential, deployments map[string]*litellm.Deployment, cfg MigrationRunConfig) ([]ProviderPlan, []ModelConfigPlan, []SkippedProvider) { + // keyAgg accumulates one Bifrost key's allowlist across the deployments that + // share its credential. wildcard records a "*" deployment, which collapses + // the allowlist to ["*"]. + type keyAgg struct { + plan KeyPlan + modelSet map[string]bool + wildcard bool + } + // provAgg accumulates one Bifrost provider and its keys, keyed by credential + // signature so deployments sharing a credential fold into one key. + type provAgg struct { + plan ProviderPlan + keys map[string]*keyAgg + keyOrder []string + } + + provByName := map[string]*provAgg{} + var provOrder []string + var skippedProviders []SkippedProvider + skippedProviderSet := map[string]bool{} + + var modelConfigs []ModelConfigPlan + mcIndex := map[string]int{} // modelConfigSignature -> index into modelConfigs + + // addModelConfig folds one deployment's tpm/rpm overrides into the global + // model config for its (provider, model), keeping the strictest limits. + addModelConfig := func(sourceName, rawModel string, b litellm.LiteLLMBudget, base, provider string) { + mc, _, err := modelConfigPlan(sourceName, rawModel, b, base, provider, cfg.MaxBudgetPeriod) + if err != nil || mc == nil { + return + } + sig := modelConfigSignature(*mc) + if i, ok := mcIndex[sig]; ok { + mergeModelConfig(&modelConfigs[i], *mc) + return + } + mcIndex[sig] = len(modelConfigs) + modelConfigs = append(modelConfigs, *mc) + } + + addSkippedProvider := func(provider, reason string) { + provider = strings.TrimSpace(provider) + if provider == "" { + provider = "unknown" + } + sig := provider + "\x00" + reason + if skippedProviderSet[sig] { + return + } + skippedProviderSet[sig] = true + skippedProviders = append(skippedProviders, SkippedProvider{Provider: provider, Reason: reason}) + } + + // ensureProv returns the accumulator for a provider, creating it on first + // use. Custom providers carry their base provider + URL; standard providers + // (including the self-hosted keyless ones) carry neither. + ensureProv := func(name string, isCustom bool, baseProvider, baseURL string) *provAgg { + pa := provByName[name] + if pa == nil { + plan := ProviderPlan{Name: name, IsCustom: isCustom} + if isCustom { + plan.BaseProvider = baseProvider + plan.BaseURL = baseURL + } + pa = &provAgg{plan: plan, keys: map[string]*keyAgg{}} + provByName[name] = pa + provOrder = append(provOrder, name) + } + return pa + } + + for _, m := range models { + in, ok := resolveModelMigrationInput(m, credByName, deployments, cfg) + if !ok { + continue + } + p := in.Params + rawModel := in.RawModel + dep := in.Deployment + apiKey := in.APIKey + apiBase := in.APIBase + credName := in.CredName + base := in.BaseProvider + + tpm, rpm := modelInfoRateLimits(m) + // Rate limits come from /model/info (unencrypted); the spend budget comes + // from the deployment's litellm_params (config plaintext / DB decrypted). + budget := litellm.LiteLLMBudget{TPMLimit: tpm, RPMLimit: rpm} + if dep != nil { + budget.MaxBudget = dep.MaxBudget + budget.BudgetDuration = dep.BudgetDuration + } + + // Full wildcard "*" routes to every provider: emit only a global, + // all-providers model config, never a concrete provider/key. + if rawModel == "*" { + addModelConfig(m.ModelName, rawModel, budget, base, base) + continue + } + + // Unsupported / unresolvable provider (incl. "*/..." provider globs): + // nothing to create. + if base == "" || !standardProviders[base] { + addSkippedProvider(base, fmt.Sprintf("unsupported or unresolvable provider for model %q", rawModel)) + continue + } + + // Self-hosted keyless providers (vllm/ollama): a standard provider + // whose server URL lives per key, so distinct base URLs become distinct + // keys. The model string is the full upstream id; only a leading + // "provider/" segment is stripped. + if selfHostedURLProviders[base] { + if apiBase == "" { + addSkippedProvider(base, fmt.Sprintf("missing api_base for self-hosted model %q", rawModel)) + continue // no server URL to route to + } + served := stripProviderPrefix(rawModel, base) + if served != "*" && strings.ContainsAny(served, "*?") { + addSkippedProvider(base, fmt.Sprintf("partial wildcard model %q is not supported", rawModel)) + continue + } + pa := ensureProv(base, false, "", "") + // vllm selects a key by exact served model, so each (URL, model) is + // its own key; ollama groups by URL and unions its models. + keySig := apiBase + if base == "vllm" { + keySig = apiBase + "\x00" + served + } + ka := pa.keys[keySig] + if ka == nil { + value, _ := keyValue(apiKey) + kp := KeyPlan{Name: base + "/" + selfHostedKeyName(base, apiBase, len(pa.keyOrder)), Value: value, URL: apiBase} + if base == "vllm" && served != "*" { + kp.VLLMModelName = served + } + ka = &keyAgg{plan: kp, modelSet: map[string]bool{}} + pa.keys[keySig] = ka + pa.keyOrder = append(pa.keyOrder, keySig) + } + if served == "*" { + ka.wildcard = true + } else { + ka.modelSet[served] = true + } + addModelConfig(m.ModelName, rawModel, budget, base, base) + continue + } + + // Azure, Bedrock, Vertex: per-key structured credentials. + // Azure always intercepted here (api_base is per-key endpoint, not a + // custom-provider URL). Bedrock/Vertex intercepted when no api_key is + // set (IAM / GCP auth); Bedrock API-key auth falls through to the + // standard path below. + if base == "azure" || ((base == "bedrock" || base == "vertex") && apiKey == "") { + model := trimModelPrefix(rawModel) + if model == "" || (model != "*" && strings.ContainsAny(model, "*?")) { + addSkippedProvider(base, fmt.Sprintf("partial wildcard model %q is not supported", rawModel)) + continue + } + + pa := ensureProv(base, false, "", "") + + var credSig string + var kp KeyPlan + + switch base { + case "azure": + endpoint, _ := keyValue(in.APIBase) + if endpoint == "" { + addSkippedProvider("azure", fmt.Sprintf("no api_base for azure model %q; endpoint required", rawModel)) + continue + } + apiKeyVal, _ := keyValue(apiKey) + if apiKeyVal == "" && in.AzureClientID == "" { + addSkippedProvider("azure", fmt.Sprintf("no api_key or azure_client_id for azure model %q", rawModel)) + continue + } + credSig = "azure:" + apiKey + ":" + in.APIBase + ":" + in.AzureClientID + azCfg := &BifrostAzureKeyConfig{Endpoint: endpoint} + if in.AzureClientID != "" { + azCfg.ClientID = ptr(in.AzureClientID) + if in.AzureClientSecret != "" { + azCfg.ClientSecret = ptr(in.AzureClientSecret) + } + if in.AzureTenantID != "" { + azCfg.TenantID = ptr(in.AzureTenantID) + } + } + kp = KeyPlan{Name: "azure/" + azureKeyName(in.APIBase, len(pa.keyOrder)), Value: apiKeyVal, AzureKeyConfig: azCfg} + + case "bedrock": + credSig = "bedrock:" + in.AWSAccessKeyID + ":" + in.AWSRoleName + ":" + in.AWSRegionName + bCfg := &BifrostBedrockKeyConfig{} + if in.AWSAccessKeyID != "" { + bCfg.AccessKey = in.AWSAccessKeyID + bCfg.SecretKey = in.AWSSecretAccessKey + } + if in.AWSSessionToken != "" { + bCfg.SessionToken = ptr(in.AWSSessionToken) + } + if in.AWSRegionName != "" { + bCfg.Region = ptr(in.AWSRegionName) + } + if in.AWSRoleName != "" { + bCfg.RoleARN = ptr(in.AWSRoleName) + if in.AWSSessionName != "" { + bCfg.RoleSessionName = ptr(in.AWSSessionName) + } + } + kp = KeyPlan{Name: "bedrock/" + bedrockKeyName(in, len(pa.keyOrder)), Value: "", BedrockKeyConfig: bCfg} + + case "vertex": + if in.VertexProject == "" { + addSkippedProvider("vertex", fmt.Sprintf("no vertex_project for vertex model %q", rawModel)) + continue + } + credSig = "vertex:" + in.VertexProject + ":" + in.VertexLocation + ":" + in.VertexCredentials + vCfg := &BifrostVertexKeyConfig{ + ProjectID: in.VertexProject, + Region: in.VertexLocation, + AuthCredentials: in.VertexCredentials, + } + kp = KeyPlan{Name: "vertex/" + vertexKeyName(in, len(pa.keyOrder)), Value: "", VertexKeyConfig: vCfg} + } + + ka := pa.keys[credSig] + if ka == nil { + ka = &keyAgg{plan: kp, modelSet: map[string]bool{}} + pa.keys[credSig] = ka + pa.keyOrder = append(pa.keyOrder, credSig) + } + if model == "*" { + ka.wildcard = true + } else { + ka.modelSet[model] = true + } + addModelConfig(m.ModelName, rawModel, budget, base, base) + continue + } + + model := trimModelPrefix(p.Model) + // Partial wildcard like openai/gpt-5*: not representable in Bifrost; ignore. + if model == "" || (model != "*" && strings.ContainsAny(model, "*?")) { + addSkippedProvider(base, fmt.Sprintf("partial wildcard model %q is not supported", rawModel)) + continue + } + + // A credential with a custom base URL becomes a Bifrost custom provider, + // but only for base providers Bifrost can wrap; others are ignored. + isCustom := apiBase != "" + provName := base + if isCustom { + if !customProviders[base] { + addSkippedProvider(base, fmt.Sprintf("custom api_base is not supported for model %q", rawModel)) + continue + } + provName = customProviderName(base, apiBase) + } + + pa := ensureProv(provName, isCustom, base, apiBase) + + // Deployments sharing a credential collapse into one key whose allowlist + // is the union of their models. Keyless inline deployments group by value. + credSig := credName + if credSig == "" { + credSig = "inline:" + apiKey + } + ka := pa.keys[credSig] + if ka == nil { + value, _ := keyValue(apiKey) + if value == "" { + // No credential resolved — Bifrost rejects keys with empty values. + // Still emit a model config (rate limits) but skip key creation. + addSkippedProvider(provName, fmt.Sprintf("no credential for model %q; skipping key", rawModel)) + addModelConfig(m.ModelName, rawModel, budget, base, provName) + continue + } + name := credName + if name == "" { + name = keyName(base, apiKey, len(pa.keyOrder)) + } + ka = &keyAgg{plan: KeyPlan{Name: provName + "/" + name, Value: value}, modelSet: map[string]bool{}} + pa.keys[credSig] = ka + pa.keyOrder = append(pa.keyOrder, credSig) + } + if model == "*" { + ka.wildcard = true + } else { + ka.modelSet[model] = true + } + + addModelConfig(m.ModelName, rawModel, budget, base, provName) + } + + plans := make([]ProviderPlan, 0, len(provOrder)) + for _, name := range provOrder { + pa := provByName[name] + plan := pa.plan + for _, sig := range pa.keyOrder { + ka := pa.keys[sig] + if ka.wildcard { + ka.plan.Models = []string{"*"} + } else { + allowed := make([]string, 0, len(ka.modelSet)) + for mdl := range ka.modelSet { + allowed = append(allowed, mdl) + } + sort.Strings(allowed) + ka.plan.Models = allowed + } + plan.Keys = append(plan.Keys, ka.plan) + } + plans = append(plans, plan) + } + + return plans, modelConfigs, skippedProviders +} + +// modelInfoRateLimits returns the deployment's tpm/rpm overrides. The model_info +// block is the base; the litellm_params block carries the deployment-level +// limits and overrides it. +func modelInfoRateLimits(m litellm.LiteLLMModelInfo) (tpm, rpm *int64) { + if m.ModelInfo != nil { + tpm = m.ModelInfo.Tpm + rpm = m.ModelInfo.Rpm + } + if m.LiteLLMParams != nil { + if m.LiteLLMParams.Tpm != nil { + tpm = m.LiteLLMParams.Tpm + } + if m.LiteLLMParams.Rpm != nil { + rpm = m.LiteLLMParams.Rpm + } + } + return tpm, rpm +} + +// ProviderKeyRef is a (provider, key-UUID) pair used to attach specific Bifrost +// provider keys to a virtual key. +type ProviderKeyRef struct { + Provider string + KeyID string +} + +// BuildKeyModelIndex builds a model-name → []ProviderKeyRef index from the +// provider/key plans that were migrated and the UUIDs returned by Bifrost. +// A key whose Models list is ["*"] (wildcard) is reachable by every model name; +// such keys are collected in wildcardKeys and merged into every lookup result. +// keyIDByName maps full key names (e.g. "openai/OPENAI_KEY") to Bifrost UUIDs. +func BuildKeyModelIndex(plans []ProviderPlan, keyIDByName map[string]string) (map[string][]ProviderKeyRef, []ProviderKeyRef) { + specific := map[string][]ProviderKeyRef{} // model → refs for keys with explicit model lists + var wildcardKeys []ProviderKeyRef // keys that serve all models + + addRef := func(model, provider, keyID string) { + ref := ProviderKeyRef{Provider: provider, KeyID: keyID} + specific[model] = append(specific[model], ref) + } + + for _, p := range plans { + for _, k := range p.Keys { + keyID, ok := keyIDByName[k.Name] + if !ok { + continue // key was not successfully created in Bifrost; skip + } + isWildcard := len(k.Models) == 1 && k.Models[0] == "*" + if isWildcard { + wildcardKeys = append(wildcardKeys, ProviderKeyRef{Provider: p.Name, KeyID: keyID}) + } else { + for _, m := range k.Models { + addRef(m, p.Name, keyID) + } + } + } + } + return specific, wildcardKeys +} + +// ModelRef is a Bifrost (provider, model) pair that a LiteLLM public model_name +// resolves to. +type ModelRef struct { + Provider string + Model string +} + +// BuildModelIndex maps each LiteLLM public model_name to the Bifrost +// (provider, model) pairs it resolves to, and returns the sorted set of target +// providers. It also indexes by the upstream model name (trimmed from its +// provider prefix), so a VK with models:["gpt-4o"] resolves even when the +// deployment model_name is "openai-gpt-4o". model_group_alias entries from the +// config (e.g. "fast-cheap" -> "openai-gpt-4o") are followed transitively. +func BuildModelIndex(models []litellm.LiteLLMModelInfo, credByName map[string]*litellm.LiteLLMModelCredential, deployments map[string]*litellm.Deployment, cfg MigrationRunConfig, litellmCfg *litellm.LiteLLMConfig) (map[string][]ModelRef, []string) { + index := map[string][]ModelRef{} + seen := map[string]map[string]bool{} // source_name -> "provider|model" set + providerSet := map[string]bool{} + + addRef := func(sourceName, provider, model string) { + if seen[sourceName] == nil { + seen[sourceName] = map[string]bool{} + } + sig := provider + "|" + model + if seen[sourceName][sig] { + return + } + seen[sourceName][sig] = true + index[sourceName] = append(index[sourceName], ModelRef{Provider: provider, Model: model}) + if provider != "*" { + providerSet[provider] = true + } + } + + for _, m := range models { + in, ok := resolveModelMigrationInput(m, credByName, deployments, cfg) + if !ok { + continue + } + base := in.BaseProvider + rawModel := in.RawModel + + if rawModel == "*" { + addRef(m.ModelName, "*", "*") + continue + } + if base == "" || !standardProviders[base] { + continue + } + + if selfHostedURLProviders[base] { + if in.APIBase == "" { + continue + } + model := stripProviderPrefix(rawModel, base) + if model != "*" && strings.ContainsAny(model, "*?") { + continue + } + addRef(m.ModelName, base, model) + continue + } + + model := trimModelPrefix(rawModel) + if model == "" || (model != "*" && strings.ContainsAny(model, "*?")) { + continue + } + + provider := base + if in.APIBase != "" { + if !customProviders[base] { + continue + } + provider = customProviderName(base, in.APIBase) + } + addRef(m.ModelName, provider, model) + + // Also index by the upstream model name (e.g. "gpt-4o") so VK + // allowed_model lists that reference the actual model name — rather than + // the LiteLLM deployment model_name — resolve to the right provider. + if model != "*" && model != m.ModelName { + addRef(model, provider, model) + } + } + + // Follow model_group_alias entries: alias -> deployment model_name. + // e.g. "fast-cheap" -> "openai-gpt-4o" -> refs already in index. + if litellmCfg != nil { + for alias, target := range litellmCfg.RouterSettings.ModelGroupAlias { + if refs, ok := index[target]; ok { + for _, r := range refs { + addRef(alias, r.Provider, r.Model) + } + } + } + } + + var providers []string + for p := range providerSet { + providers = append(providers, p) + } + sort.Strings(providers) + return index, providers +} + +func resolveModelMigrationInput(m litellm.LiteLLMModelInfo, credByName map[string]*litellm.LiteLLMModelCredential, deployments map[string]*litellm.Deployment, cfg MigrationRunConfig) (modelMigrationInput, bool) { + if m.LiteLLMParams == nil { + return modelMigrationInput{}, false + } + p := m.LiteLLMParams + rawModel := strings.TrimSpace(p.Model) + + var dep *litellm.Deployment + if m.ModelInfo != nil { + dep = deployments[m.ModelInfo.Id] + } + if dep == nil { + dep = deployments[m.ModelName] + } + + var apiKey, apiBase, credProvider, credName string + var cred *litellm.LiteLLMModelCredential + if p.LiteLLMCredentialName != nil { + credName = strings.TrimSpace(*p.LiteLLMCredentialName) + } + if credName != "" { + cred = credByName[credName] + } else { + if dep != nil { + cred = dep.InlineCredential() + } + // LiteLLM resolves os.environ/ references before storing to DB, so if the + // env var was unset at LiteLLM startup the DB entry has an empty api_key. + // Fall back to the config-keyed deployment (by model_name) which still + // holds the raw env ref string we can pass to Bifrost as "env.VAR". + if cred == nil { + if cfgDep := deployments[m.ModelName]; cfgDep != nil && cfgDep != dep { + cred = cfgDep.InlineCredential() + } + } + } + if cred != nil { + if cred.CredentialValues != nil { + if cred.CredentialValues.ApiKey != nil { + apiKey = strings.TrimSpace(*cred.CredentialValues.ApiKey) + } + if cred.CredentialValues.ApiBase != nil { + apiBase = strings.TrimSpace(*cred.CredentialValues.ApiBase) + } + } + if cred.CredentialInfo != nil { + credProvider = strings.ToLower(strings.TrimSpace(cred.CredentialInfo.CustomLLMProvider)) + } + } + + base := deriveProvider(p.CustomLLMProvider, p.Model, cfg.DefaultProvider) + if (base == "" || !standardProviders[base]) && standardProviders[credProvider] { + base = credProvider + } + + // Resolve Azure/Bedrock/Vertex structured credential fields. + // Try the DB deployment value first (resolved literal); fall back to the + // config deployment which still carries "os.environ/VAR" env refs that + // keyValue converts to the "env.VAR" format Bifrost resolves at runtime. + cfgDep := deployments[m.ModelName] + sf := func(f func(*litellm.Deployment) string) string { + var dbVal, cfgVal string + if dep != nil { + dbVal = f(dep) + } + if cfgDep != nil && cfgDep != dep { + cfgVal = f(cfgDep) + } + if v, _ := keyValue(dbVal); v != "" { + return v + } + v, _ := keyValue(cfgVal) + return v + } + + in := modelMigrationInput{ + Params: p, + Deployment: dep, + RawModel: rawModel, + APIKey: apiKey, + APIBase: apiBase, + CredName: credName, + BaseProvider: base, + // Azure + AzureClientID: sf(func(d *litellm.Deployment) string { return d.AzureClientID }), + AzureClientSecret: sf(func(d *litellm.Deployment) string { return d.AzureClientSecret }), + AzureTenantID: sf(func(d *litellm.Deployment) string { return d.AzureTenantID }), + AzureADToken: sf(func(d *litellm.Deployment) string { return d.AzureADToken }), + // Bedrock + AWSAccessKeyID: sf(func(d *litellm.Deployment) string { return d.AWSAccessKeyID }), + AWSSecretAccessKey: sf(func(d *litellm.Deployment) string { return d.AWSSecretAccessKey }), + AWSRegionName: sf(func(d *litellm.Deployment) string { return d.AWSRegionName }), + AWSRoleName: sf(func(d *litellm.Deployment) string { return d.AWSRoleName }), + AWSSessionName: sf(func(d *litellm.Deployment) string { return d.AWSSessionName }), + AWSSessionToken: sf(func(d *litellm.Deployment) string { return d.AWSSessionToken }), + // Vertex + VertexProject: sf(func(d *litellm.Deployment) string { return d.VertexProject }), + VertexLocation: sf(func(d *litellm.Deployment) string { return d.VertexLocation }), + VertexCredentials: sf(func(d *litellm.Deployment) string { return d.VertexCredentials }), + } + return in, true +} + +// modelConfigPlan maps a LiteLLM deployment's governance fields (carried in b) +// to one global Bifrost model config keyed by the actual upstream model name +// because Bifrost model configs are unique by model_name. sourceName is the +// public model_name and rawModel is litellm_params.model. +func modelConfigPlan(sourceName, rawModel string, b litellm.LiteLLMBudget, base, provider, maxBudgetPeriod string) (*ModelConfigPlan, string, error) { + budget, err := toBudget(b, maxBudgetPeriod) + if err != nil { + return nil, "", err + } + rateLimit := toRateLimit(b) + if budget == nil && rateLimit == nil { + return nil, "", nil + } + + modelName := trimModelPrefix(rawModel) + if modelName == "" { + return nil, fmt.Sprintf("%s (%s): empty model_name is not migrated", sourceName, strings.TrimSpace(rawModel)), nil + } + if modelName != "*" && strings.ContainsAny(modelName, "*?") { + return nil, fmt.Sprintf("%s (%s): partial wildcard model limits are not migrated", sourceName, strings.TrimSpace(rawModel)), nil + } + + var providerPtr *string + if strings.TrimSpace(rawModel) == "*" { + provider = "" + } + if provider != "" { + p := provider + providerPtr = &p + } + if base == "*" { + return nil, fmt.Sprintf("%s (%s): wildcard provider patterns are not migrated", sourceName, strings.TrimSpace(rawModel)), nil + } + + plan := &ModelConfigPlan{SourceName: strings.TrimSpace(sourceName), ModelName: modelName, Provider: providerPtr, RateLimit: rateLimit} + if budget != nil { + plan.Budgets = []BifrostCreateBudgetRequest{*budget} + } + return plan, "", nil +} + +// mergeModelConfig folds duplicate Bifrost model configs into the stricter +// limit because Bifrost accepts only one config per model_name. +func mergeModelConfig(existing *ModelConfigPlan, next ModelConfigPlan) { + existing.SourceName = appendSource(existing.SourceName, next.SourceName) + existing.Provider = mergeModelConfigProvider(existing.Provider, next.Provider) + existing.RateLimit = mergeRateLimit(existing.RateLimit, next.RateLimit) + existing.Budgets = mergeBudgets(existing.Budgets, next.Budgets) +} + +// appendSource records every LiteLLM deployment that contributed to a merged +// model config for dry-run visibility. +func appendSource(existing, next string) string { + if existing == "" { + return next + } + if next == "" || strings.Contains(","+existing+",", ","+next+",") { + return existing + } + return existing + "," + next +} + +// mergeModelConfigProvider preserves the scoped provider for duplicate configs; +// duplicates only merge when the provider already matches. +func mergeModelConfigProvider(existing, next *string) *string { + if existing == nil || next == nil { + return nil + } + provider := *existing + return &provider +} + +// mergeRateLimit keeps the lowest non-nil RPM and TPM from duplicate LiteLLM +// deployments for the same actual model. +func mergeRateLimit(existing, next *BifrostCreateRateLimitRequest) *BifrostCreateRateLimitRequest { + if existing == nil { + return next + } + if next == nil { + return existing + } + requestLimit, requestReset := minRateLimitDimension(existing.RequestMaxLimit, existing.RequestResetDuration, next.RequestMaxLimit, next.RequestResetDuration) + tokenLimit, tokenReset := minRateLimitDimension(existing.TokenMaxLimit, existing.TokenResetDuration, next.TokenMaxLimit, next.TokenResetDuration) + return &BifrostCreateRateLimitRequest{ + RequestMaxLimit: requestLimit, + RequestResetDuration: requestReset, + TokenMaxLimit: tokenLimit, + TokenResetDuration: tokenReset, + } +} + +// minRateLimitDimension returns the stricter limit and its required reset +// duration for one rate-limit dimension. +func minRateLimitDimension(a *int64, aReset *string, b *int64, bReset *string) (*int64, *string) { + if a == nil { + return b, bReset + } + if b == nil || *a <= *b { + return a, aReset + } + return b, bReset +} + +// mergeBudgets keeps a single lowest max budget when duplicate model configs +// contain budget limits, preserving the reset duration attached to that budget. +func mergeBudgets(existing, next []BifrostCreateBudgetRequest) []BifrostCreateBudgetRequest { + if len(existing) == 0 { + return next + } + if len(next) == 0 { + return existing + } + if next[0].MaxLimit < existing[0].MaxLimit { + return next + } + return existing +} + +// modelConfigSignature returns the Bifrost uniqueness key for a global model +// config, intentionally ignoring the governance payload. +func modelConfigSignature(mc ModelConfigPlan) string { + provider := "" + if mc.Provider != nil { + provider = *mc.Provider + } + return provider + "|" + mc.ModelName +} + +// deriveProvider resolves the Bifrost base provider for a deployment: +// custom_llm_provider, else the prefix before the first "/" in model, else the +// default. The result is lowercased and normalised through liteLLMProviderAliases +// so LiteLLM names like "vertex_ai" map to Bifrost names like "vertex". +func deriveProvider(customLLMProvider, model, def string) string { + var raw string + if customLLMProvider != "" { + raw = strings.ToLower(customLLMProvider) + } else if model == "" { + return "" + } else if i := strings.Index(model, "/"); i > 0 { + raw = strings.ToLower(model[:i]) + } else { + raw = strings.ToLower(def) + } + if canonical, ok := liteLLMProviderAliases[raw]; ok { + return canonical + } + return raw +} + +// customProviderName synthesizes a unique, stable provider name for a (base, apiBase) pair +// by hashing the full URL with SHA-256 and using the first 8 bytes as a 16-char hex suffix. +// Format: "-<16 hex chars>", e.g. "openai-3f1a9b2c4d5e6f7a". +// Max length: len("anthropic") + 1 + 16 = 26, well within the varchar(50) column limit. +// The name is deterministic: same inputs always produce the same name, so the migration +// is idempotent across runs. +func customProviderName(base, apiBase string) string { + h := sha256.Sum256([]byte(base + "\x00" + apiBase)) + return base + "-" + hex.EncodeToString(h[:8]) +} + +// keyValue maps a LiteLLM api_key onto a Bifrost key value string and reports +// whether it is a literal (plaintext) secret. "os.environ/FOO" and "env.FOO" +// become "env.FOO"; an empty key stays empty (keyless providers are valid); +// anything else is carried as a literal. +func keyValue(apiKey string) (value string, literal bool) { + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return "", false + } + if name, ok := strings.CutPrefix(apiKey, "os.environ/"); ok { + return "env." + name, false + } + if strings.HasPrefix(apiKey, "env.") { + return apiKey, false + } + return apiKey, true +} + +// stripProviderPrefix removes only a leading "provider/" segment from a +// self-hosted model string, preserving HF-style repo ids +// ("meta-llama/Llama-3.1-8B-Instruct" stays intact; "vllm/llama3" -> "llama3"). +func stripProviderPrefix(model, provider string) string { + if rest, ok := strings.CutPrefix(model, provider+"/"); ok { + return rest + } + return model +} + +// selfHostedKeyName names a per-URL key for a self-hosted provider from its +// server host, falling back to a stable per-provider index. +func selfHostedKeyName(base, apiBase string, idx int) string { + host := apiBase + if u, err := url.Parse(apiBase); err == nil && u.Host != "" { + host = u.Host + } + slug := strings.Trim(nonAlphanumRe.ReplaceAllString(strings.ToLower(host), "-"), "-") + if slug == "" { + return fmt.Sprintf("%s-key-%d", base, idx) + } + name := base + "-" + slug + if len(name) > maxProviderNameLen { + name = name[:maxProviderNameLen] + } + return name +} + +// keyName produces a readable key name from the credential: the env var name +// for env refs, else a stable per-provider fallback. +func keyName(base, apiKey string, idx int) string { + apiKey = strings.TrimSpace(apiKey) + if name, ok := strings.CutPrefix(apiKey, "os.environ/"); ok { + return name + } + if name, ok := strings.CutPrefix(apiKey, "env."); ok { + return name + } + if apiKey == "" { + return base + "-default" + } + return fmt.Sprintf("%s-key-%d", base, idx) +} + +// azureKeyName derives a stable key name for an Azure deployment from its +// endpoint URL hostname, e.g. "https://myazure.openai.azure.com/" → "myazure-openai-azure-com". +func azureKeyName(apiBase string, idx int) string { + return selfHostedKeyName("azure", apiBase, idx) +} + +// bedrockKeyName derives a stable key name from the Bedrock auth fields. +// Prefers role ARN (last path segment) > access key env var name > fallback index. +func bedrockKeyName(in modelMigrationInput, idx int) string { + if in.AWSRoleName != "" { + name := in.AWSRoleName + if n, ok := strings.CutPrefix(name, "env."); ok { + return n + } + // ARN: arn:aws:iam::123456789012:role/MyRole → "MyRole" + if i := strings.LastIndex(name, "/"); i >= 0 { + if seg := name[i+1:]; seg != "" { + return seg + } + } + return name + } + if in.AWSAccessKeyID != "" { + name := in.AWSAccessKeyID + if n, ok := strings.CutPrefix(name, "env."); ok { + return n + } + return fmt.Sprintf("bedrock-key-%d", idx) + } + return fmt.Sprintf("bedrock-default-%d", idx) +} + +// vertexKeyName derives a stable key name from the Vertex project and location. +// e.g. project="my-project", location="us-central1" → "my-project-us-central1". +func vertexKeyName(in modelMigrationInput, idx int) string { + project := in.VertexProject + if n, ok := strings.CutPrefix(project, "env."); ok { + project = n + } + location := in.VertexLocation + if n, ok := strings.CutPrefix(location, "env."); ok { + location = n + } + slug := strings.Trim(nonAlphanumRe.ReplaceAllString(strings.ToLower(project+"-"+location), "-"), "-") + if slug == "" { + return fmt.Sprintf("vertex-key-%d", idx) + } + if len(slug) > maxProviderNameLen-7 { // room for "vertex/" prefix in full key name + slug = slug[:maxProviderNameLen-7] + } + return slug +} diff --git a/scripts/litellm-to-bifrost/modelconformance_test.go b/scripts/litellm-to-bifrost/modelconformance_test.go new file mode 100644 index 0000000000..17e061234b --- /dev/null +++ b/scripts/litellm-to-bifrost/modelconformance_test.go @@ -0,0 +1,811 @@ +// Run all: go test -run TestConformance_Models -count=1 -v . +// Run one: go test -run 'TestConformance_Models/wildcard_linked' -count=1 -v . +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "reflect" + "sort" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// ---- live test harness ----------------------------------------------------- + +// conformanceEnv builds the migration config from env. +func conformanceEnv(t *testing.T) MigrationRunConfig { + t.Helper() + for _, key := range []string{ + "LITELLM_MASTER_KEY", + "LITELLM_URL", + "LITELLM_CONFIG", + "LITELLM_DB_URL", + "BIFROST_URL", + "BIFROST_API_KEY", + } { + if strings.TrimSpace(os.Getenv(key)) == "" { + t.Skipf("skipping LiteLLM→Bifrost conformance test: %s is not set", key) + } + } + return NewMigrationRunConfig() +} + +// keepEntities leaves seeded entities behind for inspection (CONF_NO_CLEANUP=1). +func keepEntities() bool { return os.Getenv("CONF_NO_CLEANUP") == "1" } + +// rawRequest performs an HTTP request with an optional JSON body and bearer, +// returning the body, status code, and a non-2xx error. +func rawRequest(method, reqURL, apiKey string, body any) ([]byte, int, error) { + var rdr io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return nil, 0, err + } + rdr = bytes.NewReader(b) + } + req, err := http.NewRequest(method, reqURL, rdr) + if err != nil { + return nil, 0, err + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + resp, err := httpClient.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("%s %s: read body: %w", method, reqURL, err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return out, resp.StatusCode, fmt.Errorf("%s %s: status %d: %s", method, reqURL, resp.StatusCode, strings.TrimSpace(string(out))) + } + return out, resp.StatusCode, nil +} + +// doJSON performs a request and decodes the JSON response into T, failing on any +// transport or decode error. +func doJSON[T any](t *testing.T, method, reqURL, apiKey string, body any) T { + t.Helper() + out, _, err := rawRequest(method, reqURL, apiKey, body) + if err != nil { + t.Fatalf("%v", err) + } + var v T + if len(out) > 0 { + if err := json.Unmarshal(out, &v); err != nil { + t.Fatalf("decode %s %s: %v (body: %s)", method, reqURL, err, string(out)) + } + } + return v +} + +// ---- declarative case model ------------------------------------------------ + +// customBase is a real, resolvable host (Bifrost validates the base URL by +// resolving it). Reused across cases since each case cleans up after itself. +const customBase = "https://example-endpoint-production.up.railway.app/" + +// seedCred describes one LiteLLM credential to create for a case. +type seedCred struct { + name string // case-local name; suffixed unique per run + provider string // credential_info.custom_llm_provider (base provider) + key string // credential_values.api_key (env ref by default) + base string // credential_values.api_base; "" => no base URL +} + +// seedModel describes one LiteLLM deployment. Only set fields are sent; cred +// wires a litellm_credential_name, the rest are inline litellm_params. +type seedModel struct { + name string // public model_name + model string // litellm_params.model, e.g. "openai/gpt-4o" or "openai/*" + provider string // litellm_params.custom_llm_provider (self-hosted: LiteLLM only routes when the provider is here, not in credential_info) + cred string // case-local credential name to link, or "" for none + key string // inline api_key + base string // inline api_base + rpm, tpm int // inline per-deployment rate limits + maxBudget float64 // inline per-deployment budget cap + budgetDuration string // inline budget reset window, e.g. "30d" +} + +// wantProvider is an expected Bifrost provider: base provider type, base URL +// ("" => standard provider whose name equals base), and one allow-list per key +// (order-insensitive). Custom providers must match keys exactly; standard +// providers need only contain the expected allow-lists (the real provider may +// carry unrelated keys). +type wantProvider struct { + base string + url string + keys [][]string +} + +// wantModelConfig is an expected global Bifrost model config. +type wantModelConfig struct { + provider string // "" => all providers (nil in Bifrost) + model string // upstream model name, or "*" + rpm, tpm int64 // 0 => not expected + budget float64 // 0 => not expected + budgetReset string // expected reset_duration when budget set +} + +type modelCase struct { + name string + creds []seedCred + models []seedModel + providers []wantProvider + modelConfigs []wantModelConfig +} + +func TestConformance_Models(t *testing.T) { + env := conformanceEnv(t) + + // envKey backs linked credentials; the Bifrost key is named after the + // credential, so the literal masked value is all that matters. + const envKey = "os.environ/OPENAI_API_KEY" + // inlineKey backs inline deployments. Bifrost resolves env-ref key values at + // create time (so the var must be in Bifrost's environment) AND names the key + // after the env var, which must not collide with an existing Bifrost key — so + // this is a real env var that is not already configured as a Bifrost key. + const inlineKey = "os.environ/GROQ_API_KEY" + + cases := []modelCase{ + // --- standard providers, linked credentials (expected PASS) --------- + { + name: "standard_linked_openai", + creds: []seedCred{{"oa", "openai", envKey, ""}}, + models: []seedModel{{name: "m", model: "openai/gpt-4o", cred: "oa"}}, + providers: []wantProvider{{"openai", "", [][]string{{"gpt-4o"}}}}, + }, + { + name: "standard_linked_anthropic", + creds: []seedCred{{"an", "anthropic", envKey, ""}}, + models: []seedModel{{name: "m", model: "anthropic/claude-sonnet-4-20250514", cred: "an"}}, + providers: []wantProvider{{"anthropic", "", [][]string{{"claude-sonnet-4-20250514"}}}}, + }, + { + name: "standard_linked_gemini", + creds: []seedCred{{"gm", "gemini", envKey, ""}}, + models: []seedModel{{name: "m", model: "gemini/gemini-2.0-flash", cred: "gm"}}, + providers: []wantProvider{{"gemini", "", [][]string{{"gemini-2.0-flash"}}}}, + }, + { + name: "standard_linked_two_credentials", + creds: []seedCred{{"a", "openai", envKey, ""}, {"b", "openai", envKey, ""}}, + models: []seedModel{ + {name: "m1", model: "openai/gpt-4o", cred: "a"}, + {name: "m2", model: "openai/gpt-4o-mini", cred: "b"}, + }, + providers: []wantProvider{{"openai", "", [][]string{{"gpt-4o"}, {"gpt-4o-mini"}}}}, + }, + + // --- linked credential, multiple models on one key (expected PASS) -- + { + name: "linked_credential_unions_models", + creds: []seedCred{{"oa", "openai", envKey, ""}}, + models: []seedModel{ + {name: "m1", model: "openai/gpt-4o", cred: "oa"}, + {name: "m2", model: "openai/gpt-4o-mini", cred: "oa"}, + {name: "m3", model: "openai/o3", cred: "oa"}, + }, + providers: []wantProvider{{"openai", "", [][]string{{"gpt-4o", "gpt-4o-mini", "o3"}}}}, + }, + + // --- custom providers, linked credentials (expected PASS) ----------- + { + name: "custom_linked_openai", + creds: []seedCred{{"c", "openai", envKey, customBase}}, + models: []seedModel{{name: "m", model: "openai/gpt-4o", cred: "c"}}, + providers: []wantProvider{{"openai", customBase, [][]string{{"gpt-4o"}}}}, + }, + { + name: "custom_linked_anthropic", + creds: []seedCred{{"c", "anthropic", envKey, customBase}}, + models: []seedModel{{name: "m", model: "anthropic/claude-3-5-haiku-20241022", cred: "c"}}, + providers: []wantProvider{{"anthropic", customBase, [][]string{{"claude-3-5-haiku-20241022"}}}}, + }, + { + name: "custom_two_credentials_same_base", + creds: []seedCred{ + {"c1", "openai", envKey, customBase}, + {"c2", "openai", envKey, customBase}, + }, + models: []seedModel{ + {name: "m1", model: "openai/gpt-4o", cred: "c1"}, + {name: "m2", model: "openai/gpt-4o-mini", cred: "c2"}, + }, + providers: []wantProvider{{"openai", customBase, [][]string{{"gpt-4o"}, {"gpt-4o-mini"}}}}, + }, + + // --- wildcards (expected PASS) -------------------------------------- + { + name: "wildcard_linked", + creds: []seedCred{{"oa", "openai", envKey, ""}}, + models: []seedModel{{name: "m", model: "openai/*", cred: "oa"}}, + providers: []wantProvider{{"openai", "", [][]string{{"*"}}}}, + }, + { + name: "wildcard_custom_linked", + creds: []seedCred{{"c", "openai", envKey, customBase}}, + models: []seedModel{{name: "m", model: "openai/*", cred: "c"}}, + providers: []wantProvider{{"openai", customBase, [][]string{{"*"}}}}, + }, + + // --- inline credentials (resolved from the LiteLLM DB, decrypted with the + // salt key, since /model/info masks/omits them). inlineKey is a real + // env-ref Bifrost resolves at create time. ------------------------ + { + name: "inline_key_standard", + models: []seedModel{{name: "m", model: "openai/gpt-4o", key: inlineKey}}, + providers: []wantProvider{{"openai", "", [][]string{{"gpt-4o"}}}}, + }, + { + name: "inline_base_url_custom", + models: []seedModel{{name: "m", model: "openai/gpt-4o", key: inlineKey, base: customBase}}, + providers: []wantProvider{{"openai", customBase, [][]string{{"gpt-4o"}}}}, + }, + { + name: "wildcard_inline", + models: []seedModel{{name: "m", model: "openai/*", key: inlineKey}}, + providers: []wantProvider{{"openai", "", [][]string{{"*"}}}}, + }, + + // --- per-deployment budgets / rate limits. Rate limits come from + // /model/info (unencrypted); the budget (max_budget + budget_duration) + // comes from the deployment's litellm_params, resolved from the DB. --- + { + name: "rate_limit_rpm_tpm", + creds: []seedCred{{"oa", "openai", envKey, ""}}, + models: []seedModel{{name: "m", model: "openai/gpt-4o", cred: "oa", rpm: 200, tpm: 200000}}, + providers: []wantProvider{{"openai", "", [][]string{{"gpt-4o"}}}}, + modelConfigs: []wantModelConfig{{provider: "openai", model: "gpt-4o", rpm: 200, tpm: 200000}}, + }, + { + name: "budget_30d", + creds: []seedCred{{"oa", "openai", envKey, ""}}, + models: []seedModel{{name: "m", model: "openai/gpt-4o", cred: "oa", maxBudget: 50, budgetDuration: "30d"}}, + providers: []wantProvider{{"openai", "", [][]string{{"gpt-4o"}}}}, + modelConfigs: []wantModelConfig{{provider: "openai", model: "gpt-4o", budget: 50, budgetReset: "30d"}}, + }, + { + name: "budget_and_rate_limit", + creds: []seedCred{{"an", "anthropic", envKey, ""}}, + models: []seedModel{{name: "m", model: "anthropic/claude-sonnet-4-20250514", cred: "an", maxBudget: 10, budgetDuration: "1d", rpm: 50, tpm: 2000}}, + providers: []wantProvider{{"anthropic", "", [][]string{{"claude-sonnet-4-20250514"}}}}, + modelConfigs: []wantModelConfig{{provider: "anthropic", model: "claude-sonnet-4-20250514", rpm: 50, tpm: 2000, budget: 10, budgetReset: "1d"}}, + }, + + // --- self-hosted: vllm/ollama keep the standard provider and carry the + // server URL on a per-base-URL key. LiteLLM only routes these when the + // provider is in litellm_params (custom_llm_provider), not the + // credential, so the seed sets it there. --------------------------- + { + name: "vllm_base_url", + creds: []seedCred{{"c", "vllm", envKey, customBase}}, + models: []seedModel{{name: "m", model: "meta-llama/Llama-3.1-8B-Instruct", provider: "vllm", cred: "c"}}, + providers: []wantProvider{{"vllm", customBase, [][]string{{"meta-llama/Llama-3.1-8B-Instruct"}}}}, + }, + { + name: "ollama_base_url", + creds: []seedCred{{"c", "ollama", envKey, customBase}}, + models: []seedModel{{name: "m", model: "llama3.1", provider: "ollama", cred: "c"}}, + providers: []wantProvider{{"ollama", customBase, [][]string{{"llama3.1"}}}}, + }, + // (LiteLLM has no sglang provider — it is served as an OpenAI-compatible + // endpoint — so there is no self-hosted case beyond vllm/ollama.) + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Unique suffix keeps names from colliding across runs; the local + // names in the case wire up cred<->model links. + suffix := nameSuffix(tc.name) + for _, c := range tc.creds { + env.seedCred(t, c.name+suffix, c.provider, c.key, c.base) + } + for _, m := range tc.models { + if m.cred != "" { + m.cred += suffix + } + m.name += suffix + env.seedModel(t, m) + } + + env.migrateModels(t, suffix) + + for _, w := range tc.providers { + env.assertProvider(t, w) + } + for _, w := range tc.modelConfigs { + env.assertModelConfig(t, w) + } + }) + } +} + +// ---- seeding --------------------------------------------------------------- + +// seedCred creates a LiteLLM credential and registers its deletion. +func (e MigrationRunConfig) seedCred(t *testing.T, name, provider, key, base string) { + t.Helper() + values := map[string]any{"api_key": key} + if base != "" { + values["api_base"] = base + } + doJSON[struct{}](t, http.MethodPost, e.LiteLLMClient.BaseURL+"/credentials", e.LiteLLMClient.APIKey, map[string]any{ + "credential_name": name, + "credential_info": map[string]any{"custom_llm_provider": provider}, + "credential_values": values, + }) + if keepEntities() { + return + } + t.Cleanup(func() { + _, _, _ = rawRequest(http.MethodDelete, e.LiteLLMClient.BaseURL+"/credentials/"+name, e.LiteLLMClient.APIKey, nil) + }) +} + +// seedModel creates a LiteLLM deployment and registers its deletion by id. +func (e MigrationRunConfig) seedModel(t *testing.T, m seedModel) { + t.Helper() + params := map[string]any{"model": m.model} + if m.provider != "" { + params["custom_llm_provider"] = m.provider + } + if m.cred != "" { + params["litellm_credential_name"] = m.cred + } + if m.key != "" { + params["api_key"] = m.key + } + if m.base != "" { + params["api_base"] = m.base + } + if m.rpm != 0 { + params["rpm"] = m.rpm + } + if m.tpm != 0 { + params["tpm"] = m.tpm + } + if m.maxBudget != 0 { + params["max_budget"] = m.maxBudget + } + if m.budgetDuration != "" { + params["budget_duration"] = m.budgetDuration + } + id := doJSON[struct { + ModelID string `json:"model_id"` + }](t, http.MethodPost, e.LiteLLMClient.BaseURL+"/model/new", e.LiteLLMClient.APIKey, map[string]any{ + "model_name": m.name, + "litellm_params": params, + }).ModelID + if id == "" { + t.Fatalf("seed model %q: empty model_id in response", m.name) + } + if keepEntities() { + return + } + t.Cleanup(func() { + _, _, _ = rawRequest(http.MethodPost, e.LiteLLMClient.BaseURL+"/model/delete", e.LiteLLMClient.APIKey, map[string]any{"id": id}) + }) +} + +// ---- migration ------------------------------------------------------------- + +// migrateModels runs the real transform and writes the result to Bifrost, +// registering cleanup for every custom provider, standard-provider key and model +// config it creates. It migrates only the entities tagged with this run's +// suffix, so leftover entities from other runs (e.g. CONF_NO_CLEANUP) never leak +// in — each case is hermetic. +func (e MigrationRunConfig) migrateModels(t *testing.T, suffix string) { + t.Helper() + ctx := context.Background() + + all, err := e.LiteLLMClient.ListModelInfo(ctx) + if err != nil { + t.Fatalf("list model info: %v", err) + } + var models []litellm.LiteLLMModelInfo + for _, m := range all { + if strings.HasSuffix(m.ModelName, suffix) { + models = append(models, m) + } + } + + allCreds, err := e.LiteLLMClient.ListCredentials(ctx) + if err != nil { + t.Fatalf("list credentials: %v", err) + } + + credByName := map[string]*litellm.LiteLLMModelCredential{} + for i := range allCreds { + if strings.HasSuffix(allCreds[i].CredentialName, suffix) { + credByName[allCreds[i].CredentialName] = &allCreds[i] + } + } + + // Inline credentials and budgets are masked/omitted by the API, so resolve + // them from the LiteLLM database (decrypted with the salt key). + store, err := litellm.NewCredentialStore(ctx, e.LiteLLMDBURL, nil, e.LiteLLMSaltKey) + if err != nil { + t.Fatalf("open credential store: %v", err) + } + defer store.Close() + deployments, err := store.Deployments(ctx) + if err != nil { + t.Fatalf("resolve deployments: %v", err) + } + + plans, modelConfigs, _ := LiteLLMModelsToBifrostProviders(models, credByName, deployments, e) + + for _, p := range plans { + if err := e.BifrostClient.EnsureProvider(ctx, p); err != nil { + t.Fatalf("ensure provider %q: %v", p.Name, err) + } + // A custom provider is disposable; deleting it removes its keys too. A + // standard provider is shared, so clean up only the keys we added. + if p.IsCustom { + e.cleanupProvider(t, p.Name) + } + for _, k := range p.Keys { + if err := e.BifrostClient.CreateProviderKey(ctx, p.Name, k); err != nil { + t.Fatalf("create key %q/%q: %v", p.Name, k.Name, err) + } + if !p.IsCustom { + e.cleanupStandardKey(t, p.Name, k.Name) + } + } + } + for _, mc := range modelConfigs { + if err := e.BifrostClient.CreateModelConfig(ctx, mc); err != nil { + t.Fatalf("create model config %q: %v", modelConfigSignature(mc), err) + } + e.cleanupModelConfig(t, mc.ModelName, mc.Provider) + } +} + +// ---- assertions ------------------------------------------------------------ + +// assertProvider checks the migrated provider's type, base URL and key allow- +// lists. Custom providers require an exact key match; standard providers (incl. +// the self-hosted keyless ones) only require the expected allow-lists to be +// present. For self-hosted providers w.url is the per-key server URL, so the +// matching keys' *_key_config is checked too. +func (e MigrationRunConfig) assertProvider(t *testing.T, w wantProvider) { + t.Helper() + selfHosted := selfHostedURLProviders[w.base] + custom := w.url != "" && !selfHosted + name := w.base + if custom { + name = customProviderName(w.base, w.url) + } + + prov, found := e.fetchProvider(t, name) + if !found { + t.Errorf("provider %q not found after migration", name) + return + } + if custom { + if prov.CustomProviderConfig == nil || prov.CustomProviderConfig.BaseProviderType != w.base { + t.Errorf("provider %q: custom_provider_config = %+v, want base_provider_type %q", name, prov.CustomProviderConfig, w.base) + } + if prov.NetworkConfig.BaseURL != w.url { + t.Errorf("provider %q: base_url = %q, want %q", name, prov.NetworkConfig.BaseURL, w.url) + } + } else if prov.CustomProviderConfig != nil { + t.Errorf("provider %q: expected standard provider, got custom %+v", name, prov.CustomProviderConfig) + } + + keys := e.fetchProviderKeys(t, name) + got := normalizeLists(allowlistsOf(keys)) + want := normalizeLists(w.keys) + if custom { + if !reflect.DeepEqual(want, got) { + t.Errorf("provider %q key allow-lists = %v, want %v", name, got, want) + } + return + } + for _, wl := range want { + if !containsList(got, wl) { + t.Errorf("provider %q: missing expected key allow-list %v (got %v)", name, wl, got) + } + } + if selfHosted { + e.assertSelfHostedKeyURLs(t, w, keys) + } +} + +// assertSelfHostedKeyURLs verifies that, for each expected allow-list, a key +// exists carrying a per-key server URL in the provider-matching *_key_config +// (and, for vllm, the served model in model_name). Bifrost masks the URL value +// on read, so the URL is asserted present (non-empty, not env-sourced) rather +// than equal to w.url. +func (e MigrationRunConfig) assertSelfHostedKeyURLs(t *testing.T, w wantProvider, keys providerKeysResponse) { + t.Helper() + for _, wl := range normalizeLists(w.keys) { + matched := false + for _, k := range keys.Keys { + km := append([]string(nil), k.Models...) + sort.Strings(km) + if !reflect.DeepEqual(km, wl) { + continue + } + var gotURL maskedValue + var gotModel string + present := false + switch w.base { + case "vllm": + if k.VLLMKeyConfig != nil { + present, gotURL, gotModel = true, k.VLLMKeyConfig.URL, k.VLLMKeyConfig.ModelName + } + case "ollama": + if k.OllamaKeyConfig != nil { + present, gotURL = true, k.OllamaKeyConfig.URL + } + } + if !present || gotURL.Value == "" || gotURL.FromEnv { + continue + } + if w.base == "vllm" && len(wl) == 1 && gotModel != wl[0] { + continue + } + matched = true + break + } + if !matched { + t.Errorf("provider %q: no key with allow-list %v carrying a %s per-key url", w.base, wl, w.base) + } + } +} + +// assertModelConfig checks a global model config exists with the expected +// budget / rate-limit fields. +func (e MigrationRunConfig) assertModelConfig(t *testing.T, w wantModelConfig) { + t.Helper() + mc, found := e.fetchModelConfig(t, w.model, w.provider) + if !found { + t.Errorf("model config %s/%s not found after migration", orAll(w.provider), w.model) + return + } + if w.budget > 0 { + if len(mc.Budgets) == 0 { + t.Errorf("model config %s/%s: no budget, want %v/%s", orAll(w.provider), w.model, w.budget, w.budgetReset) + } else if mc.Budgets[0].MaxLimit != w.budget || mc.Budgets[0].ResetDuration != w.budgetReset { + t.Errorf("model config %s/%s: budget = %v/%s, want %v/%s", orAll(w.provider), w.model, + mc.Budgets[0].MaxLimit, mc.Budgets[0].ResetDuration, w.budget, w.budgetReset) + } + } + if w.rpm > 0 { + assertLimit(t, "rpm", w.rpm, rateLimitField(mc.RateLimit, func(r bifrostRateLimit) *int64 { return r.RequestMaxLimit })) + } + if w.tpm > 0 { + assertLimit(t, "tpm", w.tpm, rateLimitField(mc.RateLimit, func(r bifrostRateLimit) *int64 { return r.TokenMaxLimit })) + } +} + +func assertLimit(t *testing.T, name string, want int64, got *int64) { + t.Helper() + switch { + case got == nil: + t.Errorf("model config %s = unset, want %d", name, want) + case *got != want: + t.Errorf("model config %s = %d, want %d", name, *got, want) + } +} + +// ---- Bifrost reads --------------------------------------------------------- + +type providerResponse struct { + Name string `json:"name"` + NetworkConfig struct { + BaseURL string `json:"base_url"` + } `json:"network_config"` + CustomProviderConfig *struct { + BaseProviderType string `json:"base_provider_type"` + } `json:"custom_provider_config"` +} + +// maskedValue is Bifrost's wrapper for a secret-bearing field (key value, per-key +// URL): the plaintext is masked on read, so only presence / env-ref metadata is +// assertable. +type maskedValue struct { + Value string `json:"value"` + EnvVar string `json:"env_var"` + FromEnv bool `json:"from_env"` +} + +type providerKeysResponse struct { + Keys []struct { + ID string `json:"id"` + Name string `json:"name"` + Models []string `json:"models"` + VLLMKeyConfig *struct { + URL maskedValue `json:"url"` + ModelName string `json:"model_name"` + } `json:"vllm_key_config"` + OllamaKeyConfig *struct { + URL maskedValue `json:"url"` + } `json:"ollama_key_config"` + } `json:"keys"` +} + +type bifrostRateLimit struct { + TokenMaxLimit *int64 `json:"token_max_limit"` + RequestMaxLimit *int64 `json:"request_max_limit"` +} + +type bifrostModelConfig struct { + ID string `json:"id"` + ModelName string `json:"model_name"` + Provider *string `json:"provider"` + Budgets []struct { + MaxLimit float64 `json:"max_limit"` + ResetDuration string `json:"reset_duration"` + } `json:"budgets"` + RateLimit *bifrostRateLimit `json:"rate_limit"` +} + +// fetchProvider returns the provider and whether it exists (404 => not found). +func (e MigrationRunConfig) fetchProvider(t *testing.T, name string) (providerResponse, bool) { + t.Helper() + body, status, err := e.BifrostClient.doRequest(context.Background(), http.MethodGet, "/api/providers/"+name, nil) + if status == http.StatusNotFound { + return providerResponse{}, false + } + if err != nil { + t.Fatalf("get provider %q: %v", name, err) + } + var p providerResponse + if err := json.Unmarshal(body, &p); err != nil { + t.Fatalf("decode provider %q: %v", name, err) + } + return p, true +} + +func (e MigrationRunConfig) fetchProviderKeys(t *testing.T, name string) providerKeysResponse { + t.Helper() + body, err := e.BifrostClient.getJSON(context.Background(), "/api/providers/"+name+"/keys", "get provider keys "+name) + if err != nil { + t.Fatalf("%v", err) + } + var out providerKeysResponse + if err := json.Unmarshal(body, &out); err != nil { + t.Fatalf("decode provider keys %q: %v (body: %s)", name, err, string(body)) + } + return out +} + +// fetchModelConfig returns the model config matching model+provider, if any. +func (e MigrationRunConfig) fetchModelConfig(t *testing.T, model, provider string) (bifrostModelConfig, bool) { + t.Helper() + body, err := e.BifrostClient.getJSON(context.Background(), "/api/governance/model-configs?search="+url.QueryEscape(model), "get model config "+model) + if err != nil { + t.Fatalf("%v", err) + } + var resp struct { + ModelConfigs []bifrostModelConfig `json:"model_configs"` + } + if err := json.Unmarshal(body, &resp); err != nil { + t.Fatalf("decode model configs for %q: %v (body: %s)", model, err, string(body)) + } + for _, mc := range resp.ModelConfigs { + if mc.ModelName == model && providerEq(mc.Provider, provider) { + return mc, true + } + } + return bifrostModelConfig{}, false +} + +// ---- Bifrost cleanup ------------------------------------------------------- + +func (e MigrationRunConfig) cleanupProvider(t *testing.T, name string) { + if keepEntities() { + return + } + t.Cleanup(func() { + _, _, _ = e.BifrostClient.doRequest(context.Background(), http.MethodDelete, "/api/providers/"+name, nil) + }) +} + +func (e MigrationRunConfig) cleanupStandardKey(t *testing.T, provider, keyName string) { + if keepEntities() { + return + } + t.Cleanup(func() { + for _, k := range e.fetchProviderKeys(t, provider).Keys { + if k.Name == keyName { + _, _, _ = e.BifrostClient.doRequest(context.Background(), http.MethodDelete, "/api/providers/"+provider+"/keys/"+k.ID, nil) + } + } + }) +} + +func (e MigrationRunConfig) cleanupModelConfig(t *testing.T, model string, provider *string) { + if keepEntities() { + return + } + want := "" + if provider != nil { + want = *provider + } + t.Cleanup(func() { + if mc, ok := e.fetchModelConfig(t, model, want); ok { + _, _, _ = e.BifrostClient.doRequest(context.Background(), http.MethodDelete, "/api/governance/model-configs/"+mc.ID, nil) + } + }) +} + +// ---- small helpers --------------------------------------------------------- + +// nameSuffix returns a per-case unique suffix for seeded entity names. +func nameSuffix(name string) string { + return "-" + name + "-" + time.Now().Format("150405.000000000") +} + +func allowlistsOf(keys providerKeysResponse) [][]string { + out := make([][]string, 0, len(keys.Keys)) + for _, k := range keys.Keys { + out = append(out, k.Models) + } + return out +} + +// normalizeLists sorts each allow-list and then the set of lists, so comparison +// ignores ordering on both axes. +func normalizeLists(lists [][]string) [][]string { + out := make([][]string, 0, len(lists)) + for _, l := range lists { + c := append([]string(nil), l...) + sort.Strings(c) + out = append(out, c) + } + sort.Slice(out, func(i, j int) bool { + return strings.Join(out[i], ",") < strings.Join(out[j], ",") + }) + return out +} + +func containsList(haystack [][]string, needle []string) bool { + for _, h := range haystack { + if reflect.DeepEqual(h, needle) { + return true + } + } + return false +} + +func providerEq(got *string, want string) bool { + if want == "" { + return got == nil + } + return got != nil && *got == want +} + +func rateLimitField(rl *bifrostRateLimit, pick func(bifrostRateLimit) *int64) *int64 { + if rl == nil { + return nil + } + return pick(*rl) +} + +func orAll(provider string) string { + if provider == "" { + return "*" + } + return provider +} diff --git a/scripts/litellm-to-bifrost/organization.go b/scripts/litellm-to-bifrost/organization.go new file mode 100644 index 0000000000..0c19a2fa7b --- /dev/null +++ b/scripts/litellm-to-bifrost/organization.go @@ -0,0 +1,132 @@ +package main + +import ( + "fmt" + "regexp" + "strings" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// rateLimitResetWindow is the Bifrost reset duration for migrated rate limits. +// LiteLLM's tpm_limit / rpm_limit are defined per minute, so both the token and +// request limits reset every minute in Bifrost. +const rateLimitResetWindow = "1m" + +// defaultMaxBudgetPeriod is the Bifrost reset window used when a LiteLLM budget +// has a spend cap but no budget_duration. LiteLLM allows a never-resetting +// budget; Bifrost requires a reset window, so a long default stands in for +// "effectively never". +const defaultMaxBudgetPeriod = "10Y" + +// litellmDurationRe matches a LiteLLM budget_duration: a positive integer +// followed by a unit (s, m, h, d, w, mo). It is intentionally unanchored at the +// end to mirror LiteLLM's own lenient parser (e.g. "1hr" -> "1h"). Named +// durations like "daily" have no leading digit and therefore do not match. +var litellmDurationRe = regexp.MustCompile(`^(\d+)(mo|[smhdw])`) + +// LiteLLMOrganizationToBifrostCustomer transforms a LiteLLM organization into a Bifrost +// create-customer request. The mapping is pure (no I/O) so each field can be +// unit-tested in isolation. +// +// Mapping: +// - organization_alias -> name (required) +// - budget.max_budget -> budgets[0].max_limit (omitted when <= 0) +// - budget.budget_duration -> budgets[0].reset_duration ("mo" -> "M") +// - budget.tpm_limit -> rate_limit.token_max_limit (reset "1m") +// - budget.rpm_limit -> rate_limit.request_max_limit (reset "1m") +// +// Models and allowed MCP servers are intentionally not migrated: a Bifrost +// customer has no models, and MCP linkage is handled separately. +func LiteLLMOrganizationToBifrostCustomer(org litellm.LiteLLMOrganization, cfg MigrationRunConfig) (*BifrostCreateCustomerRequest, error) { + name := strings.TrimSpace(org.OrganizationAlias) + if name == "" { + return nil, fmt.Errorf("organization %q has no organization_alias; a customer name is required", org.OrganizationID) + } + + req := &BifrostCreateCustomerRequest{Name: name} + + if org.Budget != nil { + budget, err := toBudget(*org.Budget, cfg.MaxBudgetPeriod) + if err != nil { + return nil, fmt.Errorf("organization %q: %w", org.OrganizationID, err) + } + if budget != nil { + req.Budgets = []BifrostCreateBudgetRequest{*budget} + } + req.RateLimit = toRateLimit(*org.Budget) + } + + return req, nil +} + +// toBudget maps the LiteLLM spend cap to a Bifrost budget. +// +// - max_budget nil or <= 0 -> no budget (nil, nil); LiteLLM treats this as +// "no spend cap", and Bifrost requires a positive max_limit. +// - max_budget > 0 with a missing/blank budget_duration -> maxBudgetPeriod; +// LiteLLM allows a never-reset budget but Bifrost requires a window, so the +// operator-supplied default stands in. +// - max_budget > 0 with an unparseable budget_duration -> error. +func toBudget(b litellm.LiteLLMBudget, maxBudgetPeriod string) (*BifrostCreateBudgetRequest, error) { + if b.MaxBudget == nil || *b.MaxBudget <= 0 { + return nil, nil + } + + if b.BudgetDuration == nil || strings.TrimSpace(*b.BudgetDuration) == "" { + return &BifrostCreateBudgetRequest{MaxLimit: *b.MaxBudget, ResetDuration: maxBudgetPeriod}, nil + } + + reset, err := convertBudgetDuration(*b.BudgetDuration) + if err != nil { + return nil, err + } + + return &BifrostCreateBudgetRequest{MaxLimit: *b.MaxBudget, ResetDuration: reset}, nil +} + +// toRateLimit maps tpm_limit / rpm_limit onto a Bifrost rate limit. Each +// dimension is independent: a non-positive or nil limit is treated as "no +// limit" and omitted. Returns nil when neither dimension has a positive limit. +func toRateLimit(b litellm.LiteLLMBudget) *BifrostCreateRateLimitRequest { + rl := &BifrostCreateRateLimitRequest{} + set := false + + if b.TPMLimit != nil && *b.TPMLimit > 0 { + v := *b.TPMLimit + reset := rateLimitResetWindow + rl.TokenMaxLimit = &v + rl.TokenResetDuration = &reset + set = true + } + + if b.RPMLimit != nil && *b.RPMLimit > 0 { + v := *b.RPMLimit + reset := rateLimitResetWindow + rl.RequestMaxLimit = &v + rl.RequestResetDuration = &reset + set = true + } + + if !set { + return nil + } + return rl +} + +// convertBudgetDuration translates a LiteLLM budget_duration into Bifrost's +// duration format. The only unit difference is months: LiteLLM uses "mo", +// Bifrost uses "M". All other units (s, m, h, d, w) are identical. +func convertBudgetDuration(d string) (string, error) { + d = strings.TrimSpace(d) + m := litellmDurationRe.FindStringSubmatch(d) + if m == nil { + return "", fmt.Errorf("unsupported budget_duration %q: expected ", d) + } + + value, unit := m[1], m[2] + if unit == "mo" { + unit = "M" + } + return value + unit, nil +} diff --git a/scripts/litellm-to-bifrost/scripts/create_organizations.sh b/scripts/litellm-to-bifrost/scripts/create_organizations.sh new file mode 100755 index 0000000000..2b094c58fa --- /dev/null +++ b/scripts/litellm-to-bifrost/scripts/create_organizations.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +# +# create_organizations.sh - seed a running LiteLLM instance with organizations. +# +# Usage: +# LITELLM_URL=http://localhost:4000 LITELLM_MASTER_KEY=sk-1234 ./create_organizations.sh + +set -euo pipefail + +LITELLM_URL="${LITELLM_URL:-http://localhost:4000}" +LITELLM_MASTER_KEY="${LITELLM_MASTER_KEY:-sk-1234}" +DRY_RUN="${DRY_RUN:-0}" + +create_org() { + local description="$1" body="$2" + echo "==> ${description}" + if [[ "${DRY_RUN}" == "1" ]]; then + jq -c . <<<"${body}" 2>/dev/null || { echo " INVALID JSON"; return; } + echo "" + return + fi + curl -sS --fail-with-body \ + --location "${LITELLM_URL}/organization/new" \ + --header "Authorization: Bearer ${LITELLM_MASTER_KEY}" \ + --header "Content-Type: application/json" \ + --data "${body}" \ + && echo "" \ + || echo " (LiteLLM rejected this payload)" + echo "" +} + +# --- Happy path: every field populated ------------------------------------- +# Expected Bifrost customer: name "seed-full", budget {500, "1M"}, +# rate_limit {token 2000 / "1m", request 120 / "1m"}. +create_org "seed-full: budget + tpm + rpm + monthly reset" '{ + "organization_alias": "seed-full", + "max_budget": 500, + "budget_duration": "1mo", + "tpm_limit": 2000, + "rpm_limit": 120 +}' + +# --- Budget only ------------------------------------------------------------ +# Expected: budget {100, "30d"}, no rate_limit. +create_org "seed-budget-30d: budget only, 30d reset" '{ + "organization_alias": "seed-budget-30d", + "max_budget": 100, + "budget_duration": "30d" +}' + +# Expected: budget {250, "1M"} (mo -> M), no rate_limit. +create_org "seed-budget-monthly: budget only, 1mo reset" '{ + "organization_alias": "seed-budget-monthly", + "max_budget": 250, + "budget_duration": "1mo" +}' + +# --- LiteLLM UI "Reset Budget" dropdown values ----------------------------- +# The UI shows daily/weekly/monthly but stores 24h/7d/30d. These migrate +# straight through - Bifrost accepts the same units. +# Expected: budget {10, "24h"}. +create_org "seed-reset-daily: UI \"daily\" -> 24h" '{ + "organization_alias": "seed-reset-daily", + "max_budget": 10, + "budget_duration": "24h" +}' + +# Expected: budget {70, "7d"}. +create_org "seed-reset-weekly: UI \"weekly\" -> 7d" '{ + "organization_alias": "seed-reset-weekly", + "max_budget": 70, + "budget_duration": "7d" +}' + +# Expected: budget {300, "30d"}. +create_org "seed-reset-monthly: UI \"monthly\" -> 30d" '{ + "organization_alias": "seed-reset-monthly", + "max_budget": 300, + "budget_duration": "30d" +}' + +# --- Pure customer (no budget, no limits) ---------------------------------- +# Expected: name only, no budget, no rate_limit. +create_org "seed-pure: no budget, no limits" '{ + "organization_alias": "seed-pure" +}' + +# --- Rate limits only ------------------------------------------------------- +# Expected: rate_limit {token 1000 / "1m"} only. +create_org "seed-tpm-only: tpm only" '{ + "organization_alias": "seed-tpm-only", + "tpm_limit": 1000 +}' + +# Expected: rate_limit {request 60 / "1m"} only. +create_org "seed-rpm-only: rpm only" '{ + "organization_alias": "seed-rpm-only", + "rpm_limit": 60 +}' + +# Expected: rate_limit {token 1500 / "1m", request 90 / "1m"}, no budget. +create_org "seed-tpm-rpm: tpm + rpm, no budget" '{ + "organization_alias": "seed-tpm-rpm", + "tpm_limit": 1500, + "rpm_limit": 90 +}' + +# Expected: rate_limit nil. +create_org "seed-tpm-negative: negative tpm" '{ + "organization_alias": "seed-tpm-negative", + "tpm_limit": -100 +}' + +# Expected: rate_limit nil. +create_org "seed-tpm-zero: zero tpm" '{ + "organization_alias": "seed-tpm-zero", + "tpm_limit": 0 +}' + +# Expected: no budget, no rate_limit. +create_org "seed-budget-zero: zero max_budget" '{ + "organization_alias": "seed-budget-zero", + "max_budget": 0, + "budget_duration": "30d" +}' + +# Expected: no budget. +create_org "seed-budget-negative: negative max_budget" '{ + "organization_alias": "seed-budget-negative", + "max_budget": -50, + "budget_duration": "30d" +}' + +# Expected: max_budget set with no reset window. +create_org "seed-budget-noduration: max_budget without budget_duration" '{ + "organization_alias": "seed-budget-noduration", + "max_budget": 100 +}' + +echo "Seeding complete." \ No newline at end of file diff --git a/scripts/litellm-to-bifrost/scripts/create_teams.sh b/scripts/litellm-to-bifrost/scripts/create_teams.sh new file mode 100755 index 0000000000..4af4dee682 --- /dev/null +++ b/scripts/litellm-to-bifrost/scripts/create_teams.sh @@ -0,0 +1,319 @@ +#!/usr/bin/env bash +# +# create_teams.sh - seed a running LiteLLM instance with teams. +# +# Usage: +# LITELLM_URL=http://localhost:4000 LITELLM_MASTER_KEY=sk-1234 ./create_teams.sh + +set -euo pipefail + +LITELLM_URL="${LITELLM_URL:-http://localhost:4000}" +LITELLM_MASTER_KEY="${LITELLM_MASTER_KEY:-sk-1234}" +DRY_RUN="${DRY_RUN:-0}" + +# Stable ids reused across cases so the script is deterministic. +ORG_ID="seed-team-org" +ORG_ID_RESTRICTED="seed-team-org-restricted" +USER_ADMIN="seed-team-user-admin" +USER_MEMBER="seed-team-user-member" + +# post sends one management-API call. Args: METHOD, PATH, DESCRIPTION, BODY. +# In DRY_RUN mode it prints the payload instead of sending it. +post() { + local method="$1" path="$2" description="$3" body="$4" + echo "==> ${description}" + if [[ "${DRY_RUN}" == "1" ]]; then + jq -c . <<<"${body}" 2>/dev/null || { echo " INVALID JSON"; return; } + echo "" + return + fi + curl -sS --fail-with-body \ + --location "${LITELLM_URL}${path}" \ + --header "Authorization: Bearer ${LITELLM_MASTER_KEY}" \ + --header "Content-Type: application/json" \ + --request "${method}" \ + --data "${body}" \ + && echo "" \ + || echo " (LiteLLM rejected this payload — may be a duplicate id or an enterprise-only field)" + echo "" +} + +create_org() { post POST "/organization/new" "$1" "$2"; } +create_team() { post POST "/team/new" "$1" "$2"; } +create_user() { post POST "/user/new" "$1" "$2"; } + +# --------------------------------------------------------------------------- +# Prerequisites: a permissive parent org and two users for membership cases. +# --------------------------------------------------------------------------- +echo "### Prerequisites ###" +echo + +create_org "prereq org: ${ORG_ID}" '{ + "organization_id": "'"${ORG_ID}"'", + "organization_alias": "seed-team-org", + "models": ["all-proxy-models"], + "max_budget": 100000, + "budget_duration": "30d", + "tpm_limit": 100000000, + "rpm_limit": 1000000 +}' + +# A second org that exposes only a subset of proxy models (not all-proxy-models). +# Teams created inside it can only grant access to models within this allowlist. +create_org "prereq org (restricted models): ${ORG_ID_RESTRICTED}" '{ + "organization_id": "'"${ORG_ID_RESTRICTED}"'", + "organization_alias": "seed-team-org-restricted", + "models": ["openai-gpt-4o", "openai-gpt-4o-mini", "anthropic-claude-sonnet-4"], + "max_budget": 100000, + "budget_duration": "30d", + "tpm_limit": 100000000, + "rpm_limit": 1000000 +}' + +create_user "prereq user: ${USER_ADMIN}" '{ + "user_id": "'"${USER_ADMIN}"'", + "user_email": "seed-team-admin@example.com", + "user_role": "internal_user", + "auto_create_key": false +}' + +create_user "prereq user: ${USER_MEMBER}" '{ + "user_id": "'"${USER_MEMBER}"'", + "user_email": "seed-team-member@example.com", + "user_role": "internal_user", + "auto_create_key": false +}' + +echo "### Standalone teams (no organization) ###" +echo + +# --- Case 1: minimal — alias only ----------------------------------------- +create_team "case01 minimal: alias only" '{ + "team_alias": "seed-minimal" +}' + +# --- Case 2: explicit team_id ---------------------------------------------- +create_team "case02 explicit team_id" '{ + "team_id": "seed-team-explicit-id", + "team_alias": "seed-explicit-id" +}' + +# --- Case 3: members_with_roles (admin + user by user_id) ------------------ +create_team "case03 members_with_roles (by user_id)" '{ + "team_alias": "seed-members-by-id", + "members_with_roles": [ + {"role": "admin", "user_id": "'"${USER_ADMIN}"'"}, + {"role": "user", "user_id": "'"${USER_MEMBER}"'"} + ] +}' + +# --- Case 4: member by user_email (auto-creates the user if missing) -------- +create_team "case04 member by user_email" '{ + "team_alias": "seed-members-by-email", + "members_with_roles": [ + {"role": "admin", "user_email": "seed-team-admin@example.com"} + ] +}' + +# --- Case 5: budget — max + soft + duration -------------------------------- +create_team "case05 budget: max + soft + duration" '{ + "team_alias": "seed-budget", + "max_budget": 500, + "soft_budget": 400, + "budget_duration": "30d" +}' + +# --- Case 6: rate limits — tpm + rpm --------------------------------------- +create_team "case06 rate limits: tpm + rpm" '{ + "team_alias": "seed-rate-limits", + "tpm_limit": 2000, + "rpm_limit": 120 +}' + +# --- Case 7: rate-limit enforcement types ---------------------------------- +create_team "case07 limit types: guaranteed_throughput" '{ + "team_alias": "seed-limit-types", + "tpm_limit": 5000, + "rpm_limit": 300, + "tpm_limit_type": "guaranteed_throughput", + "rpm_limit_type": "best_effort_throughput" +}' + +# --- Case 8: per-model tpm/rpm limits -------------------------------------- +create_team "case08 per-model tpm/rpm limits" '{ + "team_alias": "seed-per-model-limits", + "model_tpm_limit": {"gpt-4o": 1000, "claude-3-5-sonnet": 2000}, + "model_rpm_limit": {"gpt-4o": 60, "claude-3-5-sonnet": 120} +}' + +# --- Case 9: model access restriction -------------------------------------- +create_team "case09 models restriction list" '{ + "team_alias": "seed-models-allowlist", + "models": ["gpt-4o", "claude-3-5-sonnet", "text-embedding-3-small"] +}' + +# --- Case 10: model aliases (team-based routing) --------------------------- +create_team "case10 model_aliases" '{ + "team_alias": "seed-model-aliases", + "models": ["gpt-4o"], + "model_aliases": {"gpt-4": "gpt-4o", "fast": "gpt-4o"} +}' + +# --- Case 11: metadata ------------------------------------------------------ +create_team "case11 metadata" '{ + "team_alias": "seed-metadata", + "metadata": {"department": "platform", "cost_center": "CC-1234", "extra_info": "seed"} +}' + +# --- Case 12: tags (spend tracking / tag routing) -------------------------- +create_team "case12 tags" '{ + "team_alias": "seed-tags", + "tags": ["production", "team-platform", "tier-1"] +}' + +# --- Case 13: blocked team -------------------------------------------------- +create_team "case13 blocked team" '{ + "team_alias": "seed-blocked", + "blocked": true +}' + +# --- Case 14: per-team-member controls ------------------------------------- +create_team "case14 per-member budget/limits/key-duration" '{ + "team_alias": "seed-member-controls", + "team_member_budget": 50, + "team_member_budget_duration": "30d", + "team_member_rpm_limit": 30, + "team_member_tpm_limit": 10000, + "team_member_key_duration": "7d" +}' + +# --- Case 15: concurrent budget windows (enterprise) ----------------------- +create_team "case15 budget_limits: multiple windows (enterprise)" '{ + "team_alias": "seed-budget-windows", + "budget_limits": [ + {"budget_limit": 10.0, "time_period": "1d"}, + {"budget_limit": 50.0, "time_period": "7d"} + ] +}' + +# --- Case 16: team_member_permissions (enterprise) ------------------------- +create_team "case16 team_member_permissions (enterprise)" '{ + "team_alias": "seed-member-permissions", + "team_member_permissions": ["/key/generate", "/key/update", "/key/delete"] +}' + +# --- Case 17: object_permission — vector stores (enterprise) --------------- +create_team "case17 object_permission: vector_stores (enterprise)" '{ + "team_alias": "seed-object-permission", + "object_permission": {"vector_stores": ["vector_store_1", "vector_store_2"]} +}' + +# --- Case 18: guardrails (enterprise) -------------------------------------- +create_team "case18 guardrails (enterprise)" '{ + "team_alias": "seed-guardrails", + "guardrails": ["my-pii-guard", "my-prompt-injection-guard"] +}' + +# --- Case 19: default models for new members ------------------------------- +create_team "case19 default_team_member_models" '{ + "team_alias": "seed-default-member-models", + "models": ["gpt-4o", "claude-3-5-sonnet"], + "default_team_member_models": ["gpt-4o"] +}' + +# --- Case 20: enforced file/batch expiry ----------------------------------- +create_team "case20 enforced file/batch expiry" '{ + "team_alias": "seed-file-expiry", + "enforced_file_expires_after": {"anchor": "created_at", "days": 30}, + "enforced_batch_output_expires_after": {"anchor": "created_at", "days": 7} +}' + +echo "### Teams inside the organization (${ORG_ID}) ###" +echo + +# --- Case 21: minimal team in org ------------------------------------------ +create_team "case21 in-org: minimal" '{ + "team_alias": "seed-org-minimal", + "organization_id": "'"${ORG_ID}"'" +}' + +# --- Case 22: in-org team with members (auto-added to the org) ------------- +create_team "case22 in-org: with members (auto-added to org)" '{ + "team_alias": "seed-org-members", + "organization_id": "'"${ORG_ID}"'", + "members_with_roles": [ + {"role": "admin", "user_id": "'"${USER_ADMIN}"'"}, + {"role": "user", "user_id": "'"${USER_MEMBER}"'"} + ] +}' + +# --- Case 23: in-org team with budget within the org's budget -------------- +create_team "case23 in-org: budget within org limits" '{ + "team_alias": "seed-org-budget", + "organization_id": "'"${ORG_ID}"'", + "max_budget": 1000, + "budget_duration": "30d", + "tpm_limit": 50000, + "rpm_limit": 500 +}' + +# --- Case 24: in-org kitchen sink ------------------------------------------ +create_team "case24 in-org: kitchen sink" '{ + "team_id": "seed-org-kitchen-sink", + "team_alias": "seed-org-kitchen-sink", + "organization_id": "'"${ORG_ID}"'", + "members_with_roles": [ + {"role": "admin", "user_id": "'"${USER_ADMIN}"'"}, + {"role": "user", "user_id": "'"${USER_MEMBER}"'"} + ], + "models": ["gpt-4o", "claude-3-5-sonnet"], + "model_aliases": {"gpt-4": "gpt-4o"}, + "max_budget": 2000, + "soft_budget": 1500, + "budget_duration": "30d", + "tpm_limit": 40000, + "rpm_limit": 400, + "team_member_budget": 100, + "team_member_budget_duration": "30d", + "tags": ["production", "org-team"], + "metadata": {"department": "research", "tier": "premium"} +}' + +echo "### Teams inside the restricted-models organization (${ORG_ID_RESTRICTED}) ###" +echo + +# --- Case 25: in-restricted-org team, no models (inherits org's subset) ----- +# The team grants no explicit models, so members are bounded by the org's +# allowlist (openai-gpt-4o, openai-gpt-4o-mini, anthropic-claude-sonnet-4). +create_team "case25 in-restricted-org: inherits org model subset" '{ + "team_alias": "seed-restricted-org-inherit", + "organization_id": "'"${ORG_ID_RESTRICTED}"'" +}' + +# --- Case 26: in-restricted-org team with an allowed-models subset ---------- +# The team narrows access further to a subset of the org's allowlist. +create_team "case26 in-restricted-org: team model subset" '{ + "team_alias": "seed-restricted-org-subset", + "organization_id": "'"${ORG_ID_RESTRICTED}"'", + "models": ["openai-gpt-4o", "anthropic-claude-sonnet-4"] +}' + +# --- Case 27: in-restricted-org kitchen sink with model subset ------------- +create_team "case27 in-restricted-org: subset + aliases + members" '{ + "team_id": "seed-restricted-org-kitchen-sink", + "team_alias": "seed-restricted-org-kitchen-sink", + "organization_id": "'"${ORG_ID_RESTRICTED}"'", + "members_with_roles": [ + {"role": "admin", "user_id": "'"${USER_ADMIN}"'"}, + {"role": "user", "user_id": "'"${USER_MEMBER}"'"} + ], + "models": ["openai-gpt-4o", "openai-gpt-4o-mini"], + "model_aliases": {"gpt-4": "openai-gpt-4o", "fast": "openai-gpt-4o-mini"}, + "default_team_member_models": ["openai-gpt-4o-mini"], + "max_budget": 1000, + "budget_duration": "30d", + "tags": ["production", "restricted-org-team"], + "metadata": {"department": "platform", "tier": "restricted"} +}' + +echo "Seeding complete." \ No newline at end of file diff --git a/scripts/litellm-to-bifrost/scripts/create_virtual_keys.sh b/scripts/litellm-to-bifrost/scripts/create_virtual_keys.sh new file mode 100755 index 0000000000..a4d7398c60 --- /dev/null +++ b/scripts/litellm-to-bifrost/scripts/create_virtual_keys.sh @@ -0,0 +1,302 @@ +#!/usr/bin/env bash +# +# create_virtual_keys.sh - seed a running LiteLLM instance with virtual keys +# +# Usage: +# LITELLM_URL=http://localhost:4000 LITELLM_MASTER_KEY=sk-1234 ./create_virtual_keys.sh + +set -euo pipefail + +LITELLM_URL="${LITELLM_URL:-http://localhost:4000}" +LITELLM_MASTER_KEY="${LITELLM_MASTER_KEY:-sk-1234}" +DRY_RUN="${DRY_RUN:-0}" + +# Stable ids reused across cases so the script is deterministic. +ORG_ID="seed-key-org" +USER_ID="seed-key-user" +TEAM_ID="seed-key-team" +BUDGET_ID="seed-key-budget" + +# post sends one management-API call. Args: METHOD, PATH, DESCRIPTION, BODY. +# In DRY_RUN mode it prints the payload instead of sending it. Failures are +# logged and never abort the run (matches create_teams.sh). +post() { + local method="$1" path="$2" description="$3" body="$4" + echo "==> ${description}" + if [[ "${DRY_RUN}" == "1" ]]; then + jq -c . <<<"${body}" 2>/dev/null || { echo " INVALID JSON"; return; } + echo "" + return + fi + curl -sS --fail-with-body \ + --location "${LITELLM_URL}${path}" \ + --header "Authorization: Bearer ${LITELLM_MASTER_KEY}" \ + --header "Content-Type: application/json" \ + --request "${method}" \ + --data "${body}" \ + && echo "" \ + || echo " (LiteLLM rejected this payload — may be a duplicate id/key or an enterprise-only field)" + echo "" +} + +create_key() { post POST "/key/generate" "$1" "$2"; } + +# --------------------------------------------------------------------------- +# Prerequisites: org, user, budget, and a permissive team for scoped keys. +# --------------------------------------------------------------------------- +echo "### Prerequisites ###" +echo + +post POST "/organization/new" "prereq org: ${ORG_ID}" '{ + "organization_id": "'"${ORG_ID}"'", + "organization_alias": "seed-key-org", + "models": ["all-proxy-models"], + "max_budget": 100000, + "budget_duration": "30d" +}' + +post POST "/user/new" "prereq user: ${USER_ID}" '{ + "user_id": "'"${USER_ID}"'", + "user_email": "seed-key-user@example.com", + "user_role": "internal_user", + "auto_create_key": false +}' + +post POST "/budget/new" "prereq budget: ${BUDGET_ID}" '{ + "budget_id": "'"${BUDGET_ID}"'", + "max_budget": 250, + "budget_duration": "30d", + "tpm_limit": 5000, + "rpm_limit": 300 +}' + +# Permissive team (empty models = all models allowed) +post POST "/team/new" "prereq team: ${TEAM_ID}" '{ + "team_id": "'"${TEAM_ID}"'", + "team_alias": "seed-key-team", + "organization_id": "'"${ORG_ID}"'" +}' + +echo "### Standalone keys ###" +echo + +# --- Case 1: minimal — auto-generated key, no params ----------------------- +create_key "case01 minimal: auto key" '{}' + +# --- Case 2: key_alias ----------------------------------------------------- +create_key "case02 key_alias" '{ + "key_alias": "seed-key-alias" +}' + +# --- Case 3: custom key value (NOT idempotent) ----------------------------- +create_key "case03 custom key value" '{ + "key": "sk-seed-custom-0001", + "key_alias": "seed-custom-value" +}' + +# --- Case 4: expiry duration ----------------------------------------------- +create_key "case04 expiry: duration 30d" '{ + "key_alias": "seed-expiry", + "duration": "30d" +}' + +# --- Case 5: model allowlist ----------------------------------------------- +create_key "case05 models allowlist" '{ + "key_alias": "seed-models", + "models": ["gpt-4o", "claude-3-5-sonnet", "text-embedding-3-small"] +}' + +# --- Case 6: model aliases (upgrade/downgrade) ----------------------------- +create_key "case06 model aliases" '{ + "key_alias": "seed-aliases", + "models": ["gpt-4o"], + "aliases": {"gpt-4": "gpt-4o", "fast": "gpt-4o"} +}' + +# --- Case 7: budget — max + soft + duration -------------------------------- +create_key "case07 budget: max + soft + duration" '{ + "key_alias": "seed-budget", + "max_budget": 100, + "soft_budget": 80, + "budget_duration": "30d" +}' + +# --- Case 8: rate limits — tpm + rpm + parallel ---------------------------- +create_key "case08 rate limits: tpm + rpm + parallel" '{ + "key_alias": "seed-rate-limits", + "tpm_limit": 2000, + "rpm_limit": 120, + "max_parallel_requests": 10 +}' + +# --- Case 9: rate-limit enforcement types ---------------------------------- +create_key "case09 limit types" '{ + "key_alias": "seed-limit-types", + "tpm_limit": 5000, + "rpm_limit": 300, + "tpm_limit_type": "guaranteed_throughput", + "rpm_limit_type": "best_effort_throughput" +}' + +# --- Case 10: per-model budgets -------------------------------------------- +create_key "case10 model_max_budget" '{ + "key_alias": "seed-model-budgets", + "model_max_budget": {"gpt-4o": {"budget_limit": 0.5, "time_period": "30d"}} +}' + +# --- Case 11: per-model rpm/tpm limits ------------------------------------- +create_key "case11 per-model rpm/tpm" '{ + "key_alias": "seed-per-model-limits", + "model_rpm_limit": {"gpt-4o": 60, "claude-3-5-sonnet": 120}, + "model_tpm_limit": {"gpt-4o": 1000, "claude-3-5-sonnet": 2000} +}' + +# --- Case 12: metadata ----------------------------------------------------- +create_key "case12 metadata" '{ + "key_alias": "seed-metadata", + "metadata": {"team": "core-infra", "app": "app2", "email": "owner@example.com"} +}' + +# --- Case 13: tags (spend tracking / tag routing) -------------------------- +create_key "case13 tags" '{ + "key_alias": "seed-tags", + "tags": ["production", "team-platform", "tier-1"] +}' + +# --- Case 14: permissions (pii controls) ----------------------------------- +create_key "case14 permissions" '{ + "key_alias": "seed-permissions", + "permissions": {"allow_pii_controls": true} +}' + +# --- Case 15: blocked key -------------------------------------------------- +create_key "case15 blocked" '{ + "key_alias": "seed-blocked", + "blocked": true +}' + +# --- Case 16: allowed_routes ----------------------------------------------- +create_key "case16 allowed_routes" '{ + "key_alias": "seed-allowed-routes", + "allowed_routes": ["/chat/completions", "/embeddings"] +}' + +# --- Case 17: allowed_cache_controls --------------------------------------- +create_key "case17 allowed_cache_controls" '{ + "key_alias": "seed-cache-controls", + "allowed_cache_controls": ["no-cache", "no-store"] +}' + +# --- Case 18: key_type — read_only ----------------------------------------- +create_key "case18 key_type: read_only" '{ + "key_alias": "seed-read-only", + "key_type": "read_only" +}' + +# --- Case 19: key_type — management ---------------------------------------- +create_key "case19 key_type: management" '{ + "key_alias": "seed-management", + "key_type": "management" +}' + +# --- Case 20: config override ---------------------------------------------- +create_key "case20 config override" '{ + "key_alias": "seed-config", + "config": {"num_retries": 3} +}' + +echo "### Scoped keys (team / user / org / budget) ###" +echo + +# --- Case 21: scoped to a team --------------------------------------------- +create_key "case21 team-scoped" '{ + "key_alias": "seed-team-scoped", + "team_id": "'"${TEAM_ID}"'" +}' + +# --- Case 22: scoped to a user --------------------------------------------- +create_key "case22 user-scoped" '{ + "key_alias": "seed-user-scoped", + "user_id": "'"${USER_ID}"'" +}' + +# --- Case 23: scoped to an organization ------------------------------------ +create_key "case23 org-scoped" '{ + "key_alias": "seed-org-scoped", + "organization_id": "'"${ORG_ID}"'" +}' + +# --- Case 24: linked to a shared budget ------------------------------------ +create_key "case24 budget_id-linked" '{ + "key_alias": "seed-budget-linked", + "budget_id": "'"${BUDGET_ID}"'" +}' + +# --- Case 25: team + user together ----------------------------------------- +create_key "case25 team + user" '{ + "key_alias": "seed-team-user", + "team_id": "'"${TEAM_ID}"'", + "user_id": "'"${USER_ID}"'" +}' + +echo "### Enterprise-only cases (may be rejected on OSS builds) ###" +echo + +# --- Case 26: guardrails --------------------------------------------------- +create_key "case26 guardrails (enterprise)" '{ + "key_alias": "seed-guardrails", + "guardrails": ["my-pii-guard", "my-prompt-injection-guard"] +}' + +# --- Case 27: object_permission — vector stores ---------------------------- +create_key "case27 object_permission (enterprise)" '{ + "key_alias": "seed-object-permission", + "object_permission": {"vector_stores": ["vector_store_1", "vector_store_2"]} +}' + +# --- Case 28: enforced_params ---------------------------------------------- +create_key "case28 enforced_params (enterprise)" '{ + "key_alias": "seed-enforced-params", + "enforced_params": ["user", "metadata.generation_name"] +}' + +# --- Case 29: concurrent budget windows ------------------------------------ +create_key "case29 budget_limits windows (enterprise)" '{ + "key_alias": "seed-budget-windows", + "budget_limits": [ + {"budget_limit": 10.0, "time_period": "1d"}, + {"budget_limit": 50.0, "time_period": "7d"} + ] +}' + +# --- Case 30: auto-rotation ------------------------------------------------ +create_key "case30 auto_rotate (enterprise)" '{ + "key_alias": "seed-auto-rotate", + "auto_rotate": true, + "rotation_interval": "30d" +}' + +echo "### Kitchen sink ###" +echo + +# --- Case 31: everything together ------------------------------------------ +create_key "case31 kitchen sink" '{ + "key_alias": "seed-kitchen-sink", + "team_id": "'"${TEAM_ID}"'", + "user_id": "'"${USER_ID}"'", + "models": ["gpt-4o", "claude-3-5-sonnet"], + "aliases": {"gpt-4": "gpt-4o"}, + "max_budget": 200, + "soft_budget": 150, + "budget_duration": "30d", + "duration": "90d", + "tpm_limit": 4000, + "rpm_limit": 200, + "max_parallel_requests": 20, + "model_rpm_limit": {"gpt-4o": 100}, + "tags": ["production", "kitchen-sink"], + "metadata": {"owner": "platform", "tier": "premium"}, + "permissions": {"allow_pii_controls": true} +}' + +echo "Seeding complete." \ No newline at end of file diff --git a/scripts/litellm-to-bifrost/scripts/delete_bifrost_entities.sh b/scripts/litellm-to-bifrost/scripts/delete_bifrost_entities.sh new file mode 100755 index 0000000000..07d500d0e4 --- /dev/null +++ b/scripts/litellm-to-bifrost/scripts/delete_bifrost_entities.sh @@ -0,0 +1,280 @@ +#!/usr/bin/env bash +# +# delete_bifrost_entities.sh — delete all migration entities from a running Bifrost +# instance: virtual keys, model configs, teams, customers, users, provider keys +# and providers. +# +# Deletion order matters: virtual keys belong to teams/customers/users, model +# configs own budgets/rate limits, teams belong to customers, and provider keys +# belong to providers, so children are removed first. Deleting a user cascades +# its own governance settings, team memberships and access profiles. +# +# Usage: +# BIFROST_URL=http://localhost:8080 BIFROST_API_KEY= ./delete_bifrost_entities.sh +# +# DRY_RUN=1 ./delete_bifrost_entities.sh # list what would be deleted, delete nothing +# DELETE_VKS=0 ./delete_bifrost_entities.sh # skip virtual keys (default: 1) +# DELETE_MODEL_CONFIGS=0 ./delete_bifrost_entities.sh # skip model configs +# DELETE_TEAMS=0 ./delete_bifrost__entities.sh # skip teams +# DELETE_CUSTOMERS=0 ./delete_bifrost_entities.sh # skip customers +# DELETE_USERS=0 ./delete_bifrost_entities.sh # skip users +# DELETE_KEYS=0 ./delete_bifrost_entities.sh # skip provider keys +# DELETE_PROVIDERS=0 ./delete_bifrost_entities.sh # skip providers +# +# Endpoints used (Bifrost management API, all require a management bearer token): +# GET {url}/api/governance/virtual-keys DELETE {url}/api/governance/virtual-keys/{id} +# GET {url}/api/governance/model-configs DELETE {url}/api/governance/model-configs/{id} +# GET {url}/api/governance/teams DELETE {url}/api/governance/teams/{id} +# GET {url}/api/governance/customers DELETE {url}/api/governance/customers/{id} +# GET {url}/api/users DELETE {url}/api/users/{id} +# GET {url}/api/providers DELETE {url}/api/providers/{provider} +# GET {url}/api/providers/{provider}/keys DELETE {url}/api/providers/{provider}/keys/{key_id} +# +# Note: Bifrost exposes no batch delete for these entities, so each row is +# deleted individually. Deleting your own user is rejected by the API; such a +# failure is logged and counted, and does not abort the run. +# +# Requires: curl, jq. + +set -euo pipefail + +BIFROST_URL="${BIFROST_URL:-http://localhost:8080}" +BIFROST_API_KEY="${BIFROST_API_KEY:-}" +DRY_RUN="${DRY_RUN:-0}" +DELETE_VKS="${DELETE_VKS:-1}" +DELETE_MODEL_CONFIGS="${DELETE_MODEL_CONFIGS:-1}" +DELETE_TEAMS="${DELETE_TEAMS:-1}" +DELETE_CUSTOMERS="${DELETE_CUSTOMERS:-1}" +DELETE_USERS="${DELETE_USERS:-1}" +DELETE_KEYS="${DELETE_KEYS:-1}" +DELETE_PROVIDERS="${DELETE_PROVIDERS:-1}" + +# curl_args adds the management bearer token only when one is configured, so the +# script also works against a local dev instance with auth disabled. +curl_args=(-sS --fail-with-body) +if [[ -n "${BIFROST_API_KEY}" ]]; then + curl_args+=(--header "Authorization: Bearer ${BIFROST_API_KEY}") +fi + +# get issues an authenticated GET against a Bifrost path and prints the body. +get() { + curl "${curl_args[@]}" --location "${BIFROST_URL}$1" +} + +# urlenc percent-encodes one URL path segment using jq's URI encoder. +urlenc() { + jq -rn --arg v "$1" '$v|@uri' +} + +# delete_one issues an authenticated DELETE for a single entity path, printing +# OK/FAIL and returning non-zero on failure. +delete_one() { + local label="$1" path="$2" + if curl "${curl_args[@]}" --request DELETE --location "${BIFROST_URL}${path}" >/dev/null; then + echo " OK deleted ${label}" + return 0 + fi + echo " FAIL delete ${label}" + return 1 +} + +# delete_ids deletes every id read on stdin under the given path prefix, +# printing a dry-run line instead when DRY_RUN is set. Echoes the failure count. +delete_ids() { + local label="$1" prefix="$2" failed=0 id enc + while IFS= read -r id; do + [[ -z "${id}" ]] && continue + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete ${label} ${id}" + continue + fi + enc="$(urlenc "${id}")" + delete_one "${label} ${id}" "${prefix}${enc}" || failed=$((failed + 1)) + done + return "${failed}" +} + +# delete_all_virtual_keys pages through /api/governance/virtual-keys (limit / +# offset) and deletes every virtual key individually. +delete_all_virtual_keys() { + echo "==> Listing virtual keys" + local ids limit=100 offset=0 page batch + ids="" + while :; do + page="$(get "/api/governance/virtual-keys?limit=${limit}&offset=${offset}")" + batch="$(jq -r '.virtual_keys[]?.id' <<<"${page}")" + [[ -z "${batch}" ]] && break + ids+="${batch}"$'\n' + offset=$((offset + limit)) + done + + local count + count="$(grep -c . <<<"${ids}" || true)" + echo " found ${count} virtual key(s)" + [[ "${count}" -eq 0 ]] && return 0 + + local failed=0 + delete_ids "virtual key" "/api/governance/virtual-keys/" <<<"${ids}" || failed=$? + [[ "${failed}" -gt 0 ]] && { echo " ${failed} virtual key(s) failed"; return 1; } + return 0 +} + +# delete_all_model_configs pages through /api/governance/model-configs and +# deletes every model config individually. Deleting a model config also deletes +# its owned budgets and rate limits in Bifrost. +delete_all_model_configs() { + echo "==> Listing model configs" + local ids limit=100 offset=0 page batch + ids="" + while :; do + page="$(get "/api/governance/model-configs?limit=${limit}&offset=${offset}")" + batch="$(jq -r '.model_configs[]?.id' <<<"${page}")" + [[ -z "${batch}" ]] && break + ids+="${batch}"$'\n' + offset=$((offset + limit)) + done + + local count + count="$(grep -c . <<<"${ids}" || true)" + echo " found ${count} model config(s)" + [[ "${count}" -eq 0 ]] && return 0 + + local failed=0 + delete_ids "model config" "/api/governance/model-configs/" <<<"${ids}" || failed=$? + [[ "${failed}" -gt 0 ]] && { echo " ${failed} model config(s) failed"; return 1; } + return 0 +} + +# delete_all_teams deletes every team. /api/governance/teams returns the full +# list (no pagination). +delete_all_teams() { + echo "==> Listing teams" + local ids count + ids="$(get "/api/governance/teams" | jq -r '.teams[]?.id')" + count="$(grep -c . <<<"${ids}" || true)" + echo " found ${count} team(s)" + [[ "${count}" -eq 0 ]] && return 0 + + local failed=0 + delete_ids "team" "/api/governance/teams/" <<<"${ids}" || failed=$? + [[ "${failed}" -gt 0 ]] && { echo " ${failed} team(s) failed"; return 1; } + return 0 +} + +# delete_all_customers deletes every customer. /api/governance/customers returns +# the full list (no pagination). +delete_all_customers() { + echo "==> Listing customers" + local ids count + ids="$(get "/api/governance/customers" | jq -r '.customers[]?.id')" + count="$(grep -c . <<<"${ids}" || true)" + echo " found ${count} customer(s)" + [[ "${count}" -eq 0 ]] && return 0 + + local failed=0 + delete_ids "customer" "/api/governance/customers/" <<<"${ids}" || failed=$? + [[ "${failed}" -gt 0 ]] && { echo " ${failed} customer(s) failed"; return 1; } + return 0 +} + +# delete_all_users pages through /api/users (page / has_more) and deletes every +# user individually. Deleting the caller's own user is rejected by the API and +# counted as a failure. +delete_all_users() { + echo "==> Listing users" + local ids="" limit=100 page=1 resp batch has_more + while :; do + resp="$(get "/api/users?page=${page}&limit=${limit}")" + batch="$(jq -r '.users[]?.id' <<<"${resp}")" + [[ -n "${batch}" ]] && ids+="${batch}"$'\n' + has_more="$(jq -r '.has_more // false' <<<"${resp}")" + [[ "${has_more}" != "true" ]] && break + page=$((page + 1)) + done + + local count + count="$(grep -c . <<<"${ids}" || true)" + echo " found ${count} user(s)" + [[ "${count}" -eq 0 ]] && return 0 + + local failed=0 + delete_ids "user" "/api/users/" <<<"${ids}" || failed=$? + [[ "${failed}" -gt 0 ]] && { echo " ${failed} user(s) failed (a self-delete is expected to fail)"; return 1; } + return 0 +} + +# delete_all_provider_keys deletes every key under every configured provider. +# Provider deletion usually cascades keys, but deleting keys first gives clearer +# reset output and works when provider deletion is disabled. +delete_all_provider_keys() { + echo "==> Listing provider keys" + local providers provider provider_path ids count total=0 failed=0 key_id key_path + providers="$(get "/api/providers" | jq -r '.providers[]?.name')" + count="$(grep -c . <<<"${providers}" || true)" + echo " found ${count} provider(s) to inspect" + [[ "${count}" -eq 0 ]] && return 0 + + while IFS= read -r provider; do + [[ -z "${provider}" ]] && continue + provider_path="$(urlenc "${provider}")" + ids="$(get "/api/providers/${provider_path}/keys" | jq -r '.keys[]?.id')" + count="$(grep -c . <<<"${ids}" || true)" + echo " provider ${provider}: found ${count} key(s)" + total=$((total + count)) + while IFS= read -r key_id; do + [[ -z "${key_id}" ]] && continue + key_path="$(urlenc "${key_id}")" + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete provider key ${provider}/${key_id}" + continue + fi + delete_one "provider key ${provider}/${key_id}" "/api/providers/${provider_path}/keys/${key_path}" || failed=$((failed + 1)) + done <<<"${ids}" + done <<<"${providers}" + + echo " found ${total} provider key(s)" + [[ "${failed}" -gt 0 ]] && { echo " ${failed} provider key(s) failed"; return 1; } + return 0 +} + +# delete_all_providers deletes every configured provider after keys are removed. +delete_all_providers() { + echo "==> Listing providers" + local providers count failed=0 provider provider_path + providers="$(get "/api/providers" | jq -r '.providers[]?.name')" + count="$(grep -c . <<<"${providers}" || true)" + echo " found ${count} provider(s)" + [[ "${count}" -eq 0 ]] && return 0 + + while IFS= read -r provider; do + [[ -z "${provider}" ]] && continue + provider_path="$(urlenc "${provider}")" + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete provider ${provider}" + continue + fi + delete_one "provider ${provider}" "/api/providers/${provider_path}" || failed=$((failed + 1)) + done <<<"${providers}" + + [[ "${failed}" -gt 0 ]] && { echo " ${failed} provider(s) failed"; return 1; } + return 0 +} + +main() { + if [[ "${DRY_RUN}" == "1" ]]; then + echo "DRY-RUN mode: nothing will be deleted." + echo + fi + + [[ "${DELETE_VKS}" == "1" ]] && { delete_all_virtual_keys || true; echo; } + [[ "${DELETE_MODEL_CONFIGS}" == "1" ]] && { delete_all_model_configs || true; echo; } + [[ "${DELETE_TEAMS}" == "1" ]] && { delete_all_teams || true; echo; } + [[ "${DELETE_CUSTOMERS}" == "1" ]] && { delete_all_customers || true; echo; } + [[ "${DELETE_USERS}" == "1" ]] && { delete_all_users || true; echo; } + [[ "${DELETE_KEYS}" == "1" ]] && { delete_all_provider_keys || true; echo; } + [[ "${DELETE_PROVIDERS}" == "1" ]] && { delete_all_providers || true; echo; } + + echo "Cleanup complete." +} + +main "$@" \ No newline at end of file diff --git a/scripts/litellm-to-bifrost/scripts/delete_entities.sh b/scripts/litellm-to-bifrost/scripts/delete_entities.sh new file mode 100755 index 0000000000..87a6b898c2 --- /dev/null +++ b/scripts/litellm-to-bifrost/scripts/delete_entities.sh @@ -0,0 +1,268 @@ +#!/usr/bin/env bash +# +# delete_entities.sh — delete ALL entities from a running LiteLLM instance. +# Covers, in dependency order: virtual keys, teams, organizations, budgets, +# (DB-backed) models and credentials. Use this to reset a LiteLLM instance before +# re-running the org->customer migration (go run ./scripts/migration/litellm). +# +# Deletion order matters: keys and teams belong to organizations, so they are +# removed first; budgets orphaned by org deletion are swept up afterwards. +# +# Usage: +# LITELLM_URL=http://localhost:4000 LITELLM_MASTER_KEY=sk-1234 ./delete_entities.sh +# +# DRY_RUN=1 ./delete_entities.sh # list what would be deleted, delete nothing +# DELETE_KEYS=0 ./delete_entities.sh # skip virtual keys +# DELETE_TEAMS=0 ./delete_entities.sh # skip teams +# DELETE_ORGS=0 ./delete_entities.sh # skip organizations +# DELETE_BUDGETS=0 ./delete_entities.sh # skip budgets +# DELETE_MODELS=0 ./delete_entities.sh # skip models +# DELETE_CREDENTIALS=0 ./delete_entities.sh # skip credentials +# +# Endpoints used (LiteLLM management API): +# GET {url}/key/list POST {url}/key/delete (batch) +# GET {url}/team/list POST {url}/team/delete (batch) +# GET {url}/organization/list DELETE {url}/organization/delete (batch) +# GET {url}/budget/list POST {url}/budget/delete (one at a time) +# GET {url}/model/info POST {url}/model/delete (one at a time) +# GET {url}/credentials DELETE {url}/credentials/{name} (one at a time) +# +# Note: only DB-backed models (model_info.db_model == true) can be deleted via +# the API; models defined in config.yaml are skipped. +# +# Requires: curl, jq. + +set -euo pipefail + +LITELLM_URL="${LITELLM_URL:-http://localhost:4000}" +LITELLM_MASTER_KEY="${LITELLM_MASTER_KEY:-sk-1234}" +DRY_RUN="${DRY_RUN:-0}" +DELETE_KEYS="${DELETE_KEYS:-1}" +DELETE_TEAMS="${DELETE_TEAMS:-1}" +DELETE_ORGS="${DELETE_ORGS:-1}" +DELETE_BUDGETS="${DELETE_BUDGETS:-1}" +DELETE_MODELS="${DELETE_MODELS:-1}" +DELETE_CREDENTIALS="${DELETE_CREDENTIALS:-1}" + +AUTH_HEADER="Authorization: Bearer ${LITELLM_MASTER_KEY}" + +# get issues an authenticated GET and prints the response body. +get() { + curl -sS --fail-with-body --location "${LITELLM_URL}$1" --header "${AUTH_HEADER}" +} + +# delete_all_keys collects every virtual key across all pages of /key/list and +# deletes them in a single batch /key/delete request. +delete_all_keys() { + echo "==> Listing virtual keys" + local page1 total_pages keys count + page1="$(get "/key/list?page=1")" + total_pages="$(jq -r '.total_pages // 1' <<<"${page1}")" + + keys="$(jq -c '[.keys[]? | if type == "string" then . else (.token // .key // .key_name // empty) end]' <<<"${page1}")" + local p + for ((p = 2; p <= total_pages; p++)); do + keys="$(jq -c --argjson more "$(get "/key/list?page=${p}" | jq -c '[.keys[]? | if type == "string" then . else (.token // .key // .key_name // empty) end]')" '. + $more' <<<"${keys}")" + done + + count="$(jq 'length' <<<"${keys}")" + echo " found ${count} virtual key(s)" + if [[ "${count}" -eq 0 ]]; then + return 0 + fi + + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete ${count} virtual key(s)" + return 0 + fi + + echo "==> Deleting ${count} virtual key(s)" + curl -sS --fail-with-body --request POST --location "${LITELLM_URL}/key/delete" \ + --header "${AUTH_HEADER}" \ + --header "Content-Type: application/json" \ + --data "{\"keys\": ${keys}}" >/dev/null \ + && echo " OK deleted ${count} virtual key(s)" \ + || echo " FAIL delete ${count} virtual key(s)" +} + +# delete_all_teams enumerates every team and deletes them in a single batch +# /team/delete request. Null team ids (the implicit default team) are skipped. +delete_all_teams() { + echo "==> Listing teams" + local teams ids count + teams="$(get "/team/list")" + ids="$(jq -c '[.[].team_id | select(. != null)]' <<<"${teams}")" + count="$(jq 'length' <<<"${ids}")" + echo " found ${count} team(s)" + if [[ "${count}" -eq 0 ]]; then + return 0 + fi + + jq -r '.[] | " team \(.)"' <<<"${ids}" + + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete ${count} team(s)" + return 0 + fi + + echo "==> Deleting ${count} team(s)" + curl -sS --fail-with-body --request POST --location "${LITELLM_URL}/team/delete" \ + --header "${AUTH_HEADER}" \ + --header "Content-Type: application/json" \ + --data "{\"team_ids\": ${ids}}" >/dev/null \ + && echo " OK deleted ${count} team(s)" \ + || echo " FAIL delete ${count} team(s)" +} + +# delete_all_organizations enumerates every organization and deletes them in a +# single batch /organization/delete request. +delete_all_organizations() { + echo "==> Listing organizations" + local orgs ids count + orgs="$(get "/organization/list")" + ids="$(jq -c '[.[].organization_id | select(. != null)]' <<<"${orgs}")" + count="$(jq 'length' <<<"${ids}")" + echo " found ${count} organization(s)" + if [[ "${count}" -eq 0 ]]; then + return 0 + fi + + jq -r '.[] | " org \(.organization_id) (\(.organization_alias // "-"))"' <<<"${orgs}" + + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete ${count} organization(s)" + return 0 + fi + + echo "==> Deleting ${count} organization(s)" + curl -sS --fail-with-body --request DELETE --location "${LITELLM_URL}/organization/delete" \ + --header "${AUTH_HEADER}" \ + --header "Content-Type: application/json" \ + --data "{\"organization_ids\": ${ids}}" >/dev/null \ + && echo " OK deleted ${count} organization(s)" \ + || echo " FAIL delete ${count} organization(s)" +} + +# delete_all_budgets enumerates every budget and deletes them one at a time; +# LiteLLM exposes no batch delete for budgets. +delete_all_budgets() { + echo "==> Listing budgets" + local budgets count failed=0 + budgets="$(get "/budget/list")" + count="$(jq 'length' <<<"${budgets}")" + echo " found ${count} budget(s)" + if [[ "${count}" -eq 0 ]]; then + return 0 + fi + + while IFS= read -r budget_id; do + [[ -z "${budget_id}" ]] && continue + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete budget ${budget_id}" + continue + fi + if curl -sS --fail-with-body --request POST --location "${LITELLM_URL}/budget/delete" \ + --header "${AUTH_HEADER}" \ + --header "Content-Type: application/json" \ + --data "$(jq -n --arg id "${budget_id}" '{"id": $id}')" >/dev/null; then + echo " OK deleted budget ${budget_id}" + else + echo " FAIL delete budget ${budget_id}" + failed=$((failed + 1)) + fi + done < <(jq -r '.[] | select(.budget_id != null) | .budget_id' <<<"${budgets}") + + if [[ "${failed}" -gt 0 ]]; then + echo " ${failed} budget(s) failed to delete" + return 1 + fi +} + +# delete_all_models enumerates DB-backed models and deletes them one at a time; +# models defined in config.yaml (db_model == false) cannot be deleted via the +# API and are skipped. +delete_all_models() { + echo "==> Listing models" + local models db_ids count skipped failed=0 + models="$(get "/model/info")" + db_ids="$(jq -c '[.data[] | select(.model_info.db_model == true and .model_info.id != null) | .model_info.id]' <<<"${models}")" + count="$(jq 'length' <<<"${db_ids}")" + skipped="$(jq '[.data[] | select(.model_info.db_model != true)] | length' <<<"${models}")" + echo " found ${count} DB-backed model(s) (${skipped} config model(s) skipped)" + if [[ "${count}" -eq 0 ]]; then + return 0 + fi + + while IFS= read -r model_id; do + [[ -z "${model_id}" ]] && continue + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete model ${model_id}" + continue + fi + if curl -sS --fail-with-body --request POST --location "${LITELLM_URL}/model/delete" \ + --header "${AUTH_HEADER}" \ + --header "Content-Type: application/json" \ + --data "$(jq -n --arg id "${model_id}" '{"id": $id}')" >/dev/null; then + echo " OK deleted model ${model_id}" + else + echo " FAIL delete model ${model_id}" + failed=$((failed + 1)) + fi + done < <(jq -r '.[]' <<<"${db_ids}") + + if [[ "${failed}" -gt 0 ]]; then + echo " ${failed} model(s) failed to delete" + return 1 + fi +} + +# delete_all_credentials enumerates every stored credential and deletes them one +# at a time via DELETE /credentials/{credential_name}. +delete_all_credentials() { + echo "==> Listing credentials" + local credentials count failed=0 + credentials="$(get "/credentials")" + count="$(jq '.credentials // [] | length' <<<"${credentials}")" + echo " found ${count} credential(s)" + if [[ "${count}" -eq 0 ]]; then + return 0 + fi + + while IFS=$'\t' read -r credential_name credential_path; do + [[ -z "${credential_name}" ]] && continue + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN would delete credential ${credential_name}" + continue + fi + if curl -sS --fail-with-body --request DELETE --location "${LITELLM_URL}/credentials/${credential_path}" \ + --header "${AUTH_HEADER}" >/dev/null; then + echo " OK deleted credential ${credential_name}" + else + echo " FAIL delete credential ${credential_name}" + failed=$((failed + 1)) + fi + done < <(jq -r '.credentials // [] | .[] | select(.credential_name != null) | [.credential_name, (.credential_name | @uri)] | @tsv' <<<"${credentials}") + + if [[ "${failed}" -gt 0 ]]; then + echo " ${failed} credential(s) failed to delete" + return 1 + fi +} + +main() { + if [[ "${DRY_RUN}" == "1" ]]; then + echo "DRY-RUN mode: nothing will be deleted." + echo + fi + + [[ "${DELETE_KEYS}" == "1" ]] && { delete_all_keys || true; echo; } + [[ "${DELETE_TEAMS}" == "1" ]] && { delete_all_teams || true; echo; } + [[ "${DELETE_ORGS}" == "1" ]] && { delete_all_organizations || true; echo; } + [[ "${DELETE_BUDGETS}" == "1" ]] && { delete_all_budgets || true; echo; } + [[ "${DELETE_MODELS}" == "1" ]] && { delete_all_models || true; echo; } + [[ "${DELETE_CREDENTIALS}" == "1" ]] && { delete_all_credentials || true; echo; } + + echo "Cleanup complete." +} + +main "$@" diff --git a/scripts/litellm-to-bifrost/scripts/seed_models.sh b/scripts/litellm-to-bifrost/scripts/seed_models.sh new file mode 100755 index 0000000000..4806f0a64c --- /dev/null +++ b/scripts/litellm-to-bifrost/scripts/seed_models.sh @@ -0,0 +1,225 @@ +#!/usr/bin/env bash +# +# seed_models.sh - seed the comprehensive model fixture into a running LiteLLM +# proxy via the management API (POST /model/new), instead of via config.yaml. +# +# Models added via /model/new are stored in the DB (db_model: true) and are +# therefore deletable via delete_entities.sh in this directory. +# +# Usage: +# LITELLM_URL=http://localhost:4000 LITELLM_MASTER_KEY=sk-1234 ./seed_models.sh +# +# DRY_RUN=1 ./seed_models.sh # print payloads, POST nothing +# +# The proxy must run with `general_settings.store_model_in_db: True` and a DB +# connection, otherwise /model/new returns 500. +# +# Requires: curl, jq. + +set -euo pipefail + +LITELLM_URL="${LITELLM_URL:-http://localhost:4000}" +LITELLM_MASTER_KEY="${LITELLM_MASTER_KEY:-sk-1234}" +DRY_RUN="${DRY_RUN:-0}" + +added=0 +failed=0 + +# add_model posts a single Deployment to /model/new. Args: +# $1 = human-readable description (logged only) +# $2 = JSON body ({model_name, litellm_params, model_info?}) +# In DRY_RUN mode it pretty-prints the body instead of sending it. +add_model() { + local desc="$1" body="$2" + + if ! jq -e . >/dev/null 2>&1 <<<"${body}"; then + echo " SKIP ${desc}: invalid JSON payload" + failed=$((failed + 1)) + return + fi + + if [[ "${DRY_RUN}" == "1" ]]; then + echo " DRY-RUN ${desc}:" + jq -c . <<<"${body}" + added=$((added + 1)) + return + fi + + if curl -sS --fail-with-body \ + --request POST "${LITELLM_URL}/model/new" \ + --header "Authorization: Bearer ${LITELLM_MASTER_KEY}" \ + --header "Content-Type: application/json" \ + --data "${body}" >/dev/null; then + echo " OK ${desc}" + added=$((added + 1)) + else + echo " FAIL ${desc}" + failed=$((failed + 1)) + fi +} + +# ----------------------------------------------------------------------------- +# CASE 1 — Simple chat/completion, one representative model per provider. +# ----------------------------------------------------------------------------- +echo "==> Case 1: simple chat per provider" +add_model "openai chat" '{"model_name":"openai-gpt-4o","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY"}}' +add_model "anthropic chat" '{"model_name":"anthropic-claude","litellm_params":{"model":"anthropic/claude-sonnet-4-20250514","api_key":"os.environ/ANTHROPIC_API_KEY"}}' +add_model "gemini chat" '{"model_name":"gemini-flash","litellm_params":{"model":"gemini/gemini-2.5-flash","api_key":"os.environ/GEMINI_API_KEY"}}' +add_model "xai chat" '{"model_name":"xai-grok","litellm_params":{"model":"xai/grok-3","api_key":"os.environ/XAI_API_KEY"}}' +add_model "zai chat" '{"model_name":"zai-glm","litellm_params":{"model":"zai/glm-4.6","api_key":"os.environ/ZAI_API_KEY"}}' +add_model "mistral chat" '{"model_name":"mistral-large","litellm_params":{"model":"mistral/mistral-large-latest","api_key":"os.environ/MISTRAL_API_KEY"}}' +add_model "codestral chat" '{"model_name":"codestral","litellm_params":{"model":"codestral/codestral-latest","api_key":"os.environ/CODESTRAL_API_KEY"}}' +add_model "deepseek chat" '{"model_name":"deepseek-chat","litellm_params":{"model":"deepseek/deepseek-chat","api_key":"os.environ/DEEPSEEK_API_KEY"}}' +add_model "groq chat" '{"model_name":"groq-llama","litellm_params":{"model":"groq/llama-3.3-70b-versatile","api_key":"os.environ/GROQ_API_KEY"}}' +add_model "cerebras chat" '{"model_name":"cerebras-llama","litellm_params":{"model":"cerebras/llama-3.3-70b","api_key":"os.environ/CEREBRAS_API_KEY"}}' +add_model "cohere chat" '{"model_name":"cohere-command","litellm_params":{"model":"cohere_chat/command-r-plus","api_key":"os.environ/COHERE_API_KEY"}}' +add_model "together chat" '{"model_name":"together-llama","litellm_params":{"model":"together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo","api_key":"os.environ/TOGETHERAI_API_KEY"}}' +add_model "fireworks chat" '{"model_name":"fireworks-llama","litellm_params":{"model":"fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct","api_key":"os.environ/FIREWORKS_AI_API_KEY"}}' +add_model "openrouter chat" '{"model_name":"openrouter-gpt","litellm_params":{"model":"openrouter/openai/gpt-4o","api_key":"os.environ/OPENROUTER_API_KEY"}}' +add_model "perplexity chat" '{"model_name":"perplexity-sonar","litellm_params":{"model":"perplexity/sonar","api_key":"os.environ/PERPLEXITYAI_API_KEY"}}' +add_model "deepinfra chat" '{"model_name":"deepinfra-llama","litellm_params":{"model":"deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct","api_key":"os.environ/DEEPINFRA_API_KEY"}}' +add_model "sambanova chat" '{"model_name":"sambanova-llama","litellm_params":{"model":"sambanova/Meta-Llama-3.3-70B-Instruct","api_key":"os.environ/SAMBANOVA_API_KEY"}}' +add_model "nvidia nim chat" '{"model_name":"nvidia-nim-llama","litellm_params":{"model":"nvidia_nim/meta/llama-3.1-70b-instruct","api_key":"os.environ/NVIDIA_NIM_API_KEY"}}' +add_model "ai21 chat" '{"model_name":"ai21-jamba","litellm_params":{"model":"ai21_chat/jamba-1.5-large","api_key":"os.environ/AI21_API_KEY"}}' +add_model "replicate chat" '{"model_name":"replicate-llama","litellm_params":{"model":"replicate/meta/meta-llama-3-70b-instruct","api_key":"os.environ/REPLICATE_API_KEY"}}' +add_model "huggingface chat" '{"model_name":"huggingface-llama","litellm_params":{"model":"huggingface/meta-llama/Llama-3.1-8B-Instruct","api_key":"os.environ/HUGGINGFACE_API_KEY"}}' +add_model "databricks chat" '{"model_name":"databricks-dbrx","litellm_params":{"model":"databricks/databricks-dbrx-instruct","api_key":"os.environ/DATABRICKS_API_KEY","api_base":"os.environ/DATABRICKS_API_BASE"}}' +add_model "predibase chat" '{"model_name":"predibase-llama","litellm_params":{"model":"predibase/llama-3-8b-instruct","api_key":"os.environ/PREDIBASE_API_KEY","tenant_id":"os.environ/PREDIBASE_TENANT_ID"}}' +add_model "nscale chat" '{"model_name":"nscale-llama","litellm_params":{"model":"nscale/meta-llama/Llama-3.3-70B-Instruct","api_key":"os.environ/NSCALE_API_KEY"}}' +add_model "novita chat" '{"model_name":"novita-llama","litellm_params":{"model":"novita/meta-llama/llama-3.1-70b-instruct","api_key":"os.environ/NOVITA_API_KEY"}}' +add_model "friendliai chat" '{"model_name":"friendliai-llama","litellm_params":{"model":"friendliai/meta-llama-3.1-70b-instruct","api_key":"os.environ/FRIENDLI_TOKEN"}}' +add_model "featherless chat" '{"model_name":"featherless-llama","litellm_params":{"model":"featherless_ai/meta-llama/Meta-Llama-3.1-8B-Instruct","api_key":"os.environ/FEATHERLESS_AI_API_KEY"}}' +add_model "github chat" '{"model_name":"github-gpt","litellm_params":{"model":"github/gpt-4o","api_key":"os.environ/GITHUB_API_KEY"}}' +add_model "github copilot" '{"model_name":"github-copilot-gpt","litellm_params":{"model":"github_copilot/gpt-4o"}}' +add_model "cloudflare chat" '{"model_name":"cloudflare-llama","litellm_params":{"model":"cloudflare/@cf/meta/llama-3.1-8b-instruct","api_key":"os.environ/CLOUDFLARE_API_KEY","api_base":"os.environ/CLOUDFLARE_API_BASE"}}' +add_model "gigachat chat" '{"model_name":"gigachat-pro","litellm_params":{"model":"gigachat/GigaChat-Pro","api_key":"os.environ/GIGACHAT_API_KEY"}}' +add_model "moonshot chat" '{"model_name":"moonshot-v1","litellm_params":{"model":"moonshot/moonshot-v1-8k","api_key":"os.environ/MOONSHOT_API_KEY"}}' +add_model "dashscope chat" '{"model_name":"dashscope-qwen","litellm_params":{"model":"dashscope/qwen-max","api_key":"os.environ/DASHSCOPE_API_KEY"}}' +add_model "volcengine chat" '{"model_name":"volcengine-doubao","litellm_params":{"model":"volcengine/doubao-pro-4k","api_key":"os.environ/VOLCENGINE_API_KEY"}}' +add_model "nebius chat" '{"model_name":"nebius-llama","litellm_params":{"model":"nebius/meta-llama/Meta-Llama-3.1-70B-Instruct","api_key":"os.environ/NEBIUS_API_KEY"}}' +add_model "hyperbolic chat" '{"model_name":"hyperbolic-llama","litellm_params":{"model":"hyperbolic/meta-llama/Meta-Llama-3.1-70B-Instruct","api_key":"os.environ/HYPERBOLIC_API_KEY"}}' +add_model "lambda chat" '{"model_name":"lambda-llama","litellm_params":{"model":"lambda_ai/llama3.1-70b-instruct-fp8","api_key":"os.environ/LAMBDA_API_KEY"}}' +add_model "v0 chat" '{"model_name":"v0-md","litellm_params":{"model":"v0/v0-1.5-md","api_key":"os.environ/V0_API_KEY"}}' +add_model "morph chat" '{"model_name":"morph-large","litellm_params":{"model":"morph/morph-v3-large","api_key":"os.environ/MORPH_API_KEY"}}' +add_model "inception chat" '{"model_name":"inception-mercury","litellm_params":{"model":"inception/mercury","api_key":"os.environ/INCEPTION_API_KEY"}}' +add_model "maritalk chat" '{"model_name":"maritalk-sabia","litellm_params":{"model":"maritalk/sabia-3","api_key":"os.environ/MARITALK_API_KEY"}}' +add_model "meta llama api" '{"model_name":"meta-llama-api","litellm_params":{"model":"meta_llama/Llama-3.3-70B-Instruct","api_key":"os.environ/LLAMA_API_KEY"}}' +add_model "galadriel chat" '{"model_name":"galadriel-llama","litellm_params":{"model":"galadriel/llama3.1","api_key":"os.environ/GALADRIEL_API_KEY"}}' +add_model "aiml chat" '{"model_name":"aiml-gpt","litellm_params":{"model":"aiml/gpt-4o","api_key":"os.environ/AIML_API_KEY"}}' +add_model "cometapi chat" '{"model_name":"cometapi-gpt","litellm_params":{"model":"cometapi/gpt-4o","api_key":"os.environ/COMETAPI_KEY"}}' +add_model "vercel gateway" '{"model_name":"vercel-gateway-gpt","litellm_params":{"model":"vercel_ai_gateway/openai/gpt-4o","api_key":"os.environ/VERCEL_AI_GATEWAY_API_KEY"}}' +add_model "nano-gpt chat" '{"model_name":"nano-gpt","litellm_params":{"model":"nano-gpt/gpt-4o","api_key":"os.environ/NANO_GPT_API_KEY"}}' +add_model "chutes chat" '{"model_name":"chutes-llama","litellm_params":{"model":"chutes/meta-llama/Llama-3.3-70B-Instruct","api_key":"os.environ/CHUTES_API_KEY"}}' +add_model "minimax chat" '{"model_name":"minimax-text","litellm_params":{"model":"minimax/MiniMax-Text-01","api_key":"os.environ/MINIMAX_API_KEY"}}' +add_model "baseten chat" '{"model_name":"baseten-llama","litellm_params":{"model":"baseten/llama-3-70b-instruct","api_key":"os.environ/BASETEN_API_KEY"}}' + +# ----------------------------------------------------------------------------- +# CASE 2 — Self-hosted / OpenAI-compatible endpoints (require api_base). +# ----------------------------------------------------------------------------- +echo "==> Case 2: self-hosted / OpenAI-compatible" +add_model "ollama" '{"model_name":"ollama-llama","litellm_params":{"model":"ollama_chat/llama3.1","api_base":"os.environ/OLLAMA_API_BASE"}}' +add_model "hosted vllm" '{"model_name":"vllm-llama","litellm_params":{"model":"hosted_vllm/meta-llama/Llama-3.1-8B-Instruct","api_base":"os.environ/HOSTED_VLLM_API_BASE"}}' +add_model "lm studio" '{"model_name":"lmstudio-llama","litellm_params":{"model":"lm_studio/llama-3.1-8b-instruct","api_base":"os.environ/LM_STUDIO_API_BASE"}}' +add_model "llamafile" '{"model_name":"llamafile-local","litellm_params":{"model":"llamafile/local-model","api_base":"os.environ/LLAMAFILE_API_BASE"}}' +add_model "triton" '{"model_name":"triton-model","litellm_params":{"model":"triton/ensemble","api_base":"os.environ/TRITON_API_BASE"}}' +add_model "xinference" '{"model_name":"xinference-llama","litellm_params":{"model":"xinference/llama-3-instruct","api_base":"os.environ/XINFERENCE_API_BASE"}}' +add_model "openai-compat" '{"model_name":"openai-compatible-custom","litellm_params":{"model":"openai/my-self-hosted-model","api_key":"os.environ/CUSTOM_LLM_API_KEY","api_base":"os.environ/CUSTOM_LLM_API_BASE"}}' + +# ----------------------------------------------------------------------------- +# CASE 3 — Provider-specific credential shapes. +# ----------------------------------------------------------------------------- +echo "==> Case 3: provider-specific credentials" +add_model "vertex ai" '{"model_name":"vertex-gemini","litellm_params":{"model":"vertex_ai/gemini-2.5-pro","vertex_project":"os.environ/VERTEX_PROJECT","vertex_location":"os.environ/VERTEX_LOCATION","vertex_credentials":"os.environ/VERTEX_CREDENTIALS"}}' +add_model "bedrock keys" '{"model_name":"bedrock-claude","litellm_params":{"model":"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0","aws_access_key_id":"os.environ/AWS_ACCESS_KEY_ID","aws_secret_access_key":"os.environ/AWS_SECRET_ACCESS_KEY","aws_region_name":"os.environ/AWS_REGION_NAME"}}' +add_model "bedrock role" '{"model_name":"bedrock-claude-role","litellm_params":{"model":"bedrock/converse/anthropic.claude-3-5-sonnet-20241022-v2:0","aws_role_name":"os.environ/AWS_ROLE_NAME","aws_session_name":"litellm-session","aws_region_name":"os.environ/AWS_REGION_NAME"}}' +add_model "sagemaker" '{"model_name":"sagemaker-llama","litellm_params":{"model":"sagemaker/jumpstart-dft-meta-textgeneration-llama-3-8b","aws_region_name":"os.environ/AWS_REGION_NAME","input_cost_per_second":0.000420}}' +add_model "watsonx" '{"model_name":"watsonx-llama","litellm_params":{"model":"watsonx/meta-llama/llama-3-3-70b-instruct","api_key":"os.environ/WATSONX_API_KEY","api_base":"os.environ/WATSONX_URL","project_id":"os.environ/WATSONX_PROJECT_ID"}}' +add_model "oci" '{"model_name":"oci-cohere","litellm_params":{"model":"oci/cohere.command-r-plus","oci_region":"os.environ/OCI_REGION","oci_user":"os.environ/OCI_USER","oci_tenancy":"os.environ/OCI_TENANCY","oci_fingerprint":"os.environ/OCI_FINGERPRINT","oci_key_file":"os.environ/OCI_KEY_FILE"}}' +add_model "snowflake" '{"model_name":"snowflake-llama","litellm_params":{"model":"snowflake/llama3.1-70b","api_base":"os.environ/SNOWFLAKE_API_BASE","api_key":"os.environ/SNOWFLAKE_JWT"}}' + +# ----------------------------------------------------------------------------- +# CASE 4 — Embeddings (mode: embedding). +# ----------------------------------------------------------------------------- +echo "==> Case 4: embeddings" +add_model "openai embed" '{"model_name":"text-embedding-3-small","litellm_params":{"model":"openai/text-embedding-3-small","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"embedding","base_model":"text-embedding-3-small"}}' +add_model "cohere embed" '{"model_name":"cohere-embed","litellm_params":{"model":"cohere/embed-english-v3.0","api_key":"os.environ/COHERE_API_KEY"},"model_info":{"mode":"embedding"}}' +add_model "mistral embed" '{"model_name":"mistral-embed","litellm_params":{"model":"mistral/mistral-embed","api_key":"os.environ/MISTRAL_API_KEY"},"model_info":{"mode":"embedding"}}' +add_model "voyage embed" '{"model_name":"voyage-embed","litellm_params":{"model":"voyage/voyage-3","api_key":"os.environ/VOYAGE_API_KEY"},"model_info":{"mode":"embedding"}}' +add_model "jina embed" '{"model_name":"jina-embed","litellm_params":{"model":"jina_ai/jina-embeddings-v3","api_key":"os.environ/JINA_AI_API_KEY"},"model_info":{"mode":"embedding"}}' +add_model "bedrock embed" '{"model_name":"bedrock-titan-embed","litellm_params":{"model":"bedrock/amazon.titan-embed-text-v2:0","aws_region_name":"os.environ/AWS_REGION_NAME"},"model_info":{"mode":"embedding"}}' +add_model "vertex embed" '{"model_name":"vertex-embed","litellm_params":{"model":"vertex_ai/text-embedding-005","vertex_project":"os.environ/VERTEX_PROJECT","vertex_location":"os.environ/VERTEX_LOCATION"},"model_info":{"mode":"embedding"}}' +add_model "infinity embed" '{"model_name":"infinity-embed","litellm_params":{"model":"infinity/BAAI/bge-small-en-v1.5","api_base":"os.environ/INFINITY_API_BASE"},"model_info":{"mode":"embedding"}}' + +# ----------------------------------------------------------------------------- +# CASE 5 — Image generation (mode: image_generation). +# ----------------------------------------------------------------------------- +echo "==> Case 5: image generation" +add_model "openai image" '{"model_name":"gpt-image-1","litellm_params":{"model":"openai/gpt-image-1","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"image_generation"}}' +add_model "vertex imagen" '{"model_name":"vertex-imagen","litellm_params":{"model":"vertex_ai/imagen-3.0-generate-002","vertex_project":"os.environ/VERTEX_PROJECT","vertex_location":"os.environ/VERTEX_LOCATION"},"model_info":{"mode":"image_generation"}}' +add_model "bedrock sd" '{"model_name":"bedrock-sd","litellm_params":{"model":"bedrock/stability.sd3-large-v1:0","aws_region_name":"os.environ/AWS_REGION_NAME"},"model_info":{"mode":"image_generation"}}' +add_model "flux pro" '{"model_name":"flux-pro","litellm_params":{"model":"black_forest_labs/flux-pro-1.1","api_key":"os.environ/BFL_API_KEY"},"model_info":{"mode":"image_generation"}}' +add_model "recraft image" '{"model_name":"recraft-image","litellm_params":{"model":"recraft/recraftv3","api_key":"os.environ/RECRAFT_API_KEY"},"model_info":{"mode":"image_generation"}}' +add_model "xai image" '{"model_name":"xai-image","litellm_params":{"model":"xai/grok-2-image","api_key":"os.environ/XAI_API_KEY"},"model_info":{"mode":"image_generation"}}' + +# ----------------------------------------------------------------------------- +# CASE 6 — Audio transcription (audio_transcription) & TTS (audio_speech). +# ----------------------------------------------------------------------------- +echo "==> Case 6: audio (transcription + speech)" +add_model "openai whisper" '{"model_name":"whisper","litellm_params":{"model":"openai/whisper-1","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"audio_transcription"}}' +add_model "groq whisper" '{"model_name":"groq-whisper","litellm_params":{"model":"groq/whisper-large-v3","api_key":"os.environ/GROQ_API_KEY"},"model_info":{"mode":"audio_transcription"}}' +add_model "deepgram" '{"model_name":"deepgram-nova","litellm_params":{"model":"deepgram/nova-3","api_key":"os.environ/DEEPGRAM_API_KEY"},"model_info":{"mode":"audio_transcription"}}' +add_model "assemblyai" '{"model_name":"assemblyai-best","litellm_params":{"model":"assemblyai/best","api_key":"os.environ/ASSEMBLYAI_API_KEY"},"model_info":{"mode":"audio_transcription"}}' +add_model "openai tts" '{"model_name":"openai-tts","litellm_params":{"model":"openai/tts-1","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"audio_speech"}}' +add_model "elevenlabs tts" '{"model_name":"elevenlabs-tts","litellm_params":{"model":"elevenlabs/eleven_turbo_v2_5","api_key":"os.environ/ELEVENLABS_API_KEY"},"model_info":{"mode":"audio_speech"}}' + +# ----------------------------------------------------------------------------- +# CASE 7 — Rerank (mode: rerank). +# ----------------------------------------------------------------------------- +echo "==> Case 7: rerank" +add_model "cohere rerank" '{"model_name":"cohere-rerank","litellm_params":{"model":"cohere/rerank-english-v3.0","api_key":"os.environ/COHERE_API_KEY"},"model_info":{"mode":"rerank"}}' +add_model "jina rerank" '{"model_name":"jina-rerank","litellm_params":{"model":"jina_ai/jina-reranker-v2-base-multilingual","api_key":"os.environ/JINA_AI_API_KEY"},"model_info":{"mode":"rerank"}}' +add_model "bedrock rerank" '{"model_name":"bedrock-rerank","litellm_params":{"model":"bedrock/amazon.rerank-v1:0","aws_region_name":"os.environ/AWS_REGION_NAME"},"model_info":{"mode":"rerank"}}' +add_model "voyage rerank" '{"model_name":"voyage-rerank","litellm_params":{"model":"voyage/rerank-2","api_key":"os.environ/VOYAGE_API_KEY"},"model_info":{"mode":"rerank"}}' +add_model "infinity rerank" '{"model_name":"infinity-rerank","litellm_params":{"model":"infinity/BAAI/bge-reranker-base","api_base":"os.environ/INFINITY_API_BASE"},"model_info":{"mode":"rerank"}}' + +# ----------------------------------------------------------------------------- +# CASE 8 — Other modes: moderation, completion, responses, realtime. +# ----------------------------------------------------------------------------- +echo "==> Case 8: moderation / completion / responses / realtime" +add_model "moderation" '{"model_name":"omni-moderation","litellm_params":{"model":"openai/omni-moderation-latest","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"moderation"}}' +add_model "text completion" '{"model_name":"gpt-instruct","litellm_params":{"model":"text-completion-openai/gpt-3.5-turbo-instruct","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"completion"}}' +add_model "responses api" '{"model_name":"o1-responses","litellm_params":{"model":"openai/o1-pro","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"responses"}}' +add_model "openai realtime" '{"model_name":"openai-realtime","litellm_params":{"model":"openai/gpt-4o-realtime-preview","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"mode":"realtime"}}' + +# ----------------------------------------------------------------------------- +# CASE 9 — Wildcard / pass-through routing. +# ----------------------------------------------------------------------------- +echo "==> Case 9: wildcard routing" +add_model "catch-all" '{"model_name":"*","litellm_params":{"model":"openai/*","api_key":"os.environ/OPENAI_API_KEY"}}' +add_model "anthropic/*" '{"model_name":"anthropic/*","litellm_params":{"model":"anthropic/*","api_key":"os.environ/ANTHROPIC_API_KEY"}}' +add_model "bedrock/*" '{"model_name":"bedrock/*","litellm_params":{"model":"bedrock/*","aws_region_name":"os.environ/AWS_REGION_NAME"}}' +add_model "groq/*" '{"model_name":"groq/*","litellm_params":{"model":"groq/*","api_key":"os.environ/GROQ_API_KEY"}}' +add_model "vertex_ai/*" '{"model_name":"vertex_ai/*","litellm_params":{"model":"vertex_ai/*","vertex_project":"os.environ/VERTEX_PROJECT","vertex_location":"os.environ/VERTEX_LOCATION"}}' + +# ----------------------------------------------------------------------------- +# CASE 10 — Multi-deployment load balancing (one model_name, many backends). +# ----------------------------------------------------------------------------- +echo "==> Case 10: load-balanced deployments" +add_model "balanced openai" '{"model_name":"gpt-4-balanced","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY","rpm":480},"model_info":{"id":"balanced-openai"}}' + +# ----------------------------------------------------------------------------- +# CASE 11 — Metadata / params: explicit id, base_model, region, limits, +# timeouts, retries, custom pricing, access groups, tags. +# ----------------------------------------------------------------------------- +echo "==> Case 11: metadata & params" +add_model "pinned id" '{"model_name":"gpt-4o-pinned-id","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"id":"my-stable-model-id-001","base_model":"gpt-4o"}}' +add_model "region eu" '{"model_name":"gpt-4o-eu","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY","region_name":"eu"}}' +add_model "throttled" '{"model_name":"gpt-4o-throttled","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY","rpm":100,"tpm":100000,"timeout":300,"stream_timeout":60,"max_retries":3}}' +add_model "custom pricing" '{"model_name":"custom-priced-model","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY","input_cost_per_token":0.0000025,"output_cost_per_token":0.00001}}' +add_model "access groups" '{"model_name":"gpt-4o-restricted","litellm_params":{"model":"openai/gpt-4o","api_key":"os.environ/OPENAI_API_KEY"},"model_info":{"access_groups":["beta-users","internal"],"tags":["production","team-platform"]}}' + +echo +echo "Seeding complete: ${added} model(s) ${DRY_RUN:+(dry-run) }processed, ${failed} failed." + +[[ "${failed}" -gt 0 ]] && exit 1 +exit 0 \ No newline at end of file diff --git a/scripts/litellm-to-bifrost/team.go b/scripts/litellm-to-bifrost/team.go new file mode 100644 index 0000000000..f24f35806d --- /dev/null +++ b/scripts/litellm-to-bifrost/team.go @@ -0,0 +1,51 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// LiteLLMTeamToBifrostTeam transforms a LiteLLM team into a Bifrost +// create-team request. The mapping is pure (no I/O); customerID is resolved by +// the caller (team.organization_id -> org alias -> Bifrost customer id) and is +// nil for a standalone team or one whose customer could not be resolved. +// +// Mapping (mirrors LiteLLMOrganizationToBifrostCustomer; budget/rate-limit +// fields are inline on the team rather than in a budget table): +// - team_alias -> name (required) +// - max_budget -> budgets[0].max_limit (omitted when <= 0) +// - budget_duration -> budgets[0].reset_duration ("mo" -> "M") +// - tpm_limit -> rate_limit.token_max_limit (reset "1m") +// - rpm_limit -> rate_limit.request_max_limit (reset "1m") +// - organization_id -> customer_id (via the resolved customerID) +func LiteLLMTeamToBifrostTeam(team litellm.LiteLLMTeam, customerID *string, cfg MigrationRunConfig) (*BifrostCreateTeamRequest, error) { + name := strings.TrimSpace(team.TeamAlias) + if name == "" { + return nil, fmt.Errorf("team %q has no team_alias; a team name is required", team.TeamID) + } + + req := &BifrostCreateTeamRequest{Name: name, CustomerID: customerID} + + // A team's budget/rate-limit fields are inline; adapt them to the same + // LiteLLMBudget shape the organization migration maps, so the spend-cap and + // rate-limit rules stay identical across entities. + b := litellm.LiteLLMBudget{ + MaxBudget: team.MaxBudget, + BudgetDuration: team.BudgetDuration, + TPMLimit: team.TPMLimit, + RPMLimit: team.RPMLimit, + } + + budget, err := toBudget(b, cfg.MaxBudgetPeriod) + if err != nil { + return nil, fmt.Errorf("team %q: %w", team.TeamID, err) + } + if budget != nil { + req.Budgets = []BifrostCreateBudgetRequest{*budget} + } + req.RateLimit = toRateLimit(b) + + return req, nil +} diff --git a/scripts/litellm-to-bifrost/user.go b/scripts/litellm-to-bifrost/user.go new file mode 100644 index 0000000000..233bc4af67 --- /dev/null +++ b/scripts/litellm-to-bifrost/user.go @@ -0,0 +1,81 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// UserPlan is the planned Bifrost user (POST /api/users) plus the LiteLLM team +// ids it should be linked to. Team-id -> Bifrost-team resolution and the +// membership writes are done by the caller (orchestration), keeping the +// transform pure. +type UserPlan struct { + UserID string // LiteLLM user_id, for logging + Name string // Bifrost user name (required, non-empty) + Email string // Bifrost user email (required, valid) + SourceTeamIDs []string // LiteLLM team_ids to link as memberships +} + +// UserMigrationReport collects everything the transform could not faithfully +// carry, so the operator can act on it. +type UserMigrationReport struct { + SkippedNoEmail []string // LiteLLM has no email; Bifrost requires one + DroppedRoles []string // LiteLLM user_role has no Bifrost numeric role_id mapping + DroppedBudgets []string // user-level spend cap (no field on POST /api/users) + DroppedRateLimits []string // user-level tpm/rpm (no field on POST /api/users) +} + +// LiteLLMUsersToBifrostUsers transforms LiteLLM internal users into Bifrost user +// plans. The mapping is pure (no I/O): +// - user_email -> email (required; user skipped + reported when missing) +// - user_alias -> name (falls back to email when blank) +// - teams -> SourceTeamIDs (linked downstream) +// +// LiteLLM user_role, max_budget/budget_duration and tpm/rpm have no field on +// Bifrost's create-user API (role is a numeric role_id with no LiteLLM mapping; +// user governance is driven by access profiles), so they are reported, not +// carried. +func LiteLLMUsersToBifrostUsers(users []litellm.LiteLLMUser) ([]UserPlan, UserMigrationReport) { + var plans []UserPlan + var report UserMigrationReport + + for _, u := range users { + email := "" + if u.UserEmail != nil { + email = strings.TrimSpace(*u.UserEmail) + } + if email == "" { + report.SkippedNoEmail = append(report.SkippedNoEmail, u.UserID) + continue + } + + name := "" + if u.UserAlias != nil { + name = strings.TrimSpace(*u.UserAlias) + } + if name == "" { + name = email + } + + if u.UserRole != nil && strings.TrimSpace(*u.UserRole) != "" { + report.DroppedRoles = append(report.DroppedRoles, fmt.Sprintf("%s (%s)", u.UserID, strings.TrimSpace(*u.UserRole))) + } + if u.MaxBudget != nil && *u.MaxBudget > 0 { + report.DroppedBudgets = append(report.DroppedBudgets, fmt.Sprintf("%s ($%.2f)", u.UserID, *u.MaxBudget)) + } + if (u.TPMLimit != nil && *u.TPMLimit > 0) || (u.RPMLimit != nil && *u.RPMLimit > 0) { + report.DroppedRateLimits = append(report.DroppedRateLimits, u.UserID) + } + + plans = append(plans, UserPlan{ + UserID: u.UserID, + Name: name, + Email: email, + SourceTeamIDs: u.Teams, + }) + } + + return plans, report +} diff --git a/scripts/litellm-to-bifrost/utils.go b/scripts/litellm-to-bifrost/utils.go new file mode 100644 index 0000000000..2ecc3e108b --- /dev/null +++ b/scripts/litellm-to-bifrost/utils.go @@ -0,0 +1,18 @@ +package main + +import "strings" + +// trimModelPrefix strips the "provider/" prefix from a LiteLLM model string. +// wildcards like "*", "openai/*", "bedrock/*" collapse to "*". +func trimModelPrefix(model string) string { + if i := strings.Index(model, "/"); i >= 0 { + model = model[i+1:] + } + if model == "*" { + return "*" + } + return model +} + +// ptr returns a pointer to v. +func ptr[T any](v T) *T { return &v } diff --git a/scripts/litellm-to-bifrost/virtualkey.go b/scripts/litellm-to-bifrost/virtualkey.go new file mode 100644 index 0000000000..606e6b9425 --- /dev/null +++ b/scripts/litellm-to-bifrost/virtualkey.go @@ -0,0 +1,229 @@ +package main + +import ( + "fmt" + "sort" + "strings" + + "github.com/maximhq/bifrost/scripts/litellm-to-bifrost/litellm" +) + +// VKProviderConfig is the planned per-provider key attachment for a VK. +// KeyIDs ["*"] grants all keys of that provider; specific UUIDs limit to those +// keys. AllowedModels is always ["*"] (no model restriction on the VK). +type VKProviderConfig struct { + Provider string + KeyIDs []string // ["*"] = all keys; specific UUIDs = only those keys +} + +// VKPlan is the planned Bifrost virtual key. Owner ids are LiteLLM ids; the +// caller resolves them to Bifrost team/customer ids (mutually exclusive) before +// the write. The remaining fields are informational for the migration report. +type VKPlan struct { + Name string + OwnerTeamID *string // LiteLLM team_id (resolve -> Bifrost team) + OwnerOrgID *string // LiteLLM org_id (resolve -> Bifrost customer) + ProviderConfigs []VKProviderConfig + Budget *BifrostCreateBudgetRequest + RateLimit *BifrostCreateRateLimitRequest + IsActive *bool // set false when the LiteLLM key is blocked + + // Informational (reported, not carried onto the VK): + SourceUserID *string // LiteLLM user_id — no VK-create field for it + UnmappedModels []string // model names not found in the model index + Expiry *string // LiteLLM expiry — no VK-create field for it +} + +// llmAllModelSentinels are LiteLLM model-list tokens that mean "all models". +var llmAllModelSentinels = map[string]bool{ + "*": true, "all-proxy-models": true, "all-team-models": true, +} + +// LiteLLMVirtualKeyToBifrost transforms a LiteLLM virtual key into a Bifrost VK +// plan. The mapping is pure (no I/O). +// +// Keys: the VK's allowed model list (union of key + team + org) is used to +// select which provider keys to attach. Keys whose model list matches are +// attached; wildcard keys (models=["*"]) are always included when any specific +// provider from that model is matched. When the union is empty or contains an +// "all models" sentinel, all providers are granted via key_ids=["*"]. +// No model restriction is set on the VK — all models are always allowed. +// +// Owner: team_id wins over org_id (Bifrost VK ownership is team XOR customer); +// user_id has no VK-create field and is reported instead. +func LiteLLMVirtualKeyToBifrost(key litellm.LiteLLMVirtualKey, teamModels, orgModels []string, keyModelIdx map[string][]ProviderKeyRef, wildcardKeys []ProviderKeyRef, allProviders []string, maxBudgetPeriod string) (*VKPlan, error) { + name := "" + if key.KeyAlias != nil { + name = strings.TrimSpace(*key.KeyAlias) + } + if name == "" && key.KeyName != nil { + name = strings.TrimSpace(*key.KeyName) + } + if name == "" { + return nil, fmt.Errorf("virtual key has no key_alias or key_name; a name is required") + } + + plan := &VKPlan{Name: name, SourceUserID: key.UserID} + + if key.TeamID != nil && strings.TrimSpace(*key.TeamID) != "" { + plan.OwnerTeamID = key.TeamID + } else if key.OrgID != nil && strings.TrimSpace(*key.OrgID) != "" { + plan.OwnerOrgID = key.OrgID + } + + // Effective model set is the intersection of all non-empty, non-sentinel + // allowed_models lists across the VK, team, and org layers, mirroring + // LiteLLM's own enforcement: a request must satisfy every layer that has a + // restriction. A layer with an empty list or an "all models" sentinel + // imposes no restriction and is skipped. + names := intersectModelNames(key.Models, teamModels, orgModels) + plan.ProviderConfigs, plan.UnmappedModels = resolveProviderKeyConfigs(names, keyModelIdx, wildcardKeys, allProviders) + + b := litellm.LiteLLMBudget{ + MaxBudget: key.MaxBudget, + BudgetDuration: key.BudgetDuration, + TPMLimit: key.TPMLimit, + RPMLimit: key.RPMLimit, + } + budget, err := toBudget(b, maxBudgetPeriod) + if err != nil { + return nil, fmt.Errorf("virtual key %q: %w", name, err) + } + plan.Budget = budget + plan.RateLimit = toRateLimit(b) + + if key.Blocked != nil && *key.Blocked { + inactive := false + plan.IsActive = &inactive + } + + if key.Expires != nil && strings.TrimSpace(*key.Expires) != "" { + plan.Expiry = key.Expires + } + + return plan, nil +} + +// resolveProviderKeyConfigs turns a set of LiteLLM model names into Bifrost VK +// provider configs by selecting which provider keys serve those models. +// +// An empty set (no restriction) or an "all models" sentinel grants key_ids=["*"] +// on every provider in allProviders (allows all keys, all models). +// +// For a specific model list: +// - Keys explicitly modelling those names are attached by UUID. +// - Wildcard keys (models=["*"]) of any provider already matched are also +// attached, since they can serve any request. +// - Models that resolve to no key are returned as unmapped (reported, not granted). +// +// No AllowedModels restriction is set — all models are allowed on attached keys. +func resolveProviderKeyConfigs(names []string, keyModelIdx map[string][]ProviderKeyRef, wildcardKeys []ProviderKeyRef, allProviders []string) ([]VKProviderConfig, []string) { + if len(names) == 0 || hasAllSentinel(names) { + return allProvidersAllKeys(allProviders), nil + } + + // perProvider accumulates UUIDs of keys to attach, keyed by provider. + perProvider := map[string]map[string]bool{} // provider → set of keyIDs + var provOrder []string + var unmapped []string + + for _, n := range names { + refs := keyModelIdx[n] + if len(refs) == 0 { + unmapped = append(unmapped, n) + continue + } + for _, r := range refs { + if perProvider[r.Provider] == nil { + perProvider[r.Provider] = map[string]bool{} + provOrder = append(provOrder, r.Provider) + } + perProvider[r.Provider][r.KeyID] = true + } + } + // Attach wildcard keys only for providers already matched above. + for _, wk := range wildcardKeys { + if perProvider[wk.Provider] != nil { + perProvider[wk.Provider][wk.KeyID] = true + } + } + + sort.Strings(provOrder) + + configs := make([]VKProviderConfig, 0, len(provOrder)) + for _, p := range provOrder { + configs = append(configs, VKProviderConfig{ + Provider: p, + KeyIDs: keysSorted(perProvider[p]), + }) + } + return configs, unmapped +} + +func hasAllSentinel(names []string) bool { + for _, n := range names { + if llmAllModelSentinels[strings.TrimSpace(n)] { + return true + } + } + return false +} + +func allProvidersAllKeys(providers []string) []VKProviderConfig { + configs := make([]VKProviderConfig, 0, len(providers)) + for _, p := range providers { + configs = append(configs, VKProviderConfig{Provider: p, KeyIDs: []string{"*"}}) + } + return configs +} + +// intersectModelNames returns the intersection of all non-empty, non-sentinel +// model lists. A list that is empty or contains only "all models" sentinels +// imposes no restriction at that level and is skipped entirely. If all levels +// are unrestricted, nil is returned (caller maps this to "all models"). +func intersectModelNames(lists ...[]string) []string { + var base map[string]bool + for _, l := range lists { + var active []string + for _, m := range l { + if t := strings.TrimSpace(m); t != "" { + active = append(active, t) + } + } + if len(active) == 0 || hasAllSentinel(active) { + continue + } + if base == nil { + base = make(map[string]bool, len(active)) + for _, m := range active { + base[m] = true + } + } else { + next := make(map[string]bool) + for _, m := range active { + if base[m] { + next[m] = true + } + } + base = next + } + } + if len(base) == 0 { + return nil + } + out := make([]string, 0, len(base)) + for m := range base { + out = append(out, m) + } + sort.Strings(out) + return out +} + +func keysSorted(set map[string]bool) []string { + out := make([]string, 0, len(set)) + for k := range set { + out = append(out, k) + } + sort.Strings(out) + return out +} \ No newline at end of file diff --git a/tests/cmd/e2eseed/go.mod b/tests/cmd/e2eseed/go.mod index 88fcb84760..fe7f6b2a1b 100644 --- a/tests/cmd/e2eseed/go.mod +++ b/tests/cmd/e2eseed/go.mod @@ -80,7 +80,7 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect - github.com/maximhq/bifrost/core v1.5.20 // indirect + github.com/maximhq/bifrost/core v1.5.21 // indirect github.com/maximhq/bifrost/framework v1.3.16 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect diff --git a/tests/cmd/seed/go.mod b/tests/cmd/seed/go.mod index 1700989cfc..edccf3d7f0 100644 --- a/tests/cmd/seed/go.mod +++ b/tests/cmd/seed/go.mod @@ -8,7 +8,7 @@ replace ( ) require ( - github.com/maximhq/bifrost/core v1.5.20 + github.com/maximhq/bifrost/core v1.5.21 github.com/maximhq/bifrost/framework v1.3.16 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 diff --git a/tests/cmd/seedvks/go.mod b/tests/cmd/seedvks/go.mod index d944218601..f60ad8f9bb 100644 --- a/tests/cmd/seedvks/go.mod +++ b/tests/cmd/seedvks/go.mod @@ -9,7 +9,7 @@ replace ( require ( github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.5.20 + github.com/maximhq/bifrost/core v1.5.21 github.com/maximhq/bifrost/framework v1.3.16 gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.31.1 diff --git a/tests/e2e/api/collections/bifrost-api-management.postman_collection.json b/tests/e2e/api/collections/bifrost-api-management.postman_collection.json index cac0b54c30..9a7662282e 100644 --- a/tests/e2e/api/collections/bifrost-api-management.postman_collection.json +++ b/tests/e2e/api/collections/bifrost-api-management.postman_collection.json @@ -2331,6 +2331,57 @@ } } }, + { + "name": "Search Logs", + "event": [ + { + "listen": "test", + "script": { + "exec": [ + "// Content search across logs should return a well-formed paginated response", + "pm.test('Search Logs returns 200', function () {", + " pm.response.to.have.status(200);", + "});", + "pm.test('Search Logs response has logs array and pagination', function () {", + " var jsonData = pm.response.json();", + " pm.expect(jsonData).to.have.property('logs');", + " pm.expect(jsonData.logs).to.be.an('array');", + " pm.expect(jsonData).to.have.property('pagination');", + "});" + ], + "type": "text/javascript" + } + } + ], + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{base_url}}/api/logs?content_search=hello&limit=10&offset=0", + "host": [ + "{{base_url}}" + ], + "path": [ + "api", + "logs" + ], + "query": [ + { + "key": "content_search", + "value": "hello" + }, + { + "key": "limit", + "value": "10" + }, + { + "key": "offset", + "value": "0" + } + ] + } + } + }, { "name": "Get Logs Stats", "request": { diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index ef9253e9cb..766901038b 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -497,10 +497,41 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.RoutingChainMaxDepth = payload.ClientConfig.RoutingChainMaxDepth } - // Update external base URLs for OAuth server metadata and client redirect_uri (nil clears each override). + // Update external base URL for OAuth client redirect_uri (nil clears the override). // Validation is performed up front in this handler so a failure here cannot leave the process in a partial state. updatedConfig.MCPExternalClientURL = payload.ClientConfig.MCPExternalClientURL + // Validate the inbound MCP auth mode against the allowed enum before persisting + // (config.schema.json is the source of truth: headers | both | oauth). + switch payload.ClientConfig.MCPServerAuthMode { + case "", configstoreTables.MCPServerAuthModeHeaders, configstoreTables.MCPServerAuthModeBoth, configstoreTables.MCPServerAuthModeOAuth: + // valid; empty means the field was omitted from a partial update + default: + SendError(ctx, fasthttp.StatusBadRequest, "mcp_server_auth_mode must be one of: headers, both, oauth") + return + } + + // oauth2_server_config only applies when discovery is enabled (both | oauth). + // Evaluate against the effective mode so a partial update that supplies only + // the config cannot smuggle it in while the stored mode is headers. + effectiveAuthMode := payload.ClientConfig.MCPServerAuthMode + if effectiveAuthMode == "" { + effectiveAuthMode = currentConfig.MCPServerAuthMode + } + if payload.ClientConfig.OAuth2ServerConfig != nil && effectiveAuthMode == configstoreTables.MCPServerAuthModeHeaders { + SendError(ctx, fasthttp.StatusBadRequest, "oauth2_server_config is only valid when mcp_server_auth_mode is both or oauth") + return + } + + // Only update each field when explicitly provided so partial /api/config + // payloads do not clear stored values (matches the MCP field handling above). + if payload.ClientConfig.MCPServerAuthMode != "" { + updatedConfig.MCPServerAuthMode = payload.ClientConfig.MCPServerAuthMode + } + if payload.ClientConfig.OAuth2ServerConfig != nil { + updatedConfig.OAuth2ServerConfig = payload.ClientConfig.OAuth2ServerConfig + } + // Handle HeaderFilterConfig changes if !headerFilterConfigEqual(payload.ClientConfig.HeaderFilterConfig, currentConfig.HeaderFilterConfig) { // Validate that no security headers are in the allowlist or denylist diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index dc0e4b4e08..b9a978e8a8 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -862,6 +862,11 @@ func (m *AuthMiddleware) APIMiddleware() schemas.BifrostHTTPMiddleware { // credentials securely. Management endpoints under /api/skills (without // /serve/) remain authenticated. "/api/skills/serve/", + // OAuth2 discovery endpoints (RFC 8414 AS metadata, RFC 9728 protected + // resource metadata, RFC 7517 JWKS) must be reachable without auth so + // clients can bootstrap the flow. Each handler still gates availability + // behind discoveryEnabled() and serves 404 when OAuth mode is off. + "/.well-known/", } return m.middleware(func(authConfig *configstore.AuthConfig, url string) bool { if slices.Contains(systemWhitelistedRoutes, url) || diff --git a/transports/bifrost-http/handlers/oauth2_discovery.go b/transports/bifrost-http/handlers/oauth2_discovery.go new file mode 100644 index 0000000000..3610e4fa6e --- /dev/null +++ b/transports/bifrost-http/handlers/oauth2_discovery.go @@ -0,0 +1,182 @@ +package handlers + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + + "github.com/bytedance/sonic" + "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// OAuth2DiscoveryHandler serves the three well-known discovery endpoints that +// make Bifrost's /mcp endpoint a spec-compliant OAuth 2.1 protected resource: +// +// - GET /.well-known/oauth-protected-resource[/{path}] (RFC 9728 PRM) +// - GET /.well-known/oauth-authorization-server[/{path}] (RFC 8414 AS metadata) +// - GET /.well-known/jwks.json (RFC 7517 JWKS) +// +// All three return 404 when MCPServerAuthMode == "headers" (the default), so +// discovery is only available when the operator explicitly enables OAuth mode. +// Discoverability is the feature toggle. +type OAuth2DiscoveryHandler struct { + store *lib.Config +} + +// NewOAuth2DiscoveryHandler creates a new discovery handler. +func NewOAuth2DiscoveryHandler(store *lib.Config) *OAuth2DiscoveryHandler { + return &OAuth2DiscoveryHandler{store: store} +} + +// RegisterRoutes wires all well-known discovery routes. Routes are always +// registered; individual handlers 404 when discovery is disabled. +func (h *OAuth2DiscoveryHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + // RFC 9728: both root and path-aware well-known forms are required. + r.GET("/.well-known/oauth-protected-resource", lib.ChainMiddlewares(h.handlePRM, middlewares...)) + r.GET("/.well-known/oauth-protected-resource/{path:*}", lib.ChainMiddlewares(h.handlePRM, middlewares...)) + + // RFC 8414: same two forms. + r.GET("/.well-known/oauth-authorization-server", lib.ChainMiddlewares(h.handleASMetadata, middlewares...)) + r.GET("/.well-known/oauth-authorization-server/{path:*}", lib.ChainMiddlewares(h.handleASMetadata, middlewares...)) + + // RFC 7517 JWKS. + r.GET("/.well-known/jwks.json", lib.ChainMiddlewares(h.handleJWKS, middlewares...)) +} + +// discoveryEnabled reports whether OAuth discovery is active, reading the mode +// from the in-memory ClientConfig under the read lock. +func (h *OAuth2DiscoveryHandler) discoveryEnabled() bool { + h.store.Mu.RLock() + enabled := h.store.ClientConfig.IsMCPOAuthDiscoveryEnabled() + h.store.Mu.RUnlock() + return enabled +} + +// handlePRM serves GET /.well-known/oauth-protected-resource[/{path}]. +// RFC 9728 §3 Protected Resource Metadata. +func (h *OAuth2DiscoveryHandler) handlePRM(ctx *fasthttp.RequestCtx) { + if !h.discoveryEnabled() { + ctx.SetStatusCode(fasthttp.StatusNotFound) + return + } + + base := oauth2IssuerURL(ctx, h.store) + doc := map[string]any{ + "resource": oauth2MCPResourceURL(ctx, h.store), + "authorization_servers": []string{base}, + "scopes_supported": []string{"mcp"}, + "bearer_methods_supported": []string{"header"}, + } + data, err := sonic.Marshal(doc) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("marshal protected resource metadata: %v", err)) + return + } + ctx.SetContentType("application/json") + ctx.SetBody(data) +} + +// handleASMetadata serves GET /.well-known/oauth-authorization-server[/{path}]. +// RFC 8414 Authorization Server Metadata. +func (h *OAuth2DiscoveryHandler) handleASMetadata(ctx *fasthttp.RequestCtx) { + if !h.discoveryEnabled() { + ctx.SetStatusCode(fasthttp.StatusNotFound) + return + } + + base := oauth2IssuerURL(ctx, h.store) + doc := map[string]any{ + "issuer": base, + "authorization_endpoint": base + "/oauth2/authorize", + "token_endpoint": base + "/oauth2/token", + "registration_endpoint": base + "/oauth2/register", + "jwks_uri": base + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code", "refresh_token"}, + "code_challenge_methods_supported": []string{"S256"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "scopes_supported": []string{"mcp"}, + // RFC 9207: we include iss in authorization responses. + "authorization_response_iss_parameter_supported": true, + } + data, err := sonic.Marshal(doc) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("marshal authorization server metadata: %v", err)) + return + } + ctx.SetContentType("application/json") + ctx.SetBody(data) +} + +// handleJWKS serves GET /.well-known/jwks.json (RFC 7517). +func (h *OAuth2DiscoveryHandler) handleJWKS(ctx *fasthttp.RequestCtx) { + if !h.discoveryEnabled() { + ctx.SetStatusCode(fasthttp.StatusNotFound) + return + } + + if h.store.ConfigStore == nil { + ctx.SetContentType("application/json") + ctx.SetBodyString(`{"keys":[]}`) + return + } + + key, err := h.store.ConfigStore.GetOAuth2SigningKey(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to load signing key: %v", err)) + return + } + + var jwks []map[string]any + if key != nil { + pub, parseErr := parseRSAPublicKeyPEM(key.PublicKeyPEM) + if parseErr != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to parse signing key: %v", parseErr)) + return + } + jwks = []map[string]any{rsaPublicKeyToJWK(key.KID, "RS256", pub)} + } + + data, err := sonic.Marshal(map[string]any{"keys": jwks}) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("marshal jwks: %v", err)) + return + } + ctx.SetContentType("application/json") + ctx.SetBody(data) +} + +// parseRSAPublicKeyPEM decodes a PEM-encoded RSA public key. +func parseRSAPublicKeyPEM(pemStr string) (*rsa.PublicKey, error) { + block, rest := pem.Decode([]byte(pemStr)) + if block == nil || len(rest) > 0 { + return nil, fmt.Errorf("malformed public key PEM") + } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse public key: %w", err) + } + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("expected RSA public key, got %T", pub) + } + return rsaPub, nil +} + +// rsaPublicKeyToJWK encodes an RSA public key as a JWK (RFC 7517 §6.3). +func rsaPublicKeyToJWK(kid, alg string, pub *rsa.PublicKey) map[string]any { + return map[string]any{ + "kty": "RSA", + "use": "sig", + "kid": kid, + "alg": alg, + "n": base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()), + } +} diff --git a/transports/bifrost-http/handlers/oauth2_utils.go b/transports/bifrost-http/handlers/oauth2_utils.go new file mode 100644 index 0000000000..6a66cb54a5 --- /dev/null +++ b/transports/bifrost-http/handlers/oauth2_utils.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "strings" + + configtables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// oauth2IssuerURL resolves the effective AS issuer URL for a request. +// Uses the explicitly configured IssuerURL when set; falls back to deriving +// it from the request Host header for single-host / dev deployments. +func oauth2IssuerURL(ctx *fasthttp.RequestCtx, store *lib.Config) string { + store.Mu.RLock() + cfg := store.ClientConfig.OAuth2ServerConfig + store.Mu.RUnlock() + if cfg != nil && cfg.IssuerURL.IsSet() { + return cfg.IssuerURL.GetValue() + } + return lib.BuildBaseURL(ctx, "") +} + +// oauth2MCPResourceURL returns the canonical RFC 8707 resource identifier for +// the /mcp endpoint — the single protected resource this server issues tokens +// for. Discovery advertises it, the authorize endpoint pins the request's +// resource parameter to it, and /mcp token verification checks the audience +// against it; routing all three through here keeps them from drifting. +func oauth2MCPResourceURL(ctx *fasthttp.RequestCtx, store *lib.Config) string { + // Trim a trailing slash so a slash-suffixed issuer_url can't produce "//mcp" + // and drift this canonical resource away from what clients normalize to. + return strings.TrimRight(oauth2IssuerURL(ctx, store), "/") + "/mcp" +} + +// oauth2ServerCfg returns the OAuth2 AS-specific config under the read lock, +// falling back to sensible defaults when not yet configured. +func oauth2ServerCfg(store *lib.Config) *configtables.OAuth2ServerConfig { + store.Mu.RLock() + cfg := store.ClientConfig.OAuth2ServerConfig + store.Mu.RUnlock() + if cfg == nil { + return configtables.DefaultOAuth2ServerConfig() + } + return cfg +} diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go index b74399600e..25fe26fddf 100644 --- a/transports/bifrost-http/handlers/plugins.go +++ b/transports/bifrost-http/handlers/plugins.go @@ -561,36 +561,71 @@ func (h *PluginsHandler) deletePlugin(ctx *fasthttp.RequestCtx) { }) } -// restoreRedactedFromExisting walks the incoming config map and, for any field that -// looks like an EnvVar object whose value ShouldPreserveStored (i.e. it is a redacted -// placeholder like "***"), replaces the value with the corresponding field from the -// existing DB config. This mirrors the mergeUpdatedKey pattern used by provider keys. +// restoreRedactedFromExisting walks the incoming config map and, for any field whose +// value is a redacted placeholder (a masked EnvVar object, or a masked plain string), +// replaces it with the corresponding value from the existing DB +// config so client-side redaction never overwrites real credentials. It descends into +// nested maps AND slices (e.g. the OTEL `profiles` array), and handles header values that +// are stored as plain strings rather than EnvVar objects. Mirrors the mergeUpdatedKey +// pattern used by provider keys. func restoreRedactedFromExisting(incoming, existing map[string]any) map[string]any { if len(incoming) == 0 { return incoming } result := make(map[string]any, len(incoming)) for k, v := range incoming { - switch val := v.(type) { - case map[string]any: - if isEnvVarObject(val) { - ev := schemas.NewEnvVar(marshalEnvVarObject(val)) - if ev.ShouldPreserveStored() { - if existingVal, ok := existing[k]; ok { - result[k] = existingVal - continue - } - } - } else if existingNested, ok := existing[k].(map[string]any); ok { - result[k] = restoreRedactedFromExisting(val, existingNested) - continue + result[k] = restoreRedactedValue(v, existing[k]) + } + return result +} + +// restoreRedactedValue restores a single incoming value against its corresponding existing +// value. It recurses through maps and slices, and treats both EnvVar-shaped objects and +// plain redacted strings as placeholders to swap back to the stored original. Returns the +// incoming value unchanged when it is not a redaction placeholder or has no stored match. +func restoreRedactedValue(incoming, existing any) any { + switch val := incoming.(type) { + case map[string]any: + if isEnvVarObject(val) { + if schemas.NewEnvVar(marshalEnvVarObject(val)).ShouldPreserveStored() && existing != nil { + return existing } - result[k] = val - default: - result[k] = v + return val + } + if existingNested, ok := existing.(map[string]any); ok { + return restoreRedactedFromExisting(val, existingNested) + } + return val + case []any: + // Restore element-by-element against the existing slice (index-aligned). New + // elements beyond the existing length carry user-supplied values, so keep them. + existingSlice, ok := existing.([]any) + if !ok { + return val } + out := make([]any, len(val)) + for i, item := range val { + if i < len(existingSlice) { + out[i] = restoreRedactedValue(item, existingSlice[i]) + } else { + out[i] = item + } + } + return out + case string: + // Plain-string secrets (e.g. OTEL headers): restore only when the incoming string + // is a redaction artifact and not an intentional env reference. Empty strings are + // left as-is so clearing a value works. + if existingStr, ok := existing.(string); ok { + envVal := schemas.NewEnvVar(val) + if !envVal.IsFromEnv() && envVal.IsRedacted() { + return existingStr + } + } + return val + default: + return incoming } - return result } // isEnvVarObject returns true if m has exactly the shape of a serialised EnvVar: diff --git a/transports/bifrost-http/handlers/plugins_test.go b/transports/bifrost-http/handlers/plugins_test.go index e9d6580844..83efb8f42f 100644 --- a/transports/bifrost-http/handlers/plugins_test.go +++ b/transports/bifrost-http/handlers/plugins_test.go @@ -77,6 +77,69 @@ func buildUpdateRequest(t *testing.T, body any) *fasthttp.RequestCtx { // config over the existing DB config, preserving fields the caller did not send. // This is critical for the plugin_span_filter field: the OTEL config form in the // UI does not send plugin_span_filter, so it must survive a save without being wiped. +// TestRestoreRedacted_OTELProfilesHeaders covers the two gaps that broke OTEL header +// round-trips after the multi-profile change: (1) headers live inside the `profiles` +// array (slice traversal), and (2) header values are plain redacted strings, not EnvVar +// objects. Saving a config whose headers came back redacted must not overwrite the +// stored credentials. +func TestRestoreRedacted_OTELProfilesHeaders(t *testing.T) { + realAuth := "Basic-REAL-SUPER-SECRET-VALUE" + realVersion := "4" + maskedAuth := schemas.NewEnvVar(realAuth).Redacted().GetValue() // long -> first4 + **** + last4 + maskedVersion := schemas.NewEnvVar(realVersion).Redacted().GetValue() // "4" -> "*" + + mkConfig := func(auth, version string) map[string]any { + return map[string]any{ + "profiles": []any{ + map[string]any{ + "service_name": "langfuse", + "headers": map[string]any{ + "Authorization": auth, + "x-langfuse-ingestion-version": version, + }, + }, + }, + } + } + + existing := mkConfig(realAuth, realVersion) + incoming := mkConfig(maskedAuth, maskedVersion) // what the UI sends back after a redacted GET + + got := restoreRedactedFromExisting(incoming, existing) + headers := got["profiles"].([]any)[0].(map[string]any)["headers"].(map[string]any) + + if headers["Authorization"] != realAuth { + t.Errorf("Authorization not restored: got %q, want %q", headers["Authorization"], realAuth) + } + if headers["x-langfuse-ingestion-version"] != realVersion { + t.Errorf("version not restored: got %q, want %q", headers["x-langfuse-ingestion-version"], realVersion) + } + + // A genuinely changed (non-redacted) header value must pass through untouched. + changed := mkConfig("Basic-A-BRAND-NEW-KEY-VALUE-1234", "3") + got2 := restoreRedactedFromExisting(changed, existing) + headers2 := got2["profiles"].([]any)[0].(map[string]any)["headers"].(map[string]any) + if headers2["Authorization"] != "Basic-A-BRAND-NEW-KEY-VALUE-1234" { + t.Errorf("new Authorization should pass through, got %q", headers2["Authorization"]) + } + if headers2["x-langfuse-ingestion-version"] != "3" { + t.Errorf("new version should pass through, got %q", headers2["x-langfuse-ingestion-version"]) + } + + // An intentional env.* reference (e.g. credential rotation) must pass through. + // NewEnvVar parses the "env." prefix as FromEnv=true, which IsRedacted reports as + // redacted; the IsFromEnv guard must let it through rather than restoring the stored value. + rotated := mkConfig("env.NEW_TOKEN", "env.NEW_VERSION") + got3 := restoreRedactedFromExisting(rotated, existing) + headers3 := got3["profiles"].([]any)[0].(map[string]any)["headers"].(map[string]any) + if headers3["Authorization"] != "env.NEW_TOKEN" { + t.Errorf("env.* Authorization should pass through, got %q", headers3["Authorization"]) + } + if headers3["x-langfuse-ingestion-version"] != "env.NEW_VERSION" { + t.Errorf("env.* version should pass through, got %q", headers3["x-langfuse-ingestion-version"]) + } +} + func TestUpdatePlugin_ConfigMerge(t *testing.T) { SetLogger(&mockLogger{}) diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index eccb201cd9..5624a6ea7b 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -425,6 +425,9 @@ func NewMockConfigStore() *MockConfigStore { func (m *MockConfigStore) RefreshConnectionPool(ctx context.Context) error { return nil } +func (m *MockConfigStore) GetOAuth2SigningKey(ctx context.Context) (*tables.OAuth2SigningKey, error) { + return &tables.OAuth2SigningKey{}, nil +} func (m *MockConfigStore) Ping(ctx context.Context) error { return nil } func (m *MockConfigStore) EncryptPlaintextRows(ctx context.Context) error { return nil } func (m *MockConfigStore) Close(ctx context.Context) error { return nil } @@ -14447,7 +14450,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { t.Run("NetworkConfig_GORMRoundTrip", func(t *testing.T) { networkConfig := &schemas.NetworkConfig{ BaseURL: "https://api.custom.com", - DefaultRequestTimeoutInSeconds: 30, + DefaultRequestTimeoutInSeconds: 300, } providerToSave := tables.TableProvider{ diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 223b9fddec..7f61fce0ed 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -1394,6 +1394,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser promptsHandler := handlers.NewPromptsHandler(s.Config.ConfigStore, promptsReloader) featureFlagsHandler := handlers.NewFeatureFlagsHandler(s.Config.FeatureFlags, s.Config.ConfigStore) // Going ahead with API handlers + handlers.NewOAuth2DiscoveryHandler(s.Config).RegisterRoutes(s.Router, middlewares...) healthHandler.RegisterRoutes(s.Router, middlewares...) providerHandler.RegisterRoutes(s.Router, middlewares...) mcpHandler.RegisterRoutes(s.Router, middlewares...) diff --git a/transports/config.schema.json b/transports/config.schema.json index c3262290e9..f250c9612d 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -272,6 +272,45 @@ } ], "description": "Public base URL Bifrost uses as the redirect_uri when acting as an OAuth client to upstream MCP servers (Notion, Jira, etc.). Set when Bifrost's callback endpoint is reached via a different URL than its server-side metadata. Supports env var syntax: \"env.MY_VAR\"." + }, + "mcp_server_auth_mode": { + "type": "string", + "enum": ["headers", "both", "oauth"], + "description": "How /mcp authenticates inbound MCP clients. 'headers' (default): VK/api-key/session headers only, discovery disabled. 'both': accepts header credentials and Bifrost-issued JWTs, discovery enabled. 'oauth': Bifrost JWTs only — WARNING: disables VK/header MCP access." + }, + "oauth2_server_config": { + "type": "object", + "properties": { + "issuer_url": { + "anyOf": [ + { "type": "string" }, + { + "type": "object", + "properties": { + "value": { "type": "string" }, + "env_var": { "type": "string" }, + "from_env": { "type": "boolean" } + }, + "additionalProperties": false + } + ], + "description": "Stable public URL advertised as the OAuth2 AS issuer in discovery documents and JWT iss claim. Required for multi-host deployments; single-host deployments can omit this (falls back to request Host header). Supports env var syntax: \"env.MY_VAR\"." + }, + "auth_code_ttl": { + "type": "integer", + "minimum": 1, + "default": 600, + "description": "Lifetime of the single-use authorization code in seconds (default: 600)." + }, + "access_token_ttl": { + "type": "integer", + "minimum": 1, + "default": 600, + "description": "Lifetime of the issued JWT Bearer token in seconds (default: 600)." + } + }, + "additionalProperties": false, + "description": "OAuth2 authorization server settings for /mcp. Only relevant when mcp_server_auth_mode is 'both' or 'oauth'." } }, "additionalProperties": false @@ -2935,7 +2974,7 @@ "type": "integer", "minimum": 5, "maximum": 3600, - "description": "Idle timeout per stream chunk in seconds. If no data is received for this many seconds, the stream is closed. Default: 60." + "description": "Idle timeout per stream chunk in seconds. If no data is received for this many seconds, the stream is closed. Default: 120." }, "max_conns_per_host": { "type": "integer", @@ -3004,7 +3043,7 @@ "type": "integer", "minimum": 5, "maximum": 3600, - "description": "Idle timeout per stream chunk in seconds. If no data is received for this many seconds, the stream is closed. Default: 60." + "description": "Idle timeout per stream chunk in seconds. If no data is received for this many seconds, the stream is closed. Default: 120." }, "max_conns_per_host": { "type": "integer", diff --git a/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx b/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx index a9a5668e0a..4f4b669804 100644 --- a/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx +++ b/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx @@ -4,6 +4,7 @@ import { Input } from "@/components/ui/input"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import { Sheet, SheetContent, SheetDescription, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { Switch } from "@/components/ui/switch"; +import { DefaultNetworkConfig } from "@/lib/constants/config"; import { getErrorMessage, useCreateProviderMutation } from "@/lib/store"; import { BaseProvider, ModelProviderName } from "@/lib/types/config"; import { allowedRequestsSchema } from "@/lib/types/schemas"; @@ -102,7 +103,7 @@ export function AddCustomProviderSheetContent({ show = true, onClose, onSave }: network_config: { base_url: data.base_url, allow_private_network: data.allow_private_network ?? false, - default_request_timeout_in_seconds: 30, + default_request_timeout_in_seconds: DefaultNetworkConfig.default_request_timeout_in_seconds, max_retries: 0, retry_backoff_initial: 500, retry_backoff_max: 5000, diff --git a/ui/app/workspace/providers/fragments/networkFormFragment.tsx b/ui/app/workspace/providers/fragments/networkFormFragment.tsx index 7f94e846cd..5c4849496a 100644 --- a/ui/app/workspace/providers/fragments/networkFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/networkFormFragment.tsx @@ -104,7 +104,8 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { ...provider.network_config, base_url: data.network_config?.base_url || undefined, extra_headers: data.network_config?.extra_headers || undefined, - default_request_timeout_in_seconds: data.network_config?.default_request_timeout_in_seconds ?? 30, + default_request_timeout_in_seconds: + data.network_config?.default_request_timeout_in_seconds ?? DefaultNetworkConfig.default_request_timeout_in_seconds, max_retries: data.network_config?.max_retries ?? 0, retry_backoff_initial: data.network_config?.retry_backoff_initial ?? 500, retry_backoff_max: data.network_config?.retry_backoff_max ?? 10000, diff --git a/ui/lib/constants/config.ts b/ui/lib/constants/config.ts index 6db7013bc2..30bfeed17d 100644 --- a/ui/lib/constants/config.ts +++ b/ui/lib/constants/config.ts @@ -83,13 +83,13 @@ export const isKeyRequiredByProvider: Record = { export const DefaultNetworkConfig = { base_url: "", - default_request_timeout_in_seconds: 30, + default_request_timeout_in_seconds: 300, max_retries: 0, retry_backoff_initial: 1000, retry_backoff_max: 10000, insecure_skip_verify: false, ca_cert_pem: { value: "", env_var: "", from_env: false }, - stream_idle_timeout_in_seconds: 60, + stream_idle_timeout_in_seconds: 120, max_conns_per_host: 5000, enforce_http2: false, allow_private_network: false,