Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/api/drizzle/0031_add_wiki_compose_sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ END $$;
DO $$ BEGIN
ALTER TABLE "wiki_compose_sessions"
ADD CONSTRAINT "wiki_compose_sessions_user_id_users_id_fk"
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE cascade ON UPDATE no action;
FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN NULL;
END $$;
Expand Down
2 changes: 1 addition & 1 deletion server/api/drizzle/0032_add_user_ai_credentials.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ CREATE TABLE IF NOT EXISTS "user_ai_credentials" (
DO $$ BEGIN
ALTER TABLE "user_ai_credentials"
ADD CONSTRAINT "user_ai_credentials_user_id_users_id_fk"
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE cascade ON UPDATE no action;
FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE cascade ON UPDATE no action;
EXCEPTION
WHEN duplicate_object THEN NULL;
END $$;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ import { describe, expect, it, vi, beforeEach } from "vitest";
import { HTTPException } from "hono/http-exception";
import { assertComposeBackendReady } from "../../../agents/core/composeBackendValidation.js";

const mockValidateModelAccess = vi.fn();
const mockGetUserAiCredentialPlaintext = vi.fn();

vi.mock("../../../services/usageService.js", () => ({
validateModelAccess: (...args: unknown[]) => mockValidateModelAccess(...args),
}));

vi.mock("../../../services/userAiCredentialService.js", () => ({
getUserAiCredentialPlaintext: (...args: unknown[]) => mockGetUserAiCredentialPlaintext(...args),
}));
Expand All @@ -18,12 +13,6 @@ describe("assertComposeBackendReady", () => {

beforeEach(() => {
vi.clearAllMocks();
mockValidateModelAccess.mockResolvedValue({
provider: "anthropic",
apiModelId: "claude-3-5-haiku",
inputCostUnits: 1,
outputCostUnits: 2,
});
mockGetUserAiCredentialPlaintext.mockResolvedValue("sk-user");
});

Expand All @@ -35,25 +24,29 @@ describe("assertComposeBackendReady", () => {
tier: "free",
db,
});
expect(mockValidateModelAccess).not.toHaveBeenCalled();
expect(mockGetUserAiCredentialPlaintext).not.toHaveBeenCalled();
});

it("throws 400 when model provider mismatches BYOK backend", async () => {
mockValidateModelAccess.mockResolvedValue({
provider: "openai",
apiModelId: "gpt-4o-mini",
inputCostUnits: 1,
outputCostUnits: 2,
it("allows BYOK backend without static env model provider mismatch (#972)", async () => {
await assertComposeBackendReady({
backend: "user_openai",
graphId: "wiki-compose-research",
userId: "u1",
tier: "free",
db,
});
await expect(
assertComposeBackendReady({
backend: "user_anthropic",
graphId: "wiki-compose-research",
userId: "u1",
tier: "free",
db,
}),
).rejects.toMatchObject({ status: 400 });
expect(mockGetUserAiCredentialPlaintext).toHaveBeenCalledWith("u1", "openai", db);
});

it("skips credential check for model-less graphs (wiki-maintenance)", async () => {
await assertComposeBackendReady({
backend: "user_anthropic",
graphId: "wiki-maintenance",
userId: "u1",
tier: "free",
db,
});
expect(mockGetUserAiCredentialPlaintext).not.toHaveBeenCalled();
});

it("throws 400 when credential is missing", async () => {
Expand Down
11 changes: 11 additions & 0 deletions server/api/src/__tests__/agents/runner/sseMapper.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ describe("mapLangGraphEvent", () => {
]);
});

it("maps on_custom_event compose_phase conflict to a typed compose_phase event", () => {
const ev: LangGraphRuntimeEvent = {
event: "on_custom_event",
name: "compose_phase",
data: { phase: "conflict", status: "entered" },
};
expect(mapLangGraphEvent(ev)).toEqual([
{ type: "compose_phase", phase: "conflict", status: "entered" },
]);
});

it("drops compose_phase with an unknown phase value", () => {
const ev: LangGraphRuntimeEvent = {
event: "on_custom_event",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,6 @@ describe("projectComposeStateValues", () => {
});
expect(projection.completedMarkdown).toBe("## A\n\nBody");
expect(projection.draftedSections).toHaveLength(1);
expect(projection.phase).toBe("completed");
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ export function getPostgresCheckpointer(schema: string = "public"): PostgresSave
export async function ensurePostgresCheckpointerSetup(schema: string = "public"): Promise<void> {
if (!setupOnce) {
const saver = getPostgresCheckpointer(schema);
setupOnce = saver.setup();
setupOnce = saver.setup().catch((error: unknown) => {
setupOnce = null;
throw error;
});
}
return setupOnce;
}
Expand Down
15 changes: 4 additions & 11 deletions server/api/src/agents/core/composeBackendValidation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* Compose グラフのモデル provider と BYOK backend の整合を検証する。
*/
import { HTTPException } from "hono/http-exception";
import { validateModelAccess } from "../../services/usageService.js";
import type { Database, UserTier } from "../../types/index.js";
import { getComposeModelIdsForGraph } from "./composeModelConfig.js";
import {
Expand All @@ -26,18 +25,12 @@ export async function assertComposeBackendReady(input: {
}): Promise<void> {
if (!isUserByokBackend(input.backend)) return;

const expectedProvider = backendToCredentialProvider(input.backend);
const modelIds = getComposeModelIdsForGraph(input.graphId);
// Model-less graphs (e.g. wiki-maintenance) never call `createZediChatModel`.
// Provider matching for BYOK runs is enforced at runtime via `resolveComposeModelId`.
if (modelIds.length === 0) return;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for (const modelId of modelIds) {
const modelInfo = await validateModelAccess(modelId, input.tier, input.db);
if (modelInfo.provider !== expectedProvider) {
throw new HTTPException(400, {
message: `Backend "${input.backend}" does not match compose model "${modelId}" (provider ${modelInfo.provider})`,
});
}
}

const expectedProvider = backendToCredentialProvider(input.backend);
const key = await getUserAiCredentialPlaintext(input.userId, expectedProvider, input.db);
if (!key?.trim()) {
throw new HTTPException(400, {
Expand Down
126 changes: 64 additions & 62 deletions server/api/src/agents/core/llm/zediChatModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,72 +283,74 @@ export class ZediChatModel extends BaseChatModel<BaseChatModelCallOptions> {
let finishReason: string | undefined;
let done = false;

for await (const chunk of gen) {
if (chunk.content) {
accumulated += chunk.content;
const chatChunk = new ChatGenerationChunk({
text: chunk.content,
message: new AIMessageChunk({ content: chunk.content }),
});
// Surface incremental tokens to LangChain callback consumers so any
// `streamEvents` listener (e.g. SSE mapper) sees deltas before the
// final usage event.
await runManager?.handleLLMNewToken(
chunk.content,
undefined,
undefined,
undefined,
undefined,
{
chunk: chatChunk,
},
);
yield chatChunk;
}
if (chunk.done) {
finishReason = chunk.finishReason;
done = true;
break;
try {
for await (const chunk of gen) {
if (chunk.content) {
accumulated += chunk.content;
const chatChunk = new ChatGenerationChunk({
text: chunk.content,
message: new AIMessageChunk({ content: chunk.content }),
});
// Surface incremental tokens to LangChain callback consumers so any
// `streamEvents` listener (e.g. SSE mapper) sees deltas before the
// final usage event.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
await runManager?.handleLLMNewToken(
chunk.content,
undefined,
undefined,
undefined,
undefined,
{
chunk: chatChunk,
},
);
yield chatChunk;
}
if (chunk.done) {
finishReason = chunk.finishReason;
done = true;
break;
}
}
}
} finally {
// Streaming providers in this codebase do not return token counts; mirror
// chat.ts and estimate. Pre-encoded message length is what the user "paid"
// for, response length is what they received.
const promptLength = zediMessages.reduce((sum, m) => sum + m.content.length, 0);
const inputTokens = Math.ceil(promptLength / 4);
const outputTokens = Math.ceil(accumulated.length / 4);

// Streaming providers in this codebase do not return token counts; mirror
// chat.ts and estimate. Pre-encoded message length is what the user "paid"
// for, response length is what they received.
const promptLength = zediMessages.reduce((sum, m) => sum + m.content.length, 0);
const inputTokens = Math.ceil(promptLength / 4);
const outputTokens = Math.ceil(accumulated.length / 4);

const usage = await recordZediUsage({
db: this.db,
userId: this.userId,
modelId: this.modelRowId,
feature: this.feature,
usage: { inputTokens, outputTokens },
inputCostUnits: this.inputCostUnits,
outputCostUnits: this.outputCostUnits,
apiMode: this.apiMode,
});
const usage = await recordZediUsage({
db: this.db,
userId: this.userId,
modelId: this.modelRowId,
feature: this.feature,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip usage billing when streaming call errors

recordZediUsage now runs unconditionally in finally, so it executes even when the provider stream throws before completion; this is reachable because streamOpenAI/streamAnthropic/streamGoogle throw on non-2xx responses in server/api/src/services/aiProviders.ts. In those failure paths the request fails but we still charge estimated prompt tokens, which can inflate usage/budget accounting and bill failed calls; billing should be limited to successful completion (or a narrowly-detected abort case).

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

対応しました。finally による無条件課金をやめ、routes/ai/chat.ts と同様に done === true のときだけ recordZediUsage を呼ぶようにしました。プロバイダ例外は for await からそのまま伝播するため課金コードに到達しません。incomplete 終了は usage メタデータのみ返し DB 課金しません。テスト does not record usage when the provider stream throws を追加済みです。

usage: { inputTokens, outputTokens },
inputCostUnits: this.inputCostUnits,
outputCostUnits: this.outputCostUnits,
apiMode: this.apiMode,
});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

finally ブロック内で recordZediUsage を await していますが、この処理が失敗した場合に try ブロックで発生した元のエラーが上書き(マスク)されてしまう可能性があります。利用状況の記録は重要ですが、ここでの例外が原因で最終的なチャンクの yield や後続のクリーンアップが阻害されるのを防ぐため、recordZediUsage を try-catch で囲んでエラーをハンドル(ログ出力など)することを検討してください。

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

対応しました。finally を廃止し、done === true 時のみ recordZediUsage を呼びます。recordZediUsage 自体は try/catch で囲み、失敗時は console.error のうえ推定トークンで最終チャンクを返します(成功ストリームの例外マスクを防止)。

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

// Final chunk surfaces aggregate usage so downstream consumers (sseMapper,
// LangChain callbacks) can read totals from a single ChatGenerationChunk.
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
content: "",
response_metadata: {
// Final chunk surfaces aggregate usage so downstream consumers (sseMapper,
// LangChain callbacks) can read totals from a single ChatGenerationChunk.
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
content: "",
response_metadata: {
finishReason: finishReason ?? (done ? "stop" : "incomplete"),
},
usage_metadata: {
input_tokens: usage.inputTokens,
output_tokens: usage.outputTokens,
total_tokens: usage.inputTokens + usage.outputTokens,
},
}),
generationInfo: {
finishReason: finishReason ?? (done ? "stop" : "incomplete"),
costUnits: usage.costUnits,
},
usage_metadata: {
input_tokens: usage.inputTokens,
output_tokens: usage.outputTokens,
total_tokens: usage.inputTokens + usage.outputTokens,
},
}),
generationInfo: {
finishReason: finishReason ?? (done ? "stop" : "incomplete"),
costUnits: usage.costUnits,
},
});
});
}
}
}
23 changes: 14 additions & 9 deletions server/api/src/agents/core/tools/resolveWebSearchModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,18 @@ export async function resolveWebSearchModelId(
)
.orderBy(asc(aiModels.inputCostUnits), asc(aiModels.outputCostUnits));

// Prefer OpenAI when costs tie, since `useWebSearch` (chat completions
// `web_search_options`) is well-tested in `aiProviders.ts`; Google's
// `googleSearch` tool is also supported but requires the `tools` payload.
// 同コストなら OpenAI を優先(`useWebSearch` 経路が安定)。
const openai = rows.find((r) => r.provider === "openai");
if (openai) return openai.id;
const google = rows.find((r) => r.provider === "google");
if (google) return google.id;
return null;
if (rows.length === 0) return null;

const first = rows[0];
if (!first) return null;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

直前の 94 行目で rows.length === 0 の場合に null を返しているため、この if (!first) チェックは冗長です。コードを簡潔にするために削除を検討してください。

References
  1. 簡潔で意図が明確なコードを推奨します。 (link)
  2. 不要な条件分岐や到達不能なコードを削除することで、コードの可読性と保守性を向上させます。

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

対応しました。rows.length === 0 の直後に rows[0] を直接参照する形に整理し、冗長な if (!first) を削除しました。


// Prefer OpenAI only among cheapest rows (cost tie-break), since `useWebSearch`
// is well-tested in `aiProviders.ts`.
const cheapestInput = first.inputCostUnits;
const cheapestOutput = first.outputCostUnits;
const cheapest = rows.filter(
(r) => r.inputCostUnits === cheapestInput && r.outputCostUnits === cheapestOutput,
);
const preferred = cheapest.find((r) => r.provider === "openai") ?? cheapest[0];
return preferred?.id ?? null;
}
2 changes: 2 additions & 0 deletions server/api/src/agents/core/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ export {
type SseResearchIterationEvent,
type SseResearchEvaluationEvent,
type SseResearchBatchEvent,
type SseComposePhaseEvent,
type SseComposeSectionEvent,
SSE_EVENT_NAMES,
} from "./sseEvents.js";
1 change: 1 addition & 0 deletions server/api/src/agents/runner/sseMapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ function mapComposePhase(data: Record<string, unknown>): SseComposePhaseEvent[]
if (
phase !== "brief" &&
phase !== "research" &&
phase !== "conflict" &&
phase !== "structure" &&
phase !== "draft" &&
phase !== "completed"
Expand Down
22 changes: 17 additions & 5 deletions server/api/src/agents/subgraphs/research/resumeSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,23 @@ import { z } from "zod";
* required (empty array means "reject all"); `rejectedSourceIds` defaults to
* the empty array; `note` is free-form metadata.
*/
export const researchResumeSchema = z.object({
approvedSourceIds: z.array(z.string().min(1)).default([]),
rejectedSourceIds: z.array(z.string().min(1)).optional().default([]),
note: z.string().optional(),
});
export const researchResumeSchema = z
.object({
approvedSourceIds: z.array(z.string().min(1)),
rejectedSourceIds: z.array(z.string().min(1)).optional().default([]),
note: z.string().optional(),
})
.superRefine((value, ctx) => {
const rejected = new Set(value.rejectedSourceIds);
const overlap = value.approvedSourceIds.filter((id) => rejected.has(id));
if (overlap.length > 0) {
ctx.addIssue({
code: "custom",
path: ["rejectedSourceIds"],
message: `approvedSourceIds and rejectedSourceIds must not overlap: ${overlap.join(", ")}`,
});
}
});

/** Inferred TS type for the parsed resume payload. */
export type ResearchResumeParsed = z.infer<typeof researchResumeSchema>;
5 changes: 4 additions & 1 deletion server/api/src/routes/composeSessionProjection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ export function projectComposeStateValues(
// Interrupt-derived phase wins; row `phase` is only a fallback.
// interrupt 由来の phase を優先し、行の phase はフォールバックのみ。
if (typeof state.phase === "string" && projection.phase === undefined) {
projection.phase = phaseFromSessionRow(state.phase, "interrupted");
projection.phase =
state.phase.startsWith("completed") || state.phase === "completed"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

state.phase.startsWith("completed") が真であれば state.phase === "completed" の条件も満たされるため、後者のチェックは冗長です。コードを簡潔にするために削除を検討してください。

References
  1. 冗長な論理演算を避けることを推奨します。 (link)
  2. 論理的に冗長な条件式を整理することで、コードの意図をより明確にします。

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

対応しました。startsWith("completed") のみに簡略化しました。

? "completed"
: phaseFromSessionRow(state.phase, "interrupted");
}

return projection;
Expand Down
Loading
Loading