Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/helm-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ jobs:
cd helm-charts
gh release create "helm-chart-v${{ steps.chart-version.outputs.version }}" \
bifrost-${{ steps.chart-version.outputs.version }}.tgz \
--target ${{ github.sha }} \
--title "Helm Chart v${{ steps.chart-version.outputs.version }}" \
--notes "Helm chart release for Bifrost v${{ steps.chart-version.outputs.version }}"
env:
Expand Down
17 changes: 14 additions & 3 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,20 @@ func (bifrost *Bifrost) ListAllModels(ctx *schemas.BifrostContext, req *schemas.

response, bifrostErr := bifrost.ListModelsRequest(providerCtx, providerRequest)
if bifrostErr != nil {
// Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured
if !strings.Contains(bifrostErr.Error.Message, "no keys found") &&
!strings.Contains(bifrostErr.Error.Message, "not supported") {
// Some per-provider failures are expected when fanning out across all
// configured providers and must not be surfaced as a top-level error
errType := ""
if bifrostErr.Type != nil {
errType = *bifrostErr.Type
}
errMsg := ""
if bifrostErr.Error != nil {
errMsg = bifrostErr.Error.Message
}
isExpected := strings.Contains(errMsg, "no keys found") ||
strings.Contains(errMsg, "not supported") ||
errType == "provider_blocked"
if !isExpected {
providerErr = bifrostErr
bifrost.logger.Warn("failed to list models for provider %s: %s", providerKey, bifrostErr.GetErrorString())
}
Expand Down
2 changes: 1 addition & 1 deletion core/bifrost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ func (ma *MockAccount) AddProviderWithBaseURL(provider schemas.ModelProvider, co
ma.configs[provider] = &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: baseURL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
MaxRetries: 3,
RetryBackoffInitial: 500 * time.Millisecond,
RetryBackoffMax: 5 * time.Second,
Expand Down
16 changes: 8 additions & 8 deletions core/providers/bedrock/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func generateTestCACert(t *testing.T) string {
func TestBedrockTransportHTTP2Config(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
MaxConnsPerHost: 5000,
EnforceHTTP2: true,
},
Expand All @@ -570,7 +570,7 @@ func TestBedrockTransportHTTP2Config(t *testing.T) {
func TestBedrockTransportCustomMaxConns(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
MaxConnsPerHost: 50,
},
}
Expand All @@ -590,7 +590,7 @@ func TestBedrockTransportCustomMaxConns(t *testing.T) {
func TestBedrockTransportDefaultMaxConns(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
// MaxConnsPerHost left as 0 — should default to 5000
},
}
Expand All @@ -612,7 +612,7 @@ func TestBedrockTransportDefaultMaxConns(t *testing.T) {
func TestBedrockTransportTLSInsecureSkipVerify(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
InsecureSkipVerify: true,
EnforceHTTP2: true,
},
Expand All @@ -636,7 +636,7 @@ func TestBedrockTransportTLSCACert(t *testing.T) {

config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
CACertPEM: schemas.NewEnvVar(testCACert),
EnforceHTTP2: true,
},
Expand All @@ -657,7 +657,7 @@ func TestBedrockTransportTLSCACert(t *testing.T) {
func TestBedrockTransportDefaultTLS(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
// No TLS settings — should use system defaults
},
}
Expand All @@ -677,7 +677,7 @@ func TestBedrockTransportDefaultTLS(t *testing.T) {
func TestBedrockTransportEnforceHTTP2(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
EnforceHTTP2: true,
},
}
Expand All @@ -696,7 +696,7 @@ func TestBedrockTransportEnforceHTTP2(t *testing.T) {
func TestBedrockTransportEnforceHTTP2Disabled(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
EnforceHTTP2: false,
},
}
Expand Down
2 changes: 1 addition & 1 deletion core/providers/fireworks/fireworks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ func newTestFireworksProvider(t *testing.T, baseURL string) *fireworksprovider.F
provider, err := fireworksprovider.NewFireworksProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: baseURL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, bifrost.NewNoOpLogger())
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions core/providers/mistral/ocr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ func TestOCRWithMockServer(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -639,7 +639,7 @@ func TestOCRNilInput(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -686,7 +686,7 @@ func TestOCRRequestValidation(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down
12 changes: 6 additions & 6 deletions core/providers/mistral/transcription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ func TestTranscriptionWithMockServer(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -637,7 +637,7 @@ func TestTranscriptionNilInput(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -787,7 +787,7 @@ func TestTranscriptionStreamWithMockServer(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -842,7 +842,7 @@ func TestTranscriptionStreamNilInput(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -1272,7 +1272,7 @@ func TestTranscriptionStreamEdgeCases(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down Expand Up @@ -1342,7 +1342,7 @@ func TestTranscriptionStreamContextCancellation(t *testing.T) {
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
DefaultRequestTimeoutInSeconds: 300,
},
}, &testLogger{})

Expand Down
57 changes: 57 additions & 0 deletions core/providers/openai/chatgpt_passthrough.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package openai

import (
"encoding/base64"
"strings"

"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)

const (
chatGPTAccountIDKey = "chatgpt_account_id"
openAIAuthClaim = "https://api.openai.com/auth"

// ChatGPTCodexURL is the full upstream URL for ChatGPT subscription token requests.
ChatGPTCodexURL = "https://chatgpt.com/backend-api/codex/responses"
)

// ParseChatGPTJWT parses a raw bearer token, checks for the ChatGPT subscription
// JWT claim, and returns the chatgpt_account_id. No signature verification is
// Returns ("", false) for any non-ChatGPT or malformed token.
func ParseChatGPTJWT(token string) (accountID string, ok bool) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", false
}

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", false
}

// Extract the nested claim: {"https://api.openai.com/auth": {"chatgpt_account_id": "..."}}
var claims map[string]interface{}
if err := sonic.Unmarshal(payload, &claims); err != nil {
return "", false
}

authClaim, ok := claims[openAIAuthClaim].(map[string]interface{})
if !ok {
return "", false
}

accountID, ok = authClaim[chatGPTAccountIDKey].(string)
if !ok || accountID == "" {
return "", false
}

return accountID, true
}

// IsChatGPTPassthrough reports whether the current request was auto-detected
// as a ChatGPT subscription token and should be routed to chatgpt.com.
func IsChatGPTPassthrough(ctx *schemas.BifrostContext) bool {
v, _ := ctx.Value(schemas.BifrostContextKeyChatGPTPassthrough).(bool)
return v
}
90 changes: 90 additions & 0 deletions core/providers/openai/chatgpt_passthrough_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package openai

import (
"encoding/base64"
"fmt"
"testing"
)

// makeTestJWT builds a syntactically valid JWT with arbitrary header/payload JSON.
// The signature segment is a fixed placeholder — ParseChatGPTJWT never verifies it.
func makeTestJWT(payloadJSON string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
payload := base64.RawURLEncoding.EncodeToString([]byte(payloadJSON))
return fmt.Sprintf("%s.%s.fakesig", header, payload)
}

func TestParseChatGPTJWT(t *testing.T) {
validAccountID := "9dce4683-94cd-4aeb-ade4-4ecce82ebac5"

tests := []struct {
name string
token string
wantID string
wantOK bool
}{
{
name: "valid ChatGPT JWT returns account ID",
token: makeTestJWT(fmt.Sprintf(
`{"aud":["https://api.openai.com/v1"],"https://api.openai.com/auth":{"chatgpt_account_id":%q}}`,
validAccountID,
)),
wantID: validAccountID,
wantOK: true,
},
{
name: "JWT missing chatgpt_account_id claim returns false",
token: makeTestJWT(`{"aud":["https://api.openai.com/v1"],"sub":"user-abc"}`),
wantID: "",
wantOK: false,
},
{
name: "JWT with https://api.openai.com/auth but no chatgpt_account_id returns false",
token: makeTestJWT(`{"https://api.openai.com/auth":{"other_field":"value"}}`),
wantID: "",
wantOK: false,
},
{
name: "not a JWT (sk- API key) returns false",
token: "sk-proj-abcdefghijklmnopqrstuvwxyz",
wantID: "",
wantOK: false,
},
{
name: "empty string returns false",
token: "",
wantID: "",
wantOK: false,
},
{
name: "only two segments returns false",
token: "header.payload",
wantID: "",
wantOK: false,
},
{
name: "invalid base64 in payload returns false",
token: "header.!!!invalid!!!.sig",
wantID: "",
wantOK: false,
},
{
name: "payload is valid base64 but not JSON returns false",
token: fmt.Sprintf("header.%s.sig", base64.RawURLEncoding.EncodeToString([]byte("not-json"))),
wantID: "",
wantOK: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotID, gotOK := ParseChatGPTJWT(tt.token)
if gotOK != tt.wantOK {
t.Errorf("ParseChatGPTJWT() ok = %v, want %v", gotOK, tt.wantOK)
}
if gotID != tt.wantID {
t.Errorf("ParseChatGPTJWT() accountID = %q, want %q", gotID, tt.wantID)
}
})
}
}
Loading