diff --git a/src/core/schemas.py b/src/core/schemas.py new file mode 100644 index 00000000..fd52eddd --- /dev/null +++ b/src/core/schemas.py @@ -0,0 +1,53 @@ +from typing import List, Dict, Any, Optional +from pydantic import BaseModel, Field, HttpUrl + +# --- Core Data Models for the Financial Research Agent System --- +# This module implements the "pure Pydantic implementation for multi-agent systems" +# required to standardize data flow and enhance commercial reliability. + +class ResearchGoal(BaseModel): + """ + Defines the structured input request received from the Gradio UI or the MCP client. + This serves as the initial contract for the entire agent workflow. + """ + query: str = Field(..., description="The high-level financial research question or task.") + investment_target: Optional[str] = Field(None, description="Specific company, sector, or asset under investigation.") + time_horizon: str = Field("Next 12 months", description="The required time frame for the analysis (e.g., '6 months', 'long-term').") + required_format: str = Field("Comprehensive Report", description="The desired output format (e.g., 'Summary', 'Detailed Analysis', 'Presentation Slides').") + +class ToolUsage(BaseModel): + """ + Details of a specific tool utilized during a research step (e.g., using a data acquisition tool like Akshare or Baostock, which were relevant in the ModelScope context). + """ + tool_name: str = Field(..., description="The name of the external tool or function used.") + arguments: Dict[str, Any] = Field(..., description="The arguments passed to the tool.") + result_summary: str = Field(..., description="A summary of the information retrieved or action taken by the tool.") + +class AnalysisStep(BaseModel): + """ + Represents an intermediate step in the multi-agent research process (the iterative search-and-judge loops). + """ + agent_id: str = Field(..., description="Identifier of the agent responsible for this step.") + action: str = Field(..., description="Description of the agent's action (e.g., 'Searching market data', 'Synthesizing conflicting reports').") + tools_used: List[ToolUsage] = Field(default_factory=list, description="List of specific tool calls made during this step.") + reasoning: str = Field(..., description="The rationale for the agent's action.") + +class FinancialAnalysisResult(BaseModel): + """ + Defines the final, structured output (the monetizable product) delivered by the agent system. + This structure ensures the output is professional and dependable for enterprise users. + """ + summary: str = Field(..., description="A concise executive summary of the findings.") + key_recommendation: str = Field(..., description="The primary investment or business recommendation based on the research.") + analysis_steps: List[AnalysisStep] = Field(default_factory=list, description="A verifiable trace of all steps taken by the agents.") + data_sources: List[str] = Field(default_factory=list, description="List of reliable sources used, including links or references.") + confidence_score: float = Field(..., description="A quantitative score (0.0 to 1.0) reflecting the system's confidence in the recommendation.") + +class MCPToolDefinition(BaseModel): + """ + Schema for defining a tool that is exposed via the Model Context Protocol (MCP) server endpoint. + This facilitates integration with external clients like Claude Desktop. + """ + name: str = Field(..., description="The name of the tool exposed via MCP.") + description: str = Field(..., description="A brief description of what the tool does.") + endpoint_url: HttpUrl = Field(..., description="The API endpoint URL for the tool.") diff --git a/tests/core/test_schema.py b/tests/core/test_schema.py new file mode 100644 index 00000000..222feef9 --- /dev/null +++ b/tests/core/test_schema.py @@ -0,0 +1,143 @@ +import unittest +import os +import sys +from pydantic import ValidationError + +# Add the src directory to the path so we can import the new schemas module +# Note: This is necessary because src/core is a new folder structure +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +# Import the newly created Pydantic models +# These schemas are foundational for the 'pure Pydantic implementation' for multi-agent systems +from src.core.schemas import ( + ResearchGoal, + ToolUsage, + AnalysisStep, + FinancialAnalysisResult, + MCPToolDefinition +) + +class TestAgentDataSchemas(unittest.TestCase): + """ + Tests the integrity of the core Pydantic data models used for agent communication + and output validation, ensuring data reliability for enterprise use. + """ + + def test_research_goal_successful_validation(self): + """Should validate a ResearchGoal with all fields provided.""" + goal_data = { + "query": "Analyze Qwen's market position for the next quarter.", + "investment_target": "Qwen-LLM", + "time_horizon": "Q3 2025", + "required_format": "Detailed Analysis" + } + try: + ResearchGoal(**goal_data) + except ValidationError as e: + self.fail(f"ResearchGoal validation failed unexpectedly: {e}") + + def test_research_goal_missing_required_field(self): + """Should fail validation if the 'query' field is missing.""" + invalid_data = { + "investment_target": "Alibaba stock", + "time_horizon": "6 months" + } + with self.assertRaises(ValidationError) as cm: + ResearchGoal(**invalid_data) + + # Check that the error pertains to the missing 'query' field + self.assertIn('query', str(cm.exception)) + + def test_tool_usage_successful_validation(self): + """Should validate a ToolUsage instance.""" + tool_data = { + "tool_name": "AkshareDataGetter", + "arguments": {"stock_code": "600000.SH", "period": "weekly"}, + "result_summary": "Retrieved 52 weeks of trading data." + } + try: + ToolUsage(**tool_data) + except ValidationError as e: + self.fail(f"ToolUsage validation failed unexpectedly: {e}") + + def test_analysis_step_with_nested_tool_usage(self): + """Should validate an AnalysisStep that includes ToolUsage records.""" + tool_data = ToolUsage( + tool_name="BaostockAPI", + arguments={"query": "financial reports"}, + result_summary="Acquired 2024 earnings report." + ).model_dump() # Use model_dump() for Pydantic V2 + + step_data = { + "agent_id": "DataGatherer-A", + "action": "Acquiring Q4 2024 reports.", + "tools_used": [tool_data], + "reasoning": "Need current financials for valuation model." + } + + try: + AnalysisStep(**step_data) + except ValidationError as e: + self.fail(f"AnalysisStep validation failed unexpectedly: {e}") + + def test_financial_analysis_result_successful_validation(self): + """Should validate the final report structure, including float confidence score.""" + valid_result = { + "summary": "Qwen LLM market share is growing rapidly in Asia.", + "key_recommendation": "Strong Buy signal.", + "analysis_steps": [], # Optional list, can be empty + "data_sources": ["ModelScope.cn", "Official Press Release"], + "confidence_score": 0.85 + } + try: + result = FinancialAnalysisResult(**valid_result) + self.assertIsInstance(result.confidence_score, float) + except ValidationError as e: + self.fail(f"FinancialAnalysisResult validation failed unexpectedly: {e}") + + def test_financial_analysis_result_invalid_confidence_score_type(self): + """ + Should fail if confidence_score is not a valid number (e.g., a string). + (Updated assertion for Pydantic V2 error message) + """ + invalid_result = { + "summary": "Test", + "key_recommendation": "Hold", + "analysis_steps": [], + "data_sources": [], + "confidence_score": "high" # Should be float + } + with self.assertRaises(ValidationError) as cm: + FinancialAnalysisResult(**invalid_result) + # V2 error message often contains 'unable to parse string as a number' + self.assertIn('unable to parse string as a number', str(cm.exception)) + + def test_mcp_tool_definition_successful_validation(self): + """Should validate the MCP definition, especially the HttpUrl field.""" + # The application exposes an MCP server endpoint + mcp_data = { + "name": "TheDeterminatorsSearch", + "description": "Performs deep financial research.", + "endpoint_url": "http://localhost:7860/gradio_api/mcp/" + } + try: + MCPToolDefinition(**mcp_data) + except ValidationError as e: + self.fail(f"MCPToolDefinition validation failed unexpectedly: {e}") + + def test_mcp_tool_definition_invalid_url(self): + """Should fail if the endpoint_url is not a valid URL format.""" + invalid_mcp_data = { + "name": "InvalidTool", + "description": "Test", + "endpoint_url": "not a url" + } + with self.assertRaises(ValidationError) as cm: + MCPToolDefinition(**invalid_mcp_data) + + # Pydantic V2 URL validation error + self.assertIn('url_parsing', str(cm.exception)) + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/integration/test_deep_research.py b/tests/integration/test_deep_research.py index b7bf173b..0ac239b2 100644 --- a/tests/integration/test_deep_research.py +++ b/tests/integration/test_deep_research.py @@ -3,7 +3,7 @@ Tests the complete deep research pattern: plan → parallel loops → synthesis. """ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -12,13 +12,56 @@ from src.utils.models import ReportPlan, ReportPlanSection +def _create_mock_planner_agent(): + """Create a mock planner agent for testing.""" + mock_agent = MagicMock() + mock_agent.run = AsyncMock() + return mock_agent + + +def _create_mock_long_writer_agent(): + """Create a mock long writer agent for testing.""" + mock_agent = MagicMock() + mock_agent.write_report = AsyncMock() + return mock_agent + + +def _create_mock_proofreader_agent(): + """Create a mock proofreader agent for testing.""" + mock_agent = MagicMock() + mock_agent.proofread = AsyncMock() + return mock_agent + + +def _create_mock_judge_handler(): + """Create a mock judge handler for testing.""" + mock_handler = MagicMock() + mock_handler.assess = AsyncMock(return_value=MagicMock(is_sufficient=True)) + return mock_handler + + @pytest.mark.integration class TestDeepResearchFlow: """Integration tests for DeepResearchFlow.""" @pytest.mark.asyncio - async def test_deep_research_creates_plan(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + async def test_deep_research_creates_plan( + self, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test that deep research creates a report plan.""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + # Initialize workflow state init_workflow_state() @@ -66,8 +109,36 @@ async def mock_iterative_run(query: str, **kwargs: dict) -> str: assert plan.report_outline[0].title == "Section 1" @pytest.mark.asyncio - async def test_deep_research_parallel_loops_state_synchronization(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + @patch("src.orchestrator.research_flow.create_knowledge_gap_agent") + @patch("src.orchestrator.research_flow.create_tool_selector_agent") + @patch("src.orchestrator.research_flow.create_thinking_agent") + @patch("src.orchestrator.research_flow.create_writer_agent") + async def test_deep_research_parallel_loops_state_synchronization( + self, + mock_writer_factory, + mock_thinking_factory, + mock_tool_selector_factory, + mock_knowledge_gap_factory, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test that parallel loops properly synchronize state.""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + # Mocks for agents created by IterativeResearchFlow + mock_knowledge_gap_factory.return_value = AsyncMock() + mock_tool_selector_factory.return_value = AsyncMock() + mock_thinking_factory.return_value = AsyncMock() + mock_writer_factory.return_value = AsyncMock() + # Initialize workflow state state = init_workflow_state() @@ -120,16 +191,33 @@ async def mock_iterative_run(query: str, **kwargs: dict) -> str: # Verify parallel execution assert len(section_drafts) == 2 - assert "Question 1" in section_drafts[0] - assert "Question 2" in section_drafts[1] + # Order is not guaranteed in parallel execution, check for presence of both drafts + all_drafts = "".join(section_drafts) + assert "Question 1" in all_drafts + assert "Question 2" in all_drafts # Verify state has evidence from both sections # Note: In real execution, evidence would be synced via WorkflowManager # This test verifies the structure works @pytest.mark.asyncio - async def test_deep_research_synthesizes_final_report(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + async def test_deep_research_synthesizes_final_report( + self, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test that deep research synthesizes final report from section drafts.""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + flow = DeepResearchFlow( max_iterations=1, max_time_minutes=2, @@ -177,8 +265,36 @@ async def test_deep_research_synthesizes_final_report(self) -> None: assert len(call_args.kwargs["report_draft"].sections) == 2 @pytest.mark.asyncio - async def test_deep_research_agent_chains_full_flow(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + @patch("src.orchestrator.research_flow.create_knowledge_gap_agent") + @patch("src.orchestrator.research_flow.create_tool_selector_agent") + @patch("src.orchestrator.research_flow.create_thinking_agent") + @patch("src.orchestrator.research_flow.create_writer_agent") + async def test_deep_research_agent_chains_full_flow( + self, + mock_writer_factory, + mock_thinking_factory, + mock_tool_selector_factory, + mock_knowledge_gap_factory, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test full deep research flow with agent chains (mocked).""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + # Mocks for agents created by IterativeResearchFlow + mock_knowledge_gap_factory.return_value = AsyncMock() + mock_tool_selector_factory.return_value = AsyncMock() + mock_thinking_factory.return_value = AsyncMock() + mock_writer_factory.return_value = AsyncMock() + # Initialize workflow state init_workflow_state() @@ -224,8 +340,36 @@ async def mock_iterative_run(query: str, **kwargs: dict) -> str: flow.long_writer_agent.write_report.assert_called_once() @pytest.mark.asyncio - async def test_deep_research_handles_multiple_sections(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + @patch("src.orchestrator.research_flow.create_knowledge_gap_agent") + @patch("src.orchestrator.research_flow.create_tool_selector_agent") + @patch("src.orchestrator.research_flow.create_thinking_agent") + @patch("src.orchestrator.research_flow.create_writer_agent") + async def test_deep_research_handles_multiple_sections( + self, + mock_writer_factory, + mock_thinking_factory, + mock_tool_selector_factory, + mock_knowledge_gap_factory, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test that deep research handles multiple sections correctly.""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + # Mocks for agents created by IterativeResearchFlow + mock_knowledge_gap_factory.return_value = AsyncMock() + mock_tool_selector_factory.return_value = AsyncMock() + mock_thinking_factory.return_value = AsyncMock() + mock_writer_factory.return_value = AsyncMock() + flow = DeepResearchFlow( max_iterations=1, max_time_minutes=2, @@ -263,8 +407,35 @@ async def mock_iterative_run(query: str, **kwargs: dict) -> str: assert f"Section {i}" in draft or f"section {i}" in draft.lower() @pytest.mark.asyncio - async def test_deep_research_workflow_manager_integration(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + @patch("src.orchestrator.research_flow.create_knowledge_gap_agent") + @patch("src.orchestrator.research_flow.create_tool_selector_agent") + @patch("src.orchestrator.research_flow.create_thinking_agent") + @patch("src.orchestrator.research_flow.create_writer_agent") + async def test_deep_research_workflow_manager_integration( + self, + mock_writer_factory, + mock_thinking_factory, + mock_tool_selector_factory, + mock_knowledge_gap_factory, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test that deep research properly uses WorkflowManager.""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + # Mocks for agents created by IterativeResearchFlow + mock_knowledge_gap_factory.return_value = AsyncMock() + mock_tool_selector_factory.return_value = AsyncMock() + mock_thinking_factory.return_value = AsyncMock() + mock_writer_factory.return_value = AsyncMock() # Initialize workflow state init_workflow_state() @@ -308,8 +479,36 @@ async def mock_iterative_run(query: str, **kwargs: dict) -> str: assert all(isinstance(draft, str) for draft in section_drafts) @pytest.mark.asyncio - async def test_deep_research_state_initialization(self) -> None: + @patch("src.orchestrator.research_flow.create_planner_agent") + @patch("src.orchestrator.research_flow.create_long_writer_agent") + @patch("src.orchestrator.research_flow.create_proofreader_agent") + @patch("src.orchestrator.research_flow.create_judge_handler") + @patch("src.orchestrator.research_flow.create_knowledge_gap_agent") + @patch("src.orchestrator.research_flow.create_tool_selector_agent") + @patch("src.orchestrator.research_flow.create_thinking_agent") + @patch("src.orchestrator.research_flow.create_writer_agent") + async def test_deep_research_state_initialization( + self, + mock_writer_factory, + mock_thinking_factory, + mock_tool_selector_factory, + mock_knowledge_gap_factory, + mock_judge_factory, + mock_proofreader_factory, + mock_long_writer_factory, + mock_planner_factory, + ) -> None: """Test that deep research properly initializes workflow state.""" + mock_planner_factory.return_value = _create_mock_planner_agent() + mock_long_writer_factory.return_value = _create_mock_long_writer_agent() + mock_proofreader_factory.return_value = _create_mock_proofreader_agent() + mock_judge_factory.return_value = _create_mock_judge_handler() + # Mocks for agents created by IterativeResearchFlow + mock_knowledge_gap_factory.return_value = AsyncMock() + mock_tool_selector_factory.return_value = AsyncMock() + mock_thinking_factory.return_value = AsyncMock() + mock_writer_factory.return_value = AsyncMock() + flow = DeepResearchFlow( max_iterations=1, max_time_minutes=2, diff --git a/tests/unit/agent_factory/test_judges_factory.py b/tests/unit/agent_factory/test_judges_factory.py index 3cc7e331..e150df90 100644 --- a/tests/unit/agent_factory/test_judges_factory.py +++ b/tests/unit/agent_factory/test_judges_factory.py @@ -23,7 +23,9 @@ def mock_settings(): def test_get_model_openai(mock_settings): """Test that OpenAI model is returned when provider is openai.""" - mock_settings.llm_provider = "openai" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.has_openai_key = True mock_settings.openai_api_key = "sk-test" mock_settings.openai_model = "gpt-5.1" @@ -34,7 +36,10 @@ def test_get_model_openai(mock_settings): def test_get_model_anthropic(mock_settings): """Test that Anthropic model is returned when provider is anthropic.""" - mock_settings.llm_provider = "anthropic" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.has_openai_key = False + mock_settings.has_anthropic_key = True mock_settings.anthropic_api_key = "sk-ant-test" mock_settings.anthropic_model = "claude-sonnet-4-5-20250929" diff --git a/tests/unit/agents/test_input_parser.py b/tests/unit/agents/test_input_parser.py index 95324026..c7ea45db 100644 --- a/tests/unit/agents/test_input_parser.py +++ b/tests/unit/agents/test_input_parser.py @@ -75,21 +75,34 @@ def input_parser_agent(mock_model: MagicMock) -> InputParserAgent: class TestInputParserAgentInit: """Test InputParserAgent initialization.""" - def test_input_parser_agent_init_with_model(self, mock_model: MagicMock) -> None: + @patch("src.agents.input_parser.Agent") + def test_input_parser_agent_init_with_model(self, mock_agent_class: MagicMock, mock_model: MagicMock) -> None: """Test InputParserAgent initialization with provided model.""" + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = InputParserAgent(model=mock_model) + assert agent.model == mock_model - assert agent.agent is not None + assert agent.agent == mock_agent_instance + mock_agent_class.assert_called_once() + @patch("src.agents.input_parser.Agent") @patch("src.agents.input_parser.get_model") def test_input_parser_agent_init_without_model( - self, mock_get_model: MagicMock, mock_model: MagicMock + self, mock_get_model: MagicMock, mock_agent_class: MagicMock, mock_model: MagicMock ) -> None: """Test InputParserAgent initialization without model (uses default).""" mock_get_model.return_value = mock_model + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = InputParserAgent() + assert agent.model == mock_model + assert agent.agent == mock_agent_instance mock_get_model.assert_called_once() + mock_agent_class.assert_called_once() def test_input_parser_agent_has_correct_system_prompt( self, input_parser_agent: InputParserAgent @@ -254,26 +267,36 @@ async def test_parse_heuristic_iterative_mode( class TestCreateInputParserAgent: """Test create_input_parser_agent() factory function.""" + @patch("src.agents.input_parser.InputParserAgent") @patch("src.agents.input_parser.get_model") def test_create_input_parser_agent_with_model( - self, mock_get_model: MagicMock, mock_model: MagicMock + self, mock_get_model: MagicMock, mock_input_parser_agent_class: MagicMock, mock_model: MagicMock ) -> None: """Test factory function with provided model.""" + mock_agent_instance = MagicMock() + mock_input_parser_agent_class.return_value = mock_agent_instance + agent = create_input_parser_agent(model=mock_model) - assert isinstance(agent, InputParserAgent) - assert agent.model == mock_model + + assert agent == mock_agent_instance + mock_input_parser_agent_class.assert_called_once_with(model=mock_model) mock_get_model.assert_not_called() + @patch("src.agents.input_parser.InputParserAgent") @patch("src.agents.input_parser.get_model") def test_create_input_parser_agent_without_model( - self, mock_get_model: MagicMock, mock_model: MagicMock + self, mock_get_model: MagicMock, mock_input_parser_agent_class: MagicMock, mock_model: MagicMock ) -> None: """Test factory function without model (uses default).""" mock_get_model.return_value = mock_model + mock_agent_instance = MagicMock() + mock_input_parser_agent_class.return_value = mock_agent_instance + agent = create_input_parser_agent() - assert isinstance(agent, InputParserAgent) - assert agent.model == mock_model + + assert agent == mock_agent_instance mock_get_model.assert_called_once() + mock_input_parser_agent_class.assert_called_once_with(model=mock_model) @patch("src.agents.input_parser.get_model") def test_create_input_parser_agent_handles_error(self, mock_get_model: MagicMock) -> None: diff --git a/tests/unit/agents/test_long_writer.py b/tests/unit/agents/test_long_writer.py index b99b87db..fd986006 100644 --- a/tests/unit/agents/test_long_writer.py +++ b/tests/unit/agents/test_long_writer.py @@ -62,21 +62,34 @@ def sample_report_draft() -> ReportDraft: class TestLongWriterAgentInit: """Test LongWriterAgent initialization.""" - def test_long_writer_agent_init_with_model(self, mock_model: MagicMock) -> None: + @patch("src.agents.long_writer.Agent") + def test_long_writer_agent_init_with_model(self, mock_agent_class: MagicMock, mock_model: MagicMock) -> None: """Test LongWriterAgent initialization with provided model.""" + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = LongWriterAgent(model=mock_model) + assert agent.model == mock_model - assert agent.agent is not None + assert agent.agent == mock_agent_instance + mock_agent_class.assert_called_once() + @patch("src.agents.long_writer.Agent") @patch("src.agents.long_writer.get_model") def test_long_writer_agent_init_without_model( - self, mock_get_model: MagicMock, mock_model: MagicMock + self, mock_get_model: MagicMock, mock_agent_class: MagicMock, mock_model: MagicMock ) -> None: """Test LongWriterAgent initialization without model (uses default).""" mock_get_model.return_value = mock_model + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = LongWriterAgent() + assert agent.model == mock_model + assert agent.agent == mock_agent_instance mock_get_model.assert_called_once() + mock_agent_class.assert_called_once() def test_long_writer_agent_has_structured_output( self, long_writer_agent: LongWriterAgent diff --git a/tests/unit/agents/test_proofreader.py b/tests/unit/agents/test_proofreader.py index e16ad72f..2969f213 100644 --- a/tests/unit/agents/test_proofreader.py +++ b/tests/unit/agents/test_proofreader.py @@ -62,21 +62,34 @@ def sample_report_draft() -> ReportDraft: class TestProofreaderAgentInit: """Test ProofreaderAgent initialization.""" - def test_proofreader_agent_init_with_model(self, mock_model: MagicMock) -> None: + @patch("src.agents.proofreader.Agent") + def test_proofreader_agent_init_with_model(self, mock_agent_class: MagicMock, mock_model: MagicMock) -> None: """Test ProofreaderAgent initialization with provided model.""" + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = ProofreaderAgent(model=mock_model) + assert agent.model == mock_model - assert agent.agent is not None + assert agent.agent == mock_agent_instance + mock_agent_class.assert_called_once() + @patch("src.agents.proofreader.Agent") @patch("src.agents.proofreader.get_model") def test_proofreader_agent_init_without_model( - self, mock_get_model: MagicMock, mock_model: MagicMock + self, mock_get_model: MagicMock, mock_agent_class: MagicMock, mock_model: MagicMock ) -> None: """Test ProofreaderAgent initialization without model (uses default).""" mock_get_model.return_value = mock_model + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = ProofreaderAgent() + assert agent.model == mock_model + assert agent.agent == mock_agent_instance mock_get_model.assert_called_once() + mock_agent_class.assert_called_once() def test_proofreader_agent_has_correct_system_prompt( self, proofreader_agent: ProofreaderAgent diff --git a/tests/unit/agents/test_writer.py b/tests/unit/agents/test_writer.py index 766420df..f149d223 100644 --- a/tests/unit/agents/test_writer.py +++ b/tests/unit/agents/test_writer.py @@ -29,27 +29,45 @@ def mock_agent_result() -> AgentRunResult[Any]: @pytest.fixture def writer_agent(mock_model: MagicMock) -> WriterAgent: """Create a WriterAgent instance with mocked model.""" - return WriterAgent(model=mock_model) + with patch("src.agents.writer.Agent") as mock_agent_class: + mock_agent_instance = MagicMock() + # The .run method needs to be an async mock for the tests + mock_agent_instance.run = AsyncMock() + mock_agent_class.return_value = mock_agent_instance + yield WriterAgent(model=mock_model) class TestWriterAgentInit: """Test WriterAgent initialization.""" - def test_writer_agent_init_with_model(self, mock_model: MagicMock) -> None: + @patch("src.agents.writer.Agent") + def test_writer_agent_init_with_model(self, mock_agent_class: MagicMock, mock_model: MagicMock) -> None: """Test WriterAgent initialization with provided model.""" + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = WriterAgent(model=mock_model) + assert agent.model == mock_model - assert agent.agent is not None + assert agent.agent == mock_agent_instance + mock_agent_class.assert_called_once() + @patch("src.agents.writer.Agent") @patch("src.agents.writer.get_model") def test_writer_agent_init_without_model( - self, mock_get_model: MagicMock, mock_model: MagicMock + self, mock_get_model: MagicMock, mock_agent_class: MagicMock, mock_model: MagicMock ) -> None: """Test WriterAgent initialization without model (uses default).""" mock_get_model.return_value = mock_model + mock_agent_instance = MagicMock() + mock_agent_class.return_value = mock_agent_instance + agent = WriterAgent() + assert agent.model == mock_model + assert agent.agent == mock_agent_instance mock_get_model.assert_called_once() + mock_agent_class.assert_called_once() def test_writer_agent_has_correct_system_prompt(self, writer_agent: WriterAgent) -> None: """Test that WriterAgent has correct system prompt.""" diff --git a/tests/unit/orchestrator/test_graph_orchestrator.py b/tests/unit/orchestrator/test_graph_orchestrator.py index 4136663f..d3dae8b4 100644 --- a/tests/unit/orchestrator/test_graph_orchestrator.py +++ b/tests/unit/orchestrator/test_graph_orchestrator.py @@ -209,10 +209,22 @@ async def test_run_handles_errors(self): from src.orchestrator.research_flow import IterativeResearchFlow # Create flow and patch its run method to raise exception - original_flow = IterativeResearchFlow( - max_iterations=2, - max_time_minutes=5, - ) + with ( + patch("src.orchestrator.research_flow.create_knowledge_gap_agent") as mock_kg, + patch("src.orchestrator.research_flow.create_tool_selector_agent") as mock_ts, + patch("src.orchestrator.research_flow.create_thinking_agent") as mock_thinking, + patch("src.orchestrator.research_flow.create_writer_agent") as mock_writer, + patch("src.orchestrator.research_flow.create_judge_handler") as mock_judge + ): + mock_kg.return_value = AsyncMock() + mock_ts.return_value = AsyncMock() + mock_thinking.return_value = AsyncMock() + mock_writer.return_value = AsyncMock() + mock_judge.return_value = AsyncMock() + original_flow = IterativeResearchFlow( + max_iterations=2, + max_time_minutes=5, + ) orchestrator._iterative_flow = original_flow with patch.object(original_flow, "run", side_effect=Exception("Test error")): diff --git a/tests/unit/orchestrator/test_research_flow.py b/tests/unit/orchestrator/test_research_flow.py index 2691ec15..78a30df8 100644 --- a/tests/unit/orchestrator/test_research_flow.py +++ b/tests/unit/orchestrator/test_research_flow.py @@ -37,6 +37,7 @@ def flow(self, mock_agents): patch("src.orchestrator.research_flow.create_thinking_agent") as mock_thinking, patch("src.orchestrator.research_flow.create_writer_agent") as mock_writer, patch("src.orchestrator.research_flow.execute_tool_tasks") as mock_execute, + patch("src.orchestrator.research_flow.create_judge_handler") as mock_judge, ): mock_kg.return_value = mock_agents["knowledge_gap"] mock_ts.return_value = mock_agents["tool_selector"] @@ -45,6 +46,7 @@ def flow(self, mock_agents): mock_execute.return_value = { "task_1": ToolAgentOutput(output="Finding 1", sources=["url1"]), } + mock_judge.return_value = AsyncMock() yield IterativeResearchFlow(max_iterations=2, max_time_minutes=5) @@ -203,10 +205,12 @@ def flow(self, mock_agents): patch("src.orchestrator.research_flow.create_planner_agent") as mock_planner, patch("src.orchestrator.research_flow.create_long_writer_agent") as mock_long_writer, patch("src.orchestrator.research_flow.create_proofreader_agent") as mock_proofreader, + patch("src.orchestrator.research_flow.create_judge_handler") as mock_judge_handler, ): mock_planner.return_value = mock_agents["planner"] mock_long_writer.return_value = mock_agents["long_writer"] mock_proofreader.return_value = mock_agents["proofreader"] + mock_judge_handler.return_value = AsyncMock() yield DeepResearchFlow(max_iterations=2, max_time_minutes=5) diff --git a/tests/unit/test_app_smoke.py b/tests/unit/test_app_smoke.py index 3fb347f9..a3e7ee2a 100644 --- a/tests/unit/test_app_smoke.py +++ b/tests/unit/test_app_smoke.py @@ -5,6 +5,8 @@ that wouldn't be caught by unit tests. """ +from unittest.mock import MagicMock, patch + import pytest @@ -12,7 +14,8 @@ class TestAppSmoke: """Smoke tests for app initialization.""" - def test_app_creates_demo(self) -> None: + @patch("gradio.LoginButton") + def test_app_creates_demo(self, mock_login_button: MagicMock) -> None: """App should create Gradio demo without crashing. This catches: