Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 126 additions & 8 deletions crates/goose/src/acp/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::providers::inventory::{
InventoryIdentity, ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan,
RefreshPlan, RefreshSkipReason,
};
use crate::session::session_manager::SessionType;
use crate::session::session_manager::{SessionListCursor, SessionType};
use crate::session::{EnabledExtensionsState, Session, SessionManager};
use crate::source_roots::SourceRoot;
use crate::utils::sanitize_unicode_tags;
Expand Down Expand Up @@ -51,14 +51,16 @@ use agent_client_protocol::{
Responder,
};
use anyhow::Result;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use fs_err as fs;
use futures::future::{BoxFuture, Either};
use futures::stream::{self, StreamExt};
use futures::FutureExt;
use rmcp::model::{
AnnotateAble, CallToolResult, RawContent, RawTextContent, ResourceContents, Role,
};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
Expand Down Expand Up @@ -91,6 +93,10 @@ pub type AcpProviderFactory = Arc<
+ Sync,
>;

const SESSION_LIST_PAGE_SIZE: usize = 50;
const ACP_SESSION_LIST_TYPES: [SessionType; 3] =
[SessionType::User, SessionType::Scheduled, SessionType::Acp];

/// Convenience conversions from any `Display` error into an `agent_client_protocol::Error`.
///
/// Replaces the repetitive `.internal_err()`
Expand Down Expand Up @@ -253,6 +259,94 @@ fn sid_short(id: &str) -> String {
id.chars().take(8).collect()
}

#[derive(Debug, Serialize, Deserialize)]
struct SessionListCursorToken {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably define these in crates/goose-sdk so the client gets things typegened for it on the TS side as well

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatted with @alexhancock - this is the internal server cursor token, which we send back up to the client as an opaque token string, so we don't need/want the clients to use this

updated_at: chrono::DateTime<chrono::Utc>,
// Goose stores updated_at with second precision in common write paths, so the
// cursor needs the full (updated_at, id) sort key to avoid skipping tied rows.
session_id: String,
filter_hash: String,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct SessionListCursorFilters {
cwd: Option<String>,
session_types: Vec<String>,
non_empty: bool,
}

fn invalid_session_list_cursor(message: &'static str) -> agent_client_protocol::Error {
agent_client_protocol::Error::invalid_params().data(message)
}

// bind cursors to the effective filters so they cannot be reused for a different list.
fn session_list_filter_hash(
cwd: Option<&std::path::Path>,
session_types: &[SessionType],
) -> Result<String, agent_client_protocol::Error> {
let mut session_type_names = session_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>();
session_type_names.sort();
let filters = SessionListCursorFilters {
cwd: cwd.map(|path| path.to_string_lossy().to_string()),
session_types: session_type_names,
non_empty: true,
};
let bytes =
serde_json::to_vec(&filters).internal_err_ctx("Failed to encode session list filters")?;
Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(bytes)))
}

fn decode_session_list_cursor(
cursor: Option<&str>,
cwd: Option<&std::path::Path>,
session_types: &[SessionType],
) -> Result<Option<SessionListCursor>, agent_client_protocol::Error> {
let Some(cursor) = cursor else {
return Ok(None);
};

let bytes = URL_SAFE_NO_PAD
.decode(cursor)
.map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?;
let token: SessionListCursorToken = serde_json::from_slice(&bytes)
.map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?;

if token.session_id.is_empty() || token.filter_hash.is_empty() {
return Err(invalid_session_list_cursor("malformed session list cursor"));
}

let expected_filter_hash = session_list_filter_hash(cwd, session_types)?;
if token.filter_hash != expected_filter_hash {
return Err(invalid_session_list_cursor(
"session list cursor does not match filters",
));
}

Ok(Some(SessionListCursor {
updated_at: token.updated_at,
session_id: token.session_id,
}))
}

fn encode_session_list_cursor(
cursor: &SessionListCursor,
cwd: Option<&std::path::Path>,
session_types: &[SessionType],
) -> Result<String, agent_client_protocol::Error> {
let token = SessionListCursorToken {
updated_at: cursor.updated_at,
session_id: cursor.session_id.clone(),
filter_hash: session_list_filter_hash(cwd, session_types)?,
};
let bytes =
serde_json::to_vec(&token).internal_err_ctx("Failed to encode session list cursor")?;
Ok(URL_SAFE_NO_PAD.encode(bytes))
}

fn session_meta(session: &Session) -> serde_json::Map<String, serde_json::Value> {
let mut meta = serde_json::Map::new();
meta.insert(
Expand Down Expand Up @@ -3306,16 +3400,35 @@ impl GooseAcpAgent {
Ok(())
}

async fn on_list_sessions(&self) -> Result<ListSessionsResponse, agent_client_protocol::Error> {
async fn on_list_sessions(
&self,
req: ListSessionsRequest,
) -> Result<ListSessionsResponse, agent_client_protocol::Error> {
if let Some(cwd) = req.cwd.as_deref() {
if !cwd.is_absolute() {
return Err(agent_client_protocol::Error::invalid_params()
.data("cwd must be an absolute path"));
}
}

let cwd = req.cwd.as_deref();
let cursor =
decode_session_list_cursor(req.cursor.as_deref(), cwd, &ACP_SESSION_LIST_TYPES)?;

// ACP clients see their own (Acp) sessions plus legacy User/Scheduled ones.
let sessions = self
let page = self
.session_manager
.list_sessions_by_types(&[SessionType::User, SessionType::Scheduled, SessionType::Acp])
.list_nonempty_sessions_by_types_paged(
&ACP_SESSION_LIST_TYPES,
cwd,
cursor.as_ref(),
SESSION_LIST_PAGE_SIZE,
)
.await
.internal_err()?;
let session_infos: Vec<SessionInfo> = sessions
let session_infos: Vec<SessionInfo> = page
.sessions
.into_iter()
.filter(|s| s.message_count > 0)
.map(|s| {
let meta = session_meta(&s);
SessionInfo::new(SessionId::new(s.id), s.working_dir)
Expand All @@ -3324,7 +3437,12 @@ impl GooseAcpAgent {
.meta(meta)
})
.collect();
Ok(ListSessionsResponse::new(session_infos))
let next_cursor = page
.next_cursor
.as_ref()
.map(|cursor| encode_session_list_cursor(cursor, cwd, &ACP_SESSION_LIST_TYPES))
.transpose()?;
Ok(ListSessionsResponse::new(session_infos).next_cursor(next_cursor))
}

async fn on_fork_session(
Expand Down
7 changes: 5 additions & 2 deletions crates/goose/src/acp/server/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,12 @@ impl HandleDispatchFrom<Client> for GooseAcpHandler {
.if_request({
let agent = agent.clone();
let cx = cx.clone();
|_req: ListSessionsRequest, responder: Responder<ListSessionsResponse>| async move {
|req: ListSessionsRequest, responder: Responder<ListSessionsResponse>| async move {
cx.spawn(async move {
responder.respond(agent.on_list_sessions().await?)?;
match agent.on_list_sessions(req).await {
Ok(response) => responder.respond(response)?,
Err(e) => responder.respond_with_error(e)?,
}
Ok(())
})?;
Ok(())
Expand Down
Loading
Loading