diff --git a/examples/localcowork/.env.example b/examples/localcowork/.env.example index 16e64bd..7b66b73 100644 --- a/examples/localcowork/.env.example +++ b/examples/localcowork/.env.example @@ -8,6 +8,10 @@ # Directory containing GGUF model files (downloaded from HuggingFace) # LOCALCOWORK_MODELS_DIR=~/Projects/_models +# HuggingFace cache directory (defaults to ~/.cache/huggingface) +# Set to LOCALCOWORK_MODELS_DIR to keep all model files together +# HF_HOME= + # Text model API endpoint (OpenAI-compatible). Set by start-model.sh. # Default when using LFM2 via llama-server: http://localhost:8080/v1 # Default when using Ollama: http://localhost:11434/v1 diff --git a/examples/localcowork/.gitignore b/examples/localcowork/.gitignore index 4f4a5d5..281c756 100644 --- a/examples/localcowork/.gitignore +++ b/examples/localcowork/.gitignore @@ -17,6 +17,7 @@ src-tauri/gen/ _models/*.gguf _models/*.bin _models/*.safetensors +_models/.cache/ # ─── IDE ────────────────────────────────────────────────────────────────── .vscode/ diff --git a/examples/localcowork/_models/config.yaml b/examples/localcowork/_models/config.yaml index d61d1fb..1005ea9 100644 --- a/examples/localcowork/_models/config.yaml +++ b/examples/localcowork/_models/config.yaml @@ -14,6 +14,10 @@ active_model: lfm2-24b-a2b # Sparse MoE: 24B total, 2.3B active, 64 experts top # Default model directory for non-Ollama model files (GGUF, MLX, etc.) models_dir: "${LOCALCOWORK_MODELS_DIR:-~/Projects/_models}" +# HuggingFace model downloads default to the user cache directory +# (typically ~/.cache/huggingface on Linux/macOS) and can be overridden +# via HF_HOME, HF_HUB_CACHE, or XDG_CACHE_HOME. Pointing these to the +# project directory will cause large cache files to be included here. # ─── Tool Surface Curation ────────────────────────────────────────────────── # By default, only the curated high-accuracy servers and tools are active. diff --git a/src-tauri/src/commands/chat.rs b/src-tauri/src/commands/chat.rs new file mode 100644 index 0000000..0e9ee4c --- /dev/null +++ b/src-tauri/src/commands/chat.rs @@ -0,0 +1,3672 @@ +//! Tauri IPC commands for the chat interface. +//! +//! These commands are called from the React frontend via `invoke()`. +//! They bridge the frontend to the agent core (ConversationManager, +//! ToolRouter, and InferenceClient). + +use std::sync::Mutex; + +use futures::StreamExt; +use serde::Serialize; +use uuid::Uuid; + +use crate::agent_core::permissions::{PermissionScope, PermissionStatus, PermissionStore}; +// NOTE: response_analysis functions (is_incomplete_response, is_deflection_response) +// remain in the codebase and are tested, but are no longer called from the agent loop. +// They are available for the Orchestrator (ADR-009) or re-enablement via config. +// Tests below still exercise them for regression coverage. +use crate::agent_core::tokens::truncate_utf8; +use crate::agent_core::tool_router::{generate_preview, is_destructive_action}; +use crate::agent_core::{AuditStatus, ConfirmationRequest, ConfirmationResponse}; +use crate::agent_core::ConversationManager; +use crate::inference::config::{find_config_path, load_models_config}; +use crate::inference::types::{SamplingOverrides, ToolDefinition}; +use crate::inference::InferenceClient; +use crate::mcp_client::{CategoryRegistry, McpClient, ToolResolution}; +use crate::{PendingConfirmation, TokioMutex}; + +// ─── Two-Pass Tool Selection ──────────────────────────────────────────────── + +/// Tracks the two-pass tool selection state within the agent loop. +/// +/// On `Categories` phase, the model sees ~15 category meta-tools (~1,500 tokens). +/// On `Expanded`, the model sees real tools from selected categories. +/// On `Flat` (legacy), all tools are sent every turn (~8,670 tokens). +#[derive(Debug, Clone)] +enum ToolSelectionPhase { + /// First turn: model sees category meta-tools. + Categories { + /// The category registry used for expansion. + cat_registry: CategoryRegistry, + }, + /// Subsequent turns: model sees real tools from selected categories. + Expanded { + /// Category names that were selected (retained for diagnostics). + _selected_categories: Vec, + }, + /// Legacy flat mode: all tools every turn. + Flat, +} + +/// Minimum number of registered tools to activate two-pass mode. +/// Below this threshold, flat mode is used regardless of config. +/// Set to 30 because category meta-tools confuse LFM2-24B-A2B at ≤21 tools +/// (model responds with text instead of calling tools). Two-pass is only +/// worthwhile at 67+ tools where it saves ~7k tokens/turn. +const TWO_PASS_MIN_TOOLS: usize = 30; + +// ─── Response Types ───────────────────────────────────────────────────────── + +/// Session start response. +#[derive(Debug, Serialize)] +pub struct SessionInfo { + pub session_id: String, + /// Whether this is a newly created session or a resumed one. + pub resumed: bool, +} + +// ─── System prompt ────────────────────────────────────────────────────────── + +/// Identity and intro — static portion of the system prompt. +/// +/// Kept short: research shows small LLMs perform better with concise identity +/// statements. The capabilities section (dynamic) is inserted after this. +const SYSTEM_PROMPT_INTRO: &str = "\ +You are LocalCowork, a private on-device AI assistant. You call tools to help the user."; + +/// Behavioral rules and few-shot examples — dynamic portion of the system prompt. +/// +/// Optimized for small LLMs (24B MoE) based on research: +/// - XML section tags for clear structure (models parse sections, not paragraphs) +/// - Pre-computed relative dates (model doesn't need to reason about "today") +/// - ≤6 rules (small models lose track beyond 5-7 instructions) +/// - Calendar/date example in position 1 (primacy effect) +/// - Concrete dates in examples (no indirection) +fn system_prompt_rules(today: &str, tomorrow: &str, week_start: &str, week_end: &str) -> String { + let home = dirs::home_dir() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_else(|| { + if cfg!(target_os = "windows") { + r"C:\Users\user".to_string() + } else if cfg!(target_os = "macos") { + "/Users/user".to_string() + } else { + "/home/user".to_string() + } + }); + + format!("\ +\n\ +1. Use fully-qualified tool names: filesystem.list_dir, NOT list_dir.\n\ +2. Use absolute paths: {home}/Documents/file.txt, NOT ~/Documents/file.txt. \ +If a WORKING FOLDER is set, use ONLY the exact paths listed there.\n\ +3. READ tools: call immediately. WRITE tools: call directly (system shows confirmation).\n\ +4. After a scan returns results, present findings and STOP. Do NOT auto-chain to \ +mutable tools (encrypt, delete, move) without the user asking.\n\ +5. Never call the same tool with the same arguments twice. Use the result you already have.\n\ +6. Be concise. Respond after 1-3 tool calls unless the user asked for exhaustive processing.\n\ +\n\n\ +\n\ +Example 1 — calendar query (use pre-computed dates, never ask the user):\n\ + User: \"What's on my calendar today?\"\n\ + You call: calendar.list_events({{\"start_date\": \"{today}\", \"end_date\": \"{today}\"}})\n\ + User: \"Any meetings tomorrow?\"\n\ + You call: calendar.list_events({{\"start_date\": \"{tomorrow}\", \"end_date\": \"{tomorrow}\"}})\n\ + User: \"What do I have this week?\"\n\ + You call: calendar.list_events({{\"start_date\": \"{week_start}\", \"end_date\": \"{week_end}\"}})\n\ + WRONG: Asking the user what today's date is. You already know it.\n\n\ +Example 2 — file listing:\n\ + User: \"List my Documents folder.\"\n\ + You call: filesystem.list_dir({{\"path\": \"{home}/Documents\"}})\n\n\ +Example 3 — security scan:\n\ + User: \"Scan for secrets.\"\n\ + You call: security.scan_for_secrets({{\"path\": \"{home}/Projects\"}})\n\ + Then: present findings and STOP.\n\ +") +} + +/// Build the system prompt with dynamic tool capabilities from the MCP registry. +/// +/// Structure (optimized for small LLMs): +/// ```text +/// Identity (1 line) +/// block with pre-computed dates +/// from MCP registry +/// consolidated behavioral rules +/// few-shot examples (calendar first) +/// ``` +/// +/// Key optimizations for 24B MoE models: +/// - XML section tags (research: small models parse structured prompts better) +/// - Pre-computed relative dates (today, tomorrow, week range) — no reasoning needed +/// - Date block at position 2 (high primacy) and repeated in rules reminder +/// - ≤6 rules instead of 12 (small models lose track beyond 5-7) +fn build_system_prompt( + registry: &crate::mcp_client::registry::ToolRegistry, + two_pass_active: bool, +) -> String { + use chrono::{Datelike, Duration}; + + let capabilities = registry.capability_summary(); + + // Pre-compute all relative dates so the model never needs to reason about them. + // Research: small LLMs fail at date arithmetic; pre-computing eliminates the problem. + let now = chrono::Local::now(); + let today = now.format("%Y-%m-%d").to_string(); + let day_of_week = now.format("%A").to_string(); // e.g. "Monday" + let tomorrow = (now + Duration::days(1)).format("%Y-%m-%d").to_string(); + + // Compute Monday (start) and Sunday (end) of the current week + let weekday_num = now.weekday().num_days_from_monday(); // Mon=0, Sun=6 + let week_start = (now - Duration::days(weekday_num as i64)) + .format("%Y-%m-%d") + .to_string(); + let week_end = (now + Duration::days((6 - weekday_num) as i64)) + .format("%Y-%m-%d") + .to_string(); + + let time_str = now.format("%H:%M").to_string(); + + // Date block — prominent, structured, with pre-computed values. + // Placed immediately after identity for maximum primacy. + let date_block = format!( + "\n\ + today = {today} ({day_of_week})\n\ + tomorrow = {tomorrow}\n\ + this_week = {week_start} to {week_end}\n\ + current_time = {time_str}\n\ + Use these exact values when the user says \"today\", \"tomorrow\", \"this week\".\n\ + NEVER ask the user for a date.\n\ + " + ); + + let rules = system_prompt_rules(&today, &tomorrow, &week_start, &week_end); + + if two_pass_active { + let two_pass_instruction = "\n\nIMPORTANT: You will first see category-level tools \ + (like file_browse, image_ocr, data_analysis, etc.). Call 1-3 categories that match \ + the user's request. You will then receive the specific tools within those categories. \ + Always select the categories FIRST before trying to use specific tools. \ + After selecting categories and receiving the expanded tools, call the minimum \ + tools needed to answer the user's question, then provide your response."; + format!( + "{SYSTEM_PROMPT_INTRO}\n\n\ + {date_block}\n\n\ + \n{capabilities}\n\ + {two_pass_instruction}\n\n\ + {rules}" + ) + } else { + format!( + "{SYSTEM_PROMPT_INTRO}\n\n\ + {date_block}\n\n\ + \n{capabilities}\n\n\n\ + {rules}" + ) + } +} + +/// Maximum number of tool-call round-trips per user message. +/// +/// Each round allows one model response + one set of tool executions. +/// Complex tasks (e.g., OCR on 10 files) may use many rounds. +/// The model gets one call per tool per round (it can batch multiple +/// tool calls in a single response, but typically does one at a time). +const MAX_TOOL_ROUNDS: usize = 10; + +/// Maximum consecutive empty responses before forcing a summary. +/// +/// If the model returns 0 text AND 0 tool calls this many times in a row, +/// it's stuck (likely due to context confusion or timeout). We inject a +/// summary prompt to force text output. +const MAX_EMPTY_RETRIES: usize = 2; + +/// Maximum consecutive rounds with ALL tool calls failing before injecting +/// a corrective hint. +/// +/// When the model repeatedly calls the same non-existent tool (e.g., +/// `filesystem.rename_file` instead of `filesystem.move_file`), this +/// prevents burning all 20 rounds on the same error. After this many +/// consecutive all-error rounds, we inject a hint telling the model +/// which tools actually exist. +const MAX_CONSECUTIVE_ERROR_ROUNDS: usize = 2; + +/// Maximum times a single tool can fail before it's removed from the tool +/// definitions and the model is told to stop retrying. +/// +/// This catches the case where the model alternates between a succeeding tool +/// and a failing one — the per-round counter (`consecutive_error_rounds`) resets +/// on every success, so this per-tool counter is the only thing that can break +/// that loop. +const MAX_SAME_TOOL_FAILURES: usize = 3; + +/// Maximum consecutive duplicate tool calls (same tool name with identical +/// arguments) before the agent loop breaks. +/// +/// When the model gets stuck calling the same tool repeatedly with identical +/// params (e.g., `list_directory("~/Downloads")` 3× in a row), the loop +/// should detect this and exit. +/// +/// Note: `consecutive_duplicate_count()` returns 1 for the first occurrence, +/// so a threshold of 2 means "one genuine duplicate" (the tool was called +/// twice with identical args). Before reaching this hard break, the soft +/// interception in the tool execution loop will skip the redundant call and +/// inject a "you already have these results" nudge, giving the model a +/// chance to produce text. +const MAX_DUPLICATE_TOOL_CALLS: usize = 2; + +/// Minimum remaining token budget to start a new agent loop round. +/// +/// If the context window has fewer than this many tokens remaining, the +/// agent loop exits early rather than risk context overflow and degraded +/// model quality. Set to accommodate a model response (~500 tokens) plus +/// a tool result (~1000 tokens). +const MIN_ROUND_TOKEN_BUDGET: u32 = 1500; + +/// Configuration for tool result compression. +const COMPRESSION_THRESHOLD_CHARS: usize = 3_000; +const MAX_TOOL_RESULT_CHARS: usize = 6_000; + +/// Truncate a tool result if it exceeds `MAX_TOOL_RESULT_CHARS`. +/// +/// Preserves the beginning of the result (which usually contains the most +/// useful information) and appends a truncation notice. +fn truncate_tool_result(result: &str, tool_name: &str) -> String { + if result.len() <= MAX_TOOL_RESULT_CHARS { + return result.to_string(); + } + + // Try smart compression for known tool types + if let Some(summary) = compress_tool_result(result, tool_name) { + if summary.len() <= MAX_TOOL_RESULT_CHARS { + tracing::info!( + tool = %tool_name, + original_len = result.len(), + compressed_len = summary.len(), + "tool result compressed via smart summary" + ); + return summary; + } + } + + // Fall back to simple truncation + let truncated = &result[..MAX_TOOL_RESULT_CHARS]; + tracing::warn!( + tool = %tool_name, + original_len = result.len(), + truncated_to = MAX_TOOL_RESULT_CHARS, + "tool result truncated — exceeded MAX_TOOL_RESULT_CHARS" + ); + format!( + "{truncated}\n\n[... truncated: showing first {MAX_TOOL_RESULT_CHARS} of {} chars]", + result.len() + ) +} + +/// Compress tool results using smart extraction for known data patterns. +/// +/// For directory listings, extracts just filenames and counts. +/// For search results, extracts matches and count. +/// For JSON/structured data, extracts key summaries. +/// +/// Returns None if compression isn't beneficial for this tool type. +fn compress_tool_result(result: &str, tool_name: &str) -> Option { + // Only compress for read/search type operations + let compressible_tools = [ + "list_dir", + "search_files", + "scan_for_secrets", + "scan_for_pii", + "query_knowledge", + "list_events", + "list_tasks", + "search_emails", + ]; + + let is_compressible = compressible_tools + .iter() + .any(|t| tool_name.contains(t)); + + if !is_compressible || result.len() < COMPRESSION_THRESHOLD_CHARS { + return None; + } + + // For directory listings: extract just file/dir names and summary + if tool_name.contains("list_dir") { + return compress_directory_listing(result); + } + + // For search results: extract match counts and key matches + if tool_name.contains("search") || tool_name.contains("scan") { + return compress_search_results(result); + } + + // For JSON-like results: extract key fields + if result.starts_with('{') || result.starts_with('[') { + return compress_json_result(result); + } + + None +} + +/// Compress a directory listing to just names and counts. +fn compress_directory_listing(result: &str) -> Option { + let lines: Vec<&str> = result.lines().collect(); + if lines.is_empty() { + return None; + } + + let mut files = Vec::new(); + let mut dirs = Vec::new(); + + for line in &lines { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + if trimmed.starts_with("📁") { + if let Some(name) = trimmed.strip_prefix("📁").map(|s| s.trim()) { + dirs.push(name.to_string()); + } + } else if trimmed.starts_with("📄") { + if let Some(name) = trimmed.strip_prefix("📄").map(|s| s.trim()) { + files.push(name); + } + } else { + // Try to extract name from parenthetical format: "name (123 KB)" + if let Some(paren_idx) = trimmed.find('(') { + let name = trimmed[..paren_idx].trim(); + if !name.is_empty() { + files.push(name); + } + } + } + } + + let total = files.len() + dirs.len(); + let mut summary = format!("Total: {} items ({} files, {} directories)\n\n", total, files.len(), dirs.len()); + + if !dirs.is_empty() { + summary.push_str("Directories:\n"); + for d in dirs.iter().take(20) { + summary.push_str(&format!(" 📁 {}\n", d)); + } + if dirs.len() > 20 { + summary.push_str(&format!(" ... and {} more\n", dirs.len() - 20)); + } + summary.push('\n'); + } + + if !files.is_empty() { + summary.push_str("Files:\n"); + for f in files.iter().take(30) { + summary.push_str(&format!(" 📄 {}\n", f)); + } + if files.len() > 30 { + summary.push_str(&format!(" ... and {} more\n", files.len() - 30)); + } + } + + Some(summary) +} + +/// Compress search/scan results to match counts and key findings. +fn compress_search_results(result: &str) -> Option { + let lower = result.to_lowercase(); + + // Try to extract match count + let count_patterns = [ + ("found ", " matches"), + ("matches: ", ""), + ("results: ", ""), + ("total: ", " items"), + ]; + + for (prefix, suffix) in &count_patterns { + if let Some(idx) = lower.find(prefix) { + let after_prefix = &result[idx + prefix.len()..]; + let end_idx = after_prefix + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(after_prefix.len()); + let count = &after_prefix[..end_idx]; + if !count.is_empty() && count.len() <= 6 { + // Found a count, now get first few matches + let matches: Vec<&str> = result + .lines() + .filter(|l| { + let l_lower = l.to_lowercase(); + !l_lower.contains("found") + && !l_lower.contains("total") + && !l_lower.contains("scan") + && !l_lower.contains("error") + && l.trim().len() > 3 + }) + .take(10) + .collect(); + + let mut summary = format!("{}{}{}\n\n", prefix, count, suffix); + if !matches.is_empty() { + summary.push_str("Key findings:\n"); + for m in matches { + let trimmed = m.trim(); + if trimmed.len() > 100 { + summary.push_str(&format!(" {}\n", &trimmed[..100])); + } else { + summary.push_str(&format!(" {}\n", trimmed)); + } + } + } + return Some(summary); + } + } + } + + // Fallback: just take first 15 lines + let lines: Vec<&str> = result.lines().take(15).collect(); + if lines.is_empty() { + return None; + } + + let summary = format!( + "[... {} lines total ...]\n\n{}", + result.lines().count(), + lines.join("\n") + ); + Some(summary) +} + +/// Compress JSON results by extracting key fields. +fn compress_json_result(result: &str) -> Option { + let parsed: serde_json::Value = serde_json::from_str(result).ok()?; + + // For arrays, extract count and first few items + if let Some(arr) = parsed.as_array() { + if arr.is_empty() { + return Some("[] (empty)".to_string()); + } + + let count = arr.len(); + let mut summary = format!("[{} items]\n\n", count); + + for (i, item) in arr.iter().take(5).enumerate() { + summary.push_str(&format!("{}. ", i + 1)); + if let Some(obj) = item.as_object() { + // Extract common "name" or "text" fields + if let Some(name) = obj.get("name").and_then(|v| v.as_str()) { + summary.push_str(&format!("name: {}", name)); + } else if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { + let preview = if text.len() > 50 { + format!("{}...", &text[..50]) + } else { + text.to_string() + }; + summary.push_str(&preview); + } else if let Some(path) = obj.get("path").and_then(|v| v.as_str()) { + summary.push_str(&format!("path: {}", path)); + } else { + // Just stringify the object + let s = serde_json::to_string(item).ok()?; + let preview = if s.len() > 80 { + format!("{}...", &s[..80]) + } else { + s + }; + summary.push_str(&preview); + } + } else if let Some(s) = item.as_str() { + let preview = if s.len() > 80 { + format!("{}...", &s[..80]) + } else { + s.to_string() + }; + summary.push_str(&preview); + } + summary.push('\n'); + } + + if count > 5 { + summary.push_str(&format!("... and {} more\n", count - 5)); + } + + return Some(summary); + } + + // For objects, extract key fields + if let Some(obj) = parsed.as_object() { + let key_fields = ["text", "content", "name", "path", "message", "result", "total", "count"]; + let mut summary = String::new(); + + for key in &key_fields { + if let Some(val) = obj.get(*key) { + if let Some(s) = val.as_str() { + summary.push_str(&format!("{}: {}\n", key, s)); + } else if let Some(n) = val.as_u64() { + summary.push_str(&format!("{}: {}\n", key, n)); + } + } + } + + // If we extracted nothing useful, return original + if summary.is_empty() { + // Just return first 500 chars + let preview = if result.len() > 500 { + format!("{}...", &result[..500]) + } else { + result.to_string() + }; + return Some(preview); + } + + return Some(summary); + } + + None +} + +// ─── Tool Definitions ────────────────────────────────────────────────────── + +/// Built-in tool definitions (filesystem operations handled in-process). +fn builtin_tool_definitions() -> Vec { + vec![ + ToolDefinition { + r#type: "function".to_string(), + function: crate::inference::types::FunctionDefinition { + name: "list_directory".to_string(), + description: "List files and directories at the given path. \ + Returns name, type (file/dir), size, and modification date \ + for each entry. Use ~/path for home-relative paths." + .to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to list, e.g. ~/Desktop" + } + }, + "required": ["path"] + }), + }, + }, + ToolDefinition { + r#type: "function".to_string(), + function: crate::inference::types::FunctionDefinition { + name: "read_file".to_string(), + description: "Read the text contents of a file at the given path. \ + Returns the file content as a string. Only works for text files." + .to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to read, e.g. ~/Desktop/notes.txt" + } + }, + "required": ["path"] + }), + }, + }, + ] +} + +/// Build merged tool definitions: built-in + MCP tools from the registry. +/// +/// Built-in tools (`list_directory`, `read_file`) are suppressed when the MCP +/// registry already contains their equivalents (`filesystem.list_dir`, +/// `filesystem.read_file`). This avoids confusing the model with near-duplicate +/// tools, which causes it to pick the wrong one or get stuck in loops. +fn build_all_tool_definitions(mcp_client: &McpClient) -> Vec { + // Map of built-in tool name → MCP equivalent that supersedes it + let builtin_mcp_equivalents: &[(&str, &str)] = &[ + ("list_directory", "filesystem.list_dir"), + ("read_file", "filesystem.read_file"), + ]; + + // Only include built-ins whose MCP equivalent is NOT in the registry + let mut tools: Vec = builtin_tool_definitions() + .into_iter() + .filter(|tool| { + let name = &tool.function.name; + !builtin_mcp_equivalents.iter().any(|(builtin, mcp)| { + name == builtin && mcp_client.registry.get_tool(mcp).is_some() + }) + }) + .collect(); + + // Append MCP tool definitions from the registry + let mcp_tools = mcp_client.registry.to_openai_tools(); + for mcp_tool_json in mcp_tools { + if let Ok(tool_def) = serde_json::from_value::(mcp_tool_json) { + tools.push(tool_def); + } + } + + tools +} + +/// Build tool definitions from category meta-tools (two-pass mode). +/// +/// Each category becomes a synthetic OpenAI function with a single `"intent"` +/// parameter. The model calls these to signal which capability areas it needs. +/// Built-in tools (`list_directory`, `read_file`) are always included. +fn build_category_tool_definitions(cat_registry: &CategoryRegistry) -> Vec { + let mut tools = builtin_tool_definitions(); + + let cat_tools = cat_registry.to_openai_tools(); + for cat_json in cat_tools { + if let Ok(tool_def) = serde_json::from_value::(cat_json) { + tools.push(tool_def); + } + } + + tools +} + +// ─── Tool Execution ────────────────────────────────────────────────────── + +/// Execute a built-in tool call and return the result as a string. +fn execute_builtin_tool(name: &str, arguments: &serde_json::Value) -> String { + match name { + "list_directory" => { + let path = arguments + .get("path") + .and_then(|v| v.as_str()) + .unwrap_or("."); + match super::filesystem::list_directory(path.to_string()) { + Ok(entries) => { + if entries.is_empty() { + "Directory is empty.".to_string() + } else { + let mut lines = Vec::new(); + for e in &entries { + let type_icon = if e.entry_type == "dir" { + "📁" + } else { + "📄" + }; + let size_str = if e.entry_type == "dir" { + String::new() + } else { + format_file_size(e.size) + }; + lines.push(format!( + "{} {} {}", + type_icon, e.name, size_str + )); + } + lines.join("\n") + } + } + Err(e) => format!("Error: {e}"), + } + } + "read_file" => { + let path = arguments + .get("path") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let resolved = if path.starts_with('~') { + if let Some(home) = dirs::home_dir() { + home.join(path.strip_prefix("~/").unwrap_or(path)) + } else { + std::path::PathBuf::from(path) + } + } else { + std::path::PathBuf::from(path) + }; + match std::fs::read_to_string(&resolved) { + Ok(content) => { + if content.len() > 8000 { + format!( + "{}\n\n[... truncated, showing first ~8000 chars of {} total]", + truncate_utf8(&content, 8000), + content.len() + ) + } else { + content + } + } + Err(e) => format!("Error reading file: {e}"), + } + } + _ => format!("Unknown built-in tool: {name}"), + } +} + +// ─── Tool Execution Outcome ───────────────────────────────────────────────── + +/// Typed result from executing a single tool call in the agent loop. +/// +/// Preserves the success/failure distinction through types instead of string +/// matching. The agent loop uses this to: +/// - Feed the right text back to the model (via `model_text()`) +/// - Track error patterns for loop detection (via `is_error()`) +/// - Build correction hints from `ToolResolution` suggestions +#[derive(Debug)] +#[allow(dead_code)] // tool_name fields used for Debug output and future ToolRouter integration +enum ToolExecutionOutcome { + /// Tool executed successfully. + Success { tool_name: String, text: String }, + + /// Tool exists but returned an application-level error + /// (e.g., "file not found", "permission denied"). + ToolError { tool_name: String, text: String }, + + /// Tool name not found in the registry. `resolution` carries the + /// registry's analysis (suggestions, nearest matches, etc.). + UnknownTool { + tool_name: String, + resolution: ToolResolution, + text: String, + }, + + /// Infrastructure error: timeout, server crash, transport failure. + InfraError { tool_name: String, text: String }, +} + +impl ToolExecutionOutcome { + /// The text to feed back to the model as the tool result message. + fn model_text(&self) -> &str { + match self { + Self::Success { text, .. } + | Self::ToolError { text, .. } + | Self::UnknownTool { text, .. } + | Self::InfraError { text, .. } => text, + } + } + + /// Whether this outcome represents an error (any variant except Success). + fn is_error(&self) -> bool { + !matches!(self, Self::Success { .. }) + } +} + +/// Minimum similarity score (0.0–1.0) for auto-correcting tool names. +/// +/// Below this threshold, the registry returns `NotFound` instead of +/// `Corrected`. Set conservatively to avoid correcting to the wrong tool. +const TOOL_RESOLUTION_THRESHOLD: f64 = 0.5; + +/// Execute a tool call: built-in tools run in-process, MCP tools route +/// through the McpClient. +/// +/// Tool names are resolved via `ToolRegistry::resolve()` which handles: +/// - Exact matches (tool exists as-is) +/// - Unprefixed names (model dropped the `server.` prefix) +/// - Fuzzy correction (model hallucinated a similar name) +/// +/// Results are capped at `MAX_TOOL_RESULT_CHARS` to prevent a single large +/// result from consuming the entire context window budget. +async fn execute_tool( + name: &str, + arguments: &serde_json::Value, + mcp_client: &mut McpClient, +) -> ToolExecutionOutcome { + // Built-in tools (handled in-process for speed) + if name == "list_directory" || name == "read_file" { + let text = truncate_tool_result(&execute_builtin_tool(name, arguments), name); + return ToolExecutionOutcome::Success { + tool_name: name.to_string(), + text, + }; + } + + // Resolve tool name via the registry (exact → unprefixed → fuzzy) + let resolution = mcp_client.registry.resolve(name, TOOL_RESOLUTION_THRESHOLD); + + let resolved_name = match &resolution { + ToolResolution::Exact(n) => n.clone(), + ToolResolution::Unprefixed { resolved, original } => { + tracing::info!( + original = %original, + resolved = %resolved, + "resolved unprefixed tool name" + ); + resolved.clone() + } + ToolResolution::Corrected { + resolved, + original, + score, + } => { + tracing::info!( + original = %original, + resolved = %resolved, + score = score, + "auto-corrected tool name via edit distance" + ); + resolved.clone() + } + ToolResolution::NotFound { + original, + suggestions, + } => { + let text = if suggestions.is_empty() { + format!( + "Unknown tool: '{original}'. Use fully-qualified names \ + (e.g., filesystem.list_dir, security.scan_for_secrets)." + ) + } else { + format!( + "Unknown tool: '{original}'. Did you mean: {}?", + suggestions.join(", ") + ) + }; + return ToolExecutionOutcome::UnknownTool { + tool_name: original.clone(), + resolution, + text, + }; + } + }; + + // Track whether we auto-corrected the name so we can annotate errors. + let correction_context: Option = match &resolution { + ToolResolution::Corrected { + original, resolved, .. + } => Some(format!( + "NOTE: '{original}' does not exist. Auto-corrected to '{resolved}'. " + )), + _ => None, + }; + + // Expand `~` prefixes in string arguments before MCP dispatch. + // Built-in tools handle tilde themselves; MCP servers expect absolute paths. + let expanded_arguments = expand_tilde_in_arguments(arguments); + + // Execute via MCP + match mcp_client + .call_tool(&resolved_name, expanded_arguments) + .await + { + Ok(result) => { + let raw_text = if result.success { + extract_mcp_result_text(&result.result) + } else { + result + .error + .unwrap_or_else(|| "Tool execution failed".to_string()) + }; + let text = truncate_tool_result(&raw_text, &resolved_name); + if result.success { + ToolExecutionOutcome::Success { + tool_name: resolved_name, + text, + } + } else { + // Prepend correction context so the model understands the + // mis-dispatch: e.g. "rename_file does not exist, corrected + // to move_file. ". + let annotated = if let Some(ctx) = &correction_context { + format!("{ctx}{text}") + } else { + text + }; + ToolExecutionOutcome::ToolError { + tool_name: resolved_name, + text: annotated, + } + } + } + Err(e) => { + let base = format!("MCP error for '{resolved_name}': {e}"); + let text = if let Some(ctx) = &correction_context { + format!("{ctx}{base}") + } else { + base + }; + ToolExecutionOutcome::InfraError { + tool_name: resolved_name, + text, + } + } + } +} + +// `is_incomplete_response` and `is_deflection_response` are now in +// `agent_core::response_analysis` — no longer called from the agent loop, +// but still tested for regression coverage and available for the Orchestrator. + +/// Detect when a model's final text claims task completion but tool history +/// disagrees — i.e., the model confabulated a summary. +/// +/// This catches the pattern where the model says "I've successfully renamed +/// all 9 files" but `move_file` never appeared in `tool_call_history`. +/// +/// Returns `true` when the response looks like a confabulated completion. +/// +/// NOTE: Currently only used by tests. The agent loop no longer calls this +/// (continuation heuristics were removed in favour of trusting the model). +/// Retained for the Orchestrator (ADR-009) and regression test coverage. +#[cfg(test)] +fn has_unverified_completion(text: &str, tool_call_history: &[String]) -> bool { + let lower = text.to_lowercase(); + + // Only trigger on text that claims the task is done. + let claims_done = [ + "successfully", + "completed", + "all files", + "renamed", + "processed all", + "all done", + "task complete", + "finished processing", + ]; + let claims_completion = claims_done.iter().any(|s| lower.contains(s)); + if !claims_completion { + return false; + } + + // Mutable operations the model might claim to have done. + // If the model claims completion but never called any of these, it confabulated. + // This list covers all mutable tools across all 13 MCP servers. + let mutable_tools = [ + // Filesystem + "move_file", + "write_file", + "copy_file", + "create_dir", + "move_to_trash", + "rename_file", + // Task management + "create_task", + "update_task", + "delete_task", + "complete_task", + // Calendar + "create_event", + "update_event", + "delete_event", + // Email + "send_email", + "draft_email", + // Security + "encrypt_file", + "decrypt_file", + "propose_cleanup", + // Knowledge + "index_document", + "delete_index", + // Document + "convert_document", + "merge_documents", + ]; + + let called_any_mutable = tool_call_history + .iter() + .any(|t| mutable_tools.iter().any(|m| t.contains(m))); + + // If model claims done AND actually called mutable tools → not confabulated. + if called_any_mutable { + return false; + } + + // If the model never called any mutable tool but claims completion, it + // MAY be confabulated. However, we need to distinguish two cases: + // + // 1. Read-only task genuinely complete: "What files are in Downloads?" → + // model calls list_dir, says "all done" → NOT confabulation. + // + // 2. Mutable task not executed: "Rename all screenshots" → model calls + // list_dir + OCR but says "all files renamed" → IS confabulation. + // + // Heuristic: check if the completion text specifically claims a mutable + // action (rename, create, move, delete, write, send, encrypt, etc.). + // Generic "all done" / "completed" without mutable verbs is likely a + // legitimate read-only task completion. + // Edge case: ZERO tool calls but model claims completion — always confabulated. + // The model literally did nothing but claims to have finished. + if tool_call_history.is_empty() { + return true; + } + + // The model called tools but none were mutable. Check if the completion + // text specifically claims a mutable action (rename, create, move, etc.). + // Generic "all done" / "completed" without mutable verbs is likely a + // legitimate read-only task completion. + let mutable_action_claims = [ + "renamed", + "moved", + "deleted", + "created", + "written", + "sent", + "encrypted", + "decrypted", + "copied", + "converted", + "merged", + "updated", + "modified", + "saved", + ]; + + let claims_mutable_action = mutable_action_claims.iter().any(|v| lower.contains(v)); + + // Only confabulation if model claims a mutable action it never performed. + // "All done" after read-only work → not confabulation (let it exit). + // "Successfully renamed all files" after only reading → confabulation. + claims_mutable_action +} + +/// Detect if the model is stuck calling the same tool with the same arguments. +/// +/// Returns the number of consecutive times the last tool call signature has +/// repeated. The caller compares this against `MAX_DUPLICATE_TOOL_CALLS`. +/// +/// A "signature" is `"tool_name|arguments_json"` — if the model calls +/// `list_directory(path="~/Downloads")` three rounds in a row, this returns 3. +fn consecutive_duplicate_count(history: &[(String, String)]) -> usize { + if history.is_empty() { + return 0; + } + let last = &history[history.len() - 1]; + let mut count = 1; + for entry in history.iter().rev().skip(1) { + if entry.0 == last.0 && entry.1 == last.1 { + count += 1; + } else { + break; + } + } + count +} + +/// Format a correction hint from the `ToolResolution` data collected during +/// a round where all tool calls failed. +/// +/// Uses the suggestions already computed by `ToolRegistry::resolve()` — no +/// extra registry queries needed. +fn format_correction_hint(unknown_tools: &[(String, ToolResolution)]) -> String { + if unknown_tools.is_empty() { + return "TOOL ERROR: All tool calls in this round failed. \ + Check your tool names and try again." + .to_string(); + } + + let mut parts = Vec::new(); + for (name, resolution) in unknown_tools { + match resolution { + ToolResolution::NotFound { suggestions, .. } if !suggestions.is_empty() => { + parts.push(format!( + "'{name}' does not exist. Did you mean: {}?", + suggestions.join(", ") + )); + } + _ => { + parts.push(format!("'{name}' does not exist.")); + } + } + } + + format!( + "TOOL ERROR: {}. Use ONLY tools listed in your available tools.", + parts.join(" ") + ) +} + +/// Expand `~` or `~/` prefixes to the user's home directory in any string +/// argument value that looks like a file path. +/// +/// MCP servers expect absolute paths. The LLM frequently generates `~/...` +/// despite system-prompt rules. Rather than relying on each MCP server to +/// handle tildes, we expand them centrally before dispatch. +/// +/// Also fixes cross-platform path hallucination: +/// - `/home//...` on macOS → `/Users//...` +/// - `/Users/{user}/...` (placeholder) → real home dir +/// - `/Users//...` → real home dir +/// +/// Only replaces `~` or `~/...` at the start of a string value. Values like +/// `~other_user/` or `~suffix` are left untouched (we can't resolve those). +fn expand_tilde_in_arguments(args: &serde_json::Value) -> serde_json::Value { + match args { + serde_json::Value::Object(map) => { + let mut out = serde_json::Map::new(); + for (k, v) in map { + out.insert(k.clone(), expand_tilde_in_arguments(v)); + } + serde_json::Value::Object(out) + } + serde_json::Value::String(s) => { + if let Some(fixed) = fix_path_string(s) { + serde_json::Value::String(fixed) + } else { + serde_json::Value::String(s.clone()) + } + } + serde_json::Value::Array(arr) => { + serde_json::Value::Array(arr.iter().map(expand_tilde_in_arguments).collect()) + } + other => other.clone(), + } +} + +/// Fix a single path string: tilde expansion + cross-platform path correction. +/// +/// Returns `Some(fixed)` if the path was modified, `None` if no fix was needed. +/// +/// The model hallucinates paths in several forms: +/// - `~/Documents` → tilde shorthand +/// - `Projects` → bare relative dir name +/// - `/home/user/...` → wrong OS prefix (Linux on macOS/Windows) +/// - `/Users/{user}/...` → template placeholders +/// - `C:\Users\{user}\...` → template placeholders (Windows) +/// +/// All corrections use `std::path::Path::join` so separators are always +/// correct for the target platform. +fn fix_path_string(s: &str) -> Option { + use std::path::MAIN_SEPARATOR; + + let home = dirs::home_dir()?; + let home_str = home.to_string_lossy(); + + // ── 1. Tilde expansion: ~/... → /... ────────────────────────────── + if s.starts_with("~/") || s.starts_with("~\\") { + let rest = &s[2..]; + return Some(home.join(rest).to_string_lossy().into_owned()); + } + if s == "~" { + return Some(home_str.into_owned()); + } + + // ── 2. Bare relative path that matches a well-known home subdirectory ─── + // Model outputs "Projects" or "Downloads" instead of an absolute path. + // Guard: skip strings that look like absolute paths or URLs. + let looks_absolute = s.starts_with('/') + || s.starts_with('\\') + || (s.len() >= 3 && s.as_bytes()[1] == b':'); // C:\ or D:\ + if !looks_absolute && !s.contains("://") { + let first_segment = s.split(&['/', '\\'][..]).next().unwrap_or(s); + let well_known = [ + "Desktop", + "Documents", + "Downloads", + "Projects", + "Pictures", + "Music", + "Videos", // Windows + "Movies", // macOS + "Library", // macOS + ]; + if well_known.iter().any(|d| d.eq_ignore_ascii_case(first_segment)) { + return Some(home.join(s).to_string_lossy().into_owned()); + } + } + + // ── 3. Foreign OS home prefix → real home dir ───────────────────────── + // LLMs hallucinate Linux-style /home/... on macOS/Windows and + // macOS-style /Users/... on Linux/Windows. A foreign prefix means + // the entire path is hallucinated — rewrite any username unconditionally. + let foreign_prefixes: &[&str] = if cfg!(target_os = "macos") { + &["/home/"] // /Users/ is native on macOS — handled separately below + } else if cfg!(target_os = "linux") { + &["/Users/"] // /home/ is native on Linux — handled separately below + } else { + &["/home/", "/Users/"] // both are foreign on Windows + }; + + for prefix in foreign_prefixes { + if let Some(after_prefix) = s.strip_prefix(prefix) { + if let Some(slash_idx) = after_prefix.find('/') { + let rest = &after_prefix[slash_idx + 1..]; + return Some(home.join(rest).to_string_lossy().into_owned()); + } + } + } + + // ── 4. Native OS home prefix with template placeholder ────────────── + // /Users/{user}/... on macOS, /home/{user}/... on Linux. + // Only rewrite if the "username" is a known template placeholder — + // never silently replace a real username on a multi-user system. + let native_prefix: &str = if cfg!(target_os = "macos") { + "/Users/" + } else if cfg!(target_os = "linux") { + "/home/" + } else { + "" // Windows native prefix handled in section 5 + }; + + if !native_prefix.is_empty() && s.starts_with(native_prefix) { + // Already matches our home dir — nothing to fix + if s.starts_with(&*home_str) { + return None; + } + + let after_prefix = &s[native_prefix.len()..]; + if let Some(slash_idx) = after_prefix.find('/') { + let placeholder = &after_prefix[..slash_idx]; + let rest = &after_prefix[slash_idx + 1..]; + + let is_template = + (placeholder.starts_with('{') && placeholder.ends_with('}')) + || (placeholder.starts_with('<') && placeholder.ends_with('>')) + || (placeholder.starts_with('[') && placeholder.ends_with(']')); + + if is_template { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + + // Common LLM placeholder words (not real usernames) + let placeholder_lower = placeholder.to_ascii_lowercase(); + let known_placeholders = ["user", "username", "your_name", "me"]; + if known_placeholders.contains(&placeholder_lower.as_str()) { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + } + } + + // ── 5. Windows C:\Users\{placeholder}\... ─────────────────────────────── + let win_prefix = "C:\\Users\\"; + let win_prefix_fwd = "C:/Users/"; // model may use forward slashes on Windows too + for prefix in &[win_prefix, win_prefix_fwd] { + if let Some(after_prefix) = s.strip_prefix(prefix) { + // Already matches our home dir — nothing to fix + if s.starts_with(&*home_str) { + return None; + } + + let sep_idx = after_prefix.find(&['/', '\\'][..]); + + if let Some(idx) = sep_idx { + let placeholder = &after_prefix[..idx]; + let rest = &after_prefix[idx + 1..]; + + let is_template = + (placeholder.starts_with('{') && placeholder.ends_with('}')) + || (placeholder.starts_with('<') && placeholder.ends_with('>')) + || (placeholder.starts_with('[') && placeholder.ends_with(']')); + + if is_template { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + + let placeholder_lower = placeholder.to_ascii_lowercase(); + let known_placeholders = ["user", "username", "your_name", "me"]; + if known_placeholders.contains(&placeholder_lower.as_str()) { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + } + } + } + + // Suppress unused-variable warning on platforms where MAIN_SEPARATOR is `/` + let _ = MAIN_SEPARATOR; + + None +} + +/// Extract readable text from an MCP tool result. +/// +/// MCP results follow the format: `{ "content": [{ "type": "text", "text": "..." }] }` +/// The `text` field may itself be a JSON-serialized result object (e.g. from Python +/// pydantic `.model_dump()` + `json.dumps()`), so we attempt to extract a human-readable +/// summary from known fields like "text", "content", "message", or "result". +fn extract_mcp_result_text(result: &Option) -> String { + let Some(value) = result else { + return "No result returned.".to_string(); + }; + + // Try standard MCP content format + if let Some(content_arr) = value.get("content").and_then(|c| c.as_array()) { + let texts: Vec<&str> = content_arr + .iter() + .filter_map(|item| item.get("text").and_then(|t| t.as_str())) + .collect(); + if !texts.is_empty() { + let raw = texts.join("\n"); + // The text might be a JSON-serialized tool result (e.g. from json.dumps). + // Try to parse it and extract human-readable content. + return unwrap_tool_result_json(&raw); + } + } + + // Fallback: stringify the entire result + match serde_json::to_string_pretty(value) { + Ok(s) => s, + Err(_) => format!("{value:?}"), + } +} + +/// If `raw` is a JSON object with known text fields, extract and format them +/// for human readability. Otherwise return the original string unchanged. +/// +/// This handles the case where Python MCP servers serialize their result model +/// via `json.dumps(result.model_dump())`, producing strings like: +/// `{"text": "extracted text...", "confidence": 0.9, "engine": "lfm_vision"}` +fn unwrap_tool_result_json(raw: &str) -> String { + let Ok(parsed) = serde_json::from_str::(raw) else { + return raw.to_string(); // Not JSON, return as-is + }; + + let obj = match parsed.as_object() { + Some(o) => o, + None => return raw.to_string(), // JSON but not an object + }; + + // Look for a primary text field in priority order + for key in &["text", "content", "message", "result", "output"] { + if let Some(val) = obj.get(*key).and_then(|v| v.as_str()) { + if !val.is_empty() { + // Build a summary with the primary text and any useful metadata + let mut parts = vec![val.to_string()]; + for meta_key in &["engine", "confidence", "language", "page_count"] { + if let Some(meta_val) = obj.get(*meta_key) { + let display = match meta_val { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Number(n) => n.to_string(), + serde_json::Value::Bool(b) => b.to_string(), + other => other.to_string(), + }; + parts.push(format!("[{meta_key}: {display}]")); + } + } + return parts.join("\n"); + } + } + } + + // JSON object but no recognized text field — return the formatted JSON + raw.to_string() +} + +/// Format bytes into human-readable size. +fn format_file_size(bytes: u64) -> String { + if bytes < 1024 { + format!("({bytes} B)") + } else if bytes < 1024 * 1024 { + format!("({:.1} KB)", bytes as f64 / 1024.0) + } else { + format!("({:.1} MB)", bytes as f64 / (1024.0 * 1024.0)) + } +} + +/// Emit context budget to the frontend. +fn emit_context_budget( + app_handle: &tauri::AppHandle, + mgr: &ConversationManager, + session_id: &str, +) { + use tauri::Emitter; + if let Ok(budget) = mgr.get_budget(session_id) { + let _ = app_handle.emit( + "context-budget", + serde_json::json!({ + "total": budget.total, + "systemPrompt": budget.system_prompt, + "toolDefinitions": budget.tool_definitions, + "conversationHistory": budget.conversation_history, + "outputReservation": budget.output_reservation, + "remaining": budget.remaining, + }), + ); + } +} + +// ─── Commands ─────────────────────────────────────────────────────────────── + +/// Start or resume a chat session. +/// +/// On first launch, creates a new session. On subsequent app opens, +/// returns the most recent session that has user messages. +/// If explicitly called with `force_new = true`, always creates a new session. +#[tauri::command] +pub async fn start_session( + force_new: Option, + state: tauri::State<'_, Mutex>, + mcp_state: tauri::State<'_, TokioMutex>, +) -> Result { + // Phase 1: Check for resumable sessions (lock ConversationManager, then drop). + // std::sync::MutexGuard is !Send, so it MUST be dropped before any .await. + { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + if force_new != Some(true) { + if let Ok(sessions) = mgr.db().list_sessions() { + for session in &sessions { + if let Ok(count) = mgr.db().message_count(&session.id) { + if count > 1 { + tracing::info!( + session_id = %session.id, + message_count = count, + "resuming existing session" + ); + return Ok(SessionInfo { + session_id: session.id.clone(), + resumed: true, + }); + } + } + } + } + } + } // mgr lock dropped here — safe to .await below + + // Phase 2: Build dynamic system prompt from MCP registry (async lock). + // Check if two-pass mode should be noted in the system prompt. + let system_prompt = { + let mcp = mcp_state.lock().await; + let cwd = std::env::current_dir().unwrap_or_default(); + let two_pass_active = if let Ok(cfg_path) = find_config_path(&cwd) { + load_models_config(&cfg_path) + .ok() + .and_then(|cfg| cfg.two_pass_tool_selection) + .unwrap_or(false) + && mcp.registry.len() > TWO_PASS_MIN_TOOLS + } else { + false + }; + build_system_prompt(&mcp.registry, two_pass_active) + }; + + // Phase 3: Create the new session (re-acquire ConversationManager). + let session_id = Uuid::new_v4().to_string(); + + { + let mut mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + mgr.new_session(&session_id, &system_prompt) + .map_err(|e| format!("Failed to create session: {e}"))?; + + // Set accurate system prompt budget from the actual dynamic prompt + let actual_prompt_tokens = + crate::agent_core::tokens::estimate_system_prompt_tokens(&system_prompt); + mgr.set_system_prompt_budget(actual_prompt_tokens); + + tracing::info!( + session_id = %session_id, + prompt_tokens = actual_prompt_tokens, + "new chat session created with dynamic system prompt" + ); + } + + Ok(SessionInfo { + session_id, + resumed: false, + }) +} + +/// Send a user message and get an assistant response. +/// +/// Implements the agent loop: +/// 1. Persist user message, build history +/// 2. Call LLM with tool definitions (built-in + MCP) +/// 3. If model returns tool calls → execute them → feed results back → repeat +/// 4. When model returns text → stream it to frontend +#[tauri::command] +#[allow(clippy::too_many_arguments)] +pub async fn send_message( + session_id: String, + content: String, + working_directory: Option, + app_handle: tauri::AppHandle, + state: tauri::State<'_, Mutex>, + mcp_state: tauri::State<'_, TokioMutex>, + permission_state: tauri::State<'_, TokioMutex>, + pending_confirm: tauri::State<'_, PendingConfirmation>, + sampling_state: tauri::State<'_, TokioMutex>, + in_flight: tauri::State<'_, crate::InFlightRequests>, +) -> Result<(), String> { + use tauri::Emitter; + + // Request deduplication: check if there's already a request in flight for this session + { + let mut in_flight_guard = in_flight.lock().await; + if in_flight_guard.get(&session_id) == Some(&true) { + tracing::warn!(session_id = %session_id, "duplicate request ignored"); + return Ok(()); // Silently ignore duplicate request + } + in_flight_guard.insert(session_id.clone(), true); + } + + // Generate trace ID for this request (for correlation across logs) + let trace_id = uuid::Uuid::new_v4().to_string()[..8].to_string(); + + tracing::info!(trace_id = %trace_id, session_id = %session_id, content_len = content.len(), "starting message processing"); + + // Read sampling config once at the start of this request. + let sampling_cfg = sampling_state.lock().await.clone(); + let tool_turn_sampling = SamplingOverrides { + temperature: Some(sampling_cfg.tool_temperature), + top_p: Some(sampling_cfg.tool_top_p), + }; + let conversational_sampling = SamplingOverrides { + temperature: Some(sampling_cfg.conversational_temperature), + top_p: Some(sampling_cfg.conversational_top_p), + }; + + // 1. Persist user message and build conversation history + let mut messages = { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + mgr.add_user_message(&session_id, &content) + .map_err(|e| format!("Failed to save user message: {e}"))?; + + let evicted = mgr + .evict_if_needed(&session_id) + .map_err(|e| format!("Eviction error: {e}"))?; + if evicted > 0 { + tracing::info!(evicted_tokens = evicted, "evicted old messages"); + } + + mgr.build_chat_messages(&session_id) + .map_err(|e| format!("Failed to build messages: {e}"))? + }; + + // 1b. Inject date context directly into the user message when temporal words + // are detected. Small LLMs (24B) have strong training priors for 2023/2024 + // dates and will ignore system prompt dates. Putting the date IN the user + // message forces the model to see it as part of the query itself. + { + use chrono::{Datelike, Duration}; + + let content_lower = content.to_lowercase(); + let has_temporal = content_lower.contains("today") + || content_lower.contains("tomorrow") + || content_lower.contains("this week") + || content_lower.contains("next week") + || content_lower.contains("yesterday") + || content_lower.contains("calendar") + || content_lower.contains("schedule") + || content_lower.contains("meeting"); + + if has_temporal { + let now = chrono::Local::now(); + let today_str = now.format("%Y-%m-%d").to_string(); + let tomorrow_str = (now + Duration::days(1)).format("%Y-%m-%d").to_string(); + let weekday_num = now.weekday().num_days_from_monday(); + let week_start = (now - Duration::days(weekday_num as i64)) + .format("%Y-%m-%d") + .to_string(); + let week_end = (now + Duration::days((6 - weekday_num) as i64)) + .format("%Y-%m-%d") + .to_string(); + + let date_prefix = format!( + "[Today is {today_str}. Tomorrow is {tomorrow_str}. \ + This week is {week_start} to {week_end}.]\n" + ); + + // Find the last user message and prepend the date context + if let Some(last_user_msg) = messages + .iter_mut() + .rev() + .find(|m| m.role == crate::inference::types::Role::User) + { + if let Some(ref mut msg_content) = last_user_msg.content { + let original = msg_content.clone(); + msg_content.clear(); + msg_content.push_str(&date_prefix); + msg_content.push_str(&original); + tracing::info!( + date_injected = %today_str, + "injected date context into user message" + ); + } + } + } + } + + // 1b2. Inject working folder PATH (not file listing) into the user message. + // Same strategy as date injection: small LLMs ignore system prompt paths + // and hallucinate /path/to/... from training data. Putting the folder + // path IN the user message makes it impossible to ignore. + // + // IMPORTANT: Only the path goes here, NOT the file listing. If we put + // files in the user message, the model skips tool calls (it already has + // the answer) and the user never sees the tool trace UI. The full file + // listing stays in the system prompt to guide tool argument selection. + if let Some(ref dir) = working_directory { + if let Some(last_user_msg) = messages + .iter_mut() + .rev() + .find(|m| m.role == crate::inference::types::Role::User) + { + if let Some(ref mut msg_content) = last_user_msg.content { + let folder_prefix = format!( + "[Working folder: {dir}. Use tools on files in this folder.]\n" + ); + + // Prepend — but AFTER any date prefix that may already be there + let original = msg_content.clone(); + msg_content.clear(); + if original.starts_with("[Today is") { + // Date prefix exists — insert folder after it + if let Some(newline_pos) = original.find("]\n") { + let after_date = newline_pos + 2; // skip "]\n" + msg_content.push_str(&original[..after_date]); + msg_content.push_str(&folder_prefix); + msg_content.push_str(&original[after_date..]); + } else { + msg_content.push_str(&folder_prefix); + msg_content.push_str(&original); + } + } else { + msg_content.push_str(&folder_prefix); + msg_content.push_str(&original); + } + + tracing::info!( + working_directory = %dir, + "injected working folder into user message" + ); + } + } + } + + // 1c. Inject working directory context + file listing into the system message. + // This is a per-request overlay — not persisted in the DB — so it + // automatically reflects the user's current folder selection. + // Including the actual file listing is a product-level optimization: + // same pattern as Cowork's project indexing — the model sees concrete + // file names without needing to call list_dir first. + const MAX_FOLDER_ENTRIES: usize = 50; + + if let Some(ref dir) = working_directory { + let mut file_count: usize = 0; + if let Some(system_msg) = messages.first_mut() { + if system_msg.role == crate::inference::types::Role::System { + if let Some(ref mut content) = system_msg.content { + // Build the working folder context block with XML tags + let mut folder_ctx = format!( + "\n\ + Use ONLY the file paths listed below. \ + Do NOT invent or guess paths." + ); + + // List directory contents (skip hidden files, cap at 50) + if let Ok(entries) = std::fs::read_dir(dir) { + let mut files: Vec = entries + .filter_map(|e| e.ok()) + .filter(|e| { + !e.file_name() + .to_string_lossy() + .starts_with('.') + }) + .map(|e| { + let full_path = + e.path().to_string_lossy().into_owned(); + if e.path().is_dir() { + format!(" {full_path}/") + } else { + format!(" {full_path}") + } + }) + .collect(); + files.sort(); + + let total = files.len(); + file_count = total; + if total > MAX_FOLDER_ENTRIES { + files.truncate(MAX_FOLDER_ENTRIES); + files.push(format!( + " (and {} more files...)", + total - MAX_FOLDER_ENTRIES + )); + } + if !files.is_empty() { + folder_ctx.push_str("\nFiles:\n"); + folder_ctx.push_str(&files.join("\n")); + } + } + + folder_ctx.push_str("\n\n"); + + // RECENCY REMINDER: shorter repetition block for the end. + // Research shows repeating key instructions at the end improves + // accuracy for smaller models (primacy + recency positions). + let mut folder_reminder = format!( + "\n\n\n\ + working_folder = {dir}\n\ + Files:" + ); + if let Ok(entries) = std::fs::read_dir(dir) { + let mut files: Vec = entries + .filter_map(|e| e.ok()) + .filter(|e| { + !e.file_name() + .to_string_lossy() + .starts_with('.') + }) + .map(|e| { + let full_path = + e.path().to_string_lossy().into_owned(); + format!(" {full_path}") + }) + .collect(); + files.sort(); + if files.len() > MAX_FOLDER_ENTRIES { + files.truncate(MAX_FOLDER_ENTRIES); + } + if !files.is_empty() { + folder_reminder.push('\n'); + folder_reminder.push_str(&files.join("\n")); + } + } + folder_reminder.push_str( + "\nUse ONLY these paths. Do NOT invent paths.\n\ + " + ); + + // SANDWICH PATTERN: insert working folder at TOP and BOTTOM + // of system prompt. The model sees the file paths at the + // strongest positions (primacy + recency). + let original = content.clone(); + content.clear(); + + // TOP: Insert after the first paragraph (identity intro) + if let Some(pos) = original.find("\n\n") { + content.push_str(&original[..pos]); + content.push_str("\n\n"); + content.push_str(&folder_ctx); + content.push_str(&original[pos..]); + } else { + content.push_str(&folder_ctx); + content.push_str("\n\n"); + content.push_str(&original); + } + + // BOTTOM: Append reminder at the very end + content.push_str(&folder_reminder); + } + } + } + tracing::info!( + working_directory = %dir, + file_count, + "injected working folder into system prompt" + ); + } + + // 2. Create inference client and build merged tool list + let cwd = std::env::current_dir().unwrap_or_default(); + let config_path = + find_config_path(&cwd).map_err(|e| format!("Config error: {e}"))?; + let config = + load_models_config(&config_path).map_err(|e| format!("Config error: {e}"))?; + let mut client = InferenceClient::from_config(config.clone()) + .map_err(|e| format!("Inference client error: {e}"))?; + + // 2a. Build tool definitions — either flat (all tools) or category meta-tools. + // Two-pass mode sends ~15 categories on the first turn (~1,500 tokens) + // instead of all ~67 tools (~8,670 tokens). Selected categories are + // expanded to real tools on subsequent turns. + let (mut tool_phase, mut tools) = { + let mcp = mcp_state.lock().await; + let use_two_pass = config.two_pass_tool_selection.unwrap_or(false) + && mcp.registry.len() > TWO_PASS_MIN_TOOLS; + + if use_two_pass { + let cat_registry = CategoryRegistry::build(&mcp.registry); + let cat_tools = build_category_tool_definitions(&cat_registry); + tracing::info!( + category_count = cat_registry.len(), + tool_count_saved = mcp.registry.len(), + "two-pass mode: sending category meta-tools instead of all tools" + ); + ( + ToolSelectionPhase::Categories { cat_registry }, + cat_tools, + ) + } else { + let all_tools = build_all_tool_definitions(&mcp); + (ToolSelectionPhase::Flat, all_tools) + } + }; + + // Measure actual tool definition tokens and update the budget. + // The default TOOL_DEFINITIONS_BUDGET (2000) was calibrated for stub schemas. + // With real JSON Schema from zod-to-json-schema, 15 tools consume 5000-8000+ + // tokens. Using the measured value ensures accurate eviction timing. + { + let tools_json: Vec = tools + .iter() + .filter_map(|t| serde_json::to_value(t).ok()) + .collect(); + let actual_tool_tokens = + crate::agent_core::tokens::estimate_tool_definitions_tokens(&tools_json); + + tracing::info!( + tool_count = tools.len(), + tool_tokens = actual_tool_tokens, + two_pass = matches!(tool_phase, ToolSelectionPhase::Categories { .. }), + "measured actual tool definition tokens" + ); + + let mut mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.set_tool_definitions_budget(actual_tool_tokens); + } + + // Response text — set by either the orchestrator or the agent loop. + let mut full_response = String::new(); + // Set to true when the orchestrator already persisted the response to DB. + let mut already_persisted = false; + + // 2b. Dual-model orchestrator (ADR-009) — if enabled, try the planner+router + // pipeline before falling into the single-model agent loop. + if let Some(ref orch_config) = config.orchestrator { + if orch_config.enabled { + tracing::info!("orchestrator enabled — attempting dual-model pipeline"); + match crate::agent_core::orchestrator::orchestrate_dual_model( + &session_id, + &content, + &messages, + &config, + orch_config, + &app_handle, + &state, + &mcp_state, + ) + .await + { + Ok(result) if !result.fell_back => { + // Fix F3: Check if orchestrator "succeeded" but no tools were + // actually called. This happens when the router fails to produce + // bracket-format tool calls for every step. + let any_tool_called = result + .step_results + .iter() + .any(|r| r.tool_called.is_some()); + + if !result.all_steps_succeeded && !any_tool_called { + tracing::warn!( + session_id = %session_id, + failed_steps = result.step_results.len(), + "orchestrator: no tools called — falling back to single-model" + ); + // Fall through to single-model agent loop + } else { + tracing::info!( + steps = result.step_results.len(), + all_succeeded = result.all_steps_succeeded, + tools_called = any_tool_called, + "orchestrator completed — skipping single-model loop" + ); + // Set the response so the normal completion path (step 5) + // emits the properly-formatted stream-complete event. + // The orchestrator already persisted the message to the DB. + full_response = result.synthesis; + already_persisted = true; + } + } + Ok(_) => { + tracing::warn!( + "orchestrator fell back — continuing to single-model agent loop" + ); + } + Err(e) => { + tracing::warn!( + error = %e, + "orchestrator error — continuing to single-model agent loop" + ); + } + } + } + } + + // 3. Agent loop: call model → execute tools → repeat + // Variables used by both the agent loop and the force-summary path. + let mut empty_response_count: usize = 0; + let mut tool_call_history: Vec = Vec::new(); + + // ── Turn-level tool call accumulator ────────────────────────────── + // The bracket format emits one tool call per inference round, so a + // multi-tool response spans multiple rounds: + // assistant(toolCalls:[A]) → tool(resultA) → assistant(toolCalls:[B]) → ... + // + // To present this as a single "2 tools executed" block in the UI, we + // accumulate all tool calls under a stable message ID and re-emit the + // growing list on each round. The frontend upserts by ID. + let turn_message_id = chrono::Utc::now().timestamp_millis(); + let mut turn_tool_calls: Vec = Vec::new(); + + // Skip entirely if the orchestrator already produced a response. + if full_response.is_empty() { + + // Track (tool_name, arguments) pairs to detect duplicate calls + let mut tool_call_signatures: Vec<(String, String)> = Vec::new(); + let mut consecutive_error_rounds: usize = 0; + let mut tool_failure_counts: std::collections::HashMap = + std::collections::HashMap::new(); + + for round in 0..MAX_TOOL_ROUNDS { + // ── Token budget gate ────────────────────────────────────────── + // Before each LLM call, check that we have enough remaining + // tokens for a productive round. If not, break early to avoid + // context overflow and degraded model quality. + { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + let budget = mgr + .get_budget(&session_id) + .map_err(|e| format!("Budget error: {e}"))?; + if budget.remaining < MIN_ROUND_TOKEN_BUDGET { + tracing::warn!( + round = round, + remaining = budget.remaining, + threshold = MIN_ROUND_TOKEN_BUDGET, + "token budget exhausted — ending agent loop" + ); + break; + } + } + + tracing::info!( + session_id = %session_id, + round = round, + message_count = messages.len(), + total_content_bytes = messages.iter() + .map(|m| m.content.as_deref().unwrap_or("").len()) + .sum::(), + "=== AGENT LOOP ROUND START ===" + ); + + let mut round_text = String::new(); + let mut tool_calls_detected: Vec = Vec::new(); + + // Measure model inference time (from request to full response parsed). + let inference_start = std::time::Instant::now(); + + match client + .chat_completion_stream(messages.clone(), Some(tools.clone()), Some(tool_turn_sampling)) + .await + { + Ok(stream) => { + futures::pin_mut!(stream); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Some(token) = &chunk.token { + round_text.push_str(token); + if tool_calls_detected.is_empty() { + let _ = app_handle.emit( + "stream-token", + token.clone(), + ); + } + } + if let Some(ref calls) = chunk.tool_calls { + for tc in calls { + if !tool_calls_detected + .iter() + .any(|existing| existing.id == tc.id) + { + tool_calls_detected.push(tc.clone()); + } + } + } + } + Err(e) => { + tracing::warn!( + round = round, + error = %e, + "stream error in agent loop" + ); + // Don't abort the whole loop — treat as empty + // response and let retry logic handle it + break; + } + } + } + } + Err(e) => { + let fallback = + crate::inference::client::static_fallback_response(); + if let Some(token) = &fallback.token { + full_response = token.clone(); + let _ = app_handle.emit("stream-token", token.clone()); + } + tracing::warn!(error = %e, "all models unavailable, using static fallback"); + break; + } + } + + let inference_time_ms = inference_start.elapsed().as_millis() as u64; + + tracing::info!( + session_id = %session_id, + round = round, + round_text_len = round_text.len(), + tool_calls_count = tool_calls_detected.len(), + tool_names = ?tool_calls_detected.iter().map(|tc| tc.name.as_str()).collect::>(), + inference_time_ms = inference_time_ms, + "=== MODEL RESPONSE ===" + ); + + // ── Handle empty response (0 text AND 0 tool calls) ──────── + // This is abnormal — typically caused by timeout, context overflow, + // or model confusion. Retry a limited number of times, then force + // a summary. + if tool_calls_detected.is_empty() && round_text.trim().is_empty() { + empty_response_count += 1; + tracing::warn!( + round = round, + empty_count = empty_response_count, + max_retries = MAX_EMPTY_RETRIES, + "model returned empty response (0 text, 0 tools)" + ); + + if empty_response_count >= MAX_EMPTY_RETRIES { + tracing::warn!("max empty retries reached — forcing summary"); + break; + } + + // Inject a nudge prompt instead of retrying with identical messages. + // Retrying unchanged context causes the same stall. A new user message + // gives the model fresh input to work from. + let nudge = if tool_call_history.is_empty() { + "You returned an empty response. Please answer the user's question \ + or call the appropriate tool now." + .to_string() + } else { + format!( + "You returned an empty response after processing {} tool call(s). \ + If there are more files to process, call the next tool now. \ + If the task is complete, provide a final summary of what was done.", + tool_call_history.len() + ) + }; + + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::User, + content: Some(nudge), + tool_call_id: None, + tool_calls: None, + }); + + tracing::info!( + round = round, + tools_completed = tool_call_history.len(), + "injected nudge prompt after empty response" + ); + continue; + } + + // Reset empty counter on any successful response + empty_response_count = 0; + + // ── Text response (0 tool calls) — accept and exit ───────── + // When the model returns text without tool calls, it has decided + // the task is complete. Trust the model's judgment and exit. + // + // This is the same pattern as Claude Code: model produces text → + // loop ends. If the user wants more, they say "continue." + // + // Previously, heuristic detectors (is_incomplete_response, + // has_unverified_completion, is_deflection_response) would + // second-guess the model and inject continuation prompts. These + // caused more harm than good — a valid 324-char system info + // summary would trigger "FM-3 deflection" because it contained + // "let me know", causing the model to spiral into unnecessary + // tool calls and produce a worse answer. + // + // Multi-step tasks that need continuation belong in the + // Orchestrator (ADR-009), not in heuristic string-matching. + if tool_calls_detected.is_empty() { + full_response.push_str(&round_text); + break; + } + + // ── Two-pass category expansion ───────────────────────────── + // If we're in Categories phase and the model called category meta-tools, + // expand them to real tools for subsequent rounds. Category "tool calls" + // are NOT executed — they just tell us which capability areas are needed. + if let ToolSelectionPhase::Categories { ref cat_registry } = tool_phase { + let mut selected_categories: Vec = Vec::new(); + let mut direct_tool_calls: Vec = Vec::new(); + + for tc in &tool_calls_detected { + if cat_registry.is_category(&tc.name) { + selected_categories.push(tc.name.clone()); + } else { + // Model called a real tool directly — handle gracefully + direct_tool_calls.push(tc.clone()); + } + } + + if !selected_categories.is_empty() { + // Expand categories to real tool names + let expanded_names = cat_registry.expand_categories(&selected_categories); + + // Build expanded tool definitions from the live registry + let expanded_defs = { + let mcp = mcp_state.lock().await; + let mut defs = builtin_tool_definitions(); + let mcp_tools = mcp.registry.to_openai_tools_filtered(&expanded_names); + for tool_json in mcp_tools { + if let Ok(td) = + serde_json::from_value::(tool_json) + { + defs.push(td); + } + } + defs + }; + + tracing::info!( + session_id = %session_id, + round = round, + categories = ?selected_categories, + expanded_tool_count = expanded_defs.len(), + "two-pass: expanded categories to real tools" + ); + + // Update token budget for the expanded (smaller) tool set + { + let tools_json: Vec = expanded_defs + .iter() + .filter_map(|t| serde_json::to_value(t).ok()) + .collect(); + let expanded_tokens = + crate::agent_core::tokens::estimate_tool_definitions_tokens( + &tools_json, + ); + let mut mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.set_tool_definitions_budget(expanded_tokens); + tracing::info!( + expanded_tool_tokens = expanded_tokens, + "updated token budget for expanded tools" + ); + } + + // Transition phase and update tools + tool_phase = ToolSelectionPhase::Expanded { + _selected_categories: selected_categories.clone(), + }; + tools = expanded_defs; + + // Inject an assistant message noting the category selection + // (in-memory only — not persisted, same pattern as continuation prompts) + let cat_text = format!( + "Selected capability areas: {}. Now proceeding with specific tools.", + selected_categories.join(", ") + ); + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::Assistant, + content: Some(cat_text), + tool_call_id: None, + tool_calls: None, + }); + + // If the model also called real tools directly, process them + if !direct_tool_calls.is_empty() { + tracing::info!( + direct_tool_count = direct_tool_calls.len(), + "two-pass: model also called real tools directly — \ + processing as fallback" + ); + tool_calls_detected = direct_tool_calls; + // Fall through to normal tool execution below + } else { + // Re-prompt with the expanded real tools — no tool execution + // this round. The model will now see the specific tools. + continue; + } + } + // If no categories were selected (model called only real tools), + // fall through to normal execution — graceful degradation. + } + + // ── Tool execution round ────────────────────────────────────── + + if !round_text.is_empty() { + let _ = app_handle.emit("stream-clear", ()); + } + + tracing::info!( + round = round, + tool_count = tool_calls_detected.len(), + "executing tool calls" + ); + + // Persist the assistant's tool-call message + { + let mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.add_tool_call_message(&session_id, &tool_calls_detected) + .map_err(|e| format!("Failed to save tool call: {e}"))?; + } + + // ── Accumulate tool calls for the turn ───────────────────────── + // Push this round's calls into the turn-level accumulator, then + // emit ALL accumulated calls under the same stable message ID. + // The frontend upserts by ID, so the ToolTrace grows in-place + // rather than spawning a new block each round. + for tc in &tool_calls_detected { + turn_tool_calls.push(serde_json::json!({ + "id": tc.id, + "name": tc.name, + "arguments": tc.arguments, + })); + } + + let _ = app_handle.emit( + "tool-call", + serde_json::json!({ + "id": turn_message_id, + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "assistant", + "toolCalls": turn_tool_calls, + "tokenCount": 10, + }), + ); + + // Execute each tool and collect typed outcomes. + let mut round_error_count: usize = 0; + let round_call_count = tool_calls_detected.len(); + let mut round_unknown_tools: Vec<(String, ToolResolution)> = Vec::new(); + + for tc in &tool_calls_detected { + // Auto-inject session_id into audit tool arguments so the model + // doesn't need to guess it. Audit tools expect a session_id param + // that matches the agent_core audit log's session column. + // Always override — the model often hallucinates placeholder values + // like "SESSION_ID_FROM_CURRENT_CONTEXT" or tool_call_ids. + let mut effective_arguments = if tc.name.starts_with("audit.") { + let mut args = tc.arguments.clone(); + if let Some(obj) = args.as_object_mut() { + obj.insert( + "session_id".to_string(), + serde_json::Value::String(session_id.clone()), + ); + } + args + } else { + tc.arguments.clone() + }; + + // ── HITL confirmation check ────────────────────────────── + // Built-in tools (list_directory, read_file) are always read-only. + // MCP tools check the registry's confirmation_required metadata. + // If the user has previously granted permission, skip the dialog. + let is_builtin = tc.name == "list_directory" || tc.name == "read_file"; + let needs_confirmation = !is_builtin && { + let mcp = mcp_state.lock().await; + mcp.registry.requires_confirmation(&tc.name) + }; + + let mut user_confirmed = !needs_confirmation; + + if needs_confirmation { + // Check if permission was previously granted + let already_allowed = { + let perms = permission_state.lock().await; + perms.check(&tc.name) == PermissionStatus::Allowed + }; + + if already_allowed { + user_confirmed = true; + tracing::debug!( + tool = %tc.name, + "skipping confirmation — permission granted" + ); + } else { + // Build and emit a confirmation request + let supports_undo = { + let mcp = mcp_state.lock().await; + mcp.registry.supports_undo(&tc.name) + }; + let preview = generate_preview(&tc.name, &effective_arguments); + let is_destructive = is_destructive_action(&tc.name); + + let request = ConfirmationRequest { + request_id: Uuid::new_v4().to_string(), + tool_name: tc.name.clone(), + arguments: effective_arguments.clone(), + preview, + confirmation_required: true, + undo_supported: supports_undo, + is_destructive, + }; + + tracing::info!( + tool = %tc.name, + request_id = %request.request_id, + is_destructive, + "awaiting user confirmation" + ); + + // Create a oneshot channel for this confirmation + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + { + let mut pending = pending_confirm.lock().await; + *pending = Some(resp_tx); + } + + // Emit confirmation-request event to frontend + let _ = app_handle.emit("confirmation-request", &request); + + // Wait for user response (blocks the agent loop) + match resp_rx.await { + Ok(ConfirmationResponse::Rejected) => { + tracing::info!( + tool = %tc.name, + "tool call rejected by user" + ); + // Write rejection to audit log + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let _ = mgr.db().insert_audit_entry( + &session_id, + &tc.name, + &effective_arguments, + None, + AuditStatus::RejectedByUser, + false, + 0, + ); + } + + let rejection_text = + format!("Tool '{}' was rejected by the user.", tc.name); + + // Emit rejection result to frontend + let _ = app_handle.emit( + "tool-result", + serde_json::json!({ + "id": chrono::Utc::now().timestamp_millis(), + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "tool", + "content": rejection_text, + "toolCallId": tc.id, + "toolResult": { + "success": false, + "result": rejection_text, + "toolCallId": tc.id, + "toolName": tc.name, + }, + "tokenCount": rejection_text.len() / 4, + }), + ); + + // Persist rejection so the model knows + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let result_json = + serde_json::Value::String(rejection_text); + mgr.add_tool_result_message( + &session_id, + &tc.id, + &result_json, + ) + .map_err(|e| { + format!("Failed to save tool result: {e}") + })?; + } + + // Add to conversation history for the LLM + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::Tool, + content: Some(format!( + "Tool '{}' was rejected by the user.", + tc.name + )), + tool_call_id: Some(tc.id.clone()), + tool_calls: None, + }); + + round_error_count += 1; + tool_call_history.push(tc.name.clone()); + tool_call_signatures.push(( + tc.name.clone(), + tc.arguments.to_string(), + )); + continue; + } + Ok(ConfirmationResponse::ConfirmedForSession) => { + let mut perms = permission_state.lock().await; + perms.grant(&tc.name, PermissionScope::Session); + user_confirmed = true; + } + Ok(ConfirmationResponse::ConfirmedAlways) => { + let mut perms = permission_state.lock().await; + perms.grant(&tc.name, PermissionScope::Always); + user_confirmed = true; + } + Ok(ConfirmationResponse::Confirmed) => { + user_confirmed = true; + } + Ok(ConfirmationResponse::EditedAndConfirmed { + new_arguments, + }) => { + effective_arguments = new_arguments; + user_confirmed = true; + } + Err(_) => { + tracing::warn!( + tool = %tc.name, + "confirmation channel closed — skipping tool" + ); + continue; + } + } + } + } + + // ── Duplicate call interception ────────────────────────── + // If the model is requesting the exact same tool+args as a + // previous call in this conversation, skip execution entirely. + // Instead, feed back a short "you already have this" nudge so + // the model transitions to summarising the results it already + // has. This is cheaper and more robust than executing the + // duplicate and relying on post-hoc detection to break the loop. + let call_sig = (tc.name.clone(), tc.arguments.to_string()); + let is_duplicate = tool_call_signatures.contains(&call_sig); + + if is_duplicate { + tracing::info!( + session_id = %session_id, + round = round, + tool = %tc.name, + "skipping duplicate tool call — returning cached nudge" + ); + + let nudge = format!( + "You already called {} with these exact arguments. \ + The results are in the conversation above. \ + Summarize those results for the user now.", + tc.name + ); + + // Record the signature so the hard-break counter still works + tool_call_history.push(tc.name.clone()); + tool_call_signatures.push(call_sig); + + // Push the nudge as the tool result so the model sees it + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::Tool, + content: Some(nudge.clone()), + tool_call_id: Some(tc.id.clone()), + tool_calls: None, + }); + + // Persist so windowed rebuild includes it + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let nudge_json = serde_json::Value::String(nudge); + mgr.add_tool_result_message(&session_id, &tc.id, &nudge_json) + .map_err(|e| format!("Failed to save nudge: {e}"))?; + } + + continue; // skip to next tool call (or next round) + } + + // ── Execute tool with Error Boundary ─────────────────────── + let tool_start = std::time::Instant::now(); + let outcome: ToolExecutionOutcome = { + let mut mcp = mcp_state.lock().await; + // Error boundary: wrap in a timeout to prevent hung tool executions + // The try_read_with_timeout helper handles both success and error cases + match tokio::time::timeout( + std::time::Duration::from_secs(120), // 2 min timeout per tool + execute_tool(&tc.name, &effective_arguments, &mut mcp), + ) + .await + { + Ok(result) => result, + Err(_elapsed) => { + tracing::error!( + tool = %tc.name, + timeout_secs = 120, + "tool execution timed out — caught by error boundary" + ); + ToolExecutionOutcome::InfraError { + tool_name: tc.name.clone(), + text: format!( + "Tool '{}' timed out after 120 seconds. Please try again or use a different tool.", + tc.name + ), + } + } + } + }; + let execution_time_ms = tool_start.elapsed().as_millis() as u64; + + let is_error = outcome.is_error(); + let result_text = outcome.model_text().to_string(); + + // ── Audit log write ────────────────────────────────────── + // Record every tool execution in the audit_log table so + // audit.get_tool_log / audit.generate_audit_report can read them. + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let audit_status = if is_error { + AuditStatus::Error + } else { + AuditStatus::Success + }; + let result_val = serde_json::Value::String(result_text.clone()); + if let Err(e) = mgr.db().insert_audit_entry( + &session_id, + &tc.name, + &effective_arguments, + Some(&result_val), + audit_status, + user_confirmed, + execution_time_ms, + ) { + tracing::warn!( + session_id = %session_id, + tool = %tc.name, + error = %e, + "failed to write audit log entry" + ); + } + } + + if is_error { + round_error_count += 1; + *tool_failure_counts.entry(tc.name.clone()).or_default() += 1; + } + + // Collect UnknownTool resolutions for correction hints + if let ToolExecutionOutcome::UnknownTool { + ref tool_name, + ref resolution, + .. + } = outcome + { + round_unknown_tools.push((tool_name.clone(), resolution.clone())); + } + + tool_call_history.push(tc.name.clone()); + tool_call_signatures.push(( + tc.name.clone(), + tc.arguments.to_string(), + )); + + if is_error { + tracing::warn!( + session_id = %session_id, + tool = %tc.name, + tool_call_id = %tc.id, + result_len = result_text.len(), + result_preview = %truncate_utf8(&result_text, 200), + execution_time_ms = execution_time_ms, + tools_completed = tool_call_history.len(), + "tool call FAILED" + ); + } else { + tracing::info!( + session_id = %session_id, + tool = %tc.name, + tool_call_id = %tc.id, + result_len = result_text.len(), + execution_time_ms = execution_time_ms, + tools_completed = tool_call_history.len(), + user_confirmed, + "tool execution complete" + ); + } + + let _ = app_handle.emit( + "tool-result", + serde_json::json!({ + "id": chrono::Utc::now().timestamp_millis(), + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "tool", + "content": result_text, + "toolCallId": tc.id, + "toolResult": { + "success": !is_error, + "result": result_text, + "toolCallId": tc.id, + "toolName": tc.name, + "executionTimeMs": execution_time_ms, + "inferenceTimeMs": inference_time_ms, + }, + "tokenCount": result_text.len() / 4, + }), + ); + + // Persist tool result in conversation + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let result_json = serde_json::Value::String(result_text); + mgr.add_tool_result_message( + &session_id, + &tc.id, + &result_json, + ) + .map_err(|e| format!("Failed to save tool result: {e}"))?; + } + } + + // ── Consecutive error round tracking ───────────────────────── + // If ALL tool calls in this round errored, the model may be stuck + // in a loop calling a non-existent tool (e.g., filesystem.rename_file). + // After MAX_CONSECUTIVE_ERROR_ROUNDS, inject a corrective hint using + // the suggestions already computed by ToolRegistry::resolve(). + if round_error_count > 0 && round_error_count == round_call_count { + consecutive_error_rounds += 1; + tracing::warn!( + session_id = %session_id, + round = round, + consecutive_error_rounds = consecutive_error_rounds, + failed_tools = ?tool_calls_detected.iter().map(|tc| tc.name.as_str()).collect::>(), + "all tool calls in round failed" + ); + + if consecutive_error_rounds >= MAX_CONSECUTIVE_ERROR_ROUNDS { + let hint = format_correction_hint(&round_unknown_tools); + + tracing::info!( + round = round, + hint_len = hint.len(), + "injecting tool correction hint after repeated failures" + ); + + // Persist the corrective hint as a user message + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + mgr.add_user_message(&session_id, &hint) + .map_err(|e| format!("Failed to save hint: {e}"))?; + } + + // Reset counter so the model gets another chance + consecutive_error_rounds = 0; + } + } else { + // At least one tool succeeded — reset the counter + consecutive_error_rounds = 0; + } + + // ── Per-tool failure circuit breaker ────────────────────────── + // Even when the per-round counter resets (because the model alternates + // between a succeeding tool and a failing one), the per-tool counter + // keeps accumulating. Once a tool hits MAX_SAME_TOOL_FAILURES, remove + // it from the definitions and inject a hard stop hint. + let stuck_tools: Vec = tool_failure_counts + .iter() + .filter(|(_, &count)| count >= MAX_SAME_TOOL_FAILURES) + .map(|(name, _)| name.clone()) + .collect(); + + if !stuck_tools.is_empty() { + let hint = format!( + "STOP: The following tools have each failed {} or more times and have been \ + removed: {}. Do NOT attempt to call them again. Respond to the user with \ + what you know so far, or try a completely different approach.", + MAX_SAME_TOOL_FAILURES, + stuck_tools.join(", ") + ); + + tracing::warn!( + session_id = %session_id, + round = round, + stuck_tools = ?stuck_tools, + "per-tool failure limit reached — removing stuck tools from definitions" + ); + + // Remove stuck tools from the active tool definitions + tools.retain(|t| !stuck_tools.contains(&t.function.name)); + + // Clear the counters for removed tools so we don't re-trigger + for name in &stuck_tools { + tool_failure_counts.remove(name); + } + + // Inject the hint as a user message + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + mgr.add_user_message(&session_id, &hint) + .map_err(|e| format!("Failed to save stuck-tool hint: {e}"))?; + } + } + + // ── Duplicate tool call detection ───────────────────────────── + // If the model is calling the same tool with the same arguments + // repeatedly (e.g., list_directory("~/Downloads") 3× in a row), + // the results won't change. Break to prevent wasting rounds. + let dup_count = consecutive_duplicate_count(&tool_call_signatures); + if dup_count >= MAX_DUPLICATE_TOOL_CALLS { + tracing::warn!( + session_id = %session_id, + round = round, + duplicate_count = dup_count, + tool = %tool_call_signatures.last().map(|(n, _)| n.as_str()).unwrap_or("?"), + "duplicate tool call detected — model is stuck, breaking loop" + ); + break; + } + + // ── Mid-loop eviction ─────────────────────────────────────── + // After persisting tool results, check if context window needs + // eviction before the next round. This prevents unbounded growth + // during long multi-step workflows. + { + let mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + let evicted = mgr + .evict_if_needed(&session_id) + .map_err(|e| format!("Eviction error: {e}"))?; + if evicted > 0 { + tracing::info!( + round = round, + evicted_tokens = evicted, + "mid-loop eviction" + ); + } + } + + // Rebuild messages (windowed — compress old tool results to save tokens) + messages = { + let mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.build_windowed_chat_messages(&session_id, 4) + .map_err(|e| format!("Failed to build messages: {e}"))? + }; + } + + } // end if full_response.is_empty() (skip agent loop when orchestrator succeeded) + + // 4. If the agent loop finished without generating text, force a + // summary. This can happen when: + // - All rounds were used on tool calls (normal for large batches) + // - Model returned empty responses (timeout / context overflow) + // - Streaming errors caused early exit + // + // Strategy: inject a short, explicit "summarize now" user message + // and call the model WITHOUT tools, so it MUST produce text. + if full_response.is_empty() { + tracing::info!( + session_id = %session_id, + rounds_used = empty_response_count, + tool_calls_total = tool_call_history.len(), + "forcing summary — injecting summarize prompt" + ); + + // Inject a constrained summary instruction that prevents confabulation. + // The model MUST only report results it actually received from tools. + let summary_instruction = crate::inference::types::ChatMessage { + role: crate::inference::types::Role::User, + content: Some( + "Based on the tool results above, provide a concise summary.\n\ + CRITICAL RULES:\n\ + - ONLY report results you actually received from tool calls above.\n\ + - If a file was not processed, say 'not processed' — do NOT guess or invent results.\n\ + - If no tool results are visible, say 'I was unable to complete the task.'\n\ + Do NOT call any more tools." + .to_string(), + ), + tool_call_id: None, + tool_calls: None, + }; + messages.push(summary_instruction); + + match client + .chat_completion_stream(messages, None, Some(conversational_sampling)) // No tools → model MUST produce text + .await + { + Ok(stream) => { + futures::pin_mut!(stream); + while let Some(chunk_result) = stream.next().await { + if let Ok(chunk) = chunk_result { + if let Some(token) = &chunk.token { + full_response.push_str(token); + let _ = app_handle.emit("stream-token", token.clone()); + } + } + } + } + Err(e) => { + tracing::warn!(error = %e, "summary call failed"); + } + } + + // If even the summary call returned nothing, use a static fallback + if full_response.is_empty() { + tracing::warn!("summary call also returned empty — using static fallback text"); + full_response = "I processed the requested files using the tools above. \ + You can see the individual results in the tool trace. \ + Please ask a follow-up question if you'd like me to continue." + .to_string(); + let _ = app_handle.emit("stream-token", full_response.clone()); + } + } + + // 5. Persist final assistant text response + // (skip if the orchestrator already persisted it) + { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + if !full_response.is_empty() && !already_persisted { + mgr.add_assistant_message(&session_id, &full_response) + .map_err(|e| format!("Failed to save assistant message: {e}"))?; + } + + emit_context_budget(&app_handle, &mgr, &session_id); + } + + // 5. Emit the complete message + let message = serde_json::json!({ + "id": chrono::Utc::now().timestamp_millis(), + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "assistant", + "content": full_response, + "tokenCount": full_response.len() / 4, + }); + + let _ = app_handle.emit("stream-complete", message); + + // Release in-flight lock + { + let mut in_flight_guard = in_flight.lock().await; + in_flight_guard.insert(session_id.clone(), false); + } + + Ok(()) +} + +/// Respond to a confirmation request from the agent loop. +/// +/// The frontend calls this when the user clicks Confirm/Cancel on a +/// confirmation dialog. The response is forwarded to the agent loop +/// via the pending oneshot channel. +#[tauri::command] +pub async fn respond_to_confirmation( + request_id: String, + response: serde_json::Value, + pending: tauri::State<'_, PendingConfirmation>, +) -> Result<(), String> { + tracing::info!( + request_id = %request_id, + response = %response, + "confirmation response received" + ); + + let parsed: ConfirmationResponse = serde_json::from_value(response) + .map_err(|e| format!("Invalid confirmation response: {e}"))?; + + let mut lock = pending.lock().await; + if let Some(tx) = lock.take() { + // oneshot::Sender::send returns Err if receiver was dropped + tx.send(parsed).map_err(|_| { + "Confirmation channel closed — agent loop may have timed out".to_string() + })?; + } else { + tracing::warn!( + request_id = %request_id, + "no pending confirmation — response ignored" + ); + } + + Ok(()) +} + +// ─── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent_core::response_analysis::is_incomplete_response; + + #[test] + fn test_unwrap_tool_result_json_extracts_text() { + // Simulates what Python MCP servers send: json.dumps(result.model_dump()) + let raw = r#"{"text": "LocalCowork OCR Test\nInvoice #12345", "confidence": 0.9, "engine": "lfm_vision"}"#; + let result = unwrap_tool_result_json(raw); + assert!(result.starts_with("LocalCowork OCR Test")); + assert!(result.contains("[engine: lfm_vision]")); + assert!(result.contains("[confidence: 0.9]")); + } + + #[test] + fn test_unwrap_tool_result_json_plain_text() { + let raw = "Just a plain text result"; + let result = unwrap_tool_result_json(raw); + assert_eq!(result, "Just a plain text result"); + } + + #[test] + fn test_unwrap_tool_result_json_no_text_field() { + let raw = r#"{"headers": ["col1", "col2"], "rows": [["a", "b"]]}"#; + let result = unwrap_tool_result_json(raw); + // No recognized text field, should return raw JSON + assert_eq!(result, raw); + } + + #[test] + fn test_extract_mcp_result_text_with_content_array() { + let value = serde_json::json!({ + "content": [{"type": "text", "text": "{\"text\": \"hello\", \"engine\": \"tesseract\"}"}] + }); + let result = extract_mcp_result_text(&Some(value)); + assert!(result.starts_with("hello")); + assert!(result.contains("[engine: tesseract]")); + } + + #[test] + fn test_extract_mcp_result_text_none() { + let result = extract_mcp_result_text(&None); + assert_eq!(result, "No result returned."); + } + + #[test] + fn test_truncate_tool_result_short() { + let result = truncate_tool_result("short result", "test_tool"); + assert_eq!(result, "short result"); + } + + #[test] + fn test_truncate_tool_result_long() { + let long = "x".repeat(10_000); + let result = truncate_tool_result(&long, "test_tool"); + assert!(result.len() < long.len()); + assert!(result.contains("[... truncated: showing first 6000 of 10000 chars]")); + } + + #[test] + fn test_is_incomplete_response_remaining() { + assert!(is_incomplete_response( + "I've processed 3 files. There are 4 remaining screenshots to rename." + )); + } + + #[test] + fn test_is_incomplete_response_next_file() { + assert!(is_incomplete_response( + "Renamed screenshot 1. Moving on to the next file." + )); + } + + #[test] + fn test_is_incomplete_response_complete() { + assert!(!is_incomplete_response( + "All screenshots have been renamed successfully." + )); + } + + #[test] + fn test_is_incomplete_response_no_signals() { + // No incomplete or complete signals — defaults to false (task done) + assert!(!is_incomplete_response( + "Here is the result of your request." + )); + } + + /// Helper to create an McpClient with registered tools for testing. + fn mcp_client_with_tools(tools: Vec<(&str, &str)>) -> McpClient { + use crate::mcp_client::types::{McpServersConfig, McpToolDefinition}; + + let config = McpServersConfig { + servers: std::collections::HashMap::new(), + }; + let mut client = McpClient::new(config, None); + + // Group tools by server name and register them + let mut server_tools: std::collections::HashMap<&str, Vec> = + std::collections::HashMap::new(); + for (server, tool) in tools { + server_tools + .entry(server) + .or_default() + .push(McpToolDefinition { + name: tool.to_string(), + description: format!("Test tool: {tool}"), + params_schema: serde_json::json!({"type": "object", "properties": {}}), + returns_schema: serde_json::json!({}), + confirmation_required: false, + undo_supported: false, + }); + } + for (server, defs) in server_tools { + client.registry.register_server_tools(server, defs); + } + + client + } + + #[test] + fn test_resolve_exact_match() { + let client = mcp_client_with_tools(vec![("filesystem", "move_file")]); + let resolution = client.registry.resolve("filesystem.move_file", 0.5); + assert!(matches!(resolution, ToolResolution::Exact(_))); + assert_eq!(resolution.resolved_name(), Some("filesystem.move_file")); + } + + #[test] + fn test_resolve_unprefixed() { + let client = mcp_client_with_tools(vec![ + ("filesystem", "move_file"), + ("filesystem", "copy_file"), + ("ocr", "extract_text_from_image"), + ]); + let resolution = client.registry.resolve("move_file", 0.5); + assert!(matches!(resolution, ToolResolution::Unprefixed { .. })); + assert_eq!(resolution.resolved_name(), Some("filesystem.move_file")); + } + + #[test] + fn test_resolve_unknown_unprefixed() { + let client = mcp_client_with_tools(vec![("filesystem", "move_file")]); + let resolution = client.registry.resolve("nonexistent_tool", 0.5); + assert!(matches!(resolution, ToolResolution::NotFound { .. })); + assert_eq!(resolution.resolved_name(), None); + } + + #[test] + fn test_resolve_wrong_server_prefix() { + let client = mcp_client_with_tools(vec![("filesystem", "move_file")]); + // "wrong_server" doesn't exist — no same-server tools to match against + let resolution = client.registry.resolve("wrong_server.move_file", 0.5); + assert!(matches!(resolution, ToolResolution::NotFound { .. })); + } + + #[test] + fn test_resolve_ambiguous_unprefixed() { + let client = mcp_client_with_tools(vec![ + ("ocr", "process"), + ("document", "process"), + ]); + // Ambiguous — two servers have "process" + let resolution = client.registry.resolve("process", 0.5); + assert!(matches!(resolution, ToolResolution::NotFound { .. })); + } + + #[test] + fn test_build_system_prompt_includes_server_names() { + use crate::mcp_client::registry::ToolRegistry; + use crate::mcp_client::types::McpToolDefinition; + + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![McpToolDefinition { + name: "list_dir".to_string(), + description: "List directory".to_string(), + params_schema: serde_json::json!({"type": "object"}), + returns_schema: serde_json::json!({}), + confirmation_required: false, + undo_supported: false, + }], + ); + registry.register_server_tools( + "email", + vec![McpToolDefinition { + name: "send_draft".to_string(), + description: "Send draft".to_string(), + params_schema: serde_json::json!({"type": "object"}), + returns_schema: serde_json::json!({}), + confirmation_required: true, + undo_supported: false, + }], + ); + + let prompt = build_system_prompt(®istry, false); + assert!(prompt.contains("filesystem (1)")); + assert!(prompt.contains("email (1)")); + assert!(prompt.contains("2 tools across 2 servers")); + assert!(prompt.contains("LocalCowork")); + // Should include XML-tagged rules section + assert!(prompt.contains("")); + assert!(prompt.contains("fully-qualified tool names")); + } + + #[test] + fn test_build_system_prompt_empty_registry() { + use crate::mcp_client::registry::ToolRegistry; + + let registry = ToolRegistry::new(); + let prompt = build_system_prompt(®istry, false); + assert!(prompt.contains("No MCP tools currently available")); + // Should still include the rules and examples sections + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("filesystem.list_dir")); + } + + #[test] + fn test_build_system_prompt_with_two_pass() { + use crate::mcp_client::registry::ToolRegistry; + + let registry = ToolRegistry::new(); + let prompt = build_system_prompt(®istry, true); + assert!(prompt.contains("category-level tools")); + assert!(prompt.contains("file_browse")); + } + + #[test] + fn test_build_system_prompt_has_precomputed_dates() { + use crate::mcp_client::registry::ToolRegistry; + + let registry = ToolRegistry::new(); + let prompt = build_system_prompt(®istry, false); + // Must contain the block with pre-computed dates + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("today =")); + assert!(prompt.contains("tomorrow =")); + assert!(prompt.contains("this_week =")); + assert!(prompt.contains("NEVER ask the user for a date")); + } + + // ── has_unverified_completion tests ────────────────────────────── + + #[test] + fn test_unverified_completion_claims_done_no_mutable_calls() { + // Model says "all files renamed" but history has no move_file + let history = vec![ + "filesystem.list_dir".to_string(), + "ocr.extract_text_from_image".to_string(), + "ocr.extract_text_from_image".to_string(), + ]; + assert!(has_unverified_completion( + "I've successfully renamed all 9 files.", + &history, + )); + } + + #[test] + fn test_unverified_completion_claims_done_with_mutable_calls() { + // Model says "all files renamed" AND move_file is in history — genuine + let history = vec![ + "filesystem.list_dir".to_string(), + "ocr.extract_text_from_image".to_string(), + "filesystem.move_file".to_string(), + ]; + assert!(!has_unverified_completion( + "I've successfully renamed all 9 files.", + &history, + )); + } + + #[test] + fn test_unverified_completion_no_completion_claim() { + // Model doesn't claim completion — no confabulation check needed + let history = vec!["filesystem.list_dir".to_string()]; + assert!(!has_unverified_completion( + "Here are the files I found on your desktop.", + &history, + )); + } + + #[test] + fn test_unverified_completion_empty_history() { + // Empty tool history + completion claim = confabulation + assert!(has_unverified_completion( + "All done! Finished processing everything.", + &[], + )); + } + + #[test] + fn test_unverified_completion_write_file_counts_as_mutable() { + // write_file is a mutable operation — should count + let history = vec!["filesystem.write_file".to_string()]; + assert!(!has_unverified_completion( + "Task complete. All files processed.", + &history, + )); + } + + #[test] + fn test_unverified_completion_create_task_counts_as_mutable() { + // create_task should now be recognized as mutable + let history = vec![ + "filesystem.read_file".to_string(), + "task.create_task".to_string(), + ]; + assert!(!has_unverified_completion( + "Successfully created the task.", + &history, + )); + } + + #[test] + fn test_unverified_completion_read_only_generic_done() { + // Read-only task (list files) saying "all done" — NOT confabulation. + // The model legitimately completed a read-only request. + let history = vec![ + "filesystem.list_dir".to_string(), + ]; + assert!(!has_unverified_completion( + "All done! Here are the files in your Downloads folder.", + &history, + )); + } + + #[test] + fn test_unverified_completion_read_only_claims_rename() { + // Read-only tools but claims "renamed" → confabulation + let history = vec![ + "filesystem.list_dir".to_string(), + "ocr.extract_text_from_image".to_string(), + ]; + assert!(has_unverified_completion( + "I've successfully renamed all 9 files.", + &history, + )); + } + + #[test] + fn test_unverified_completion_scan_then_complete() { + // Security scan (read-only) followed by "completed" → not confabulation + // (it's a genuinely complete read-only scan task) + let history = vec![ + "security.scan_for_pii".to_string(), + "security.scan_for_secrets".to_string(), + ]; + assert!(!has_unverified_completion( + "All done! Here's what I found in the scan.", + &history, + )); + } + + // ── consecutive_duplicate_count tests ──────────────────────────── + + #[test] + fn test_duplicate_count_empty() { + let history: Vec<(String, String)> = vec![]; + assert_eq!(consecutive_duplicate_count(&history), 0); + } + + #[test] + fn test_duplicate_count_single() { + let history = vec![("list_dir".into(), r#"{"path":"~/Downloads"}"#.into())]; + assert_eq!(consecutive_duplicate_count(&history), 1); + } + + #[test] + fn test_duplicate_count_three_identical() { + let history = vec![ + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ]; + assert_eq!(consecutive_duplicate_count(&history), 3); + } + + #[test] + fn test_duplicate_count_different_args() { + let history = vec![ + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Documents"}"#.into()), + ]; + assert_eq!(consecutive_duplicate_count(&history), 1); + } + + #[test] + fn test_duplicate_count_interrupted_by_different_tool() { + let history = vec![ + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("read_file".into(), r#"{"path":"file.txt"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ]; + // Only the last consecutive run counts (just 1) + assert_eq!(consecutive_duplicate_count(&history), 1); + } + + // ── expand_tilde_in_arguments tests ────────────────────────────── + + #[test] + fn test_expand_tilde_simple_path() { + let args = serde_json::json!({"path": "~/Documents/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!(!path.starts_with('~'), "tilde should be expanded: {path}"); + assert!(path.ends_with("/Documents/file.txt")); + } + + #[test] + fn test_expand_tilde_bare() { + let args = serde_json::json!({"path": "~"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!(!path.starts_with('~')); + assert!(!path.is_empty()); + } + + #[test] + fn test_expand_tilde_leaves_absolute_paths() { + let args = serde_json::json!({"path": "/Users/chintan/Documents/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!( + expanded["path"].as_str().unwrap(), + "/Users/chintan/Documents/file.txt" + ); + } + + #[test] + fn test_expand_tilde_leaves_other_user() { + // ~other_user/... should NOT be expanded + let args = serde_json::json!({"path": "~other_user/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!(expanded["path"].as_str().unwrap(), "~other_user/file.txt"); + } + + #[test] + fn test_expand_tilde_nested_object() { + let args = serde_json::json!({ + "source": "~/Desktop/a.png", + "destination": "/tmp/b.png", + "options": {"backup": "~/backup/"} + }); + let expanded = expand_tilde_in_arguments(&args); + assert!(!expanded["source"].as_str().unwrap().starts_with('~')); + assert_eq!(expanded["destination"].as_str().unwrap(), "/tmp/b.png"); + assert!(!expanded["options"]["backup"].as_str().unwrap().starts_with('~')); + } + + #[test] + fn test_expand_tilde_non_string_values() { + let args = serde_json::json!({"count": 42, "flag": true, "path": "~/file"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!(expanded["count"], 42); + assert_eq!(expanded["flag"], true); + assert!(!expanded["path"].as_str().unwrap().starts_with('~')); + } + + #[test] + fn test_expand_tilde_array_values() { + let args = serde_json::json!({"paths": ["~/a.txt", "/b.txt", "~/c.txt"]}); + let expanded = expand_tilde_in_arguments(&args); + let paths = expanded["paths"].as_array().unwrap(); + assert!(!paths[0].as_str().unwrap().starts_with('~')); + assert_eq!(paths[1].as_str().unwrap(), "/b.txt"); + assert!(!paths[2].as_str().unwrap().starts_with('~')); + } + + // ── fix_path_string: cross-platform path correction tests ─────── + + /// Helper: build the expected path using Path::join (platform-correct). + fn expected_home_join(suffix: &str) -> String { + dirs::home_dir() + .unwrap() + .join(suffix) + .to_string_lossy() + .into_owned() + } + + #[cfg(target_os = "macos")] + #[test] + fn test_fix_foreign_os_prefix() { + // On macOS, /home/ is foreign — any username is hallucinated + let args = serde_json::json!({"path": "/home/chintan/Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_real_username_not_rewritten() { + // On macOS, /Users//... should NOT be rewritten + // (could be a legitimate multi-user path) + let args = serde_json::json!({"path": "/Users/admin/shared/notes.txt"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, "/Users/admin/shared/notes.txt", "Real username should not be rewritten"); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_template_user() { + // /Users/{user}/Downloads on macOS — template on native prefix + let args = serde_json::json!({"path": "/Users/{user}/Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains("{user}"), + "Placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_template_username() { + // /Users/{username}/Documents on macOS + let args = serde_json::json!({"path": "/Users/{username}/Documents"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains("{username}"), + "Placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Documents")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_angle_bracket() { + // /Users//Downloads on macOS + let args = serde_json::json!({"path": "/Users//Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains(""), + "Angle-bracket placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_square_bracket() { + // /Users/[USER]/Documents/Projects on macOS + let args = serde_json::json!({"path": "/Users/[USER]/Documents/Projects"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains("[USER]"), + "Square-bracket placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Documents/Projects")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_known_placeholder_word() { + // /Users/user/Documents on macOS — "user" is a known placeholder + let args = serde_json::json!({"path": "/Users/user/Documents"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Documents")); + } + + #[test] + fn test_fix_bare_relative_path() { + // Model generates just "Projects" instead of an absolute path + let args = serde_json::json!({"path": "Projects"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Projects")); + } + + #[test] + fn test_fix_bare_downloads_relative_path() { + // Model generates "Downloads" + let args = serde_json::json!({"path": "Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[test] + fn test_tilde_expansion() { + let args = serde_json::json!({"path": "~/Documents/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Documents/file.txt")); + } + + #[test] + fn test_no_fix_for_correct_path() { + // Already-correct absolute path should not be modified + let home = dirs::home_dir().unwrap(); + let correct = home.join("Documents").join("test.txt"); + let correct_str = correct.to_string_lossy().into_owned(); + let args = serde_json::json!({"path": correct_str}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!(expanded["path"].as_str().unwrap(), correct_str); + } + + #[test] + fn test_no_fix_for_urls() { + // URL-like strings should not be modified + let args = serde_json::json!({"url": "https://example.com/Documents/file"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!( + expanded["url"].as_str().unwrap(), + "https://example.com/Documents/file" + ); + } + + // ─── Tool Result Compression Tests (PR #59) ───────────────────────────── + + #[test] + fn test_truncate_tool_result_small() { + let result = "short result"; + let truncated = truncate_tool_result(result, "test_tool"); + assert_eq!(truncated, result); + } + + #[test] + fn test_truncate_tool_result_large_no_compression() { + // Large result that doesn't match compression patterns + let result = "x".repeat(10000); + let truncated = truncate_tool_result(&result, "unknown_tool"); + assert!(truncated.contains("truncated")); + assert!(truncated.len() < result.len()); + } + + #[test] + fn test_compress_directory_listing() { + let listing = r#"📁 src/ +📁 tests/ +📄 file1.txt (1 KB) +📄 file2.txt (2 KB) +📁 subdir/ +📄 long_file_name.txt (100 KB)"#; + let compressed = compress_directory_listing(listing); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("Total:")); + assert!(summary.contains("Directories:")); + assert!(summary.contains("Files:")); + } + + #[test] + fn test_compress_directory_listing_empty() { + let compressed = compress_directory_listing(""); + assert!(compressed.is_none()); + } + + #[test] + fn test_compress_search_results_with_count() { + let result = "Scanning files...\nFound 42 matches\n\n/path/to/file1.txt: line 10\n/path/to/file2.txt: line 20"; + let compressed = compress_search_results(result); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("42")); + assert!(summary.contains("Key findings")); + } + + #[test] + fn test_compress_search_results_no_count() { + // Search result without clear count pattern + let result = "File A\nFile B\nFile C\nFile D\nFile E"; + let compressed = compress_search_results(result); + assert!(compressed.is_some()); + } + + #[test] + fn test_compress_json_result_array() { + let json = r#"[ + {"name": "item1", "size": 100}, + {"name": "item2", "size": 200}, + {"name": "item3", "size": 300} + ]"#; + let compressed = compress_json_result(json); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("3 items")); + assert!(summary.contains("item1")); + } + + #[test] + fn test_compress_json_result_object() { + let json = r#"{"name": "test", "content": "hello world", "count": 42}"#; + let compressed = compress_json_result(json); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("name:")); + assert!(summary.contains("test")); + } + + #[test] + fn test_compress_json_result_empty_array() { + let json = "[]"; + let compressed = compress_json_result(json); + assert!(compressed.is_some()); + assert!(compressed.unwrap().contains("empty")); + } + + #[test] + fn test_compress_tool_result_skips_small() { + // Small results should not be compressed + let result = "short"; + let compressed = compress_tool_result(result, "list_dir"); + assert!(compressed.is_none()); + } + + #[test] + fn test_compress_tool_result_non_compressible_tool() { + // Non-compressible tools should return None even for large results + let result = "x".repeat(10000); + let compressed = compress_tool_result(&result, "write_file"); + assert!(compressed.is_none()); + } + + #[test] + fn test_compress_tool_result_compressible_tool() { + // Compressible tools should attempt compression (need 3000+ chars) + let result = "📁 folder1\n📁 folder2\n📄 file1.txt (1 KB)\n📄 file2.txt (2 KB)\n".repeat(500); + let compressed = compress_tool_result(&result, "list_dir"); + assert!(compressed.is_some()); + } + + #[test] + fn test_builtin_tool_definitions() { + let tools = builtin_tool_definitions(); + assert!(!tools.is_empty()); + + let names: Vec = tools.iter() + .map(|t| t.function.name.clone()) + .collect(); + + assert!(names.contains(&"list_directory".to_string())); + assert!(names.contains(&"read_file".to_string())); + } + + #[test] + fn test_builtin_tool_definitions_have_descriptions() { + let tools = builtin_tool_definitions(); + for tool in &tools { + assert!(!tool.function.description.is_empty()); + assert!(!tool.function.parameters.is_null()); + } + } +} diff --git a/src-tauri/src/commands/settings.rs b/src-tauri/src/commands/settings.rs new file mode 100644 index 0000000..6c8d20f --- /dev/null +++ b/src-tauri/src/commands/settings.rs @@ -0,0 +1,726 @@ +//! Tauri IPC commands for the Settings panel. +//! +//! Reads model configuration from `_models/config.yaml` (the same source +//! of truth used by the inference client at runtime) and provides live +//! MCP server status from the running McpClient. + +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; + +use serde::{Deserialize, Serialize}; + +static SETTINGS_CHANGED: AtomicBool = AtomicBool::new(false); + +pub fn settings_changed() { + SETTINGS_CHANGED.store(true, Ordering::SeqCst); +} + +pub fn has_settings_changed() -> bool { + SETTINGS_CHANGED.swap(false, Ordering::SeqCst) +} + +/// Unified app settings that persist across restarts. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AppSettings { + /// Currently active model key from _models/config.yaml + pub active_model_key: Option, + /// Allowed filesystem paths for sandboxed operations + pub allowed_paths: Vec, + /// UI theme preference + pub theme: String, + /// Whether to show tool traces + pub show_tool_traces: bool, + /// Sampling config (integrated from existing system) + pub sampling: SamplingConfig, +} + +impl Default for AppSettings { + fn default() -> Self { + Self { + active_model_key: None, + allowed_paths: Vec::new(), + theme: "system".to_string(), + show_tool_traces: true, + sampling: SamplingConfig::default(), + } + // Default allowed paths + } +} + +impl AppSettings { + const FILE_NAME: &'static str = "settings.json"; + + fn persist_path() -> PathBuf { + crate::data_dir().join(Self::FILE_NAME) + } + + pub fn load_or_default() -> Self { + let path = Self::persist_path(); + if !path.exists() { + return Self::default(); + } + match std::fs::read_to_string(&path) { + Ok(content) => match serde_json::from_str::(&content) { + Ok(settings) => { + tracing::info!(path = %path.display(), "loaded app settings"); + settings + } + Err(e) => { + tracing::warn!(error = %e, "failed to parse settings, using defaults"); + Self::default() + } + }, + Err(e) => { + tracing::warn!(error = %e, "failed to read settings, using defaults"); + Self::default() + } + } + } + + pub fn save(&self) { + let path = Self::persist_path(); + let content = match serde_json::to_string_pretty(self) { + Ok(c) => c, + Err(e) => { + tracing::error!(error = %e, "failed to serialize settings"); + return; + } + }; + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let tmp_path = path.with_extension("json.tmp"); + if let Err(e) = std::fs::write(&tmp_path, &content) { + tracing::error!(error = %e, "failed to write settings temp file"); + return; + } + if let Err(e) = std::fs::rename(&tmp_path, &path) { + tracing::error!(error = %e, "failed to rename settings file"); + return; + } + settings_changed(); + tracing::debug!("saved app settings"); + } + + pub fn export_to_json(&self) -> Result { + serde_json::to_string_pretty(self).map_err(|e| format!("export failed: {}", e)) + } + + pub fn import_from_json(json: &str) -> Result { + let settings: Self = + serde_json::from_str(json).map_err(|e| format!("invalid settings JSON: {}", e))?; + settings.sampling.save(); + settings.save(); + Ok(settings) + } +} + +/// Model configuration exposed to the frontend. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelConfigInfo { + pub key: String, + pub display_name: String, + pub runtime: String, + pub base_url: String, + pub context_window: u32, + pub temperature: f64, + pub max_tokens: u32, + pub estimated_vram_gb: Option, + pub capabilities: Vec, + pub tool_call_format: String, +} + +/// Models overview returned to the frontend. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelsOverviewInfo { + pub active_model: String, + pub models: Vec, + pub fallback_chain: Vec, +} + +/// MCP server status. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerStatusInfo { + pub name: String, + pub status: String, + pub tool_count: u32, + pub tool_names: Vec, + pub last_check: String, + pub error: Option, +} + +/// Get the models configuration overview. +/// +/// Reads from `_models/config.yaml` using the same config loader +/// that the inference client uses at runtime. +#[tauri::command] +pub fn get_models_config() -> Result { + let cwd = std::env::current_dir().unwrap_or_default(); + let config_path = crate::inference::config::find_config_path(&cwd) + .map_err(|e| format!("Config not found: {e}"))?; + let config = crate::inference::config::load_models_config(&config_path) + .map_err(|e| format!("Config load error: {e}"))?; + + let models: Vec = config + .models + .iter() + .map(|(key, m)| ModelConfigInfo { + key: key.clone(), + display_name: m.display_name.clone(), + runtime: m.runtime.clone(), + base_url: m.base_url.clone(), + context_window: m.context_window, + temperature: f64::from(m.temperature), + max_tokens: m.max_tokens, + estimated_vram_gb: m.estimated_vram_gb.map(f64::from), + capabilities: m.capabilities.clone(), + tool_call_format: format!("{:?}", m.tool_call_format), + }) + .collect(); + + Ok(ModelsOverviewInfo { + active_model: config.active_model.clone(), + models, + fallback_chain: config.fallback_chain.clone(), + }) +} + +/// Get the status of all MCP servers from the running McpClient. +/// +/// Queries actual server state — no hardcoded stubs. Returns configured +/// servers with their running status and tool count. +#[tauri::command] +pub async fn get_mcp_servers_status( + mcp_state: tauri::State<'_, crate::TokioMutex>, +) -> Result, String> { + let mcp = mcp_state.lock().await; + let now = chrono::Utc::now().to_rfc3339(); + + let configured = mcp.configured_servers(); + let mut statuses: Vec = configured + .into_iter() + .map(|name| { + let is_running = mcp.is_server_running(&name); + let tool_count = mcp.registry.tools_for_server(&name) as u32; + let tool_names = mcp.registry.tool_names_for_server(&name); + + McpServerStatusInfo { + status: if is_running { + "initialized".to_string() + } else { + "failed".to_string() + }, + tool_count, + tool_names, + last_check: now.clone(), + error: if is_running { + None + } else { + Some("Server not running".to_string()) + }, + name, + } + }) + .collect(); + + statuses.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(statuses) +} + +// ─── Permission Grant Management ──────────────────────────────────────────── + +/// A permission grant exposed to the frontend. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionGrantInfo { + pub tool_name: String, + pub scope: String, + pub granted_at: String, +} + +/// List all persistent permission grants. +/// +/// Reads from the PermissionStore in Tauri state. +#[tauri::command] +pub async fn list_permission_grants( + perms: tauri::State<'_, crate::TokioMutex>, +) -> Result, String> { + let store = perms.lock().await; + let grants = store + .list_persistent() + .into_iter() + .map(|g| PermissionGrantInfo { + tool_name: g.tool_name.clone(), + scope: format!("{:?}", g.scope).to_lowercase(), + granted_at: g.granted_at.clone(), + }) + .collect(); + Ok(grants) +} + +/// Revoke a persistent permission grant by tool name. +/// +/// Removes the grant from the PermissionStore and persists the change to disk. +#[tauri::command] +pub async fn revoke_permission( + tool_name: String, + perms: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let mut store = perms.lock().await; + let removed = store.revoke(&tool_name); + tracing::info!(tool = %tool_name, removed, "revoke_permission"); + Ok(removed) +} + +// ─── Sampling Configuration ───────────────────────────────────────────────── + +/// Runtime sampling hyperparameters exposed to the frontend. +/// +/// Persisted to `sampling_config.json` in the app data directory. +/// The agent loop reads these at the start of each `send_message` call +/// instead of using hardcoded constants. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingConfig { + pub tool_temperature: f32, + pub tool_top_p: f32, + pub conversational_temperature: f32, + pub conversational_top_p: f32, +} + +impl Default for SamplingConfig { + fn default() -> Self { + Self { + tool_temperature: 0.1, + tool_top_p: 0.2, + conversational_temperature: 0.7, + conversational_top_p: 0.9, + } + } +} + +impl SamplingConfig { + /// Load from disk or return defaults. + pub fn load_or_default() -> Self { + let path = Self::persist_path(); + if !path.exists() { + return Self::default(); + } + match std::fs::read_to_string(&path) { + Ok(content) => match serde_json::from_str::(&content) { + Ok(cfg) => { + tracing::info!(path = %path.display(), "loaded sampling config"); + cfg + } + Err(e) => { + tracing::warn!(error = %e, "failed to parse sampling config, using defaults"); + Self::default() + } + }, + Err(e) => { + tracing::warn!(error = %e, "failed to read sampling config, using defaults"); + Self::default() + } + } + } + + /// Save to disk (atomic write). + pub fn save(&self) { + let path = Self::persist_path(); + let content = match serde_json::to_string_pretty(self) { + Ok(c) => c, + Err(e) => { + tracing::error!(error = %e, "failed to serialize sampling config"); + return; + } + }; + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let tmp_path = path.with_extension("json.tmp"); + if let Err(e) = std::fs::write(&tmp_path, &content) { + tracing::error!(error = %e, "failed to write sampling config temp file"); + return; + } + if let Err(e) = std::fs::rename(&tmp_path, &path) { + tracing::error!(error = %e, "failed to rename sampling config file"); + return; + } + tracing::debug!("saved sampling config"); + } + + fn persist_path() -> PathBuf { + crate::data_dir().join("sampling_config.json") + } +} + +/// Get the current sampling configuration. +#[tauri::command] +pub async fn get_sampling_config( + state: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let cfg = state.lock().await; + Ok(cfg.clone()) +} + +/// Update the sampling configuration and persist to disk. +#[tauri::command] +pub async fn update_sampling_config( + config: SamplingConfig, + state: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let mut cfg = state.lock().await; + *cfg = config; + cfg.save(); + tracing::info!( + tool_temp = cfg.tool_temperature, + tool_top_p = cfg.tool_top_p, + conv_temp = cfg.conversational_temperature, + conv_top_p = cfg.conversational_top_p, + "sampling config updated" + ); + Ok(cfg.clone()) +} + +/// Reset the sampling configuration to defaults and persist. +#[tauri::command] +pub async fn reset_sampling_config( + state: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let mut cfg = state.lock().await; + *cfg = SamplingConfig::default(); + cfg.save(); + tracing::info!("sampling config reset to defaults"); + Ok(cfg.clone()) +} + +// ─── Unified App Settings ──────────────────────────────────────────────────── + +/// Get the current app settings. +#[tauri::command] +pub fn get_app_settings() -> AppSettings { + AppSettings::load_or_default() +} + +/// Update app settings and persist to disk. +#[tauri::command] +pub fn update_app_settings(settings: AppSettings) -> AppSettings { + settings.save(); + tracing::info!( + active_model = ?settings.active_model_key, + theme = %settings.theme, + allowed_paths = settings.allowed_paths.len(), + "app settings updated" + ); + settings +} + +/// Add an allowed path to settings. +#[tauri::command] +pub fn add_allowed_path(path: String) -> AppSettings { + let mut settings = AppSettings::load_or_default(); + if !settings.allowed_paths.contains(&path) { + settings.allowed_paths.push(path.clone()); + settings.save(); + tracing::info!(path = %path, "allowed path added"); + } + settings +} + +/// Remove an allowed path from settings. +#[tauri::command] +pub fn remove_allowed_path(path: String) -> AppSettings { + let mut settings = AppSettings::load_or_default(); + let path_clone = path.clone(); + settings.allowed_paths.retain(|p| p != &path); + settings.save(); + tracing::info!(path = %path_clone, "allowed path removed"); + settings +} + +/// Export settings to JSON string. +#[tauri::command] +pub fn export_settings() -> Result { + let settings = AppSettings::load_or_default(); + settings.export_to_json() +} + +/// Import settings from JSON string. +#[tauri::command] +pub fn import_settings(json: String) -> Result { + AppSettings::import_from_json(&json) +} + +/// Check if settings have changed since last check (for file watching). +#[tauri::command] +pub fn poll_settings_changed() -> bool { + has_settings_changed() +} + +// ─── Config Hot Reload ────────────────────────────────────────────────────── + +use std::sync::atomic::AtomicU64; +use std::time::SystemTime; + +static CONFIG_LAST_MODIFIED: AtomicU64 = AtomicU64::new(0); + +/// Check if config file has been modified since last check. +#[tauri::command] +pub fn check_config_reload() -> Result { + let cwd = std::env::current_dir().unwrap_or_default(); + let config_path = crate::inference::config::find_config_path(&cwd) + .map_err(|e| format!("Config not found: {e}"))?; + + let metadata = std::fs::metadata(&config_path) + .map_err(|e| format!("Failed to read config metadata: {}", e))?; + + let modified = metadata.modified() + .map_err(|e| format!("Failed to get modification time: {}", e))?; + + let modified_secs = modified + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(|e| format!("Time error: {}", e))? + .as_secs(); + + let last_modified = CONFIG_LAST_MODIFIED.load(Ordering::SeqCst); + + if modified_secs > last_modified { + CONFIG_LAST_MODIFIED.store(modified_secs, Ordering::SeqCst); + Ok(true) + } else { + Ok(false) + } +} + +/// Force reload the model config (for manual refresh). +#[tauri::command] +pub fn reload_model_config() -> Result { + get_models_config() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::Ordering; + + #[test] + fn test_sampling_config_default() { + let cfg = SamplingConfig::default(); + assert_eq!(cfg.tool_temperature, 0.1); + assert_eq!(cfg.tool_top_p, 0.2); + assert_eq!(cfg.conversational_temperature, 0.7); + assert_eq!(cfg.conversational_top_p, 0.9); + } + + #[test] + fn test_sampling_config_serialization() { + let cfg = SamplingConfig { + tool_temperature: 0.5, + tool_top_p: 0.3, + conversational_temperature: 0.8, + conversational_top_p: 0.95, + }; + let json = serde_json::to_string(&cfg).unwrap(); + assert!(json.contains("0.5")); + assert!(json.contains("0.3")); + assert!(json.contains("0.8")); + assert!(json.contains("0.95")); + } + + #[test] + fn test_sampling_config_deserialization() { + let json = r#"{ + "toolTemperature": 0.3, + "toolTopP": 0.4, + "conversationalTemperature": 0.6, + "conversationalTopP": 0.8 + }"#; + let cfg: SamplingConfig = serde_json::from_str(json).unwrap(); + assert_eq!(cfg.tool_temperature, 0.3); + assert_eq!(cfg.tool_top_p, 0.4); + assert_eq!(cfg.conversational_temperature, 0.6); + assert_eq!(cfg.conversational_top_p, 0.8); + } + + #[test] + fn test_app_settings_default() { + let settings = AppSettings::default(); + assert_eq!(settings.active_model_key, None); + assert!(settings.allowed_paths.is_empty()); + assert_eq!(settings.theme, "system"); + assert!(settings.show_tool_traces); + // Sampling should be default + assert_eq!(settings.sampling.tool_temperature, 0.1); + } + + #[test] + fn test_app_settings_serialization() { + let mut settings = AppSettings::default(); + settings.active_model_key = Some("test-model".to_string()); + settings.allowed_paths = vec!["/home/user/docs".to_string()]; + settings.theme = "dark".to_string(); + settings.show_tool_traces = false; + + let json = serde_json::to_string(&settings).unwrap(); + assert!(json.contains("test-model")); + assert!(json.contains("dark")); + assert!(json.contains("docs")); + } + + #[test] + fn test_app_settings_deserialization() { + let json = r#"{ + "activeModelKey": "lm-studio-model", + "allowedPaths": ["/tmp", "/var"], + "theme": "light", + "showToolTraces": false, + "sampling": { + "toolTemperature": 0.2, + "toolTopP": 0.3, + "conversationalTemperature": 0.8, + "conversationalTopP": 0.9 + } + }"#; + let settings: AppSettings = serde_json::from_str(json).unwrap(); + assert_eq!(settings.active_model_key, Some("lm-studio-model".to_string())); + assert_eq!(settings.allowed_paths.len(), 2); + assert_eq!(settings.theme, "light"); + assert!(!settings.show_tool_traces); + } + + #[test] + fn test_config_last_modified_atomic() { + // Test that CONFIG_LAST_MODIFIED is properly initialized + let initial = CONFIG_LAST_MODIFIED.load(Ordering::SeqCst); + assert_eq!(initial, 0); + + // Store a value and verify + CONFIG_LAST_MODIFIED.store(12345, Ordering::SeqCst); + let after = CONFIG_LAST_MODIFIED.load(Ordering::SeqCst); + assert_eq!(after, 12345); + + // Reset + CONFIG_LAST_MODIFIED.store(0, Ordering::SeqCst); + } + + #[test] + fn test_settings_changed_atomic() { + // Test the SETTINGS_CHANGED flag + settings_changed(); + assert!(has_settings_changed()); + assert!(!has_settings_changed()); // Should clear after check + + // Setting it again should work + settings_changed(); + assert!(has_settings_changed()); + } + + #[test] + fn test_model_config_info_fields() { + let info = ModelConfigInfo { + key: "test-key".to_string(), + display_name: "Test Model".to_string(), + runtime: "lm-studio".to_string(), + base_url: "http://localhost:1234/v1".to_string(), + context_window: 32768, + temperature: 0.7, + max_tokens: 4096, + estimated_vram_gb: Some(24.0), + capabilities: vec!["chat".to_string(), "tools".to_string()], + tool_call_format: "json".to_string(), + }; + + assert_eq!(info.key, "test-key"); + assert_eq!(info.runtime, "lm-studio"); + assert_eq!(info.context_window, 32768); + } + + #[test] + fn test_models_overview_info_serialization() { + let overview = ModelsOverviewInfo { + active_model: "qwen2.5".to_string(), + models: vec![ + ModelConfigInfo { + key: "qwen2.5".to_string(), + display_name: "Qwen 2.5".to_string(), + runtime: "ollama".to_string(), + base_url: "http://localhost:11434/v1".to_string(), + context_window: 32768, + temperature: 0.7, + max_tokens: 4096, + estimated_vram_gb: Some(20.0), + capabilities: vec!["chat".to_string()], + tool_call_format: "json".to_string(), + } + ], + fallback_chain: vec!["gpt-oss".to_string()], + }; + + let json = serde_json::to_string(&overview).unwrap(); + assert!(json.contains("qwen2.5")); + assert!(json.contains("ollama")); + } + + #[test] + fn test_mcp_server_status_info() { + let status = McpServerStatusInfo { + name: "filesystem".to_string(), + status: "initialized".to_string(), + tool_count: 10, + tool_names: vec!["list_dir".to_string(), "read_file".to_string()], + last_check: "2024-01-01T00:00:00Z".to_string(), + error: None, + }; + + assert_eq!(status.name, "filesystem"); + assert_eq!(status.status, "initialized"); + assert_eq!(status.tool_count, 10); + + // Test with error + let status_with_error = McpServerStatusInfo { + error: Some("Connection refused".to_string()), + ..status + }; + assert!(status_with_error.error.is_some()); + } + + #[test] + fn test_permission_grant_info() { + let grant = PermissionGrantInfo { + tool_name: "filesystem.write_file".to_string(), + scope: "session".to_string(), + granted_at: "2024-01-01T12:00:00Z".to_string(), + }; + + assert_eq!(grant.tool_name, "filesystem.write_file"); + assert_eq!(grant.scope, "session"); + } + + #[test] + fn test_app_settings_export_import_roundtrip() { + let original = AppSettings { + active_model_key: Some("test-model".to_string()), + allowed_paths: vec!["/home/user".to_string()], + theme: "dark".to_string(), + show_tool_traces: true, + sampling: SamplingConfig { + tool_temperature: 0.15, + tool_top_p: 0.25, + conversational_temperature: 0.75, + conversational_top_p: 0.85, + }, + }; + + let json = original.export_to_json().unwrap(); + let imported = AppSettings::import_from_json(&json).unwrap(); + + assert_eq!(imported.active_model_key, original.active_model_key); + assert_eq!(imported.allowed_paths, original.allowed_paths); + assert_eq!(imported.theme, original.theme); + assert_eq!(imported.sampling.tool_temperature, original.sampling.tool_temperature); + } +} diff --git a/src-tauri/src/inference/client.rs b/src-tauri/src/inference/client.rs new file mode 100644 index 0000000..b1ccce6 --- /dev/null +++ b/src-tauri/src/inference/client.rs @@ -0,0 +1,862 @@ +//! OpenAI-compatible inference client. +//! +//! Sends chat completion requests to a local LLM endpoint and streams back +//! tokens and tool calls. Handles the fallback chain when the primary model +//! is unavailable. + +use std::time::Duration; + +use futures::future::Either; +use futures::Stream; +use reqwest::Client as HttpClient; +use uuid::Uuid; + +use super::config::{ModelConfig, ModelsConfig, ToolCallFormat}; +use super::errors::InferenceError; +use super::streaming::{parse_non_streaming_response, parse_sse_stream}; +use super::tool_call_parser::{extract_tool_call_from_error, repair_malformed_tool_call_json}; +use super::types::{ + ChatCompletionRequest, ChatMessage, SamplingOverrides, StreamChunk, ToolCall, ToolDefinition, +}; + +// ─── Constants ─────────────────────────────────────────────────────────────── + +/// TCP connection timeout. +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); + +/// Total request timeout for non-streaming calls. +const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); + +/// Total request timeout for streaming calls. +/// +/// Streaming responses from local models can take a long time, especially +/// when the context window is large (18+ messages). The model needs time +/// to process the full context before emitting the first token. A 30s +/// timeout causes silent stream termination that looks like "empty response" +/// to the agent loop. +const STREAM_REQUEST_TIMEOUT: Duration = Duration::from_secs(180); + +// ─── InferenceClient ───────────────────────────────────────────────────────── + +/// Client for the local LLM inference endpoint. +/// +/// Created from `ModelsConfig` and holds the current model configuration. +/// Provides streaming and non-streaming chat completion methods. +pub struct InferenceClient { + /// HTTP client for non-streaming requests (30s timeout). + http: HttpClient, + /// HTTP client for streaming requests (180s timeout). + http_stream: HttpClient, + /// The full models configuration (for fallback chain). + config: ModelsConfig, + /// The current model key (e.g., "qwen25-32b"). + current_model_key: String, + /// The current model configuration. + current_model: ModelConfig, + /// Models that have already been tried and failed. + exhausted_models: Vec, +} + +impl InferenceClient { + /// Create a new inference client from the models configuration. + /// + /// Resolves the active model from config. Does NOT check connectivity — + /// that happens on the first request. + pub fn from_config(config: ModelsConfig) -> Result { + let (key, model) = super::config::resolve_active_model(&config)?; + + let http = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build HTTP client: {e}"), + })?; + + let http_stream = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(STREAM_REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build streaming HTTP client: {e}"), + })?; + + Ok(Self { + http, + http_stream, + config, + current_model_key: key, + current_model: model, + exhausted_models: Vec::new(), + }) + } + + /// Create an inference client targeting a specific model by key. + /// + /// Unlike [`from_config`] which resolves the active model + fallback chain, + /// this constructor pins the client to a specific model. Used by the + /// orchestrator (ADR-009) to create separate planner and router clients. + pub fn from_config_with_model( + config: ModelsConfig, + model_key: &str, + ) -> Result { + let model = config + .models + .get(model_key) + .ok_or_else(|| InferenceError::ConfigError { + reason: format!("model '{model_key}' not found in config"), + })? + .clone(); + + let http = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build HTTP client: {e}"), + })?; + + let http_stream = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(STREAM_REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build streaming HTTP client: {e}"), + })?; + + Ok(Self { + http, + http_stream, + config, + current_model_key: model_key.to_string(), + current_model: model, + exhausted_models: Vec::new(), + }) + } + + /// The base URL of the current model's endpoint. + pub fn current_base_url(&self) -> &str { + &self.current_model.base_url + } + + /// The name of the currently selected model. + pub fn current_model_name(&self) -> &str { + &self.current_model.display_name + } + + /// The tool call format of the current model. + pub fn tool_call_format(&self) -> ToolCallFormat { + self.current_model.tool_call_format + } + + /// The context window size of the current model. + pub fn context_window(&self) -> u32 { + self.current_model.context_window + } + + // ─── Chat Completion (streaming) ───────────────────────────────────── + + /// Send a streaming chat completion request. + /// + /// Returns a `Stream` of `StreamChunk`s. Each chunk contains either a + /// text token, tool calls, or both. + /// + /// If the current model is unavailable, automatically tries the fallback + /// chain before returning an error. When Ollama returns HTTP 500 due to + /// malformed JSON in tool call arguments, attempts client-side repair + /// before triggering the fallback chain. + pub async fn chat_completion_stream( + &mut self, + messages: Vec, + tools: Option>, + sampling: Option, + ) -> Result< + impl Stream>, + InferenceError, + > { + let mut last_error: Option = None; + + for _attempt in 0..=self.remaining_fallbacks() { + match self.try_stream_request(&messages, &tools, sampling.as_ref()).await { + Ok(stream) => return Ok(Either::Left(stream)), + Err(e) if e.is_tool_call_parse_error() => { + // Ollama returned HTTP 500 because the model generated + // malformed JSON in tool call arguments. Try to repair the + // JSON client-side before falling back to the next model. + if let Some(repaired) = Self::try_repair_from_error(&e) { + tracing::info!( + tool = %repaired.tool_calls.as_ref() + .and_then(|tc| tc.first()) + .map(|tc| tc.name.as_str()) + .unwrap_or("unknown"), + "repaired malformed JSON tool call" + ); + return Ok(Either::Right(futures::stream::once( + async { Ok(repaired) }, + ))); + } + // Repair failed — continue to fallback chain + tracing::warn!("tool call JSON repair failed, falling back"); + last_error = Some(e); + if self.try_next_fallback().is_err() { + break; + } + } + Err(e) if Self::is_retriable(&e) => { + last_error = Some(e); + if self.try_next_fallback().is_err() { + break; // No more fallbacks + } + } + Err(e) => return Err(e), // Non-retriable error + } + } + + Err(last_error.unwrap_or(InferenceError::AllModelsUnavailable { + attempted: self.exhausted_models.clone(), + })) + } + + /// Attempt a single streaming request to the current model. + async fn try_stream_request( + &self, + messages: &[ChatMessage], + tools: &Option>, + sampling: Option<&SamplingOverrides>, + ) -> Result>, InferenceError> { + let url = format!("{}/chat/completions", self.current_model.base_url); + let model_name = self + .current_model + .model_name + .clone() + .unwrap_or_else(|| self.current_model_key.clone()); + + let temperature = sampling + .and_then(|s| s.temperature) + .unwrap_or(self.current_model.temperature); + let top_p = sampling.and_then(|s| s.top_p); + + // Enable JSON response format when the model config opts in AND + // tools are present. This sends `response_format: {"type":"json_object"}` + // which triggers Ollama's GBNF grammar enforcement for valid JSON output. + let response_format = if self.current_model.force_json_response && tools.is_some() { + Some(super::types::ResponseFormat { + r#type: "json_object".to_string(), + }) + } else { + None + }; + + let body = ChatCompletionRequest { + model: model_name, + messages: messages.to_vec(), + tools: tools.clone(), + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + temperature, + top_p, + max_tokens: self.current_model.max_tokens, + stream: true, + response_format, + }; + + // Log the request metadata (not the full body — it can be huge) + tracing::info!( + url = %url, + model = %body.model, + message_count = body.messages.len(), + has_tools = body.tools.is_some(), + tool_count = body.tools.as_ref().map(|t| t.len()).unwrap_or(0), + max_tokens = body.max_tokens, + stream = body.stream, + "=== LLM REQUEST ===" + ); + + let response = self + .http_stream + .post(&url) + .json(&body) + .header("Accept", "text/event-stream") + .send() + .await + .map_err(|e| { + if e.is_connect() { + InferenceError::ConnectionFailed { + endpoint: url.clone(), + reason: e.to_string(), + } + } else if e.is_timeout() { + InferenceError::Timeout { duration_secs: 5 } + } else { + InferenceError::ConnectionFailed { + endpoint: url.clone(), + reason: e.to_string(), + } + } + })?; + + let status = response.status(); + if !status.is_success() { + let body_text = response.text().await.unwrap_or_default(); + return Err(InferenceError::HttpError { + status: status.as_u16(), + body: body_text, + }); + } + + Ok(parse_sse_stream(response, self.current_model.tool_call_format)) + } + + // ─── Tool Call Repair ────────────────────────────────────────────────── + + /// Attempt to repair a malformed tool call from an Ollama HTTP 500 error. + /// + /// Extracts the raw JSON from the error body, applies repair heuristics, + /// and builds a synthetic `StreamChunk` with the repaired tool call. + /// Returns `None` if the error body doesn't match or repair fails. + fn try_repair_from_error(err: &InferenceError) -> Option { + let body = err.error_body()?; + let (_tool_name, raw_args) = extract_tool_call_from_error(body)?; + let repaired_args = repair_malformed_tool_call_json(&raw_args)?; + + // Build a synthetic tool call. The tool name is empty because + // Ollama's error body doesn't include it — the agent loop resolves + // the name from the conversation context (the model declared intent + // before Ollama attempted to parse the arguments). + let tool_call = ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: _tool_name, + arguments: repaired_args, + }; + + Some(StreamChunk { + token: None, + tool_calls: Some(vec![tool_call]), + finish_reason: Some("tool_calls".to_string()), + }) + } + + // ─── Chat Completion (non-streaming) ───────────────────────────────── + + /// Send a non-streaming chat completion request. + /// + /// Returns a single `StreamChunk` with the complete response. + pub async fn chat_completion( + &mut self, + messages: Vec, + tools: Option>, + sampling: Option, + ) -> Result { + let url = format!("{}/chat/completions", self.current_model.base_url); + let model_name = self + .current_model + .model_name + .clone() + .unwrap_or_else(|| self.current_model_key.clone()); + + let temperature = sampling + .as_ref() + .and_then(|s| s.temperature) + .unwrap_or(self.current_model.temperature); + let top_p = sampling.as_ref().and_then(|s| s.top_p); + + let response_format = if self.current_model.force_json_response && tools.is_some() { + Some(super::types::ResponseFormat { + r#type: "json_object".to_string(), + }) + } else { + None + }; + + let body = ChatCompletionRequest { + model: model_name, + messages, + tools: tools.clone(), + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + temperature, + top_p, + max_tokens: self.current_model.max_tokens, + stream: false, + response_format, + }; + + let response = self + .http + .post(&url) + .json(&body) + .send() + .await + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: url.clone(), + reason: e.to_string(), + })?; + + let status = response.status(); + if !status.is_success() { + let body_text = response.text().await.unwrap_or_default(); + return Err(InferenceError::HttpError { + status: status.as_u16(), + body: body_text, + }); + } + + let body_text = response.text().await.map_err(|e| InferenceError::StreamError { + reason: format!("failed to read response body: {e}"), + })?; + + parse_non_streaming_response(&body_text, self.current_model.tool_call_format) + } + + // ─── Health Check ──────────────────────────────────────────────────── + + /// Check if the current model endpoint is reachable. + /// + /// Sends a lightweight request to verify connectivity. Does not consume + /// inference tokens. + pub async fn health_check(&self) -> Result { + let url = format!("{}/models", self.current_model.base_url); + + match self.http.get(&url).timeout(CONNECT_TIMEOUT).send().await { + Ok(resp) => Ok(resp.status().is_success()), + Err(_) => Ok(false), + } + } + + /// Get detailed model status including endpoint info. + pub async fn get_status(&self) -> super::types::ModelStatus { + let url = format!("{}/models", self.current_model.base_url); + match self.http.get(&url).timeout(CONNECT_TIMEOUT).send().await { + Ok(resp) if resp.status().is_success() => super::types::ModelStatus { + key: self.current_model_key.clone(), + display_name: self.current_model.display_name.clone(), + base_url: self.current_model.base_url.clone(), + healthy: true, + model_name: self.current_model.model_name.clone().or_else(|| Some(self.current_model_key.clone())), + error: None, + }, + Ok(resp) => super::types::ModelStatus { + key: self.current_model_key.clone(), + display_name: self.current_model.display_name.clone(), + base_url: self.current_model.base_url.clone(), + healthy: false, + model_name: None, + error: Some(format!("HTTP {}", resp.status())), + }, + Err(e) => super::types::ModelStatus { + key: self.current_model_key.clone(), + display_name: self.current_model.display_name.clone(), + base_url: self.current_model.base_url.clone(), + healthy: false, + model_name: None, + error: Some(e.to_string()), + }, + } + } + + // ─── Fallback Chain ─────────────────────────────────────────────────────── + + /// Move to the next model in the fallback chain. + /// + /// Returns `Err` if no more fallbacks are available. + pub fn try_next_fallback(&mut self) -> Result<(), InferenceError> { + self.exhausted_models.push(self.current_model_key.clone()); + + for key in &self.config.fallback_chain { + if self.exhausted_models.contains(key) || key == "static_response" { + continue; + } + if let Some(model) = self.config.models.get(key) { + self.current_model_key = key.clone(); + self.current_model = model.clone(); + return Ok(()); + } + } + + Err(InferenceError::AllModelsUnavailable { + attempted: self.exhausted_models.clone(), + }) + } + + /// Number of remaining fallback models. + fn remaining_fallbacks(&self) -> usize { + self.config + .fallback_chain + .iter() + .filter(|k| !self.exhausted_models.contains(k) && k.as_str() != "static_response") + .count() + } + + /// Whether an error should trigger a fallback attempt. + /// + /// HTTP 404 is included because Ollama returns 404 when a model isn't + /// pulled/installed — the next model in the chain may still be available. + /// + /// HTTP 500 is included because local model servers (Ollama, llama.cpp) + /// return 500 when the model generates malformed JSON in tool call + /// arguments — this is a transient model error, not a permanent server + /// failure. Retrying (or falling back) is the correct behavior. + fn is_retriable(err: &InferenceError) -> bool { + matches!( + err, + InferenceError::ConnectionFailed { .. } + | InferenceError::Timeout { .. } + | InferenceError::HttpError { status: 404, .. } + | InferenceError::HttpError { status: 500, .. } + | InferenceError::HttpError { status: 502..=504, .. } + ) + } +} + +// ─── Static Response Fallback ──────────────────────────────────────────────── + +/// Generate the static response used when all models are unavailable. +pub fn static_fallback_response() -> StreamChunk { + StreamChunk { + token: Some( + "The model server is not running. \ + Start it with: ./scripts/start-model.sh\n\n\ + If using Ollama instead, run: ollama serve" + .to_string(), + ), + tool_calls: None, + finish_reason: Some("stop".to_string()), + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn test_config() -> ModelsConfig { + let mut models = HashMap::new(); + models.insert( + "model-a".to_string(), + ModelConfig { + display_name: "Model A".to_string(), + runtime: "ollama".to_string(), + model_name: Some("model-a:latest".to_string()), + model_path: None, + base_url: "http://localhost:11111/v1".to_string(), + context_window: 4096, + tool_call_format: ToolCallFormat::NativeJson, + temperature: 0.7, + max_tokens: 1024, + estimated_vram_gb: None, + capabilities: vec!["text".to_string()], + force_json_response: false, + role: None, + }, + ); + models.insert( + "model-b".to_string(), + ModelConfig { + display_name: "Model B".to_string(), + runtime: "ollama".to_string(), + model_name: Some("model-b:latest".to_string()), + model_path: None, + base_url: "http://localhost:22222/v1".to_string(), + context_window: 8192, + tool_call_format: ToolCallFormat::Pythonic, + temperature: 0.5, + max_tokens: 2048, + estimated_vram_gb: None, + capabilities: vec!["text".to_string()], + force_json_response: false, + role: None, + }, + ); + models.insert( + "lmstudio-model".to_string(), + ModelConfig { + display_name: "LM Studio Model".to_string(), + runtime: "lmstudio".to_string(), + model_name: Some("lmstudio/default".to_string()), + model_path: None, + base_url: "http://localhost:1234/v1".to_string(), + context_window: 32768, + tool_call_format: ToolCallFormat::NativeJson, + temperature: 0.7, + max_tokens: 4096, + estimated_vram_gb: Some(8.0), + capabilities: vec!["text".to_string(), "tool_calling".to_string()], + force_json_response: false, + role: None, + }, + ); + + ModelsConfig { + active_model: "model-a".to_string(), + models_dir: None, + models, + fallback_chain: vec![ + "model-a".to_string(), + "model-b".to_string(), + "static_response".to_string(), + ], + orchestrator: None, + two_pass_tool_selection: None, + enabled_servers: None, + enabled_tools: None, + } + } + + #[test] + fn test_from_config_selects_active_model() { + let client = InferenceClient::from_config(test_config()).unwrap(); + assert_eq!(client.current_model_key, "model-a"); + assert_eq!(client.current_model_name(), "Model A"); + } + + #[test] + fn test_fallback_chain() { + let mut client = InferenceClient::from_config(test_config()).unwrap(); + assert_eq!(client.current_model_key, "model-a"); + + // Fallback to model-b + client.try_next_fallback().unwrap(); + assert_eq!(client.current_model_key, "model-b"); + assert_eq!(client.tool_call_format(), ToolCallFormat::Pythonic); + + // No more fallbacks + let result = client.try_next_fallback(); + assert!(result.is_err()); + } + + #[test] + fn test_lmstudio_model_config() { + let config = test_config(); + // Create client targeting LM Studio model directly + let client = InferenceClient::from_config_with_model(config, "lmstudio-model").unwrap(); + assert_eq!(client.current_model_key, "lmstudio-model"); + assert_eq!(client.current_model_name(), "LM Studio Model"); + assert_eq!(client.current_base_url(), "http://localhost:1234/v1"); + } + + #[test] + fn test_remaining_fallbacks() { + let client = InferenceClient::from_config(test_config()).unwrap(); + // model-a (current, in chain) + model-b = 2 remaining + assert_eq!(client.remaining_fallbacks(), 2); + } + + #[test] + fn test_is_retriable() { + assert!(InferenceClient::is_retriable( + &InferenceError::ConnectionFailed { + endpoint: "".into(), + reason: "".into() + } + )); + assert!(InferenceClient::is_retriable(&InferenceError::Timeout { + duration_secs: 5 + })); + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 404, + body: "model not found".into() + })); + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 500, + body: "malformed JSON".into() + })); + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 503, + body: "".into() + })); + assert!(!InferenceClient::is_retriable( + &InferenceError::HttpError { + status: 400, + body: "".into() + } + )); + assert!(!InferenceClient::is_retriable( + &InferenceError::ToolCallParseError { + raw_response: "".into(), + reason: "".into() + } + )); + } + + #[test] + fn test_is_retriable_connection_failed() { + assert!(InferenceClient::is_retriable( + &InferenceError::ConnectionFailed { + endpoint: "localhost".into(), + reason: "connection refused".into() + } + )); + } + + #[test] + fn test_is_retriable_timeout() { + assert!(InferenceClient::is_retriable(&InferenceError::Timeout { duration_secs: 5 })); + } + + #[test] + fn test_is_retriable_404() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 404, + body: "not found".into() + })); + } + + #[test] + fn test_is_retriable_500() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 500, + body: "internal error".into() + })); + } + + #[test] + fn test_is_retriable_502() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 502, + body: "bad gateway".into() + })); + } + + #[test] + fn test_is_retriable_503() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 503, + body: "service unavailable".into() + })); + } + + #[test] + fn test_is_retriable_400_not_retriable() { + // HTTP 400 should NOT be retriable + assert!(!InferenceClient::is_retriable(&InferenceError::HttpError { + status: 400, + body: "bad request".into() + })); + } + + #[test] + fn test_is_retriable_401_not_retriable() { + // HTTP 401 should NOT be retriable + assert!(!InferenceClient::is_retriable(&InferenceError::HttpError { + status: 401, + body: "unauthorized".into() + })); + } + + #[test] + fn test_is_retriable_403_not_retriable() { + // HTTP 403 should NOT be retriable + assert!(!InferenceClient::is_retriable(&InferenceError::HttpError { + status: 403, + body: "forbidden".into() + })); + } + + #[test] + fn test_is_retriable_tool_call_error_not_retriable() { + // Tool call parse error should NOT be retriable (it's a model issue) + assert!(!InferenceClient::is_retriable(&InferenceError::ToolCallParseError { + raw_response: "invalid".into(), + reason: "bad json".into() + })); + } + + #[test] + fn test_static_fallback_response() { + let chunk = static_fallback_response(); + assert!(chunk.token.is_some()); + assert!(chunk.tool_calls.is_none()); + assert_eq!(chunk.finish_reason.as_deref(), Some("stop")); + } + + #[test] + fn test_try_repair_from_error_success() { + // Simulate the exact Ollama HTTP 500 error with malformed JSON + let err = InferenceError::HttpError { + status: 500, + body: r#"{"error":{"message":"error parsing tool call: raw='{\"create_dirs\":true,\"destination\":\"\"/Users/chintan/Desktop/file.png\",\"source\":\"/tmp/file.png\"}', err=invalid character '/' after object key:value pair"}}"#.to_string(), + }; + + let result = InferenceClient::try_repair_from_error(&err); + assert!(result.is_some(), "should repair the malformed JSON"); + + let chunk = result.unwrap(); + assert!(chunk.token.is_none()); + assert!(chunk.tool_calls.is_some()); + assert_eq!(chunk.finish_reason.as_deref(), Some("tool_calls")); + + let calls = chunk.tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].arguments["destination"], "/Users/chintan/Desktop/file.png"); + assert_eq!(calls[0].arguments["create_dirs"], true); + } + + #[test] + fn test_try_repair_from_error_non_tool_call_error() { + // A regular HTTP 500 that isn't a tool call parse error + let err = InferenceError::HttpError { + status: 500, + body: "internal server error".to_string(), + }; + assert!(InferenceClient::try_repair_from_error(&err).is_none()); + } + + #[test] + fn test_try_repair_from_error_non_http_error() { + let err = InferenceError::Timeout { duration_secs: 30 }; + assert!(InferenceClient::try_repair_from_error(&err).is_none()); + } + + #[test] + fn test_lmstudio_base_url_construction() { + // Test LM Studio URL construction with different ports + let client = InferenceClient::from_config_with_model(test_config(), "lmstudio-model").unwrap(); + assert_eq!(client.current_base_url(), "http://localhost:1234/v1"); + } + + #[test] + fn test_fallback_chain_exhausted_error() { + let mut client = InferenceClient::from_config(test_config()).unwrap(); + + // Exhaust all fallbacks + client.try_next_fallback().unwrap(); // model-b + let result = client.try_next_fallback(); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, InferenceError::AllModelsUnavailable { .. })); + } + + #[test] + fn test_current_model_name_for_lmstudio() { + let client = InferenceClient::from_config_with_model(test_config(), "lmstudio-model").unwrap(); + assert_eq!(client.current_model_name(), "LM Studio Model"); + } + + #[test] + fn test_current_model_name_for_ollama() { + let client = InferenceClient::from_config(test_config()).unwrap(); + assert_eq!(client.current_model_name(), "Model A"); + } + + #[test] + fn test_tool_call_format_json() { + // Test that LM Studio uses NativeJson format + let config = test_config(); + let client = InferenceClient::from_config_with_model(config, "lmstudio-model").unwrap(); + assert_eq!(client.tool_call_format(), ToolCallFormat::NativeJson); + } + + #[test] + fn test_tool_call_format_pythonic() { + // Test that model-b uses Pythonic format + let config = test_config(); + let mut client = InferenceClient::from_config(config).unwrap(); + client.try_next_fallback().unwrap(); // model-b + assert_eq!(client.tool_call_format(), ToolCallFormat::Pythonic); + } +} diff --git a/src-tauri/src/inference/types.rs b/src-tauri/src/inference/types.rs new file mode 100644 index 0000000..ea42915 --- /dev/null +++ b/src-tauri/src/inference/types.rs @@ -0,0 +1,438 @@ +//! Shared types for the inference client. +//! +//! These mirror the OpenAI Chat Completions API types, used for both +//! request building and response parsing. + +use serde::{Deserialize, Serialize}; + +// ─── Request Types ─────────────────────────────────────────────────────────── + +/// A single message in the conversation. +/// +/// Serialization notes for OpenAI-compatible local models: +/// - `content` must be `""` (not `null`) for assistant messages with tool calls. +/// Many local models (Ollama, llama.cpp) misinterpret `null` content and fail +/// to recognize the tool call round-trip pattern. +/// - `tool_call_id` and `tool_calls` are skipped when `None`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: Role, + #[serde(serialize_with = "serialize_content")] + pub content: Option, + /// Tool call results are sent back as `tool` role messages. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Assistant messages may contain tool calls. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +/// Custom serializer for `content`: emit `""` instead of `null` when `None`. +/// +/// OpenAI's API accepts `null` content, but many local LLM runtimes +/// (Ollama, llama.cpp, vLLM) reject or mishandle `null` content fields. +/// Using `""` (empty string) is universally safe. +fn serialize_content(value: &Option, serializer: S) -> Result +where + S: serde::Serializer, +{ + match value { + Some(s) => serializer.serialize_str(s), + None => serializer.serialize_str(""), + } +} + +/// Message role. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +/// Tool definition sent in the request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + pub r#type: String, + pub function: FunctionDefinition, +} + +/// Function definition within a tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +/// Structured output format hint for the model. +/// +/// When set to `json_object`, instructs Ollama to use GBNF grammar +/// enforcement to guarantee valid JSON output. This is opt-in and +/// experimental — only enable after live testing with the target model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseFormat { + /// The format type. Currently only `"json_object"` is supported. + pub r#type: String, +} + +/// Request body for `POST /v1/chat/completions`. +#[derive(Debug, Clone, Serialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + pub temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + pub max_tokens: u32, + pub stream: bool, + /// Optional structured output format. When set, the model backend + /// (Ollama/llama.cpp) uses grammar constraints to enforce valid output. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +/// Optional sampling parameter overrides for a single inference call. +/// +/// When provided, these override the model config defaults. +/// Used to lower temperature/top_p for tool-calling turns (more deterministic) +/// and raise them for conversational turns (more creative). +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingOverrides { + /// Override temperature (0.0 = deterministic, 1.0 = creative). + pub temperature: Option, + /// Override top_p (nucleus sampling threshold). + pub top_p: Option, +} + +// ─── Response Types ────────────────────────────────────────────────────────── + +/// A parsed tool call extracted from the model's response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + /// Unique ID for this tool call (generated if the model doesn't provide one). + pub id: String, + /// Fully qualified tool name, e.g. `"filesystem.list_dir"`. + pub name: String, + /// Validated JSON arguments. + pub arguments: serde_json::Value, +} + +/// Tool call as returned in the OpenAI response format. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallResponse { + pub id: String, + pub r#type: String, + pub function: FunctionCallResponse, +} + +/// Function call details in a response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCallResponse { + pub name: String, + pub arguments: String, +} + +/// A single chunk from the streaming response. +#[derive(Debug, Clone)] +pub struct StreamChunk { + /// Incremental text token (if this chunk carries text). + pub token: Option, + /// Tool calls detected in this chunk (accumulated). + pub tool_calls: Option>, + /// Why the model stopped: `"stop"`, `"tool_calls"`, or `None` (still going). + pub finish_reason: Option, +} + +/// Raw SSE chunk from the OpenAI API. +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionChunk { + #[allow(dead_code)] + pub id: Option, + pub choices: Vec, +} + +/// A single choice within a streaming chunk. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkChoice { + pub delta: ChunkDelta, + pub finish_reason: Option, +} + +/// The delta (incremental update) within a chunk choice. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkDelta { + #[serde(default)] + pub content: Option, + /// Reasoning/thinking content from models like Qwen3 and GPT-OSS. + /// Deserialized to prevent serde unknown-field errors, but not used for + /// streaming output — `content` holds the actual answer after reasoning + /// completes. Reasoning tokens are silently discarded. + #[serde(default)] + #[allow(dead_code)] + pub reasoning: Option, + #[serde(default)] + pub tool_calls: Option>, +} + +/// A tool call fragment within a streaming delta. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkToolCall { + pub index: Option, + pub id: Option, + pub function: Option, +} + +/// A function call fragment within a streaming tool call. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkFunction { + pub name: Option, + pub arguments: Option, +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_top_p_omitted_when_none() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.7, + top_p: None, + max_tokens: 1024, + stream: false, + response_format: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("top_p"), "top_p should be omitted when None"); + } + + #[test] + fn test_top_p_included_when_some() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.1, + top_p: Some(0.2), + max_tokens: 1024, + stream: false, + response_format: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!( + json.contains("\"top_p\":0.2"), + "top_p should appear in JSON when Some" + ); + } + + #[test] + fn test_response_format_omitted_when_none() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.7, + top_p: None, + max_tokens: 1024, + stream: false, + response_format: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!( + !json.contains("response_format"), + "response_format should be omitted when None" + ); + } + + #[test] + fn test_response_format_included_when_set() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.7, + top_p: None, + max_tokens: 1024, + stream: false, + response_format: Some(ResponseFormat { + r#type: "json_object".to_string(), + }), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!( + json.contains("\"response_format\""), + "response_format should appear in JSON when Some" + ); + assert!( + json.contains("\"json_object\""), + "type should be json_object" + ); + } + + #[test] + fn test_sampling_overrides_default() { + let overrides = SamplingOverrides::default(); + assert!(overrides.temperature.is_none()); + assert!(overrides.top_p.is_none()); + } + + #[test] + fn test_sampling_overrides_with_values() { + let overrides = SamplingOverrides { + temperature: Some(0.5), + top_p: Some(0.9), + }; + assert!(overrides.temperature.is_some()); + assert_eq!(overrides.temperature.unwrap(), 0.5); + } + + #[test] + fn test_sampling_overrides_serialization() { + let overrides = SamplingOverrides { + temperature: Some(0.3), + top_p: Some(0.7), + }; + let json = serde_json::to_string(&overrides).unwrap(); + assert!(json.contains("0.3")); + assert!(json.contains("0.7")); + } + + #[test] + fn test_sampling_overrides_deserialization() { + let json = r#"{"temperature": 0.4, "topP": 0.8}"#; + let overrides: SamplingOverrides = serde_json::from_str(json).unwrap(); + assert_eq!(overrides.temperature, Some(0.4)); + assert_eq!(overrides.top_p, Some(0.8)); + } +} + +/// Model status information for health monitoring. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelStatus { + pub key: String, + pub display_name: String, + pub base_url: String, + pub healthy: bool, + pub model_name: Option, + pub error: Option, +} + +#[cfg(test)] +mod model_status_tests { + use super::*; + + #[test] + fn test_model_status_healthy() { + let status = ModelStatus { + key: "model-a".to_string(), + display_name: "Model A".to_string(), + base_url: "http://localhost:11434".to_string(), + healthy: true, + model_name: Some("model-a:latest".to_string()), + error: None, + }; + + assert!(status.healthy); + assert!(status.error.is_none()); + assert!(status.model_name.is_some()); + } + + #[test] + fn test_model_status_unhealthy() { + let status = ModelStatus { + key: "model-a".to_string(), + display_name: "Model A".to_string(), + base_url: "http://localhost:11434".to_string(), + healthy: false, + model_name: None, + error: Some("connection refused".to_string()), + }; + + assert!(!status.healthy); + assert!(status.error.is_some()); + assert!(status.model_name.is_none()); + } + + #[test] + fn test_model_status_serialization() { + let status = ModelStatus { + key: "qwen".to_string(), + display_name: "Qwen 2.5".to_string(), + base_url: "http://localhost:1234/v1".to_string(), + healthy: true, + model_name: Some("qwen2.5:14b".to_string()), + error: None, + }; + + let json = serde_json::to_string(&status).unwrap(); + assert!(json.contains("qwen")); + assert!(json.contains("healthy")); + assert!(json.contains("true")); + } + + #[test] + fn test_model_status_deserialization() { + let json = r#"{ + "key": "lm-studio-model", + "displayName": "LM Studio Model", + "baseUrl": "http://localhost:1234/v1", + "healthy": true, + "modelName": "model-name", + "error": null + }"#; + + let status: ModelStatus = serde_json::from_str(json).unwrap(); + assert_eq!(status.key, "lm-studio-model"); + assert!(status.healthy); + assert_eq!(status.model_name, Some("model-name".to_string())); + } + + #[test] + fn test_model_status_with_error_deserialization() { + let json = r#"{ + "key": "ollama-model", + "displayName": "Ollama Model", + "baseUrl": "http://localhost:11434", + "healthy": false, + "modelName": null, + "error": "timeout after 30s" + }"#; + + let status: ModelStatus = serde_json::from_str(json).unwrap(); + assert!(!status.healthy); + assert_eq!(status.error, Some("timeout after 30s".to_string())); + } + + #[test] + fn test_model_status_display() { + let status = ModelStatus { + key: "test".to_string(), + display_name: "Test Model".to_string(), + base_url: "http://localhost:8080".to_string(), + healthy: true, + model_name: None, + error: None, + }; + + let debug_str = format!("{:?}", status); + assert!(debug_str.contains("Test Model")); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs new file mode 100644 index 0000000..19ef34e --- /dev/null +++ b/src-tauri/src/lib.rs @@ -0,0 +1,897 @@ +pub mod agent_core; +pub mod commands; +pub mod inference; +pub mod mcp_client; + +use std::collections::HashMap; +use std::sync::Mutex; + +use agent_core::{AgentDatabase, ConversationManager, ConfirmationResponse, PermissionStore}; +use commands::settings::SamplingConfig; +use mcp_client::McpClient; +use tauri::Manager; + +/// Pending confirmation channel — holds a oneshot sender while the agent loop +/// awaits a user response via the ConfirmationDialog. +pub type PendingConfirmation = + TokioMutex>>; + +/// In-flight request tracker — prevents duplicate requests for the same session. +pub type InFlightRequests = TokioMutex>; + +/// Async mutex for types that require `.await` inside their methods. +pub type TokioMutex = tokio::sync::Mutex; + +/// Return the platform-standard data directory for LocalCowork. +/// +/// - macOS: `~/Library/Application Support/com.localcowork.app/` +/// - Windows: `{FOLDERID_RoamingAppData}\localcowork\` +/// - Linux: `$XDG_DATA_HOME/com.localcowork.app/` (fallback `~/.local/share/...`) +/// +/// Falls back to `~/.localcowork/` only if none of the above can be resolved. +pub(crate) fn data_dir() -> std::path::PathBuf { + if let Some(dir) = dirs::data_dir() { + return dir.join("com.localcowork.app"); + } + dirs::home_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join(".localcowork") +} + +/// Returns the cache directory for the app (embedding indexes, etc.). +pub(crate) fn cache_dir() -> std::path::PathBuf { + data_dir().join("cache") +} + +/// Initialize the tracing subscriber — writes structured logs to the app data directory. +/// +/// On each app startup: +/// 1. Rotates existing logs (agent.log → agent.log.1 → .2 → .3, keeps last 3). +/// 2. Opens a fresh agent.log with a line-flushing writer for crash resilience. +/// 3. Logs a startup banner with the data directory path for discoverability. +fn init_tracing() { + use tracing_subscriber::fmt; + use tracing_subscriber::EnvFilter; + + let log_dir = data_dir(); + let _ = std::fs::create_dir_all(&log_dir); + + let log_path = log_dir.join("agent.log"); + + // Rotate: agent.log.2 → .3, .1 → .2, agent.log → .1 + rotate_log_file(&log_path, 3); + + let log_file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&log_path) + .expect("failed to open agent.log"); + + let flushing_writer = FlushingWriter::new(log_file); + + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("localcowork=info,warn")); + + fmt::fmt() + .with_env_filter(filter) + .with_writer(flushing_writer) + .with_ansi(false) + .with_target(true) + .with_thread_ids(false) + .init(); + + // Startup banner — makes it easy to find the right log file + tracing::info!( + version = env!("CARGO_PKG_VERSION"), + data_dir = %log_dir.display(), + log_file = %log_path.display(), + pid = std::process::id(), + "=== LocalCowork starting ===" + ); +} + +/// Rotate log files: `agent.log` → `agent.log.1` → `.2` → … → `.{keep}`. +/// +/// Oldest file beyond `keep` is deleted. Missing files in the chain are skipped. +fn rotate_log_file(base_path: &std::path::Path, keep: u32) { + // Delete the oldest + let oldest = format!("{}.{keep}", base_path.display()); + let _ = std::fs::remove_file(&oldest); + + // Shift: .{n-1} → .{n} + for i in (1..keep).rev() { + let from = format!("{}.{i}", base_path.display()); + let to = format!("{}.{}", base_path.display(), i + 1); + let _ = std::fs::rename(&from, &to); + } + + // Current → .1 + if base_path.exists() { + let to = format!("{}.1", base_path.display()); + let _ = std::fs::rename(base_path, &to); + } +} + +/// A writer that wraps `std::fs::File` and flushes after every write. +/// +/// `tracing-subscriber` buffers log output internally. Without explicit +/// flushing, log entries may sit in OS buffers and be lost on crash. +/// This wrapper ensures each log line is on disk immediately. +/// +/// Performance impact is minimal for a desktop app (~100 log lines/minute). +#[derive(Clone)] +struct FlushingWriter { + file: std::sync::Arc>, +} + +impl FlushingWriter { + fn new(file: std::fs::File) -> Self { + Self { + file: std::sync::Arc::new(std::sync::Mutex::new(file)), + } + } +} + +impl std::io::Write for FlushingWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut f = self.file.lock().map_err(|e| { + std::io::Error::other(format!("lock poisoned: {e}")) + })?; + let n = std::io::Write::write(&mut *f, buf)?; + std::io::Write::flush(&mut *f)?; + Ok(n) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut f = self.file.lock().map_err(|e| { + std::io::Error::other(format!("lock poisoned: {e}")) + })?; + std::io::Write::flush(&mut *f) + } +} + +impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for FlushingWriter { + type Writer = FlushingWriter; + + fn make_writer(&'a self) -> Self::Writer { + self.clone() + } +} + +/// Resolve the path for the agent SQLite database. +/// +/// Uses the platform-standard data directory (creates it if needed). +fn resolve_db_path() -> String { + let dir = data_dir(); + if !dir.exists() { + let _ = std::fs::create_dir_all(&dir); + } + dir.join("agent.db").to_string_lossy().into_owned() +} + +/// Resolve the MCP servers configuration using auto-discovery + optional overrides. +/// +/// 1. Auto-discovers servers by scanning `mcp-servers/` for `package.json` (TS) +/// or `pyproject.toml` (Python) markers. +/// 2. Loads `mcp-servers.json` as optional overrides (missing file is fine). +/// 3. Merges: override entries fully replace discovered entries. +/// 4. Resolves relative paths, venvs, and injects vision model env vars. +fn resolve_mcp_config() -> mcp_client::types::McpServersConfig { + let project_root = resolve_project_root(); + + // 1. Auto-discover servers from mcp-servers/ directory + let mcp_servers_dir = project_root.join("mcp-servers"); + let discovered = mcp_client::discovery::discover_servers(&mcp_servers_dir); + tracing::info!( + discovered = discovered.len(), + servers = ?discovered.keys().collect::>(), + "auto-discovered MCP servers" + ); + + // 2. Load optional override file + let overrides = load_override_file(&project_root); + + // 3. Merge: overrides win + let mut merged = mcp_client::discovery::merge_configs(discovered, overrides); + + // 4. Filter by enabled_servers allowlist from _models/config.yaml (if set) + filter_by_enabled_servers(&mut merged, &project_root); + + let mut config = mcp_client::types::McpServersConfig { servers: merged }; + + // 5. Post-process: resolve paths, venvs, inject vision env vars + resolve_paths_and_env(&mut config, &project_root); + + tracing::info!( + server_count = config.servers.len(), + servers = ?config.servers.keys().collect::>(), + "final MCP server config" + ); + + config +} + +/// Filter discovered servers by the `enabled_servers` allowlist in `_models/config.yaml`. +/// +/// When `enabled_servers` is set, only servers whose names appear in the list +/// are kept. All others are removed. When absent or empty, all servers pass through. +fn filter_by_enabled_servers( + servers: &mut std::collections::HashMap, + project_root: &std::path::Path, +) { + let config_path = project_root.join("_models/config.yaml"); + let content = match std::fs::read_to_string(&config_path) { + Ok(c) => c, + Err(_) => return, // No config file — skip filtering + }; + + // Parse just enough YAML to extract enabled_servers without requiring + // the full ModelsConfig (which needs model configs to be valid). + let yaml: serde_json::Value = match serde_yaml::from_str(&content) { + Ok(v) => v, + Err(_) => return, + }; + + let enabled = match yaml.get("enabled_servers").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return, // Field absent — no filtering + }; + + let allowlist: std::collections::HashSet = enabled + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + if allowlist.is_empty() { + return; + } + + let before = servers.len(); + servers.retain(|name, _| allowlist.contains(name)); + let after = servers.len(); + + tracing::info!( + before, + after, + enabled = ?allowlist, + "filtered MCP servers by enabled_servers allowlist" + ); +} + +/// Filter tools by the `enabled_tools` allowlist in `_models/config.yaml`. +/// +/// When `enabled_tools` is set, only tools whose fully-qualified names appear +/// in the list are kept in the registry. All others are removed. This allows +/// curating a tight tool surface for specific demos or deployments. +/// +/// Must be called AFTER `McpClient::start_all()` has populated the registry. +fn filter_tools_by_allowlist(mcp_client: &mut McpClient, project_root: &std::path::Path) { + let config_path = project_root.join("_models/config.yaml"); + let content = match std::fs::read_to_string(&config_path) { + Ok(c) => c, + Err(_) => return, // No config file — skip filtering + }; + + let yaml: serde_json::Value = match serde_yaml::from_str(&content) { + Ok(v) => v, + Err(_) => return, + }; + + let enabled = match yaml.get("enabled_tools").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return, // Field absent — no filtering + }; + + let allowlist: std::collections::HashSet = enabled + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + if allowlist.is_empty() { + return; + } + + mcp_client.registry.retain_tools(&allowlist); +} + +/// Determine the project root directory. +/// +/// Resolution order: +/// 1. `mcp-servers/` relative to cwd (dev mode, running from project root). +/// 2. `../mcp-servers/` relative to cwd (dev mode, running from `src-tauri/`). +/// 3. `mcp-servers/` relative to the executable (packaged app). +/// 4. Fallback: cwd parent directory. +pub(crate) fn resolve_project_root() -> std::path::PathBuf { + let cwd = std::env::current_dir().unwrap_or_default(); + + // Dev mode: cwd is the project root + if cwd.join("mcp-servers").is_dir() { + return cwd; + } + + // Dev mode: cwd is src-tauri/ + if cwd.join("..").join("mcp-servers").is_dir() { + return cwd.join("..").canonicalize().unwrap_or(cwd); + } + + // Packaged app: check relative to the executable location. + // macOS: .app/Contents/MacOS/localcowork → .app/Contents/Resources/ + // Windows: install_dir/localcowork.exe → install_dir/ + // Linux: install_dir/localcowork → install_dir/ + if let Ok(exe) = std::env::current_exe() { + if let Some(exe_dir) = exe.parent() { + // macOS .app bundle: Resources/ is a sibling of MacOS/ + let macos_resources = exe_dir.join("../Resources"); + if macos_resources.join("mcp-servers").is_dir() { + if let Ok(resolved) = macos_resources.canonicalize() { + return resolved; + } + } + // Flat layout (Windows/Linux or dev binary) + if exe_dir.join("mcp-servers").is_dir() { + return exe_dir.to_path_buf(); + } + } + } + + // Last resort: cwd parent + cwd.parent() + .unwrap_or(std::path::Path::new(".")) + .to_path_buf() +} + +/// Load the optional `mcp-servers.json` override file. +/// +/// Returns an empty map if the file doesn't exist or can't be parsed. +fn load_override_file( + project_root: &std::path::Path, +) -> std::collections::HashMap { + let candidates = [ + project_root.join("src-tauri/mcp-servers.json"), + project_root.join("mcp-servers.json"), + ]; + + for path in &candidates { + if let Ok(content) = std::fs::read_to_string(path) { + match serde_json::from_str::(&content) { + Ok(cfg) => { + tracing::info!( + path = %path.display(), + count = cfg.servers.len(), + "loaded MCP override config" + ); + return cfg.servers; + } + Err(e) => { + tracing::warn!( + path = %path.display(), + error = %e, + "failed to parse MCP override config" + ); + } + } + } + } + + std::collections::HashMap::new() +} + +/// Resolve relative paths, venvs, and inject vision env vars into all server configs. +fn resolve_paths_and_env( + config: &mut mcp_client::types::McpServersConfig, + project_root: &std::path::Path, +) { + for server_config in config.servers.values_mut() { + // Resolve relative cwd to absolute + if let Some(ref cwd) = server_config.cwd { + if !std::path::Path::new(cwd).is_absolute() { + let abs_cwd = project_root.join(cwd); + server_config.cwd = Some(abs_cwd.to_string_lossy().into_owned()); + } + } + + // Resolve venv: rewrite command to venv binary and inject env vars + if let Some(ref venv) = server_config.venv { + let base_dir = server_config + .cwd + .as_ref() + .map(std::path::PathBuf::from) + .unwrap_or_else(|| project_root.to_path_buf()); + + let abs_venv = if std::path::Path::new(venv).is_absolute() { + std::path::PathBuf::from(venv) + } else { + base_dir.join(venv) + }; + // Windows venvs use Scripts\ instead of bin/ + let venv_bin = if cfg!(target_os = "windows") { + abs_venv.join("Scripts") + } else { + abs_venv.join("bin") + }; + let venv_command = venv_bin.join(&server_config.command); + + if venv_command.exists() { + server_config.command = venv_command.to_string_lossy().into_owned(); + server_config.env.insert( + "VIRTUAL_ENV".to_string(), + abs_venv.to_string_lossy().into_owned(), + ); + let system_path = std::env::var("PATH").unwrap_or_default(); + server_config.env.insert( + "PATH".to_string(), + if cfg!(target_os = "windows") { + format!("{};{system_path}", venv_bin.to_string_lossy()) + } else { + format!("{}:{system_path}", venv_bin.to_string_lossy()) + }, + ); + tracing::info!( + venv = %abs_venv.display(), + command = %server_config.command, + "resolved venv for MCP server" + ); + } else { + tracing::warn!( + venv = %abs_venv.display(), + command = %server_config.command, + "venv binary not found, using command as-is" + ); + } + + server_config.venv = Some(abs_venv.to_string_lossy().into_owned()); + } + } + + // Inject LOCALCOWORK_DATA_DIR so MCP servers use platform-standard paths + let app_data = data_dir().to_string_lossy().into_owned(); + for server_config in config.servers.values_mut() { + server_config + .env + .entry("LOCALCOWORK_DATA_DIR".to_string()) + .or_insert_with(|| app_data.clone()); + } + + // Inject vision model endpoint env vars + if let Some((vision_endpoint, vision_model)) = resolve_vision_model(project_root) { + for server_config in config.servers.values_mut() { + server_config + .env + .entry("LOCALCOWORK_VISION_ENDPOINT".to_string()) + .or_insert_with(|| vision_endpoint.clone()); + server_config + .env + .entry("LOCALCOWORK_VISION_MODEL".to_string()) + .or_insert_with(|| vision_model.clone()); + } + tracing::info!( + endpoint = %vision_endpoint, + model = %vision_model, + "injected vision model env vars into MCP servers" + ); + } +} + +/// Find the first vision-capable model from `_models/config.yaml`. +/// +/// Returns `(base_url, model_name)` if a model with the "vision" capability is found. +/// Checks: (1) active model, (2) fallback chain, (3) any model in the config. +fn resolve_vision_model(project_root: &std::path::Path) -> Option<(String, String)> { + let config_path = project_root.join("_models/config.yaml"); + let content = std::fs::read_to_string(&config_path).ok()?; + let yaml: serde_json::Value = serde_yaml::from_str(&content).ok()?; + + let models = yaml.get("models")?.as_object()?; + let active = yaml.get("active_model")?.as_str()?; + + // Helper: check if a model has vision capability + let has_vision = |key: &str| -> Option<(String, String)> { + let model = models.get(key)?; + let caps = model.get("capabilities")?.as_array()?; + let is_vision = caps.iter().any(|c| c.as_str() == Some("vision")); + if !is_vision { + return None; + } + let base_url = model.get("base_url")?.as_str()?.to_string(); + let model_name = model + .get("model_name") + .and_then(|v| v.as_str()) + .unwrap_or(key) + .to_string(); + Some((base_url, model_name)) + }; + + // 1. Check active model first + if let Some(result) = has_vision(active) { + return Some(result); + } + + // 2. Check fallback chain + if let Some(chain) = yaml.get("fallback_chain").and_then(|c| c.as_array()) { + for entry in chain { + if let Some(key) = entry.as_str() { + if let Some(result) = has_vision(key) { + return Some(result); + } + } + } + } + + // 3. Scan all models for any with vision capability (e.g., dedicated VL model) + for key in models.keys() { + if let Some(result) = has_vision(key) { + return Some(result); + } + } + + None +} + +/// Run the Tauri application. +pub fn run() { + // Initialize tracing FIRST — before any tracing::info!() calls + init_tracing(); + + // Initialize the SQLite-backed ConversationManager + let db_path = resolve_db_path(); + let db = AgentDatabase::open(&db_path).expect("failed to open agent database"); + let conversation_manager = ConversationManager::new(db); + + tracing::info!(db_path = %db_path, "agent database initialized"); + + // Register an empty MCP client synchronously so that TokioMutex + // is always available in Tauri state. The async setup task will replace the + // empty client with a fully initialized one once servers are started. + // This prevents panics if start_session is called before MCP init completes. + let empty_mcp_config = mcp_client::types::McpServersConfig { + servers: std::collections::HashMap::new(), + }; + + tauri::Builder::default() + .plugin(tauri_plugin_shell::init()) + .plugin(tauri_plugin_dialog::init()) + .manage(Mutex::new(conversation_manager)) + .manage(TokioMutex::new(McpClient::new(empty_mcp_config, None))) + .manage(TokioMutex::new(PermissionStore::new())) + .manage(TokioMutex::new(SamplingConfig::load_or_default())) + .manage(TokioMutex::new(None::>) + as PendingConfirmation) + .manage(TokioMutex::new(HashMap::::new()) as InFlightRequests) + .setup(|app| { + // Initialize MCP client asynchronously during app setup. + // Once servers are started, replace the empty client via lock. + let handle = app.handle().clone(); + tauri::async_runtime::spawn(async move { + // Provision missing Python venvs BEFORE resolving MCP config, + // so that discovery picks up the newly created .venv directories. + let project_root = resolve_project_root(); + commands::python_env_startup::provision_missing_venvs(&project_root).await; + + let config = resolve_mcp_config(); + let mut mcp_client = McpClient::new(config, None); + + let errors = mcp_client.start_all().await; + for (name, err) in &errors { + tracing::warn!( + server = %name, + error = %err, + "MCP server failed to start (non-fatal)" + ); + } + + // Filter tools by enabled_tools allowlist (if configured) + filter_tools_by_allowlist(&mut mcp_client, &project_root); + + + let running = mcp_client.running_server_count(); + let tools = mcp_client.tool_count(); + tracing::info!( + running_servers = running, + total_tools = tools, + "MCP client initialized" + ); + + // Replace the empty placeholder with the fully initialized client + let state: tauri::State<'_, TokioMutex> = handle.state(); + let mut lock = state.lock().await; + *lock = mcp_client; + }); + + Ok(()) + }) + .invoke_handler(tauri::generate_handler![ + commands::greet, + commands::chat::start_session, + commands::chat::send_message, + commands::chat::respond_to_confirmation, + commands::session::list_sessions, + commands::session::load_session, + commands::session::delete_session, + commands::session::get_context_budget, + commands::session::cleanup_empty_sessions, + commands::filesystem::list_directory, + commands::filesystem::get_home_dir, + commands::settings::get_models_config, + commands::settings::get_mcp_servers_status, + commands::settings::list_permission_grants, + commands::settings::revoke_permission, + commands::settings::get_sampling_config, + commands::settings::update_sampling_config, + commands::settings::reset_sampling_config, + commands::settings::get_app_settings, + commands::settings::update_app_settings, + commands::settings::add_allowed_path, + commands::settings::remove_allowed_path, + commands::settings::export_settings, + commands::settings::import_settings, + commands::settings::poll_settings_changed, + commands::settings::check_config_reload, + commands::settings::reload_model_config, + commands::hardware::detect_hardware, + commands::model_download::download_model, + commands::model_download::verify_model, + commands::model_download::get_model_dir, + commands::ollama::check_llama_server_status, + commands::ollama::check_ollama_status, + commands::ollama::list_ollama_models, + commands::ollama::pull_ollama_model, + commands::python_env::ensure_python_server_env, + commands::python_env::ensure_all_python_envs, + ]) + .run(tauri::generate_context!()) + .expect("error while running tauri application"); +} + +#[cfg(test)] +mod tests { + use super::*; + use mcp_client::ServerConfig; + use std::collections::HashMap; + use tempfile::TempDir; + + #[test] + fn test_data_dir_returns_valid_path() { + let dir = data_dir(); + assert!(dir.is_absolute()); + assert!(dir.to_string_lossy().contains("com.localcowork.app")); + } + + #[test] + fn test_cache_dir_is_subdirectory_of_data_dir() { + let data = data_dir(); + let cache = cache_dir(); + assert!(cache.starts_with(&data)); + assert!(cache.to_string_lossy().contains("cache")); + } + + #[test] + fn test_rotate_log_file_creates_rotated_copies() { + let temp_dir = TempDir::new().unwrap(); + let log_path = temp_dir.path().join("test.log"); + + // Create original file + std::fs::write(&log_path, "original content").unwrap(); + + // Rotate + rotate_log_file(&log_path, 3); + + // Original should be moved to .1 + let rotated = log_path.with_extension("log.1"); + assert!(rotated.exists()); + + let content = std::fs::read_to_string(&rotated).unwrap(); + assert_eq!(content, "original content"); + } + + #[test] + fn test_rotate_log_file_handles_missing_file() { + let temp_dir = TempDir::new().unwrap(); + let log_path = temp_dir.path().join("nonexistent.log"); + + // Should not panic + rotate_log_file(&log_path, 3); + } + + #[test] + fn test_rotate_log_file_multiple_rotations() { + let temp_dir = TempDir::new().unwrap(); + let log_path = temp_dir.path().join("test.log"); + + // Create and rotate multiple times + std::fs::write(&log_path, "v1").unwrap(); + rotate_log_file(&log_path, 3); + + std::fs::write(&log_path, "v2").unwrap(); + rotate_log_file(&log_path, 3); + + std::fs::write(&log_path, "v3").unwrap(); + rotate_log_file(&log_path, 3); + + // Check all versions exist + assert!(log_path.with_extension("log.1").exists()); + assert!(log_path.with_extension("log.2").exists()); + assert!(log_path.with_extension("log.3").exists()); + + // Oldest should be v1 + let v1 = std::fs::read_to_string(log_path.with_extension("log.3")).unwrap(); + assert_eq!(v1, "v1"); + } + + #[test] + fn test_resolve_db_path_returns_sqlite_path() { + let path = resolve_db_path(); + assert!(path.starts_with('/')); // Should be absolute path + assert!(path.ends_with(".db")); + } + + #[test] + fn test_resolve_project_root_finds_mcp_servers() { + // This test verifies the function returns a valid path + let root = resolve_project_root(); + assert!(root.is_absolute()); + } + + fn test_filter_by_enabled_servers_filters_correctly() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + // Create config with enabled_servers + let config_content = r#" +enabled_servers: + - filesystem + - task +"#; + std::fs::write(project_root.join("_models/config.yaml"), config_content).unwrap(); + + // Create servers + let mut servers = HashMap::new(); + servers.insert("filesystem".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("task".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("calendar".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); // Should be removed + servers.insert("email".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); // Should be removed + + let before = servers.len(); + filter_by_enabled_servers(&mut servers, project_root); + let after = servers.len(); + + assert_eq!(before, 4); + assert_eq!(after, 2); + assert!(servers.contains_key("filesystem")); + assert!(servers.contains_key("task")); + assert!(!servers.contains_key("calendar")); + assert!(!servers.contains_key("email")); + } + + #[test] + fn test_filter_by_enabled_servers_handles_missing_config() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + // No config file at all + + let mut servers = HashMap::new(); + servers.insert("a".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("b".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + + let before = servers.len(); + filter_by_enabled_servers(&mut servers, project_root); + + // Should keep all since no config + assert_eq!(servers.len(), before); + } + + #[test] + fn test_filter_by_enabled_servers_no_config_keeps_all() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + // No config file + + let mut servers = HashMap::new(); + servers.insert("a".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("b".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + + let before = servers.len(); + filter_by_enabled_servers(&mut servers, project_root); + + // Should keep all since no config + assert_eq!(servers.len(), before); + } + + #[test] + fn test_load_override_file_returns_empty_for_missing() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + let result = load_override_file(project_root); + assert!(result.is_empty()); + } + + #[test] + fn test_load_override_file_parses_valid_config() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + let config_content = r#"{ + "servers": { + "test-server": { + "command": "node", + "args": ["test.js"] + } + } + }"#; + std::fs::write(project_root.join("mcp-servers.json"), config_content).unwrap(); + + let result = load_override_file(project_root); + assert!(result.contains_key("test-server")); + } + + #[test] + fn test_resolve_vision_model_returns_none_without_config() { + let temp_dir = TempDir::new().unwrap(); + + let result = resolve_vision_model(temp_dir.path()); + assert!(result.is_none()); + } + + #[test] + fn test_filter_tools_by_allowlist_works_without_config() { + // Test that filter_tools_by_allowlist doesn't panic without config + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + let mut mcp_client = McpClient::new( + mcp_client::types::McpServersConfig { servers: HashMap::new() }, + None, + ); + + // Should not panic + filter_tools_by_allowlist(&mut mcp_client, project_root); + } +} diff --git a/src/App.tsx b/src/App.tsx new file mode 100644 index 0000000..b4ece1a --- /dev/null +++ b/src/App.tsx @@ -0,0 +1,92 @@ +import { useEffect } from "react"; + +import { ChatPanel } from "./components/Chat"; +import { FileBrowser } from "./components/FileBrowser"; +import { OnboardingWizard } from "./components/Onboarding"; +import { SettingsPanel } from "./components/Settings"; +import { useOnboardingStore } from "./stores/onboardingStore"; +import { useSettingsStore } from "./stores/settingsStore"; + +/** + * Root application component. + * + * Shows the OnboardingWizard on first run, then the main app layout. + */ +export function App(): React.JSX.Element { + const toggleSettings = useSettingsStore((s) => s.togglePanel); + const isSettingsOpen = useSettingsStore((s) => s.isOpen); + const startConfigWatch = useSettingsStore((s) => s.startConfigWatch); + const stopConfigWatch = useSettingsStore((s) => s.stopConfigWatch); + const configReloadNotification = useSettingsStore( + (s) => s.configReloadNotification, + ); + const clearConfigReloadNotification = useSettingsStore( + (s) => s.clearConfigReloadNotification, + ); + const isOnboardingComplete = useOnboardingStore((s) => s.isComplete); + + // Start/stop config file watching based on settings panel state + useEffect(() => { + if (isSettingsOpen) { + startConfigWatch(); + } else { + stopConfigWatch(); + } + return () => stopConfigWatch(); + }, [isSettingsOpen, startConfigWatch, stopConfigWatch]); + + if (!isOnboardingComplete) { + return ; + } + + return ( +
+ {/* Config reload toast notification */} + {configReloadNotification && ( +
+ 🔄 + {configReloadNotification} + +
+ )} + +
+
+
+

