From fdfcffb1213a60c4136404f2fde3c57482b04ed6 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 13 May 2026 17:46:06 -0700 Subject: [PATCH 1/2] feat(acp): paginate session list add keyset pagination for acp session/list using a fixed server page size and opaque cursors. filter cwd and nonempty sessions before paging so nextcursor reflects the visible result set. Signed-off-by: Kalvin Chau --- crates/goose/src/acp/server.rs | 101 ++++++++- crates/goose/src/acp/server/dispatch.rs | 7 +- crates/goose/src/session/session_manager.rs | 237 ++++++++++++++++++++ crates/goose/tests/acp_server_test.rs | 127 ++++++++++- 4 files changed, 461 insertions(+), 11 deletions(-) diff --git a/crates/goose/src/acp/server.rs b/crates/goose/src/acp/server.rs index 6cda8fedf979..d8a476cda8fb 100644 --- a/crates/goose/src/acp/server.rs +++ b/crates/goose/src/acp/server.rs @@ -21,7 +21,7 @@ use crate::providers::inventory::{ InventoryIdentity, ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan, RefreshPlan, RefreshSkipReason, }; -use crate::session::session_manager::SessionType; +use crate::session::session_manager::{SessionListCursor, SessionType}; use crate::session::{EnabledExtensionsState, Session, SessionManager}; use crate::source_roots::SourceRoot; use crate::utils::sanitize_unicode_tags; @@ -51,6 +51,7 @@ use agent_client_protocol::{ Responder, }; use anyhow::Result; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; use fs_err as fs; use futures::future::{BoxFuture, Either}; use futures::stream::{self, StreamExt}; @@ -58,7 +59,7 @@ use futures::FutureExt; use rmcp::model::{ AnnotateAble, CallToolResult, RawContent, RawTextContent, ResourceContents, Role, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::panic::AssertUnwindSafe; use std::sync::Arc; @@ -91,6 +92,8 @@ pub type AcpProviderFactory = Arc< + Sync, >; +const SESSION_LIST_PAGE_SIZE: usize = 50; + /// Convenience conversions from any `Display` error into an `agent_client_protocol::Error`. /// /// Replaces the repetitive `.internal_err()` @@ -253,6 +256,65 @@ fn sid_short(id: &str) -> String { id.chars().take(8).collect() } +#[derive(Debug, Serialize, Deserialize)] +struct SessionListCursorToken { + updated_at: chrono::DateTime, + // Goose stores updated_at with second precision in common write paths, so the + // cursor needs the full (updated_at, id) sort key to avoid skipping tied rows. + session_id: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + cwd: Option, +} + +fn invalid_session_list_cursor(message: &'static str) -> agent_client_protocol::Error { + agent_client_protocol::Error::invalid_params().data(message) +} + +fn decode_session_list_cursor( + cursor: Option<&str>, + cwd: Option<&std::path::Path>, +) -> Result, agent_client_protocol::Error> { + let Some(cursor) = cursor else { + return Ok(None); + }; + + let bytes = URL_SAFE_NO_PAD + .decode(cursor) + .map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?; + let token: SessionListCursorToken = serde_json::from_slice(&bytes) + .map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?; + + if token.session_id.is_empty() { + return Err(invalid_session_list_cursor("malformed session list cursor")); + } + + let requested_cwd = cwd.map(std::path::Path::to_path_buf); + if token.cwd != requested_cwd { + return Err(invalid_session_list_cursor( + "session list cursor does not match cwd", + )); + } + + Ok(Some(SessionListCursor { + updated_at: token.updated_at, + session_id: token.session_id, + })) +} + +fn encode_session_list_cursor( + cursor: &SessionListCursor, + cwd: Option<&std::path::Path>, +) -> Result { + let token = SessionListCursorToken { + updated_at: cursor.updated_at, + session_id: cursor.session_id.clone(), + cwd: cwd.map(std::path::Path::to_path_buf), + }; + let bytes = + serde_json::to_vec(&token).internal_err_ctx("Failed to encode session list cursor")?; + Ok(URL_SAFE_NO_PAD.encode(bytes)) +} + fn session_meta(session: &Session) -> serde_json::Map { let mut meta = serde_json::Map::new(); meta.insert( @@ -3306,16 +3368,34 @@ impl GooseAcpAgent { Ok(()) } - async fn on_list_sessions(&self) -> Result { + async fn on_list_sessions( + &self, + req: ListSessionsRequest, + ) -> Result { + if let Some(cwd) = req.cwd.as_deref() { + if !cwd.is_absolute() { + return Err(agent_client_protocol::Error::invalid_params() + .data("cwd must be an absolute path")); + } + } + + let cwd = req.cwd.as_deref(); + let cursor = decode_session_list_cursor(req.cursor.as_deref(), cwd)?; + // ACP clients see their own (Acp) sessions plus legacy User/Scheduled ones. - let sessions = self + let page = self .session_manager - .list_sessions_by_types(&[SessionType::User, SessionType::Scheduled, SessionType::Acp]) + .list_nonempty_sessions_by_types_paged( + &[SessionType::User, SessionType::Scheduled, SessionType::Acp], + cwd, + cursor.as_ref(), + SESSION_LIST_PAGE_SIZE, + ) .await .internal_err()?; - let session_infos: Vec = sessions + let session_infos: Vec = page + .sessions .into_iter() - .filter(|s| s.message_count > 0) .map(|s| { let meta = session_meta(&s); SessionInfo::new(SessionId::new(s.id), s.working_dir) @@ -3324,7 +3404,12 @@ impl GooseAcpAgent { .meta(meta) }) .collect(); - Ok(ListSessionsResponse::new(session_infos)) + let next_cursor = page + .next_cursor + .as_ref() + .map(|cursor| encode_session_list_cursor(cursor, cwd)) + .transpose()?; + Ok(ListSessionsResponse::new(session_infos).next_cursor(next_cursor)) } async fn on_fork_session( diff --git a/crates/goose/src/acp/server/dispatch.rs b/crates/goose/src/acp/server/dispatch.rs index 3a3694c4e426..7e23333e2545 100644 --- a/crates/goose/src/acp/server/dispatch.rs +++ b/crates/goose/src/acp/server/dispatch.rs @@ -330,9 +330,12 @@ impl HandleDispatchFrom for GooseAcpHandler { .if_request({ let agent = agent.clone(); let cx = cx.clone(); - |_req: ListSessionsRequest, responder: Responder| async move { + |req: ListSessionsRequest, responder: Responder| async move { cx.spawn(async move { - responder.respond(agent.on_list_sessions().await?)?; + match agent.on_list_sessions(req).await { + Ok(response) => responder.respond(response)?, + Err(e) => responder.respond_with_error(e)?, + } Ok(()) })?; Ok(()) diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 37b27266469a..67229d51cfb5 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -266,6 +266,18 @@ pub struct SessionManager { storage: Arc, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct SessionListCursor { + pub(crate) updated_at: DateTime, + pub(crate) session_id: String, +} + +#[derive(Debug, Clone)] +pub(crate) struct SessionListPage { + pub(crate) sessions: Vec, + pub(crate) next_cursor: Option, +} + #[derive(Debug, Clone)] pub struct SessionNameUpdate { pub session_id: String, @@ -332,6 +344,18 @@ impl SessionManager { self.storage.list_sessions_by_types(Some(types)).await } + pub(crate) async fn list_nonempty_sessions_by_types_paged( + &self, + types: &[SessionType], + working_dir: Option<&Path>, + cursor: Option<&SessionListCursor>, + page_size: usize, + ) -> Result { + self.storage + .list_nonempty_sessions_by_types_paged(types, working_dir, cursor, page_size) + .await + } + pub async fn list_all_sessions(&self) -> Result> { self.storage.list_sessions_by_types(None).await } @@ -1467,6 +1491,91 @@ impl SessionStorage { q.fetch_all(pool).await.map_err(Into::into) } + async fn list_nonempty_sessions_by_types_paged( + &self, + types: &[SessionType], + working_dir: Option<&Path>, + cursor: Option<&SessionListCursor>, + page_size: usize, + ) -> Result { + if types.is_empty() || page_size == 0 { + return Ok(SessionListPage { + sessions: Vec::new(), + next_cursor: None, + }); + } + + let type_placeholders = types.iter().map(|_| "?").collect::>().join(", "); + let mut where_clause = format!("s.session_type IN ({})", type_placeholders); + if working_dir.is_some() { + where_clause.push_str(" AND s.working_dir = ?"); + } + if cursor.is_some() { + where_clause.push_str( + " AND (datetime(s.updated_at) < datetime(?) \ + OR (datetime(s.updated_at) = datetime(?) AND s.id < ?))", + ); + } + + let query = format!( + r#" + SELECT s.id, s.working_dir, s.name, s.description, s.user_set_name, s.session_type, s.created_at, s.updated_at, s.extension_data, + s.total_tokens, s.input_tokens, s.output_tokens, + s.accumulated_total_tokens, s.accumulated_input_tokens, s.accumulated_output_tokens, + s.schedule_id, s.recipe_json, s.user_recipe_values_json, + s.provider_name, s.model_config_json, s.goose_mode, + s.archived_at, s.project_id, + COUNT(m.id) as message_count + FROM sessions s + JOIN messages m ON s.id = m.session_id + WHERE {} + GROUP BY s.id + ORDER BY datetime(s.updated_at) DESC, s.id DESC + LIMIT ? + "#, + where_clause + ); + + let mut q = sqlx::query_as::<_, Session>(&query); + for session_type in types { + q = q.bind(session_type.to_string()); + } + + if let Some(working_dir) = working_dir { + q = q.bind(working_dir.to_string_lossy().to_string()); + } + + if let Some(cursor) = cursor { + let updated_at = cursor.updated_at.to_rfc3339(); + // Normalize mixed SQLite CURRENT_TIMESTAMP and RFC3339 stored values. + q = q.bind(updated_at.clone()); + q = q.bind(updated_at); + q = q.bind(&cursor.session_id); + } + q = q.bind((page_size + 1) as i64); + + let pool = self.pool().await?; + let mut sessions = q.fetch_all(pool).await?; + let has_next_page = sessions.len() > page_size; + let next_cursor = if has_next_page { + let anchor = &sessions[page_size - 1]; + Some(SessionListCursor { + updated_at: anchor.updated_at, + session_id: anchor.id.clone(), + }) + } else { + None + }; + if has_next_page { + sessions.truncate(page_size); + } + + Ok(SessionListPage { + sessions, + next_cursor, + }) + } + async fn list_sessions(&self) -> Result> { self.list_sessions_by_types(Some(&[SessionType::User, SessionType::Scheduled])) .await @@ -1790,6 +1899,74 @@ mod tests { const NUM_CONCURRENT_SESSIONS: i32 = 10; + async fn insert_session_for_list( + sm: &SessionManager, + id: &str, + working_dir: &str, + updated_at: &str, + message_count: usize, + ) { + let pool = sm.storage().pool().await.unwrap(); + let updated_at = chrono::DateTime::parse_from_rfc3339(updated_at).unwrap(); + let timestamp = updated_at.format("%Y-%m-%d %H:%M:%S").to_string(); + + sqlx::query( + "INSERT INTO sessions ( + id, name, user_set_name, session_type, working_dir, created_at, updated_at, extension_data, goose_mode + ) VALUES (?, ?, FALSE, ?, ?, ?, ?, '{}', ?)", + ) + .bind(id) + .bind(format!("Session {id}")) + .bind(SessionType::User.to_string()) + .bind(working_dir) + .bind(×tamp) + .bind(×tamp) + .bind(GooseMode::default().to_string()) + .execute(pool) + .await + .unwrap(); + + for index in 0..message_count { + sqlx::query( + "INSERT INTO messages (message_id, session_id, role, content_json, created_timestamp, metadata_json) + VALUES (?, ?, 'user', '[]', ?, '{}')", + ) + .bind(format!("{id}_{index}")) + .bind(id) + .bind(index as i64) + .execute(pool) + .await + .unwrap(); + } + } + + async fn assert_session_list_page( + sm: &SessionManager, + cursor: Option<&SessionListCursor>, + working_dir: Option<&str>, + page_size: usize, + expected_ids: &[&str], + expected_next_cursor: bool, + ) -> Option { + let page = sm + .list_nonempty_sessions_by_types_paged( + &[SessionType::User], + working_dir.map(Path::new), + cursor, + page_size, + ) + .await + .unwrap(); + let ids: Vec<_> = page + .sessions + .iter() + .map(|session| session.id.as_str()) + .collect(); + assert_eq!(ids, expected_ids); + assert_eq!(page.next_cursor.is_some(), expected_next_cursor); + page.next_cursor + } + async fn run_lock_upgrade_attempt( pool: Pool, session_id: String, @@ -1882,6 +2059,66 @@ mod tests { ); } + #[tokio::test] + async fn test_session_list_paged_first_second_and_final_page() { + let temp_dir = TempDir::new().unwrap(); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); + for index in 1..=5 { + insert_session_for_list( + &sm, + &format!("s{index:03}"), + "/tmp/session-list", + &format!("2024-01-01T00:00:0{index}Z"), + 1, + ) + .await; + } + + let cursor = assert_session_list_page(&sm, None, None, 2, &["s005", "s004"], true).await; + let cursor = + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &["s003", "s002"], true).await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &["s001"], false).await; + } + + #[tokio::test] + async fn test_session_list_paged_uses_id_tiebreaker_for_duplicate_updated_at() { + let temp_dir = TempDir::new().unwrap(); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); + for id in ["s001", "s002", "s003"] { + insert_session_for_list(&sm, id, "/tmp/session-list", "2024-01-01T00:00:00Z", 1).await; + } + + let cursor = assert_session_list_page(&sm, None, None, 2, &["s003", "s002"], true).await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &["s001"], false).await; + } + + #[tokio::test] + async fn test_session_list_paged_filters_empty_and_cwd_before_pagination() { + let temp_dir = TempDir::new().unwrap(); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); + for (id, dir, updated_at, messages) in [ + ("s004", "/tmp/session-list/a", "2024-01-01T00:00:04Z", 1), + ("s003", "/tmp/session-list/a", "2024-01-01T00:00:03Z", 0), + ("s002", "/tmp/session-list/b", "2024-01-01T00:00:02Z", 1), + ("s001", "/tmp/session-list/a", "2024-01-01T00:00:01Z", 1), + ] { + insert_session_for_list(&sm, id, dir, updated_at, messages).await; + } + + let cursor = + assert_session_list_page(&sm, None, Some("/tmp/session-list/a"), 1, &["s004"], true) + .await; + assert_session_list_page( + &sm, + cursor.as_ref(), + Some("/tmp/session-list/a"), + 1, + &["s001"], + false, + ) + .await; + } + #[tokio::test] async fn test_concurrent_session_creation() { let temp_dir = TempDir::new().unwrap(); diff --git a/crates/goose/tests/acp_server_test.rs b/crates/goose/tests/acp_server_test.rs index 72dab982fa10..7d162fea3125 100644 --- a/crates/goose/tests/acp_server_test.rs +++ b/crates/goose/tests/acp_server_test.rs @@ -1,8 +1,10 @@ #[allow(dead_code)] #[path = "acp_common_tests/mod.rs"] mod common_tests; -use common_tests::fixtures::run_test; +use agent_client_protocol::schema::{ListSessionsRequest, ListSessionsResponse}; +use agent_client_protocol::ErrorCode; use common_tests::fixtures::server::AcpServerConnection; +use common_tests::fixtures::{run_test, Connection, OpenAiFixture, TestConnectionConfig}; use common_tests::{ run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set, run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false, @@ -14,10 +16,65 @@ use common_tests::{ run_prompt_mcp, run_prompt_model_mismatch, run_prompt_skill, run_session_name_update_notification, run_shell_terminal_false, run_shell_terminal_true, }; +use goose::config::GooseMode; +use goose::conversation::message::Message; +use goose::session::{SessionManager, SessionType}; +use std::path::Path; tests_config_option_set_error!(AcpServerConnection); tests_mode_set_error!(AcpServerConnection); +async fn seed_list_sessions(data_root: &Path, working_dir: &Path, count: usize) { + let session_manager = SessionManager::new(data_root.to_path_buf()); + for index in 0..count { + let session = session_manager + .create_session( + working_dir.to_path_buf(), + format!("Seed session {index}"), + SessionType::Acp, + GooseMode::default(), + ) + .await + .unwrap(); + session_manager + .add_message(&session.id, &Message::user().with_text("hello")) + .await + .unwrap(); + } +} + +async fn new_connection(data_root: &Path) -> AcpServerConnection { + let openai = OpenAiFixture::new( + vec![], + ::expected_session_id(), + ) + .await; + ::new( + TestConnectionConfig { + data_root: data_root.to_path_buf(), + ..Default::default() + }, + openai, + ) + .await +} + +async fn list_sessions_request( + conn: &AcpServerConnection, + request: ListSessionsRequest, +) -> anyhow::Result { + conn.cx() + .send_request(request) + .block_task() + .await + .map_err(Into::into) +} + +fn assert_invalid_params(error: anyhow::Error) { + let acp_error = error.downcast::().unwrap(); + assert_eq!(acp_error.code, ErrorCode::InvalidParams); +} + #[test] fn test_config_mcp() { run_test(async { run_config_mcp::().await }); @@ -33,6 +90,74 @@ fn test_list_sessions() { run_test(async { run_list_sessions::().await }); } +#[test] +fn test_list_sessions_pagination() { + run_test(async { + let data_root = tempfile::tempdir().unwrap(); + seed_list_sessions(data_root.path(), Path::new("/tmp/acp-session-list"), 51).await; + let conn = new_connection(data_root.path()).await; + + let first = list_sessions_request(&conn, ListSessionsRequest::new()) + .await + .unwrap(); + assert_eq!(first.sessions.len(), 50); + + let second = list_sessions_request( + &conn, + ListSessionsRequest::new().cursor(first.next_cursor.clone().unwrap()), + ) + .await + .unwrap(); + assert_eq!(second.sessions.len(), 1); + assert!(second.next_cursor.is_none()); + + let second_id = &second.sessions[0].session_id; + assert!(first + .sessions + .iter() + .all(|session| session.session_id != *second_id)); + }); +} + +#[test] +fn test_list_sessions_invalid_params() { + run_test(async { + let data_root = tempfile::tempdir().unwrap(); + let cwd = tempfile::tempdir().unwrap(); + let other_cwd = tempfile::tempdir().unwrap(); + seed_list_sessions(data_root.path(), cwd.path(), 51).await; + let conn = new_connection(data_root.path()).await; + + let error = + list_sessions_request(&conn, ListSessionsRequest::new().cursor("*".to_string())) + .await + .unwrap_err(); + assert_invalid_params(error); + + let error = list_sessions_request( + &conn, + ListSessionsRequest::new().cwd(std::path::PathBuf::from("relative/path")), + ) + .await + .unwrap_err(); + assert_invalid_params(error); + + let first = list_sessions_request(&conn, ListSessionsRequest::new().cwd(cwd.path())) + .await + .unwrap(); + + let error = list_sessions_request( + &conn, + ListSessionsRequest::new() + .cwd(other_cwd.path()) + .cursor(first.next_cursor.unwrap()), + ) + .await + .unwrap_err(); + assert_invalid_params(error); + }); +} + #[test] fn test_session_name_update_notification() { run_test(async { run_session_name_update_notification::().await }); From 02d338a3d65f26f394b15ea67f84710d9279e484 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 18 May 2026 09:20:44 -0700 Subject: [PATCH 2/2] fix(acp): address session list review share session list query construction across plain and paged list paths. bind pagination cursors to the effective list filters and update the pagination tests to create sessions through session manager APIs. Signed-off-by: Kalvin Chau --- crates/goose/src/acp/server.rs | 53 +++- crates/goose/src/session/session_manager.rs | 299 +++++++++++--------- 2 files changed, 212 insertions(+), 140 deletions(-) diff --git a/crates/goose/src/acp/server.rs b/crates/goose/src/acp/server.rs index d8a476cda8fb..506d706113d0 100644 --- a/crates/goose/src/acp/server.rs +++ b/crates/goose/src/acp/server.rs @@ -60,6 +60,7 @@ use rmcp::model::{ AnnotateAble, CallToolResult, RawContent, RawTextContent, ResourceContents, Role, }; use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use std::collections::{HashMap, HashSet}; use std::panic::AssertUnwindSafe; use std::sync::Arc; @@ -93,6 +94,8 @@ pub type AcpProviderFactory = Arc< >; const SESSION_LIST_PAGE_SIZE: usize = 50; +const ACP_SESSION_LIST_TYPES: [SessionType; 3] = + [SessionType::User, SessionType::Scheduled, SessionType::Acp]; /// Convenience conversions from any `Display` error into an `agent_client_protocol::Error`. /// @@ -262,17 +265,45 @@ struct SessionListCursorToken { // Goose stores updated_at with second precision in common write paths, so the // cursor needs the full (updated_at, id) sort key to avoid skipping tied rows. session_id: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - cwd: Option, + filter_hash: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct SessionListCursorFilters { + cwd: Option, + session_types: Vec, + non_empty: bool, } fn invalid_session_list_cursor(message: &'static str) -> agent_client_protocol::Error { agent_client_protocol::Error::invalid_params().data(message) } +// bind cursors to the effective filters so they cannot be reused for a different list. +fn session_list_filter_hash( + cwd: Option<&std::path::Path>, + session_types: &[SessionType], +) -> Result { + let mut session_type_names = session_types + .iter() + .map(ToString::to_string) + .collect::>(); + session_type_names.sort(); + let filters = SessionListCursorFilters { + cwd: cwd.map(|path| path.to_string_lossy().to_string()), + session_types: session_type_names, + non_empty: true, + }; + let bytes = + serde_json::to_vec(&filters).internal_err_ctx("Failed to encode session list filters")?; + Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(bytes))) +} + fn decode_session_list_cursor( cursor: Option<&str>, cwd: Option<&std::path::Path>, + session_types: &[SessionType], ) -> Result, agent_client_protocol::Error> { let Some(cursor) = cursor else { return Ok(None); @@ -284,14 +315,14 @@ fn decode_session_list_cursor( let token: SessionListCursorToken = serde_json::from_slice(&bytes) .map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?; - if token.session_id.is_empty() { + if token.session_id.is_empty() || token.filter_hash.is_empty() { return Err(invalid_session_list_cursor("malformed session list cursor")); } - let requested_cwd = cwd.map(std::path::Path::to_path_buf); - if token.cwd != requested_cwd { + let expected_filter_hash = session_list_filter_hash(cwd, session_types)?; + if token.filter_hash != expected_filter_hash { return Err(invalid_session_list_cursor( - "session list cursor does not match cwd", + "session list cursor does not match filters", )); } @@ -304,11 +335,12 @@ fn decode_session_list_cursor( fn encode_session_list_cursor( cursor: &SessionListCursor, cwd: Option<&std::path::Path>, + session_types: &[SessionType], ) -> Result { let token = SessionListCursorToken { updated_at: cursor.updated_at, session_id: cursor.session_id.clone(), - cwd: cwd.map(std::path::Path::to_path_buf), + filter_hash: session_list_filter_hash(cwd, session_types)?, }; let bytes = serde_json::to_vec(&token).internal_err_ctx("Failed to encode session list cursor")?; @@ -3380,13 +3412,14 @@ impl GooseAcpAgent { } let cwd = req.cwd.as_deref(); - let cursor = decode_session_list_cursor(req.cursor.as_deref(), cwd)?; + let cursor = + decode_session_list_cursor(req.cursor.as_deref(), cwd, &ACP_SESSION_LIST_TYPES)?; // ACP clients see their own (Acp) sessions plus legacy User/Scheduled ones. let page = self .session_manager .list_nonempty_sessions_by_types_paged( - &[SessionType::User, SessionType::Scheduled, SessionType::Acp], + &ACP_SESSION_LIST_TYPES, cwd, cursor.as_ref(), SESSION_LIST_PAGE_SIZE, @@ -3407,7 +3440,7 @@ impl GooseAcpAgent { let next_cursor = page .next_cursor .as_ref() - .map(|cursor| encode_session_list_cursor(cursor, cwd)) + .map(|cursor| encode_session_list_cursor(cursor, cwd, &ACP_SESSION_LIST_TYPES)) .transpose()?; Ok(ListSessionsResponse::new(session_infos).next_cursor(next_cursor)) } diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 67229d51cfb5..6090a3041283 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -278,6 +278,15 @@ pub(crate) struct SessionListPage { pub(crate) next_cursor: Option, } +#[derive(Debug, Default)] +struct SessionListQuery<'a> { + types: Option<&'a [SessionType]>, + working_dir: Option<&'a Path>, + cursor: Option<&'a SessionListCursor>, + limit: Option, + require_messages: bool, +} + #[derive(Debug, Clone)] pub struct SessionNameUpdate { pub session_id: String, @@ -1451,17 +1460,46 @@ impl SessionStorage { Self::replace_conversation_inner(pool, session_id, conversation).await } - async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result> { - let (where_clause, binds): (String, Vec) = match types { - Some(t) if !t.is_empty() => { - let placeholders: String = t.iter().map(|_| "?").collect::>().join(", "); - ( - format!("WHERE s.session_type IN ({})", placeholders), - t.iter().map(|t| t.to_string()).collect(), - ) - } - Some(_) => return Ok(Vec::new()), - None => (String::new(), Vec::new()), + async fn list_sessions_matching(&self, options: SessionListQuery<'_>) -> Result> { + if matches!(options.types, Some(types) if types.is_empty()) { + return Ok(Vec::new()); + } + + let mut where_clauses = Vec::new(); + if let Some(types) = options.types { + let placeholders = types.iter().map(|_| "?").collect::>().join(", "); + where_clauses.push(format!("s.session_type IN ({})", placeholders)); + } + if options.working_dir.is_some() { + where_clauses.push("s.working_dir = ?".to_string()); + } + if options.cursor.is_some() { + where_clauses.push( + "(datetime(s.updated_at) < datetime(?) \ + OR (datetime(s.updated_at) = datetime(?) AND s.id < ?))" + .to_string(), + ); + } + + let where_clause = if where_clauses.is_empty() { + String::new() + } else { + format!("WHERE {}", where_clauses.join(" AND ")) + }; + let message_join = if options.require_messages { + "JOIN messages m ON s.id = m.session_id" + } else { + "LEFT JOIN messages m ON s.id = m.session_id" + }; + let order_by = if options.cursor.is_some() || options.limit.is_some() { + "ORDER BY datetime(s.updated_at) DESC, s.id DESC" + } else { + "ORDER BY s.updated_at DESC" + }; + let limit_clause = if options.limit.is_some() { + "LIMIT ?" + } else { + "" }; let query = format!( @@ -1474,23 +1512,47 @@ impl SessionStorage { s.archived_at, s.project_id, COUNT(m.id) as message_count FROM sessions s - LEFT JOIN messages m ON s.id = m.session_id + {} {} GROUP BY s.id - ORDER BY s.updated_at DESC + {} + {} "#, - where_clause + message_join, where_clause, order_by, limit_clause ); let mut q = sqlx::query_as::<_, Session>(&query); - for b in &binds { - q = q.bind(b); + if let Some(types) = options.types { + for session_type in types { + q = q.bind(session_type.to_string()); + } + } + if let Some(working_dir) = options.working_dir { + q = q.bind(working_dir.to_string_lossy().to_string()); + } + if let Some(cursor) = options.cursor { + let updated_at = cursor.updated_at.to_rfc3339(); + // Normalize mixed SQLite CURRENT_TIMESTAMP and RFC3339 stored values. + q = q.bind(updated_at.clone()); + q = q.bind(updated_at); + q = q.bind(&cursor.session_id); + } + if let Some(limit) = options.limit { + q = q.bind(limit as i64); } let pool = self.pool().await?; q.fetch_all(pool).await.map_err(Into::into) } + async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result> { + self.list_sessions_matching(SessionListQuery { + types, + ..Default::default() + }) + .await + } + async fn list_nonempty_sessions_by_types_paged( &self, types: &[SessionType], @@ -1505,57 +1567,15 @@ impl SessionStorage { }); } - let type_placeholders = types.iter().map(|_| "?").collect::>().join(", "); - let mut where_clause = format!("s.session_type IN ({})", type_placeholders); - if working_dir.is_some() { - where_clause.push_str(" AND s.working_dir = ?"); - } - if cursor.is_some() { - where_clause.push_str( - " AND (datetime(s.updated_at) < datetime(?) \ - OR (datetime(s.updated_at) = datetime(?) AND s.id < ?))", - ); - } - - let query = format!( - r#" - SELECT s.id, s.working_dir, s.name, s.description, s.user_set_name, s.session_type, s.created_at, s.updated_at, s.extension_data, - s.total_tokens, s.input_tokens, s.output_tokens, - s.accumulated_total_tokens, s.accumulated_input_tokens, s.accumulated_output_tokens, - s.schedule_id, s.recipe_json, s.user_recipe_values_json, - s.provider_name, s.model_config_json, s.goose_mode, - s.archived_at, s.project_id, - COUNT(m.id) as message_count - FROM sessions s - JOIN messages m ON s.id = m.session_id - WHERE {} - GROUP BY s.id - ORDER BY datetime(s.updated_at) DESC, s.id DESC - LIMIT ? - "#, - where_clause - ); - - let mut q = sqlx::query_as::<_, Session>(&query); - for session_type in types { - q = q.bind(session_type.to_string()); - } - - if let Some(working_dir) = working_dir { - q = q.bind(working_dir.to_string_lossy().to_string()); - } - - if let Some(cursor) = cursor { - let updated_at = cursor.updated_at.to_rfc3339(); - // Normalize mixed SQLite CURRENT_TIMESTAMP and RFC3339 stored values. - q = q.bind(updated_at.clone()); - q = q.bind(updated_at); - q = q.bind(&cursor.session_id); - } - q = q.bind((page_size + 1) as i64); - - let pool = self.pool().await?; - let mut sessions = q.fetch_all(pool).await?; + let mut sessions = self + .list_sessions_matching(SessionListQuery { + types: Some(types), + working_dir, + cursor, + limit: Some(page_size + 1), + require_messages: true, + }) + .await?; let has_next_page = sessions.len() > page_size; let next_cursor = if has_next_page { let anchor = &sessions[page_size - 1]; @@ -1899,45 +1919,60 @@ mod tests { const NUM_CONCURRENT_SESSIONS: i32 = 10; - async fn insert_session_for_list( + async fn create_session_for_list( sm: &SessionManager, - id: &str, working_dir: &str, + has_message: bool, + ) -> String { + let session = sm + .create_session( + PathBuf::from(working_dir), + format!("Session in {working_dir}"), + SessionType::User, + GooseMode::default(), + ) + .await + .unwrap(); + + if has_message { + sm.add_message(&session.id, &Message::user().with_text("message")) + .await + .unwrap(); + } + + session.id + } + + async fn set_sessions_updated_at( + sm: &SessionManager, + session_ids: &[String], updated_at: &str, - message_count: usize, ) { let pool = sm.storage().pool().await.unwrap(); let updated_at = chrono::DateTime::parse_from_rfc3339(updated_at).unwrap(); let timestamp = updated_at.format("%Y-%m-%d %H:%M:%S").to_string(); - sqlx::query( - "INSERT INTO sessions ( - id, name, user_set_name, session_type, working_dir, created_at, updated_at, extension_data, goose_mode - ) VALUES (?, ?, FALSE, ?, ?, ?, ?, '{}', ?)", - ) - .bind(id) - .bind(format!("Session {id}")) - .bind(SessionType::User.to_string()) - .bind(working_dir) - .bind(×tamp) - .bind(×tamp) - .bind(GooseMode::default().to_string()) - .execute(pool) - .await - .unwrap(); + for session_id in session_ids { + sqlx::query("UPDATE sessions SET updated_at = ? WHERE id = ?") + .bind(×tamp) + .bind(session_id) + .execute(pool) + .await + .unwrap(); + } + } - for index in 0..message_count { - sqlx::query( - "INSERT INTO messages (message_id, session_id, role, content_json, created_timestamp, metadata_json) - VALUES (?, ?, 'user', '[]', ?, '{}')", - ) - .bind(format!("{id}_{index}")) - .bind(id) - .bind(index as i64) - .execute(pool) - .await - .unwrap(); + async fn expected_session_list_ids(sm: &SessionManager, session_ids: &[String]) -> Vec { + let mut sessions = Vec::new(); + for session_id in session_ids { + sessions.push(sm.get_session(session_id, false).await.unwrap()); } + sessions.sort_by(|a, b| { + b.updated_at + .cmp(&a.updated_at) + .then_with(|| b.id.cmp(&a.id)) + }); + sessions.into_iter().map(|session| session.id).collect() } async fn assert_session_list_page( @@ -1945,7 +1980,7 @@ mod tests { cursor: Option<&SessionListCursor>, working_dir: Option<&str>, page_size: usize, - expected_ids: &[&str], + expected_ids: &[String], expected_next_cursor: bool, ) -> Option { let page = sm @@ -1957,12 +1992,12 @@ mod tests { ) .await .unwrap(); - let ids: Vec<_> = page + let ids = page .sessions .iter() - .map(|session| session.id.as_str()) - .collect(); - assert_eq!(ids, expected_ids); + .map(|session| session.id.clone()) + .collect::>(); + assert_eq!(ids.as_slice(), expected_ids); assert_eq!(page.next_cursor.is_some(), expected_next_cursor); page.next_cursor } @@ -2063,57 +2098,61 @@ mod tests { async fn test_session_list_paged_first_second_and_final_page() { let temp_dir = TempDir::new().unwrap(); let sm = SessionManager::new(temp_dir.path().to_path_buf()); - for index in 1..=5 { - insert_session_for_list( - &sm, - &format!("s{index:03}"), - "/tmp/session-list", - &format!("2024-01-01T00:00:0{index}Z"), - 1, - ) - .await; + let mut expected_ids = Vec::new(); + for _ in 0..5 { + expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await); } + let expected_ids = expected_session_list_ids(&sm, &expected_ids).await; - let cursor = assert_session_list_page(&sm, None, None, 2, &["s005", "s004"], true).await; + let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await; let cursor = - assert_session_list_page(&sm, cursor.as_ref(), None, 2, &["s003", "s002"], true).await; - assert_session_list_page(&sm, cursor.as_ref(), None, 2, &["s001"], false).await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..4], true) + .await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[4..5], false).await; } #[tokio::test] async fn test_session_list_paged_uses_id_tiebreaker_for_duplicate_updated_at() { let temp_dir = TempDir::new().unwrap(); let sm = SessionManager::new(temp_dir.path().to_path_buf()); - for id in ["s001", "s002", "s003"] { - insert_session_for_list(&sm, id, "/tmp/session-list", "2024-01-01T00:00:00Z", 1).await; + let mut expected_ids = Vec::new(); + for _ in 0..3 { + expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await); } + set_sessions_updated_at(&sm, &expected_ids, "2024-01-01T00:00:00Z").await; + let expected_ids = expected_session_list_ids(&sm, &expected_ids).await; - let cursor = assert_session_list_page(&sm, None, None, 2, &["s003", "s002"], true).await; - assert_session_list_page(&sm, cursor.as_ref(), None, 2, &["s001"], false).await; + let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await; + assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..3], false).await; } #[tokio::test] async fn test_session_list_paged_filters_empty_and_cwd_before_pagination() { let temp_dir = TempDir::new().unwrap(); let sm = SessionManager::new(temp_dir.path().to_path_buf()); - for (id, dir, updated_at, messages) in [ - ("s004", "/tmp/session-list/a", "2024-01-01T00:00:04Z", 1), - ("s003", "/tmp/session-list/a", "2024-01-01T00:00:03Z", 0), - ("s002", "/tmp/session-list/b", "2024-01-01T00:00:02Z", 1), - ("s001", "/tmp/session-list/a", "2024-01-01T00:00:01Z", 1), - ] { - insert_session_for_list(&sm, id, dir, updated_at, messages).await; - } - - let cursor = - assert_session_list_page(&sm, None, Some("/tmp/session-list/a"), 1, &["s004"], true) - .await; + let expected_ids = vec![ + create_session_for_list(&sm, "/tmp/session-list/a", true).await, + create_session_for_list(&sm, "/tmp/session-list/a", true).await, + ]; + create_session_for_list(&sm, "/tmp/session-list/a", false).await; + create_session_for_list(&sm, "/tmp/session-list/b", true).await; + let expected_ids = expected_session_list_ids(&sm, &expected_ids).await; + + let cursor = assert_session_list_page( + &sm, + None, + Some("/tmp/session-list/a"), + 1, + &expected_ids[0..1], + true, + ) + .await; assert_session_list_page( &sm, cursor.as_ref(), Some("/tmp/session-list/a"), 1, - &["s001"], + &expected_ids[1..2], false, ) .await;