Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions crates/goose/src/acp/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use rmcp::model::{
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use std::panic::AssertUnwindSafe;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use strum::{EnumMessage, VariantNames};
use tokio::sync::{Mutex, OnceCell};
Expand All @@ -86,6 +87,7 @@ pub type AcpProviderFactory = Arc<
String,
crate::model::ModelConfig,
Vec<ExtensionConfig>,
Option<PathBuf>,
) -> BoxFuture<'static, Result<Arc<dyn Provider>>>
+ Send
+ Sync,
Expand Down Expand Up @@ -1044,6 +1046,20 @@ fn build_usage_update(session: &Session, context_limit: usize) -> UsageUpdate {
UsageUpdate::new(used, context_limit as u64)
}

fn validate_absolute_cwd(cwd: &Path) -> Result<(), agent_client_protocol::Error> {
if !cwd.is_absolute() {
return Err(
agent_client_protocol::Error::invalid_params().data("cwd must be an absolute path")
);
}

if !cwd.exists() || !cwd.is_dir() {
return Err(agent_client_protocol::Error::invalid_params().data("invalid directory path"));
}

Ok(())
}

impl GooseAcpAgent {
pub fn permission_manager(&self) -> Arc<PermissionManager> {
Arc::clone(&self.permission_manager)
Expand Down Expand Up @@ -1093,8 +1109,15 @@ impl GooseAcpAgent {
provider_name: &str,
model_config: crate::model::ModelConfig,
extensions: Vec<ExtensionConfig>,
working_dir: Option<PathBuf>,
) -> Result<Arc<dyn Provider>> {
(self.provider_factory)(provider_name.to_string(), model_config, extensions).await
(self.provider_factory)(
provider_name.to_string(),
model_config,
extensions,
working_dir,
)
.await
}

async fn prepare_session_init_config(
Expand Down Expand Up @@ -1131,7 +1154,12 @@ impl GooseAcpAgent {
);
Config::global().invalidate_secrets_cache();
match self
.create_provider(provider_name, model_config.clone(), ext_state)
.create_provider(
provider_name,
model_config.clone(),
ext_state,
Some(goose_session.working_dir.clone()),
)
.await
{
Ok(provider) => {
Expand Down Expand Up @@ -1348,9 +1376,14 @@ impl GooseAcpAgent {
);
let provider = match prebuilt_provider {
Some(provider) => provider,
None => provider_factory(provider_name.to_string(), model_config, ext_state)
.await
.map_err(|e| e.to_string())?,
None => provider_factory(
provider_name.to_string(),
model_config,
ext_state,
Some(goose_session.working_dir.clone()),
)
.await
.map_err(|e| e.to_string())?,
};
agent
.update_provider(provider.clone(), &goose_session.id)
Expand Down Expand Up @@ -1440,15 +1473,17 @@ impl GooseAcpAgent {
}

let ext_manager = &agent.extension_manager;
let working_dir = goose_session.working_dir.clone();
let extension_futures = extensions
.into_iter()
.map(|ext| {
let ext_manager = Arc::clone(ext_manager);
let sid_inner = sid_str.clone();
let working_dir = working_dir.clone();
async move {
let name = ext.name().to_string();
if let Err(e) = ext_manager
.add_extension(ext, None, None, sid_inner.as_deref())
.add_extension(ext, Some(working_dir), None, sid_inner.as_deref())
.await
{
warn!(extension = %name, error = %e, "extension load failed");
Expand Down Expand Up @@ -2412,6 +2447,7 @@ impl GooseAcpAgent {
) -> Result<NewSessionResponse, agent_client_protocol::Error> {
debug!(?args, "new session request");
let t_start = std::time::Instant::now();
validate_absolute_cwd(&args.cwd)?;

let requested_provider = args
.meta
Expand Down Expand Up @@ -2664,6 +2700,7 @@ impl GooseAcpAgent {
args: LoadSessionRequest,
) -> Result<LoadSessionResponse, agent_client_protocol::Error> {
debug!(?args, "load session request");
validate_absolute_cwd(&args.cwd)?;

let session_id = args.session_id.0.to_string();
let sid = sid_short(&session_id);
Expand Down Expand Up @@ -2835,6 +2872,11 @@ impl GooseAcpAgent {
.apply()
.await
.internal_err_ctx("Failed to update session working directory")?;
let goose_session = self
.session_manager
.get_session(&session_id, false)
.await
.internal_err_ctx("Failed to reload session")?;

// Register the session with a Loading handle.
let (agent_tx, agent_rx) = tokio::sync::watch::channel::<AgentSetupSignal>(None);
Expand Down Expand Up @@ -3137,8 +3179,18 @@ impl GooseAcpAgent {
let model_config = crate::model::ModelConfig::new(model_id)
.invalid_params_err_ctx("Invalid model config")?
.with_canonical_limits(&provider_name);
let session = self
.session_manager
.get_session(session_id, false)
.await
.internal_err_ctx("Failed to get session")?;
let provider = self
.create_provider(&provider_name, model_config, extensions)
.create_provider(
&provider_name,
model_config,
extensions,
Some(session.working_dir),
)
.await
.internal_err_ctx("Failed to create provider")?;
agent
Expand Down Expand Up @@ -3264,8 +3316,18 @@ impl GooseAcpAgent {

let extensions =
EnabledExtensionsState::for_session(&self.session_manager, session_id, &config).await;
let session = self
.session_manager
.get_session(session_id, false)
.await
.internal_err_ctx("Failed to get session")?;
let new_provider = self
.create_provider(&resolved_provider_name, model_config, extensions)
.create_provider(
&resolved_provider_name,
model_config,
extensions,
Some(session.working_dir),
)
.await
.internal_err_ctx("Failed to create provider")?;
agent
Expand Down Expand Up @@ -3332,6 +3394,7 @@ impl GooseAcpAgent {
cx: &ConnectionTo<Client>,
args: ForkSessionRequest,
) -> Result<ForkSessionResponse, agent_client_protocol::Error> {
validate_absolute_cwd(&args.cwd)?;
let source_session_id = &*args.session_id.0;

let new_session = self
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/acp/server/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ impl GooseAcpAgent {
let model_config =
crate::model::ModelConfig::new(&metadata.metadata().default_model)?
.with_canonical_limits(&provider_id);
provider_factory(provider_id.clone(), model_config, Vec::new()).await
provider_factory(provider_id.clone(), model_config, Vec::new(), None).await
})
.catch_unwind()
.await;
Expand Down
6 changes: 1 addition & 5 deletions crates/goose/src/acp/server/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ impl GooseAcpAgent {
.data("working directory cannot be empty"));
}
let path = std::path::PathBuf::from(&working_dir);
if !path.exists() || !path.is_dir() {
return Err(
agent_client_protocol::Error::invalid_params().data("invalid directory path")
);
}
validate_absolute_cwd(&path)?;
let session_id = &req.session_id;
self.session_manager
.update(session_id)
Expand Down
22 changes: 18 additions & 4 deletions crates/goose/src/acp/server_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,26 @@ impl AcpServer {
.unwrap_or(crate::config::GooseMode::Auto);
let disable_session_naming = config.get_goose_disable_session_naming().unwrap_or(false);

let provider_factory: AcpProviderFactory =
Arc::new(move |provider_name, model_config, extensions| {
let provider_factory: AcpProviderFactory = Arc::new(
move |provider_name, model_config, extensions, working_dir| {
Box::pin(async move {
crate::providers::create(&provider_name, model_config, extensions).await
match working_dir {
Some(working_dir) => {
crate::providers::create_with_working_dir(
&provider_name,
model_config,
extensions,
working_dir,
)
.await
}
None => {
crate::providers::create(&provider_name, model_config, extensions).await
}
}
})
});
},
);

let agent = GooseAcpAgent::new(GooseAcpAgentOptions {
provider_factory,
Expand Down
12 changes: 10 additions & 2 deletions crates/goose/src/providers/amp_acp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::base::{current_working_dir, ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;

const AMP_ACP_PROVIDER_NAME: &str = "amp-acp";
Expand Down Expand Up @@ -45,6 +45,14 @@ impl ProviderDef for AmpAcpProvider {
fn from_env(
model: ModelConfig,
extensions: Vec<crate::config::ExtensionConfig>,
) -> BoxFuture<'static, Result<AcpProvider>> {
Self::from_env_with_working_dir(model, extensions, current_working_dir())
}

fn from_env_with_working_dir(
model: ModelConfig,
extensions: Vec<crate::config::ExtensionConfig>,
working_dir: PathBuf,
) -> BoxFuture<'static, Result<AcpProvider>> {
Box::pin(async move {
let config = Config::global();
Expand All @@ -65,7 +73,7 @@ impl ProviderDef for AmpAcpProvider {
args: vec![],
env: vec![],
env_remove: vec![],
work_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
work_dir: working_dir,
mcp_servers: extension_configs_to_mcp_servers(&extensions),
session_mode_id: Some(mode_mapping[&goose_mode].clone()),
mode_mapping,
Expand Down
18 changes: 18 additions & 0 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use utoipa::ToSchema;
use once_cell::sync::Lazy;
use regex::Regex;
use std::ops::{Add, AddAssign};
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::LazyLock;
use std::sync::Mutex;
Expand Down Expand Up @@ -766,6 +767,10 @@ impl Usage {
}
}

pub(crate) fn current_working_dir() -> PathBuf {
std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
}

pub trait ProviderDef: Send + Sync {
type Provider: Provider + 'static;

Expand All @@ -780,6 +785,19 @@ pub trait ProviderDef: Send + Sync {
where
Self: Sized;

fn from_env_with_working_dir(
model: ModelConfig,
extensions: Vec<ExtensionConfig>,
_working_dir: PathBuf,
) -> BoxFuture<'static, Result<Self::Provider>>
where
Self: Sized,
{
// ACP subprocess providers must override this so session cwd is preserved.
// Non-subprocess providers can rely on the default because cwd is irrelevant.
Self::from_env(model, extensions)
}

fn supports_inventory_refresh() -> bool
where
Self: Sized,
Expand Down
12 changes: 10 additions & 2 deletions crates/goose/src/providers/claude_acp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::base::{current_working_dir, ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;

const CLAUDE_ACP_PROVIDER_NAME: &str = "claude-acp";
Expand Down Expand Up @@ -43,6 +43,14 @@ impl ProviderDef for ClaudeAcpProvider {
fn from_env(
model: ModelConfig,
extensions: Vec<crate::config::ExtensionConfig>,
) -> BoxFuture<'static, Result<AcpProvider>> {
Self::from_env_with_working_dir(model, extensions, current_working_dir())
}

fn from_env_with_working_dir(
model: ModelConfig,
extensions: Vec<crate::config::ExtensionConfig>,
working_dir: PathBuf,
) -> BoxFuture<'static, Result<AcpProvider>> {
Box::pin(async move {
let config = Config::global();
Expand All @@ -69,7 +77,7 @@ impl ProviderDef for ClaudeAcpProvider {
env: vec![],
// Prevent nested-session detection in claude-agent-acp (wraps Claude Code)
env_remove: vec!["CLAUDECODE".to_string()],
work_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
work_dir: working_dir,
mcp_servers: extension_configs_to_mcp_servers(&extensions),
session_mode_id: Some(mode_mapping[&goose_mode].clone()),
mode_mapping,
Expand Down
13 changes: 10 additions & 3 deletions crates/goose/src/providers/codex_acp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::base::{current_working_dir, ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;

const CODEX_ACP_PROVIDER_NAME: &str = "codex-acp";
Expand Down Expand Up @@ -42,14 +42,21 @@ impl ProviderDef for CodexAcpProvider {
fn from_env(
model: ModelConfig,
extensions: Vec<crate::config::ExtensionConfig>,
) -> BoxFuture<'static, Result<AcpProvider>> {
Self::from_env_with_working_dir(model, extensions, current_working_dir())
}

fn from_env_with_working_dir(
model: ModelConfig,
extensions: Vec<crate::config::ExtensionConfig>,
working_dir: PathBuf,
) -> BoxFuture<'static, Result<AcpProvider>> {
Box::pin(async move {
let config = Config::global();
// with_npm() includes npm global bin dir (desktop app PATH may not)
let resolved_command = SearchPaths::builder()
.with_npm()
.resolve(CODEX_ACP_PROVIDER_NAME)?;
let work_dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
let env = vec![];
let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto);
let mcp_servers = extension_configs_to_mcp_servers(&extensions);
Expand Down Expand Up @@ -88,7 +95,7 @@ impl ProviderDef for CodexAcpProvider {
args,
env,
env_remove: vec![],
work_dir,
work_dir: working_dir,
mcp_servers,
// Disabled until https://github.com/zed-industries/codex-acp/issues/179 is fixed.
session_mode_id: None,
Expand Down
Loading