diff --git a/src/agent/agent.ts b/src/agent/agent.ts index c335024b8..36696a1ea 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -36,6 +36,7 @@ const OVERFLOW_KEEP_ROUNDS = 3; */ export class Agent { private readonly model: string; + private readonly modelProvider?: string; private readonly maxIterations: number; private readonly tools: StructuredToolInterface[]; private readonly toolMap: Map; @@ -53,6 +54,7 @@ export class Agent { concurrencyMap: Map, ) { this.model = config.model ?? DEFAULT_MODEL; + this.modelProvider = config.modelProvider; this.maxIterations = config.maxIterations ?? DEFAULT_MAX_ITERATIONS; this.tools = tools; this.toolMap = new Map(tools.map(t => [t.name, t])); @@ -160,7 +162,7 @@ export class Agent { } const totalTime = Date.now() - ctx.startTime; - const provider = resolveProvider(this.model).displayName; + const provider = resolveProvider(this.model, this.modelProvider).displayName; yield { type: 'done', answer: `Error: ${formatUserFacingError(errorMessage, provider)}`, @@ -300,6 +302,7 @@ export class Agent { for await (const chunk of streamLlmWithMessages(messages, { model: this.model, + modelProvider: this.modelProvider, tools: this.tools, signal: this.signal, })) { @@ -345,6 +348,7 @@ export class Agent { ): Promise<{ response: AIMessage; usage?: TokenUsage }> { const result = await callLlmWithMessages(messages, { model: this.model, + modelProvider: this.modelProvider, tools: this.tools, signal: this.signal, }); @@ -541,7 +545,7 @@ export class Agent { : estimateTokens(messageState.messages.map(m => typeof m.content === 'string' ? m.content : JSON.stringify(m.content), ).join('\n')); - const threshold = getAutoCompactThreshold(this.model); + const threshold = getAutoCompactThreshold(this.model, this.modelProvider); if (estimatedContextTokens <= threshold) { return; @@ -560,6 +564,7 @@ export class Agent { yield { type: 'memory_flush', phase: 'start' }; const flushResult = await runMemoryFlush({ model: this.model, + modelProvider: this.modelProvider, systemPrompt: this.systemPrompt, query, toolResults: fullToolResults, @@ -583,6 +588,7 @@ export class Agent { try { const result = await compactContext({ model: this.model, + modelProvider: this.modelProvider, systemPrompt: this.systemPrompt, query, toolResults: fullToolResults, @@ -611,7 +617,7 @@ export class Agent { success: true, preCompactTokens: estimatedContextTokens, postCompactTokens, - compactionModel: resolveProvider(this.model).fastModel ?? this.model, + compactionModel: resolveProvider(this.model, this.modelProvider).fastModel ?? this.model, }; return; diff --git a/src/agent/compact.ts b/src/agent/compact.ts index 4fc1b8669..5125b8aa7 100644 --- a/src/agent/compact.ts +++ b/src/agent/compact.ts @@ -168,6 +168,8 @@ Continue working toward answering the query without asking the user any further export interface CompactContextParams { /** Main model name (used to resolve provider and fast model). */ model: string; + /** Explicit provider override from agent settings. */ + modelProvider?: string; /** System prompt for the compaction call. */ systemPrompt: string; /** Original user query. */ @@ -192,10 +194,10 @@ export interface CompactResult { * Throws on failure — caller is responsible for fallback to clearing. */ export async function compactContext(params: CompactContextParams): Promise { - const { model, systemPrompt, query, toolResults, signal } = params; + const { model, modelProvider, systemPrompt, query, toolResults, signal } = params; // Resolve fast model for the current provider - const provider = resolveProvider(model); + const provider = resolveProvider(model, modelProvider); const fastModel = provider.fastModel ?? model; // Build the compaction prompt @@ -204,6 +206,7 @@ export async function compactContext(params: CompactContextParams): Promise export function getChatModel( modelName: string = DEFAULT_MODEL, - streaming: boolean = false + streaming: boolean = false, + providerOverride?: string, ): BaseChatModel { const opts: ModelOpts = { streaming }; - const provider = resolveProvider(modelName); + const provider = resolveProvider(modelName, providerOverride); const factory = MODEL_FACTORIES[provider.id] ?? DEFAULT_FACTORY; return factory(modelName, opts); } interface CallLlmOptions { model?: string; + modelProvider?: string; systemPrompt?: string; outputSchema?: z.ZodType; tools?: StructuredToolInterface[]; @@ -213,10 +215,10 @@ function buildAnthropicMessages(systemPrompt: string, userPrompt: string) { } export async function callLlm(prompt: string, options: CallLlmOptions = {}): Promise { - const { model = DEFAULT_MODEL, systemPrompt, outputSchema, tools, signal } = options; + const { model = DEFAULT_MODEL, modelProvider, systemPrompt, outputSchema, tools, signal } = options; const finalSystemPrompt = systemPrompt || DEFAULT_SYSTEM_PROMPT; - const llm = getChatModel(model, false); + const llm = getChatModel(model, false, modelProvider); // eslint-disable-next-line @typescript-eslint/no-explicit-any let runnable: Runnable = llm; @@ -228,7 +230,7 @@ export async function callLlm(prompt: string, options: CallLlmOptions = {}): Pro } const invokeOpts = signal ? { signal } : undefined; - const provider = resolveProvider(model); + const provider = resolveProvider(model, modelProvider); let result; if (provider.id === 'anthropic') { @@ -287,6 +289,7 @@ function annotateSystemMessageForCaching(messages: BaseMessage[]): BaseMessage[] interface CallLlmWithMessagesOptions { model?: string; + modelProvider?: string; tools?: StructuredToolInterface[]; signal?: AbortSignal; } @@ -306,9 +309,9 @@ export async function callLlmWithMessages( messages: BaseMessage[], options: CallLlmWithMessagesOptions = {}, ): Promise { - const { model = DEFAULT_MODEL, tools, signal } = options; + const { model = DEFAULT_MODEL, modelProvider, tools, signal } = options; - const llm = getChatModel(model, false); + const llm = getChatModel(model, false, modelProvider); // eslint-disable-next-line @typescript-eslint/no-explicit-any let runnable: Runnable = llm; @@ -318,7 +321,7 @@ export async function callLlmWithMessages( } const invokeOpts = signal ? { signal } : undefined; - const provider = resolveProvider(model); + const provider = resolveProvider(model, modelProvider); // For Anthropic: annotate SystemMessage with cache_control for prompt caching const finalMessages = provider.id === 'anthropic' @@ -349,9 +352,9 @@ export async function* streamLlmWithMessages( messages: BaseMessage[], options: CallLlmWithMessagesOptions = {}, ): AsyncGenerator { - const { model = DEFAULT_MODEL, tools, signal } = options; + const { model = DEFAULT_MODEL, modelProvider, tools, signal } = options; - const llm = getChatModel(model, true); + const llm = getChatModel(model, true, modelProvider); // eslint-disable-next-line @typescript-eslint/no-explicit-any let runnable: Runnable = llm; @@ -361,7 +364,7 @@ export async function* streamLlmWithMessages( } const invokeOpts = signal ? { signal } : undefined; - const provider = resolveProvider(model); + const provider = resolveProvider(model, modelProvider); const finalMessages = provider.id === 'anthropic' ? annotateSystemMessageForCaching(messages) diff --git a/src/providers.test.ts b/src/providers.test.ts new file mode 100644 index 000000000..45aa3e8d6 --- /dev/null +++ b/src/providers.test.ts @@ -0,0 +1,22 @@ +import { describe, expect, test } from 'bun:test'; +import { resolveProvider } from './providers.js'; + +describe('resolveProvider', () => { + test('uses explicit provider before model prefix routing', () => { + const provider = resolveProvider('deepseek-v4-flash', 'openai'); + + expect(provider.id).toBe('openai'); + }); + + test('falls back to model prefix routing without override', () => { + const provider = resolveProvider('deepseek-v4-flash'); + + expect(provider.id).toBe('deepseek'); + }); + + test('ignores unknown provider override', () => { + const provider = resolveProvider('claude-sonnet-4-5', 'unknown-provider'); + + expect(provider.id).toBe('anthropic'); + }); +}); diff --git a/src/providers.ts b/src/providers.ts index 0fd945802..165112a1c 100644 --- a/src/providers.ts +++ b/src/providers.ts @@ -86,10 +86,16 @@ export const PROVIDERS: ProviderDef[] = [ const defaultProvider = PROVIDERS.find((p) => p.id === 'openai')!; /** - * Resolve the provider for a given model name based on its prefix. - * Falls back to OpenAI when no prefix matches. + * Resolve the provider for a given model name. + * Explicit provider settings take precedence over model-name prefix routing. + * Falls back to OpenAI when no prefix matches or an override is unknown. */ -export function resolveProvider(modelName: string): ProviderDef { +export function resolveProvider(modelName: string, providerOverride?: string): ProviderDef { + if (providerOverride) { + const provider = getProviderById(providerOverride); + if (provider) return provider; + } + return ( PROVIDERS.find((p) => p.modelPrefix && modelName.startsWith(p.modelPrefix)) ?? defaultProvider diff --git a/src/utils/tokens.ts b/src/utils/tokens.ts index 247d27bd6..4e0c5fa39 100644 --- a/src/utils/tokens.ts +++ b/src/utils/tokens.ts @@ -36,8 +36,8 @@ const DEFAULT_CONTEXT_WINDOW = 128_000; * Get the effective context window size for a model, accounting for * reserved output tokens. */ -export function getEffectiveContextWindow(model: string): number { - const provider = resolveProvider(model); +export function getEffectiveContextWindow(model: string, modelProvider?: string): number { + const provider = resolveProvider(model, modelProvider); const contextWindow = provider.contextWindow ?? DEFAULT_CONTEXT_WINDOW; return contextWindow - MAX_OUTPUT_TOKENS_FOR_SUMMARY; } @@ -47,8 +47,8 @@ export function getEffectiveContextWindow(model: string): number { * This is the token count at which compaction should trigger. * Formula: effectiveWindow - 13K buffer. */ -export function getAutoCompactThreshold(model: string): number { - return getEffectiveContextWindow(model) - AUTOCOMPACT_BUFFER_TOKENS; +export function getAutoCompactThreshold(model: string, modelProvider?: string): number { + return getEffectiveContextWindow(model, modelProvider) - AUTOCOMPACT_BUFFER_TOKENS; } // ---------------------------------------------------------------------------