diff --git a/src/main/presenter/deepchatAgentPresenter/accumulator.ts b/src/main/presenter/deepchatAgentPresenter/accumulator.ts index c5b2cddda..b91905170 100644 --- a/src/main/presenter/deepchatAgentPresenter/accumulator.ts +++ b/src/main/presenter/deepchatAgentPresenter/accumulator.ts @@ -168,6 +168,7 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void state.metadata.outputTokens = event.usage.completion_tokens state.metadata.totalTokens = event.usage.total_tokens state.metadata.cachedInputTokens = event.usage.cached_tokens + state.metadata.cacheWriteInputTokens = event.usage.cache_write_tokens break } case 'stop': { diff --git a/src/main/presenter/deepchatAgentPresenter/index.ts b/src/main/presenter/deepchatAgentPresenter/index.ts index 54c69daac..e7668a163 100644 --- a/src/main/presenter/deepchatAgentPresenter/index.ts +++ b/src/main/presenter/deepchatAgentPresenter/index.ts @@ -1405,7 +1405,8 @@ export class DeepChatAgentPresenter implements IAgentImplementation { maxTokens: generationSettings.maxTokens, thinkingBudget: generationSettings.thinkingBudget, reasoningEffort: generationSettings.reasoningEffort, - verbosity: generationSettings.verbosity + verbosity: generationSettings.verbosity, + conversationId: sessionId } const traceEnabled = this.configPresenter.getSetting('traceDebugEnabled') === true diff --git a/src/main/presenter/deepchatAgentPresenter/process.ts b/src/main/presenter/deepchatAgentPresenter/process.ts index 6b341fabd..c075da570 100644 --- a/src/main/presenter/deepchatAgentPresenter/process.ts +++ b/src/main/presenter/deepchatAgentPresenter/process.ts @@ -308,5 +308,8 @@ function buildUsageSnapshot(state: StreamState): Record { if (typeof state.metadata.cachedInputTokens === 'number') { usage.cachedInputTokens = state.metadata.cachedInputTokens } + if (typeof state.metadata.cacheWriteInputTokens === 'number') { + usage.cacheWriteInputTokens = state.metadata.cacheWriteInputTokens + } return usage } diff --git a/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts b/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts new file mode 100644 index 000000000..242de3dfa --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/promptCacheCapabilities.ts @@ -0,0 +1,50 @@ +export type PromptCacheMode = + | 'disabled' + | 'openai_implicit' + | 'anthropic_auto' + | 'anthropic_explicit' + +function normalizeId(value: string | undefined): string { + return value?.trim().toLowerCase() ?? '' +} + +function isClaudeModel(modelId: string): boolean { + return modelId.includes('claude') +} + +export function resolvePromptCacheMode(providerId: string, modelId: string): PromptCacheMode { + const normalizedProviderId = normalizeId(providerId) + const normalizedModelId = normalizeId(modelId) + + if (normalizedProviderId === 'openai') { + return 'openai_implicit' + } + + if (normalizedProviderId === 'anthropic' && isClaudeModel(normalizedModelId)) { + return 'anthropic_auto' + } + + if ( + normalizedProviderId === 'zenmux' && + normalizedModelId.startsWith('anthropic/') && + isClaudeModel(normalizedModelId) + ) { + return 'anthropic_explicit' + } + + if ( + normalizedProviderId === 'aws-bedrock' && + (normalizedModelId.includes('anthropic.claude') || isClaudeModel(normalizedModelId)) + ) { + return 'anthropic_explicit' + } + + if ( + normalizedProviderId === 'openrouter' && + (normalizedModelId.startsWith('anthropic/') || isClaudeModel(normalizedModelId)) + ) { + return 'anthropic_explicit' + } + + return 'disabled' +} diff --git a/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts b/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts new file mode 100644 index 000000000..8de06538c --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts @@ -0,0 +1,328 @@ +import { createHash } from 'crypto' +import Anthropic from '@anthropic-ai/sdk' +import type { ChatCompletionContentPart, ChatCompletionMessageParam } from 'openai/resources' +import type { MCPToolDefinition } from '@shared/presenter' +import { resolvePromptCacheMode, type PromptCacheMode } from './promptCacheCapabilities' + +export type PromptCacheApiType = 'openai_chat' | 'openai_responses' | 'anthropic' +export type PromptCacheTtl = '5m' + +export interface PromptCacheBreakpointPlan { + messageIndex: number + contentIndex: number +} + +export interface PromptCachePlan { + mode: PromptCacheMode + ttl: PromptCacheTtl | null + cacheKey?: string + breakpointPlan?: PromptCacheBreakpointPlan +} + +export interface ResolvePromptCachePlanParams { + providerId: string + apiType: PromptCacheApiType + modelId: string + messages: unknown[] + tools?: MCPToolDefinition[] + conversationId?: string +} + +type EphemeralCacheControl = { type: 'ephemeral' } + +const EPHEMERAL_CACHE_CONTROL: EphemeralCacheControl = { type: 'ephemeral' } + +type AnthropicTextBlockWithCache = Anthropic.TextBlockParam & { + cache_control?: EphemeralCacheControl +} + +function normalizeId(value: string | undefined): string { + return value?.trim().toLowerCase() ?? '' +} + +function buildPromptCacheKey( + providerId: string, + modelId: string, + conversationId?: string +): string | undefined { + const normalizedConversationId = conversationId?.trim() + if (!normalizedConversationId) { + return undefined + } + + const digest = createHash('sha256') + .update(`${normalizeId(providerId)}:${normalizeId(modelId)}:${normalizedConversationId}`) + .digest('hex') + .slice(0, 20) + + return `deepchat:${normalizeId(providerId)}:${normalizeId(modelId)}:${digest}` +} + +function findOpenAIChatBreakpoint( + messages: ChatCompletionMessageParam[] +): PromptCacheBreakpointPlan | undefined { + let prefixEnd = messages.length + + while (prefixEnd > 0) { + const role = messages[prefixEnd - 1]?.role + if (role === 'user' || role === 'tool') { + prefixEnd -= 1 + continue + } + break + } + + for (let messageIndex = prefixEnd - 1; messageIndex >= 0; messageIndex -= 1) { + const message = messages[messageIndex] + if (!message || message.role === 'tool') { + continue + } + + const content = 'content' in message ? message.content : undefined + if (typeof content === 'string') { + if (content.trim()) { + return { messageIndex, contentIndex: 0 } + } + continue + } + + if (!Array.isArray(content)) { + continue + } + + for (let contentIndex = content.length - 1; contentIndex >= 0; contentIndex -= 1) { + const part = content[contentIndex] + if (part?.type === 'text' && typeof part.text === 'string' && part.text.trim()) { + return { messageIndex, contentIndex } + } + } + } + + return undefined +} + +function findAnthropicBreakpoint( + messages: Anthropic.MessageParam[] +): PromptCacheBreakpointPlan | undefined { + let prefixEnd = messages.length + + while (prefixEnd > 0) { + const role = messages[prefixEnd - 1]?.role + if (role === 'user') { + prefixEnd -= 1 + continue + } + break + } + + for (let messageIndex = prefixEnd - 1; messageIndex >= 0; messageIndex -= 1) { + const message = messages[messageIndex] + if (!message) { + continue + } + + const content = message.content + if (typeof content === 'string') { + if (content.trim()) { + return { messageIndex, contentIndex: 0 } + } + continue + } + + if (!Array.isArray(content)) { + continue + } + + for (let contentIndex = content.length - 1; contentIndex >= 0; contentIndex -= 1) { + const block = content[contentIndex] + if (block?.type === 'text' && typeof block.text === 'string' && block.text.trim()) { + return { messageIndex, contentIndex } + } + } + } + + return undefined +} + +export function resolvePromptCachePlan(params: ResolvePromptCachePlanParams): PromptCachePlan { + const mode = resolvePromptCacheMode(params.providerId, params.modelId) + + if (mode === 'disabled') { + return { mode, ttl: null } + } + + if (mode === 'openai_implicit') { + return { + mode, + ttl: null, + cacheKey: buildPromptCacheKey(params.providerId, params.modelId, params.conversationId) + } + } + + if (mode === 'anthropic_auto') { + return { + mode, + ttl: '5m' + } + } + + const breakpointPlan = + params.apiType === 'anthropic' + ? findAnthropicBreakpoint(params.messages as Anthropic.MessageParam[]) + : findOpenAIChatBreakpoint(params.messages as ChatCompletionMessageParam[]) + + return { + mode, + ttl: '5m', + breakpointPlan + } +} + +export function applyOpenAIPromptCacheKey>( + requestParams: T, + plan: PromptCachePlan +): T { + if (plan.mode !== 'openai_implicit' || !plan.cacheKey) { + return requestParams + } + + return { + ...requestParams, + prompt_cache_key: plan.cacheKey + } +} + +export function applyAnthropicTopLevelCacheControl>( + requestParams: T, + plan: PromptCachePlan +): T { + if (plan.mode !== 'anthropic_auto') { + return requestParams + } + + return { + ...requestParams, + cache_control: EPHEMERAL_CACHE_CONTROL + } +} + +export function applyOpenAIChatExplicitCacheBreakpoint( + messages: ChatCompletionMessageParam[], + plan: PromptCachePlan +): ChatCompletionMessageParam[] { + if (plan.mode !== 'anthropic_explicit' || !plan.breakpointPlan) { + return messages + } + + const { messageIndex, contentIndex } = plan.breakpointPlan + const target = messages[messageIndex] + + if (!target || !('content' in target)) { + return messages + } + + const content = target.content + let nextContent: ChatCompletionMessageParam['content'] = + content as ChatCompletionMessageParam['content'] + + if (typeof content === 'string') { + if (!content.trim() || contentIndex !== 0) { + return messages + } + + nextContent = [ + { + type: 'text', + text: content, + cache_control: EPHEMERAL_CACHE_CONTROL + } as unknown as ChatCompletionContentPart + ] + } else if (Array.isArray(content)) { + nextContent = content.map((part, index) => { + if ( + index !== contentIndex || + part?.type !== 'text' || + typeof part.text !== 'string' || + !part.text.trim() + ) { + return part + } + + return { + ...part, + cache_control: EPHEMERAL_CACHE_CONTROL + } as unknown as ChatCompletionContentPart + }) as ChatCompletionMessageParam['content'] + } else { + return messages + } + + return messages.map((message, index) => + index === messageIndex + ? ({ + ...message, + content: nextContent + } as ChatCompletionMessageParam) + : message + ) +} + +export function applyAnthropicExplicitCacheBreakpoint( + messages: Anthropic.MessageParam[], + plan: PromptCachePlan +): Anthropic.MessageParam[] { + if (plan.mode !== 'anthropic_explicit' || !plan.breakpointPlan) { + return messages + } + + const { messageIndex, contentIndex } = plan.breakpointPlan + const target = messages[messageIndex] + + if (!target) { + return messages + } + + const content = target.content + let nextContent: Anthropic.MessageParam['content'] = content + + if (typeof content === 'string') { + if (!content.trim() || contentIndex !== 0) { + return messages + } + + nextContent = [ + { + type: 'text', + text: content, + cache_control: EPHEMERAL_CACHE_CONTROL + } satisfies AnthropicTextBlockWithCache + ] + } else if (Array.isArray(content)) { + nextContent = content.map((block, index) => { + if ( + index !== contentIndex || + block?.type !== 'text' || + typeof block.text !== 'string' || + !block.text.trim() + ) { + return block + } + + return { + ...block, + cache_control: EPHEMERAL_CACHE_CONTROL + } satisfies AnthropicTextBlockWithCache + }) + } else { + return messages + } + + return messages.map((message, index) => + index === messageIndex + ? ({ + ...message, + content: nextContent + } as Anthropic.MessageParam) + : message + ) +} diff --git a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts index 296713ee7..5b29a7708 100644 --- a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts @@ -15,6 +15,67 @@ import { proxyConfig } from '../../proxyConfig' import { ProxyAgent } from 'undici' import type { Usage } from '@anthropic-ai/sdk/resources' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { + applyAnthropicExplicitCacheBreakpoint, + applyAnthropicTopLevelCacheControl, + resolvePromptCachePlan +} from '../promptCacheStrategy' + +type CacheAwareAnthropicUsage = Usage & { + cache_read_input_tokens?: number + cache_creation_input_tokens?: number + cacheReadInputTokens?: number + cacheWriteInputTokens?: number +} + +function getAnthropicUsageNumber( + usage: CacheAwareAnthropicUsage | undefined, + snakeKey: 'cache_read_input_tokens' | 'cache_creation_input_tokens', + camelKey: 'cacheReadInputTokens' | 'cacheWriteInputTokens' +): number { + const value = usage?.[snakeKey] ?? usage?.[camelKey] + return typeof value === 'number' && Number.isFinite(value) ? value : 0 +} + +function buildAnthropicUsageSnapshot(usage: CacheAwareAnthropicUsage | undefined): { + prompt_tokens: number + completion_tokens: number + total_tokens: number + cached_tokens?: number + cache_write_tokens?: number +} | null { + if (!usage) { + return null + } + + const uncachedInputTokens = + typeof usage.input_tokens === 'number' && Number.isFinite(usage.input_tokens) + ? usage.input_tokens + : 0 + const completionTokens = + typeof usage.output_tokens === 'number' && Number.isFinite(usage.output_tokens) + ? usage.output_tokens + : 0 + const cachedTokens = getAnthropicUsageNumber( + usage, + 'cache_read_input_tokens', + 'cacheReadInputTokens' + ) + const cacheWriteTokens = getAnthropicUsageNumber( + usage, + 'cache_creation_input_tokens', + 'cacheWriteInputTokens' + ) + const promptTokens = uncachedInputTokens + cachedTokens + cacheWriteTokens + + return { + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens, + ...(cachedTokens > 0 ? { cached_tokens: cachedTokens } : {}), + ...(cacheWriteTokens > 0 ? { cache_write_tokens: cacheWriteTokens } : {}) + } +} export class AnthropicProvider extends BaseLLMProvider { private anthropic!: Anthropic @@ -42,6 +103,30 @@ export class AnthropicProvider extends BaseLLMProvider { } } + private applyPromptCache>( + requestParams: T, + modelId: string, + messages: Anthropic.MessageParam[], + conversationId?: string + ): T { + const plan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'anthropic', + modelId, + messages: messages as unknown[], + conversationId + }) + const nextRequestParams = + plan.mode === 'anthropic_explicit' + ? { + ...requestParams, + messages: applyAnthropicExplicitCacheBreakpoint(messages, plan) + } + : requestParams + + return applyAnthropicTopLevelCacheControl(nextRequestParams, plan) + } + public onProxyResolved(): void { this.init() } @@ -458,10 +543,16 @@ export class AnthropicProvider extends BaseLLMProvider { requestParams.system = formattedMessages.system } + const cachedRequestParams = this.applyPromptCache( + requestParams, + modelId, + formattedMessages.messages + ) + if (!this.anthropic) { throw new Error('Anthropic client is not initialized') } - const response = await this.anthropic.messages.create(requestParams) + const response = await this.anthropic.messages.create(cachedRequestParams) const resultResp: LLMResponse = { content: '' @@ -469,10 +560,13 @@ export class AnthropicProvider extends BaseLLMProvider { // 添加usage信息 if (response.usage) { + const usageSnapshot = buildAnthropicUsageSnapshot( + response.usage as CacheAwareAnthropicUsage + ) resultResp.totalUsage = { - prompt_tokens: response.usage.input_tokens, - completion_tokens: response.usage.output_tokens, - total_tokens: response.usage.input_tokens + response.usage.output_tokens + prompt_tokens: usageSnapshot?.prompt_tokens ?? 0, + completion_tokens: usageSnapshot?.completion_tokens ?? 0, + total_tokens: usageSnapshot?.total_tokens ?? 0 } } @@ -546,10 +640,16 @@ ${text} requestParams.system = systemPrompt } + const cachedRequestParams = this.applyPromptCache( + requestParams, + modelId, + requestParams.messages + ) + if (!this.anthropic) { throw new Error('Anthropic client is not initialized') } - const response = await this.anthropic.messages.create(requestParams) + const response = await this.anthropic.messages.create(cachedRequestParams) return { content: response.content @@ -588,10 +688,16 @@ ${context} requestParams.system = systemPrompt } + const cachedRequestParams = this.applyPromptCache( + requestParams, + modelId, + requestParams.messages + ) + if (!this.anthropic) { throw new Error('Anthropic client is not initialized') } - const response = await this.anthropic.messages.create(requestParams) + const response = await this.anthropic.messages.create(cachedRequestParams) const suggestions = response.content .filter((block: any) => block.type === 'text') @@ -657,14 +763,20 @@ ${context} // @ts-ignore - 类型不匹配,但格式是正确的 streamParams.tools = anthropicTools } + const cachedStreamParams = this.applyPromptCache( + streamParams as unknown as Record, + modelId, + formattedMessagesObject.messages, + modelConfig.conversationId + ) as unknown as Anthropic.Messages.MessageCreateParamsStreaming await this.emitRequestTrace(modelConfig, { endpoint: this.buildAnthropicEndpoint(), headers: this.buildAnthropicApiKeyHeaders(), - body: streamParams + body: cachedStreamParams }) // console.log('streamParams', JSON.stringify(streamParams.messages)) // 创建Anthropic流 - const stream = await this.anthropic.messages.create(streamParams) + const stream = await this.anthropic.messages.create(cachedStreamParams) // 状态变量 let accumulatedJson = '' @@ -820,11 +932,10 @@ ${context} } } if (usageMetadata) { - yield createStreamEvent.usage({ - prompt_tokens: usageMetadata.input_tokens, - completion_tokens: usageMetadata.output_tokens, - total_tokens: usageMetadata.input_tokens + usageMetadata.output_tokens - }) + const usageSnapshot = buildAnthropicUsageSnapshot(usageMetadata as CacheAwareAnthropicUsage) + if (usageSnapshot) { + yield createStreamEvent.usage(usageSnapshot) + } } // 发送停止事件 yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') diff --git a/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts b/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts index f11b727e8..00ee51071 100644 --- a/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts @@ -20,6 +20,66 @@ import { import Anthropic from '@anthropic-ai/sdk' import { Usage } from '@anthropic-ai/sdk/resources/messages' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { + applyAnthropicExplicitCacheBreakpoint, + resolvePromptCachePlan +} from '../promptCacheStrategy' + +type CacheAwareBedrockUsage = Usage & { + cache_read_input_tokens?: number + cache_creation_input_tokens?: number + cacheReadInputTokens?: number + cacheWriteInputTokens?: number +} + +function getBedrockUsageNumber( + usage: CacheAwareBedrockUsage | undefined, + snakeKey: 'cache_read_input_tokens' | 'cache_creation_input_tokens', + camelKey: 'cacheReadInputTokens' | 'cacheWriteInputTokens' +): number { + const value = usage?.[snakeKey] ?? usage?.[camelKey] + return typeof value === 'number' && Number.isFinite(value) ? value : 0 +} + +function buildBedrockUsageSnapshot(usage: CacheAwareBedrockUsage | undefined): { + prompt_tokens: number + completion_tokens: number + total_tokens: number + cached_tokens?: number + cache_write_tokens?: number +} | null { + if (!usage) { + return null + } + + const uncachedInputTokens = + typeof usage.input_tokens === 'number' && Number.isFinite(usage.input_tokens) + ? usage.input_tokens + : 0 + const completionTokens = + typeof usage.output_tokens === 'number' && Number.isFinite(usage.output_tokens) + ? usage.output_tokens + : 0 + const cachedTokens = getBedrockUsageNumber( + usage, + 'cache_read_input_tokens', + 'cacheReadInputTokens' + ) + const cacheWriteTokens = getBedrockUsageNumber( + usage, + 'cache_creation_input_tokens', + 'cacheWriteInputTokens' + ) + const promptTokens = uncachedInputTokens + cachedTokens + cacheWriteTokens + + return { + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: promptTokens + completionTokens, + ...(cachedTokens > 0 ? { cached_tokens: cachedTokens } : {}), + ...(cacheWriteTokens > 0 ? { cache_write_tokens: cacheWriteTokens } : {}) + } +} export class AwsBedrockProvider extends BaseLLMProvider { private bedrock!: BedrockClient @@ -64,6 +124,21 @@ export class AwsBedrockProvider extends BaseLLMProvider { return body } + private applyPromptCache( + messages: Anthropic.MessageParam[], + modelId: string, + conversationId?: string + ): Anthropic.MessageParam[] { + const plan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'anthropic', + modelId, + messages: messages as unknown[], + conversationId + }) + return applyAnthropicExplicitCacheBreakpoint(messages, plan) + } + public onProxyResolved(): void { this.init() } @@ -539,6 +614,7 @@ ${text} } const formattedMessages = this.formatMessages(messages) + const cachedMessages = this.applyPromptCache(formattedMessages.messages, modelId) // 创建基本请求参数 const payload = { @@ -546,7 +622,7 @@ ${text} max_tokens: maxTokens, temperature, system: formattedMessages.system, - messages + messages: cachedMessages } const command = new InvokeModelCommand({ @@ -566,10 +642,11 @@ ${text} // 添加usage信息 if (response.usage) { + const usageSnapshot = buildBedrockUsageSnapshot(response.usage as CacheAwareBedrockUsage) resultResp.totalUsage = { - prompt_tokens: response.usage.input_tokens, - completion_tokens: response.usage.output_tokens, - total_tokens: response.usage.input_tokens + response.usage.output_tokens + prompt_tokens: usageSnapshot?.prompt_tokens ?? 0, + completion_tokens: usageSnapshot?.completion_tokens ?? 0, + total_tokens: usageSnapshot?.total_tokens ?? 0 } } @@ -628,6 +705,11 @@ ${text} try { // 格式化消息 const formattedMessagesObject = this.formatMessages(messages) + const cachedMessages = this.applyPromptCache( + formattedMessagesObject.messages, + modelId, + modelConfig.conversationId + ) console.log('formattedMessagesObject', JSON.stringify(formattedMessagesObject)) // 将MCP工具转换为Anthropic工具格式 @@ -648,8 +730,8 @@ ${text} anthropic_version: 'bedrock-2023-05-31', max_tokens: maxTokens || 1024, temperature: temperature ?? 0.7, - // system: formattedMessagesObject.system, - messages: formattedMessagesObject.messages, + system: formattedMessagesObject.system, + messages: cachedMessages, thinking: undefined as any, tools: undefined as any } @@ -888,11 +970,10 @@ ${text} } } if (usageMetadata) { - yield createStreamEvent.usage({ - prompt_tokens: usageMetadata.input_tokens, - completion_tokens: usageMetadata.output_tokens, - total_tokens: usageMetadata.input_tokens + usageMetadata.output_tokens - }) + const usageSnapshot = buildBedrockUsageSnapshot(usageMetadata as CacheAwareBedrockUsage) + if (usageSnapshot) { + yield createStreamEvent.usage(usageSnapshot) + } } // 发送停止事件 yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') diff --git a/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts index 8aedd633b..8d6c1640c 100644 --- a/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts @@ -33,6 +33,11 @@ import { proxyConfig } from '../../proxyConfig' import { modelCapabilities } from '../../configPresenter/modelCapabilities' import { ProxyAgent } from 'undici' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { + applyOpenAIChatExplicitCacheBreakpoint, + applyOpenAIPromptCacheKey, + resolvePromptCachePlan +} from '../promptCacheStrategy' const OPENAI_REASONING_MODELS = [ 'o4-mini', @@ -80,6 +85,17 @@ export function normalizeExtractedImageText(content: string): string { } function getOpenAIChatCachedTokens(usage: unknown): number | undefined { + return getOpenAIChatUsageDetail(usage, 'cached_tokens') +} + +function getOpenAIChatCacheWriteTokens(usage: unknown): number | undefined { + return getOpenAIChatUsageDetail(usage, 'cache_write_tokens') +} + +function getOpenAIChatUsageDetail( + usage: unknown, + key: 'cached_tokens' | 'cache_write_tokens' +): number | undefined { if (!usage || typeof usage !== 'object') { return undefined } @@ -88,11 +104,11 @@ function getOpenAIChatCachedTokens(usage: unknown): number | undefined { const inputTokensDetails = (usage as { input_tokens_details?: unknown }).input_tokens_details const promptCachedTokens = promptTokensDetails && typeof promptTokensDetails === 'object' - ? (promptTokensDetails as { cached_tokens?: unknown }).cached_tokens + ? (promptTokensDetails as Record)[key] : undefined const inputCachedTokens = inputTokensDetails && typeof inputTokensDetails === 'object' - ? (inputTokensDetails as { cached_tokens?: unknown }).cached_tokens + ? (inputTokensDetails as Record)[key] : undefined const cachedTokens = typeof promptCachedTokens === 'number' ? promptCachedTokens : inputCachedTokens @@ -768,7 +784,7 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) const supportsFunctionCall = modelConfig?.functionCall || false - const requestParams: OpenAI.Chat.ChatCompletionCreateParams = { + const requestParams: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming = { messages: this.formatMessages(messages, supportsFunctionCall), model: modelId, stream: false, @@ -781,12 +797,27 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { ? { max_completion_tokens: maxTokens } : { max_tokens: maxTokens }) } + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_chat', + modelId, + messages: requestParams.messages as unknown[], + conversationId: modelConfig?.conversationId + }) + requestParams.messages = applyOpenAIChatExplicitCacheBreakpoint( + requestParams.messages as ChatCompletionMessageParam[], + promptCachePlan + ) OPENAI_REASONING_MODELS.forEach((noTempId) => { if (modelId.startsWith(noTempId)) { delete requestParams.temperature } }) - const completion = await this.openai.chat.completions.create(requestParams) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming + const completion = await this.openai.chat.completions.create(cachedRequestParams) const message = completion.choices[0].message as ChatCompletionMessage & { reasoning_content?: string @@ -1012,7 +1043,8 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { prompt_tokens: result.usage.input_tokens || 0, completion_tokens: result.usage.output_tokens || 0, total_tokens: result.usage.total_tokens || 0, - cached_tokens: getOpenAIChatCachedTokens(result.usage) + cached_tokens: getOpenAIChatCachedTokens(result.usage), + cache_write_tokens: getOpenAIChatCacheWriteTokens(result.usage) }) } @@ -1081,7 +1113,7 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { : undefined // 构建请求参数 - const requestParams: OpenAI.Chat.ChatCompletionCreateParams = { + const requestParams: OpenAI.Chat.ChatCompletionCreateParamsStreaming = { messages: processedMessages, model: modelId, stream: true, @@ -1132,15 +1164,32 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { // 如果存在 API 工具且支持函数调用,则添加到请求参数中 if (apiTools && apiTools.length > 0 && supportsFunctionCall) requestParams.tools = apiTools + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_chat', + modelId, + messages: processedMessages as unknown[], + tools, + conversationId: modelConfig?.conversationId + }) + requestParams.messages = applyOpenAIChatExplicitCacheBreakpoint( + requestParams.messages as ChatCompletionMessageParam[], + promptCachePlan + ) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Chat.ChatCompletionCreateParamsStreaming + await this.emitRequestTrace(modelConfig, { endpoint: this.buildChatCompletionsEndpoint(), headers: this.buildChatCompletionsTraceHeaders(), - body: requestParams + body: cachedRequestParams }) // console.log('[handleChatCompletion] requestParams', JSON.stringify(requestParams)) // 发起 OpenAI 聊天补全请求 - const stream = await this.openai.chat.completions.create(requestParams) + const stream = await this.openai.chat.completions.create(cachedRequestParams) //----------------------------------------------------------------------------------------------------- // 流处理状态定义 (已将相关变量声明提升到顶部,确保可见性) @@ -1179,6 +1228,7 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { completion_tokens: number total_tokens: number cached_tokens?: number + cache_write_tokens?: number } | undefined = undefined @@ -1195,7 +1245,8 @@ export class OpenAICompatibleProvider extends BaseLLMProvider { if (chunk.usage) { usage = { ...chunk.usage, - cached_tokens: getOpenAIChatCachedTokens(chunk.usage) + cached_tokens: getOpenAIChatCachedTokens(chunk.usage), + cache_write_tokens: getOpenAIChatCacheWriteTokens(chunk.usage) } } diff --git a/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts index e4e4daa6c..67a477f4b 100644 --- a/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts @@ -24,6 +24,7 @@ import { proxyConfig } from '../../proxyConfig' import { ProxyAgent } from 'undici' import { modelCapabilities } from '../../configPresenter/modelCapabilities' import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { applyOpenAIPromptCacheKey, resolvePromptCachePlan } from '../promptCacheStrategy' const OPENAI_REASONING_MODELS = [ 'o4-mini', @@ -62,7 +63,9 @@ function getOpenAIResponseCachedTokens( | { input_tokens_details?: { cached_tokens?: number + cache_write_tokens?: number } + cache_write_tokens?: number } | null | undefined @@ -73,6 +76,24 @@ function getOpenAIResponseCachedTokens( : undefined } +function getOpenAIResponseCacheWriteTokens(usage: unknown): number | undefined { + if (!usage || typeof usage !== 'object') { + return undefined + } + + const inputTokensDetails = (usage as { input_tokens_details?: unknown }).input_tokens_details + const nestedCacheWriteTokens = + inputTokensDetails && typeof inputTokensDetails === 'object' + ? (inputTokensDetails as Record).cache_write_tokens + : undefined + const topLevelCacheWriteTokens = (usage as Record).cache_write_tokens + const cacheWriteTokens = + typeof nestedCacheWriteTokens === 'number' ? nestedCacheWriteTokens : topLevelCacheWriteTokens + return typeof cacheWriteTokens === 'number' && Number.isFinite(cacheWriteTokens) + ? cacheWriteTokens + : undefined +} + export class OpenAIResponsesProvider extends BaseLLMProvider { protected openai!: OpenAI private isNoModelsApi: boolean = false @@ -305,7 +326,7 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { } const formattedMessages = this.formatMessages(messages) - const requestParams: OpenAI.Responses.ResponseCreateParams = { + const requestParams: OpenAI.Responses.ResponseCreateParamsNonStreaming = { model: modelId, input: formattedMessages, temperature: temperature, @@ -314,6 +335,13 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { } const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_responses', + modelId, + messages: formattedMessages as unknown[], + conversationId: modelConfig?.conversationId + }) if (modelConfig.reasoningEffort && this.supportsEffortParameter(modelId)) { ;(requestParams as any).reasoning = { effort: modelConfig.reasoningEffort @@ -333,7 +361,12 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { } }) - const response = await this.openai.responses.create(requestParams) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Responses.ResponseCreateParamsNonStreaming + + const response = await this.openai.responses.create(cachedRequestParams) const resultResp: LLMResponse = { content: '' } @@ -571,7 +604,8 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { prompt_tokens: result.usage.input_tokens || 0, completion_tokens: result.usage.output_tokens || 0, total_tokens: result.usage.total_tokens || 0, - cached_tokens: getOpenAIResponseCachedTokens(result.usage) + cached_tokens: getOpenAIResponseCachedTokens(result.usage), + cache_write_tokens: getOpenAIResponseCacheWriteTokens(result.usage) }) } @@ -631,13 +665,21 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { ? await this.mcpRuntime?.mcpToolsToOpenAIResponsesTools(tools, this.provider.id) : undefined - const requestParams: OpenAI.Responses.ResponseCreateParams = { + const requestParams: OpenAI.Responses.ResponseCreateParamsStreaming = { model: modelId, input: processedMessages, temperature, max_output_tokens: maxTokens, stream: true } + const promptCachePlan = resolvePromptCachePlan({ + providerId: this.provider.id, + apiType: 'openai_responses', + modelId, + messages: processedMessages as unknown[], + tools, + conversationId: modelConfig?.conversationId + }) // 如果模型支持函数调用且有工具,添加 tools 参数 if (tools.length > 0 && supportsFunctionCall && apiTools) { @@ -660,13 +702,18 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { if (modelId.startsWith(noTempId)) delete requestParams.temperature }) + const cachedRequestParams = applyOpenAIPromptCacheKey( + requestParams as unknown as Record, + promptCachePlan + ) as unknown as OpenAI.Responses.ResponseCreateParamsStreaming + await this.emitRequestTrace(modelConfig, { endpoint: this.buildResponsesEndpoint(), headers: this.buildResponsesTraceHeaders(), - body: requestParams + body: cachedRequestParams }) - const stream = await this.openai.responses.create(requestParams) + const stream = await this.openai.responses.create(cachedRequestParams) // --- State Variables --- type TagState = 'none' | 'start' | 'inside' | 'end' @@ -696,6 +743,7 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { completion_tokens: number total_tokens: number cached_tokens?: number + cache_write_tokens?: number } | undefined = undefined @@ -1006,7 +1054,8 @@ export class OpenAIResponsesProvider extends BaseLLMProvider { prompt_tokens: response.usage.input_tokens || 0, completion_tokens: response.usage.output_tokens || 0, total_tokens: response.usage.total_tokens || 0, - cached_tokens: getOpenAIResponseCachedTokens(response.usage) + cached_tokens: getOpenAIResponseCachedTokens(response.usage), + cache_write_tokens: getOpenAIResponseCacheWriteTokens(response.usage) } yield createStreamEvent.usage(usage) } diff --git a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts index be1bf5c42..37b4d7efc 100644 --- a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts @@ -1,21 +1,228 @@ -import { IConfigPresenter, LLM_PROVIDER, MODEL_META } from '@shared/presenter' +import Anthropic from '@anthropic-ai/sdk' +import { + ChatMessage, + IConfigPresenter, + KeyStatus, + LLM_EMBEDDING_ATTRS, + LLM_PROVIDER, + LLMResponse, + LLMCoreStreamEvent, + MCPToolDefinition, + MODEL_META, + ModelConfig +} from '@shared/presenter' +import { ProxyAgent } from 'undici' +import { BaseLLMProvider } from '../baseProvider' +import { proxyConfig } from '../../proxyConfig' +import { AnthropicProvider } from './anthropicProvider' import { OpenAICompatibleProvider } from './openAICompatibleProvider' import type { ProviderMcpRuntimePort } from '../runtimePorts' -export class ZenmuxProvider extends OpenAICompatibleProvider { +const ZENMUX_ANTHROPIC_BASE_URL = 'https://zenmux.ai/api/anthropic' + +class ZenmuxOpenAIDelegate extends OpenAICompatibleProvider { + protected override async init() { + this.isInitialized = true + } + + public async fetchZenmuxModels(options?: { timeout: number }): Promise { + return super.fetchOpenAIModels(options) + } +} + +class ZenmuxAnthropicDelegate extends AnthropicProvider { + private clientInitialized = false + + protected override async init() {} + + public async ensureClientInitialized(): Promise { + const apiKey = this.provider.apiKey || process.env.ANTHROPIC_API_KEY || null + if (!apiKey) { + this.clientInitialized = false + this.isInitialized = false + return + } + + const proxyUrl = proxyConfig.getProxyUrl() + const fetchOptions: { dispatcher?: ProxyAgent } = {} + + if (proxyUrl) { + const proxyAgent = new ProxyAgent(proxyUrl) + fetchOptions.dispatcher = proxyAgent + } + + const self = this as unknown as { anthropic?: Anthropic } + self.anthropic = new Anthropic({ + apiKey, + baseURL: this.provider.baseUrl || ZENMUX_ANTHROPIC_BASE_URL, + defaultHeaders: this.defaultHeaders, + fetchOptions + }) + + this.clientInitialized = true + this.isInitialized = true + } + + public isClientInitialized(): boolean { + return this.clientInitialized + } + + public override onProxyResolved(): void { + void this.ensureClientInitialized() + } +} + +export class ZenmuxProvider extends BaseLLMProvider { + private readonly openaiDelegate: ZenmuxOpenAIDelegate + private readonly anthropicDelegate: ZenmuxAnthropicDelegate + constructor( provider: LLM_PROVIDER, configPresenter: IConfigPresenter, mcpRuntime?: ProviderMcpRuntimePort ) { super(provider, configPresenter, mcpRuntime) + + this.openaiDelegate = new ZenmuxOpenAIDelegate(provider, configPresenter, mcpRuntime) + this.anthropicDelegate = new ZenmuxAnthropicDelegate( + { + ...provider, + apiType: 'anthropic', + baseUrl: ZENMUX_ANTHROPIC_BASE_URL + }, + configPresenter, + mcpRuntime + ) + + this.init() + } + + private isAnthropicModel(modelId: string): boolean { + return modelId.trim().toLowerCase().startsWith('anthropic/') + } + + private async ensureAnthropicDelegateReady(): Promise { + await this.anthropicDelegate.ensureClientInitialized() + + if (!this.anthropicDelegate.isClientInitialized()) { + throw new Error('Anthropic SDK not initialized') + } + + return this.anthropicDelegate } - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const models = await super.fetchOpenAIModels(options) + protected async fetchProviderModels(): Promise { + const models = await this.openaiDelegate.fetchZenmuxModels() return models.map((model) => ({ ...model, group: 'ZenMux' })) } + + public onProxyResolved(): void { + this.openaiDelegate.onProxyResolved() + + if (this.anthropicDelegate.isClientInitialized()) { + this.anthropicDelegate.onProxyResolved() + } + } + + public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { + return this.openaiDelegate.check() + } + + public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.summaryTitles(messages, modelId) + } + + return this.openaiDelegate.summaryTitles(messages, modelId) + } + + public async completions( + messages: ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.completions(messages, modelId, temperature, maxTokens) + } + + return this.openaiDelegate.completions(messages, modelId, temperature, maxTokens) + } + + public async summaries( + text: string, + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.summaries(text, modelId, temperature, maxTokens) + } + + return this.openaiDelegate.summaries(text, modelId, temperature, maxTokens) + } + + public async generateText( + prompt: string, + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + return delegate.generateText(prompt, modelId, temperature, maxTokens) + } + + return this.openaiDelegate.generateText(prompt, modelId, temperature, maxTokens) + } + + public async *coreStream( + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + temperature: number, + maxTokens: number, + tools: MCPToolDefinition[] + ): AsyncGenerator { + if (this.isAnthropicModel(modelId)) { + const delegate = await this.ensureAnthropicDelegateReady() + yield* delegate.coreStream(messages, modelId, modelConfig, temperature, maxTokens, tools) + return + } + + yield* this.openaiDelegate.coreStream( + messages, + modelId, + modelConfig, + temperature, + maxTokens, + tools + ) + } + + public async getEmbeddings(modelId: string, texts: string[]): Promise { + if (this.isAnthropicModel(modelId)) { + throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) + } + + return this.openaiDelegate.getEmbeddings(modelId, texts) + } + + public async getDimensions(modelId: string): Promise { + if (this.isAnthropicModel(modelId)) { + throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) + } + + return this.openaiDelegate.getDimensions(modelId) + } + + public async getKeyStatus(): Promise { + return this.openaiDelegate.getKeyStatus() + } } diff --git a/src/main/presenter/newAgentPresenter/index.ts b/src/main/presenter/newAgentPresenter/index.ts index 0b1d776a6..c3860d7ee 100644 --- a/src/main/presenter/newAgentPresenter/index.ts +++ b/src/main/presenter/newAgentPresenter/index.ts @@ -2143,6 +2143,7 @@ export class NewAgentPresenter { inputTokens?: number outputTokens?: number cachedInputTokens?: number + cacheWriteInputTokens?: number generationTime?: number firstTokenTime?: number tokensPerSecond?: number @@ -2158,6 +2159,10 @@ export class NewAgentPresenter { outputTokens: typeof parsed.outputTokens === 'number' ? parsed.outputTokens : undefined, cachedInputTokens: typeof parsed.cachedInputTokens === 'number' ? parsed.cachedInputTokens : undefined, + cacheWriteInputTokens: + typeof parsed.cacheWriteInputTokens === 'number' + ? parsed.cacheWriteInputTokens + : undefined, generationTime: typeof parsed.generationTime === 'number' ? parsed.generationTime : undefined, firstTokenTime: @@ -2208,7 +2213,8 @@ export class NewAgentPresenter { modelId, metadata: { ...metadata, - cachedInputTokens: 0 + cachedInputTokens: metadata.cachedInputTokens ?? 0, + cacheWriteInputTokens: metadata.cacheWriteInputTokens ?? 0 }, source: 'backfill' }) diff --git a/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts b/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts index edada206e..a071f9474 100644 --- a/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts +++ b/src/main/presenter/sqlitePresenter/tables/deepchatUsageStats.ts @@ -12,6 +12,7 @@ export interface DeepChatUsageStatsRow { output_tokens: number total_tokens: number cached_input_tokens: number + cache_write_input_tokens: number estimated_cost_usd: number | null source: 'backfill' | 'live' created_at: number @@ -93,6 +94,7 @@ export class DeepChatUsageStatsTable extends BaseTable { output_tokens INTEGER NOT NULL DEFAULT 0, total_tokens INTEGER NOT NULL DEFAULT 0, cached_input_tokens INTEGER NOT NULL DEFAULT 0, + cache_write_input_tokens INTEGER NOT NULL DEFAULT 0, estimated_cost_usd REAL, source TEXT NOT NULL DEFAULT 'live', created_at INTEGER NOT NULL, @@ -108,11 +110,16 @@ export class DeepChatUsageStatsTable extends BaseTable { if (version === 17) { return this.getCreateTableSQL() } + if (version === 22) { + return this.hasColumn('cache_write_input_tokens') + ? null + : `ALTER TABLE deepchat_usage_stats ADD COLUMN cache_write_input_tokens INTEGER NOT NULL DEFAULT 0;` + } return null } getLatestVersion(): number { - return 17 + return 22 } upsert(row: UsageStatsRecordInput): void { @@ -128,11 +135,12 @@ export class DeepChatUsageStatsTable extends BaseTable { output_tokens, total_tokens, cached_input_tokens, + cache_write_input_tokens, estimated_cost_usd, source, created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(message_id) DO UPDATE SET session_id = excluded.session_id, usage_date = excluded.usage_date, @@ -142,6 +150,7 @@ export class DeepChatUsageStatsTable extends BaseTable { output_tokens = excluded.output_tokens, total_tokens = excluded.total_tokens, cached_input_tokens = excluded.cached_input_tokens, + cache_write_input_tokens = excluded.cache_write_input_tokens, estimated_cost_usd = excluded.estimated_cost_usd, source = excluded.source, created_at = excluded.created_at, @@ -157,6 +166,7 @@ export class DeepChatUsageStatsTable extends BaseTable { row.outputTokens, row.totalTokens, row.cachedInputTokens, + row.cacheWriteInputTokens, row.estimatedCostUsd, row.source, row.createdAt, diff --git a/src/main/presenter/usageStats.ts b/src/main/presenter/usageStats.ts index 495465737..9c550770a 100644 --- a/src/main/presenter/usageStats.ts +++ b/src/main/presenter/usageStats.ts @@ -22,6 +22,7 @@ export interface UsageStatsRecordInput { outputTokens: number totalTokens: number cachedInputTokens: number + cacheWriteInputTokens: number estimatedCostUsd: number | null source: UsageStatsSource createdAt: number @@ -152,18 +153,24 @@ export function normalizeUsageCounts(metadata: MessageMetadata): { outputTokens: number totalTokens: number cachedInputTokens: number + cacheWriteInputTokens: number } { const inputTokens = toNonNegativeInteger(metadata.inputTokens) ?? 0 const outputTokens = toNonNegativeInteger(metadata.outputTokens) ?? 0 const totalTokens = toNonNegativeInteger(metadata.totalTokens) ?? inputTokens + outputTokens const rawCached = toNonNegativeInteger(metadata.cachedInputTokens) ?? 0 - const cachedInputTokens = inputTokens > 0 ? Math.min(rawCached, inputTokens) : rawCached + const rawCacheWrite = toNonNegativeInteger(metadata.cacheWriteInputTokens) ?? 0 + const cappedCachedInputTokens = inputTokens > 0 ? Math.min(rawCached, inputTokens) : rawCached + const remainingInputTokens = Math.max(inputTokens - cappedCachedInputTokens, 0) + const cacheWriteInputTokens = + inputTokens > 0 ? Math.min(rawCacheWrite, remainingInputTokens) : rawCacheWrite return { inputTokens, outputTokens, totalTokens, - cachedInputTokens + cachedInputTokens: cappedCachedInputTokens, + cacheWriteInputTokens } } @@ -195,6 +202,7 @@ export function estimateUsageCostUsd(params: { inputTokens: number outputTokens: number cachedInputTokens: number + cacheWriteInputTokens: number }): number | null { const model = resolvePricedModel(params.providerId, params.modelId) const inputRate = getCostNumber(model, 'input') @@ -205,12 +213,17 @@ export function estimateUsageCostUsd(params: { } const cacheReadRate = getCostNumber(model, 'cache_read') - const billableInput = Math.max(params.inputTokens - params.cachedInputTokens, 0) + const cacheWriteRate = getCostNumber(model, 'cache_write') + const uncachedInput = Math.max( + params.inputTokens - params.cachedInputTokens - params.cacheWriteInputTokens, + 0 + ) return ( - (billableInput * inputRate + + (uncachedInput * inputRate + params.outputTokens * outputRate + - params.cachedInputTokens * (cacheReadRate ?? inputRate)) / + params.cachedInputTokens * (cacheReadRate ?? inputRate) + + params.cacheWriteInputTokens * (cacheWriteRate ?? inputRate)) / 1_000_000 ) } @@ -247,12 +260,14 @@ export function buildUsageStatsRecord(params: { outputTokens: usage.outputTokens, totalTokens: usage.totalTokens, cachedInputTokens: usage.cachedInputTokens, + cacheWriteInputTokens: usage.cacheWriteInputTokens, estimatedCostUsd: estimateUsageCostUsd({ providerId, modelId, inputTokens: usage.inputTokens, outputTokens: usage.outputTokens, - cachedInputTokens: usage.cachedInputTokens + cachedInputTokens: usage.cachedInputTokens, + cacheWriteInputTokens: usage.cacheWriteInputTokens }), source: params.source, createdAt: params.createdAt, diff --git a/src/renderer/settings/components/ProviderApiConfig.vue b/src/renderer/settings/components/ProviderApiConfig.vue index f62517f6f..cf912291d 100644 --- a/src/renderer/settings/components/ProviderApiConfig.vue +++ b/src/renderer/settings/components/ProviderApiConfig.vue @@ -26,7 +26,21 @@ {{ t('settings.provider.delete') }} +
+
+ + {{ apiHost || t('settings.provider.urlPlaceholder') }} + +
+ +
- + ' +}) + +const labelStub = defineComponent({ + name: 'Label', + inheritAttrs: false, + template: '' +}) + +const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ + id: 'deepseek', + name: 'DeepSeek', + apiType: 'openai-compatible', + apiKey: 'test-key', + baseUrl: 'https://api.deepseek.com/v1', + enable: true, + custom: false, + ...overrides +}) + +async function setup(options?: { + provider?: LLM_PROVIDER + providerWebsites?: { + official: string + apiKey: string + docs: string + models: string + defaultBaseUrl: string + } +}) { + vi.resetModules() + + const llmproviderPresenter = { + getKeyStatus: vi.fn().mockResolvedValue(null), + refreshModels: vi.fn().mockResolvedValue(undefined) + } + const modelCheckStore = { + openDialog: vi.fn() + } + + vi.doMock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'settings.provider.modifyBaseUrl') return 'Modify' + if (key === 'settings.provider.baseUrlLockedHint') { + return 'This provider is pinned to the recommended Base URL.' + } + if (key === 'settings.provider.urlPlaceholder') return 'Enter API URL' + if (key === 'settings.provider.urlFormat') { + return `Default: ${params?.defaultUrl ?? ''}` + } + if (key === 'settings.provider.urlFormatFill') return 'Fill into API URL' + if (key === 'settings.provider.dialog.baseUrlUnlock.confirm') return 'Continue' + return key + } + }) + })) + + vi.doMock('@/composables/usePresenter', () => ({ + usePresenter: (name: string) => { + if (name === 'llmproviderPresenter') return llmproviderPresenter + throw new Error(`Unexpected presenter: ${name}`) + } + })) + + vi.doMock('@/stores/modelCheck', () => ({ + useModelCheckStore: () => modelCheckStore + })) + + vi.doMock('@shadcn/components/ui/input', () => ({ + Input: createInputStub() + })) + vi.doMock('@shadcn/components/ui/button', () => ({ + Button: buttonStub + })) + vi.doMock('@shadcn/components/ui/label', () => ({ + Label: labelStub + })) + vi.doMock('@shadcn/components/ui/tooltip', () => ({ + Tooltip: passthrough('Tooltip'), + TooltipContent: passthrough('TooltipContent'), + TooltipProvider: passthrough('TooltipProvider'), + TooltipTrigger: passthrough('TooltipTrigger') + })) + vi.doMock('@iconify/vue', () => ({ + Icon: defineComponent({ + name: 'Icon', + template: '' + }) + })) + + const ProviderApiConfig = ( + await import('../../../src/renderer/settings/components/ProviderApiConfig.vue') + ).default + + const wrapper = mount(ProviderApiConfig, { + props: { + provider: options?.provider ?? createProvider(), + providerWebsites: options?.providerWebsites ?? { + official: 'https://example.com', + apiKey: 'https://example.com/key', + docs: 'https://example.com/docs', + models: 'https://example.com/models', + defaultBaseUrl: 'https://api.deepseek.com/v1' + } + }, + global: { + stubs: { + GitHubCopilotOAuth: true + } + } + }) + + await flushPromises() + + return { + wrapper, + llmproviderPresenter, + modelCheckStore + } +} + +function findButtonByText(wrapper: ReturnType, text: string) { + return wrapper.findAll('button').find((button) => button.text().trim() === text) +} + +describe('ProviderApiConfig', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('shows a locked Base URL display for built-in providers outside the allowlist', async () => { + const { wrapper, llmproviderPresenter } = await setup() + + expect(wrapper.find('input#deepseek-url').exists()).toBe(false) + expect(wrapper.text()).toContain('This provider is pinned to the recommended Base URL.') + expect(findButtonByText(wrapper, 'Modify')).toBeDefined() + expect(wrapper.html()).not.toContain('Fill into API URL') + expect(llmproviderPresenter.getKeyStatus).toHaveBeenCalledWith('deepseek') + }) + + it('switches directly into edit mode and hides the modify button', async () => { + const { wrapper } = await setup() + const modifyButton = findButtonByText(wrapper, 'Modify') + + expect(modifyButton).toBeDefined() + await modifyButton!.trigger('click') + await flushPromises() + + expect(wrapper.find('input#deepseek-url').exists()).toBe(true) + expect(findButtonByText(wrapper, 'Modify')).toBeUndefined() + }) + + it('preserves the existing save behavior after unlocking', async () => { + const { wrapper } = await setup() + const modifyButton = findButtonByText(wrapper, 'Modify') + + expect(modifyButton).toBeDefined() + await modifyButton!.trigger('click') + await flushPromises() + + const input = wrapper.get('input#deepseek-url') + await input.setValue('https://custom.deepseek.com/v1') + await input.trigger('blur') + + expect(wrapper.emitted('api-host-change')).toEqual([['https://custom.deepseek.com/v1']]) + }) + + it('keeps OpenAI Responses editable without the lock prompt', async () => { + const { wrapper } = await setup({ + provider: createProvider({ + id: 'openai-responses', + name: 'OpenAI Responses', + baseUrl: 'https://api.openai.com/v1' + }) + }) + + expect(wrapper.find('input#openai-responses-url').exists()).toBe(true) + expect(findButtonByText(wrapper, 'Modify')).toBeUndefined() + expect(wrapper.text()).not.toContain('This provider is pinned to the recommended Base URL.') + }) + + it('keeps custom providers editable by default', async () => { + const { wrapper } = await setup({ + provider: createProvider({ + id: 'custom-demo', + name: 'Custom Demo', + custom: true, + baseUrl: 'https://custom.example.com/v1' + }) + }) + + expect(wrapper.find('input#custom-demo-url').exists()).toBe(true) + expect(findButtonByText(wrapper, 'Modify')).toBeUndefined() + }) +})