Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 162 additions & 2 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ const COMPACTION_THINKING_TEXT: &str = "goose is compacting the conversation..."
const DEFAULT_FRONTEND_INSTRUCTIONS: &str =
"The following tools are provided directly by the frontend and will be executed by the frontend when called.";

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ToolCategory {
Shell,
Read,
Write,
Other,
}

fn categorize_tool(tool_name: &str) -> ToolCategory {
let local = tool_name.rsplit("__").next().unwrap_or(tool_name);
match local {
"shell" | "bash" | "exec" | "run" => ToolCategory::Shell,
"read" | "view" | "cat" | "read_file" => ToolCategory::Read,
"write" | "edit" | "patch" | "write_file" | "edit_file" => ToolCategory::Write,
_ => ToolCategory::Other,
}
}

fn extract_string_arg(input: &Value, keys: &[&str]) -> Option<String> {
let obj = input.as_object()?;
for k in keys {
if let Some(s) = obj.get(*k).and_then(|v| v.as_str()) {
if !s.is_empty() {
return Some(s.to_string());
}
}
}
None
}

/// Context needed for the reply function
pub struct ReplyContext {
pub conversation: Conversation,
Expand Down Expand Up @@ -304,6 +334,65 @@ impl Agent {
.await;
}

async fn emit_pre_tool_extended_hooks(
&self,
tool_name: &str,
tool_input: Option<&Value>,
session: &Session,
) {
let working_dir = session.working_dir.to_string_lossy().to_string();
match categorize_tool(tool_name) {
ToolCategory::Shell => {
if let Some(cmd) = tool_input.and_then(|v| extract_string_arg(v, &["command"])) {
self.emit_with_matcher(
crate::hooks::HookEvent::BeforeShellExecution,
&session.id,
&cmd,
tool_name,
tool_input.cloned(),
&working_dir,
)
.await;
}
}
ToolCategory::Read => {
if let Some(path) =
tool_input.and_then(|v| extract_string_arg(v, &["path", "file", "file_path"]))
{
self.emit_with_matcher(
crate::hooks::HookEvent::BeforeReadFile,
&session.id,
&path,
tool_name,
tool_input.cloned(),
&working_dir,
)
.await;
}
}
ToolCategory::Write | ToolCategory::Other => {}
}
}

async fn emit_with_matcher(
&self,
event: crate::hooks::HookEvent,
session_id: &str,
matcher_context: &str,
tool_name: &str,
tool_input: Option<Value>,
working_dir: &str,
) {
if !self.hook_manager.has_hooks(event) {
return;
}
let mut ctx = crate::hooks::HookContext::new(event, session_id)
.with_tool(tool_name.to_string(), tool_input)
.with_working_dir(working_dir.to_string());
ctx.matcher_context = Some(matcher_context.to_string());
self.hook_manager.emit(event, ctx).await;
}

fn with_post_tool_hook(
&self,
result: ToolCallResult,
Expand All @@ -318,6 +407,7 @@ impl Agent {
.arguments
.as_ref()
.map(|a| serde_json::Value::Object(a.clone()));
let category = categorize_tool(&tool_name);

let fut = async move {
let processed_result =
Expand All @@ -331,11 +421,38 @@ impl Agent {

if hook_manager.has_hooks(event) {
let ctx = crate::hooks::HookContext::new(event, &session_id)
.with_tool(tool_name, tool_input)
.with_working_dir(working_dir);
.with_tool(tool_name.clone(), tool_input.clone())
.with_working_dir(working_dir.clone());
hook_manager.emit(event, ctx).await;
}

if event == crate::hooks::HookEvent::PostToolUse {
let extended = match category {
ToolCategory::Shell => Some((
crate::hooks::HookEvent::AfterShellExecution,
tool_input
.as_ref()
.and_then(|v| extract_string_arg(v, &["command"])),
)),
ToolCategory::Write => Some((
crate::hooks::HookEvent::AfterFileEdit,
tool_input
.as_ref()
.and_then(|v| extract_string_arg(v, &["path", "file", "file_path"])),
)),
_ => None,
};
if let Some((ext_event, Some(matcher))) = extended {
if hook_manager.has_hooks(ext_event) {
let mut ctx = crate::hooks::HookContext::new(ext_event, &session_id)
.with_tool(tool_name, tool_input)
.with_working_dir(working_dir);
ctx.matcher_context = Some(matcher);
hook_manager.emit(ext_event, ctx).await;
}
}
}

processed_result
};

Expand Down Expand Up @@ -796,6 +913,17 @@ impl Agent {
.await;
}

let tool_input_for_extended = tool_call
.arguments
.as_ref()
.map(|a| serde_json::Value::Object(a.clone()));
self.emit_pre_tool_extended_hooks(
&tool_call.name,
tool_input_for_extended.as_ref(),
session,
)
.await;

if tool_call.name == PLATFORM_MANAGE_SCHEDULE_TOOL_NAME {
let arguments = tool_call
.arguments
Expand Down Expand Up @@ -2092,6 +2220,8 @@ impl Agent {
if !last_assistant_text.is_empty() {
tracing::info!(target: "goose::agents::agent", trace_output = last_assistant_text.as_str());
}

self.emit_hook(crate::hooks::HookEvent::Stop, &session_config.id).await;
}.instrument(reply_stream_span));
Ok(inner)
}
Expand Down Expand Up @@ -2754,4 +2884,34 @@ mod tests {

Ok(())
}

#[test]
fn categorize_tool_recognizes_conventional_names() {
assert_eq!(categorize_tool("developer__shell"), ToolCategory::Shell);
assert_eq!(categorize_tool("filesystem__write"), ToolCategory::Write);
assert_eq!(categorize_tool("filesystem__edit"), ToolCategory::Write);
assert_eq!(categorize_tool("filesystem__read"), ToolCategory::Read);
assert_eq!(categorize_tool("filesystem__view"), ToolCategory::Read);
assert_eq!(categorize_tool("filesystem__cat"), ToolCategory::Read);
assert_eq!(categorize_tool("scheduler__list"), ToolCategory::Other);
assert_eq!(categorize_tool("shell"), ToolCategory::Shell);
}

#[test]
fn extract_string_arg_picks_first_present_key() {
let input = serde_json::json!({ "file_path": "/tmp/a.txt", "path": "/tmp/b.txt" });
assert_eq!(
extract_string_arg(&input, &["path", "file", "file_path"]).as_deref(),
Some("/tmp/b.txt")
);
let input = serde_json::json!({ "file_path": "/tmp/a.txt" });
assert_eq!(
extract_string_arg(&input, &["path", "file", "file_path"]).as_deref(),
Some("/tmp/a.txt")
);
let input = serde_json::json!({ "other": 1 });
assert!(extract_string_arg(&input, &["path"]).is_none());
let input = serde_json::json!({ "path": "" });
assert!(extract_string_arg(&input, &["path"]).is_none());
}
}
Loading
Loading