Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/api/drizzle/0031_add_wiki_compose_sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ END $$;
DO $$ BEGIN
ALTER TABLE "wiki_compose_sessions"
ADD CONSTRAINT "wiki_compose_sessions_user_id_users_id_fk"
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE cascade ON UPDATE no action;
FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN NULL;
END $$;
Expand Down
2 changes: 1 addition & 1 deletion server/api/drizzle/0032_add_user_ai_credentials.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ CREATE TABLE IF NOT EXISTS "user_ai_credentials" (
DO $$ BEGIN
ALTER TABLE "user_ai_credentials"
ADD CONSTRAINT "user_ai_credentials_user_id_users_id_fk"
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE cascade ON UPDATE no action;
FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN NULL;
END $$;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ import { describe, expect, it, vi, beforeEach } from "vitest";
import { HTTPException } from "hono/http-exception";
import { assertComposeBackendReady } from "../../../agents/core/composeBackendValidation.js";

const mockValidateModelAccess = vi.fn();
const mockGetUserAiCredentialPlaintext = vi.fn();

vi.mock("../../../services/usageService.js", () => ({
validateModelAccess: (...args: unknown[]) => mockValidateModelAccess(...args),
}));

vi.mock("../../../services/userAiCredentialService.js", () => ({
getUserAiCredentialPlaintext: (...args: unknown[]) => mockGetUserAiCredentialPlaintext(...args),
}));
Expand All @@ -18,12 +13,6 @@ describe("assertComposeBackendReady", () => {

beforeEach(() => {
vi.clearAllMocks();
mockValidateModelAccess.mockResolvedValue({
provider: "anthropic",
apiModelId: "claude-3-5-haiku",
inputCostUnits: 1,
outputCostUnits: 2,
});
mockGetUserAiCredentialPlaintext.mockResolvedValue("sk-user");
});

Expand All @@ -35,25 +24,29 @@ describe("assertComposeBackendReady", () => {
tier: "free",
db,
});
expect(mockValidateModelAccess).not.toHaveBeenCalled();
expect(mockGetUserAiCredentialPlaintext).not.toHaveBeenCalled();
});

it("throws 400 when model provider mismatches BYOK backend", async () => {
mockValidateModelAccess.mockResolvedValue({
provider: "openai",
apiModelId: "gpt-4o-mini",
inputCostUnits: 1,
outputCostUnits: 2,
it("allows BYOK backend without static env model provider mismatch (#972)", async () => {
await assertComposeBackendReady({
backend: "user_openai",
graphId: "wiki-compose-research",
userId: "u1",
tier: "free",
db,
});
await expect(
assertComposeBackendReady({
backend: "user_anthropic",
graphId: "wiki-compose-research",
userId: "u1",
tier: "free",
db,
}),
).rejects.toMatchObject({ status: 400 });
expect(mockGetUserAiCredentialPlaintext).toHaveBeenCalledWith("u1", "openai", db);
});

it("skips credential check for model-less graphs (wiki-maintenance)", async () => {
await assertComposeBackendReady({
backend: "user_anthropic",
graphId: "wiki-maintenance",
userId: "u1",
tier: "free",
db,
});
expect(mockGetUserAiCredentialPlaintext).not.toHaveBeenCalled();
});

it("throws 400 when credential is missing", async () => {
Expand Down
31 changes: 30 additions & 1 deletion server/api/src/__tests__/agents/core/llm/zediChatModel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ describe("ZediChatModel._streamResponseChunks", () => {
});

it("uses 'incomplete' finishReason when the provider stream ends without done=true", async () => {
const { db } = asDb([undefined, undefined]);
const { db, chains } = asDb([undefined, undefined]);
const model = new ZediChatModel({
provider: "openai",
apiKey: "k",
Expand All @@ -250,6 +250,35 @@ describe("ZediChatModel._streamResponseChunks", () => {
response_metadata?: { finishReason?: string };
};
expect(last.response_metadata?.finishReason).toBe("incomplete");
expect(chains.length).toBe(0);
});

it("does not record usage when the provider stream throws", async () => {
const { db, chains } = asDb([undefined, undefined]);
const model = new ZediChatModel({
provider: "openai",
apiKey: "k",
apiModelId: "m",
modelRowId: "m",
inputCostUnits: 1,
outputCostUnits: 2,
userId: "u",
tier: "free",
db,
feature: "x",
streamProvider: async function* () {
yield { content: "partial" };
throw new Error("provider 502");
},
});

const stream = await model.stream([new HumanMessage("hi")]);
await expect(async () => {
for await (const _chunk of stream) {
/* drain */
}
}).rejects.toThrow("provider 502");
expect(chains.length).toBe(0);
});
});

Expand Down
11 changes: 11 additions & 0 deletions server/api/src/__tests__/agents/runner/sseMapper.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ describe("mapLangGraphEvent", () => {
]);
});

it("maps on_custom_event compose_phase conflict to a typed compose_phase event", () => {
const ev: LangGraphRuntimeEvent = {
event: "on_custom_event",
name: "compose_phase",
data: { phase: "conflict", status: "entered" },
};
expect(mapLangGraphEvent(ev)).toEqual([
{ type: "compose_phase", phase: "conflict", status: "entered" },
]);
});

it("drops compose_phase with an unknown phase value", () => {
const ev: LangGraphRuntimeEvent = {
event: "on_custom_event",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,6 @@ describe("projectComposeStateValues", () => {
});
expect(projection.completedMarkdown).toBe("## A\n\nBody");
expect(projection.draftedSections).toHaveLength(1);
expect(projection.phase).toBe("completed");
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ export function getPostgresCheckpointer(schema: string = "public"): PostgresSave
export async function ensurePostgresCheckpointerSetup(schema: string = "public"): Promise<void> {
if (!setupOnce) {
const saver = getPostgresCheckpointer(schema);
setupOnce = saver.setup();
setupOnce = saver.setup().catch((error: unknown) => {
setupOnce = null;
throw error;
});
}
return setupOnce;
}
Expand Down
31 changes: 16 additions & 15 deletions server/api/src/agents/core/composeBackendValidation.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
/**
* Validate BYOK backend against compose graph model providers (#951).
* Compose グラフのモデル provider と BYOK backend の整合を検証する。
* Pre-flight BYOK checks for Wiki Compose session creation (#951).
* Wiki Compose セッション作成前の BYOK 事前チェック(#951)。
*
* Static env model ids are not validated here — provider matching is enforced at
* runtime via {@link resolveComposeModelId}. This function only verifies that
* the user has a stored credential when the graph will call an LLM.
*
* 静的 env モデル id との provider 照合は行わない。実行時の
* `resolveComposeModelId` が provider 整合を担保する。本関数は LLM を呼ぶ
* グラフで credential が存在するかだけを確認する。
*/
import { HTTPException } from "hono/http-exception";
import { validateModelAccess } from "../../services/usageService.js";
import type { Database, UserTier } from "../../types/index.js";
import { getComposeModelIdsForGraph } from "./composeModelConfig.js";
import {
Expand All @@ -14,8 +21,8 @@ import {
import { getUserAiCredentialPlaintext } from "../../services/userAiCredentialService.js";

/**
* Ensure BYOK backend matches configured compose models and credentials exist.
* BYOK backend が compose モデルと credential と整合するか検証する
* Ensure a BYOK backend has stored credentials when the target graph uses LLMs.
* LLM を使うグラフ向け BYOK backend credential が存在するか検証する
*/
export async function assertComposeBackendReady(input: {
backend: ExecutionBackend;
Expand All @@ -26,18 +33,12 @@ export async function assertComposeBackendReady(input: {
}): Promise<void> {
if (!isUserByokBackend(input.backend)) return;

const expectedProvider = backendToCredentialProvider(input.backend);
const modelIds = getComposeModelIdsForGraph(input.graphId);
// Model-less graphs (e.g. wiki-maintenance) never call `createZediChatModel`.
// LLM を呼ばないグラフ(wiki-maintenance 等)は credential 不要。
if (modelIds.length === 0) return;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for (const modelId of modelIds) {
const modelInfo = await validateModelAccess(modelId, input.tier, input.db);
if (modelInfo.provider !== expectedProvider) {
throw new HTTPException(400, {
message: `Backend "${input.backend}" does not match compose model "${modelId}" (provider ${modelInfo.provider})`,
});
}
}

const expectedProvider = backendToCredentialProvider(input.backend);
const key = await getUserAiCredentialPlaintext(input.userId, expectedProvider, input.db);
if (!key?.trim()) {
throw new HTTPException(400, {
Expand Down
56 changes: 40 additions & 16 deletions server/api/src/agents/core/llm/zediChatModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ import type { BaseMessage } from "@langchain/core/messages";
import { AIMessage, AIMessageChunk } from "@langchain/core/messages";
import { ChatGenerationChunk, type ChatResult } from "@langchain/core/outputs";
import { callProvider, streamProvider } from "../../../services/aiProviders.js";
import { calculateCost } from "../../../services/usageService.js";
import type {
AIChatOptions,
AIProviderType,
ApiMode,
Database,
UserTier,
} from "../../../types/index.js";
import { recordZediUsage, toZediMessages } from "./usageCallback.js";
import { recordZediUsage, toZediMessages, type RecordZediUsageResult } from "./usageCallback.js";

/**
* `callProvider` / `streamProvider` のインジェクション型。テストでは fake を
Expand Down Expand Up @@ -290,9 +291,9 @@ export class ZediChatModel extends BaseChatModel<BaseChatModelCallOptions> {
text: chunk.content,
message: new AIMessageChunk({ content: chunk.content }),
});
// LangChain callback / SSE 向けにトークン delta を先に流す。
// Surface incremental tokens to LangChain callback consumers so any
// `streamEvents` listener (e.g. SSE mapper) sees deltas before the
// final usage event.
// `streamEvents` listener (e.g. SSE mapper) sees deltas before usage.
await runManager?.handleLLMNewToken(
chunk.content,
undefined,
Expand All @@ -312,26 +313,49 @@ export class ZediChatModel extends BaseChatModel<BaseChatModelCallOptions> {
}
}

// Streaming providers in this codebase do not return token counts; mirror
// chat.ts and estimate. Pre-encoded message length is what the user "paid"
// for, response length is what they received.
const promptLength = zediMessages.reduce((sum, m) => sum + m.content.length, 0);
const inputTokens = Math.ceil(promptLength / 4);
const outputTokens = Math.ceil(accumulated.length / 4);

const usage = await recordZediUsage({
db: this.db,
userId: this.userId,
modelId: this.modelRowId,
feature: this.feature,
usage: { inputTokens, outputTokens },
inputCostUnits: this.inputCostUnits,
outputCostUnits: this.outputCostUnits,
apiMode: this.apiMode,
});
let usage: RecordZediUsageResult;
if (done) {
try {
usage = await recordZediUsage({
db: this.db,
userId: this.userId,
modelId: this.modelRowId,
feature: this.feature,
usage: { inputTokens, outputTokens },
inputCostUnits: this.inputCostUnits,
outputCostUnits: this.outputCostUnits,
apiMode: this.apiMode,
});
} catch (err) {
// Billing failure must not mask a successful stream.
// 課金記録失敗で成功ストリームを潰さない。
console.error("Failed to record streaming usage", err);
usage = {
inputTokens,
outputTokens,
costUnits:
this.apiMode === "user_key"
? 0
: calculateCost(
{ inputTokens, outputTokens },
this.inputCostUnits,
this.outputCostUnits,
),
};
}
} else {
// Stream ended without `done` — expose metadata only, no DB billing (chat.ts 同様).
// `done` 未到達で終了した incomplete ストリームは DB 課金しない。
usage = { inputTokens, outputTokens, costUnits: 0 };
}

// Final chunk surfaces aggregate usage so downstream consumers (sseMapper,
// LangChain callbacks) can read totals from a single ChatGenerationChunk.
// 集計 usage を最終チャンクで返し、sseMapper 等が 1 箇所から読めるようにする。
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
Expand Down
21 changes: 12 additions & 9 deletions server/api/src/agents/core/tools/resolveWebSearchModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,16 @@ export async function resolveWebSearchModelId(
)
.orderBy(asc(aiModels.inputCostUnits), asc(aiModels.outputCostUnits));

// Prefer OpenAI when costs tie, since `useWebSearch` (chat completions
// `web_search_options`) is well-tested in `aiProviders.ts`; Google's
// `googleSearch` tool is also supported but requires the `tools` payload.
// 同コストなら OpenAI を優先(`useWebSearch` 経路が安定)。
const openai = rows.find((r) => r.provider === "openai");
if (openai) return openai.id;
const google = rows.find((r) => r.provider === "google");
if (google) return google.id;
return null;
if (rows.length === 0) return null;

// Prefer OpenAI only among cheapest rows (cost tie-break), since `useWebSearch`
// is well-tested in `aiProviders.ts`.
// 最安行の中でのみ OpenAI を優先(`aiProviders.ts` の useWebSearch 経路が安定)。
const cheapestInput = rows[0].inputCostUnits;
const cheapestOutput = rows[0].outputCostUnits;
const cheapest = rows.filter(
(r) => r.inputCostUnits === cheapestInput && r.outputCostUnits === cheapestOutput,
);
const preferred = cheapest.find((r) => r.provider === "openai") ?? cheapest[0];
return preferred?.id ?? null;
}
2 changes: 2 additions & 0 deletions server/api/src/agents/core/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ export {
type SseResearchIterationEvent,
type SseResearchEvaluationEvent,
type SseResearchBatchEvent,
type SseComposePhaseEvent,
type SseComposeSectionEvent,
SSE_EVENT_NAMES,
} from "./sseEvents.js";
1 change: 1 addition & 0 deletions server/api/src/agents/runner/sseMapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ function mapComposePhase(data: Record<string, unknown>): SseComposePhaseEvent[]
if (
phase !== "brief" &&
phase !== "research" &&
phase !== "conflict" &&
phase !== "structure" &&
phase !== "draft" &&
phase !== "completed"
Expand Down
22 changes: 17 additions & 5 deletions server/api/src/agents/subgraphs/research/resumeSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,23 @@ import { z } from "zod";
* required (empty array means "reject all"); `rejectedSourceIds` defaults to
* the empty array; `note` is free-form metadata.
*/
export const researchResumeSchema = z.object({
approvedSourceIds: z.array(z.string().min(1)).default([]),
rejectedSourceIds: z.array(z.string().min(1)).optional().default([]),
note: z.string().optional(),
});
export const researchResumeSchema = z
.object({
approvedSourceIds: z.array(z.string().min(1)),
rejectedSourceIds: z.array(z.string().min(1)).optional().default([]),
note: z.string().optional(),
})
.superRefine((value, ctx) => {
const rejected = new Set(value.rejectedSourceIds);
const overlap = value.approvedSourceIds.filter((id) => rejected.has(id));
if (overlap.length > 0) {
ctx.addIssue({
code: "custom",
path: ["rejectedSourceIds"],
message: `approvedSourceIds and rejectedSourceIds must not overlap: ${overlap.join(", ")}`,
});
}
});

/** Inferred TS type for the parsed resume payload. */
export type ResearchResumeParsed = z.infer<typeof researchResumeSchema>;
Loading
Loading