LocalCowork

+ on-device +
+ + powered by LFM2-24B-A2B from Liquid AI + +
+
+ +
+ +
+ + +
+ +
+ v0.1.0 — Agent Core +
+ + +
+ ); +} diff --git a/src/components/Chat/MessageInput.tsx b/src/components/Chat/MessageInput.tsx new file mode 100644 index 0000000..fd401b6 --- /dev/null +++ b/src/components/Chat/MessageInput.tsx @@ -0,0 +1,118 @@ +/** + * MessageInput — text input area for sending messages. + * + * Supports Enter to send (Shift+Enter for newline) and disables + * input while the assistant is generating. Includes an InputToolbar + * below the textarea for folder context (Cowork-style "Work in a folder"). + * Implements debouncing to prevent duplicate sends. + */ + +import { useCallback, useRef, useState } from "react"; + +import { InputToolbar } from "./InputToolbar"; + +interface MessageInputProps { + readonly onSend: (content: string) => void; + readonly disabled: boolean; +} + +/** Minimum time between send requests to prevent duplicates (500ms) */ +const SEND_DEBOUNCE_MS = 500; + +export function MessageInput({ + onSend, + disabled, +}: MessageInputProps): React.JSX.Element { + const [value, setValue] = useState(""); + const textareaRef = useRef(null); + const lastSendTimeRef = useRef(0); + const [isDebouncing, setIsDebouncing] = useState(false); + + const handleSend = useCallback(() => { + const trimmed = value.trim(); + if (!trimmed || disabled) return; + + // Debounce: ignore clicks within 500ms + const now = Date.now(); + if (now - lastSendTimeRef.current < SEND_DEBOUNCE_MS) { + setIsDebouncing(true); + setTimeout(() => setIsDebouncing(false), SEND_DEBOUNCE_MS); + return; + } + lastSendTimeRef.current = now; + + onSend(trimmed); + setValue(""); + + // Reset textarea height + if (textareaRef.current) { + textareaRef.current.style.height = "auto"; + } + }, [value, disabled, onSend]); + + const handleKeyDown = (e: React.KeyboardEvent): void => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + handleSend(); + } + }; + + const handleInput = (e: React.ChangeEvent): void => { + setValue(e.target.value); + + // Auto-resize textarea + const textarea = e.target; + textarea.style.height = "auto"; + textarea.style.height = `${Math.min(textarea.scrollHeight, 200)}px`; + }; + + const isLoading = disabled || isDebouncing; + + return ( +
+
+