Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/api/drizzle/0031_add_wiki_compose_sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ END $$;
DO $$ BEGIN
ALTER TABLE "wiki_compose_sessions"
ADD CONSTRAINT "wiki_compose_sessions_user_id_users_id_fk"
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE cascade ON UPDATE no action;
FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN NULL;
END $$;
Expand Down
2 changes: 1 addition & 1 deletion server/api/drizzle/0032_add_user_ai_credentials.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ CREATE TABLE IF NOT EXISTS "user_ai_credentials" (
DO $$ BEGIN
ALTER TABLE "user_ai_credentials"
ADD CONSTRAINT "user_ai_credentials_user_id_users_id_fk"
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE cascade ON UPDATE no action;
FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN NULL;
END $$;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ import { describe, expect, it, vi, beforeEach } from "vitest";
import { HTTPException } from "hono/http-exception";
import { assertComposeBackendReady } from "../../../agents/core/composeBackendValidation.js";

const mockValidateModelAccess = vi.fn();
const mockGetUserAiCredentialPlaintext = vi.fn();

vi.mock("../../../services/usageService.js", () => ({
validateModelAccess: (...args: unknown[]) => mockValidateModelAccess(...args),
}));

vi.mock("../../../services/userAiCredentialService.js", () => ({
getUserAiCredentialPlaintext: (...args: unknown[]) => mockGetUserAiCredentialPlaintext(...args),
}));
Expand All @@ -18,12 +13,6 @@ describe("assertComposeBackendReady", () => {

beforeEach(() => {
vi.clearAllMocks();
mockValidateModelAccess.mockResolvedValue({
provider: "anthropic",
apiModelId: "claude-3-5-haiku",
inputCostUnits: 1,
outputCostUnits: 2,
});
mockGetUserAiCredentialPlaintext.mockResolvedValue("sk-user");
});

Expand All @@ -35,25 +24,29 @@ describe("assertComposeBackendReady", () => {
tier: "free",
db,
});
expect(mockValidateModelAccess).not.toHaveBeenCalled();
expect(mockGetUserAiCredentialPlaintext).not.toHaveBeenCalled();
});

it("throws 400 when model provider mismatches BYOK backend", async () => {
mockValidateModelAccess.mockResolvedValue({
provider: "openai",
apiModelId: "gpt-4o-mini",
inputCostUnits: 1,
outputCostUnits: 2,
it("allows BYOK backend without static env model provider mismatch (#972)", async () => {
await assertComposeBackendReady({
backend: "user_openai",
graphId: "wiki-compose-research",
userId: "u1",
tier: "free",
db,
});
await expect(
assertComposeBackendReady({
backend: "user_anthropic",
graphId: "wiki-compose-research",
userId: "u1",
tier: "free",
db,
}),
).rejects.toMatchObject({ status: 400 });
expect(mockGetUserAiCredentialPlaintext).toHaveBeenCalledWith("u1", "openai", db);
});

it("skips credential check for model-less graphs (wiki-maintenance)", async () => {
await assertComposeBackendReady({
backend: "user_anthropic",
graphId: "wiki-maintenance",
userId: "u1",
tier: "free",
db,
});
expect(mockGetUserAiCredentialPlaintext).not.toHaveBeenCalled();
});

it("throws 400 when credential is missing", async () => {
Expand Down
31 changes: 30 additions & 1 deletion server/api/src/__tests__/agents/core/llm/zediChatModel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ describe("ZediChatModel._streamResponseChunks", () => {
});

it("uses 'incomplete' finishReason when the provider stream ends without done=true", async () => {
const { db } = asDb([undefined, undefined]);
const { db, chains } = asDb([undefined, undefined]);
const model = new ZediChatModel({
provider: "openai",
apiKey: "k",
Expand All @@ -250,6 +250,35 @@ describe("ZediChatModel._streamResponseChunks", () => {
response_metadata?: { finishReason?: string };
};
expect(last.response_metadata?.finishReason).toBe("incomplete");
expect(chains.length).toBe(0);
});

it("does not record usage when the provider stream throws", async () => {
const { db, chains } = asDb([undefined, undefined]);
const model = new ZediChatModel({
provider: "openai",
apiKey: "k",
apiModelId: "m",
modelRowId: "m",
inputCostUnits: 1,
outputCostUnits: 2,
userId: "u",
tier: "free",
db,
feature: "x",
streamProvider: async function* () {
yield { content: "partial" };
throw new Error("provider 502");
},
});

const stream = await model.stream([new HumanMessage("hi")]);
await expect(async () => {
for await (const _chunk of stream) {
/* drain */
}
}).rejects.toThrow("provider 502");
expect(chains.length).toBe(0);
});
});

Expand Down
11 changes: 11 additions & 0 deletions server/api/src/__tests__/agents/runner/sseMapper.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ describe("mapLangGraphEvent", () => {
]);
});

it("maps on_custom_event compose_phase conflict to a typed compose_phase event", () => {
const ev: LangGraphRuntimeEvent = {
event: "on_custom_event",
name: "compose_phase",
data: { phase: "conflict", status: "entered" },
};
expect(mapLangGraphEvent(ev)).toEqual([
{ type: "compose_phase", phase: "conflict", status: "entered" },
]);
});

it("drops compose_phase with an unknown phase value", () => {
const ev: LangGraphRuntimeEvent = {
event: "on_custom_event",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,6 @@ describe("projectComposeStateValues", () => {
});
expect(projection.completedMarkdown).toBe("## A\n\nBody");
expect(projection.draftedSections).toHaveLength(1);
expect(projection.phase).toBe("completed");
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ export function getPostgresCheckpointer(schema: string = "public"): PostgresSave
export async function ensurePostgresCheckpointerSetup(schema: string = "public"): Promise<void> {
if (!setupOnce) {
const saver = getPostgresCheckpointer(schema);
setupOnce = saver.setup();
setupOnce = saver.setup().catch((error: unknown) => {
setupOnce = null;
throw error;
});
}
return setupOnce;
}
Expand Down
31 changes: 16 additions & 15 deletions server/api/src/agents/core/composeBackendValidation.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
/**
* Validate BYOK backend against compose graph model providers (#951).
* Compose グラフのモデル provider と BYOK backend の整合を検証する。
* Pre-flight BYOK checks for Wiki Compose session creation (#951).
* Wiki Compose セッション作成前の BYOK 事前チェック(#951)。
*
* Static env model ids are not validated here — provider matching is enforced at
* runtime via {@link resolveComposeModelId}. This function only verifies that
* the user has a stored credential when the graph will call an LLM.
*
* 静的 env モデル id との provider 照合は行わない。実行時の
* `resolveComposeModelId` が provider 整合を担保する。本関数は LLM を呼ぶ
* グラフで credential が存在するかだけを確認する。
*/
import { HTTPException } from "hono/http-exception";
import { validateModelAccess } from "../../services/usageService.js";
import type { Database, UserTier } from "../../types/index.js";
import { getComposeModelIdsForGraph } from "./composeModelConfig.js";
import {
Expand All @@ -14,8 +21,8 @@ import {
import { getUserAiCredentialPlaintext } from "../../services/userAiCredentialService.js";

/**
* Ensure BYOK backend matches configured compose models and credentials exist.
* BYOK backend が compose モデルと credential と整合するか検証する
* Ensure a BYOK backend has stored credentials when the target graph uses LLMs.
* LLM を使うグラフ向け BYOK backend credential が存在するか検証する
*/
export async function assertComposeBackendReady(input: {
backend: ExecutionBackend;
Expand All @@ -26,18 +33,12 @@ export async function assertComposeBackendReady(input: {
}): Promise<void> {
if (!isUserByokBackend(input.backend)) return;

const expectedProvider = backendToCredentialProvider(input.backend);
const modelIds = getComposeModelIdsForGraph(input.graphId);
// Model-less graphs (e.g. wiki-maintenance) never call `createZediChatModel`.
// LLM を呼ばないグラフ(wiki-maintenance 等)は credential 不要。
if (modelIds.length === 0) return;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for (const modelId of modelIds) {
const modelInfo = await validateModelAccess(modelId, input.tier, input.db);
if (modelInfo.provider !== expectedProvider) {
throw new HTTPException(400, {
message: `Backend "${input.backend}" does not match compose model "${modelId}" (provider ${modelInfo.provider})`,
});
}
}

const expectedProvider = backendToCredentialProvider(input.backend);
const key = await getUserAiCredentialPlaintext(input.userId, expectedProvider, input.db);
if (!key?.trim()) {
throw new HTTPException(400, {
Expand Down
56 changes: 40 additions & 16 deletions server/api/src/agents/core/llm/zediChatModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ import type { BaseMessage } from "@langchain/core/messages";
import { AIMessage, AIMessageChunk } from "@langchain/core/messages";
import { ChatGenerationChunk, type ChatResult } from "@langchain/core/outputs";
import { callProvider, streamProvider } from "../../../services/aiProviders.js";
import { calculateCost } from "../../../services/usageService.js";
import type {
AIChatOptions,
AIProviderType,
ApiMode,
Database,
UserTier,
} from "../../../types/index.js";
import { recordZediUsage, toZediMessages } from "./usageCallback.js";
import { recordZediUsage, toZediMessages, type RecordZediUsageResult } from "./usageCallback.js";

/**
* `callProvider` / `streamProvider` のインジェクション型。テストでは fake を
Expand Down Expand Up @@ -290,9 +291,9 @@ export class ZediChatModel extends BaseChatModel<BaseChatModelCallOptions> {
text: chunk.content,
message: new AIMessageChunk({ content: chunk.content }),
});
// LangChain callback / SSE 向けにトークン delta を先に流す。
// Surface incremental tokens to LangChain callback consumers so any
// `streamEvents` listener (e.g. SSE mapper) sees deltas before the
// final usage event.
// `streamEvents` listener (e.g. SSE mapper) sees deltas before usage.
await runManager?.handleLLMNewToken(
chunk.content,
undefined,
Expand All @@ -312,26 +313,49 @@ export class ZediChatModel extends BaseChatModel<BaseChatModelCallOptions> {
}
}

// Streaming providers in this codebase do not return token counts; mirror
// chat.ts and estimate. Pre-encoded message length is what the user "paid"
// for, response length is what they received.
const promptLength = zediMessages.reduce((sum, m) => sum + m.content.length, 0);
const inputTokens = Math.ceil(promptLength / 4);
const outputTokens = Math.ceil(accumulated.length / 4);

const usage = await recordZediUsage({
db: this.db,
userId: this.userId,
modelId: this.modelRowId,
feature: this.feature,
usage: { inputTokens, outputTokens },
inputCostUnits: this.inputCostUnits,
outputCostUnits: this.outputCostUnits,
apiMode: this.apiMode,
});
let usage: RecordZediUsageResult;
if (done) {
try {
usage = await recordZediUsage({
db: this.db,
userId: this.userId,
modelId: this.modelRowId,
feature: this.feature,
usage: { inputTokens, outputTokens },
inputCostUnits: this.inputCostUnits,
outputCostUnits: this.outputCostUnits,
apiMode: this.apiMode,
});
} catch (err) {
// Billing failure must not mask a successful stream.
// 課金記録失敗で成功ストリームを潰さない。
console.error("Failed to record streaming usage", err);
usage = {
inputTokens,
outputTokens,
costUnits:
this.apiMode === "user_key"
? 0
: calculateCost(
{ inputTokens, outputTokens },
this.inputCostUnits,
this.outputCostUnits,
),
};
}
} else {
// Stream ended without `done` — expose metadata only, no DB billing (chat.ts 同様).
// `done` 未到達で終了した incomplete ストリームは DB 課金しない。
usage = { inputTokens, outputTokens, costUnits: 0 };
}

// Final chunk surfaces aggregate usage so downstream consumers (sseMapper,
// LangChain callbacks) can read totals from a single ChatGenerationChunk.
// 集計 usage を最終チャンクで返し、sseMapper 等が 1 箇所から読めるようにする。
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
Expand Down
24 changes: 15 additions & 9 deletions server/api/src/agents/core/tools/resolveWebSearchModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,19 @@ export async function resolveWebSearchModelId(
)
.orderBy(asc(aiModels.inputCostUnits), asc(aiModels.outputCostUnits));

// Prefer OpenAI when costs tie, since `useWebSearch` (chat completions
// `web_search_options`) is well-tested in `aiProviders.ts`; Google's
// `googleSearch` tool is also supported but requires the `tools` payload.
// 同コストなら OpenAI を優先(`useWebSearch` 経路が安定)。
const openai = rows.find((r) => r.provider === "openai");
if (openai) return openai.id;
const google = rows.find((r) => r.provider === "google");
if (google) return google.id;
return null;
if (rows.length === 0) return null;

const [first] = rows;
if (!first) return null;

// Prefer OpenAI only among cheapest rows (cost tie-break), since `useWebSearch`
// is well-tested in `aiProviders.ts`.
// 最安行の中でのみ OpenAI を優先(`aiProviders.ts` の useWebSearch 経路が安定)。
const cheapestInput = first.inputCostUnits;
const cheapestOutput = first.outputCostUnits;
const cheapest = rows.filter(
(r) => r.inputCostUnits === cheapestInput && r.outputCostUnits === cheapestOutput,
);
const preferred = cheapest.find((r) => r.provider === "openai") ?? cheapest[0];
return preferred?.id ?? null;
}
2 changes: 2 additions & 0 deletions server/api/src/agents/core/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ export {
type SseResearchIterationEvent,
type SseResearchEvaluationEvent,
type SseResearchBatchEvent,
type SseComposePhaseEvent,
type SseComposeSectionEvent,
SSE_EVENT_NAMES,
} from "./sseEvents.js";
1 change: 1 addition & 0 deletions server/api/src/agents/runner/sseMapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ function mapComposePhase(data: Record<string, unknown>): SseComposePhaseEvent[]
if (
phase !== "brief" &&
phase !== "research" &&
phase !== "conflict" &&
phase !== "structure" &&
phase !== "draft" &&
phase !== "completed"
Expand Down
22 changes: 17 additions & 5 deletions server/api/src/agents/subgraphs/research/resumeSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,23 @@ import { z } from "zod";
* required (empty array means "reject all"); `rejectedSourceIds` defaults to
* the empty array; `note` is free-form metadata.
*/
export const researchResumeSchema = z.object({
approvedSourceIds: z.array(z.string().min(1)).default([]),
rejectedSourceIds: z.array(z.string().min(1)).optional().default([]),
note: z.string().optional(),
});
export const researchResumeSchema = z
.object({
approvedSourceIds: z.array(z.string().min(1)),
rejectedSourceIds: z.array(z.string().min(1)).optional().default([]),
note: z.string().optional(),
})
.superRefine((value, ctx) => {
const rejected = new Set(value.rejectedSourceIds);
const overlap = value.approvedSourceIds.filter((id) => rejected.has(id));
if (overlap.length > 0) {
ctx.addIssue({
code: "custom",
path: ["rejectedSourceIds"],
message: `approvedSourceIds and rejectedSourceIds must not overlap: ${overlap.join(", ")}`,
});
}
});

/** Inferred TS type for the parsed resume payload. */
export type ResearchResumeParsed = z.infer<typeof researchResumeSchema>;
Loading
Loading