diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index ea3ba24e8..ca2f25ee8 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -25,12 +25,12 @@ from database.agent_version_db import query_current_version_no from database.tool_db import search_tools_for_sub_agent from database.model_management_db import get_model_records, get_model_by_model_id +from database.knowledge_db import get_knowledge_name_map_by_index_names from database.client import minio_client from utils.model_name_utils import add_repo_to_name from utils.prompt_template_utils import get_agent_prompt_template from utils.config_utils import tenant_config_manager, get_model_name_from_config from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE -import re logger = logging.getLogger("create_agent_info") logger.setLevel(logging.DEBUG) @@ -358,11 +358,15 @@ async def create_agent_config( if "KnowledgeBaseSearchTool" == tool.class_name: index_names = tool.params.get("index_names") if index_names: + # Reuse the index_name -> display_name mapping from tool.metadata + # (already computed in create_tool_config_list to avoid redundant DB query) + index_name_to_display_map = tool.metadata.get("index_name_to_display_map", {}) if tool.metadata else {} for index_name in index_names: try: + display_name = index_name_to_display_map.get(index_name, index_name) message = ElasticSearchService().get_summary(index_name=index_name) summary = message.get("summary", "") - knowledge_base_summary += f"**{index_name}**: {summary}\n\n" + knowledge_base_summary += f"**{display_name}**: {summary}\n\n" except Exception as e: logger.warning( f"Failed to get summary for knowledge base {index_name}: {e}") @@ -458,10 +462,24 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tenant_id=tenant_id, model_name=rerank_model_name ) + # Build display_name to index_name mapping for LLM parameter conversion + # Also build reverse mapping (index_name -> display_name) for knowledge_base_summary + index_names = param_dict.get("index_names", []) + display_name_to_index_map = {} + index_name_to_display_map = {} + if index_names: + knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) + # Reverse the mapping: display_name (knowledge_name) -> index_name + for idx_name, kb_name in knowledge_name_map.items(): + display_name_to_index_map[kb_name] = idx_name + index_name_to_display_map[idx_name] = kb_name + tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), "rerank_model": rerank_model, + "display_name_to_index_map": display_name_to_index_map, + "index_name_to_display_map": index_name_to_display_map, } elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: rerank = param_dict.get("rerank", False) diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index df42e1888..0d13eb9f7 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -374,3 +374,42 @@ def get_index_name_by_knowledge_name(knowledge_name: str, tenant_id: str) -> str ) except SQLAlchemyError as e: raise e + + +def get_knowledge_name_map_by_index_names(index_names: List[str]) -> Dict[str, str]: + """ + Get a mapping from index_name to knowledge_name (display name) for the given index_names. + Used to build user-friendly knowledge base summaries in prompts. + + Args: + index_names: List of internal index names + + Returns: + Dict[str, str]: Mapping of index_name -> knowledge_name. + If a knowledge base is not found in the database, + the index_name itself is used as the fallback value. + """ + if not index_names: + return {} + + try: + with get_db_session() as session: + result = session.query( + KnowledgeRecord.index_name, + KnowledgeRecord.knowledge_name + ).filter( + KnowledgeRecord.index_name.in_(index_names), + KnowledgeRecord.delete_flag != 'Y' + ).all() + + knowledge_name_map = {} + for row in result: + knowledge_name_map[row.index_name] = row.knowledge_name + + for index_name in index_names: + if index_name not in knowledge_name_map: + knowledge_name_map[index_name] = index_name + + return knowledge_name_map + except SQLAlchemyError as e: + raise e diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index e3a4cfa4f..e77a05643 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -34,6 +34,7 @@ search_last_tool_instance_by_tool_id, update_tool_table_from_scan_tool_list, ) +from database.knowledge_db import get_knowledge_name_map_by_index_names from mcpadapt.smolagents_adapter import _sanitize_function_name from services.file_management_service import get_llm_model from services.vectordatabase_service import get_embedding_model, get_rerank_model, get_vector_db_core @@ -712,11 +713,20 @@ def _validate_local_tool( if rerank and rerank_model_name: rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name) + # Build display_name to index_name mapping for LLM parameter conversion + index_names = instantiation_params.get("index_names", []) + display_name_to_index_map = {} + if index_names: + knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) + for idx_name, kb_name in knowledge_name_map.items(): + display_name_to_index_map[kb_name] = idx_name + params = { **instantiation_params, 'vdb_core': vdb_core, 'embedding_model': embedding_model, 'rerank_model': rerank_model, + 'display_name_to_index_map': display_name_to_index_map, } tool_instance = tool_class(**params) elif tool_name in ["dify_search", "datamate_search"]: diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 6ba851a02..4b1870ee5 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -73,7 +73,7 @@ def create_local_tool(self, tool_config: ToolConfig): # These parameters have exclude=True and cannot be passed to __init__ # due to smolagents.tools.Tool wrapper restrictions filtered_params = {k: v for k, v in params.items() - if k not in ["vdb_core", "embedding_model", "observer", "rerank_model"]} + if k not in ["vdb_core", "embedding_model", "observer", "rerank_model", "display_name_to_index_map"]} # Create instance with only non-excluded parameters tools_obj = tool_class(**filtered_params) # Set excluded parameters directly as attributes after instantiation @@ -85,6 +85,8 @@ def create_local_tool(self, tool_config: ToolConfig): "embedding_model", None) if tool_config.metadata else None tools_obj.rerank_model = tool_config.metadata.get( "rerank_model", None) if tool_config.metadata else None + tools_obj.display_name_to_index_map = tool_config.metadata.get( + "display_name_to_index_map", {}) if tool_config.metadata else {} elif class_name in ["DifySearchTool", "DataMateSearchTool"]: # These parameters have exclude=True and cannot be passed to __init__ filtered_params = {k: v for k, v in params.items() diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index a8863caaf..e3fb2916c 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -86,12 +86,18 @@ def __init__( description="The rerank model to use", default=None, exclude=True), vdb_core: VectorDatabaseCore = Field( description="Vector database client", default=None, exclude=True), + display_name_to_index_map: dict = Field( + description="Mapping from display_name (knowledge_name) to index_name", + default_factory=dict, exclude=True), ): """Initialize the KBSearchTool. Args: top_k (int, optional): Number of results to return. Defaults to 3. observer (MessageObserver, optional): Message observer instance. Defaults to None. + display_name_to_index_map (dict, optional): Mapping from display_name to index_name. + When LLM passes display_name as index_names parameter, it will be converted + to the actual index_name for ES queries. Raises: ValueError: If language is not supported @@ -106,16 +112,49 @@ def __init__( self.rerank = rerank self.rerank_model_name = rerank_model_name self.rerank_model = rerank_model + self.display_name_to_index_map = display_name_to_index_map self.record_ops = 1 # To record serial number self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." + def _convert_to_index_names(self, names: List[str]) -> List[str]: + """Convert display names (knowledge_name) to index names if necessary. + + When LLM passes display_name as the index_names parameter, + this method converts it to the actual index_name for ES queries. + + Args: + names: List of names that could be either display_name or index_name + + Returns: + List of actual index_names for ES queries + """ + display_map = self.display_name_to_index_map + if isinstance(display_map, FieldInfo): + if display_map.default_factory is not None: + display_map = display_map.default_factory() + else: + display_map = display_map.default + if not display_map: + return names + + converted_names = [] + for name in names: + if name in display_map: + converted_names.append(display_map[name]) + else: + converted_names.append(name) + return converted_names + def forward(self, query: str, index_names: Optional[List[str]] = None) -> str: # Parse index_names from string (always required) search_index_names = index_names if index_names is not None else self.index_names + # Convert display names to index names if necessary + search_index_names = self._convert_to_index_names(search_index_names) + # Use the instance search_mode search_mode = self.search_mode @@ -138,9 +177,15 @@ def forward(self, query: str, index_names: Optional[List[str]] = None) -> str: effective_top_k = self.top_k is_rerank = self.rerank if isinstance(effective_top_k, FieldInfo): - effective_top_k = effective_top_k.default + if effective_top_k.default_factory is not None: + effective_top_k = effective_top_k.default_factory() + else: + effective_top_k = effective_top_k.default if isinstance(is_rerank, FieldInfo): - is_rerank = is_rerank.default + if is_rerank.default_factory is not None: + is_rerank = is_rerank.default_factory() + else: + is_rerank = is_rerank.default if is_rerank: effective_top_k = effective_top_k * RERANK_OVERSEARCH_MULTIPLIER diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index a0183d59e..b92f62571 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -97,6 +97,8 @@ def _create_stub_module(name: str, **attrs): ) sys.modules['database.a2a_agent_db'] = a2a_agent_db_stub database_module.a2a_agent_db = a2a_agent_db_stub +sys.modules['database.knowledge_db'] = MagicMock() +sys.modules['database.knowledge_db'].get_knowledge_name_map_by_index_names = MagicMock() sys.modules['services.vectordatabase_service'] = MagicMock() sys.modules['services.tenant_config_service'] = MagicMock() sys.modules['utils.prompt_template_utils'] = MagicMock() @@ -741,13 +743,13 @@ async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self): mock_get_vector_db_core.assert_called_once() mock_embedding.assert_called_once_with(tenant_id="tenant_1") - # Verify metadata contains vdb_core, embedding_model and rerank_model - expected_metadata = { - "vdb_core": mock_vdb_core, - "embedding_model": mock_embedding_model, - "rerank_model": mock_rerank.return_value, - } - assert mock_tool_instance.metadata == expected_metadata + # Verify metadata contains vdb_core, embedding_model, rerank_model and display_name_to_index_map + assert "vdb_core" in mock_tool_instance.metadata + assert "embedding_model" in mock_tool_instance.metadata + assert "rerank_model" in mock_tool_instance.metadata + assert "display_name_to_index_map" in mock_tool_instance.metadata + # display_name_to_index_map should be empty dict when index_names is empty + assert mock_tool_instance.metadata["display_name_to_index_map"] == {} # Explicitly verify that old fields are NOT present assert "index_names" not in mock_tool_instance.metadata @@ -808,12 +810,11 @@ async def test_create_tool_config_list_with_knowledge_base_tool_multiple_tools(s assert len(result) == 2 - # Verify KnowledgeBaseSearchTool has correct metadata - assert mock_tool_kb.metadata == { - "vdb_core": "vdb_core_instance", - "embedding_model": "embedding_instance", - "rerank_model": mock_rerank.return_value, - } + # Verify KnowledgeBaseSearchTool has correct metadata including display_name_to_index_map + assert "vdb_core" in mock_tool_kb.metadata + assert "embedding_model" in mock_tool_kb.metadata + assert "rerank_model" in mock_tool_kb.metadata + assert "display_name_to_index_map" in mock_tool_kb.metadata # Verify OtherTool has no special metadata (should not have metadata attribute set) # Note: MagicMock will return a new MagicMock for unset attributes, so we check call_args @@ -861,11 +862,9 @@ async def test_create_tool_config_list_with_knowledge_base_tool_mixed_sources(se assert len(result) == 1 # Even for MCP-sourced KnowledgeBaseSearchTool, metadata should be set - assert mock_tool_instance.metadata == { - "vdb_core": "vdb_core", - "embedding_model": "embedding", - "rerank_model": mock_rerank.return_value, - } + assert "vdb_core" in mock_tool_instance.metadata + assert "embedding_model" in mock_tool_instance.metadata + assert "display_name_to_index_map" in mock_tool_instance.metadata @pytest.mark.asyncio async def test_create_tool_config_list_with_datamate_tool(self): @@ -1010,14 +1009,13 @@ async def test_create_tool_config_list_multiple_tools_same_type(self): assert len(result) == 2 - # Both tools should have the same simplified metadata - expected_metadata = { - "vdb_core": "vdb_core", - "embedding_model": "embedding", - "rerank_model": mock_rerank.return_value, - } - assert mock_tool_1.metadata == expected_metadata - assert mock_tool_2.metadata == expected_metadata + # Both tools should have the same metadata including display_name_to_index_map + assert "vdb_core" in mock_tool_1.metadata + assert "embedding_model" in mock_tool_1.metadata + assert "rerank_model" in mock_tool_1.metadata + assert "display_name_to_index_map" in mock_tool_1.metadata + assert mock_tool_1.metadata["display_name_to_index_map"] == {} + assert mock_tool_2.metadata["display_name_to_index_map"] == {} @pytest.mark.asyncio async def test_create_tool_config_list_with_dify_tool(self): @@ -1929,6 +1927,9 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): kb_tool_1.class_name = "KnowledgeBaseSearchTool" kb_tool_1.name = "kb_tool_1" kb_tool_1.params = {"index_names": ["idx_a", "idx_b"]} + kb_tool_1.metadata = { + "index_name_to_display_map": {"idx_a": "idx_a", "idx_b": "idx_b"} + } other_tool = Mock() other_tool.class_name = "OtherTool" @@ -1939,6 +1940,9 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): kb_tool_2.class_name = "KnowledgeBaseSearchTool" kb_tool_2.name = "kb_tool_2" kb_tool_2.params = {"index_names": ["idx_c"]} + kb_tool_2.metadata = { + "index_name_to_display_map": {"idx_c": "idx_c"} + } mock_create_tools.return_value = [kb_tool_1, other_tool, kb_tool_2] mock_get_template.return_value = {"system_prompt": "{{ knowledge_base_summary }}"} @@ -1977,6 +1981,214 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): # Ensure only the first KnowledgeBaseSearchTool is processed. assert "idx_c" not in str(mock_es_instance.get_summary.call_args_list) + @pytest.mark.asyncio + async def test_create_agent_config_uses_metadata_index_name_to_display_map(self): + """Test that create_agent_config uses index_name_to_display_map from tool.metadata. + + This test verifies the refactored behavior where create_agent_config + reuses the index_name -> display_name mapping from tool.metadata instead of + making redundant database queries. + """ + with ( + patch( + "backend.agents.create_agent_info.search_agent_info_by_agent_id" + ) as mock_search_agent, + patch( + "backend.agents.create_agent_info.query_sub_agents_id_list" + ) as mock_query_sub, + patch( + "backend.agents.create_agent_info.create_tool_config_list" + ) as mock_create_tools, + patch( + "backend.agents.create_agent_info.get_agent_prompt_template" + ) as mock_get_template, + patch( + "backend.agents.create_agent_info.tenant_config_manager" + ) as mock_tenant_config, + patch( + "backend.agents.create_agent_info.build_memory_context" + ) as mock_build_memory, + patch( + "backend.agents.create_agent_info.ElasticSearchService" + ) as mock_es_service, + patch( + "backend.agents.create_agent_info.prepare_prompt_templates" + ) as mock_prepare_templates, + patch( + "backend.agents.create_agent_info.get_model_by_model_id" + ) as mock_get_model_by_id, + patch( + "backend.agents.create_agent_info._get_skills_for_template" + ) as mock_get_skills, + patch( + "backend.agents.create_agent_info._get_skill_script_tools" + ) as mock_get_skill_tools, + patch( + "backend.agents.create_agent_info.get_knowledge_name_map_by_index_names" + ) as mock_get_knowledge_name_map, + ): + mock_search_agent.return_value = { + "name": "test_agent", + "description": "test description", + "duty_prompt": "test duty", + "constraint_prompt": "test constraint", + "few_shots_prompt": "test few shots", + "max_steps": 5, + "model_id": 123, + "provide_run_summary": True, + } + mock_query_sub.return_value = [] + + # Create a tool with index_name_to_display_map in metadata + kb_tool = Mock() + kb_tool.class_name = "KnowledgeBaseSearchTool" + kb_tool.name = "kb_tool" + kb_tool.params = {"index_names": ["idx1", "idx2"]} + # The tool.metadata contains the index_name -> display_name mapping + kb_tool.metadata = { + "index_name_to_display_map": { + "idx1": "Custom Name 1", + "idx2": "Custom Name 2" + } + } + + mock_create_tools.return_value = [kb_tool] + mock_get_template.return_value = {"system_prompt": "{{ knowledge_base_summary }}"} + mock_tenant_config.get_app_config.side_effect = ["TestApp", "Test Description"] + mock_build_memory.return_value = Mock( + user_config=Mock(memory_switch=False), + memory_config={}, + tenant_id="tenant_1", + user_id="user_1", + agent_id="agent_1", + ) + mock_prepare_templates.return_value = {"system_prompt": "populated_system_prompt"} + mock_get_model_by_id.return_value = {"display_name": "test_model"} + mock_get_skills.return_value = [] + mock_get_skill_tools.return_value = [] + # This should NOT be called when tool.metadata has index_name_to_display_map + mock_get_knowledge_name_map.return_value = {"idx1": "idx1", "idx2": "idx2"} + + mock_es_instance = Mock() + mock_es_instance.get_summary.side_effect = [ + {"summary": "Summary 1"}, + {"summary": "Summary 2"}, + ] + mock_es_service.return_value = mock_es_instance + + await create_agent_config("agent_1", "tenant_1", "user_1", "zh", "test query") + + # Verify ElasticSearchService was called for both indices + assert mock_es_instance.get_summary.call_count == 2 + + # Verify get_knowledge_name_map_by_index_names was NOT called + # because we're using the mapping from tool.metadata + mock_get_knowledge_name_map.assert_not_called() + + # Verify the system prompt uses the display names from metadata + mock_prepare_templates.assert_called_once() + system_prompt = mock_prepare_templates.call_args[1]["system_prompt"] + assert "**Custom Name 1**" in system_prompt + assert "**Custom Name 2**" in system_prompt + assert "idx1" not in system_prompt + assert "idx2" not in system_prompt + + @pytest.mark.asyncio + async def test_create_agent_config_metadata_without_index_name_to_display_map(self): + """Test that create_agent_config handles missing index_name_to_display_map gracefully. + + When tool.metadata exists but doesn't have index_name_to_display_map, + it should fall back to using index_name as display_name. + """ + with ( + patch( + "backend.agents.create_agent_info.search_agent_info_by_agent_id" + ) as mock_search_agent, + patch( + "backend.agents.create_agent_info.query_sub_agents_id_list" + ) as mock_query_sub, + patch( + "backend.agents.create_agent_info.create_tool_config_list" + ) as mock_create_tools, + patch( + "backend.agents.create_agent_info.get_agent_prompt_template" + ) as mock_get_template, + patch( + "backend.agents.create_agent_info.tenant_config_manager" + ) as mock_tenant_config, + patch( + "backend.agents.create_agent_info.build_memory_context" + ) as mock_build_memory, + patch( + "backend.agents.create_agent_info.ElasticSearchService" + ) as mock_es_service, + patch( + "backend.agents.create_agent_info.prepare_prompt_templates" + ) as mock_prepare_templates, + patch( + "backend.agents.create_agent_info.get_model_by_model_id" + ) as mock_get_model_by_id, + patch( + "backend.agents.create_agent_info._get_skills_for_template" + ) as mock_get_skills, + patch( + "backend.agents.create_agent_info._get_skill_script_tools" + ) as mock_get_skill_tools, + patch( + "backend.agents.create_agent_info.get_knowledge_name_map_by_index_names" + ) as mock_get_knowledge_name_map, + ): + mock_search_agent.return_value = { + "name": "test_agent", + "description": "test description", + "duty_prompt": "test duty", + "constraint_prompt": "test constraint", + "few_shots_prompt": "test few shots", + "max_steps": 5, + "model_id": 123, + "provide_run_summary": True, + } + mock_query_sub.return_value = [] + + # Create a tool with empty metadata (no index_name_to_display_map) + kb_tool = Mock() + kb_tool.class_name = "KnowledgeBaseSearchTool" + kb_tool.name = "kb_tool" + kb_tool.params = {"index_names": ["idx1", "idx2"]} + kb_tool.metadata = {} # Empty metadata + + mock_create_tools.return_value = [kb_tool] + mock_get_template.return_value = {"system_prompt": "{{ knowledge_base_summary }}"} + mock_tenant_config.get_app_config.side_effect = ["TestApp", "Test Description"] + mock_build_memory.return_value = Mock( + user_config=Mock(memory_switch=False), + memory_config={}, + tenant_id="tenant_1", + user_id="user_1", + agent_id="agent_1", + ) + mock_prepare_templates.return_value = {"system_prompt": "populated_system_prompt"} + mock_get_model_by_id.return_value = {"display_name": "test_model"} + mock_get_skills.return_value = [] + mock_get_skill_tools.return_value = [] + mock_get_knowledge_name_map.return_value = {} + + mock_es_instance = Mock() + mock_es_instance.get_summary.side_effect = [ + {"summary": "Summary 1"}, + {"summary": "Summary 2"}, + ] + mock_es_service.return_value = mock_es_instance + + await create_agent_config("agent_1", "tenant_1", "user_1", "zh", "test query") + + # When metadata is empty, it should fall back to using index_name + # as the display_name (no mapping available) + mock_prepare_templates.assert_called_once() + system_prompt = mock_prepare_templates.call_args[1]["system_prompt"] + assert "**idx1**" in system_prompt + assert "**idx2**" in system_prompt + @pytest.mark.parametrize( "language,expected_message", [ @@ -3308,5 +3520,366 @@ def test_get_external_a2a_agents_exception_handling(self): assert "Database error" in mock_logger.error.call_args[0][0] +class TestCreateToolConfigListWithDisplayNameMap: + """Tests for create_tool_config_list with display_name_to_index_map functionality""" + + @pytest.mark.asyncio + async def test_knowledge_base_with_display_name_to_index_map(self): + """Test that KnowledgeBaseSearchTool gets correct display_name_to_index_map from index_names""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Mock the knowledge name map: index_name -> knowledge_name (display_name) + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + assert len(result) == 1 + # Verify get_knowledge_name_map_by_index_names was called + mock_get_knowledge_map.assert_called_once_with(["idx1", "idx2"]) + # Verify display_name_to_index_map contains reversed mapping + assert result[0].metadata["display_name_to_index_map"] == { + "Knowledge Base 1": "idx1", + "Knowledge Base 2": "idx2" + } + + @pytest.mark.asyncio + async def test_knowledge_base_with_empty_index_names(self): + """Test that KnowledgeBaseSearchTool gets empty display_name_to_index_map when no index_names""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": []}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + # get_knowledge_name_map_by_index_names should NOT be called with empty index_names + mock_get_knowledge_map.assert_not_called() + assert result[0].metadata["display_name_to_index_map"] == {} + + @pytest.mark.asyncio + async def test_knowledge_base_with_partial_name_mapping(self): + """Test that KnowledgeBaseSearchTool handles partial name mapping correctly""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2", "idx3"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Only idx1 is found in database, idx2 and idx3 are not found + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + # display_name_to_index_map should only contain the found mappings + # Unfound indices will use index_name as fallback (which is not in get_knowledge_name_map result) + assert "Knowledge Base 1" in result[0].metadata["display_name_to_index_map"] + + @pytest.mark.asyncio + async def test_knowledge_base_with_index_name_to_display_map(self): + """Test that KnowledgeBaseSearchTool gets correct index_name_to_display_map from index_names. + + This test verifies the reverse mapping (index_name -> display_name) that was added + to avoid redundant database queries when building knowledge_base_summary. + """ + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Mock the knowledge name map: index_name -> knowledge_name (display_name) + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + assert len(result) == 1 + # Verify display_name_to_index_map (original mapping) + assert result[0].metadata["display_name_to_index_map"] == { + "Knowledge Base 1": "idx1", + "Knowledge Base 2": "idx2" + } + # Verify index_name_to_display_map (new reverse mapping) + assert result[0].metadata["index_name_to_display_map"] == { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + # Both maps should be present + assert "display_name_to_index_map" in result[0].metadata + assert "index_name_to_display_map" in result[0].metadata + + @pytest.mark.asyncio + async def test_knowledge_base_with_partial_index_name_mapping(self): + """Test that KnowledgeBaseSearchTool handles partial index_name_to_display_map correctly. + + When some index_names are not found in the database, they should not be + added to the index_name_to_display_map. + """ + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2", "idx3"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Only idx1 and idx2 are found, idx3 is not in the database + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + # Verify both mappings contain only found entries + assert "idx1" in result[0].metadata["index_name_to_display_map"] + assert "idx2" in result[0].metadata["index_name_to_display_map"] + # idx3 was not found, so it should not be in the map + assert "idx3" not in result[0].metadata["index_name_to_display_map"] + + # Verify reverse mapping also contains only found entries + assert "Knowledge Base 1" in result[0].metadata["display_name_to_index_map"] + assert "Knowledge Base 2" in result[0].metadata["display_name_to_index_map"] + assert "idx3" not in result[0].metadata["display_name_to_index_map"] + + +class TestFilterMcpServersAndTools: + """Tests for filter_mcp_servers_and_tools function""" + + def test_filter_mcp_servers_with_multiple_tools(self): + """Test filtering with multiple MCP tools""" + mock_tool1 = MagicMock() + mock_tool1.source = "mcp" + mock_tool1.usage = "server1" + + mock_tool2 = MagicMock() + mock_tool2.source = "local" + mock_tool2.usage = None + + mock_tool3 = MagicMock() + mock_tool3.source = "mcp" + mock_tool3.usage = "server2" + + mock_sub_agent = MagicMock() + mock_sub_agent.tools = [] + mock_sub_agent.managed_agents = [] + + mock_agent_config = MagicMock() + mock_agent_config.tools = [mock_tool1, mock_tool2, mock_tool3] + mock_agent_config.managed_agents = [mock_sub_agent] + + mcp_info_dict = { + "server1": {"remote_mcp_server": "http://server1.example.com"}, + "server2": {"remote_mcp_server": "http://server2.example.com"}, + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert len(result) == 2 + assert "http://server1.example.com" in result + assert "http://server2.example.com" in result + + def test_filter_mcp_servers_with_nested_sub_agents(self): + """Test filtering with nested sub-agents""" + mock_tool1 = MagicMock() + mock_tool1.source = "mcp" + mock_tool1.usage = "nested_server" + + mock_sub_sub_agent = MagicMock() + mock_sub_sub_agent.tools = [mock_tool1] + mock_sub_sub_agent.managed_agents = [] + + mock_sub_agent = MagicMock() + mock_sub_agent.tools = [] + mock_sub_agent.managed_agents = [mock_sub_sub_agent] + + mock_agent_config = MagicMock() + mock_agent_config.tools = [] + mock_agent_config.managed_agents = [mock_sub_agent] + + mcp_info_dict = { + "nested_server": {"remote_mcp_server": "http://nested.example.com"}, + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert len(result) == 1 + assert "http://nested.example.com" in result + + def test_filter_mcp_servers_with_disabled_server(self): + """Test filtering excludes servers not in mcp_info_dict""" + mock_tool1 = MagicMock() + mock_tool1.source = "mcp" + mock_tool1.usage = "enabled_server" + + mock_tool2 = MagicMock() + mock_tool2.source = "mcp" + mock_tool2.usage = "disabled_server" + + mock_agent_config = MagicMock() + mock_agent_config.tools = [mock_tool1, mock_tool2] + mock_agent_config.managed_agents = [] + + mcp_info_dict = { + "enabled_server": {"remote_mcp_server": "http://enabled.example.com"}, + # disabled_server is not in the dict + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert len(result) == 1 + assert "http://enabled.example.com" in result + + def test_filter_mcp_servers_with_empty_tools(self): + """Test filtering with no tools returns empty list""" + mock_agent_config = MagicMock() + mock_agent_config.tools = [] + mock_agent_config.managed_agents = [] + + mcp_info_dict = { + "server1": {"remote_mcp_server": "http://server1.example.com"}, + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert result == [] + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index 496e04b19..724a62c68 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -51,7 +51,8 @@ get_index_name_by_knowledge_name, get_knowledge_info_by_tenant_and_source, upsert_knowledge_record, - _generate_index_name + _generate_index_name, + get_knowledge_name_map_by_index_names, ) @@ -1948,3 +1949,140 @@ def mock_exit(exc_type, exc_val, exc_tb): with pytest.raises(MockSQLAlchemyError, match="Database error"): get_knowledge_info_by_tenant_and_source("tenant1", "datamate") + + +def test_get_knowledge_name_map_by_index_names_success(monkeypatch, mock_session): + """Test successfully getting knowledge name map by index names""" + session, query = mock_session + + # Create mock records with index_name and knowledge_name + class MockRow: + def __init__(self, index_name, knowledge_name): + self.index_name = index_name + self.knowledge_name = knowledge_name + + mock_rows = [ + MockRow("index1", "Knowledge Base 1"), + MockRow("index2", "Knowledge Base 2"), + ] + + mock_filter = MagicMock() + mock_filter.all.return_value = mock_rows + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = ["index1", "index2"] + result = get_knowledge_name_map_by_index_names(index_names) + + expected = { + "index1": "Knowledge Base 1", + "index2": "Knowledge Base 2", + } + assert result == expected + + +def test_get_knowledge_name_map_by_index_names_with_fallback(monkeypatch, mock_session): + """Test get_knowledge_name_map_by_index_names uses index_name as fallback when not found""" + session, query = mock_session + + # Only return one of the two index names + class MockRow: + def __init__(self, index_name, knowledge_name): + self.index_name = index_name + self.knowledge_name = knowledge_name + + mock_rows = [ + MockRow("index1", "Knowledge Base 1"), + # index2 is not found in database + ] + + mock_filter = MagicMock() + mock_filter.all.return_value = mock_rows + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = ["index1", "index2"] + result = get_knowledge_name_map_by_index_names(index_names) + + expected = { + "index1": "Knowledge Base 1", + "index2": "index2", # Falls back to index_name + } + assert result == expected + + +def test_get_knowledge_name_map_by_index_names_empty_list(monkeypatch): + """Test get_knowledge_name_map_by_index_names with empty list returns empty dict""" + result = get_knowledge_name_map_by_index_names([]) + + assert result == {} + + +def test_get_knowledge_name_map_by_index_names_no_results(monkeypatch, mock_session): + """Test get_knowledge_name_map_by_index_names when no records found""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.all.return_value = [] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = ["nonexistent1", "nonexistent2"] + result = get_knowledge_name_map_by_index_names(index_names) + + # Should return index_names as fallback for all + expected = { + "nonexistent1": "nonexistent1", + "nonexistent2": "nonexistent2", + } + assert result == expected + + +def test_get_knowledge_name_map_by_index_names_exception(monkeypatch, mock_session): + """Test exception during get_knowledge_name_map_by_index_names""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + get_knowledge_name_map_by_index_names(["index1", "index2"]) diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 7dedc9dba..9b9b5f485 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -269,6 +269,12 @@ def validate(self): sys.modules['nexent.storage.storage_client_factory'] = storage_factory_module sys.modules['nexent.storage.minio_config'] = storage_config_module +# Mock nexent.memory module to break import chain before loading backend modules +memory_service_module = types.ModuleType('nexent.memory.memory_service') +memory_service_module.clear_memory = MagicMock() +sys.modules['nexent.memory'] = _create_package_mock('nexent.memory') +sys.modules['nexent.memory.memory_service'] = memory_service_module + # Load actual backend modules so that patch targets resolve correctly import importlib # noqa: E402 backend_module = importlib.import_module('backend') @@ -321,6 +327,7 @@ def validate(self): patch('services.tenant_config_service.build_knowledge_name_mapping', MagicMock()).start() patch('services.image_service.get_vlm_model', MagicMock()).start() +patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start() # Import consts after patching dependencies from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest # noqa: E402 @@ -2195,12 +2202,13 @@ async def test_validate_tool_langchain_tool_not_found(self, mock_validate_tool_i class TestValidateLocalToolKnowledgeBaseSearch: """Test cases for _validate_local_tool function with knowledge_base_search tool""" + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector_db_core, mock_get_embedding_model, - mock_signature, mock_get_class): + mock_signature, mock_get_class, mock_get_knowledge_map): """Test successful knowledge_base_search tool validation with proper dependencies""" # Mock tool class mock_tool_class = Mock() @@ -2228,6 +2236,9 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core + # Mock knowledge name map to return empty dict for this test + mock_get_knowledge_map.return_value = {} + from backend.services.tool_configuration_service import _validate_local_tool result = _validate_local_tool( @@ -2248,6 +2259,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", "rerank_model": None, + "display_name_to_index_map": {}, } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") @@ -2255,6 +2267,61 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Verify service calls mock_get_embedding_model.assert_called_once_with(tenant_id="tenant1") + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.get_embedding_model') + @patch('backend.services.tool_configuration_service.get_vector_db_core') + def test_validate_local_tool_knowledge_base_search_with_display_name_mapping( + self, mock_get_vector_db_core, mock_get_embedding_model, mock_get_class, mock_get_knowledge_map): + """Test knowledge_base_search tool with display_name_to_index_map parameter""" + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "mapped knowledge result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + + mock_get_embedding_model.return_value = "mock_embedding_model" + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core + + # Mock the knowledge name map for display_name to index_name mapping + mock_get_knowledge_map.return_value = { + "test_index_1": "Display Knowledge 1", + "test_index_2": "Display Knowledge 2" + } + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + "knowledge_base_search", + {"query": "test query"}, + {"index_names": ["test_index_1", "test_index_2"]}, + "tenant1", + "user1" + ) + + assert result == "mapped knowledge result" + + # Verify tool class was called exactly once + assert mock_tool_class.call_count == 1, f"Expected 1 call, got {mock_tool_class.call_count}" + + # Get the actual call arguments + actual_call = mock_tool_class.call_args + actual_kwargs = actual_call.kwargs if actual_call.kwargs else actual_call[1] + + # Verify each expected parameter + assert actual_kwargs.get("index_names") == ["test_index_1", "test_index_2"] + assert actual_kwargs.get("vdb_core") == mock_vdb_core + assert actual_kwargs.get("embedding_model") == "mock_embedding_model" + assert actual_kwargs.get("rerank_model") is None + assert actual_kwargs.get("display_name_to_index_map") == { + "Display Knowledge 1": "test_index_1", + "Display Knowledge 2": "test_index_2" + } + + # Verify knowledge name map was called with index_names + mock_get_knowledge_map.assert_called_once_with(["test_index_1", "test_index_2"]) + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') @@ -2339,6 +2406,7 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g assert result == "knowledge base search result" + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_embedding_model') @@ -2346,7 +2414,8 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_get_vector_db_core, mock_get_embedding_model, mock_signature, - mock_get_class): + mock_get_class, + mock_get_knowledge_map): """Test knowledge_base_search tool validation with empty knowledge list""" # Mock tool class mock_tool_class = Mock() @@ -2392,11 +2461,13 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", "rerank_model": None, + "display_name_to_index_map": {}, } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_embedding_model') @@ -2404,7 +2475,8 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_get_vector_db_core, mock_get_embedding_model, mock_signature, - mock_get_class): + mock_get_class, + mock_get_knowledge_map): """Test knowledge_base_search tool validation when execution fails""" # Mock tool class mock_tool_class = Mock() diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 474fa8baa..d198fb067 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -880,6 +880,122 @@ def test_create_local_tool_knowledge_base_search_tool_with_none_defaults(nexent_ assert result == mock_kb_tool_instance +def test_create_local_tool_knowledge_base_with_display_name_map(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation sets display_name_to_index_map from metadata.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + display_name_map = { + "Knowledge A": "es_index_knowledge_a", + "Knowledge B": "es_index_knowledge_b", + } + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata={ + "vdb_core": "mock_vdb_core", + "embedding_model": "mock_embedding_model", + "rerank_model": "mock_rerank_model", + "display_name_to_index_map": display_name_map, + }, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify display_name_to_index_map was set correctly from metadata + assert result.display_name_to_index_map == display_name_map + assert result.vdb_core == "mock_vdb_core" + assert result.embedding_model == "mock_embedding_model" + assert result.rerank_model == "mock_rerank_model" + + +def test_create_local_tool_knowledge_base_with_empty_display_name_map(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation handles empty display_name_to_index_map.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata={ + "vdb_core": "mock_vdb_core", + "embedding_model": "mock_embedding_model", + "display_name_to_index_map": {}, + }, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify empty display_name_to_index_map was set + assert result.display_name_to_index_map == {} + + +def test_create_local_tool_knowledge_base_without_metadata(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation handles missing metadata.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata=None, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify defaults were set when metadata is None + assert result.display_name_to_index_map == {} + assert result.vdb_core is None + assert result.embedding_model is None + assert result.rerank_model is None + + def test_create_local_tool_analyze_text_file_tool(nexent_agent_instance): """Test AnalyzeTextFileTool creation injects observer and metadata.""" mock_analyze_tool_class = MagicMock() diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index ad6c7987b..bcfeaddc4 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -40,6 +40,7 @@ def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_mode vdb_core=mock_vdb_core, search_mode="hybrid", rerank=False, + display_name_to_index_map={}, ) return tool @@ -395,6 +396,7 @@ def test_forward_with_rerank_enabled(self, mock_observer, mock_vdb_core, mock_em vdb_core=mock_vdb_core, embedding_model=mock_embedding_model, observer=mock_observer, + display_name_to_index_map={}, ) result = tool.forward("test query") @@ -433,6 +435,7 @@ def test_forward_rerank_disabled(self, mock_observer, mock_vdb_core, mock_embedd vdb_core=mock_vdb_core, embedding_model=mock_embedding_model, observer=mock_observer, + display_name_to_index_map={}, ) result = tool.forward("test query") @@ -472,6 +475,7 @@ def test_forward_rerank_error_continues(self, mock_observer, mock_vdb_core, mock vdb_core=mock_vdb_core, embedding_model=mock_embedding_model, observer=mock_observer, + display_name_to_index_map={}, ) # Should not raise, should continue with original results @@ -536,3 +540,679 @@ def test_forward_with_whitespace_in_index_names(self, knowledge_base_search_tool embedding_model=knowledge_base_search_tool.embedding_model, top_k=5 ) + + +class TestConvertToIndexNames: + """Tests for _convert_to_index_names method.""" + + def test_convert_with_empty_map(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when display_name_to_index_map is empty.""" + tool = KnowledgeBaseSearchTool( + index_names=["index1", "index2"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + result = tool._convert_to_index_names(["index1", "index2"]) + + assert result == ["index1", "index2"] + + def test_convert_with_matching_names(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when names are in the map.""" + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + "Knowledge B": "es_index_knowledge_b", + }, + ) + + result = tool._convert_to_index_names(["Knowledge A", "Knowledge B"]) + + assert result == ["es_index_knowledge_a", "es_index_knowledge_b"] + + def test_convert_with_mixed_names(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when some names are in the map and some are not.""" + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + }, + ) + + result = tool._convert_to_index_names(["Knowledge A", "raw_index_name"]) + + assert result == ["es_index_knowledge_a", "raw_index_name"] + + def test_convert_with_unmatched_names(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when no names are in the map.""" + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + }, + ) + + result = tool._convert_to_index_names(["raw_index1", "raw_index2"]) + + assert result == ["raw_index1", "raw_index2"] + + def test_convert_forward_integration(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that forward method uses _convert_to_index_names correctly.""" + mock_results = create_mock_search_result(1) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + }, + ) + + tool.forward("test query", index_names=["Knowledge A"]) + + mock_vdb_core.hybrid_search.assert_called_once_with( + index_names=["es_index_knowledge_a"], + query_text="test query", + embedding_model=mock_embedding_model, + top_k=3 + ) + + +class TestEffectiveTopK: + """Tests for effective_top_k calculation with rerank.""" + + def test_effective_top_k_increases_with_rerank(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that effective_top_k is multiplied when rerank is enabled.""" + from sdk.nexent.core.utils.constants import RERANK_OVERSEARCH_MULTIPLIER + + mock_results = create_mock_search_result(10) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + top_k=5, + rerank=True, + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + tool.forward("test query") + + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 5 * RERANK_OVERSEARCH_MULTIPLIER + + def test_effective_top_k_unchanged_without_rerank(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that effective_top_k remains the same when rerank is disabled.""" + mock_results = create_mock_search_result(5) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + top_k=5, + rerank=False, + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + tool.forward("test query") + + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 5 + + +class TestSourceTypeConversion: + """Tests for source_type conversion (local/minio -> file).""" + + def test_source_type_local_converted_to_file(self, knowledge_base_search_tool, mock_vdb_core): + """Test that source_type 'local' is converted to 'file'.""" + mock_results = [ + { + "document": { + "title": "Local Doc", + "content": "Content from local file", + "filename": "local.txt", + "path_or_url": "/path/local.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "local" + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message which contains full results via to_dict() + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["source_type"] == "file" + + def test_source_type_minio_converted_to_file(self, knowledge_base_search_tool, mock_vdb_core): + """Test that source_type 'minio' is converted to 'file'.""" + mock_results = [ + { + "document": { + "title": "Minio Doc", + "content": "Content from minio storage", + "filename": "minio.txt", + "path_or_url": "/minio/bucket/minio.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "minio" + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["source_type"] == "file" + + def test_source_type_other_unchanged(self, knowledge_base_search_tool, mock_vdb_core): + """Test that source_type other than local/minio remains unchanged.""" + mock_results = [ + { + "document": { + "title": "Web Doc", + "content": "Content from web", + "filename": "web.html", + "path_or_url": "https://example.com/page.html", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "web" + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["source_type"] == "web" + + +class TestRecordOps: + """Tests for record_ops counter functionality.""" + + def test_record_ops_increments_by_result_count(self, knowledge_base_search_tool): + """Test that record_ops increases by the number of results returned.""" + mock_results = create_mock_search_result(2) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + initial_ops = knowledge_base_search_tool.record_ops + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + assert knowledge_base_search_tool.record_ops == initial_ops + 2 + + def test_record_ops_accumulates_across_calls(self, knowledge_base_search_tool): + """Test that record_ops accumulates across multiple forward calls.""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.record_ops = 0 + knowledge_base_search_tool.forward("query1", index_names=["kb1"]) + first_call_ops = knowledge_base_search_tool.record_ops + + knowledge_base_search_tool.forward("query2", index_names=["kb1"]) + second_call_ops = knowledge_base_search_tool.record_ops + + # Each call with 1 result adds 1 to record_ops + assert first_call_ops == 1 + assert second_call_ops == 2 + + def test_cite_index_in_results(self, knowledge_base_search_tool): + """Test that cite_index in results starts from record_ops + index + 1.""" + mock_results = create_mock_search_result(2) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + # record_ops starts at 1, so cite_index should be 1+0+1=1, 1+1+1=2 + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message for cite_index values + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["cite_index"] == 1 + assert full_results[1]["cite_index"] == 2 + + +class TestSearchContentObserver: + """Tests for SEARCH_CONTENT observer message.""" + + def test_forward_sends_search_content_to_observer(self, knowledge_base_search_tool): + """Test that forward sends SEARCH_CONTENT message to observer.""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + search_content_calls = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ] + + assert len(search_content_calls) == 1 + message = search_content_calls[0][0][2] + parsed = json.loads(message) + assert isinstance(parsed, list) + assert len(parsed) == 1 + + def test_forward_no_search_content_without_observer(self, mock_vdb_core, mock_embedding_model): + """Test that forward works without observer and doesn't send SEARCH_CONTENT.""" + mock_results = create_mock_search_result(1) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=None, + display_name_to_index_map={}, + ) + + result = tool.forward("test query") + + assert result is not None + + +class TestToolMetadata: + """Tests for tool metadata attributes.""" + + def test_tool_name(self, knowledge_base_search_tool): + """Test tool name is correctly set.""" + assert knowledge_base_search_tool.name == "knowledge_base_search" + + def test_tool_category(self, knowledge_base_search_tool): + """Test tool category is SEARCH.""" + from sdk.nexent.core.utils.tools_common_message import ToolCategory + assert knowledge_base_search_tool.category == ToolCategory.SEARCH.value + + def test_tool_sign(self, knowledge_base_search_tool): + """Test tool_sign is KNOWLEDGE_BASE.""" + from sdk.nexent.core.utils.tools_common_message import ToolSign + assert knowledge_base_search_tool.tool_sign == ToolSign.KNOWLEDGE_BASE.value + + def test_output_type(self, knowledge_base_search_tool): + """Test output_type is string.""" + assert knowledge_base_search_tool.output_type == "string" + + def test_inputs_contain_required_fields(self): + """Test that inputs dict contains required fields.""" + assert "query" in KnowledgeBaseSearchTool.inputs + assert "index_names" in KnowledgeBaseSearchTool.inputs + assert KnowledgeBaseSearchTool.inputs["query"]["type"] == "string" + assert KnowledgeBaseSearchTool.inputs["index_names"]["type"] == "array" + + def test_running_prompts(self, knowledge_base_search_tool): + """Test running prompts for both languages.""" + assert knowledge_base_search_tool.running_prompt_zh == "知识库检索中..." + assert knowledge_base_search_tool.running_prompt_en == "Searching the knowledge base..." + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_forward_with_score_details(self, knowledge_base_search_tool, mock_vdb_core): + """Test forward includes score_details in results via SEARCH_CONTENT.""" + mock_results = [ + { + "document": { + "title": "Doc", + "content": "Content", + "filename": "doc.txt", + "path_or_url": "/path/doc.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file", + "score_details": {"bm25": 0.5, "knn": 0.4} + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message which contains full results via to_dict() + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert "score_details" in full_results[0] + assert full_results[0]["score_details"]["bm25"] == 0.5 + + def test_forward_with_empty_content(self, knowledge_base_search_tool, mock_vdb_core): + """Test forward handles empty content gracefully.""" + mock_results = [ + { + "document": { + "title": "Doc with no content", + "content": "", + "filename": "empty.txt", + "path_or_url": "/path/empty.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file" + }, + "score": 0.5, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + result = knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + search_results = json.loads(result) + + assert search_results[0]["text"] == "" + + def test_forward_multiple_indices(self, knowledge_base_search_tool, mock_vdb_core): + """Test forward searches across multiple indices.""" + mock_results = [ + { + "document": { + "title": "Doc from index1", + "content": "Content", + "filename": "doc1.txt", + "path_or_url": "/path/doc1.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file", + }, + "score": 0.9, + "index": "index1" + }, + { + "document": { + "title": "Doc from index2", + "content": "Content", + "filename": "doc2.txt", + "path_or_url": "/path/doc2.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file", + }, + "score": 0.8, + "index": "index2" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + result = knowledge_base_search_tool.forward("test query", index_names=["index1", "index2"]) + search_results = json.loads(result) + + assert len(search_results) == 2 + + def test_rerank_trims_to_top_k(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that rerank results are trimmed to original top_k.""" + mock_results = create_mock_search_result(10) + mock_vdb_core.hybrid_search.return_value = mock_results + + mock_rerank_model = MagicMock() + mock_rerank_model.rerank.return_value = [ + {"index": i, "relevance_score": 0.9 - i * 0.05} + for i in range(10) + ] + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + top_k=3, + rerank=True, + rerank_model=mock_rerank_model, + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + result = tool.forward("test query") + search_results = json.loads(result) + + assert len(search_results) == 3 + + +class TestFieldInfoDefaultFactory: + """Tests for FieldInfo default_factory handling. + + smolagents Tool may not properly expand Field defaults, so the code + handles FieldInfo objects with both .default and .default_factory attributes. + These tests verify the correct handling of both cases. + """ + + def test_convert_to_index_names_with_fieldinfo_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test _convert_to_index_names handles FieldInfo with default_factory correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + # Create a FieldInfo with default_factory only (Pydantic doesn't allow both) + field_info_with_factory = FieldInfo( + default_factory=lambda: {"Knowledge X": "es_index_x", "Knowledge Y": "es_index_y"} + ) + + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map=field_info_with_factory, + ) + + result = tool._convert_to_index_names(["Knowledge X", "Knowledge Y"]) + + # Should convert using the factory result + assert result == ["es_index_x", "es_index_y"] + + def test_convert_to_index_names_with_fieldinfo_default_only(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test _convert_to_index_names handles FieldInfo with only default correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + # Create a FieldInfo with default only (no factory) + field_info_with_default = FieldInfo( + default={"Knowledge A": "es_index_a"} + ) + + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map=field_info_with_default, + ) + + result = tool._convert_to_index_names(["Knowledge A"]) + + # Should convert using the default value + assert result == ["es_index_a"] + + def test_forward_with_fieldinfo_top_k_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo top_k with default_factory correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(3) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default_factory only (Pydantic doesn't allow both) + field_info_top_k = FieldInfo( + default_factory=lambda: 5 + ) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override top_k with FieldInfo + tool.top_k = field_info_top_k + + result = tool.forward("test query") + + # Should use the factory result (5) for top_k + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 5 + + def test_forward_with_fieldinfo_rerank_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo rerank with default_factory correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(10) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default_factory only (Pydantic doesn't allow both) + field_info_rerank = FieldInfo( + default_factory=lambda: True + ) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override rerank with FieldInfo + tool.rerank = field_info_rerank + + from sdk.nexent.core.utils.constants import RERANK_OVERSEARCH_MULTIPLIER + + result = tool.forward("test query") + + # Should use the factory result (True) and multiply top_k + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + # top_k from default is 3, multiplied by RERANK_OVERSEARCH_MULTIPLIER + assert call_kwargs["top_k"] == 3 * RERANK_OVERSEARCH_MULTIPLIER + + def test_forward_with_fieldinfo_top_k_default_only(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo top_k with only default correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(5) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default only (no factory) + field_info_top_k = FieldInfo(default=10) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override top_k with FieldInfo + tool.top_k = field_info_top_k + + result = tool.forward("test query") + + # Should use the default value (10) + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 10 + + def test_forward_with_fieldinfo_rerank_default_only(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo rerank with only default correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(5) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default only (no factory) + field_info_rerank = FieldInfo(default=True) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override rerank with FieldInfo + tool.rerank = field_info_rerank + + from sdk.nexent.core.utils.constants import RERANK_OVERSEARCH_MULTIPLIER + + result = tool.forward("test query") + + # Should use the default value (True) and multiply top_k + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + # top_k from default is 3, multiplied by RERANK_OVERSEARCH_MULTIPLIER + assert call_kwargs["top_k"] == 3 * RERANK_OVERSEARCH_MULTIPLIER