From bb82c282e3f82e67430ea2da77d2bf879cc60fd5 Mon Sep 17 00:00:00 2001 From: Vincent Wu <284134889@qq.com> Date: Sun, 14 Jun 2026 12:06:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20support=20lm=20studio=20models=20#71?= =?UTF-8?q?=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bun.lock | 2 +- env.example | 4 ++ src/agent/agent.ts | 3 +- src/components/select-list.ts | 8 +++ src/controllers/model-selection.ts | 13 ++++- src/cron/executor.ts | 5 +- src/gateway/gateway.ts | 6 +-- src/model/llm.ts | 33 +++++++++++- src/providers.ts | 6 +++ src/utils/lm-studio.test.ts | 82 ++++++++++++++++++++++++++++++ src/utils/lm-studio.ts | 54 ++++++++++++++++++++ src/utils/model.ts | 35 ++++++++++++- 12 files changed, 237 insertions(+), 14 deletions(-) create mode 100644 src/utils/lm-studio.test.ts create mode 100644 src/utils/lm-studio.ts diff --git a/bun.lock b/bun.lock index 35abf08a6..78ce9100b 100644 --- a/bun.lock +++ b/bun.lock @@ -917,7 +917,7 @@ "leven": ["leven@3.1.0", "", {}, "sha512-qsda+H8jTaUaN/x5vzW2rzc+8Rw4TAQ/4KjB46IwK5VH+IlVeeeje/EoZRpiXvIqjFgK84QffqPztGI3VBLG1A=="], - "libsignal": ["@whiskeysockets/libsignal-node@github:whiskeysockets/libsignal-node#1c30d7d", { "dependencies": { "curve25519-js": "^0.0.4", "protobufjs": "6.8.8" } }, "WhiskeySockets-libsignal-node-1c30d7d"], + "libsignal": ["@whiskeysockets/libsignal-node@github:whiskeysockets/libsignal-node#1c30d7d", { "dependencies": { "curve25519-js": "^0.0.4", "protobufjs": "6.8.8" } }, "WhiskeySockets-libsignal-node-1c30d7d", "sha512-5q4/OuDQaMYx3RpDqMqS3WYyqjrsSMpU8ipQZtpYnm5l6DwNoLV9oIYMDK0NILKW+tyk3tVCIA11BMYQ+A1+GA=="], "lines-and-columns": ["lines-and-columns@1.2.4", "", {}, "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg=="], diff --git a/env.example b/env.example index 5ea4b4061..842097d76 100644 --- a/env.example +++ b/env.example @@ -10,6 +10,10 @@ DEEPSEEK_API_KEY=your-deepseek-api-key # Ollama (Local LLM) OLLAMA_BASE_URL=http://127.0.0.1:11434 +# LM Studio (Local OpenAI-compatible LLM) +LM_STUDIO_BASE_URL=http://127.0.0.1:1234/v1 +LM_STUDIO_MODEL=your-lm-studio-model-id + # Persistent memory embeddings reuse existing model API keys. # Priority: OpenAI (OPENAI_API_KEY) -> Gemini (GOOGLE_API_KEY) -> Ollama (OLLAMA_BASE_URL) # No additional memory-specific API keys required. diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 84f99934d..781874dc6 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -1,6 +1,6 @@ import { AIMessage, AIMessageChunk, SystemMessage, HumanMessage, ToolMessage, type BaseMessage } from '@langchain/core/messages'; import { StructuredToolInterface } from '@langchain/core/tools'; -import { callLlmWithMessages, streamLlmWithMessages } from '../model/llm.js'; +import { callLlmWithMessages, streamLlmWithMessages, DEFAULT_MODEL } from '../model/llm.js'; import { getTools, getToolConcurrencyMap } from '../tools/registry.js'; import { buildSystemPrompt, loadSoulDocument, loadRulesDocument } from './prompts.js'; import { extractTextContent, hasToolCalls } from '../utils/ai-message.js'; @@ -20,7 +20,6 @@ import { runMemoryFlush, shouldRunMemoryFlush } from '../memory/flush.js'; import { resolveProvider } from '../providers.js'; -const DEFAULT_MODEL = 'gpt-5.5'; const DEFAULT_MAX_ITERATIONS = 10; const MAX_OVERFLOW_RETRIES = 2; const OVERFLOW_KEEP_ROUNDS = 3; diff --git a/src/components/select-list.ts b/src/components/select-list.ts index 780948742..0ad18c9d0 100644 --- a/src/components/select-list.ts +++ b/src/components/select-list.ts @@ -29,6 +29,14 @@ class EmptyModelSelector extends Container { new Text(theme.muted('Make sure Ollama is running and you have models downloaded.'), 0, 0), ); } + if (providerId === 'lmstudio') { + this.addChild( + new Text(theme.muted('Make sure LM Studio is running and exposing its OpenAI-compatible API.'), 0, 0), + ); + this.addChild( + new Text(theme.muted('You can also preconfigure a default model via LM_STUDIO_MODEL.'), 0, 0), + ); + } this.addChild(new Text(theme.muted('esc to go back'), 0, 0)); } diff --git a/src/controllers/model-selection.ts b/src/controllers/model-selection.ts index 422a41d56..0ee746036 100644 --- a/src/controllers/model-selection.ts +++ b/src/controllers/model-selection.ts @@ -10,6 +10,7 @@ import { type Model, } from '../utils/model.js'; import { getOllamaModels } from '../utils/ollama.js'; +import { getLmStudioModels } from '../utils/lm-studio.js'; import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '../model/llm.js'; import { InMemoryChatHistory } from '../utils/in-memory-chat-history.js'; @@ -109,6 +110,14 @@ export class ModelSelectionController { return; } + if (providerId === 'lmstudio') { + const lmStudioModelIds = await getLmStudioModels(); + this.pendingModelsValue = lmStudioModelIds.map((id) => ({ id, displayName: id })); + this.appStateValue = 'model_select'; + this.emitChange(); + return; + } + this.pendingModelsValue = getModelsForProvider(providerId); this.appStateValue = 'model_select'; this.emitChange(); @@ -124,8 +133,8 @@ export class ModelSelectionController { return; } - if (this.pendingProviderValue === 'ollama') { - this.completeModelSwitch(this.pendingProviderValue, `ollama:${modelId}`); + if (this.pendingProviderValue === 'ollama' || this.pendingProviderValue === 'lmstudio') { + this.completeModelSwitch(this.pendingProviderValue, `${this.pendingProviderValue}:${modelId}`); return; } diff --git a/src/cron/executor.ts b/src/cron/executor.ts index 295560240..54b82cd93 100644 --- a/src/cron/executor.ts +++ b/src/cron/executor.ts @@ -10,6 +10,7 @@ import { resolveSessionStorePath, loadSessionStore, type SessionEntry } from '.. import { cleanMarkdownForWhatsApp } from '../gateway/utils.js'; import { getSetting } from '../utils/config.js'; import { dexterPath } from '../utils/paths.js'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '../model/llm.js'; import { saveCronStore } from './store.js'; import { computeNextRunAtMs } from './schedule.js'; import type { ActiveHours, CronJob, CronStore } from './types.js'; @@ -125,8 +126,8 @@ export async function executeCronJob( } // 3. Resolve model - const model = job.payload.model ?? (getSetting('modelId', 'gpt-5.5') as string); - const modelProvider = job.payload.modelProvider ?? (getSetting('provider', 'openai') as string); + const model = job.payload.model ?? (getSetting('modelId', DEFAULT_MODEL) as string); + const modelProvider = job.payload.modelProvider ?? (getSetting('provider', DEFAULT_PROVIDER) as string); // 4. Build query let query = `[CRON JOB: ${job.name}]\n\n${job.payload.message}`; diff --git a/src/gateway/gateway.ts b/src/gateway/gateway.ts index 9ca9ac6c9..a244ba874 100644 --- a/src/gateway/gateway.ts +++ b/src/gateway/gateway.ts @@ -25,6 +25,7 @@ import type { GroupContext } from '../agent/prompts.js'; import { appendFileSync } from 'node:fs'; import { dexterPath } from '../utils/paths.js'; import { getSetting } from '../utils/config.js'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '../model/llm.js'; const LOG_PATH = dexterPath('gateway-debug.log'); function debugLog(msg: string) { @@ -157,8 +158,8 @@ async function handleInbound(cfg: GatewayConfig, inbound: WhatsAppInboundMessage } console.log(`Processing message with agent...`); - const model = getSetting('modelId', 'gpt-5.5') as string; - const modelProvider = getSetting('provider', 'openai') as string; + const model = getSetting('modelId', DEFAULT_MODEL) as string; + const modelProvider = getSetting('provider', DEFAULT_PROVIDER) as string; // If agent is already running for this session, enqueue for mid-run injection if (isSessionRunning(route.sessionKey)) { @@ -238,4 +239,3 @@ export async function startGateway(params: { configPath?: string } = {}): Promis snapshot: () => manager.getSnapshot(), }; } - diff --git a/src/model/llm.ts b/src/model/llm.ts index 42fa8bc8b..23263f1b7 100644 --- a/src/model/llm.ts +++ b/src/model/llm.ts @@ -15,8 +15,28 @@ import { logger } from '@/utils'; import { classifyError, isNonRetryableError } from '@/utils/errors'; import { resolveProvider, getProviderById } from '@/providers'; -export const DEFAULT_PROVIDER = 'openai'; -export const DEFAULT_MODEL = 'gpt-5.5'; +/** + * 解析 LM Studio 的基础地址,默认指向本地 OpenAI-compatible 端点。 + */ +function getLmStudioBaseUrl(): string { + return process.env.LM_STUDIO_BASE_URL || 'http://127.0.0.1:1234/v1'; +} + +/** + * 读取 LM Studio 配置的默认模型,并补齐内部使用的 provider 前缀。 + */ +function getDefaultLmStudioModelId(): string | null { + const model = process.env.LM_STUDIO_MODEL?.trim(); + if (!model) { + return null; + } + return `lmstudio:${model}`; +} + +const DEFAULT_LM_STUDIO_MODEL = getDefaultLmStudioModelId(); + +export const DEFAULT_PROVIDER = DEFAULT_LM_STUDIO_MODEL ? 'lmstudio' : 'openai'; +export const DEFAULT_MODEL = DEFAULT_LM_STUDIO_MODEL ?? 'gpt-5.5'; /** * Gets the fast model variant for the given provider. @@ -132,6 +152,15 @@ const MODEL_FACTORIES: Record = { ...opts, ...(process.env.OLLAMA_BASE_URL ? { baseUrl: process.env.OLLAMA_BASE_URL } : {}), }), + lmstudio: (name, opts) => + new ChatOpenAI({ + model: name.replace(/^lmstudio:/, ''), + ...opts, + apiKey: process.env.LM_STUDIO_API_KEY || 'lm-studio', + configuration: { + baseURL: getLmStudioBaseUrl(), + }, + }), }; const DEFAULT_FACTORY: ModelFactory = (name, opts) => diff --git a/src/providers.ts b/src/providers.ts index 467b1b054..75f3d969b 100644 --- a/src/providers.ts +++ b/src/providers.ts @@ -81,6 +81,12 @@ export const PROVIDERS: ProviderDef[] = [ modelPrefix: 'ollama:', contextWindow: 128_000, }, + { + id: 'lmstudio', + displayName: 'LM Studio', + modelPrefix: 'lmstudio:', + contextWindow: 128_000, + }, ]; const defaultProvider = PROVIDERS.find((p) => p.id === 'openai')!; diff --git a/src/utils/lm-studio.test.ts b/src/utils/lm-studio.test.ts new file mode 100644 index 000000000..ef947d124 --- /dev/null +++ b/src/utils/lm-studio.test.ts @@ -0,0 +1,82 @@ +import { afterEach, beforeEach, describe, expect, mock, test } from 'bun:test'; +import { resolveProvider } from '../providers.js'; +import { getDefaultModelForProvider, getModelDisplayName, getModelsForProvider } from './model.js'; +import { getLmStudioModels } from './lm-studio.js'; + +const originalFetch = globalThis.fetch; +const originalBaseUrl = process.env.LM_STUDIO_BASE_URL; +const originalModel = process.env.LM_STUDIO_MODEL; + +/** + * 统一恢复 LM Studio 相关环境变量,避免测试之间相互污染。 + */ +function restoreLmStudioEnv() { + if (originalBaseUrl === undefined) { + delete process.env.LM_STUDIO_BASE_URL; + } else { + process.env.LM_STUDIO_BASE_URL = originalBaseUrl; + } + + if (originalModel === undefined) { + delete process.env.LM_STUDIO_MODEL; + } else { + process.env.LM_STUDIO_MODEL = originalModel; + } +} + +describe('LM Studio utilities', () => { + beforeEach(() => { + process.env.LM_STUDIO_BASE_URL = 'http://127.0.0.1:1234/v1'; + delete process.env.LM_STUDIO_MODEL; + }); + + afterEach(() => { + globalThis.fetch = originalFetch; + restoreLmStudioEnv(); + mock.restore(); + }); + + test('resolves the lmstudio provider from its model prefix', () => { + expect(resolveProvider('lmstudio:qwen/qwen3-8b').id).toBe('lmstudio'); + expect(getModelDisplayName('lmstudio:qwen/qwen3-8b')).toBe('qwen/qwen3-8b'); + }); + + test('exposes the configured LM Studio model in provider model lists', () => { + process.env.LM_STUDIO_MODEL = 'qwen/qwen3-8b'; + + expect(getModelsForProvider('lmstudio')).toEqual([ + { id: 'qwen/qwen3-8b', displayName: 'qwen/qwen3-8b' }, + ]); + }); + + test('qualifies the default LM Studio model with its provider prefix', () => { + process.env.LM_STUDIO_MODEL = 'qwen/qwen3-8b'; + + expect(getDefaultModelForProvider('lmstudio')).toBe('lmstudio:qwen/qwen3-8b'); + }); + + test('returns models from the LM Studio API when available', async () => { + globalThis.fetch = mock(async () => + new Response( + JSON.stringify({ + data: [{ id: 'qwen/qwen3-8b' }, { id: 'deepseek/deepseek-r1-0528-qwen3-8b' }], + }), + { status: 200 }, + ), + ) as unknown as typeof fetch; + + await expect(getLmStudioModels()).resolves.toEqual([ + 'qwen/qwen3-8b', + 'deepseek/deepseek-r1-0528-qwen3-8b', + ]); + }); + + test('falls back to LM_STUDIO_MODEL when the API is unreachable', async () => { + process.env.LM_STUDIO_MODEL = 'qwen/qwen3-8b'; + globalThis.fetch = mock(async () => { + throw new Error('connection refused'); + }) as unknown as typeof fetch; + + await expect(getLmStudioModels()).resolves.toEqual(['qwen/qwen3-8b']); + }); +}); diff --git a/src/utils/lm-studio.ts b/src/utils/lm-studio.ts new file mode 100644 index 000000000..79c87ff38 --- /dev/null +++ b/src/utils/lm-studio.ts @@ -0,0 +1,54 @@ +/** + * LM Studio OpenAI-compatible API utilities. + */ + +interface LmStudioModel { + id?: string; +} + +interface LmStudioModelsResponse { + data?: LmStudioModel[]; +} + +/** + * 解析 LM Studio 的模型列表接口地址。 + */ +function getLmStudioBaseUrl(): string { + return process.env.LM_STUDIO_BASE_URL || 'http://127.0.0.1:1234/v1'; +} + +/** + * 返回环境变量中配置的 LM Studio 默认模型。 + */ +function getConfiguredLmStudioModel(): string | null { + const model = process.env.LM_STUDIO_MODEL?.trim(); + return model || null; +} + +/** + * 从 LM Studio 拉取可用模型;当服务不可达时回退到环境变量中的默认模型。 + */ +export async function getLmStudioModels(): Promise { + const configuredModel = getConfiguredLmStudioModel(); + + try { + const response = await fetch(`${getLmStudioBaseUrl()}/models`); + + if (!response.ok) { + return configuredModel ? [configuredModel] : []; + } + + const data = (await response.json()) as LmStudioModelsResponse; + const models = (data.data ?? []) + .map((model) => model.id?.trim()) + .filter((model): model is string => Boolean(model)); + + if (models.length > 0) { + return models; + } + + return configuredModel ? [configuredModel] : []; + } catch { + return configuredModel ? [configuredModel] : []; + } +} diff --git a/src/utils/model.ts b/src/utils/model.ts index e41baa605..7a4a2fe63 100644 --- a/src/utils/model.ts +++ b/src/utils/model.ts @@ -11,6 +11,33 @@ interface Provider { models: Model[]; } +/** + * 为需要 provider 前缀的模型补齐内部模型 ID,确保底层路由可以识别。 + */ +function qualifyModelId(providerId: string, modelId: string): string { + if (providerId === 'lmstudio' && !modelId.startsWith('lmstudio:')) { + return `lmstudio:${modelId}`; + } + if (providerId === 'ollama' && !modelId.startsWith('ollama:')) { + return `ollama:${modelId}`; + } + if (providerId === 'openrouter' && !modelId.startsWith('openrouter:')) { + return `openrouter:${modelId}`; + } + return modelId; +} + +/** + * 读取 LM Studio 在环境变量中配置的默认模型,供本地 provider 直接复用。 + */ +function getLmStudioConfiguredModel(): Model[] { + const modelId = process.env.LM_STUDIO_MODEL?.trim(); + if (!modelId) { + return []; + } + return [{ id: modelId, displayName: modelId }]; +} + const PROVIDER_MODELS: Record = { openai: [ { id: 'gpt-5.5', displayName: 'GPT 5.5' }, @@ -42,6 +69,9 @@ export const PROVIDERS: Provider[] = PROVIDER_DEFS.map((provider) => ({ })); export function getModelsForProvider(providerId: string): Model[] { + if (providerId === 'lmstudio') { + return getLmStudioConfiguredModel(); + } const provider = PROVIDERS.find((entry) => entry.providerId === providerId); return provider?.models ?? []; } @@ -52,11 +82,12 @@ export function getModelIdsForProvider(providerId: string): string[] { export function getDefaultModelForProvider(providerId: string): string | undefined { const models = getModelsForProvider(providerId); - return models[0]?.id; + const modelId = models[0]?.id; + return modelId ? qualifyModelId(providerId, modelId) : undefined; } export function getModelDisplayName(modelId: string): string { - const normalizedId = modelId.replace(/^(ollama|openrouter):/, ''); + const normalizedId = modelId.replace(/^(ollama|openrouter|lmstudio):/, ''); for (const provider of PROVIDERS) { const model = provider.models.find((entry) => entry.id === normalizedId || entry.id === modelId);