diff --git a/src/uipath_langchain/agent/react/agent.py b/src/uipath_langchain/agent/react/agent.py index 391bd74a8..5f6511779 100644 --- a/src/uipath_langchain/agent/react/agent.py +++ b/src/uipath_langchain/agent/react/agent.py @@ -9,6 +9,9 @@ from uipath.platform.context_grounding import DeepRagContent from uipath.platform.guardrails import BaseGuardrail +from uipath_langchain.agent.tools.client_side_tool import ClientSideToolInfo +from uipath_langchain.chat.hitl import IS_CONVERSATIONAL_CLIENT_SIDE_TOOL + from ...runtime._citations import cas_deep_rag_citation_wrapper from ..guardrails.actions import GuardrailAction from ..tools.structured_tool_with_output_type import StructuredToolWithOutputType @@ -77,7 +80,24 @@ def create_agent( ) llm_tools: list[BaseTool] = [*agent_tools, *flow_control_tools] - init_node = create_init_node(messages, input_schema, config.is_conversational) + # Derive client-side tool schemas from tools for input validation in the init node. + cs_tools: dict[str, ClientSideToolInfo] | None = None + if config.is_conversational: + cs_tools = {} + for t in agent_tools: + meta = getattr(t, "metadata", None) or {} + if meta.get(IS_CONVERSATIONAL_CLIENT_SIDE_TOOL): + cs_tools[t.name] = { + "input_schema": t.args_schema.model_json_schema() + if hasattr(t, "args_schema") and t.args_schema + else None, + "output_schema": meta.get("output_schema"), + } + cs_tools = cs_tools or None + + init_node = create_init_node( + messages, input_schema, config.is_conversational, cs_tools + ) tool_nodes = create_tool_node(agent_tools) diff --git a/src/uipath_langchain/agent/react/init_node.py b/src/uipath_langchain/agent/react/init_node.py index 2b8a9c77b..4a21b06ae 100644 --- a/src/uipath_langchain/agent/react/init_node.py +++ b/src/uipath_langchain/agent/react/init_node.py @@ -6,6 +6,13 @@ from langgraph.types import Overwrite from pydantic import BaseModel +from uipath_langchain.agent.tools.client_side_tool import ( + UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY, + ClientSideToolInfo, + available_client_side_tools, + validate_and_apply_tool_filter, +) + from .job_attachments import ( get_job_attachments, parse_attachments_from_conversation_messages, @@ -17,6 +24,7 @@ def create_init_node( | Callable[[Any], Sequence[SystemMessage | HumanMessage]], input_schema: type[BaseModel] | None, is_conversational: bool = False, + client_side_tools: dict[str, ClientSideToolInfo] | None = None, ): def graph_state_init(state: Any) -> Any: resolved_messages: Sequence[SystemMessage | HumanMessage] | Overwrite @@ -63,6 +71,19 @@ def graph_state_init(state: Any) -> Any: ) job_attachments_dict.update(message_attachments) + # Validate client-side tool declarations from the exchange input + if client_side_tools: + client_tools_input = getattr(state, UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY, None) + if client_tools_input is None: + available_client_side_tools.set(None) + elif not isinstance(client_tools_input, list): + raise ValueError( + f"'{UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY}' must be a list of tool declarations, " + f"got {type(client_tools_input).__name__}." + ) + else: + validate_and_apply_tool_filter(client_tools_input, client_side_tools) + # Calculate initial message count for tracking new messages initial_message_count = ( len(resolved_messages.value) diff --git a/src/uipath_langchain/agent/tools/client_side_tool.py b/src/uipath_langchain/agent/tools/client_side_tool.py new file mode 100644 index 000000000..a56d4bcc5 --- /dev/null +++ b/src/uipath_langchain/agent/tools/client_side_tool.py @@ -0,0 +1,174 @@ +"""Factory for creating client-side tools that execute on the client SDK.""" + +import json +from contextvars import ContextVar +from typing import Annotated, Any, TypedDict + +from langchain_core.messages import ToolMessage +from langchain_core.tools import InjectedToolCallId, StructuredTool +from uipath.agent.models.agent import AgentClientSideToolResourceConfig +from uipath.eval.mocks import mockable + +from uipath_langchain._utils.durable_interrupt import durable_interrupt +from uipath_langchain.agent.react.jsonschema_pydantic_converter import ( + create_model as create_model_from_schema, +) +from uipath_langchain.chat.hitl import IS_CONVERSATIONAL_CLIENT_SIDE_TOOL + +from .utils import sanitize_tool_name + +# When set, only tools in this set are available for the current exchange. +# None means all client-side tools are available (default for CAS/web UI). +available_client_side_tools: ContextVar[set[str] | None] = ContextVar( + "available_client_side_tools", default=None +) + +UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY = "uipath__client_side_tools" + + +class ClientSideToolInfo(TypedDict): + input_schema: dict[str, Any] | None + output_schema: dict[str, Any] | None + + +def validate_and_apply_tool_filter( + declared_tools: list[dict[str, Any]], + agent_tools: dict[str, ClientSideToolInfo], +) -> None: + """Validate client-side tool declarations and set the availability filter. + + Compares the client's declared tools against the agent's tool definitions. + Raises ValueError if required tools are missing or schemas don't match. + Sets the available_client_side_tools context variable for tool functions. + + Args: + declared_tools: List of tool declarations from uipath__client_side_tools input. + Each item is a dict with 'name' and optional 'inputSchema'/'outputSchema'. + agent_tools: The agent's client-side tools. + Dict of {tool_name: ClientSideToolInfo}. + """ + declared: dict[str, dict[str, Any]] = {} + for i, t in enumerate(declared_tools): + if isinstance(t, dict): + if "name" not in t: + raise ValueError( + f"Client-side tool declaration at index {i} is missing required 'name' field." + ) + name = t["name"] + elif isinstance(t, str): + name = t + t = {"name": t} + else: + raise ValueError( + f"Client-side tool declaration at index {i} must be a dict or string, got {type(t).__name__}." + ) + if name in declared: + raise ValueError( + f"Duplicate client-side tool declaration: '{name}'." + ) + declared[name] = t + + required = set(agent_tools.keys()) + missing = required - set(declared.keys()) + if missing: + raise ValueError( + f"Missing required client-side tools: {', '.join(sorted(missing))}. " + f"The client must register handlers for all client-side tools defined by the agent." + ) + + for name, decl in declared.items(): + agent_tool = agent_tools.get(name) + if agent_tool is None: + continue # Unknown tool, runtime will ignore it + if decl.get("inputSchema") and agent_tool.get("input_schema"): + if json.dumps(decl["inputSchema"], sort_keys=True) != json.dumps( + agent_tool["input_schema"], sort_keys=True + ): + raise ValueError( + f"Client-side tool '{name}' inputSchema does not match agent definition." + ) + if decl.get("outputSchema") and agent_tool.get("output_schema"): + if json.dumps(decl["outputSchema"], sort_keys=True) != json.dumps( + agent_tool["output_schema"], sort_keys=True + ): + raise ValueError( + f"Client-side tool '{name}' outputSchema does not match agent definition." + ) + + available_client_side_tools.set(set(declared.keys())) + + +def create_client_side_tool( + resource: AgentClientSideToolResourceConfig, +) -> StructuredTool: + """Create a client-side tool that pauses the graph and waits for the client to execute it. + + The tool uses @durable_interrupt to suspend the graph. The client SDK receives + an executingToolCall event, runs its registered handler, and sends endToolCall + back through CAS. The bridge routes that endToolCall to wait_for_resume(), + which unblocks the graph with the client's result. + """ + tool_name = sanitize_tool_name(resource.name) + input_model = create_model_from_schema(resource.input_schema) + + async def client_side_tool_fn( + *, tool_call_id: Annotated[str, InjectedToolCallId], **kwargs: Any + ) -> Any: + allowed = available_client_side_tools.get() + if allowed is not None and tool_name not in allowed: + return ToolMessage( + content=f"Tool '{tool_name}' is not available — the client has not registered a handler for it.", + tool_call_id=tool_call_id, + status="error", + ) + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=(resource.output_schema or {}), + example_calls=getattr(resource.properties, "example_calls", None), + ) + async def execute_tool() -> dict[str, Any]: + """Execute client-side tool, pausing for client response.""" + + @durable_interrupt + async def wait_for_client_execution() -> dict[str, Any]: + return { + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "input": kwargs, + "is_execution_phase": True, + } + + result = await wait_for_client_execution() + return result.get("output", result) if isinstance(result, dict) else result + + result = await execute_tool() + + if isinstance(result, dict): + try: + content = json.dumps(result) + except TypeError: + content = str(result) + else: + content = str(result) if result is not None else "" + + return ToolMessage( + content=content, + tool_call_id=tool_call_id, + response_metadata={IS_CONVERSATIONAL_CLIENT_SIDE_TOOL: True}, + ) + + tool = StructuredTool( + name=tool_name, + description=resource.description or f"Client-side tool: {tool_name}", + args_schema=input_model, + coroutine=client_side_tool_fn, + metadata={ + IS_CONVERSATIONAL_CLIENT_SIDE_TOOL: True, + "output_schema": resource.output_schema, + }, + ) + + return tool diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 0cbb0135e..f6a7fb4b7 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -5,6 +5,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from uipath.agent.models.agent import ( + AgentClientSideToolResourceConfig, AgentContextResourceConfig, AgentEscalationResourceConfig, AgentIntegrationToolResourceConfig, @@ -18,6 +19,7 @@ from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION +from .client_side_tool import create_client_side_tool from .context_tool import create_context_tool from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool @@ -120,4 +122,7 @@ async def _build_tool_for_resource( elif isinstance(resource, AgentIxpVsEscalationResourceConfig): return create_ixp_escalation_tool(resource) + elif isinstance(resource, AgentClientSideToolResourceConfig): + return create_client_side_tool(resource) + return None diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 88480c5a3..f28a3b28f 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -23,6 +23,7 @@ find_latest_ai_message, ) from uipath_langchain.chat.hitl import ( + IS_CONVERSATIONAL_CLIENT_SIDE_TOOL, REQUIRE_CONVERSATIONAL_CONFIRMATION, request_conversational_tool_confirmation, ) @@ -279,10 +280,13 @@ async def _afunc(state: AgentGraphState) -> OutputType: tool = getattr(tool_node, "tool", None) - # Preserve tool ref so the runtime can discover which tools need confirmation - # (see runtime.py _get_tool_confirmation_info) + # Preserve tool ref so the runtime can discover tool metadata + # (confirmation requirements, client-side markers, etc.) metadata = getattr(tool, "metadata", None) or {} - if isinstance(tool, BaseTool) and metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION): + if isinstance(tool, BaseTool) and ( + metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION) + or metadata.get(IS_CONVERSATIONAL_CLIENT_SIDE_TOOL) + ): return RunnableCallableWithTool( func=_func, afunc=_afunc, name=tool_name, tool=tool ) diff --git a/src/uipath_langchain/chat/hitl.py b/src/uipath_langchain/chat/hitl.py index 72a99800e..198f65d35 100644 --- a/src/uipath_langchain/chat/hitl.py +++ b/src/uipath_langchain/chat/hitl.py @@ -14,6 +14,7 @@ CANCELLED_MESSAGE = "Cancelled by user" ARGS_MODIFIED_MESSAGE = "User has modified the tool arguments" +IS_CONVERSATIONAL_CLIENT_SIDE_TOOL = "uipath_client_tool" CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args" REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation" @@ -132,6 +133,7 @@ def ask_confirmation(): "tool_call_id": tool_call_id, "tool_name": tool.name, "input": tool_args, + "is_execution_phase": False, } response = ask_confirmation() diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 5d7d63aa8..2534de455 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -24,6 +24,7 @@ UiPathConversationContentPartEndEvent, UiPathConversationContentPartEvent, UiPathConversationContentPartStartEvent, + UiPathConversationExecutingToolCallEvent, UiPathConversationMessage, UiPathConversationMessageData, UiPathConversationMessageEndEvent, @@ -39,6 +40,9 @@ ) from uipath.runtime import UiPathRuntimeStorageProtocol +from uipath_langchain.agent.tools.client_side_tool import ClientSideToolInfo +from uipath_langchain.chat.hitl import IS_CONVERSATIONAL_CLIENT_SIDE_TOOL + from ._citations import CitationStreamProcessor, extract_citations_from_text logger = logging.getLogger(__name__) @@ -60,6 +64,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None self.storage = storage self.current_message: AIMessageChunk | AIMessage self.tools_requiring_confirmation: dict[str, Any] = {} + self.client_side_tools: dict[str, ClientSideToolInfo] = {} self.seen_message_ids: set[str] = set() self._storage_lock = asyncio.Lock() self._citation_stream_processor = CitationStreamProcessor() @@ -436,15 +441,40 @@ async def map_current_message_to_start_tool_call_events(self): tool_name in self.tools_requiring_confirmation ) input_schema = self.tools_requiring_confirmation.get(tool_name) + is_client_side = tool_name in self.client_side_tools + output_schema = ( + self.client_side_tools[tool_name].get("output_schema") + if is_client_side + else None + ) events.append( self.map_tool_call_to_tool_call_start_event( self.current_message.id, tool_call, require_confirmation=require_confirmation or None, input_schema=input_schema, + is_client_side_tool=is_client_side or None, + output_schema=output_schema, ) ) + # Emit executingToolCall from MessageMapper for tools without + # a durable interrupt. Tools with interrupts (client-side, HITL) + # get executingToolCall from the bridge instead. + if not require_confirmation and not is_client_side: + events.append( + UiPathConversationMessageEvent( + message_id=self.current_message.id, + tool_call=UiPathConversationToolCallEvent( + tool_call_id=tool_call["id"], + executing=UiPathConversationExecutingToolCallEvent( + tool_name=tool_call["name"], + input=tool_call["args"], + ), + ), + ) + ) + if self.storage is not None: await self.storage.set_value( self.runtime_id, @@ -476,19 +506,26 @@ async def map_tool_message_to_events( # Keep as string if not valid JSON pass - events = [ - UiPathConversationMessageEvent( - message_id=message_id, - tool_call=UiPathConversationToolCallEvent( - tool_call_id=message.tool_call_id, - end=UiPathConversationToolCallEndEvent( - timestamp=self.get_timestamp(), - output=content_value, - is_error=message.status == "error", + # Suppress endToolCall for client-side tools — the client already has the result (it produced it). + is_client_side = message.response_metadata.get( + IS_CONVERSATIONAL_CLIENT_SIDE_TOOL, False + ) + events: list[UiPathConversationMessageEvent] = [] + + if not is_client_side: + events.append( + UiPathConversationMessageEvent( + message_id=message_id, + tool_call=UiPathConversationToolCallEvent( + tool_call_id=message.tool_call_id, + end=UiPathConversationToolCallEndEvent( + timestamp=self.get_timestamp(), + output=content_value, + is_error=message.status == "error", + ), ), - ), + ) ) - ] if is_last_tool_call: events.append(self.map_to_message_end_event(message_id)) @@ -546,6 +583,8 @@ def map_tool_call_to_tool_call_start_event( *, require_confirmation: bool | None = None, input_schema: Any | None = None, + is_client_side_tool: bool | None = None, + output_schema: Any | None = None, ) -> UiPathConversationMessageEvent: return UiPathConversationMessageEvent( message_id=message_id, @@ -557,6 +596,8 @@ def map_tool_call_to_tool_call_start_event( input=tool_call["args"], require_confirmation=require_confirmation, input_schema=input_schema, + is_client_side_tool=is_client_side_tool, + output_schema=output_schema, ), ), ) diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index da8d90918..2baef462e 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -1,5 +1,6 @@ import logging import os +from collections.abc import Iterator from typing import Any, AsyncGenerator from uuid import uuid4 @@ -30,8 +31,12 @@ ) from uipath.runtime.schema import UiPathRuntimeSchema +from uipath_langchain.agent.tools.client_side_tool import ClientSideToolInfo from uipath_langchain.agent.tools.tool_node import RunnableCallableWithTool -from uipath_langchain.chat.hitl import get_confirmation_schema +from uipath_langchain.chat.hitl import ( + IS_CONVERSATIONAL_CLIENT_SIDE_TOOL, + get_confirmation_schema, +) from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError from uipath_langchain.runtime.messages import UiPathChatMessagesMapper from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema @@ -68,6 +73,7 @@ def __init__( self.callbacks: list[BaseCallbackHandler] = callbacks or [] self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) self.chat.tools_requiring_confirmation = self._get_tool_confirmation_info() + self.chat.client_side_tools = self._get_client_side_tools() self._middleware_node_names: set[str] = self._detect_middleware_nodes() async def execute( @@ -490,38 +496,48 @@ def _detect_middleware_nodes(self) -> set[str]: return middleware_nodes - def _get_tool_confirmation_info(self) -> dict[str, Any]: - """Build {tool_name: input_schema} for tools requiring confirmation. - - Walks compiled graph nodes once at runtime init. This is needed because coded agents - (create_agent) export a compiled graph as the only artifact — there's no side channel - to pass confirmation metadata from the build step to the runtime. - """ - schemas: dict[str, Any] = {} + def _iter_graph_tools(self) -> Iterator[BaseTool]: + """Yield all BaseTool instances from compiled graph nodes.""" for node_spec in self.graph.nodes.values(): bound = getattr(node_spec, "bound", None) if bound is None: continue - # Coded agents: one tool per node - if isinstance(bound, RunnableCallableWithTool): - schema = get_confirmation_schema(bound.tool) - if schema is not None: - schemas[bound.tool.name] = schema + tool = getattr(bound, "tool", None) + if isinstance(tool, BaseTool): + yield tool continue - # Low-code agents: multiple tools in one node tools_by_name = getattr(bound, "tools_by_name", None) if isinstance(tools_by_name, dict): - for tool in tools_by_name.values(): - if not isinstance(tool, BaseTool): - continue - schema = get_confirmation_schema(tool) - if schema is not None: - schemas[tool.name] = schema + for t in tools_by_name.values(): + if isinstance(t, BaseTool): + yield t + def _get_tool_confirmation_info(self) -> dict[str, Any]: + """Build {tool_name: input_schema} for tools requiring confirmation.""" + schemas: dict[str, Any] = {} + for tool in self._iter_graph_tools(): + schema = get_confirmation_schema(tool) + if schema is not None: + schemas[tool.name] = schema return schemas + def _get_client_side_tools(self) -> dict[str, ClientSideToolInfo]: + """Build {tool_name: ClientSideToolInfo} for client-side tools.""" + tools: dict[str, ClientSideToolInfo] = {} + for tool in self._iter_graph_tools(): + metadata = getattr(tool, "metadata", None) or {} + if metadata.get(IS_CONVERSATIONAL_CLIENT_SIDE_TOOL): + input_schema = None + if hasattr(tool, "args_schema") and tool.args_schema: + input_schema = tool.args_schema.model_json_schema() + tools[tool.name] = { + "input_schema": input_schema, + "output_schema": metadata.get("output_schema"), + } + return tools + def _is_middleware_node(self, node_name: str) -> bool: """Check if a node name represents a middleware node.""" return node_name in self._middleware_node_names diff --git a/tests/agent/tools/test_client_side_tool_validation.py b/tests/agent/tools/test_client_side_tool_validation.py new file mode 100644 index 000000000..96fed835a --- /dev/null +++ b/tests/agent/tools/test_client_side_tool_validation.py @@ -0,0 +1,273 @@ +"""Tests for client-side tool validation and filtering logic.""" + +import pytest + +from uipath_langchain.agent.tools.client_side_tool import ( + ClientSideToolInfo, + available_client_side_tools, + validate_and_apply_tool_filter, +) + +AGENT_TOOLS: dict[str, ClientSideToolInfo] = { + "get_weather": { + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + "output_schema": { + "type": "object", + "properties": {"temp": {"type": "number"}}, + }, + }, + "show_map": { + "input_schema": None, + "output_schema": None, + }, +} + + +class TestValidateAndApplyToolFilter: + """Tests for validate_and_apply_tool_filter.""" + + def test_valid_declarations_set_filter(self): + declared = [ + {"name": "get_weather"}, + {"name": "show_map"}, + ] + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + allowed = available_client_side_tools.get() + assert allowed == {"get_weather", "show_map"} + + def test_missing_required_tool_raises(self): + declared = [{"name": "get_weather"}] # missing show_map + + with pytest.raises(ValueError, match="Missing required client-side tools"): + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + def test_input_schema_mismatch_raises(self): + declared = [ + { + "name": "get_weather", + "inputSchema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + {"name": "show_map"}, + ] + + with pytest.raises(ValueError, match="inputSchema does not match"): + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + def test_output_schema_mismatch_raises(self): + declared = [ + { + "name": "get_weather", + "outputSchema": { + "type": "object", + "properties": {"temperature": {"type": "string"}}, + }, + }, + {"name": "show_map"}, + ] + + with pytest.raises(ValueError, match="outputSchema does not match"): + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + def test_unknown_extra_tools_are_ignored(self): + declared = [ + {"name": "get_weather"}, + {"name": "show_map"}, + {"name": "unknown_tool"}, + ] + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + allowed = available_client_side_tools.get() + assert allowed is not None + assert "unknown_tool" in allowed + assert "get_weather" in allowed + + def test_string_declarations_accepted(self): + declared = ["get_weather", "show_map"] + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + allowed = available_client_side_tools.get() + assert allowed == {"get_weather", "show_map"} + + def test_missing_name_field_raises(self): + declared = [{"inputSchema": {}}] + + with pytest.raises(ValueError, match="missing required 'name' field"): + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + def test_invalid_type_raises(self): + declared = [123] + + with pytest.raises(ValueError, match="must be a dict or string"): + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + def test_duplicate_name_raises(self): + declared = [ + {"name": "get_weather"}, + {"name": "get_weather"}, + {"name": "show_map"}, + ] + + with pytest.raises(ValueError, match="Duplicate client-side tool"): + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + def test_matching_schemas_pass(self): + declared = [ + { + "name": "get_weather", + "inputSchema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + "outputSchema": { + "type": "object", + "properties": {"temp": {"type": "number"}}, + }, + }, + {"name": "show_map"}, + ] + validate_and_apply_tool_filter(declared, AGENT_TOOLS) + + allowed = available_client_side_tools.get() + assert allowed is not None + assert "get_weather" in allowed + + +class TestToolNotAvailableEnforcement: + """Tests that client_side_tool_fn returns error ToolMessage when tool is filtered out.""" + + def test_tool_not_in_allowed_set_returns_error(self): + token = available_client_side_tools.set({"other_tool"}) + try: + from unittest.mock import AsyncMock, patch + + from uipath.agent.models.agent import AgentClientSideToolResourceConfig + + resource = AgentClientSideToolResourceConfig( + name="my_tool", + description="A test tool", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + output_schema=None, + ) + + from uipath_langchain.agent.tools.client_side_tool import ( + create_client_side_tool, + ) + + tool = create_client_side_tool(resource) + + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + tool.coroutine(tool_call_id="tc-1", query="test") + ) + + assert result.status == "error" + assert "not available" in result.content + finally: + available_client_side_tools.reset(token) + + def test_tool_in_allowed_set_proceeds(self): + """When tool IS in the allowed set, it should NOT return an error. + + We can't fully test execution (it would hit durable_interrupt), + but we verify the availability check passes by patching the interrupt. + """ + token = available_client_side_tools.set({"my_tool"}) + try: + from unittest.mock import AsyncMock, patch + + from uipath.agent.models.agent import AgentClientSideToolResourceConfig + + resource = AgentClientSideToolResourceConfig( + name="my_tool", + description="A test tool", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + output_schema=None, + ) + + from uipath_langchain.agent.tools.client_side_tool import ( + create_client_side_tool, + ) + + tool = create_client_side_tool(resource) + + import asyncio + + # Patch durable_interrupt to avoid GraphInterrupt + with ( + patch( + "uipath_langchain.agent.tools.client_side_tool.durable_interrupt", + side_effect=lambda fn: fn, + ), + patch( + "uipath_langchain.agent.tools.client_side_tool.mockable", + side_effect=lambda **kw: lambda fn: fn, + ), + ): + # Re-create tool after patching + tool = create_client_side_tool(resource) + result = asyncio.get_event_loop().run_until_complete( + tool.coroutine(tool_call_id="tc-1", query="test") + ) + # Should NOT be an error ToolMessage — it proceeded past the availability check + if hasattr(result, "status"): + assert result.status != "error" + finally: + available_client_side_tools.reset(token) + + def test_none_allowed_set_permits_all(self): + """When available_client_side_tools is None (CAS default), all tools proceed.""" + token = available_client_side_tools.set(None) + try: + from uipath.agent.models.agent import AgentClientSideToolResourceConfig + + resource = AgentClientSideToolResourceConfig( + name="any_tool", + description="A test tool", + input_schema={ + "type": "object", + "properties": {"q": {"type": "string"}}, + }, + output_schema=None, + ) + + from unittest.mock import patch + + from uipath_langchain.agent.tools.client_side_tool import ( + create_client_side_tool, + ) + + with ( + patch( + "uipath_langchain.agent.tools.client_side_tool.durable_interrupt", + side_effect=lambda fn: fn, + ), + patch( + "uipath_langchain.agent.tools.client_side_tool.mockable", + side_effect=lambda **kw: lambda fn: fn, + ), + ): + tool = create_client_side_tool(resource) + + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + tool.coroutine(tool_call_id="tc-1", q="test") + ) + if hasattr(result, "status"): + assert result.status != "error" + finally: + available_client_side_tools.reset(token) diff --git a/tests/runtime/test_chat_message_mapper.py b/tests/runtime/test_chat_message_mapper.py index d2bdee792..0f740c40d 100644 --- a/tests/runtime/test_chat_message_mapper.py +++ b/tests/runtime/test_chat_message_mapper.py @@ -2102,3 +2102,156 @@ async def test_mixed_tools_only_confirmation_has_metadata(self): assert "confirm_tool" in tool_starts assert tool_starts["normal_tool"].require_confirmation is None assert tool_starts["confirm_tool"].require_confirmation is True + + +class TestExecutingToolCallEmission: + """Tests for executingToolCall event emission from MessageMapper.""" + + @pytest.mark.asyncio + async def test_emits_executing_for_normal_tool(self): + """Should emit executingToolCall for a server tool without confirmation or client-side marker.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "server_tool", "args": {"x": 1}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + executing_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.executing is not None + ] + assert len(executing_events) == 1 + assert executing_events[0].tool_call.executing.tool_name == "server_tool" + + @pytest.mark.asyncio + async def test_no_executing_for_confirmation_tool(self): + """Should NOT emit executingToolCall for a tool that requires confirmation.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tools_requiring_confirmation = {"confirm_tool": {}} + + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "confirm_tool", "args": {}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + executing_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.executing is not None + ] + assert len(executing_events) == 0 + + @pytest.mark.asyncio + async def test_no_executing_for_client_side_tool(self): + """Should NOT emit executingToolCall for a client-side tool (bridge handles it).""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.client_side_tools = { + "client_tool": {"input_schema": None, "output_schema": {"type": "object"}} + } + + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[ + {"id": "tc-1", "name": "client_tool", "args": {"title": "Avatar"}} + ], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + executing_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.executing is not None + ] + assert len(executing_events) == 0 + + +class TestClientSideToolEndSuppression: + """Tests for suppressing endToolCall for client-side tools.""" + + @pytest.mark.asyncio + async def test_client_side_tool_suppresses_end_event(self): + """ToolMessage with IS_CONVERSATIONAL_CLIENT_SIDE_TOOL should NOT emit endToolCall.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"rating": 9}', + tool_call_id="tool-1", + response_metadata={"uipath_client_tool": True}, + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + tool_end_events = [ + e for e in result if e.tool_call is not None and e.tool_call.end is not None + ] + assert len(tool_end_events) == 0 + + @pytest.mark.asyncio + async def test_client_side_tool_still_emits_message_end(self): + """ToolMessage with IS_CONVERSATIONAL_CLIENT_SIDE_TOOL should still emit message end when it's the last tool.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"rating": 9}', + tool_call_id="tool-1", + response_metadata={"uipath_client_tool": True}, + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + message_end_events = [e for e in result if e.end is not None] + assert len(message_end_events) == 1 + + @pytest.mark.asyncio + async def test_normal_tool_emits_end_event(self): + """ToolMessage without IS_CONVERSATIONAL_CLIENT_SIDE_TOOL should emit endToolCall normally.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"result": "success"}', + tool_call_id="tool-1", + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + tool_end_events = [ + e for e in result if e.tool_call is not None and e.tool_call.end is not None + ] + assert len(tool_end_events) == 1 diff --git a/tests/runtime/test_client_side_tool_discovery.py b/tests/runtime/test_client_side_tool_discovery.py new file mode 100644 index 000000000..1fa3008d0 --- /dev/null +++ b/tests/runtime/test_client_side_tool_discovery.py @@ -0,0 +1,96 @@ +"""Tests that _get_client_side_tools discovers client-side tools through RunnableCallableWithTool wrappers. + +Integration guard: if the wrapping pipeline changes and stops preserving the +BaseTool reference for client-side tools, these tests will fail. +""" + +from typing import Any + +from langchain_core.tools import BaseTool +from langgraph.constants import END, START +from langgraph.graph import StateGraph +from pydantic import BaseModel, Field + +from uipath_langchain.agent.tools.tool_node import ( + UiPathToolNode, + wrap_tools_with_error_handling, +) +from uipath_langchain.chat.hitl import IS_CONVERSATIONAL_CLIENT_SIDE_TOOL +from uipath_langchain.runtime.runtime import UiPathLangGraphRuntime + + +class _ClientSideInput(BaseModel): + title: str = Field(description="Movie title") + + +class _ClientSideTool(BaseTool): + name: str = "client_tool" + description: str = "A client-side tool" + args_schema: type[BaseModel] = _ClientSideInput + metadata: dict[str, Any] = { + IS_CONVERSATIONAL_CLIENT_SIDE_TOOL: True, + "output_schema": { + "type": "object", + "properties": {"rating": {"type": "number"}}, + }, + } + + def _run(self, title: str) -> str: + return f"result for {title}" + + +class _NormalTool(BaseTool): + name: str = "normal_tool" + description: str = "A normal server tool" + + def _run(self) -> str: + return "done" + + +class _MinimalState(BaseModel): + value: str = "" + + +def _compile_graph_with_wrapped_tools(tools: list[BaseTool]): + """Build and compile a minimal graph with tools wrapped through the standard pipeline.""" + tool_nodes = {t.name: UiPathToolNode(t) for t in tools} + wrapped = wrap_tools_with_error_handling(tool_nodes) + + builder: StateGraph[_MinimalState] = StateGraph(_MinimalState) + names = list(wrapped.keys()) + for name, node in wrapped.items(): + builder.add_node(name, node) + + builder.add_edge(START, names[0]) + for i in range(len(names) - 1): + builder.add_edge(names[i], names[i + 1]) + builder.add_edge(names[-1], END) + + return builder.compile() + + +class TestClientSideToolDiscovery: + def test_discovers_client_side_tool_through_wrapper(self): + graph = _compile_graph_with_wrapped_tools([_ClientSideTool(), _NormalTool()]) + runtime = UiPathLangGraphRuntime(graph) + + client_tools = runtime.chat.client_side_tools + assert "client_tool" in client_tools + assert "normal_tool" not in client_tools + + def test_schemas_are_preserved(self): + graph = _compile_graph_with_wrapped_tools([_ClientSideTool()]) + runtime = UiPathLangGraphRuntime(graph) + + tool_info = runtime.chat.client_side_tools["client_tool"] + assert tool_info is not None + assert "output_schema" in tool_info + assert "input_schema" in tool_info + assert "rating" in tool_info["output_schema"]["properties"] + assert "title" in tool_info["input_schema"]["properties"] + + def test_empty_when_no_client_side_tools(self): + graph = _compile_graph_with_wrapped_tools([_NormalTool()]) + runtime = UiPathLangGraphRuntime(graph) + + assert runtime.chat.client_side_tools == {}