Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/agent/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const OVERFLOW_KEEP_ROUNDS = 3;
*/
export class Agent {
private readonly model: string;
private readonly modelProvider?: string;
private readonly maxIterations: number;
private readonly tools: StructuredToolInterface[];
private readonly toolMap: Map<string, StructuredToolInterface>;
Expand All @@ -53,6 +54,7 @@ export class Agent {
concurrencyMap: Map<string, boolean>,
) {
this.model = config.model ?? DEFAULT_MODEL;
this.modelProvider = config.modelProvider;
this.maxIterations = config.maxIterations ?? DEFAULT_MAX_ITERATIONS;
this.tools = tools;
this.toolMap = new Map(tools.map(t => [t.name, t]));
Expand Down Expand Up @@ -160,7 +162,7 @@ export class Agent {
}

const totalTime = Date.now() - ctx.startTime;
const provider = resolveProvider(this.model).displayName;
const provider = resolveProvider(this.model, this.modelProvider).displayName;
yield {
type: 'done',
answer: `Error: ${formatUserFacingError(errorMessage, provider)}`,
Expand Down Expand Up @@ -300,6 +302,7 @@ export class Agent {

for await (const chunk of streamLlmWithMessages(messages, {
model: this.model,
modelProvider: this.modelProvider,
tools: this.tools,
signal: this.signal,
})) {
Expand Down Expand Up @@ -345,6 +348,7 @@ export class Agent {
): Promise<{ response: AIMessage; usage?: TokenUsage }> {
const result = await callLlmWithMessages(messages, {
model: this.model,
modelProvider: this.modelProvider,
tools: this.tools,
signal: this.signal,
});
Expand Down Expand Up @@ -541,7 +545,7 @@ export class Agent {
: estimateTokens(messageState.messages.map(m =>
typeof m.content === 'string' ? m.content : JSON.stringify(m.content),
).join('\n'));
const threshold = getAutoCompactThreshold(this.model);
const threshold = getAutoCompactThreshold(this.model, this.modelProvider);

if (estimatedContextTokens <= threshold) {
return;
Expand All @@ -560,6 +564,7 @@ export class Agent {
yield { type: 'memory_flush', phase: 'start' };
const flushResult = await runMemoryFlush({
model: this.model,
modelProvider: this.modelProvider,
systemPrompt: this.systemPrompt,
query,
toolResults: fullToolResults,
Expand All @@ -583,6 +588,7 @@ export class Agent {
try {
const result = await compactContext({
model: this.model,
modelProvider: this.modelProvider,
systemPrompt: this.systemPrompt,
query,
toolResults: fullToolResults,
Expand Down Expand Up @@ -611,7 +617,7 @@ export class Agent {
success: true,
preCompactTokens: estimatedContextTokens,
postCompactTokens,
compactionModel: resolveProvider(this.model).fastModel ?? this.model,
compactionModel: resolveProvider(this.model, this.modelProvider).fastModel ?? this.model,
};

return;
Expand Down
7 changes: 5 additions & 2 deletions src/agent/compact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ Continue working toward answering the query without asking the user any further
export interface CompactContextParams {
/** Main model name (used to resolve provider and fast model). */
model: string;
/** Explicit provider override from agent settings. */
modelProvider?: string;
/** System prompt for the compaction call. */
systemPrompt: string;
/** Original user query. */
Expand All @@ -192,10 +194,10 @@ export interface CompactResult {
* Throws on failure — caller is responsible for fallback to clearing.
*/
export async function compactContext(params: CompactContextParams): Promise<CompactResult> {
const { model, systemPrompt, query, toolResults, signal } = params;
const { model, modelProvider, systemPrompt, query, toolResults, signal } = params;

// Resolve fast model for the current provider
const provider = resolveProvider(model);
const provider = resolveProvider(model, modelProvider);
const fastModel = provider.fastModel ?? model;

// Build the compaction prompt
Expand All @@ -204,6 +206,7 @@ export async function compactContext(params: CompactContextParams): Promise<Comp
// Call LLM with no tools bound — callLlm returns string in this case
const result = await callLlm(prompt, {
model: fastModel,
modelProvider: provider.id,
systemPrompt,
signal,
});
Expand Down
2 changes: 2 additions & 0 deletions src/memory/flush.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export function shouldRunMemoryFlush(params: {

export async function runMemoryFlush(params: {
model: string;
modelProvider?: string;
systemPrompt: string;
query: string;
toolResults: string;
Expand All @@ -55,6 +56,7 @@ ${MEMORY_FLUSH_PROMPT}

const result = await callLlm(prompt, {
model: params.model,
modelProvider: params.modelProvider,
systemPrompt: params.systemPrompt,
signal: params.signal,
});
Expand Down
25 changes: 14 additions & 11 deletions src/model/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,18 @@ const DEFAULT_FACTORY: ModelFactory = (name, opts) =>

export function getChatModel(
modelName: string = DEFAULT_MODEL,
streaming: boolean = false
streaming: boolean = false,
providerOverride?: string,
): BaseChatModel {
const opts: ModelOpts = { streaming };
const provider = resolveProvider(modelName);
const provider = resolveProvider(modelName, providerOverride);
const factory = MODEL_FACTORIES[provider.id] ?? DEFAULT_FACTORY;
return factory(modelName, opts);
}

interface CallLlmOptions {
model?: string;
modelProvider?: string;
systemPrompt?: string;
outputSchema?: z.ZodType<unknown>;
tools?: StructuredToolInterface[];
Expand Down Expand Up @@ -213,10 +215,10 @@ function buildAnthropicMessages(systemPrompt: string, userPrompt: string) {
}

export async function callLlm(prompt: string, options: CallLlmOptions = {}): Promise<LlmResult> {
const { model = DEFAULT_MODEL, systemPrompt, outputSchema, tools, signal } = options;
const { model = DEFAULT_MODEL, modelProvider, systemPrompt, outputSchema, tools, signal } = options;
const finalSystemPrompt = systemPrompt || DEFAULT_SYSTEM_PROMPT;

const llm = getChatModel(model, false);
const llm = getChatModel(model, false, modelProvider);

// eslint-disable-next-line @typescript-eslint/no-explicit-any
let runnable: Runnable<any, any> = llm;
Expand All @@ -228,7 +230,7 @@ export async function callLlm(prompt: string, options: CallLlmOptions = {}): Pro
}

const invokeOpts = signal ? { signal } : undefined;
const provider = resolveProvider(model);
const provider = resolveProvider(model, modelProvider);
let result;

if (provider.id === 'anthropic') {
Expand Down Expand Up @@ -287,6 +289,7 @@ function annotateSystemMessageForCaching(messages: BaseMessage[]): BaseMessage[]

interface CallLlmWithMessagesOptions {
model?: string;
modelProvider?: string;
tools?: StructuredToolInterface[];
signal?: AbortSignal;
}
Expand All @@ -306,9 +309,9 @@ export async function callLlmWithMessages(
messages: BaseMessage[],
options: CallLlmWithMessagesOptions = {},
): Promise<LlmResult> {
const { model = DEFAULT_MODEL, tools, signal } = options;
const { model = DEFAULT_MODEL, modelProvider, tools, signal } = options;

const llm = getChatModel(model, false);
const llm = getChatModel(model, false, modelProvider);

// eslint-disable-next-line @typescript-eslint/no-explicit-any
let runnable: Runnable<any, any> = llm;
Expand All @@ -318,7 +321,7 @@ export async function callLlmWithMessages(
}

const invokeOpts = signal ? { signal } : undefined;
const provider = resolveProvider(model);
const provider = resolveProvider(model, modelProvider);

// For Anthropic: annotate SystemMessage with cache_control for prompt caching
const finalMessages = provider.id === 'anthropic'
Expand Down Expand Up @@ -349,9 +352,9 @@ export async function* streamLlmWithMessages(
messages: BaseMessage[],
options: CallLlmWithMessagesOptions = {},
): AsyncGenerator<AIMessageChunk, void> {
const { model = DEFAULT_MODEL, tools, signal } = options;
const { model = DEFAULT_MODEL, modelProvider, tools, signal } = options;

const llm = getChatModel(model, true);
const llm = getChatModel(model, true, modelProvider);

// eslint-disable-next-line @typescript-eslint/no-explicit-any
let runnable: Runnable<any, any> = llm;
Expand All @@ -361,7 +364,7 @@ export async function* streamLlmWithMessages(
}

const invokeOpts = signal ? { signal } : undefined;
const provider = resolveProvider(model);
const provider = resolveProvider(model, modelProvider);

const finalMessages = provider.id === 'anthropic'
? annotateSystemMessageForCaching(messages)
Expand Down
22 changes: 22 additions & 0 deletions src/providers.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { describe, expect, test } from 'bun:test';
import { resolveProvider } from './providers.js';

describe('resolveProvider', () => {
test('uses explicit provider before model prefix routing', () => {
const provider = resolveProvider('deepseek-v4-flash', 'openai');

expect(provider.id).toBe('openai');
});

test('falls back to model prefix routing without override', () => {
const provider = resolveProvider('deepseek-v4-flash');

expect(provider.id).toBe('deepseek');
});

test('ignores unknown provider override', () => {
const provider = resolveProvider('claude-sonnet-4-5', 'unknown-provider');

expect(provider.id).toBe('anthropic');
});
});
12 changes: 9 additions & 3 deletions src/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,16 @@ export const PROVIDERS: ProviderDef[] = [
const defaultProvider = PROVIDERS.find((p) => p.id === 'openai')!;

/**
* Resolve the provider for a given model name based on its prefix.
* Falls back to OpenAI when no prefix matches.
* Resolve the provider for a given model name.
* Explicit provider settings take precedence over model-name prefix routing.
* Falls back to OpenAI when no prefix matches or an override is unknown.
*/
export function resolveProvider(modelName: string): ProviderDef {
export function resolveProvider(modelName: string, providerOverride?: string): ProviderDef {
if (providerOverride) {
const provider = getProviderById(providerOverride);
if (provider) return provider;
}

return (
PROVIDERS.find((p) => p.modelPrefix && modelName.startsWith(p.modelPrefix)) ??
defaultProvider
Expand Down
8 changes: 4 additions & 4 deletions src/utils/tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ const DEFAULT_CONTEXT_WINDOW = 128_000;
* Get the effective context window size for a model, accounting for
* reserved output tokens.
*/
export function getEffectiveContextWindow(model: string): number {
const provider = resolveProvider(model);
export function getEffectiveContextWindow(model: string, modelProvider?: string): number {
const provider = resolveProvider(model, modelProvider);
const contextWindow = provider.contextWindow ?? DEFAULT_CONTEXT_WINDOW;
return contextWindow - MAX_OUTPUT_TOKENS_FOR_SUMMARY;
}
Expand All @@ -47,8 +47,8 @@ export function getEffectiveContextWindow(model: string): number {
* This is the token count at which compaction should trigger.
* Formula: effectiveWindow - 13K buffer.
*/
export function getAutoCompactThreshold(model: string): number {
return getEffectiveContextWindow(model) - AUTOCOMPACT_BUFFER_TOKENS;
export function getAutoCompactThreshold(model: string, modelProvider?: string): number {
return getEffectiveContextWindow(model, modelProvider) - AUTOCOMPACT_BUFFER_TOKENS;
}

// ---------------------------------------------------------------------------
Expand Down
Loading