diff --git a/crates/goose/src/acp/server.rs b/crates/goose/src/acp/server.rs index 6cda8fedf979..506d706113d0 100644 --- a/crates/goose/src/acp/server.rs +++ b/crates/goose/src/acp/server.rs @@ -21,7 +21,7 @@ use crate::providers::inventory::{ InventoryIdentity, ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan, RefreshPlan, RefreshSkipReason, }; -use crate::session::session_manager::SessionType; +use crate::session::session_manager::{SessionListCursor, SessionType}; use crate::session::{EnabledExtensionsState, Session, SessionManager}; use crate::source_roots::SourceRoot; use crate::utils::sanitize_unicode_tags; @@ -51,6 +51,7 @@ use agent_client_protocol::{ Responder, }; use anyhow::Result; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; use fs_err as fs; use futures::future::{BoxFuture, Either}; use futures::stream::{self, StreamExt}; @@ -58,7 +59,8 @@ use futures::FutureExt; use rmcp::model::{ AnnotateAble, CallToolResult, RawContent, RawTextContent, ResourceContents, Role, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use std::collections::{HashMap, HashSet}; use std::panic::AssertUnwindSafe; use std::sync::Arc; @@ -91,6 +93,10 @@ pub type AcpProviderFactory = Arc< + Sync, >; +const SESSION_LIST_PAGE_SIZE: usize = 50; +const ACP_SESSION_LIST_TYPES: [SessionType; 3] = + [SessionType::User, SessionType::Scheduled, SessionType::Acp]; + /// Convenience conversions from any `Display` error into an `agent_client_protocol::Error`. /// /// Replaces the repetitive `.internal_err()` @@ -253,6 +259,94 @@ fn sid_short(id: &str) -> String { id.chars().take(8).collect() } +#[derive(Debug, Serialize, Deserialize)] +struct SessionListCursorToken { + updated_at: chrono::DateTime, + // Goose stores updated_at with second precision in common write paths, so the + // cursor needs the full (updated_at, id) sort key to avoid skipping tied rows. + session_id: String, + filter_hash: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct SessionListCursorFilters { + cwd: Option, + session_types: Vec, + non_empty: bool, +} + +fn invalid_session_list_cursor(message: &'static str) -> agent_client_protocol::Error { + agent_client_protocol::Error::invalid_params().data(message) +} + +// bind cursors to the effective filters so they cannot be reused for a different list. +fn session_list_filter_hash( + cwd: Option<&std::path::Path>, + session_types: &[SessionType], +) -> Result { + let mut session_type_names = session_types + .iter() + .map(ToString::to_string) + .collect::>(); + session_type_names.sort(); + let filters = SessionListCursorFilters { + cwd: cwd.map(|path| path.to_string_lossy().to_string()), + session_types: session_type_names, + non_empty: true, + }; + let bytes = + serde_json::to_vec(&filters).internal_err_ctx("Failed to encode session list filters")?; + Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(bytes))) +} + +fn decode_session_list_cursor( + cursor: Option<&str>, + cwd: Option<&std::path::Path>, + session_types: &[SessionType], +) -> Result, agent_client_protocol::Error> { + let Some(cursor) = cursor else { + return Ok(None); + }; + + let bytes = URL_SAFE_NO_PAD + .decode(cursor) + .map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?; + let token: SessionListCursorToken = serde_json::from_slice(&bytes) + .map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?; + + if token.session_id.is_empty() || token.filter_hash.is_empty() { + return Err(invalid_session_list_cursor("malformed session list cursor")); + } + + let expected_filter_hash = session_list_filter_hash(cwd, session_types)?; + if token.filter_hash != expected_filter_hash { + return Err(invalid_session_list_cursor( + "session list cursor does not match filters", + )); + } + + Ok(Some(SessionListCursor { + updated_at: token.updated_at, + session_id: token.session_id, + })) +} + +fn encode_session_list_cursor( + cursor: &SessionListCursor, + cwd: Option<&std::path::Path>, + session_types: &[SessionType], +) -> Result { + let token = SessionListCursorToken { + updated_at: cursor.updated_at, + session_id: cursor.session_id.clone(), + filter_hash: session_list_filter_hash(cwd, session_types)?, + }; + let bytes = + serde_json::to_vec(&token).internal_err_ctx("Failed to encode session list cursor")?; + Ok(URL_SAFE_NO_PAD.encode(bytes)) +} + fn session_meta(session: &Session) -> serde_json::Map { let mut meta = serde_json::Map::new(); meta.insert( @@ -3306,16 +3400,35 @@ impl GooseAcpAgent { Ok(()) } - async fn on_list_sessions(&self) -> Result { + async fn on_list_sessions( + &self, + req: ListSessionsRequest, + ) -> Result { + if let Some(cwd) = req.cwd.as_deref() { + if !cwd.is_absolute() { + return Err(agent_client_protocol::Error::invalid_params() + .data("cwd must be an absolute path")); + } + } + + let cwd = req.cwd.as_deref(); + let cursor = + decode_session_list_cursor(req.cursor.as_deref(), cwd, &ACP_SESSION_LIST_TYPES)?; + // ACP clients see their own (Acp) sessions plus legacy User/Scheduled ones. - let sessions = self + let page = self .session_manager - .list_sessions_by_types(&[SessionType::User, SessionType::Scheduled, SessionType::Acp]) + .list_nonempty_sessions_by_types_paged( + &ACP_SESSION_LIST_TYPES, + cwd, + cursor.as_ref(), + SESSION_LIST_PAGE_SIZE, + ) .await .internal_err()?; - let session_infos: Vec = sessions + let session_infos: Vec = page + .sessions .into_iter() - .filter(|s| s.message_count > 0) .map(|s| { let meta = session_meta(&s); SessionInfo::new(SessionId::new(s.id), s.working_dir) @@ -3324,7 +3437,12 @@ impl GooseAcpAgent { .meta(meta) }) .collect(); - Ok(ListSessionsResponse::new(session_infos)) + let next_cursor = page + .next_cursor + .as_ref() + .map(|cursor| encode_session_list_cursor(cursor, cwd, &ACP_SESSION_LIST_TYPES)) + .transpose()?; + Ok(ListSessionsResponse::new(session_infos).next_cursor(next_cursor)) } async fn on_fork_session( diff --git a/crates/goose/src/acp/server/dispatch.rs b/crates/goose/src/acp/server/dispatch.rs index 3a3694c4e426..7e23333e2545 100644 --- a/crates/goose/src/acp/server/dispatch.rs +++ b/crates/goose/src/acp/server/dispatch.rs @@ -330,9 +330,12 @@ impl HandleDispatchFrom for GooseAcpHandler { .if_request({ let agent = agent.clone(); let cx = cx.clone(); - |_req: ListSessionsRequest, responder: Responder| async move { + |req: ListSessionsRequest, responder: Responder| async move { cx.spawn(async move { - responder.respond(agent.on_list_sessions().await?)?; + match agent.on_list_sessions(req).await { + Ok(response) => responder.respond(response)?, + Err(e) => responder.respond_with_error(e)?, + } Ok(()) })?; Ok(()) diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 37b27266469a..6090a3041283 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -266,6 +266,27 @@ pub struct SessionManager { storage: Arc, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct SessionListCursor { + pub(crate) updated_at: DateTime, + pub(crate) session_id: String, +} + +#[derive(Debug, Clone)] +pub(crate) struct SessionListPage { + pub(crate) sessions: Vec, + pub(crate) next_cursor: Option, +} + +#[derive(Debug, Default)] +struct SessionListQuery<'a> { + types: Option<&'a [SessionType]>, + working_dir: Option<&'a Path>, + cursor: Option<&'a SessionListCursor>, + limit: Option, + require_messages: bool, +} + #[derive(Debug, Clone)] pub struct SessionNameUpdate { pub session_id: String, @@ -332,6 +353,18 @@ impl SessionManager { self.storage.list_sessions_by_types(Some(types)).await } + pub(crate) async fn list_nonempty_sessions_by_types_paged( + &self, + types: &[SessionType], + working_dir: Option<&Path>, + cursor: Option<&SessionListCursor>, + page_size: usize, + ) -> Result { + self.storage + .list_nonempty_sessions_by_types_paged(types, working_dir, cursor, page_size) + .await + } + pub async fn list_all_sessions(&self) -> Result> { self.storage.list_sessions_by_types(None).await } @@ -1427,17 +1460,46 @@ impl SessionStorage { Self::replace_conversation_inner(pool, session_id, conversation).await } - async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result> { - let (where_clause, binds): (String, Vec) = match types { - Some(t) if !t.is_empty() => { - let placeholders: String = t.iter().map(|_| "?").collect::>().join(", "); - ( - format!("WHERE s.session_type IN ({})", placeholders), - t.iter().map(|t| t.to_string()).collect(), - ) - } - Some(_) => return Ok(Vec::new()), - None => (String::new(), Vec::new()), + async fn list_sessions_matching(&self, options: SessionListQuery<'_>) -> Result> { + if matches!(options.types, Some(types) if types.is_empty()) { + return Ok(Vec::new()); + } + + let mut where_clauses = Vec::new(); + if let Some(types) = options.types { + let placeholders = types.iter().map(|_| "?").collect::>().join(", "); + where_clauses.push(format!("s.session_type IN ({})", placeholders)); + } + if options.working_dir.is_some() { + where_clauses.push("s.working_dir = ?".to_string()); + } + if options.cursor.is_some() { + where_clauses.push( + "(datetime(s.updated_at) < datetime(?) \ + OR (datetime(s.updated_at) = datetime(?) AND s.id < ?))" + .to_string(), + ); + } + + let where_clause = if where_clauses.is_empty() { + String::new() + } else { + format!("WHERE {}", where_clauses.join(" AND ")) + }; + let message_join = if options.require_messages { + "JOIN messages m ON s.id = m.session_id" + } else { + "LEFT JOIN messages m ON s.id = m.session_id" + }; + let order_by = if options.cursor.is_some() || options.limit.is_some() { + "ORDER BY datetime(s.updated_at) DESC, s.id DESC" + } else { + "ORDER BY s.updated_at DESC" + }; + let limit_clause = if options.limit.is_some() { + "LIMIT ?" + } else { + "" }; let query = format!( @@ -1450,23 +1512,90 @@ impl SessionStorage { s.archived_at, s.project_id, COUNT(m.id) as message_count FROM sessions s - LEFT JOIN messages m ON s.id = m.session_id + {} {} GROUP BY s.id - ORDER BY s.updated_at DESC + {} + {} "#, - where_clause + message_join, where_clause, order_by, limit_clause ); let mut q = sqlx::query_as::<_, Session>(&query); - for b in &binds { - q = q.bind(b); + if let Some(types) = options.types { + for session_type in types { + q = q.bind(session_type.to_string()); + } + } + if let Some(working_dir) = options.working_dir { + q = q.bind(working_dir.to_string_lossy().to_string()); + } + if let Some(cursor) = options.cursor { + let updated_at = cursor.updated_at.to_rfc3339(); + // Normalize mixed SQLite CURRENT_TIMESTAMP and RFC3339 stored values. + q = q.bind(updated_at.clone()); + q = q.bind(updated_at); + q = q.bind(&cursor.session_id); + } + if let Some(limit) = options.limit { + q = q.bind(limit as i64); } let pool = self.pool().await?; q.fetch_all(pool).await.map_err(Into::into) } + async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result> { + self.list_sessions_matching(SessionListQuery { + types, + ..Default::default() + }) + .await + } + + async fn list_nonempty_sessions_by_types_paged( + &self, + types: &[SessionType], + working_dir: Option<&Path>, + cursor: Option<&SessionListCursor>, + page_size: usize, + ) -> Result { + if types.is_empty() || page_size == 0 { + return Ok(SessionListPage { + sessions: Vec::new(), + next_cursor: None, + }); + } + + let mut sessions = self + .list_sessions_matching(SessionListQuery { + types: Some(types), + working_dir, + cursor, + limit: Some(page_size + 1), + require_messages: true, + }) + .await?; + let has_next_page = sessions.len() > page_size; + let next_cursor = if has_next_page { + let anchor = &sessions[page_size - 1]; + Some(SessionListCursor { + updated_at: anchor.updated_at, + session_id: anchor.id.clone(), + }) + } else { + None + }; + if has_next_page { + sessions.truncate(page_size); + } + + Ok(SessionListPage { + sessions, + next_cursor, + }) + } + async fn list_sessions(&self) -> Result> { self.list_sessions_by_types(Some(&[SessionType::User, SessionType::Scheduled])) .await @@ -1790,6 +1919,89 @@ mod tests { const NUM_CONCURRENT_SESSIONS: i32 = 10; + async fn create_session_for_list( + sm: &SessionManager, + working_dir: &str, + has_message: bool, + ) -> String { + let session = sm + .create_session( + PathBuf::from(working_dir), + format!("Session in {working_dir}"), + SessionType::User, + GooseMode::default(), + ) + .await + .unwrap(); + + if has_message { + sm.add_message(&session.id, &Message::user().with_text("message")) + .await + .unwrap(); + } + + session.id + } + + async fn set_sessions_updated_at( + sm: &SessionManager, + session_ids: &[String], + updated_at: &str, + ) { + let pool = sm.storage().pool().await.unwrap(); + let updated_at = chrono::DateTime::parse_from_rfc3339(updated_at).unwrap(); + let timestamp = updated_at.format("%Y-%m-%d %H:%M:%S").to_string(); + + for session_id in session_ids { + sqlx::query("UPDATE sessions SET updated_at = ? WHERE id = ?") + .bind(×tamp) + .bind(session_id) + .execute(pool) + .await + .unwrap(); + } + } + + async fn expected_session_list_ids(sm: &SessionManager, session_ids: &[String]) -> Vec { + let mut sessions = Vec::new(); + for session_id in session_ids { + sessions.push(sm.get_session(session_id, false).await.unwrap()); + } + sessions.sort_by(|a, b| { + b.updated_at + .cmp(&a.updated_at) + .then_with(|| b.id.cmp(&a.id)) + }); + sessions.into_iter().map(|session| session.id).collect() + } + + async fn assert_session_list_page( + sm: &SessionManager, + cursor: Option<&SessionListCursor>, + working_dir: Option<&str>, + page_size: usize, + expected_ids: &[String], + expected_next_cursor: bool, + ) -> Option { + let page = sm + .list_nonempty_sessions_by_types_paged( + &[SessionType::User], + working_dir.map(Path::new), + cursor, + page_size, + ) + .await + .unwrap(); + let ids = page + .sessions + .iter() + .map(|session| session.id.clone()) + .collect::>(); + assert_eq!(ids.as_slice(), expected_ids); + assert_eq!(page.next_cursor.is_some(), expected_next_cursor); + page.next_cursor + } + async fn run_lock_upgrade_attempt( pool: Pool, session_id: String, @@ -1882,6 +2094,70 @@ mod tests { ); } + #[tokio::test] + async fn test_session_list_paged_first_second_and_final_page() { + let temp_dir = TempDir::new().unwrap(); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); + let mut expected_ids = Vec::new(); + for _ in 0..5 { + expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await); + } + let expected_ids = expected_session_list_ids(&sm, &expected_ids).await; + + let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await; + let cursor = + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..4], true) + .await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[4..5], false).await; + } + + #[tokio::test] + async fn test_session_list_paged_uses_id_tiebreaker_for_duplicate_updated_at() { + let temp_dir = TempDir::new().unwrap(); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); + let mut expected_ids = Vec::new(); + for _ in 0..3 { + expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await); + } + set_sessions_updated_at(&sm, &expected_ids, "2024-01-01T00:00:00Z").await; + let expected_ids = expected_session_list_ids(&sm, &expected_ids).await; + + let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..3], false).await; + } + + #[tokio::test] + async fn test_session_list_paged_filters_empty_and_cwd_before_pagination() { + let temp_dir = TempDir::new().unwrap(); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); + let expected_ids = vec![ + create_session_for_list(&sm, "/tmp/session-list/a", true).await, + create_session_for_list(&sm, "/tmp/session-list/a", true).await, + ]; + create_session_for_list(&sm, "/tmp/session-list/a", false).await; + create_session_for_list(&sm, "/tmp/session-list/b", true).await; + let expected_ids = expected_session_list_ids(&sm, &expected_ids).await; + + let cursor = assert_session_list_page( + &sm, + None, + Some("/tmp/session-list/a"), + 1, + &expected_ids[0..1], + true, + ) + .await; + assert_session_list_page( + &sm, + cursor.as_ref(), + Some("/tmp/session-list/a"), + 1, + &expected_ids[1..2], + false, + ) + .await; + } + #[tokio::test] async fn test_concurrent_session_creation() { let temp_dir = TempDir::new().unwrap(); diff --git a/crates/goose/tests/acp_server_test.rs b/crates/goose/tests/acp_server_test.rs index 72dab982fa10..7d162fea3125 100644 --- a/crates/goose/tests/acp_server_test.rs +++ b/crates/goose/tests/acp_server_test.rs @@ -1,8 +1,10 @@ #[allow(dead_code)] #[path = "acp_common_tests/mod.rs"] mod common_tests; -use common_tests::fixtures::run_test; +use agent_client_protocol::schema::{ListSessionsRequest, ListSessionsResponse}; +use agent_client_protocol::ErrorCode; use common_tests::fixtures::server::AcpServerConnection; +use common_tests::fixtures::{run_test, Connection, OpenAiFixture, TestConnectionConfig}; use common_tests::{ run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set, run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false, @@ -14,10 +16,65 @@ use common_tests::{ run_prompt_mcp, run_prompt_model_mismatch, run_prompt_skill, run_session_name_update_notification, run_shell_terminal_false, run_shell_terminal_true, }; +use goose::config::GooseMode; +use goose::conversation::message::Message; +use goose::session::{SessionManager, SessionType}; +use std::path::Path; tests_config_option_set_error!(AcpServerConnection); tests_mode_set_error!(AcpServerConnection); +async fn seed_list_sessions(data_root: &Path, working_dir: &Path, count: usize) { + let session_manager = SessionManager::new(data_root.to_path_buf()); + for index in 0..count { + let session = session_manager + .create_session( + working_dir.to_path_buf(), + format!("Seed session {index}"), + SessionType::Acp, + GooseMode::default(), + ) + .await + .unwrap(); + session_manager + .add_message(&session.id, &Message::user().with_text("hello")) + .await + .unwrap(); + } +} + +async fn new_connection(data_root: &Path) -> AcpServerConnection { + let openai = OpenAiFixture::new( + vec![], + ::expected_session_id(), + ) + .await; + ::new( + TestConnectionConfig { + data_root: data_root.to_path_buf(), + ..Default::default() + }, + openai, + ) + .await +} + +async fn list_sessions_request( + conn: &AcpServerConnection, + request: ListSessionsRequest, +) -> anyhow::Result { + conn.cx() + .send_request(request) + .block_task() + .await + .map_err(Into::into) +} + +fn assert_invalid_params(error: anyhow::Error) { + let acp_error = error.downcast::().unwrap(); + assert_eq!(acp_error.code, ErrorCode::InvalidParams); +} + #[test] fn test_config_mcp() { run_test(async { run_config_mcp::().await }); @@ -33,6 +90,74 @@ fn test_list_sessions() { run_test(async { run_list_sessions::().await }); } +#[test] +fn test_list_sessions_pagination() { + run_test(async { + let data_root = tempfile::tempdir().unwrap(); + seed_list_sessions(data_root.path(), Path::new("/tmp/acp-session-list"), 51).await; + let conn = new_connection(data_root.path()).await; + + let first = list_sessions_request(&conn, ListSessionsRequest::new()) + .await + .unwrap(); + assert_eq!(first.sessions.len(), 50); + + let second = list_sessions_request( + &conn, + ListSessionsRequest::new().cursor(first.next_cursor.clone().unwrap()), + ) + .await + .unwrap(); + assert_eq!(second.sessions.len(), 1); + assert!(second.next_cursor.is_none()); + + let second_id = &second.sessions[0].session_id; + assert!(first + .sessions + .iter() + .all(|session| session.session_id != *second_id)); + }); +} + +#[test] +fn test_list_sessions_invalid_params() { + run_test(async { + let data_root = tempfile::tempdir().unwrap(); + let cwd = tempfile::tempdir().unwrap(); + let other_cwd = tempfile::tempdir().unwrap(); + seed_list_sessions(data_root.path(), cwd.path(), 51).await; + let conn = new_connection(data_root.path()).await; + + let error = + list_sessions_request(&conn, ListSessionsRequest::new().cursor("*".to_string())) + .await + .unwrap_err(); + assert_invalid_params(error); + + let error = list_sessions_request( + &conn, + ListSessionsRequest::new().cwd(std::path::PathBuf::from("relative/path")), + ) + .await + .unwrap_err(); + assert_invalid_params(error); + + let first = list_sessions_request(&conn, ListSessionsRequest::new().cwd(cwd.path())) + .await + .unwrap(); + + let error = list_sessions_request( + &conn, + ListSessionsRequest::new() + .cwd(other_cwd.path()) + .cursor(first.next_cursor.unwrap()), + ) + .await + .unwrap_err(); + assert_invalid_params(error); + }); +} + #[test] fn test_session_name_update_notification() { run_test(async { run_session_name_update_notification::().await });