diff --git a/bundlesize.config.json b/bundlesize.config.json index 1dfc56e30e..666b80171d 100644 --- a/bundlesize.config.json +++ b/bundlesize.config.json @@ -10,11 +10,11 @@ }, { "path": "./packages/instantsearch.js/dist/instantsearch.production.min.js", - "maxSize": "121.8 kB" + "maxSize": "122.75 kB" }, { "path": "./packages/instantsearch.js/dist/instantsearch.development.js", - "maxSize": "253.4 kB" + "maxSize": "255 kB" }, { "path": "packages/react-instantsearch-core/dist/umd/ReactInstantSearchCore.min.js", diff --git a/packages/instantsearch-ui-components/src/components/chat/ChatMessage.tsx b/packages/instantsearch-ui-components/src/components/chat/ChatMessage.tsx index 7877186fc9..fcfc64b74a 100644 --- a/packages/instantsearch-ui-components/src/components/chat/ChatMessage.tsx +++ b/packages/instantsearch-ui-components/src/components/chat/ChatMessage.tsx @@ -7,6 +7,7 @@ import { createButtonComponent } from '../Button'; import { MenuIcon } from './icons'; import type { ComponentProps, Renderer, VNode } from '../../types'; +import type { ChatMessageLoaderProps } from './ChatMessageLoader'; import type { AddToolResultWithOutput, ChatMessageBase, @@ -135,6 +136,10 @@ export type ChatMessageProps = ComponentProps<'article'> & { * Array of tools available for the assistant (for tool messages) */ tools: ClientSideTools; + /** + * Loader component passed to tool layout components + */ + loaderComponent: (props: ChatMessageLoaderProps) => JSX.Element; /** * Optional suggestions element */ @@ -171,6 +176,7 @@ export function createChatMessageComponent({ createElement }: Renderer) { indexUiState, setIndexUiState, onClose, + loaderComponent: LoaderComponent, translations: userTranslations, suggestionsElement, ...props @@ -237,6 +243,13 @@ export function createChatMessageComponent({ createElement }: Renderer) { toolCallId: toolMessage.toolCallId, }); + if ( + toolMessage.state === 'input-streaming' && + !tool.streamInput + ) { + return null; + } + if (!ToolLayoutComponent) { return null; } @@ -252,8 +265,9 @@ export function createChatMessageComponent({ createElement }: Renderer) { setIndexUiState={setIndexUiState} addToolResult={boundAddToolResult} applyFilters={tool.applyFilters} - sendEvent={tool.sendEvent || (() => {})} + sendEvent={tool.sendEvent || (() => { })} onClose={onClose} + loaderComponent={LoaderComponent} /> ); diff --git a/packages/instantsearch-ui-components/src/components/chat/ChatMessageLoader.tsx b/packages/instantsearch-ui-components/src/components/chat/ChatMessageLoader.tsx index 32ce1590a3..db193f9e59 100644 --- a/packages/instantsearch-ui-components/src/components/chat/ChatMessageLoader.tsx +++ b/packages/instantsearch-ui-components/src/components/chat/ChatMessageLoader.tsx @@ -24,7 +24,7 @@ export function createChatMessageLoaderComponent({ return function ChatMessageLoader(userProps: ChatMessageLoaderProps) { const { translations: userTranslations, ...props } = userProps; const translations: Required = { - loaderText: 'Thinking...', + loaderText: '', ...userTranslations, }; diff --git a/packages/instantsearch-ui-components/src/components/chat/ChatMessages.tsx b/packages/instantsearch-ui-components/src/components/chat/ChatMessages.tsx index 8dee18329f..ef97805c3b 100644 --- a/packages/instantsearch-ui-components/src/components/chat/ChatMessages.tsx +++ b/packages/instantsearch-ui-components/src/components/chat/ChatMessages.tsx @@ -2,9 +2,11 @@ import { cx } from '../../lib'; import { + findTool, getTextContent, hasTextContent, isPartText, + isPartTool, } from '../../lib/utils/chat'; import { createButtonComponent } from '../Button'; @@ -225,6 +227,7 @@ function createDefaultMessageComponent< onFeedback, feedbackState, actionsComponent, + loaderComponent, classNames, messageTranslations, translations, @@ -243,6 +246,7 @@ function createDefaultMessageComponent< onFeedback?: (messageId: string, vote: 0 | 1) => void; feedbackState?: Record; actionsComponent?: ChatMessageProps['actionsComponent']; + loaderComponent: ChatMessageProps['loaderComponent']; translations: ChatMessagesTranslations; classNames?: Partial; messageTranslations?: Partial; @@ -326,6 +330,7 @@ function createDefaultMessageComponent< onClose={onClose} actions={defaultActions} actionsComponent={actionsComponent} + loaderComponent={loaderComponent} data-role={message.role} classNames={classNames} translations={messageTranslations} @@ -411,15 +416,7 @@ export function createChatMessagesComponent({ const lastMessage = messages[messages.length - 1]; const lastPart = lastMessage?.parts?.[lastMessage.parts.length - 1]; - const isWaitingForResponse = status === 'submitted'; - const isStreamingWithNoContent = status === 'streaming' && !lastPart; - const isStreamingNonTextContent = - status === 'streaming' && lastPart && !isPartText(lastPart); - - const showLoader = - isWaitingForResponse || - isStreamingWithNoContent || - isStreamingNonTextContent; + const showLoader = getShowLoader(status, lastPart, tools); const DefaultMessage = MessageComponent || DefaultMessageComponent; const DefaultLoader = LoaderComponent || DefaultLoaderComponent; @@ -463,6 +460,7 @@ export function createChatMessagesComponent({ onFeedback={onFeedback} feedbackState={feedbackState} actionsComponent={ActionsComponent} + loaderComponent={DefaultLoader} onClose={onClose} translations={translations} classNames={messageClassNames} @@ -506,3 +504,27 @@ export function createChatMessagesComponent({ ); }; } + +const getShowLoader = ( + status: ChatStatus, + lastPart: ChatMessageBase['parts'][number] | undefined, + tools: ClientSideTools +): boolean => { + if (status !== 'submitted' && status !== 'streaming') return false; + if (status === 'submitted') return true; + + if (!lastPart) return true; + if (isPartText(lastPart)) return false; + + if (isPartTool(lastPart)) { + if (lastPart.state === 'output-available') return false; + if (lastPart.state === 'input-streaming') { + const tool = findTool(lastPart.type, tools); + return !tool?.streamInput; + } + return true; + } + + return true; +}; + diff --git a/packages/instantsearch-ui-components/src/components/chat/__tests__/ChatMessage.test.tsx b/packages/instantsearch-ui-components/src/components/chat/__tests__/ChatMessage.test.tsx index 1dcc0e2440..2a58ca706e 100644 --- a/packages/instantsearch-ui-components/src/components/chat/__tests__/ChatMessage.test.tsx +++ b/packages/instantsearch-ui-components/src/components/chat/__tests__/ChatMessage.test.tsx @@ -21,6 +21,7 @@ describe('ChatMessage', () => { message={{ role: 'user', id: '1', parts: [] }} status="ready" tools={{}} + loaderComponent={jest.fn()} onClose={jest.fn()} /> ); @@ -66,6 +67,7 @@ describe('ChatMessage', () => { actions: 'actions', }} tools={{}} + loaderComponent={jest.fn()} onClose={jest.fn()} /> ); @@ -104,6 +106,7 @@ describe('ChatMessage', () => { }} status="ready" tools={{}} + loaderComponent={jest.fn()} onClose={jest.fn()} /> { }} status="ready" tools={{}} + loaderComponent={jest.fn()} onClose={jest.fn()} /> { }} status="ready" tools={{}} + loaderComponent={jest.fn()} onClose={jest.fn()} /> @@ -235,6 +240,7 @@ describe('ChatMessage', () => { applyFilters: jest.fn(), }, }} + loaderComponent={jest.fn()} onClose={jest.fn()} /> ); diff --git a/packages/instantsearch-ui-components/src/components/chat/types.ts b/packages/instantsearch-ui-components/src/components/chat/types.ts index 011321cb27..37f42e3691 100644 --- a/packages/instantsearch-ui-components/src/components/chat/types.ts +++ b/packages/instantsearch-ui-components/src/components/chat/types.ts @@ -1,4 +1,5 @@ import type { ComponentProps, SendEventForHits } from '../../types'; +import type { ChatMessageLoaderProps } from './ChatMessageLoader'; import type { SearchParameters } from 'algoliasearch-helper'; export type ChatStatus = 'ready' | 'submitted' | 'streaming' | 'error'; @@ -485,6 +486,7 @@ export type ClientSideToolComponentProps = { onClose: () => void; addToolResult: AddToolResultWithOutput; applyFilters: (params: ApplyFiltersParams) => SearchParameters; + loaderComponent: (props: ChatMessageLoaderProps) => JSX.Element; sendEvent: SendEventForHits; }; @@ -494,6 +496,7 @@ export type ClientSideToolComponent = ( export type ClientSideTool = { layoutComponent?: ClientSideToolComponent; + streamInput?: boolean; addToolResult: AddToolResult; sendEvent?: SendEventForHits; onToolCall?: ( diff --git a/packages/instantsearch-ui-components/src/lib/utils/chat.ts b/packages/instantsearch-ui-components/src/lib/utils/chat.ts index 0e170dd500..abf89a6145 100644 --- a/packages/instantsearch-ui-components/src/lib/utils/chat.ts +++ b/packages/instantsearch-ui-components/src/lib/utils/chat.ts @@ -1,4 +1,11 @@ +import { startsWith } from './startsWith'; + import type { ChatMessageBase } from '../../components'; +import type { + ChatToolMessage, + ClientSideTool, + ClientSideTools, +} from '../../components/chat/types'; export const getTextContent = (message: ChatMessageBase) => { return message.parts @@ -15,3 +22,23 @@ export const isPartText = ( ): part is Extract => { return part.type === 'text'; }; + +export const isPartTool = ( + part: ChatMessageBase['parts'][number] +): part is ChatToolMessage => { + return startsWith(part.type, 'tool-'); +}; + +export const findTool = ( + partType: string, + tools: ClientSideTools +): ClientSideTool | undefined => { + const toolName = partType.replace('tool-', ''); + let tool: ClientSideTool | undefined = tools[toolName]; + if (!tool) { + tool = Object.entries(tools).find(([key]) => + startsWith(toolName, `${key}_`) + )?.[1]; + } + return tool; +}; diff --git a/packages/instantsearch.js/src/connectors/chat/__tests__/connectChat-test.ts b/packages/instantsearch.js/src/connectors/chat/__tests__/connectChat-test.ts index 38b73e1d11..aff5df1658 100644 --- a/packages/instantsearch.js/src/connectors/chat/__tests__/connectChat-test.ts +++ b/packages/instantsearch.js/src/connectors/chat/__tests__/connectChat-test.ts @@ -506,6 +506,194 @@ data: [DONE]`, ); }); }); + + it('streams tool input parts from tool-input-delta without tool-input-available', async () => { + const { widget } = getInitializedWidget({ + agentId: undefined, + transport: { + fetch: () => + Promise.resolve( + new Response( + `data: {"type": "start", "messageId": "test-id"} + +data: {"type": "start-step"} + +data: {"type": "tool-input-start", "toolCallId": "call_1", "toolName": "displayResults"} + +data: {"type": "tool-input-delta", "toolCallId": "call_1", "toolName": "displayResults", "inputTextDelta": "{}"} + +data: {"type": "finish-step"} + +data: {"type": "finish"} + +data: [DONE]`, + { + headers: { 'Content-Type': 'text/event-stream' }, + } + ) + ), + }, + }); + + const { chatInstance } = widget; + + await chatInstance.sendMessage({ + id: 'message-id', + role: 'user', + parts: [{ type: 'text', text: 'Show me product groups' }], + }); + + await waitFor(() => { + const lastMessage = chatInstance.messages[chatInstance.messages.length - 1]; + expect(lastMessage?.role).toBe('assistant'); + + const toolPart = lastMessage?.parts.find( + (part) => + 'type' in part && + part.type === 'tool-displayResults' && + 'toolCallId' in part && + part.toolCallId === 'call_1' + ) as + | { + state: string; + rawInput?: string; + input?: Record; + } + | undefined; + + expect(toolPart?.state).toBe('input-streaming'); + expect(toolPart?.input).toEqual({}); + }); + }); + + it('skips JSON repair for tools without streamInput (default)', async () => { + const { widget } = getInitializedWidget({ + agentId: undefined, + tools: { + myTool: {}, + }, + transport: { + fetch: () => + Promise.resolve( + new Response( + `data: {"type": "start", "messageId": "test-id"} + +data: {"type": "start-step"} + +data: {"type": "tool-input-start", "toolCallId": "call_1", "toolName": "myTool"} + +data: {"type": "tool-input-delta", "toolCallId": "call_1", "toolName": "myTool", "inputTextDelta": "{\\"query\\": \\"sho"} + +data: {"type": "finish-step"} + +data: {"type": "finish"} + +data: [DONE]`, + { + headers: { 'Content-Type': 'text/event-stream' }, + } + ) + ), + }, + }); + + const { chatInstance } = widget; + + await chatInstance.sendMessage({ + id: 'message-id', + role: 'user', + parts: [{ type: 'text', text: 'search' }], + }); + + await waitFor(() => { + const lastMessage = + chatInstance.messages[chatInstance.messages.length - 1]; + const toolPart = lastMessage?.parts.find( + (part) => + 'type' in part && + part.type === 'tool-myTool' && + 'toolCallId' in part && + part.toolCallId === 'call_1' + ) as + | { + state: string; + rawInput?: string; + input?: unknown; + } + | undefined; + + expect(toolPart?.state).toBe('input-streaming'); + // Input is not repaired since streamInput is not set (default) + expect(toolPart?.input).toBeUndefined(); + // Raw input is still accumulated + expect(toolPart?.rawInput).toBe('{"query": "sho'); + }); + }); + + it('repairs JSON for tools with streamInput set to true', async () => { + const { widget } = getInitializedWidget({ + agentId: undefined, + tools: { + myTool: { + streamInput: true, + }, + }, + transport: { + fetch: () => + Promise.resolve( + new Response( + `data: {"type": "start", "messageId": "test-id"} + +data: {"type": "start-step"} + +data: {"type": "tool-input-start", "toolCallId": "call_1", "toolName": "myTool"} + +data: {"type": "tool-input-delta", "toolCallId": "call_1", "toolName": "myTool", "inputTextDelta": "{\\"query\\": \\"sho"} + +data: {"type": "finish-step"} + +data: {"type": "finish"} + +data: [DONE]`, + { + headers: { 'Content-Type': 'text/event-stream' }, + } + ) + ), + }, + }); + + const { chatInstance } = widget; + + await chatInstance.sendMessage({ + id: 'message-id', + role: 'user', + parts: [{ type: 'text', text: 'search' }], + }); + + await waitFor(() => { + const lastMessage = + chatInstance.messages[chatInstance.messages.length - 1]; + const toolPart = lastMessage?.parts.find( + (part) => + 'type' in part && + part.type === 'tool-myTool' && + 'toolCallId' in part && + part.toolCallId === 'call_1' + ) as + | { + state: string; + rawInput?: string; + input?: unknown; + } + | undefined; + + expect(toolPart?.state).toBe('input-streaming'); + // Input is repaired since streamInput is true + expect(toolPart?.input).toEqual({ query: 'sho' }); + expect(toolPart?.rawInput).toBe('{"query": "sho'); + }); + }); }); describe('transport configuration', () => { diff --git a/packages/instantsearch.js/src/connectors/chat/connectChat.ts b/packages/instantsearch.js/src/connectors/chat/connectChat.ts index 9d88d3cce5..2e493e4c3c 100644 --- a/packages/instantsearch.js/src/connectors/chat/connectChat.ts +++ b/packages/instantsearch.js/src/connectors/chat/connectChat.ts @@ -420,6 +420,14 @@ export default (function connectChat( ...options, transport, sendAutomaticallyWhen: lastAssistantMessageIsCompleteWithToolCalls, + shouldRepairToolInput(toolName) { + let tool = tools[toolName]; + if (!tool && toolName.startsWith(`${SearchIndexToolType}_`)) { + tool = tools[SearchIndexToolType]; + } + if (!tool) return true; + return Boolean(tool.streamInput); + }, onToolCall({ toolCall }) { let tool = tools[toolCall.toolName]; diff --git a/packages/instantsearch.js/src/lib/ai-lite/abstract-chat.ts b/packages/instantsearch.js/src/lib/ai-lite/abstract-chat.ts index d31307101e..4cbdb794ea 100644 --- a/packages/instantsearch.js/src/lib/ai-lite/abstract-chat.ts +++ b/packages/instantsearch.js/src/lib/ai-lite/abstract-chat.ts @@ -26,6 +26,96 @@ type ActiveResponse = { stream?: ReadableStream; }; +const tryParseJson = (value: string): unknown | undefined => { + try { + return JSON.parse(value); + } catch { + return undefined; + } +}; + +const repairPartialJson = (value: string): string => { + let repaired = value.trim(); + + if (!repaired) { + return repaired; + } + + let inString = false; + let isEscaped = false; + const stack: Array<'{' | '['> = []; + + for (let index = 0; index < repaired.length; index++) { + const char = repaired[index]; + if (inString) { + if (isEscaped) { + isEscaped = false; + } else if (char === '\\') { + isEscaped = true; + } else if (char === '"') { + inString = false; + } + continue; + } + + if (char === '"') { + inString = true; + continue; + } + + if (char === '{' || char === '[') { + stack.push(char); + continue; + } + + if (char === '}' && stack[stack.length - 1] === '{') { + stack.pop(); + continue; + } + + if (char === ']' && stack[stack.length - 1] === '[') { + stack.pop(); + } + } + + if (inString && !isEscaped) { + repaired += '"'; + } + + repaired = repaired.replace(/,\s*$/u, ''); + + if (stack.length > 0) { + repaired += stack + .reverse() + .map((opening) => (opening === '{' ? '}' : ']')) + .join(''); + } + + return repaired.replace(/,\s*([}\]])/gu, '$1'); +}; + +const parseToolInputDelta = ( + accumulatedRawInput: string, + fallbackInput: unknown +): unknown => { + const normalized = accumulatedRawInput.trim(); + if (!normalized) { + return fallbackInput; + } + + const directParsed = tryParseJson(normalized); + if (directParsed !== undefined) { + return directParsed; + } + + const repairedParsed = tryParseJson(repairPartialJson(normalized)); + if (repairedParsed !== undefined) { + return repairedParsed; + } + + return fallbackInput; +}; + /** * Abstract base class for chat implementations. */ @@ -42,6 +132,7 @@ export abstract class AbstractChat { private sendAutomaticallyWhen?: (options: { messages: TUIMessage[]; }) => boolean | PromiseLike; + private shouldRepairToolInput?: (toolName: string) => boolean; private activeResponse: ActiveResponse | null = null; private jobExecutor = new SerialJobExecutor(); @@ -56,6 +147,7 @@ export abstract class AbstractChat { onFinish, onData, sendAutomaticallyWhen, + shouldRepairToolInput, }: Omit, 'messages'> & { state: ChatState; }) { @@ -68,6 +160,7 @@ export abstract class AbstractChat { this.onFinish = onFinish; this.onData = onData; this.sendAutomaticallyWhen = sendAutomaticallyWhen; + this.shouldRepairToolInput = shouldRepairToolInput; } /** @@ -437,6 +530,7 @@ export abstract class AbstractChat { // Track current text/reasoning part state let currentTextPartId: string | undefined; let currentReasoningPartId: string | undefined; + const toolRawInputByCallId: Record = {}; // Promise chain for handling tool calls that return promises let pendingToolCall: Promise = Promise.resolve(); @@ -623,11 +717,21 @@ export abstract class AbstractChat { case 'tool-input-start': { if (!currentMessage) break; + const initialRawInput = + typeof chunk.input === 'string' + ? chunk.input + : chunk.input !== undefined + ? JSON.stringify(chunk.input) + : ''; + + toolRawInputByCallId[chunk.toolCallId] = initialRawInput; + const toolPart = { type: `tool-${chunk.toolName}` as const, toolCallId: chunk.toolCallId, state: 'input-streaming' as const, input: chunk.input, + rawInput: initialRawInput || undefined, providerExecuted: chunk.providerExecuted, }; @@ -640,14 +744,67 @@ export abstract class AbstractChat { } case 'tool-input-delta': { - // Tool input streaming - we'd need to parse partial JSON - // For now, we'll wait for tool-input-available + if (!currentMessage) break; + + const toolIndex = currentMessage.parts.findIndex( + (p) => 'toolCallId' in p && p.toolCallId === chunk.toolCallId + ); + + const existingPart = + toolIndex >= 0 + ? (currentMessage.parts[toolIndex] as any) + : null; + const previousRawInput = + existingPart?.rawInput ?? + toolRawInputByCallId[chunk.toolCallId] ?? + ''; + const nextRawInput = `${previousRawInput}${chunk.inputTextDelta}`; + toolRawInputByCallId[chunk.toolCallId] = nextRawInput; + + const toolName = + chunk.toolName ?? + existingPart?.type?.replace('tool-', ''); + const shouldRepair = + toolName + ? (this.shouldRepairToolInput?.(toolName) ?? true) + : true; + const parsedInput = shouldRepair + ? parseToolInputDelta(nextRawInput, existingPart?.input) + : existingPart?.input; + + const nextToolPart = { + ...(existingPart ?? { + type: `tool-${chunk.toolName}` as const, + toolCallId: chunk.toolCallId, + }), + state: 'input-streaming' as const, + input: parsedInput, + rawInput: nextRawInput, + }; + + if (toolIndex >= 0) { + const updatedParts = [...currentMessage.parts]; + updatedParts[toolIndex] = nextToolPart; + currentMessage = { + ...currentMessage, + parts: updatedParts, + } as TUIMessage; + } else { + currentMessage = { + ...currentMessage, + parts: [...currentMessage.parts, nextToolPart], + } as TUIMessage; + } + + this.state.replaceMessage(currentMessageIndex, currentMessage); break; } case 'tool-input-available': { if (!currentMessage) break; + delete toolRawInputByCallId[chunk.toolCallId]; + // Find existing tool part or create new one const existingIndex = currentMessage.parts.findIndex( (p) => 'toolCallId' in p && p.toolCallId === chunk.toolCallId @@ -702,6 +859,8 @@ export abstract class AbstractChat { ); if (toolIndex >= 0) { + delete toolRawInputByCallId[chunk.toolCallId]; + const updatedParts = [...currentMessage.parts]; const existingPart = updatedParts[toolIndex] as any; updatedParts[toolIndex] = { @@ -728,6 +887,8 @@ export abstract class AbstractChat { ); if (toolIndex >= 0) { + delete toolRawInputByCallId[chunk.toolCallId]; + const updatedParts = [...currentMessage.parts]; const existingPart = updatedParts[toolIndex] as any; updatedParts[toolIndex] = { diff --git a/packages/instantsearch.js/src/lib/ai-lite/types.ts b/packages/instantsearch.js/src/lib/ai-lite/types.ts index a317953ae4..12d83e7896 100644 --- a/packages/instantsearch.js/src/lib/ai-lite/types.ts +++ b/packages/instantsearch.js/src/lib/ai-lite/types.ts @@ -84,6 +84,7 @@ export type ToolUIPart = ValueOf<{ | { state: 'input-streaming'; input: DeepPartial | undefined; + rawInput?: string; providerExecuted?: boolean; output?: never; errorText?: never; @@ -125,6 +126,7 @@ export type DynamicToolUIPart = { | { state: 'input-streaming'; input: unknown | undefined; + rawInput?: string; output?: never; errorText?: never; } @@ -267,7 +269,7 @@ export type UIMessageChunk< type: 'tool-input-delta'; toolName: string; toolCallId: string; - inputDelta: string; + inputTextDelta: string; } | { type: 'tool-output-available'; @@ -432,6 +434,7 @@ export interface ChatInit { sendAutomaticallyWhen?: (options: { messages: UI_MESSAGE[]; }) => boolean | PromiseLike; + shouldRepairToolInput?: (toolName: string) => boolean; } export type CreateUIMessage = Omit< diff --git a/packages/instantsearch.js/src/widgets/chat/chat.tsx b/packages/instantsearch.js/src/widgets/chat/chat.tsx index 9b9278f0f5..2cb54002d8 100644 --- a/packages/instantsearch.js/src/widgets/chat/chat.tsx +++ b/packages/instantsearch.js/src/widgets/chat/chat.tsx @@ -97,7 +97,7 @@ function createCarouselTool< applyFilters, onClose, sendEvent, - }: ClientSideToolComponentProps) { + }: ClientSideToolTemplateData) { const input = message?.input as | { query: string; @@ -539,12 +539,16 @@ const createRenderer = ({ layoutComponent: ( layoutComponentProps: ClientSideToolComponentProps ) => { + const { loaderComponent: Loader, ...restProps } = layoutComponentProps; return ( , + }} /> ); }, @@ -962,8 +966,15 @@ const createRenderer = ({ }; }; +export type ClientSideToolTemplateData = Omit< + ClientSideToolComponentProps, + 'loaderComponent' +> & { + loader: () => JSX.Element; +}; + export type UserClientSideToolTemplates = Partial<{ - layout: TemplateWithBindEvent; + layout: TemplateWithBindEvent; }>; type UserClientSideToolWithTemplate = Omit< diff --git a/packages/react-instantsearch/src/widgets/chat/tools/SearchIndexTool.tsx b/packages/react-instantsearch/src/widgets/chat/tools/SearchIndexTool.tsx index d55f3a7c33..67adad96fb 100644 --- a/packages/react-instantsearch/src/widgets/chat/tools/SearchIndexTool.tsx +++ b/packages/react-instantsearch/src/widgets/chat/tools/SearchIndexTool.tsx @@ -49,6 +49,7 @@ function createCarouselTool( | { query: string; number_of_results?: number; + facet_filters?: string[][]; } | undefined; diff --git a/packages/react-instantsearch/src/widgets/chat/tools/__tests__/SearchIndexTool.test.tsx b/packages/react-instantsearch/src/widgets/chat/tools/__tests__/SearchIndexTool.test.tsx index ea1b1f3039..12c6858342 100644 --- a/packages/react-instantsearch/src/widgets/chat/tools/__tests__/SearchIndexTool.test.tsx +++ b/packages/react-instantsearch/src/widgets/chat/tools/__tests__/SearchIndexTool.test.tsx @@ -47,6 +47,7 @@ describe('createCarouselTool', () => { indexUiState={{}} addToolResult={jest.fn()} setIndexUiState={jest.fn()} + loaderComponent={jest.fn()} sendEvent={jest.fn()} /> ); @@ -81,6 +82,7 @@ describe('createCarouselTool', () => { indexUiState={{}} addToolResult={jest.fn()} setIndexUiState={jest.fn()} + loaderComponent={jest.fn()} sendEvent={jest.fn()} /> ); diff --git a/tests/common/widgets/chat/options.tsx b/tests/common/widgets/chat/options.tsx index 7e36f6a9a9..71e4f73a07 100644 --- a/tests/common/widgets/chat/options.tsx +++ b/tests/common/widgets/chat/options.tsx @@ -385,7 +385,7 @@ export function createOptionsTests( ).toBeInTheDocument(); }); - test('shows loader during streaming when last part is a tool without output', async () => { + test('shows loader during streaming when last part is a tool with streaming input', async () => { const searchClient = createSearchClient(); const chat = new Chat({}); @@ -415,7 +415,7 @@ export function createOptionsTests( role: 'assistant', parts: [ { - type: `tool-${SearchIndexToolType}`, + type: 'tool-some_tool', toolCallId: '1', state: 'input-streaming', input: undefined, @@ -432,7 +432,54 @@ export function createOptionsTests( ).toBeInTheDocument(); }); - test('shows loader during streaming when last part is a tool with output', async () => { + test('shows loader during streaming when last part is a tool with input available', async () => { + const searchClient = createSearchClient(); + const chat = new Chat({}); + + await setup({ + instantSearchOptions: { + indexName: 'indexName', + searchClient, + }, + widgetParams: { + javascript: createDefaultWidgetParams(chat), + react: createDefaultWidgetParams(chat), + vue: {}, + }, + }); + + await openChat(act); + + await act(async () => { + chat._state.messages = [ + { + id: '1', + role: 'user', + parts: [{ type: 'text', text: 'Hello' }], + }, + { + id: '2', + role: 'assistant', + parts: [ + { + type: `tool-${SearchIndexToolType}`, + toolCallId: '1', + state: 'input-available', + input: { query: 'shoes' }, + }, + ], + }, + ] as any; + chat._state.status = 'streaming'; + await wait(0); + }); + + expect( + document.querySelector('.ais-ChatMessageLoader') + ).toBeInTheDocument(); + }); + + test('does not show loader during streaming when last part is a tool with output', async () => { const searchClient = createSearchClient(); const chat = new Chat({}); @@ -478,7 +525,7 @@ export function createOptionsTests( expect( document.querySelector('.ais-ChatMessageLoader') - ).toBeInTheDocument(); + ).not.toBeInTheDocument(); }); test('does not show loader during streaming when last part is text', async () => { @@ -986,6 +1033,160 @@ export function createOptionsTests( ); }); + test('passes loaderComponent to tool layout components', async () => { + const searchClient = createSearchClient(); + + const chat = new Chat({ + messages: [ + { + id: '1', + role: 'assistant', + parts: [ + { + type: 'tool-hello', + toolCallId: '1', + input: { text: 'hello' }, + state: 'output-available', + output: 'hello', + }, + ], + }, + ], + id: 'chat-id', + }); + + await setup({ + instantSearchOptions: { + indexName: 'indexName', + searchClient, + }, + widgetParams: { + javascript: { + ...createDefaultWidgetParams(chat), + tools: { + hello: { + templates: { + layout: ({ loader }: any, { html }: any) => + html`
+ ${loader ? loader() : null} +
`, + }, + }, + }, + }, + react: { + ...createDefaultWidgetParams(chat), + tools: { + hello: { + layoutComponent: ({ + loaderComponent: Loader, + }: { + loaderComponent?: (props: any) => React.ReactElement; + }) => ( +
+ {Loader ? : null} +
+ ), + }, + }, + }, + vue: {}, + }, + }); + + await openChat(act); + + expect( + document.querySelector('#tool-with-loader') + ).toBeInTheDocument(); + expect( + document.querySelector( + '#tool-with-loader .ais-ChatMessageLoader' + ) + ).toBeInTheDocument(); + }); + + test('loaderComponent renders custom loader when messagesLoaderComponent is provided', async () => { + const searchClient = createSearchClient(); + + const chat = new Chat({ + messages: [ + { + id: '1', + role: 'assistant', + parts: [ + { + type: 'tool-hello', + toolCallId: '1', + input: { text: 'hello' }, + state: 'output-available', + output: 'hello', + }, + ], + }, + ], + id: 'chat-id', + }); + + const CustomLoader = () => ( +
Loading...
+ ); + + await setup({ + instantSearchOptions: { + indexName: 'indexName', + searchClient, + }, + widgetParams: { + javascript: { + ...createDefaultWidgetParams(chat), + templates: { + messages: { + loader: '
Loading...
', + }, + }, + tools: { + hello: { + templates: { + layout: ({ loader }: any, { html }: any) => + html`
+ ${loader ? loader() : null} +
`, + }, + }, + }, + }, + react: { + ...createDefaultWidgetParams(chat), + messagesLoaderComponent: CustomLoader, + tools: { + hello: { + layoutComponent: ({ + loaderComponent: Loader, + }: { + loaderComponent?: (props: any) => React.ReactElement; + }) => ( +
+ {Loader ? : null} +
+ ), + }, + }, + }, + vue: {}, + }, + }); + + await openChat(act); + + expect( + document.querySelector('#tool-custom-loader') + ).toBeInTheDocument(); + expect( + document.querySelector('#tool-custom-loader .custom-loader') + ).toBeInTheDocument(); + }); + test('shows actions for assistant messages when status is ready', async () => { const searchClient = createSearchClient();