From 688197dc955042dde9e2a2816866b03f47b6aa57 Mon Sep 17 00:00:00 2001 From: zerob13 Date: Thu, 2 Apr 2026 21:12:12 +0800 Subject: [PATCH 1/3] feat(cache): add provider prompt cache planner --- .../deepchatAgentPresenter/accumulator.ts | 1 + .../presenter/deepchatAgentPresenter/index.ts | 3 +- .../deepchatAgentPresenter/process.ts | 3 + .../promptCacheCapabilities.ts | 42 +++ .../promptCacheStrategy.ts | 328 ++++++++++++++++++ .../providers/anthropicProvider.ts | 125 ++++++- .../providers/awsBedrockProvider.ts | 103 +++++- .../providers/openAICompatibleProvider.ts | 69 +++- .../providers/openAIResponsesProvider.ts | 63 +++- src/main/presenter/newAgentPresenter/index.ts | 8 +- .../tables/deepchatUsageStats.ts | 14 +- src/main/presenter/usageStats.ts | 27 +- src/shared/types/agent-interface.d.ts | 1 + src/shared/types/core/llm-events.ts | 2 + .../accumulator.test.ts | 9 +- .../messageStore.test.ts | 7 +- .../anthropicProvider.test.ts | 129 ++++++- .../awsBedrockProvider.test.ts | 207 +++++++++++ .../openAICompatibleProvider.test.ts | 128 +++++++ .../openAIResponsesProvider.test.ts | 49 ++- .../promptCacheStrategy.test.ts | 171 +++++++++ .../newAgentPresenter/usageDashboard.test.ts | 12 +- test/main/presenter/sqlitePresenter.test.ts | 100 ++++++ test/main/presenter/usageStats.test.ts | 107 ++++++ 24 files changed, 1651 insertions(+), 57 deletions(-) create mode 100644 src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts create mode 100644 src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts create mode 100644 test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts create mode 100644 test/main/presenter/llmProviderPresenter/promptCacheStrategy.test.ts create mode 100644 test/main/presenter/usageStats.test.ts diff --git a/src/main/presenter/deepchatAgentPresenter/accumulator.ts b/src/main/presenter/deepchatAgentPresenter/accumulator.ts index c5b2cddda..b91905170 100644 --- a/src/main/presenter/deepchatAgentPresenter/accumulator.ts +++ b/src/main/presenter/deepchatAgentPresenter/accumulator.ts @@ -168,6 +168,7 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void state.metadata.outputTokens = event.usage.completion_tokens state.metadata.totalTokens = event.usage.total_tokens state.metadata.cachedInputTokens = event.usage.cached_tokens + state.metadata.cacheWriteInputTokens = event.usage.cache_write_tokens break } case 'stop': { diff --git a/src/main/presenter/deepchatAgentPresenter/index.ts b/src/main/presenter/deepchatAgentPresenter/index.ts index 54c69daac..e7668a163 100644 --- a/src/main/presenter/deepchatAgentPresenter/index.ts +++ b/src/main/presenter/deepchatAgentPresenter/index.ts @@ -1405,7 +1405,8 @@ export class DeepChatAgentPresenter implements IAgentImplementation { maxTokens: generationSettings.maxTokens, thinkingBudget: generationSettings.thinkingBudget, reasoningEffort: generationSettings.reasoningEffort, - verbosity: generationSettings.verbosity + verbosity: generationSettings.verbosity, + conversationId: sessionId } const traceEnabled = this.configPresenter.getSetting('traceDebugEnabled') === true diff --git a/src/main/presenter/deepchatAgentPresenter/process.ts b/src/main/presenter/deepchatAgentPresenter/process.ts index 6b341fabd..c075da570 100644 --- a/src/main/presenter/deepchatAgentPresenter/process.ts +++ b/src/main/presenter/deepchatAgentPresenter/process.ts @@ -308,5 +308,8 @@ function buildUsageSnapshot(state: StreamState): Record { if (typeof state.metadata.cachedInputTokens === 'number') { usage.cachedInputTokens = state.metadata.cachedInputTokens } + if (typeof state.metadata.cacheWriteInputTokens === 'number') { + usage.cacheWriteInputTokens = state.metadata.cacheWriteInputTokens + } return usage } diff --git a/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts b/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts new file mode 100644 index 000000000..b06f597cb --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts @@ -0,0 +1,42 @@ +export type PromptCacheMode = + | 'disabled' + | 'openai_implicit' + | 'anthropic_auto' + | 'anthropic_explicit' + +function normalizeId(value: string | undefined): string { + return value?.trim().toLowerCase() ?? '' +} + +function isClaudeModel(modelId: string): boolean { + return modelId.includes('claude') +} + +export function resolvePromptCacheMode(providerId: string, modelId: string): PromptCacheMode { + const normalizedProviderId = normalizeId(providerId) + const normalizedModelId = normalizeId(modelId) + + if (normalizedProviderId === 'openai') { + return 'openai_implicit' + } + + if (normalizedProviderId === 'anthropic' && isClaudeModel(normalizedModelId)) { + return 'anthropic_auto' + } + + if ( + normalizedProviderId === 'aws-bedrock' && + (normalizedModelId.includes('anthropic.claude') || isClaudeModel(normalizedModelId)) + ) { + return 'anthropic_explicit' + } + + if ( + normalizedProviderId === 'openrouter' && + (normalizedModelId.startsWith('anthropic/') || isClaudeModel(normalizedModelId)) + ) { + return 'anthropic_explicit' + } + + return 'disabled' +} diff --git a/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts b/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts new file mode 100644 index 000000000..8de06538c --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts @@ -0,0 +1,328 @@ +import { createHash } from 'crypto' +import Anthropic from '@anthropic-ai/sdk' +import type { ChatCompletionContentPart, ChatCompletionMessageParam } from 'openai/resources' +import type { MCPToolDefinition } from '@shared/presenter' +import { resolvePromptCacheMode, type PromptCacheMode } from './promptCacheCapabilities' + +export type PromptCacheApiType = 'openai_chat' | 'openai_responses' | 'anthropic' +export type PromptCacheTtl = '5m' + +export interface PromptCacheBreakpointPlan { + messageIndex: number + contentIndex: number +} + +export interface PromptCachePlan { + mode: PromptCacheMode + ttl: PromptCacheTtl | null + cacheKey?: string + breakpointPlan?: PromptCacheBreakpointPlan +} + +export interface ResolvePromptCachePlanParams { + providerId: string + apiType: PromptCacheApiType + modelId: string + messages: unknown[] + tools?: MCPToolDefinition[] + conversationId?: string +} + +type EphemeralCacheControl = { type: 'ephemeral' } + +const EPHEMERAL_CACHE_CONTROL: EphemeralCacheControl = { type: 'ephemeral' } + +type AnthropicTextBlockWithCache = Anthropic.TextBlockParam & { + cache_control?: EphemeralCacheControl +} + +function normalizeId(value: string | undefined): string { + return value?.trim().toLowerCase() ?? '' +} + +function buildPromptCacheKey( + providerId: string, + modelId: string, + conversationId?: string +): string | undefined { + const normalizedConversationId = conversationId?.trim() + if (!normalizedConversationId) { + return undefined + } + + const digest = createHash('sha256') + .update(`${normalizeId(providerId)}:${normalizeId(modelId)}:${normalizedConversationId}`) + .digest('hex') + .slice(0, 20) + + return `deepchat:${normalizeId(providerId)}:${normalizeId(modelId)}:${digest}` +} + +function findOpenAIChatBreakpoint( + messages: ChatCompletionMessageParam[] +): PromptCacheBreakpointPlan | undefined { + let prefixEnd = messages.length + + while (prefixEnd > 0) { + const role = messages[prefixEnd - 1]?.role + if (role === 'user' || role === 'tool') { + prefixEnd -= 1 + continue + } + break + } + + for (let messageIndex = prefixEnd - 1; messageIndex >= 0; messageIndex -= 1) { + const message = messages[messageIndex] + if (!message || message.role === 'tool') { + continue + } + + const content = 'content' in message ? message.content : undefined + if (typeof content === 'string') { + if (content.trim()) { + return { messageIndex, contentIndex: 0 } + } + continue + } + + if (!Array.isArray(content)) { + continue + } + + for (let contentIndex = content.length - 1; contentIndex >= 0; contentIndex -= 1) { + const part = content[contentIndex] + if (part?.type === 'text' && typeof part.text === 'string' && part.text.trim()) { + return { messageIndex, contentIndex } + } + } + } + + return undefined +} + +function findAnthropicBreakpoint( + messages: Anthropic.MessageParam[] +): PromptCacheBreakpointPlan | undefined { + let prefixEnd = messages.length + + while (prefixEnd > 0) { + const role = messages[prefixEnd - 1]?.role + if (role === 'user') { + prefixEnd -= 1 + continue + } + break + } + + for (let messageIndex = prefixEnd - 1; messageIndex >= 0; messageIndex -= 1) { + const message = messages[messageIndex] + if (!message) { + continue + } + + const content = message.content + if (typeof content === 'string') { + if (content.trim()) { + return { messageIndex, contentIndex: 0 } + } + continue + } + + if (!Array.isArray(content)) { + continue + } + + for (let contentIndex = content.length - 1; contentIndex >= 0; contentIndex -= 1) { + const block = content[contentIndex] + if (block?.type === 'text' && typeof block.text === 'string' && block.text.trim()) { + return { messageIndex, contentIndex } + } + } + } + + return undefined +} + +export function resolvePromptCachePlan(params: ResolvePromptCachePlanParams): PromptCachePlan { + const mode = resolvePromptCacheMode(params.providerId, params.modelId) + + if (mode === 'disabled') { + return { mode, ttl: null } + } + + if (mode === 'openai_implicit') { + return { + mode, + ttl: null, + cacheKey: buildPromptCacheKey(params.providerId, params.modelId, params.conversationId) + } + } + + if (mode === 'anthropic_auto') { + return { + mode, + ttl: '5m' + } + } + + const breakpointPlan = + params.apiType === 'anthropic' + ? findAnthropicBreakpoint(params.messages as Anthropic.MessageParam[]) + : findOpenAIChatBreakpoint(params.messages as ChatCompletionMessageParam[]) + + return { + mode, + ttl: '5m', + breakpointPlan + } +} + +export function applyOpenAIPromptCacheKey>( + requestParams: T, + plan: PromptCachePlan +): T { + if (plan.mode !== 'openai_implicit' || !plan.cacheKey) { + return requestParams + } + + return { + ...requestParams, + prompt_cache_key: plan.cacheKey + } +} + +export function applyAnthropicTopLevelCacheControl>( + requestParams: T, + plan: PromptCachePlan +): T { + if (plan.mode !== 'anthropic_auto') { + return requestParams + } + + return { + ...requestParams, + cache_control: EPHEMERAL_CACHE_CONTROL + } +} + +export function applyOpenAIChatExplicitCacheBreakpoint( + messages: ChatCompletionMessageParam[], + plan: PromptCachePlan +): ChatCompletionMessageParam[] { + if (plan.mode !== 'anthropic_explicit' || !plan.breakpointPlan) { + return messages + } + + const { messageIndex, contentIndex } = plan.breakpointPlan + const target = messages[messageIndex] + + if (!target || !('content' in target)) { + return messages + } + + const content = target.content + let nextContent: ChatCompletionMessageParam['content'] = + content as ChatCompletionMessageParam['content'] + + if (typeof content === 'string') { + if (!content.trim() || contentIndex !== 0) { + return messages + } + + nextContent = [ + { + type: 'text', + text: content, + cache_control: EPHEMERAL_CACHE_CONTROL + } as unknown as ChatCompletionContentPart + ] + } else if (Array.isArray(content)) { + nextContent = content.map((part, index) => { + if ( + index !== contentIndex || + part?.type !== 'text' || + typeof part.text !== 'string' || + !part.text.trim() + ) { + return part + } + + return { + ...part, + cache_control: EPHEMERAL_CACHE_CONTROL + } as unknown as ChatCompletionContentPart + }) as ChatCompletionMessageParam['content'] + } else { + return messages + } + + return messages.map((message, index) => + index === messageIndex + ? ({ + ...message, + content: nextContent + } as ChatCompletionMessageParam) + : message + ) +} + +export function applyAnthropicExplicitCacheBreakpoint( + messages: Anthropic.MessageParam[], + plan: PromptCachePlan +): Anthropic.MessageParam[] { + if (plan.mode !== 'anthropic_explicit' || !plan.breakpointPlan) { + return messages + } + + const { messageIndex, contentIndex } = plan.breakpointPlan + const target = messages[messageIndex] + + if (!target) { + return messages + } + + const content = target.content + let nextContent: Anthropic.MessageParam['content'] = content + + if (typeof content === 'string') { + if (!content.trim() || contentIndex !== 0) { + return messages + } + + nextContent = [ + { + type: 'text', + text: content, + cache_control: EPHEMERAL_CACHE_CONTROL + } satisfies AnthropicTextBlockWithCache + ] + } else if (Array.isArray(content)) { + nextContent = content.map((block, index) => { + if ( + index !== contentIndex || + block?.type !== 'text' || + typeof block.text !== 'string' || + !block.text.trim() + ) { + return block + } + + return { + ...block, + cache_control: EPHEMERAL_CACHE_CONTROL + } satisfies AnthropicTextBlockWithCache + }) + } else { + return messages + } + + return messages.map((message, index) => + index === messageIndex + ? ({ + ...message, + content: nextContent + } as Anthropic.MessageParam) + : message + ) +} diff --git a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts index 296713ee7..cce04e56f 100644 --- a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts @@ -15,6 +15,63 @@ import { proxyConfig } from '../../proxyConfig' import { ProxyAgent } from 'undici' import type { Usage } from '@anthropic-ai/sdk/resources' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { applyAnthropicTopLevelCacheControl, resolvePromptCachePlan } from '../promptCacheStrategy' + +type CacheAwareAnthropicUsage = Usage & { + cache_read_input_tokens?: number + cache_creation_input_tokens?: number + cacheReadInputTokens?: number + cacheWriteInputTokens?: number +} + +function getAnthropicUsageNumber( + usage: CacheAwareAnthropicUsage | undefined, + snakeKey: 'cache_read_input_tokens' | 'cache_creation_input_tokens', + camelKey: 'cacheReadInputTokens' | 'cacheWriteInputTokens' +): number { + const value = usage?.[snakeKey] ?? usage?.[camelKey] + return typeof value === 'number' && Number.isFinite(value) ? value : 0 +} + +function buildAnthropicUsageSnapshot(usage: CacheAwareAnthropicUsage | undefined): { + prompt_tokens: number + completion_tokens: number + total_tokens: number + cached_tokens?: number + cache_write_tokens?: number +} | null { + if (!usage) { + return null + } + + const uncachedInputTokens = + typeof usage.input_tokens === 'number' && Number.isFinite(usage.input_tokens) + ? usage.input_tokens + : 0 + const completionTokens = + typeof usage.output_tokens === 'number' && Number.isFinite(usage.output_tokens) + ? usage.output_tokens + : 0 + const cachedTokens = getAnthropicUsageNumber( + usage, + 'cache_read_input_tokens', + 'cacheReadInputTokens' + ) + const cacheWriteTokens = getAnthropicUsageNumber( + usage, + 'cache_creation_input_tokens', + 'cacheWriteInputTokens' + ) + const promptTokens = uncachedInputTokens + cachedTokens + cacheWriteTokens + + return { + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens, + ...(cachedTokens > 0 ? { cached_tokens: cachedTokens } : {}), + ...(cacheWriteTokens > 0 ? { cache_write_tokens: cacheWriteTokens } : {}) + } +} export class AnthropicProvider extends BaseLLMProvider { private anthropic!: Anthropic @@ -42,6 +99,22 @@ export class AnthropicProvider extends BaseLLMProvider { } } + private applyPromptCache>( + requestParams: T, + modelId: string, + messages: Anthropic.MessageParam[], + conversationId?: string + ): T { + const plan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'anthropic', + modelId, + messages: messages as unknown[], + conversationId + }) + return applyAnthropicTopLevelCacheControl(requestParams, plan) + } + public onProxyResolved(): void { this.init() } @@ -458,10 +531,16 @@ export class AnthropicProvider extends BaseLLMProvider { requestParams.system = formattedMessages.system } + const cachedRequestParams = this.applyPromptCache( + requestParams, + modelId, + formattedMessages.messages + ) + if (!this.anthropic) { throw new Error('Anthropic client is not initialized') } - const response = await this.anthropic.messages.create(requestParams) + const response = await this.anthropic.messages.create(cachedRequestParams) const resultResp: LLMResponse = { content: '' @@ -469,10 +548,13 @@ export class AnthropicProvider extends BaseLLMProvider { // 添加usage信息 if (response.usage) { + const usageSnapshot = buildAnthropicUsageSnapshot( + response.usage as CacheAwareAnthropicUsage + ) resultResp.totalUsage = { - prompt_tokens: response.usage.input_tokens, - completion_tokens: response.usage.output_tokens, - total_tokens: response.usage.input_tokens + response.usage.output_tokens + prompt_tokens: usageSnapshot?.prompt_tokens ?? 0, + completion_tokens: usageSnapshot?.completion_tokens ?? 0, + total_tokens: usageSnapshot?.total_tokens ?? 0 } } @@ -546,10 +628,16 @@ ${text} requestParams.system = systemPrompt } + const cachedRequestParams = this.applyPromptCache( + requestParams, + modelId, + requestParams.messages + ) + if (!this.anthropic) { throw new Error('Anthropic client is not initialized') } - const response = await this.anthropic.messages.create(requestParams) + const response = await this.anthropic.messages.create(cachedRequestParams) return { content: response.content @@ -588,10 +676,16 @@ ${context} requestParams.system = systemPrompt } + const cachedRequestParams = this.applyPromptCache( + requestParams, + modelId, + requestParams.messages + ) + if (!this.anthropic) { throw new Error('Anthropic client is not initialized') } - const response = await this.anthropic.messages.create(requestParams) + const response = await this.anthropic.messages.create(cachedRequestParams) const suggestions = response.content .filter((block: any) => block.type === 'text') @@ -657,14 +751,20 @@ ${context} // @ts-ignore - 类型不匹配,但格式是正确的 streamParams.tools = anthropicTools } + const cachedStreamParams = this.applyPromptCache( + streamParams as unknown as Record, + modelId, + formattedMessagesObject.messages, + modelConfig.conversationId + ) as unknown as Anthropic.Messages.MessageCreateParamsStreaming await this.emitRequestTrace(modelConfig, { endpoint: this.buildAnthropicEndpoint(), headers: this.buildAnthropicApiKeyHeaders(), - body: streamParams + body: cachedStreamParams }) // console.log('streamParams', JSON.stringify(streamParams.messages)) // 创建Anthropic流 - const stream = await this.anthropic.messages.create(streamParams) + const stream = await this.anthropic.messages.create(cachedStreamParams) // 状态变量 let accumulatedJson = '' @@ -820,11 +920,10 @@ ${context} } } if (usageMetadata) { - yield createStreamEvent.usage({ - prompt_tokens: usageMetadata.input_tokens, - completion_tokens: usageMetadata.output_tokens, - total_tokens: usageMetadata.input_tokens + usageMetadata.output_tokens - }) + const usageSnapshot = buildAnthropicUsageSnapshot(usageMetadata as CacheAwareAnthropicUsage) + if (usageSnapshot) { + yield createStreamEvent.usage(usageSnapshot) + } } // 发送停止事件 yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') diff --git a/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts b/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts index f11b727e8..00ee51071 100644 --- a/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts @@ -20,6 +20,66 @@ import { import Anthropic from '@anthropic-ai/sdk' import { Usage } from '@anthropic-ai/sdk/resources/messages' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { + applyAnthropicExplicitCacheBreakpoint, + resolvePromptCachePlan +} from '../promptCacheStrategy' + +type CacheAwareBedrockUsage = Usage & { + cache_read_input_tokens?: number + cache_creation_input_tokens?: number + cacheReadInputTokens?: number + cacheWriteInputTokens?: number +} + +function getBedrockUsageNumber( + usage: CacheAwareBedrockUsage | undefined, + snakeKey: 'cache_read_input_tokens' | 'cache_creation_input_tokens', + camelKey: 'cacheReadInputTokens' | 'cacheWriteInputTokens' +): number { + const value = usage?.[snakeKey] ?? usage?.[camelKey] + return typeof value === 'number' && Number.isFinite(value) ? value : 0 +} + +function buildBedrockUsageSnapshot(usage: CacheAwareBedrockUsage | undefined): { + prompt_tokens: number + completion_tokens: number + total_tokens: number + cached_tokens?: number + cache_write_tokens?: number +} | null { + if (!usage) { + return null + } + + const uncachedInputTokens = + typeof usage.input_tokens === 'number' && Number.isFinite(usage.input_tokens) + ? usage.input_tokens + : 0 + const completionTokens = + typeof usage.output_tokens === 'number' && Number.isFinite(usage.output_tokens) + ? usage.output_tokens + : 0 + const cachedTokens = getBedrockUsageNumber( + usage, + 'cache_read_input_tokens', + 'cacheReadInputTokens' + ) + const cacheWriteTokens = getBedrockUsageNumber( + usage, + 'cache_creation_input_tokens', + 'cacheWriteInputTokens' + ) + const promptTokens = uncachedInputTokens + cachedTokens + cacheWriteTokens + + return { + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens, + ...(cachedTokens > 0 ? { cached_tokens: cachedTokens } : {}), + ...(cacheWriteTokens > 0 ? { cache_write_tokens: cacheWriteTokens } : {}) + } +} export class AwsBedrockProvider extends BaseLLMProvider { private bedrock!: BedrockClient @@ -64,6 +124,21 @@ export class AwsBedrockProvider extends BaseLLMProvider { return body } + private applyPromptCache( + messages: Anthropic.MessageParam[], + modelId: string, + conversationId?: string + ): Anthropic.MessageParam[] { + const plan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'anthropic', + modelId, + messages: messages as unknown[], + conversationId + }) + return applyAnthropicExplicitCacheBreakpoint(messages, plan) + } + public onProxyResolved(): void { this.init() } @@ -539,6 +614,7 @@ ${text} } const formattedMessages = this.formatMessages(messages) + const cachedMessages = this.applyPromptCache(formattedMessages.messages, modelId) // 创建基本请求参数 const payload = { @@ -546,7 +622,7 @@ ${text} max_tokens: maxTokens, temperature, system: formattedMessages.system, - messages + messages: cachedMessages } const command = new InvokeModelCommand({ @@ -566,10 +642,11 @@ ${text} // 添加usage信息 if (response.usage) { + const usageSnapshot = buildBedrockUsageSnapshot(response.usage as CacheAwareBedrockUsage) resultResp.totalUsage = { - prompt_tokens: response.usage.input_tokens, - completion_tokens: response.usage.output_tokens, - total_tokens: response.usage.input_tokens + response.usage.output_tokens + prompt_tokens: usageSnapshot?.prompt_tokens ?? 0, + completion_tokens: usageSnapshot?.completion_tokens ?? 0, + total_tokens: usageSnapshot?.total_tokens ?? 0 } } @@ -628,6 +705,11 @@ ${text} try { // 格式化消息 const formattedMessagesObject = this.formatMessages(messages) + const cachedMessages = this.applyPromptCache( + formattedMessagesObject.messages, + modelId, + modelConfig.conversationId + ) console.log('formattedMessagesObject', JSON.stringify(formattedMessagesObject)) // 将MCP工具转换为Anthropic工具格式 @@ -648,8 +730,8 @@ ${text} anthropic_version: 'bedrock-2023-05-31', max_tokens: maxTokens || 1024, temperature: temperature ?? 0.7, - // system: formattedMessagesObject.system, - messages: formattedMessagesObject.messages, + system: formattedMessagesObject.system, + messages: cachedMessages, thinking: undefined as any, tools: undefined as any } @@ -888,11 +970,10 @@ ${text} } } if (usageMetadata) { - yield createStreamEvent.usage({ - prompt_tokens: usageMetadata.input_tokens, - completion_tokens: usageMetadata.output_tokens, - total_tokens: usageMetadata.input_tokens + usageMetadata.output_tokens - }) + const usageSnapshot = buildBedrockUsageSnapshot(usageMetadata as CacheAwareBedrockUsage) + if (usageSnapshot) { + yield createStreamEvent.usage(usageSnapshot) + } } // 发送停止事件 yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') diff --git a/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts index 8aedd633b..8d6c1640c 100644 --- a/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts @@ -33,6 +33,11 @@ import { proxyConfig } from '../../proxyConfig' import { modelCapabilities } from '../../configPresenter/modelCapabilities' import { ProxyAgent } from 'undici' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { + applyOpenAIChatExplicitCacheBreakpoint, + applyOpenAIPromptCacheKey, + resolvePromptCachePlan +} from '../promptCacheStrategy' const OPENAI_REASONING_MODELS = [ 'o4-mini', @@ -80,6 +85,17 @@ export function normalizeExtractedImageText(content: string): string { } function getOpenAIChatCachedTokens(usage: unknown): number | undefined { + return getOpenAIChatUsageDetail(usage, 'cached_tokens') +} + +function getOpenAIChatCacheWriteTokens(usage: unknown): number | undefined { + return getOpenAIChatUsageDetail(usage, 'cache_write_tokens') +} + +function getOpenAIChatUsageDetail( + usage: unknown, + key: 'cached_tokens' | 'cache_write_tokens' +): number | undefined { if (!usage || typeof usage !== 'object') { return undefined } @@ -88,11 +104,11 @@ function getOpenAIChatCachedTokens(usage: unknown): number | undefined { const inputTokensDetails = (usage as { input_tokens_details?: unknown }).input_tokens_details const promptCachedTokens = promptTokensDetails && typeof promptTokensDetails === 'object' - ? (promptTokensDetails as { cached_tokens?: unknown }).cached_tokens + ? (promptTokensDetails as Record)[key] : undefined const inputCachedTokens = inputTokensDetails && typeof inputTokensDetails === 'object' - ? (inputTokensDetails as { cached_tokens?: unknown }).cached_tokens + ? (inputTokensDetails as Record)[key] : undefined const cachedTokens = typeof promptCachedTokens === 'number' ? promptCachedTokens : inputCachedTokens @@ -768,7 +784,7 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) const supportsFunctionCall = modelConfig?.functionCall || false - const requestParams: OpenAI.Chat.ChatCompletionCreateParams = { + const requestParams: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming = { messages: this.formatMessages(messages, supportsFunctionCall), model: modelId, stream: false, @@ -781,12 +797,27 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { ? { max_completion_tokens: maxTokens } : { max_tokens: maxTokens }) } + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_chat', + modelId, + messages: requestParams.messages as unknown[], + conversationId: modelConfig?.conversationId + }) + requestParams.messages = applyOpenAIChatExplicitCacheBreakpoint( + requestParams.messages as ChatCompletionMessageParam[], + promptCachePlan + ) OPENAI_REASONING_MODELS.forEach((noTempId) => { if (modelId.startsWith(noTempId)) { delete requestParams.temperature } }) - const completion = await this.openai.chat.completions.create(requestParams) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming + const completion = await this.openai.chat.completions.create(cachedRequestParams) const message = completion.choices[0].message as ChatCompletionMessage & { reasoning_content?: string @@ -1012,7 +1043,8 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { prompt_tokens: result.usage.input_tokens || 0, completion_tokens: result.usage.output_tokens || 0, total_tokens: result.usage.total_tokens || 0, - cached_tokens: getOpenAIChatCachedTokens(result.usage) + cached_tokens: getOpenAIChatCachedTokens(result.usage), + cache_write_tokens: getOpenAIChatCacheWriteTokens(result.usage) }) } @@ -1081,7 +1113,7 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { : undefined // 构建请求参数 - const requestParams: OpenAI.Chat.ChatCompletionCreateParams = { + const requestParams: OpenAI.Chat.ChatCompletionCreateParamsStreaming = { messages: processedMessages, model: modelId, stream: true, @@ -1132,15 +1164,32 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { // 如果存在 API 工具且支持函数调用,则添加到请求参数中 if (apiTools && apiTools.length > 0 && supportsFunctionCall) requestParams.tools = apiTools + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_chat', + modelId, + messages: processedMessages as unknown[], + tools, + conversationId: modelConfig?.conversationId + }) + requestParams.messages = applyOpenAIChatExplicitCacheBreakpoint( + requestParams.messages as ChatCompletionMessageParam[], + promptCachePlan + ) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Chat.ChatCompletionCreateParamsStreaming + await this.emitRequestTrace(modelConfig, { endpoint: this.buildChatCompletionsEndpoint(), headers: this.buildChatCompletionsTraceHeaders(), - body: requestParams + body: cachedRequestParams }) // console.log('[handleChatCompletion] requestParams', JSON.stringify(requestParams)) // 发起 OpenAI 聊天补全请求 - const stream = await this.openai.chat.completions.create(requestParams) + const stream = await this.openai.chat.completions.create(cachedRequestParams) //----------------------------------------------------------------------------------------------------- // 流处理状态定义 (已将相关变量声明提升到顶部,确保可见性) @@ -1179,6 +1228,7 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { completion_tokens: number total_tokens: number cached_tokens?: number + cache_write_tokens?: number } | undefined = undefined @@ -1195,7 +1245,8 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { if (chunk.usage) { usage = { ...chunk.usage, - cached_tokens: getOpenAIChatCachedTokens(chunk.usage) + cached_tokens: getOpenAIChatCachedTokens(chunk.usage), + cache_write_tokens: getOpenAIChatCacheWriteTokens(chunk.usage) } } diff --git a/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts index e4e4daa6c..67a477f4b 100644 --- a/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts @@ -24,6 +24,7 @@ import { proxyConfig } from '../../proxyConfig' import { ProxyAgent } from 'undici' import { modelCapabilities } from '../../configPresenter/modelCapabilities' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { applyOpenAIPromptCacheKey, resolvePromptCachePlan } from '../promptCacheStrategy' const OPENAI_REASONING_MODELS = [ 'o4-mini', @@ -62,7 +63,9 @@ function getOpenAIResponseCachedTokens( | { input_tokens_details?: { cached_tokens?: number + cache_write_tokens?: number } + cache_write_tokens?: number } | null | undefined @@ -73,6 +76,24 @@ function getOpenAIResponseCachedTokens( : undefined } +function getOpenAIResponseCacheWriteTokens(usage: unknown): number | undefined { + if (!usage || typeof usage !== 'object') { + return undefined + } + + const inputTokensDetails = (usage as { input_tokens_details?: unknown }).input_tokens_details + const nestedCacheWriteTokens = + inputTokensDetails && typeof inputTokensDetails === 'object' + ? (inputTokensDetails as Record).cache_write_tokens + : undefined + const topLevelCacheWriteTokens = (usage as Record).cache_write_tokens + const cacheWriteTokens = + typeof nestedCacheWriteTokens === 'number' ? nestedCacheWriteTokens : topLevelCacheWriteTokens + return typeof cacheWriteTokens === 'number' && Number.isFinite(cacheWriteTokens) + ? cacheWriteTokens + : undefined +} + export class OpenAIResponsesProvider extends BaseLLMProvider { protected openai!: OpenAI private isNoModelsApi: boolean = false @@ -305,7 +326,7 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { } const formattedMessages = this.formatMessages(messages) - const requestParams: OpenAI.Responses.ResponseCreateParams = { + const requestParams: OpenAI.Responses.ResponseCreateParamsNonStreaming = { model: modelId, input: formattedMessages, temperature: temperature, @@ -314,6 +335,13 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { } const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_responses', + modelId, + messages: formattedMessages as unknown[], + conversationId: modelConfig?.conversationId + }) if (modelConfig.reasoningEffort && this.supportsEffortParameter(modelId)) { ;(requestParams as any).reasoning = { effort: modelConfig.reasoningEffort @@ -333,7 +361,12 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { } }) - const response = await this.openai.responses.create(requestParams) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Responses.ResponseCreateParamsNonStreaming + + const response = await this.openai.responses.create(cachedRequestParams) const resultResp: LLMResponse = { content: '' } @@ -571,7 +604,8 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { prompt_tokens: result.usage.input_tokens || 0, completion_tokens: result.usage.output_tokens || 0, total_tokens: result.usage.total_tokens || 0, - cached_tokens: getOpenAIResponseCachedTokens(result.usage) + cached_tokens: getOpenAIResponseCachedTokens(result.usage), + cache_write_tokens: getOpenAIResponseCacheWriteTokens(result.usage) }) } @@ -631,13 +665,21 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { ? await this.mcpRuntime?.mcpToolsToOpenAIResponsesTools(tools, this.provider.id) : undefined - const requestParams: OpenAI.Responses.ResponseCreateParams = { + const requestParams: OpenAI.Responses.ResponseCreateParamsStreaming = { model: modelId, input: processedMessages, temperature, max_output_tokens: maxTokens, stream: true } + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_responses', + modelId, + messages: processedMessages as unknown[], + tools, + conversationId: modelConfig?.conversationId + }) // 如果模型支持函数调用且有工具,添加 tools 参数 if (tools.length > 0 && supportsFunctionCall && apiTools) { @@ -660,13 +702,18 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { if (modelId.startsWith(noTempId)) delete requestParams.temperature }) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Responses.ResponseCreateParamsStreaming + await this.emitRequestTrace(modelConfig, { endpoint: this.buildResponsesEndpoint(), headers: this.buildResponsesTraceHeaders(), - body: requestParams + body: cachedRequestParams }) - const stream = await this.openai.responses.create(requestParams) + const stream = await this.openai.responses.create(cachedRequestParams) // --- State Variables --- type TagState = 'none' | 'start' | 'inside' | 'end' @@ -696,6 +743,7 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { completion_tokens: number total_tokens: number cached_tokens?: number + cache_write_tokens?: number } | undefined = undefined @@ -1006,7 +1054,8 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { prompt_tokens: response.usage.input_tokens || 0, completion_tokens: response.usage.output_tokens || 0, total_tokens: response.usage.total_tokens || 0, - cached_tokens: getOpenAIResponseCachedTokens(response.usage) + cached_tokens: getOpenAIResponseCachedTokens(response.usage), + cache_write_tokens: getOpenAIResponseCacheWriteTokens(response.usage) } yield createStreamEvent.usage(usage) } diff --git a/src/main/presenter/newAgentPresenter/index.ts b/src/main/presenter/newAgentPresenter/index.ts index 0b1d776a6..2ac0f7f11 100644 --- a/src/main/presenter/newAgentPresenter/index.ts +++ b/src/main/presenter/newAgentPresenter/index.ts @@ -2143,6 +2143,7 @@ export class NewAgentPresenter { inputTokens?: number outputTokens?: number cachedInputTokens?: number + cacheWriteInputTokens?: number generationTime?: number firstTokenTime?: number tokensPerSecond?: number @@ -2158,6 +2159,10 @@ export class NewAgentPresenter { outputTokens: typeof parsed.outputTokens === 'number' ? parsed.outputTokens : undefined, cachedInputTokens: typeof parsed.cachedInputTokens === 'number' ? parsed.cachedInputTokens : undefined, + cacheWriteInputTokens: + typeof parsed.cacheWriteInputTokens === 'number' + ? parsed.cacheWriteInputTokens + : undefined, generationTime: typeof parsed.generationTime === 'number' ? parsed.generationTime : undefined, firstTokenTime: @@ -2208,7 +2213,8 @@ export class NewAgentPresenter { modelId, metadata: { ...metadata, - cachedInputTokens: 0 + cachedInputTokens: 0, + cacheWriteInputTokens: 0 }, source: 'backfill' }) diff --git a/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts b/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts index edada206e..a071f9474 100644 --- a/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts +++ b/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts @@ -12,6 +12,7 @@ export interface DeepChatUsageStatsRow { output_tokens: number total_tokens: number cached_input_tokens: number + cache_write_input_tokens: number estimated_cost_usd: number | null source: 'backfill' | 'live' created_at: number @@ -93,6 +94,7 @@ export class DeepChatUsageStatsTable extends BaseTable { output_tokens INTEGER NOT NULL DEFAULT 0, total_tokens INTEGER NOT NULL DEFAULT 0, cached_input_tokens INTEGER NOT NULL DEFAULT 0, + cache_write_input_tokens INTEGER NOT NULL DEFAULT 0, estimated_cost_usd REAL, source TEXT NOT NULL DEFAULT 'live', created_at INTEGER NOT NULL, @@ -108,11 +110,16 @@ export class DeepChatUsageStatsTable extends BaseTable { if (version === 17) { return this.getCreateTableSQL() } + if (version === 22) { + return this.hasColumn('cache_write_input_tokens') + ? null + : `ALTER TABLE deepchat_usage_stats ADD COLUMN cache_write_input_tokens INTEGER NOT NULL DEFAULT 0;` + } return null } getLatestVersion(): number { - return 17 + return 22 } upsert(row: UsageStatsRecordInput): void { @@ -128,11 +135,12 @@ export class DeepChatUsageStatsTable extends BaseTable { output_tokens, total_tokens, cached_input_tokens, + cache_write_input_tokens, estimated_cost_usd, source, created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(message_id) DO UPDATE SET session_id = excluded.session_id, usage_date = excluded.usage_date, @@ -142,6 +150,7 @@ export class DeepChatUsageStatsTable extends BaseTable { output_tokens = excluded.output_tokens, total_tokens = excluded.total_tokens, cached_input_tokens = excluded.cached_input_tokens, + cache_write_input_tokens = excluded.cache_write_input_tokens, estimated_cost_usd = excluded.estimated_cost_usd, source = excluded.source, created_at = excluded.created_at, @@ -157,6 +166,7 @@ export class DeepChatUsageStatsTable extends BaseTable { row.outputTokens, row.totalTokens, row.cachedInputTokens, + row.cacheWriteInputTokens, row.estimatedCostUsd, row.source, row.createdAt, diff --git a/src/main/presenter/usageStats.ts b/src/main/presenter/usageStats.ts index 495465737..9c550770a 100644 --- a/src/main/presenter/usageStats.ts +++ b/src/main/presenter/usageStats.ts @@ -22,6 +22,7 @@ export interface UsageStatsRecordInput { outputTokens: number totalTokens: number cachedInputTokens: number + cacheWriteInputTokens: number estimatedCostUsd: number | null source: UsageStatsSource createdAt: number @@ -152,18 +153,24 @@ export function normalizeUsageCounts(metadata: MessageMetadata): { outputTokens: number totalTokens: number cachedInputTokens: number + cacheWriteInputTokens: number } { const inputTokens = toNonNegativeInteger(metadata.inputTokens) ?? 0 const outputTokens = toNonNegativeInteger(metadata.outputTokens) ?? 0 const totalTokens = toNonNegativeInteger(metadata.totalTokens) ?? inputTokens + outputTokens const rawCached = toNonNegativeInteger(metadata.cachedInputTokens) ?? 0 - const cachedInputTokens = inputTokens > 0 ? Math.min(rawCached, inputTokens) : rawCached + const rawCacheWrite = toNonNegativeInteger(metadata.cacheWriteInputTokens) ?? 0 + const cappedCachedInputTokens = inputTokens > 0 ? Math.min(rawCached, inputTokens) : rawCached + const remainingInputTokens = Math.max(inputTokens - cappedCachedInputTokens, 0) + const cacheWriteInputTokens = + inputTokens > 0 ? Math.min(rawCacheWrite, remainingInputTokens) : rawCacheWrite return { inputTokens, outputTokens, totalTokens, - cachedInputTokens + cachedInputTokens: cappedCachedInputTokens, + cacheWriteInputTokens } } @@ -195,6 +202,7 @@ export function estimateUsageCostUsd(params: { inputTokens: number outputTokens: number cachedInputTokens: number + cacheWriteInputTokens: number }): number | null { const model = resolvePricedModel(params.providerId, params.modelId) const inputRate = getCostNumber(model, 'input') @@ -205,12 +213,17 @@ export function estimateUsageCostUsd(params: { } const cacheReadRate = getCostNumber(model, 'cache_read') - const billableInput = Math.max(params.inputTokens - params.cachedInputTokens, 0) + const cacheWriteRate = getCostNumber(model, 'cache_write') + const uncachedInput = Math.max( + params.inputTokens - params.cachedInputTokens - params.cacheWriteInputTokens, + 0 + ) return ( - (billableInput * inputRate + + (uncachedInput * inputRate + params.outputTokens * outputRate + - params.cachedInputTokens * (cacheReadRate ?? inputRate)) / + params.cachedInputTokens * (cacheReadRate ?? inputRate) + + params.cacheWriteInputTokens * (cacheWriteRate ?? inputRate)) / 1_000_000 ) } @@ -247,12 +260,14 @@ export function buildUsageStatsRecord(params: { outputTokens: usage.outputTokens, totalTokens: usage.totalTokens, cachedInputTokens: usage.cachedInputTokens, + cacheWriteInputTokens: usage.cacheWriteInputTokens, estimatedCostUsd: estimateUsageCostUsd({ providerId, modelId, inputTokens: usage.inputTokens, outputTokens: usage.outputTokens, - cachedInputTokens: usage.cachedInputTokens + cachedInputTokens: usage.cachedInputTokens, + cacheWriteInputTokens: usage.cacheWriteInputTokens }), source: params.source, createdAt: params.createdAt, diff --git a/src/shared/types/agent-interface.d.ts b/src/shared/types/agent-interface.d.ts index 6092aaa64..ed6b299a9 100644 --- a/src/shared/types/agent-interface.d.ts +++ b/src/shared/types/agent-interface.d.ts @@ -280,6 +280,7 @@ export interface MessageMetadata { inputTokens?: number outputTokens?: number cachedInputTokens?: number + cacheWriteInputTokens?: number generationTime?: number firstTokenTime?: number reasoningStartTime?: number diff --git a/src/shared/types/core/llm-events.ts b/src/shared/types/core/llm-events.ts index f1999ab89..b8645253d 100644 --- a/src/shared/types/core/llm-events.ts +++ b/src/shared/types/core/llm-events.ts @@ -58,6 +58,7 @@ export interface UsageStreamEvent { completion_tokens: number total_tokens: number cached_tokens?: number + cache_write_tokens?: number } } @@ -136,6 +137,7 @@ export const createStreamEvent = { completion_tokens: number total_tokens: number cached_tokens?: number + cache_write_tokens?: number }): UsageStreamEvent => ({ type: 'usage', usage diff --git a/test/main/presenter/deepchatAgentPresenter/accumulator.test.ts b/test/main/presenter/deepchatAgentPresenter/accumulator.test.ts index 7e2f99f18..29672a968 100644 --- a/test/main/presenter/deepchatAgentPresenter/accumulator.test.ts +++ b/test/main/presenter/deepchatAgentPresenter/accumulator.test.ts @@ -186,13 +186,20 @@ describe('accumulate', () => { it('usage sets metadata', () => { accumulate(state, { type: 'usage', - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15, cached_tokens: 3 } + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + cached_tokens: 3, + cache_write_tokens: 2 + } }) expect(state.metadata.inputTokens).toBe(10) expect(state.metadata.outputTokens).toBe(5) expect(state.metadata.totalTokens).toBe(15) expect(state.metadata.cachedInputTokens).toBe(3) + expect(state.metadata.cacheWriteInputTokens).toBe(2) }) it('stop sets stopReason', () => { diff --git a/test/main/presenter/deepchatAgentPresenter/messageStore.test.ts b/test/main/presenter/deepchatAgentPresenter/messageStore.test.ts index 907984cb8..fe1919fe7 100644 --- a/test/main/presenter/deepchatAgentPresenter/messageStore.test.ts +++ b/test/main/presenter/deepchatAgentPresenter/messageStore.test.ts @@ -141,7 +141,8 @@ describe('DeepChatMessageStore', () => { inputTokens: 120, outputTokens: 30, totalTokens: 150, - cachedInputTokens: 20 + cachedInputTokens: 20, + cacheWriteInputTokens: 12 }) ) @@ -155,6 +156,7 @@ describe('DeepChatMessageStore', () => { outputTokens: 30, totalTokens: 150, cachedInputTokens: 20, + cacheWriteInputTokens: 12, source: 'live' }) ) @@ -184,7 +186,8 @@ describe('DeepChatMessageStore', () => { inputTokens: 120, outputTokens: 30, totalTokens: 150, - cachedInputTokens: 20 + cachedInputTokens: 20, + cacheWriteInputTokens: 12 }) ) ).not.toThrow() diff --git a/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts b/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts index f4bd86b5c..825f12644 100644 --- a/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts @@ -1,5 +1,5 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -import type { IConfigPresenter, LLM_PROVIDER } from '../../../../src/shared/presenter' +import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' import { AnthropicProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/anthropicProvider' const { mockAnthropicConstructor, mockMessagesCreate, mockModelsList, mockGetProxyUrl } = @@ -61,6 +61,14 @@ const createConfigPresenter = () => getModelStatus: vi.fn().mockReturnValue(true) }) as unknown as IConfigPresenter +const createAsyncStream = (chunks: Array>) => ({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk + } + } +}) + const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ id: 'anthropic', name: 'Anthropic', @@ -73,6 +81,15 @@ const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ describe('AnthropicProvider API-only behavior', () => { const originalEnvKey = process.env.ANTHROPIC_API_KEY + const streamModelConfig: ModelConfig = { + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat', + conversationId: 'session-1' + } beforeEach(() => { vi.clearAllMocks() @@ -172,4 +189,114 @@ describe('AnthropicProvider API-only behavior', () => { }) ) }) + + it('adds top-level cache_control for Claude streaming requests', async () => { + mockMessagesCreate.mockResolvedValue( + createAsyncStream([ + { + type: 'message_start', + message: { + usage: { + input_tokens: 10, + output_tokens: 2 + } + } + }, + { + type: 'content_block_delta', + delta: { + type: 'text_delta', + text: 'hello' + } + } + ]) + ) + + const provider = new AnthropicProvider( + createProvider({ enable: false }), + createConfigPresenter() + ) + ;(provider as any).anthropic = { + messages: { create: mockMessagesCreate }, + models: { list: mockModelsList } + } + + const events = [] + for await (const event of provider.coreStream( + [{ role: 'user', content: 'hi' }], + 'claude-sonnet-4-5-20250929', + streamModelConfig, + 0.2, + 64, + [] + )) { + events.push(event) + } + + const request = mockMessagesCreate.mock.calls.at(-1)?.[0] + expect(request).toMatchObject({ + cache_control: { + type: 'ephemeral' + } + }) + expect(events.some((event) => event.type === 'text')).toBe(true) + }) + + it('normalizes cache read and cache write usage metadata for streams', async () => { + mockMessagesCreate.mockResolvedValue( + createAsyncStream([ + { + type: 'message_start', + message: { + usage: { + input_tokens: 10, + output_tokens: 5, + cache_read_input_tokens: 20, + cache_creation_input_tokens: 30 + } + } + }, + { + type: 'content_block_delta', + delta: { + type: 'text_delta', + text: 'hello' + } + } + ]) + ) + + const provider = new AnthropicProvider( + createProvider({ enable: false }), + createConfigPresenter() + ) + ;(provider as any).anthropic = { + messages: { create: mockMessagesCreate }, + models: { list: mockModelsList } + } + + const events = [] + for await (const event of provider.coreStream( + [{ role: 'user', content: 'hi' }], + 'claude-sonnet-4-5-20250929', + streamModelConfig, + 0.2, + 64, + [] + )) { + events.push(event) + } + + const usageEvent = events.find((event) => event.type === 'usage') + expect(usageEvent).toMatchObject({ + type: 'usage', + usage: { + prompt_tokens: 60, + completion_tokens: 5, + total_tokens: 65, + cached_tokens: 20, + cache_write_tokens: 30 + } + }) + }) }) diff --git a/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts b/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts new file mode 100644 index 000000000..bad73be9e --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts @@ -0,0 +1,207 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import type { + AWS_BEDROCK_PROVIDER, + ChatMessage, + IConfigPresenter, + ModelConfig +} from '../../../../src/shared/presenter' +import { AwsBedrockProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider' + +const { mockBedrockRuntimeSend, mockGetProxyUrl } = vi.hoisted(() => ({ + mockBedrockRuntimeSend: vi.fn(), + mockGetProxyUrl: vi.fn().mockReturnValue(null) +})) + +vi.mock('@aws-sdk/client-bedrock', () => ({ + BedrockClient: vi.fn(), + ListFoundationModelsCommand: class ListFoundationModelsCommand { + input: unknown + + constructor(input: unknown) { + this.input = input + } + } +})) + +vi.mock('@aws-sdk/client-bedrock-runtime', () => ({ + BedrockRuntimeClient: vi.fn(), + InvokeModelCommand: class InvokeModelCommand { + input: Record + + constructor(input: Record) { + this.input = input + } + }, + InvokeModelWithResponseStreamCommand: class InvokeModelWithResponseStreamCommand { + input: Record + + constructor(input: Record) { + this.input = input + } + } +})) + +vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ + proxyConfig: { + getProxyUrl: mockGetProxyUrl + } +})) + +const createConfigPresenter = () => + ({ + getProviderModels: vi.fn().mockReturnValue([]), + getCustomModels: vi.fn().mockReturnValue([]), + getModelConfig: vi.fn().mockReturnValue(undefined), + getDbProviderModels: vi.fn().mockReturnValue([]), + getSetting: vi.fn().mockReturnValue(undefined), + setProviderModels: vi.fn(), + getModelStatus: vi.fn().mockReturnValue(true) + }) as unknown as IConfigPresenter + +const createProvider = (overrides?: Partial): AWS_BEDROCK_PROVIDER => ({ + id: 'aws-bedrock', + name: 'AWS Bedrock', + apiType: 'aws-bedrock', + enable: false, + credential: { + accessKeyId: 'test-access-key', + secretAccessKey: 'test-secret', + region: 'us-east-1' + }, + ...overrides +}) + +const createAsyncStream = (chunks: Array>) => ({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk + } + } +}) + +const createBedrockChunk = (chunk: Record) => ({ + chunk: { + bytes: new TextEncoder().encode(JSON.stringify(chunk)) + } +}) + +describe('AwsBedrockProvider prompt cache behavior', () => { + const modelConfig: ModelConfig = { + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat', + conversationId: 'session-1' + } + + const messages: ChatMessage[] = [ + { role: 'system', content: 'system prompt' }, + { role: 'user', content: 'history' }, + { role: 'assistant', content: 'stable reply' }, + { role: 'user', content: 'latest question' } + ] + + beforeEach(() => { + vi.clearAllMocks() + mockGetProxyUrl.mockReturnValue(null) + mockBedrockRuntimeSend.mockResolvedValue({ + body: Promise.resolve( + createAsyncStream([ + createBedrockChunk({ + type: 'message_start', + message: { + usage: { + input_tokens: 10, + output_tokens: 5, + cacheReadInputTokens: 20, + cacheWriteInputTokens: 30 + } + } + }), + createBedrockChunk({ + type: 'content_block_delta', + delta: { + type: 'text_delta', + text: 'hello' + } + }) + ]) + ) + }) + }) + + it('adds an explicit cache_control breakpoint before the latest user turn', async () => { + const provider = new AwsBedrockProvider(createProvider(), createConfigPresenter()) + ;(provider as any).bedrockRuntime = { + send: mockBedrockRuntimeSend + } + + const events = [] + for await (const event of provider.coreStream( + messages, + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + modelConfig, + 0.2, + 64, + [] + )) { + events.push(event) + } + + const command = mockBedrockRuntimeSend.mock.calls[0][0] as { + input: { + body: string + } + } + const payload = JSON.parse(command.input.body) + + expect(payload).not.toHaveProperty('cache_control') + expect(payload.system).toBe('system prompt\n') + expect(payload.messages[1]).toMatchObject({ + role: 'assistant', + content: [ + { + type: 'text', + text: 'stable reply', + cache_control: { + type: 'ephemeral' + } + } + ] + }) + expect(events.some((event) => event.type === 'text')).toBe(true) + }) + + it('normalizes cache read and cache write usage fields from Bedrock streams', async () => { + const provider = new AwsBedrockProvider(createProvider(), createConfigPresenter()) + ;(provider as any).bedrockRuntime = { + send: mockBedrockRuntimeSend + } + + const events = [] + for await (const event of provider.coreStream( + messages, + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + modelConfig, + 0.2, + 64, + [] + )) { + events.push(event) + } + + const usageEvent = events.find((event) => event.type === 'usage') + expect(usageEvent).toMatchObject({ + type: 'usage', + usage: { + prompt_tokens: 60, + completion_tokens: 5, + total_tokens: 65, + cached_tokens: 20, + cache_write_tokens: 30 + } + }) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts b/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts index e751dfc1d..aa98cfb2d 100644 --- a/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts @@ -351,3 +351,131 @@ describe('normalizeExtractedImageText', () => { expect(normalizeExtractedImageText('[]()')).toBe('') }) }) + +describe('OpenAICompatibleProvider prompt cache behavior', () => { + beforeEach(() => { + vi.clearAllMocks() + mockModelsList.mockResolvedValue({ data: [] }) + mockGetProxyUrl.mockReturnValue(null) + mockChatCompletionsCreate.mockResolvedValue( + createAsyncStream([ + { + choices: [ + { + delta: { + content: 'ok' + }, + finish_reason: 'stop' + } + ], + usage: { + prompt_tokens: 80, + completion_tokens: 12, + total_tokens: 92, + prompt_tokens_details: { + cached_tokens: 24, + cache_write_tokens: 16 + } + } + } + ]) + ) + }) + + it('injects prompt_cache_key only for official OpenAI chat completions', async () => { + const provider = new OpenAICompatibleProvider( + { + id: 'openai', + name: 'OpenAI', + apiType: 'openai-compatible', + apiKey: 'test-key', + baseUrl: 'https://api.openai.com/v1', + enable: false + }, + createConfigPresenter([]) + ) + ;(provider as any).isInitialized = true + + const modelConfig: ModelConfig = { + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat', + conversationId: 'session-1' + } + + const events = await collectEvents( + provider, + 'gpt-5', + modelConfig, + [{ role: 'user', content: 'cache me' }], + [] + ) + const usageEvent = events.find((event) => event.type === 'usage') + const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] + + expect(requestParams.prompt_cache_key).toMatch(/^deepchat:openai:gpt-5:/) + expect(usageEvent).toMatchObject({ + type: 'usage', + usage: { + cached_tokens: 24, + cache_write_tokens: 16 + } + }) + }) + + it('adds explicit cache_control breakpoint for OpenRouter Claude without top-level cache_control', async () => { + const provider = new OpenRouterProvider( + { + id: 'openrouter', + name: 'OpenRouter', + apiType: 'openai-compatible', + apiKey: 'test-key', + baseUrl: 'https://openrouter.ai/api/v1', + enable: false + }, + createConfigPresenter([]) + ) + ;(provider as any).isInitialized = true + + const modelConfig: ModelConfig = { + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat', + conversationId: 'session-2' + } + + await collectEvents( + provider, + 'anthropic/claude-sonnet-4', + modelConfig, + [ + { role: 'user', content: 'history' }, + { role: 'assistant', content: 'stable reply' }, + { role: 'user', content: 'latest question' } + ], + [] + ) + + const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] + expect(requestParams).not.toHaveProperty('cache_control') + expect(requestParams).not.toHaveProperty('prompt_cache_key') + expect(requestParams.messages[1]).toMatchObject({ + role: 'assistant', + content: [ + { + type: 'text', + text: 'stable reply', + cache_control: { + type: 'ephemeral' + } + } + ] + }) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts b/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts index 961404e31..b3a1318a0 100644 --- a/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts @@ -174,7 +174,8 @@ describe('OpenAIResponsesProvider tool call id mapping', () => { output_tokens: 5, total_tokens: 15, input_tokens_details: { - cached_tokens: 4 + cached_tokens: 4, + cache_write_tokens: 6 } } } @@ -218,7 +219,8 @@ describe('OpenAIResponsesProvider tool call id mapping', () => { expect(usageEvent).toMatchObject({ type: 'usage', usage: expect.objectContaining({ - cached_tokens: 4 + cached_tokens: 4, + cache_write_tokens: 6 }) }) expect(stopEvent?.stop_reason).toBe('tool_use') @@ -530,4 +532,47 @@ describe('OpenAIResponsesProvider tool call id mapping', () => { stream: true }) }) + + it('injects prompt_cache_key for official OpenAI Responses requests', async () => { + mockResponsesCreate.mockResolvedValue( + createAsyncStream([ + { + type: 'response.completed', + response: { + usage: { + input_tokens: 8, + output_tokens: 2, + total_tokens: 10 + } + } + } + ]) + ) + + const provider = new OpenAIResponsesProvider( + mockProvider, + mockConfigPresenter, + mcpRuntime as any + ) + ;(provider as any).isInitialized = true + + const promptCacheModelConfig: ModelConfig = { + ...modelConfig, + conversationId: 'session-1' + } + + for await (const _event of provider.coreStream( + [{ role: 'user', content: 'cache me' }], + 'gpt-5', + promptCacheModelConfig, + 0.7, + 512, + [] + )) { + // consume stream + } + + const requestParams = mockResponsesCreate.mock.calls[0][0] as Record + expect(requestParams.prompt_cache_key).toMatch(/^deepchat:openai:gpt-5:/) + }) }) diff --git a/test/main/presenter/llmProviderPresenter/promptCacheStrategy.test.ts b/test/main/presenter/llmProviderPresenter/promptCacheStrategy.test.ts new file mode 100644 index 000000000..22d9ee581 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/promptCacheStrategy.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, it } from 'vitest' +import { + applyAnthropicExplicitCacheBreakpoint, + applyOpenAIChatExplicitCacheBreakpoint, + resolvePromptCachePlan +} from '../../../../src/main/presenter/llmProviderPresenter/promptCacheStrategy' + +describe('promptCacheStrategy', () => { + it('builds prompt_cache_key for official OpenAI models', () => { + const plan = resolvePromptCachePlan({ + providerId: 'openai', + apiType: 'openai_chat', + modelId: 'gpt-5', + messages: [], + conversationId: 'session-1' + }) + + expect(plan).toMatchObject({ + mode: 'openai_implicit', + ttl: null + }) + expect(plan.cacheKey).toMatch(/^deepchat:openai:gpt-5:/) + }) + + it('enables top-level automatic cache control for Anthropic Claude', () => { + const plan = resolvePromptCachePlan({ + providerId: 'anthropic', + apiType: 'anthropic', + modelId: 'claude-sonnet-4-5-20250929', + messages: [] + }) + + expect(plan).toEqual({ + mode: 'anthropic_auto', + ttl: '5m' + }) + }) + + it('creates a single explicit breakpoint plan for Bedrock Claude', () => { + const plan = resolvePromptCachePlan({ + providerId: 'aws-bedrock', + apiType: 'anthropic', + modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + messages: [ + { role: 'user', content: [{ type: 'text', text: 'history' }] }, + { role: 'assistant', content: [{ type: 'text', text: 'stable reply' }] }, + { role: 'user', content: [{ type: 'text', text: 'latest question' }] } + ] + }) + + expect(plan).toEqual({ + mode: 'anthropic_explicit', + ttl: '5m', + breakpointPlan: { + messageIndex: 1, + contentIndex: 0 + } + }) + }) + + it('creates a single explicit breakpoint plan for OpenRouter Claude', () => { + const plan = resolvePromptCachePlan({ + providerId: 'openrouter', + apiType: 'openai_chat', + modelId: 'anthropic/claude-sonnet-4', + messages: [ + { role: 'user', content: 'history' }, + { role: 'assistant', content: 'stable reply' }, + { role: 'user', content: 'latest question' } + ] + }) + + expect(plan).toEqual({ + mode: 'anthropic_explicit', + ttl: '5m', + breakpointPlan: { + messageIndex: 1, + contentIndex: 0 + } + }) + }) + + it('keeps non-Claude OpenRouter models disabled for explicit request mutation', () => { + const plan = resolvePromptCachePlan({ + providerId: 'openrouter', + apiType: 'openai_chat', + modelId: 'openai/gpt-4o', + messages: [] + }) + + expect(plan).toEqual({ + mode: 'disabled', + ttl: null + }) + }) + + it('returns disabled for unsupported providers', () => { + const plan = resolvePromptCachePlan({ + providerId: 'gemini', + apiType: 'openai_chat', + modelId: 'gemini-2.5-pro', + messages: [] + }) + + expect(plan).toEqual({ + mode: 'disabled', + ttl: null + }) + }) + + it('annotates the selected text block for explicit cache breakpoints', () => { + const anthropicPlan = resolvePromptCachePlan({ + providerId: 'aws-bedrock', + apiType: 'anthropic', + modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + messages: [ + { role: 'user', content: [{ type: 'text', text: 'history' }] }, + { role: 'assistant', content: [{ type: 'text', text: 'stable reply' }] }, + { role: 'user', content: [{ type: 'text', text: 'latest question' }] } + ] + }) + const anthropicMessages = applyAnthropicExplicitCacheBreakpoint( + [ + { role: 'user', content: [{ type: 'text', text: 'history' }] }, + { role: 'assistant', content: [{ type: 'text', text: 'stable reply' }] }, + { role: 'user', content: [{ type: 'text', text: 'latest question' }] } + ] as any, + anthropicPlan + ) + + expect(anthropicMessages[1]?.content?.[0]).toMatchObject({ + type: 'text', + text: 'stable reply', + cache_control: { + type: 'ephemeral' + } + }) + + const openAIPlan = resolvePromptCachePlan({ + providerId: 'openrouter', + apiType: 'openai_chat', + modelId: 'anthropic/claude-sonnet-4', + messages: [ + { role: 'user', content: 'history' }, + { role: 'assistant', content: 'stable reply' }, + { role: 'user', content: 'latest question' } + ] + }) + const openAIMessages = applyOpenAIChatExplicitCacheBreakpoint( + [ + { role: 'user', content: 'history' }, + { role: 'assistant', content: 'stable reply' }, + { role: 'user', content: 'latest question' } + ] as any, + openAIPlan + ) + + expect(openAIMessages[1]).toMatchObject({ + role: 'assistant', + content: [ + { + type: 'text', + text: 'stable reply', + cache_control: { + type: 'ephemeral' + } + } + ] + }) + }) +}) diff --git a/test/main/presenter/newAgentPresenter/usageDashboard.test.ts b/test/main/presenter/newAgentPresenter/usageDashboard.test.ts index 7c1fc1d37..c20e9bd5c 100644 --- a/test/main/presenter/newAgentPresenter/usageDashboard.test.ts +++ b/test/main/presenter/newAgentPresenter/usageDashboard.test.ts @@ -95,6 +95,7 @@ type UsageStatsRow = { output_tokens: number total_tokens: number cached_input_tokens: number + cache_write_input_tokens: number estimated_cost_usd: number | null source: 'backfill' | 'live' created_at: number @@ -277,6 +278,7 @@ function createMockSqlitePresenter() { output_tokens: input.outputTokens, total_tokens: input.totalTokens, cached_input_tokens: input.cachedInputTokens, + cache_write_input_tokens: input.cacheWriteInputTokens, estimated_cost_usd: input.estimatedCostUsd, source: input.source, created_at: input.createdAt, @@ -504,7 +506,8 @@ describe('NewAgentPresenter usage dashboard', () => { inputTokens: 140, outputTokens: 60, totalTokens: 200, - cachedInputTokens: 20 + cachedInputTokens: 20, + cacheWriteInputTokens: 0 }) ) @@ -563,6 +566,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 80, totalTokens: 200, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: 0.01, source: 'live', createdAt: Date.UTC(2026, 2, 3, 8, 0, 0), @@ -578,6 +582,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 40, totalTokens: 100, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: 0.004, source: 'live', createdAt: Date.UTC(2026, 2, 3, 8, 1, 0), @@ -593,6 +598,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 20, totalTokens: 50, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: 0.002, source: 'live', createdAt: Date.UTC(2026, 2, 4, 8, 0, 0), @@ -622,6 +628,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 10, totalTokens: 20, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: null, source: 'live', createdAt: Date.UTC(2026, 2, 5, 8, 0, 0), @@ -637,6 +644,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 10, totalTokens: 20, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: null, source: 'live', createdAt: Date.UTC(2026, 2, 5, 8, 1, 0), @@ -652,6 +660,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 10, totalTokens: 20, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: null, source: 'live', createdAt: Date.UTC(2026, 2, 6, 8, 0, 0), @@ -667,6 +676,7 @@ describe('NewAgentPresenter usage dashboard', () => { outputTokens: 10, totalTokens: 20, cachedInputTokens: 0, + cacheWriteInputTokens: 0, estimatedCostUsd: null, source: 'live', createdAt: Date.UTC(2026, 2, 6, 8, 1, 0), diff --git a/test/main/presenter/sqlitePresenter.test.ts b/test/main/presenter/sqlitePresenter.test.ts index 392e98f7a..e42c7998b 100644 --- a/test/main/presenter/sqlitePresenter.test.ts +++ b/test/main/presenter/sqlitePresenter.test.ts @@ -637,4 +637,104 @@ describeIfSqlite('SQLitePresenter legacy schema bootstrap', () => { presenter.close() }) + + it('migrates deepchat_usage_stats to include cache_write_input_tokens without losing rows', async () => { + const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'deepchat-sqlite-presenter-')) + tempDirs.push(tempDir) + + const dbPath = path.join(tempDir, 'agent.db') + const bootstrapDb = new DatabaseCtor(dbPath) + bootstrapDb.exec(` + CREATE TABLE IF NOT EXISTS schema_versions ( + version INTEGER PRIMARY KEY, + applied_at INTEGER NOT NULL + ); + INSERT INTO schema_versions (version, applied_at) VALUES (21, ${Date.now()}); + CREATE TABLE IF NOT EXISTS deepchat_usage_stats ( + message_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + usage_date TEXT NOT NULL, + provider_id TEXT NOT NULL, + model_id TEXT NOT NULL, + input_tokens INTEGER NOT NULL DEFAULT 0, + output_tokens INTEGER NOT NULL DEFAULT 0, + total_tokens INTEGER NOT NULL DEFAULT 0, + cached_input_tokens INTEGER NOT NULL DEFAULT 0, + estimated_cost_usd REAL, + source TEXT NOT NULL DEFAULT 'live', + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + INSERT INTO deepchat_usage_stats ( + message_id, + session_id, + usage_date, + provider_id, + model_id, + input_tokens, + output_tokens, + total_tokens, + cached_input_tokens, + estimated_cost_usd, + source, + created_at, + updated_at + ) VALUES ( + 'message-1', + 'session-1', + '2026-03-10', + 'openai', + 'gpt-4o', + 120, + 30, + 150, + 20, + 0.01, + 'live', + 1000, + 2000 + ); + `) + bootstrapDb.close() + + const presenter = new SQLitePresenterCtor(dbPath) + presenter.close() + + const checkDb = new DatabaseCtor(dbPath) + const usageColumns = checkDb.prepare('PRAGMA table_info(deepchat_usage_stats)').all() as Array<{ + name: string + }> + const columnNames = new Set(usageColumns.map((column) => column.name)) + const row = checkDb + .prepare( + `SELECT + message_id, + cached_input_tokens, + cache_write_input_tokens, + estimated_cost_usd + FROM deepchat_usage_stats + WHERE message_id = ?` + ) + .get('message-1') as + | { + message_id: string + cached_input_tokens: number + cache_write_input_tokens: number + estimated_cost_usd: number | null + } + | undefined + const versions = checkDb + .prepare('SELECT version FROM schema_versions ORDER BY version ASC') + .all() as Array<{ version: number }> + + expect(columnNames.has('cache_write_input_tokens')).toBe(true) + expect(row).toEqual({ + message_id: 'message-1', + cached_input_tokens: 20, + cache_write_input_tokens: 0, + estimated_cost_usd: 0.01 + }) + expect(versions.map((entry) => entry.version)).toContain(22) + checkDb.close() + }) }) diff --git a/test/main/presenter/usageStats.test.ts b/test/main/presenter/usageStats.test.ts new file mode 100644 index 000000000..f22960a42 --- /dev/null +++ b/test/main/presenter/usageStats.test.ts @@ -0,0 +1,107 @@ +import { describe, expect, it, vi } from 'vitest' + +vi.mock('../../../src/main/presenter/configPresenter/providerDbLoader', () => ({ + providerDbLoader: { + getModel: vi.fn((providerId: string, modelId: string) => { + if (providerId === 'anthropic' && modelId === 'claude-sonnet') { + return { + cost: { + input: 3, + output: 15, + cache_read: 0.3, + cache_write: 3.75 + } + } + } + if (providerId === 'bedrock' && modelId === 'claude-bedrock') { + return { + cost: { + input: 4, + output: 20, + cache_read: 0.5 + } + } + } + return undefined + }) + } +})) + +import { + buildUsageStatsRecord, + estimateUsageCostUsd, + normalizeUsageCounts +} from '../../../src/main/presenter/usageStats' + +describe('usageStats cache pricing', () => { + it('charges uncached input, cache read, cache write, and output separately', () => { + const cost = estimateUsageCostUsd({ + providerId: 'anthropic', + modelId: 'claude-sonnet', + inputTokens: 1_000, + outputTokens: 200, + cachedInputTokens: 400, + cacheWriteInputTokens: 100 + }) + + expect(cost).toBeCloseTo((500 * 3 + 400 * 0.3 + 100 * 3.75 + 200 * 15) / 1_000_000) + }) + + it('falls back to the input price when cache_write pricing is unavailable', () => { + const cost = estimateUsageCostUsd({ + providerId: 'bedrock', + modelId: 'claude-bedrock', + inputTokens: 900, + outputTokens: 100, + cachedInputTokens: 300, + cacheWriteInputTokens: 200 + }) + + expect(cost).toBeCloseTo((400 * 4 + 300 * 0.5 + 200 * 4 + 100 * 20) / 1_000_000) + }) + + it('caps cached and cache-write counts against total input tokens', () => { + expect( + normalizeUsageCounts({ + inputTokens: 100, + outputTokens: 20, + totalTokens: 120, + cachedInputTokens: 90, + cacheWriteInputTokens: 50 + }) + ).toEqual({ + inputTokens: 100, + outputTokens: 20, + totalTokens: 120, + cachedInputTokens: 90, + cacheWriteInputTokens: 10 + }) + }) + + it('stores cache_write_input_tokens in usage records', () => { + const record = buildUsageStatsRecord({ + messageId: 'message-1', + sessionId: 'session-1', + createdAt: Date.UTC(2026, 2, 10, 8, 0, 0), + updatedAt: Date.UTC(2026, 2, 10, 8, 0, 1), + providerId: 'anthropic', + modelId: 'claude-sonnet', + metadata: { + inputTokens: 1_000, + outputTokens: 200, + totalTokens: 1_200, + cachedInputTokens: 400, + cacheWriteInputTokens: 100 + }, + source: 'live' + }) + + expect(record).toMatchObject({ + cachedInputTokens: 400, + cacheWriteInputTokens: 100 + }) + expect(record?.estimatedCostUsd).toBeCloseTo( + (500 * 3 + 400 * 0.3 + 100 * 3.75 + 200 * 15) / 1_000_000 + ) + }) +}) From 6b1e1a22c715dab52c6618cae928459fedc034c5 Mon Sep 17 00:00:00 2001 From: zerob13 Date: Thu, 2 Apr 2026 21:59:37 +0800 Subject: [PATCH 2/3] feat(provider): refine zenmux cache and UI --- .../promptCacheCapabilities.ts | 8 + .../providers/anthropicProvider.ts | 16 +- .../providers/zenmuxProvider.ts | 217 +++++++++++- src/main/presenter/newAgentPresenter/index.ts | 4 +- .../settings/components/ProviderApiConfig.vue | 43 ++- src/renderer/src/i18n/da-DK/settings.json | 8 + src/renderer/src/i18n/en-US/settings.json | 8 + src/renderer/src/i18n/fa-IR/settings.json | 8 + src/renderer/src/i18n/fr-FR/settings.json | 8 + src/renderer/src/i18n/he-IL/settings.json | 8 + src/renderer/src/i18n/ja-JP/settings.json | 8 + src/renderer/src/i18n/ko-KR/settings.json | 8 + src/renderer/src/i18n/pt-BR/settings.json | 8 + src/renderer/src/i18n/ru-RU/settings.json | 8 + src/renderer/src/i18n/zh-CN/settings.json | 8 + src/renderer/src/i18n/zh-HK/settings.json | 8 + src/renderer/src/i18n/zh-TW/settings.json | 8 + src/types/i18n.d.ts | 90 ++++- .../promptCacheStrategy.test.ts | 22 ++ .../zenmuxProvider.test.ts | 334 ++++++++++++++++++ .../newAgentPresenter/usageDashboard.test.ts | 7 +- .../components/ProviderApiConfig.test.ts | 236 +++++++++++++ 22 files changed, 1057 insertions(+), 16 deletions(-) create mode 100644 test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts create mode 100644 test/renderer/components/ProviderApiConfig.test.ts diff --git a/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts b/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts index b06f597cb..242de3dfa 100644 --- a/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts +++ b/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts @@ -24,6 +24,14 @@ export function resolvePromptCacheMode(providerId: string, modelId: string): Pro return 'anthropic_auto' } + if ( + normalizedProviderId === 'zenmux' && + normalizedModelId.startsWith('anthropic/') && + isClaudeModel(normalizedModelId) + ) { + return 'anthropic_explicit' + } + if ( normalizedProviderId === 'aws-bedrock' && (normalizedModelId.includes('anthropic.claude') || isClaudeModel(normalizedModelId)) diff --git a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts index cce04e56f..5b29a7708 100644 --- a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts @@ -15,7 +15,11 @@ import { proxyConfig } from '../../proxyConfig' import { ProxyAgent } from 'undici' import type { Usage } from '@anthropic-ai/sdk/resources' import type { ProviderMcpRuntimePort } from '../runtimePorts' -import { applyAnthropicTopLevelCacheControl, resolvePromptCachePlan } from '../promptCacheStrategy' +import { + applyAnthropicExplicitCacheBreakpoint, + applyAnthropicTopLevelCacheControl, + resolvePromptCachePlan +} from '../promptCacheStrategy' type CacheAwareAnthropicUsage = Usage & { cache_read_input_tokens?: number @@ -112,7 +116,15 @@ export class AnthropicProvider extends BaseLLMProvider { messages: messages as unknown[], conversationId }) - return applyAnthropicTopLevelCacheControl(requestParams, plan) + const nextRequestParams = + plan.mode === 'anthropic_explicit' + ? { + ...requestParams, + messages: applyAnthropicExplicitCacheBreakpoint(messages, plan) + } + : requestParams + + return applyAnthropicTopLevelCacheControl(nextRequestParams, plan) } public onProxyResolved(): void { diff --git a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts index be1bf5c42..7affab906 100644 --- a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts @@ -1,21 +1,230 @@ -import { IConfigPresenter, LLM_PROVIDER, MODEL_META } from '@shared/presenter' +import Anthropic from '@anthropic-ai/sdk' +import { + ChatMessage, + IConfigPresenter, + KeyStatus, + LLM_EMBEDDING_ATTRS, + LLM_PROVIDER, + LLMResponse, + LLMCoreStreamEvent, + MCPToolDefinition, + MODEL_META, + ModelConfig +} from '@shared/presenter' +import { ProxyAgent } from 'undici' +import { BaseLLMProvider } from '../baseProvider' +import { proxyConfig } from '../../proxyConfig' +import { AnthropicProvider } from './anthropicProvider' import { OpenAICompatibleProvider } from './openAICompatibleProvider' import type { ProviderMcpRuntimePort } from '../runtimePorts' -export class ZenmuxProvider extends OpenAICompatibleProvider { +const ZENMUX_ANTHROPIC_BASE_URL = 'https://zenmux.ai/api/anthropic' + +class ZenmuxOpenAIDelegate extends OpenAICompatibleProvider { + protected override async init() { + this.isInitialized = true + } + + public async fetchZenmuxModels(options?: { timeout: number }): Promise { + return super.fetchOpenAIModels(options) + } +} + +class ZenmuxAnthropicDelegate extends AnthropicProvider { + private clientInitialized = false + + protected override async init() {} + + public async ensureClientInitialized(): Promise { + const apiKey = this.provider.apiKey || process.env.ANTHROPIC_API_KEY || null + if (!apiKey) { + this.clientInitialized = false + this.isInitialized = false + return + } + + const proxyUrl = proxyConfig.getProxyUrl() + const fetchOptions: { dispatcher?: ProxyAgent } = {} + + if (proxyUrl) { + const proxyAgent = new ProxyAgent(proxyUrl) + fetchOptions.dispatcher = proxyAgent + } + + const self = this as unknown as { anthropic?: Anthropic } + self.anthropic = new Anthropic({ + apiKey, + baseURL: this.provider.baseUrl || ZENMUX_ANTHROPIC_BASE_URL, + defaultHeaders: this.defaultHeaders, + fetchOptions + }) + + this.clientInitialized = true + this.isInitialized = true + } + + public isClientInitialized(): boolean { + return this.clientInitialized + } + + public override onProxyResolved(): void { + void this.ensureClientInitialized() + } +} + +export class ZenmuxProvider extends BaseLLMProvider { + private readonly openaiDelegate: ZenmuxOpenAIDelegate + private readonly anthropicDelegate: ZenmuxAnthropicDelegate + constructor( provider: LLM_PROVIDER, configPresenter: IConfigPresenter, mcpRuntime?: ProviderMcpRuntimePort ) { super(provider, configPresenter, mcpRuntime) + + this.openaiDelegate = new ZenmuxOpenAIDelegate(provider, configPresenter, mcpRuntime) + this.anthropicDelegate = new ZenmuxAnthropicDelegate( + { + ...provider, + apiType: 'anthropic', + baseUrl: ZENMUX_ANTHROPIC_BASE_URL + }, + configPresenter, + mcpRuntime + ) + + this.init() + } + + private isAnthropicModel(modelId: string): boolean { + return modelId.trim().toLowerCase().startsWith('anthropic/') + } + + private async ensureAnthropicDelegateReady(): Promise { + await this.anthropicDelegate.ensureClientInitialized() + + if (!this.anthropicDelegate.isClientInitialized()) { + throw new Error('Anthropic SDK not initialized') + } + + return this.anthropicDelegate } - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const models = await super.fetchOpenAIModels(options) + protected async fetchProviderModels(): Promise { + const models = await this.openaiDelegate.fetchZenmuxModels() return models.map((model) => ({ ...model, group: 'ZenMux' })) } + + public onProxyResolved(): void { + this.openaiDelegate.onProxyResolved() + + if (this.anthropicDelegate.isClientInitialized()) { + this.anthropicDelegate.onProxyResolved() + } + } + + public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { + return this.openaiDelegate.check() + } + + public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.summaryTitles(messages, modelId) + } + + return this.openaiDelegate.summaryTitles(messages, modelId) + } + + public async completions( + messages: ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.completions(messages, modelId, temperature, maxTokens) + } + + return this.openaiDelegate.completions(messages, modelId, temperature, maxTokens) + } + + public async summaries( + text: string, + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.summaries(text, modelId, temperature, maxTokens) + } + + return this.openaiDelegate.summaries(text, modelId, temperature, maxTokens) + } + + public async generateText( + prompt: string, + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.generateText(prompt, modelId, temperature, maxTokens) + } + + return this.openaiDelegate.generateText(prompt, modelId, temperature, maxTokens) + } + + public async *coreStream( + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + temperature: number, + maxTokens: number, + tools: MCPToolDefinition[] + ): AsyncGenerator { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + yield* delegate.coreStream(messages, modelId, modelConfig, temperature, maxTokens, tools) + return + } + + yield* this.openaiDelegate.coreStream( + messages, + modelId, + modelConfig, + temperature, + maxTokens, + tools + ) + } + + public async getEmbeddings(modelId: string, texts: string[]): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.getEmbeddings(modelId, texts) + } + + return this.openaiDelegate.getEmbeddings(modelId, texts) + } + + public async getDimensions(modelId: string): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.getDimensions(modelId) + } + + return this.openaiDelegate.getDimensions(modelId) + } + + public async getKeyStatus(): Promise { + return this.openaiDelegate.getKeyStatus() + } } diff --git a/src/main/presenter/newAgentPresenter/index.ts b/src/main/presenter/newAgentPresenter/index.ts index 2ac0f7f11..c3860d7ee 100644 --- a/src/main/presenter/newAgentPresenter/index.ts +++ b/src/main/presenter/newAgentPresenter/index.ts @@ -2213,8 +2213,8 @@ export class NewAgentPresenter { modelId, metadata: { ...metadata, - cachedInputTokens: 0, - cacheWriteInputTokens: 0 + cachedInputTokens: metadata.cachedInputTokens ?? 0, + cacheWriteInputTokens: metadata.cacheWriteInputTokens ?? 0 }, source: 'backfill' }) diff --git a/src/renderer/settings/components/ProviderApiConfig.vue b/src/renderer/settings/components/ProviderApiConfig.vue index f62517f6f..cf912291d 100644 --- a/src/renderer/settings/components/ProviderApiConfig.vue +++ b/src/renderer/settings/components/ProviderApiConfig.vue @@ -26,7 +26,21 @@ {{ t('settings.provider.delete') }} +
+
+ + {{ apiHost || t('settings.provider.urlPlaceholder') }} + +
+ +
- + ' +}) + +const labelStub = defineComponent({ + name: 'Label', + inheritAttrs: false, + template: '' +}) + +const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ + id: 'deepseek', + name: 'DeepSeek', + apiType: 'openai-compatible', + apiKey: 'test-key', + baseUrl: 'https://api.deepseek.com/v1', + enable: true, + custom: false, + ...overrides +}) + +async function setup(options?: { + provider?: LLM_PROVIDER + providerWebsites?: { + official: string + apiKey: string + docs: string + models: string + defaultBaseUrl: string + } +}) { + vi.resetModules() + + const llmproviderPresenter = { + getKeyStatus: vi.fn().mockResolvedValue(null), + refreshModels: vi.fn().mockResolvedValue(undefined) + } + const modelCheckStore = { + openDialog: vi.fn() + } + + vi.doMock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'settings.provider.modifyBaseUrl') return 'Modify' + if (key === 'settings.provider.baseUrlLockedHint') { + return 'This provider is pinned to the recommended Base URL.' + } + if (key === 'settings.provider.urlPlaceholder') return 'Enter API URL' + if (key === 'settings.provider.urlFormat') { + return `Default: ${params?.defaultUrl ?? ''}` + } + if (key === 'settings.provider.urlFormatFill') return 'Fill into API URL' + if (key === 'settings.provider.dialog.baseUrlUnlock.confirm') return 'Continue' + return key + } + }) + })) + + vi.doMock('@/composables/usePresenter', () => ({ + usePresenter: (name: string) => { + if (name === 'llmproviderPresenter') return llmproviderPresenter + throw new Error(`Unexpected presenter: ${name}`) + } + })) + + vi.doMock('@/stores/modelCheck', () => ({ + useModelCheckStore: () => modelCheckStore + })) + + vi.doMock('@shadcn/components/ui/input', () => ({ + Input: createInputStub() + })) + vi.doMock('@shadcn/components/ui/button', () => ({ + Button: buttonStub + })) + vi.doMock('@shadcn/components/ui/label', () => ({ + Label: labelStub + })) + vi.doMock('@shadcn/components/ui/tooltip', () => ({ + Tooltip: passthrough('Tooltip'), + TooltipContent: passthrough('TooltipContent'), + TooltipProvider: passthrough('TooltipProvider'), + TooltipTrigger: passthrough('TooltipTrigger') + })) + vi.doMock('@iconify/vue', () => ({ + Icon: defineComponent({ + name: 'Icon', + template: '' + }) + })) + + const ProviderApiConfig = ( + await import('../../../src/renderer/settings/components/ProviderApiConfig.vue') + ).default + + const wrapper = mount(ProviderApiConfig, { + props: { + provider: options?.provider ?? createProvider(), + providerWebsites: options?.providerWebsites ?? { + official: 'https://example.com', + apiKey: 'https://example.com/key', + docs: 'https://example.com/docs', + models: 'https://example.com/models', + defaultBaseUrl: 'https://api.deepseek.com/v1' + } + }, + global: { + stubs: { + GitHubCopilotOAuth: true + } + } + }) + + await flushPromises() + + return { + wrapper, + llmproviderPresenter, + modelCheckStore + } +} + +function findButtonByText(wrapper: ReturnType, text: string) { + return wrapper.findAll('button').find((button) => button.text().trim() === text) +} + +describe('ProviderApiConfig', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('shows a locked Base URL display for built-in providers outside the allowlist', async () => { + const { wrapper, llmproviderPresenter } = await setup() + + expect(wrapper.find('input#deepseek-url').exists()).toBe(false) + expect(wrapper.text()).toContain('This provider is pinned to the recommended Base URL.') + expect(findButtonByText(wrapper, 'Modify')).toBeDefined() + expect(wrapper.html()).not.toContain('Fill into API URL') + expect(llmproviderPresenter.getKeyStatus).toHaveBeenCalledWith('deepseek') + }) + + it('switches directly into edit mode and hides the modify button', async () => { + const { wrapper } = await setup() + const modifyButton = findButtonByText(wrapper, 'Modify') + + expect(modifyButton).toBeDefined() + await modifyButton!.trigger('click') + await flushPromises() + + expect(wrapper.find('input#deepseek-url').exists()).toBe(true) + expect(findButtonByText(wrapper, 'Modify')).toBeUndefined() + }) + + it('preserves the existing save behavior after unlocking', async () => { + const { wrapper } = await setup() + const modifyButton = findButtonByText(wrapper, 'Modify') + + expect(modifyButton).toBeDefined() + await modifyButton!.trigger('click') + await flushPromises() + + const input = wrapper.get('input#deepseek-url') + await input.setValue('https://custom.deepseek.com/v1') + await input.trigger('blur') + + expect(wrapper.emitted('api-host-change')).toEqual([['https://custom.deepseek.com/v1']]) + }) + + it('keeps OpenAI Responses editable without the lock prompt', async () => { + const { wrapper } = await setup({ + provider: createProvider({ + id: 'openai-responses', + name: 'OpenAI Responses', + baseUrl: 'https://api.openai.com/v1' + }) + }) + + expect(wrapper.find('input#openai-responses-url').exists()).toBe(true) + expect(findButtonByText(wrapper, 'Modify')).toBeUndefined() + expect(wrapper.text()).not.toContain('This provider is pinned to the recommended Base URL.') + }) + + it('keeps custom providers editable by default', async () => { + const { wrapper } = await setup({ + provider: createProvider({ + id: 'custom-demo', + name: 'Custom Demo', + custom: true, + baseUrl: 'https://custom.example.com/v1' + }) + }) + + expect(wrapper.find('input#custom-demo-url').exists()).toBe(true) + expect(findButtonByText(wrapper, 'Modify')).toBeUndefined() + }) +}) From 6071c3937e7836d1d9885491c00981d0acdd4893 Mon Sep 17 00:00:00 2001 From: zerob13 Date: Thu, 2 Apr 2026 22:18:17 +0800 Subject: [PATCH 3/3] fix(zenmux): guard anthropic embeddings --- .../providers/zenmuxProvider.ts | 6 ++---- .../llmProviderPresenter/zenmuxProvider.test.ts | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts index 7affab906..37b4d7efc 100644 --- a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts @@ -208,8 +208,7 @@ export class ZenmuxProvider extends BaseLLMProvider { public async getEmbeddings(modelId: string, texts: string[]): Promise { if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.getEmbeddings(modelId, texts) + throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) } return this.openaiDelegate.getEmbeddings(modelId, texts) @@ -217,8 +216,7 @@ export class ZenmuxProvider extends BaseLLMProvider { public async getDimensions(modelId: string): Promise { if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.getDimensions(modelId) + throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) } return this.openaiDelegate.getDimensions(modelId) diff --git a/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts b/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts index 869ebf1a4..b8977106b 100644 --- a/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts @@ -331,4 +331,20 @@ describe('ZenmuxProvider', () => { expect(openaiProxySpy).toHaveBeenCalledTimes(1) expect(anthropicProxySpy).toHaveBeenCalledTimes(1) }) + + it('fails fast for embeddings on anthropic/* models', async () => { + const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) + + await expect(provider.getEmbeddings('anthropic/claude-sonnet-4-5', ['hello'])).rejects.toThrow( + 'Embeddings not supported for Anthropic models: anthropic/claude-sonnet-4-5' + ) + }) + + it('fails fast for embedding dimensions on anthropic/* models', async () => { + const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) + + await expect(provider.getDimensions('anthropic/claude-sonnet-4-5')).rejects.toThrow( + 'Embeddings not supported for Anthropic models: anthropic/claude-sonnet-4-5' + ) + }) })