diff --git a/Cargo.lock b/Cargo.lock index e3716f8..2baf2f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,6 +564,7 @@ dependencies = [ "libc", "predicates", "serde_json", + "shell-words", "tempfile", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index dda4866..26cf556 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ serde_json = "1" shell-words = "1" tempfile = "3" thiserror = "2" -tokio = { version = "1", features = ["fs", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] } +tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } unicode-segmentation = "1.12" diff --git a/crates/embers-cli/Cargo.toml b/crates/embers-cli/Cargo.toml index 3a7fa1b..0a762e0 100644 --- a/crates/embers-cli/Cargo.toml +++ b/crates/embers-cli/Cargo.toml @@ -25,6 +25,7 @@ embers-core = { path = "../embers-core" } embers-protocol = { path = "../embers-protocol" } embers-server = { path = "../embers-server" } libc.workspace = true +shell-words.workspace = true serde_json.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/crates/embers-cli/src/automation.rs b/crates/embers-cli/src/automation.rs new file mode 100644 index 0000000..12c9419 --- /dev/null +++ b/crates/embers-cli/src/automation.rs @@ -0,0 +1,344 @@ +use std::io::Write as _; +use std::path::PathBuf; + +use clap::Parser; +use embers_core::{MuxError, Result, SessionId, new_request_id}; +use embers_protocol::{ + BufferRecord, ClientMessage, ProtocolClient, ServerEnvelope, ServerEvent, SubscribeRequest, + SubscriptionAckResponse, +}; +use serde_json::{Value, json}; +use tokio::io::{AsyncBufReadExt, BufReader}; + +use crate::{Cli, CliConnection, execute_command}; + +pub async fn run(socket: PathBuf, target: Option, all_sessions: bool) -> Result<()> { + let mut request_connection = CliConnection::connect(&socket).await?; + let subscription_session = if all_sessions { + None + } else if let Some(target) = target.as_deref() { + Some( + request_connection + .resolve_session_record(Some(target)) + .await? + .id, + ) + } else { + None + }; + + let mut event_client = if all_sessions || subscription_session.is_some() { + Some( + ProtocolClient::connect(&socket) + .await + .map_err(|error| MuxError::transport(error.to_string()))?, + ) + } else { + None + }; + let subscription_id = if let Some(event_client) = event_client.as_mut() { + Some( + subscribe(event_client, subscription_session) + .await? + .subscription_id, + ) + } else { + None + }; + + emit_record(&json!({ + "kind": "hello", + "mode": "automation", + "subscription": { + "all_sessions": all_sessions, + "session_id": subscription_session.map(u64::from), + "subscription_id": subscription_id, + }, + }))?; + + let mut lines = BufReader::new(tokio::io::stdin()).lines(); + let mut sequence = 0_u64; + + if let Some(event_client) = event_client.as_mut() { + loop { + tokio::select! { + line = lines.next_line() => { + let Some(line) = line? else { + break; + }; + if let Some(record) = handle_command_line(&mut request_connection, &line, &mut sequence).await { + emit_record(&record)?; + } + } + envelope = event_client.recv() => { + match envelope.map_err(|error| MuxError::transport(error.to_string()))? { + Some(ServerEnvelope::Event(event)) => emit_record(&event_record(&event))?, + Some(ServerEnvelope::Response(response)) => { + emit_record(&json!({ + "kind": "protocol_response", + "response": format!("{response:?}"), + }))?; + } + None => break, + } + } + } + } + } else { + while let Some(line) = lines.next_line().await? { + if let Some(record) = + handle_command_line(&mut request_connection, &line, &mut sequence).await + { + emit_record(&record)?; + } + } + } + + Ok(()) +} + +async fn subscribe( + client: &mut ProtocolClient, + session_id: Option, +) -> Result { + let response = client + .request(&ClientMessage::Subscribe(SubscribeRequest { + request_id: new_request_id(), + session_id, + })) + .await + .map_err(|error| MuxError::transport(error.to_string()))?; + match response { + embers_protocol::ServerResponse::SubscriptionAck(response) => Ok(response), + other => Err(MuxError::protocol(format!( + "unexpected response to automation subscribe: {other:?}" + ))), + } +} + +async fn handle_command_line( + connection: &mut CliConnection, + line: &str, + sequence: &mut u64, +) -> Option { + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') { + return None; + } + + *sequence += 1; + let seq = *sequence; + let argv = match shell_words::split(trimmed) { + Ok(argv) => argv, + Err(error) => { + return Some(error_record( + seq, + trimmed, + MuxError::invalid_input(error.to_string()), + )); + } + }; + if argv.is_empty() { + return None; + } + + let cli = + match Cli::try_parse_from(std::iter::once("embers").chain(argv.iter().map(String::as_str))) + { + Ok(cli) => cli, + Err(error) => { + return Some(error_record( + seq, + trimmed, + MuxError::invalid_input(error.to_string()), + )); + } + }; + if cli.socket.is_some() || cli.config.is_some() || cli.log.is_some() || cli.verbose != 0 { + return Some(error_record( + seq, + trimmed, + MuxError::invalid_input("automation commands cannot override global CLI flags"), + )); + } + let Some(command) = cli.command else { + return Some(error_record( + seq, + trimmed, + MuxError::invalid_input("automation input requires a subcommand"), + )); + }; + + Some(match execute_command(connection, command).await { + Ok(stdout) => json!({ + "kind": "response", + "seq": seq, + "command": trimmed, + "ok": true, + "stdout": stdout, + }), + Err(error) => error_record(seq, trimmed, error), + }) +} + +fn emit_record(value: &Value) -> Result<()> { + let line = + serde_json::to_string(value).map_err(|error| MuxError::internal(error.to_string()))?; + println!("{line}"); + std::io::stdout().flush()?; + Ok(()) +} + +fn error_record(seq: u64, command: &str, error: MuxError) -> Value { + json!({ + "kind": "response", + "seq": seq, + "command": command, + "ok": false, + "error": { + "code": error_code(&error), + "message": error.to_string(), + }, + }) +} + +fn error_code(error: &MuxError) -> &'static str { + match error { + MuxError::Wire(error) => match error.code { + embers_core::ErrorCode::Unknown => "unknown", + embers_core::ErrorCode::InvalidRequest => "invalid_request", + embers_core::ErrorCode::ProtocolViolation => "protocol_violation", + embers_core::ErrorCode::Transport => "transport", + embers_core::ErrorCode::NotFound => "not_found", + embers_core::ErrorCode::Conflict => "conflict", + embers_core::ErrorCode::Unsupported => "unsupported", + embers_core::ErrorCode::Timeout => "timeout", + embers_core::ErrorCode::Internal => "internal", + }, + MuxError::Io(_) | MuxError::Transport(_) | MuxError::Pty(_) => "transport", + MuxError::Protocol(_) => "protocol_violation", + MuxError::InvalidInput(_) => "invalid_request", + MuxError::NotFound(_) => "not_found", + MuxError::Conflict(_) => "conflict", + MuxError::Unsupported(_) => "unsupported", + MuxError::Timeout(_) => "timeout", + MuxError::Internal(_) => "internal", + } +} + +fn event_record(event: &ServerEvent) -> Value { + json!({ + "kind": "event", + "event": event_value(event), + }) +} + +fn event_value(event: &ServerEvent) -> Value { + match event { + ServerEvent::SessionCreated(event) => json!({ + "type": "session_created", + "session": session_value(&event.session), + }), + ServerEvent::SessionClosed(event) => json!({ + "type": "session_closed", + "session_id": u64::from(event.session_id), + }), + ServerEvent::SessionRenamed(event) => json!({ + "type": "session_renamed", + "session_id": u64::from(event.session_id), + "name": event.name, + }), + ServerEvent::BufferCreated(event) => json!({ + "type": "buffer_created", + "buffer": buffer_value(&event.buffer), + }), + ServerEvent::BufferPipeChanged(event) => json!({ + "type": "buffer_pipe_changed", + "session_id": event.session_id.map(u64::from), + "buffer": buffer_value(&event.buffer), + }), + ServerEvent::BufferDetached(event) => json!({ + "type": "buffer_detached", + "buffer_id": u64::from(event.buffer_id), + }), + ServerEvent::NodeChanged(event) => json!({ + "type": "node_changed", + "session_id": u64::from(event.session_id), + }), + ServerEvent::FloatingChanged(event) => json!({ + "type": "floating_changed", + "session_id": u64::from(event.session_id), + "floating_id": event.floating_id.map(u64::from), + }), + ServerEvent::FocusChanged(event) => json!({ + "type": "focus_changed", + "session_id": u64::from(event.session_id), + "focused_leaf_id": event.focused_leaf_id.map(u64::from), + "focused_floating_id": event.focused_floating_id.map(u64::from), + }), + ServerEvent::RenderInvalidated(event) => json!({ + "type": "render_invalidated", + "buffer_id": u64::from(event.buffer_id), + }), + ServerEvent::ClientChanged(event) => json!({ + "type": "client_changed", + "client": client_value(&event.client), + "previous_session_id": event.previous_session_id.map(u64::from), + }), + } +} + +fn session_value(session: &embers_protocol::SessionRecord) -> Value { + json!({ + "id": u64::from(session.id), + "name": session.name, + "root_node_id": u64::from(session.root_node_id), + "floating_ids": session.floating_ids.iter().copied().map(u64::from).collect::>(), + "focused_leaf_id": session.focused_leaf_id.map(u64::from), + "focused_floating_id": session.focused_floating_id.map(u64::from), + "zoomed_node_id": session.zoomed_node_id.map(u64::from), + }) +} + +fn client_value(client: &embers_protocol::ClientRecord) -> Value { + json!({ + "id": client.id, + "current_session_id": client.current_session_id.map(u64::from), + "subscribed_all_sessions": client.subscribed_all_sessions, + "subscribed_session_ids": client.subscribed_session_ids.iter().copied().map(u64::from).collect::>(), + }) +} + +fn buffer_value(buffer: &BufferRecord) -> Value { + json!({ + "id": u64::from(buffer.id), + "title": buffer.title, + "command": buffer.command, + "cwd": buffer.cwd, + "kind": crate::buffer_kind_label(buffer.kind), + "state": crate::buffer_state_label(buffer.state), + "pid": buffer.pid, + "attachment_node_id": buffer.attachment_node_id.map(u64::from), + "read_only": buffer.read_only, + "helper_source_buffer_id": buffer.helper_source_buffer_id.map(u64::from), + "helper_scope": buffer.helper_scope.map(crate::history_scope_label), + "pty_size": { + "cols": buffer.pty_size.cols, + "rows": buffer.pty_size.rows, + }, + "activity": format!("{:?}", buffer.activity).to_lowercase(), + "last_snapshot_seq": buffer.last_snapshot_seq, + "exit_code": buffer.exit_code, + "pipe": buffer.pipe.as_ref().map(buffer_pipe_value), + }) +} + +fn buffer_pipe_value(pipe: &embers_protocol::BufferPipeRecord) -> Value { + json!({ + "command": pipe.command, + "state": crate::buffer_pipe_state_label(pipe.state), + "pid": pipe.pid, + "exit_code": pipe.exit_code, + "stop_reason": pipe.stop_reason.map(crate::buffer_pipe_stop_reason_label), + }) +} diff --git a/crates/embers-cli/src/lib.rs b/crates/embers-cli/src/lib.rs index 69de9ce..bd878d8 100644 --- a/crates/embers-cli/src/lib.rs +++ b/crates/embers-cli/src/lib.rs @@ -1,3 +1,4 @@ +mod automation; mod interactive; use std::ffi::OsString; @@ -21,10 +22,11 @@ use embers_core::{ }; use embers_protocol::{ BufferHistoryPlacement, BufferHistoryScope, BufferLocation, BufferLocationAttachment, - BufferLocationResponse, BufferRequest, BufferResponse, ClientMessage, ClientRecord, - ClientRequest, FloatingRecord, FloatingRequest, FloatingResponse, NodeBreakDestination, - NodeJoinPlacement, NodeRequest, PingRequest, ProtocolClient, ServerResponse, SessionRecord, - SessionRequest, SessionSnapshot, SnapshotResponse, + BufferLocationResponse, BufferPipeRecord, BufferPipeState, BufferPipeStopReason, BufferRequest, + BufferResponse, ClientMessage, ClientRecord, ClientRequest, FloatingRecord, FloatingRequest, + FloatingResponse, NodeBreakDestination, NodeJoinPlacement, NodeRequest, PingRequest, + ProtocolClient, ServerResponse, SessionRecord, SessionRequest, SessionSnapshot, + SnapshotResponse, }; use embers_server::{SOCKET_ENV_VAR, Server, ServerConfig}; use tokio::time::{Duration, sleep}; @@ -94,6 +96,12 @@ pub enum Command { #[arg(last = true)] command: Vec, }, + Automation { + #[arg(short = 't', long = "target", conflicts_with = "all_sessions")] + target: Option, + #[arg(long, conflicts_with = "target")] + all_sessions: bool, + }, Ping { #[arg(default_value = "phase0")] payload: String, @@ -258,6 +266,10 @@ pub enum BufferCommand { Show { buffer_id: u64, }, + Pipe { + #[command(subcommand)] + command: BufferPipeCommand, + }, Reveal { buffer_id: u64, #[arg(long)] @@ -274,6 +286,25 @@ pub enum BufferCommand { }, } +#[derive(Debug, Subcommand)] +pub enum BufferPipeCommand { + Start { + buffer_id: u64, + #[arg(long)] + cwd: Option, + #[arg(long = "env", value_parser = parse_string_env_arg)] + env: Vec<(String, String)>, + #[arg(last = true)] + command: Vec, + }, + Stop { + buffer_id: u64, + }, + Show { + buffer_id: u64, + }, +} + #[derive(Debug, Subcommand)] pub enum NodeCommand { Zoom { @@ -343,11 +374,17 @@ pub enum JoinPlacementArg { async fn execute(socket: &Path, command: Command) -> Result { let mut connection = CliConnection::connect(socket).await?; + execute_command(&mut connection, command).await +} +async fn execute_command(connection: &mut CliConnection, command: Command) -> Result { match command { - Command::Attach { .. } | Command::Serve | Command::RuntimeKeeper { .. } => Err( - MuxError::internal("interactive commands must be dispatched through run()"), - ), + Command::Attach { .. } + | Command::Serve + | Command::RuntimeKeeper { .. } + | Command::Automation { .. } => Err(MuxError::invalid_input( + "interactive commands must be dispatched through run()", + )), Command::Ping { payload } => { let response = connection .request(ClientMessage::Ping(PingRequest { @@ -488,6 +525,96 @@ async fn execute(socket: &Path, command: Command) -> Result { ensure_matching_buffer_id("buffer show", requested_buffer_id, buffer.id)?; Ok(format_buffer_details(&buffer, &location)) } + BufferCommand::Pipe { command } => match command { + BufferPipeCommand::Start { + buffer_id, + cwd, + env, + command, + } => { + if command.is_empty() { + return Err(MuxError::invalid_input( + "buffer pipe command must not be empty", + )); + } + let expected = BufferId(buffer_id); + let response = connection + .request(ClientMessage::Buffer(BufferRequest::StartPipe { + request_id: new_request_id(), + buffer_id: expected, + command, + cwd, + env: env.into_iter().collect(), + })) + .await?; + match response { + ServerResponse::Buffer(response) => { + ensure_matching_buffer_id( + "buffer pipe start", + expected, + response.buffer.id, + )?; + format_buffer_pipe_details( + response.buffer.id, + response.buffer.pipe.as_ref(), + ) + } + other => Err(MuxError::protocol(format!( + "unexpected response to buffer pipe start: {other:?}" + ))), + } + } + BufferPipeCommand::Stop { buffer_id } => { + let expected = BufferId(buffer_id); + let response = connection + .request(ClientMessage::Buffer(BufferRequest::StopPipe { + request_id: new_request_id(), + buffer_id: expected, + })) + .await?; + match response { + ServerResponse::Buffer(response) => { + ensure_matching_buffer_id( + "buffer pipe stop", + expected, + response.buffer.id, + )?; + format_buffer_pipe_details( + response.buffer.id, + response.buffer.pipe.as_ref(), + ) + } + other => Err(MuxError::protocol(format!( + "unexpected response to buffer pipe stop: {other:?}" + ))), + } + } + BufferPipeCommand::Show { buffer_id } => { + let expected = BufferId(buffer_id); + let response = connection + .request(ClientMessage::Buffer(BufferRequest::Get { + request_id: new_request_id(), + buffer_id: expected, + })) + .await?; + match response { + ServerResponse::Buffer(response) => { + ensure_matching_buffer_id( + "buffer pipe show", + expected, + response.buffer.id, + )?; + format_buffer_pipe_details( + response.buffer.id, + response.buffer.pipe.as_ref(), + ) + } + other => Err(MuxError::protocol(format!( + "unexpected response to buffer pipe show: {other:?}" + ))), + } + } + }, BufferCommand::Reveal { buffer_id, client } => { let requested_buffer_id = BufferId(buffer_id); let response = connection @@ -667,14 +794,14 @@ async fn execute(socket: &Path, command: Command) -> Result { })) .await; let response = rollback_created_buffer_on_error( - &mut connection, + connection, buffer.buffer.id, "new-window", response, ) .await?; let snapshot = rollback_created_buffer_on_error( - &mut connection, + connection, buffer.buffer.id, "new-window", expect_session_snapshot(response, "new-window"), @@ -748,14 +875,14 @@ async fn execute(socket: &Path, command: Command) -> Result { })) .await; let response = rollback_created_buffer_on_error( - &mut connection, + connection, buffer.buffer.id, "split-window", response, ) .await?; let snapshot = rollback_created_buffer_on_error( - &mut connection, + connection, buffer.buffer.id, "split-window", expect_session_snapshot(response, "split-window"), @@ -882,14 +1009,14 @@ async fn execute(socket: &Path, command: Command) -> Result { })) .await; let response = rollback_created_buffer_on_error( - &mut connection, + connection, buffer.buffer.id, "display-popup", response, ) .await?; let popup = rollback_created_buffer_on_error( - &mut connection, + connection, buffer.buffer.id, "display-popup", expect_floating(response, "display-popup"), @@ -951,6 +1078,13 @@ pub async fn run(cli: Cli) -> Result<()> { } interactive::run(socket, target, config).await } + Some(Command::Automation { + target, + all_sessions, + }) => { + ensure_server_process(&socket).await?; + automation::run(socket, target, all_sessions).await + } Some(Command::Serve) => run_server(socket).await, Some(command) => { ensure_server_process(&socket).await?; @@ -1001,6 +1135,16 @@ fn parse_env_arg(value: &str) -> std::result::Result<(String, OsString), String> Ok((key.to_owned(), decode_runtime_keeper_env_value(env_value)?)) } +fn parse_string_env_arg(value: &str) -> std::result::Result<(String, String), String> { + let Some((key, env_value)) = value.split_once('=') else { + return Err("expected KEY=VALUE".to_owned()); + }; + if key.is_empty() { + return Err("environment key must not be empty".to_owned()); + } + Ok((key.to_owned(), env_value.to_owned())) +} + fn decode_runtime_keeper_env_value(value: &str) -> std::result::Result { let Some(encoded) = value.strip_prefix("base64:") else { return Ok(OsString::from(value)); @@ -1734,9 +1878,62 @@ fn format_buffer_details( serde_json::to_string(cwd).expect("buffer working directories serialize to JSON"); lines.push(format!("cwd\t{serialized_cwd}")); } + if let Some(pipe) = &buffer.pipe { + lines.push(format!( + "pipe_state\t{}", + buffer_pipe_state_label(pipe.state) + )); + let serialized_command = + serde_json::to_string(&pipe.command).expect("buffer pipe commands serialize to JSON"); + lines.push(format!("pipe_command\t{serialized_command}")); + if let Some(pid) = pipe.pid { + lines.push(format!("pipe_pid\t{pid}")); + } + if let Some(exit_code) = pipe.exit_code { + lines.push(format!("pipe_exit_code\t{exit_code}")); + } + if let Some(reason) = pipe.stop_reason { + lines.push(format!( + "pipe_stop_reason\t{}", + buffer_pipe_stop_reason_label(reason) + )); + } + } lines.join("\n") } +fn format_buffer_pipe_details( + buffer_id: BufferId, + pipe: Option<&BufferPipeRecord>, +) -> Result { + let Some(pipe) = pipe else { + return Err(MuxError::not_found(format!( + "buffer {buffer_id} has no pipe state" + ))); + }; + let mut lines = vec![ + format!("buffer_id\t{buffer_id}"), + format!("state\t{}", buffer_pipe_state_label(pipe.state)), + format!( + "command\t{}", + serde_json::to_string(&pipe.command).expect("buffer pipe commands serialize to JSON") + ), + ]; + if let Some(pid) = pipe.pid { + lines.push(format!("pid\t{pid}")); + } + if let Some(exit_code) = pipe.exit_code { + lines.push(format!("exit_code\t{exit_code}")); + } + if let Some(reason) = pipe.stop_reason { + lines.push(format!( + "stop_reason\t{}", + buffer_pipe_stop_reason_label(reason) + )); + } + Ok(lines.join("\n")) +} + fn format_buffer_location_line(location: &BufferLocation) -> String { format!( "{}\t{}", @@ -2035,6 +2232,23 @@ fn buffer_kind_label(kind: embers_protocol::BufferRecordKind) -> &'static str { } } +fn buffer_pipe_state_label(state: BufferPipeState) -> &'static str { + match state { + BufferPipeState::Running => "running", + BufferPipeState::Stopped => "stopped", + } +} + +fn buffer_pipe_stop_reason_label(reason: BufferPipeStopReason) -> &'static str { + match reason { + BufferPipeStopReason::Requested => "requested", + BufferPipeStopReason::PipeExited => "pipe_exited", + BufferPipeStopReason::WriteFailed => "write_failed", + BufferPipeStopReason::BufferExited => "buffer_exited", + BufferPipeStopReason::RuntimeInterrupted => "runtime_interrupted", + } +} + fn history_scope_label(scope: BufferHistoryScope) -> &'static str { match scope { BufferHistoryScope::Full => "full", @@ -2359,6 +2573,7 @@ mod tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, }, &BufferLocation::session(BufferId(7), SessionId(1), NodeId(3)), ); @@ -2394,6 +2609,7 @@ mod tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, }, &BufferLocation::session(BufferId(8), SessionId(1), NodeId(4)), ); @@ -2437,6 +2653,7 @@ mod tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, }, &BufferLocation::session(BufferId(9), SessionId(1), NodeId(5)), ); diff --git a/crates/embers-cli/tests/automation.rs b/crates/embers-cli/tests/automation.rs new file mode 100644 index 0000000..e3b3585 --- /dev/null +++ b/crates/embers-cli/tests/automation.rs @@ -0,0 +1,102 @@ +use std::process::Stdio; +use std::time::Duration; + +use embers_test_support::{TestServer, acquire_test_lock, cargo_bin_path}; +use serde_json::Value; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{ChildStdout, Command}; +use tokio::time::timeout; + +async fn read_record(lines: &mut tokio::io::Lines>) -> Value { + let line = timeout(Duration::from_secs(2), lines.next_line()) + .await + .expect("automation output arrives before timeout") + .expect("automation stdout read succeeds") + .expect("automation output line"); + serde_json::from_str(&line).expect("automation output is valid JSON") +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn automation_mode_emits_hello_response_and_event_records() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); + let server = TestServer::start().await.expect("start server"); + + let mut child = Command::new(cargo_bin_path("embers")) + .arg("--socket") + .arg(server.socket_path()) + .arg("automation") + .arg("--all-sessions") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn automation mode"); + + let mut stdin = child.stdin.take().expect("automation stdin"); + let stdout = child.stdout.take().expect("automation stdout"); + let mut lines = BufReader::new(stdout).lines(); + + let hello = read_record(&mut lines).await; + assert_eq!(hello["kind"], "hello"); + assert_eq!(hello["mode"], "automation"); + assert_eq!(hello["subscription"]["all_sessions"], true); + assert!(hello["subscription"]["subscription_id"].as_u64().is_some()); + + stdin + .write_all(b"new-session alpha\n") + .await + .expect("write automation command"); + stdin.flush().await.expect("flush automation stdin"); + + let mut saw_response = false; + let mut saw_event = false; + loop { + let record = read_record(&mut lines).await; + match record["kind"].as_str() { + Some("response") => { + assert_eq!(record["seq"], 1); + assert_eq!(record["command"], "new-session alpha"); + assert_eq!(record["ok"], true); + assert!( + record["stdout"] + .as_str() + .is_some_and(|stdout| stdout.contains("alpha")), + "response stdout should mention the created session: {record:?}" + ); + saw_response = true; + } + Some("event") => { + if record["event"]["type"] == "session_created" { + assert_eq!(record["event"]["session"]["name"], "alpha"); + saw_event = true; + } + } + other => panic!("unexpected automation record kind: {other:?}"), + } + + if saw_response && saw_event { + break; + } + } + + assert!( + saw_response, + "automation mode did not emit a response record" + ); + assert!( + saw_event, + "automation mode did not emit a session_created event" + ); + + drop(stdin); + let status = timeout(Duration::from_secs(2), child.wait()) + .await + .expect("automation process exits before timeout") + .expect("automation process wait succeeds"); + assert!( + status.success(), + "automation mode exited unsuccessfully: {status}" + ); + + server.shutdown().await.expect("shutdown server"); +} diff --git a/crates/embers-cli/tests/integration.rs b/crates/embers-cli/tests/integration.rs index 0200199..061c373 100644 --- a/crates/embers-cli/tests/integration.rs +++ b/crates/embers-cli/tests/integration.rs @@ -1,3 +1,4 @@ +mod automation; mod interactive; mod panes; mod ping; diff --git a/crates/embers-cli/tests/panes.rs b/crates/embers-cli/tests/panes.rs index 24188e5..d925d61 100644 --- a/crates/embers-cli/tests/panes.rs +++ b/crates/embers-cli/tests/panes.rs @@ -1,11 +1,31 @@ +use std::fs; use std::time::Duration; use embers_core::{ErrorCode, RequestId}; use embers_protocol::{BufferRequest, ClientMessage, InputRequest, ServerResponse}; use embers_test_support::{TestConnection, TestServer, acquire_test_lock}; +use predicates::prelude::*; +use tempfile::tempdir; use tokio::time::sleep; -use crate::support::{run_cli, session_snapshot_by_name, stdout}; +use crate::support::{cli_command, run_cli, session_snapshot_by_name, stdout}; + +async fn wait_for_file_contains(path: &std::path::Path, needle: &str) { + let deadline = tokio::time::Instant::now() + Duration::from_secs(3); + loop { + if let Ok(contents) = fs::read_to_string(path) + && contents.contains(needle) + { + return; + } + assert!( + tokio::time::Instant::now() < deadline, + "timed out waiting for file {:?} to contain {needle:?}", + path + ); + sleep(Duration::from_millis(25)).await; + } +} #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn pane_commands_round_trip_through_cli() { @@ -379,6 +399,118 @@ async fn buffer_show_and_history_open_helper_buffers() { server.shutdown().await.expect("shutdown server"); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn buffer_pipe_start_show_stop_round_trips_and_rejects_helper_buffers() { + let _guard = acquire_test_lock().await.expect("acquire test lock"); + let server = TestServer::start().await.expect("start server"); + let tempdir = tempdir().expect("tempdir"); + let pipe_path = tempdir.path().join("pipe-output.txt"); + + run_cli(&server, ["new-session", "alpha"]); + run_cli( + &server, + [ + "new-window", + "-t", + "alpha", + "--title", + "work", + "--", + "/bin/sh", + ], + ); + + let mut connection = TestConnection::connect(server.socket_path()) + .await + .expect("connect protocol client"); + let snapshot = session_snapshot_by_name(&mut connection, "alpha").await; + let leaf = snapshot + .session + .focused_leaf_id + .expect("focused pane exists"); + let buffer_id = snapshot + .nodes + .iter() + .find(|node| node.id == leaf) + .and_then(|node| node.buffer_view.as_ref()) + .map(|view| view.buffer_id) + .expect("focused pane buffer exists"); + + let shell_command = format!("cat >> '{}'", pipe_path.display()); + let started = run_cli( + &server, + [ + "buffer", + "pipe", + "start", + &buffer_id.to_string(), + "--", + "/bin/sh", + "-lc", + &shell_command, + ], + ); + let started_stdout = stdout(&started); + assert!(started_stdout.contains(&format!("buffer_id\t{buffer_id}"))); + assert!(started_stdout.contains("state\trunning")); + + run_cli( + &server, + [ + "send-keys", + "-t", + &leaf.to_string(), + "--enter", + "printf", + "pipe-cli-marker\\n", + ], + ); + wait_for_file_contains(&pipe_path, "pipe-cli-marker").await; + + let shown = run_cli(&server, ["buffer", "pipe", "show", &buffer_id.to_string()]); + let shown_stdout = stdout(&shown); + assert!(shown_stdout.contains("state\trunning")); + assert!(shown_stdout.contains("command\t[")); + + let stopped = run_cli(&server, ["buffer", "pipe", "stop", &buffer_id.to_string()]); + let stopped_stdout = stdout(&stopped); + assert!(stopped_stdout.contains("state\tstopped")); + assert!(stopped_stdout.contains("stop_reason\trequested")); + + let opened = run_cli( + &server, + [ + "buffer", + "history", + &buffer_id.to_string(), + "--scope", + "visible", + ], + ); + let helper_buffer_id = stdout(&opened) + .trim() + .split('\t') + .next() + .expect("helper buffer id column") + .parse::() + .expect("helper buffer id parses"); + + cli_command(&server) + .args([ + "buffer", + "pipe", + "start", + &helper_buffer_id.to_string(), + "--", + "/bin/cat", + ]) + .assert() + .failure() + .stderr(predicate::str::contains("read-only")); + + server.shutdown().await.expect("shutdown server"); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn node_commands_cover_zoom_swap_break_join_and_reorder() { let _guard = acquire_test_lock().await.expect("acquire test lock"); diff --git a/crates/embers-client/src/client.rs b/crates/embers-client/src/client.rs index 06128ca..d46d290 100644 --- a/crates/embers-client/src/client.rs +++ b/crates/embers-client/src/client.rs @@ -321,6 +321,7 @@ where ServerEvent::RenderInvalidated(event) => { self.refresh_buffer_record(event.buffer_id).await } + ServerEvent::BufferPipeChanged(_) => Ok(()), ServerEvent::BufferCreated(_) | ServerEvent::BufferDetached(_) | ServerEvent::FocusChanged(_) => Ok(()), diff --git a/crates/embers-client/src/configured_client.rs b/crates/embers-client/src/configured_client.rs index d3191c5..08f1684 100644 --- a/crates/embers-client/src/configured_client.rs +++ b/crates/embers-client/src/configured_client.rs @@ -1597,6 +1597,9 @@ where fn event_session_id(&self, event: &ServerEvent) -> Option { event.session_id().or_else(|| match event { ServerEvent::BufferCreated(event) => self.session_id_for_buffer_record(&event.buffer), + ServerEvent::BufferPipeChanged(event) => { + self.session_id_for_buffer_record(&event.buffer) + } ServerEvent::BufferDetached(event) => self.session_id_for_buffer(event.buffer_id), ServerEvent::RenderInvalidated(event) => self.session_id_for_buffer(event.buffer_id), ServerEvent::SessionCreated(_) @@ -2764,6 +2767,7 @@ fn event_name(event: &ServerEvent) -> &'static str { ServerEvent::SessionClosed(_) => "session_closed", ServerEvent::SessionRenamed(_) => "session_renamed", ServerEvent::BufferCreated(_) => "buffer_created", + ServerEvent::BufferPipeChanged(_) => "buffer_pipe_changed", ServerEvent::BufferDetached(_) => "buffer_detached", ServerEvent::NodeChanged(_) => "node_changed", ServerEvent::FloatingChanged(_) => "floating_changed", @@ -2787,6 +2791,11 @@ fn event_info( info.buffer_id = Some(event.buffer.id); info.node_id = event.buffer.attachment_node_id; } + ServerEvent::BufferPipeChanged(event) => { + info.session_id = event.session_id; + info.buffer_id = Some(event.buffer.id); + info.node_id = event.buffer.attachment_node_id; + } ServerEvent::BufferDetached(event) => info.buffer_id = Some(event.buffer_id), ServerEvent::NodeChanged(event) => info.session_id = Some(event.session_id), ServerEvent::FloatingChanged(event) => { @@ -2924,6 +2933,7 @@ mod tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, location, diff --git a/crates/embers-client/src/presentation.rs b/crates/embers-client/src/presentation.rs index 01dcb85..7b67608 100644 --- a/crates/embers-client/src/presentation.rs +++ b/crates/embers-client/src/presentation.rs @@ -748,6 +748,7 @@ mod zoom_tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, ); @@ -837,6 +838,7 @@ mod zoom_tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, ); @@ -912,6 +914,7 @@ mod zoom_tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, ); @@ -1034,6 +1037,7 @@ mod zoom_tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, ); @@ -1056,6 +1060,7 @@ mod zoom_tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, ); @@ -1160,6 +1165,7 @@ mod zoom_tests { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: Default::default(), }, ); diff --git a/crates/embers-client/src/state.rs b/crates/embers-client/src/state.rs index 7ec1334..b0864c6 100644 --- a/crates/embers-client/src/state.rs +++ b/crates/embers-client/src/state.rs @@ -204,6 +204,9 @@ impl ClientState { ServerEvent::BufferCreated(event) => { self.buffers.insert(event.buffer.id, event.buffer.clone()); } + ServerEvent::BufferPipeChanged(event) => { + self.buffers.insert(event.buffer.id, event.buffer.clone()); + } ServerEvent::BufferDetached(event) => { if let Some(buffer) = self.buffers.get_mut(&event.buffer_id) { buffer.attachment_node_id = None; diff --git a/crates/embers-client/tests/configured_client.rs b/crates/embers-client/tests/configured_client.rs index afc8387..0e8d150 100644 --- a/crates/embers-client/tests/configured_client.rs +++ b/crates/embers-client/tests/configured_client.rs @@ -216,6 +216,7 @@ fn second_session_state() -> embers_client::ClientState { activity: ActivityState::Idle, last_snapshot_seq: 1, exit_code: None, + pipe: None, env: BTreeMap::new(), }, ); @@ -1441,6 +1442,7 @@ async fn detached_buffer_events_do_not_fall_back_to_the_active_session() { activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, env: BTreeMap::new(), }, })); diff --git a/crates/embers-client/tests/presentation.rs b/crates/embers-client/tests/presentation.rs index bb1758e..93d7401 100644 --- a/crates/embers-client/tests/presentation.rs +++ b/crates/embers-client/tests/presentation.rs @@ -290,6 +290,7 @@ fn foreign_session_zoom_targets_are_ignored() { activity: ActivityState::Idle, last_snapshot_seq: 1, exit_code: None, + pipe: None, }, ); diff --git a/crates/embers-client/tests/reducer.rs b/crates/embers-client/tests/reducer.rs index 5ae7209..e8b7f52 100644 --- a/crates/embers-client/tests/reducer.rs +++ b/crates/embers-client/tests/reducer.rs @@ -32,6 +32,7 @@ fn buffer(id: u64, attachment_node_id: Option, title: &str) -> BufferRecord activity: ActivityState::Idle, last_snapshot_seq: 0, exit_code: None, + pipe: None, } } diff --git a/crates/embers-client/tests/support/mod.rs b/crates/embers-client/tests/support/mod.rs index d08d39b..b7d89d5 100644 --- a/crates/embers-client/tests/support/mod.rs +++ b/crates/embers-client/tests/support/mod.rs @@ -322,6 +322,7 @@ fn buffer( activity, last_snapshot_seq: 0, exit_code: None, + pipe: None, } } diff --git a/crates/embers-protocol/schema/embers.fbs b/crates/embers-protocol/schema/embers.fbs index 8293754..835ed79 100644 --- a/crates/embers-protocol/schema/embers.fbs +++ b/crates/embers-protocol/schema/embers.fbs @@ -40,6 +40,7 @@ enum MessageKind : ubyte { RenderInvalidatedEvent = 47, SessionRenamedEvent = 48, ClientChangedEvent = 49, + BufferPipeChangedEvent = 50, } enum ErrorCodeWire : ubyte { @@ -79,6 +80,8 @@ enum BufferOp : ubyte { Reveal = 9, OpenHistory = 10, Inspect = 11, + StartPipe = 12, + StopPipe = 13, } enum NodeOp : ubyte { @@ -158,6 +161,19 @@ enum BufferHistoryPlacementWire : ubyte { Floating = 1, } +enum BufferPipeStateWire : ubyte { + Running = 0, + Stopped = 1, +} + +enum BufferPipeStopReasonWire : ubyte { + Requested = 0, + PipeExited = 1, + WriteFailed = 2, + BufferExited = 3, + RuntimeInterrupted = 4, +} + enum NodeBreakDestinationWire : ubyte { Tab = 0, Floating = 1, @@ -323,6 +339,18 @@ table BufferRecord { helper_source_buffer_id:ulong = 0 (id: 18); helper_scope:BufferHistoryScopeWire = Full (id: 19); has_helper_scope:bool = false (id: 20); + pipe:BufferPipeRecord (id: 21); +} + +table BufferPipeRecord { + command:[string] (id: 0); + state:BufferPipeStateWire = Running (id: 1); + pid:uint = 0 (id: 2); + has_pid:bool = false (id: 3); + exit_code:int = 0 (id: 4); + has_exit_code:bool = false (id: 5); + stop_reason:BufferPipeStopReasonWire = Requested (id: 6); + has_stop_reason:bool = false (id: 7); } table BufferViewRecord { @@ -493,6 +521,11 @@ table BufferCreatedEvent { buffer:BufferRecord; } +table BufferPipeChangedEvent { + session_id:ulong = 0; + buffer:BufferRecord; +} + table BufferDetachedEvent { buffer_id:ulong; } @@ -567,6 +600,7 @@ table Envelope { client_changed_event:ClientChangedEvent (id: 35); buffer_location_response:BufferLocationResponse (id: 36); buffer_with_location_response:BufferWithLocationResponse (id: 37); + buffer_pipe_changed_event:BufferPipeChangedEvent (id: 38); } root_type Envelope; diff --git a/crates/embers-protocol/src/codec.rs b/crates/embers-protocol/src/codec.rs index 37fa943..181f4d3 100644 --- a/crates/embers-protocol/src/codec.rs +++ b/crates/embers-protocol/src/codec.rs @@ -194,6 +194,23 @@ fn validate_required_node_ids( Ok(()) } +fn validate_non_empty_command( + command: &[String], + field: &'static str, +) -> Result<(), ProtocolError> { + if command.is_empty() { + return Err(ProtocolError::InvalidMessageOwned(format!( + "{field} must not be empty" + ))); + } + if command[0].is_empty() { + return Err(ProtocolError::InvalidMessageOwned(format!( + "{field} first segment must not be empty" + ))); + } + Ok(()) +} + fn validate_session_request(req: &SessionRequest) -> Result<(), ProtocolError> { match req { SessionRequest::Create { .. } | SessionRequest::List { .. } => Ok(()), @@ -224,6 +241,12 @@ fn validate_buffer_request(req: &BufferRequest) -> Result<(), ProtocolError> { BufferRequest::List { session_id, .. } => { validate_optional_session_id(*session_id, "buffer_request.session_id") } + BufferRequest::StartPipe { + buffer_id, command, .. + } => { + validate_required_buffer_id(*buffer_id, "buffer_request.buffer_id")?; + validate_non_empty_command(command, "buffer_request.command") + } BufferRequest::Get { buffer_id, .. } | BufferRequest::Inspect { buffer_id, .. } | BufferRequest::Detach { buffer_id, .. } @@ -233,7 +256,8 @@ fn validate_buffer_request(req: &BufferRequest) -> Result<(), ProtocolError> { | BufferRequest::ScrollbackSlice { buffer_id, .. } | BufferRequest::GetLocation { buffer_id, .. } | BufferRequest::Reveal { buffer_id, .. } - | BufferRequest::OpenHistory { buffer_id, .. } => { + | BufferRequest::OpenHistory { buffer_id, .. } + | BufferRequest::StopPipe { buffer_id, .. } => { validate_required_buffer_id(*buffer_id, "buffer_request.buffer_id") } } @@ -479,6 +503,35 @@ fn decode_buffer_history_placement( } } +fn encode_buffer_pipe_stop_reason(reason: BufferPipeStopReason) -> fb::BufferPipeStopReasonWire { + match reason { + BufferPipeStopReason::Requested => fb::BufferPipeStopReasonWire::Requested, + BufferPipeStopReason::PipeExited => fb::BufferPipeStopReasonWire::PipeExited, + BufferPipeStopReason::WriteFailed => fb::BufferPipeStopReasonWire::WriteFailed, + BufferPipeStopReason::BufferExited => fb::BufferPipeStopReasonWire::BufferExited, + BufferPipeStopReason::RuntimeInterrupted => { + fb::BufferPipeStopReasonWire::RuntimeInterrupted + } + } +} + +fn decode_buffer_pipe_stop_reason( + reason: fb::BufferPipeStopReasonWire, +) -> Result { + match reason { + fb::BufferPipeStopReasonWire::Requested => Ok(BufferPipeStopReason::Requested), + fb::BufferPipeStopReasonWire::PipeExited => Ok(BufferPipeStopReason::PipeExited), + fb::BufferPipeStopReasonWire::WriteFailed => Ok(BufferPipeStopReason::WriteFailed), + fb::BufferPipeStopReasonWire::BufferExited => Ok(BufferPipeStopReason::BufferExited), + fb::BufferPipeStopReasonWire::RuntimeInterrupted => { + Ok(BufferPipeStopReason::RuntimeInterrupted) + } + _ => Err(ProtocolError::InvalidMessage( + "unknown buffer pipe stop reason", + )), + } +} + fn encode_node_break_destination( destination: NodeBreakDestination, ) -> fb::NodeBreakDestinationWire { @@ -1021,6 +1074,46 @@ fn encode_buffer_request<'a>( None, None, ), + BufferRequest::StartPipe { + buffer_id, + command, + cwd, + env, + .. + } => ( + fb::BufferOp::StartPipe, + (*buffer_id).into(), + 0, + 0, + false, + false, + false, + 0, + 0, + fb::BufferHistoryScopeWire::Full, + fb::BufferHistoryPlacementWire::Tab, + None, + Some(command), + cwd.as_deref(), + Some(env), + ), + BufferRequest::StopPipe { buffer_id, .. } => ( + fb::BufferOp::StopPipe, + (*buffer_id).into(), + 0, + 0, + false, + false, + false, + 0, + 0, + fb::BufferHistoryScopeWire::Full, + fb::BufferHistoryPlacementWire::Tab, + None, + None, + None, + None, + ), }; let title = title_str.map(|s| builder.create_string(s)); @@ -1168,11 +1261,23 @@ fn validate_buffer_record(record: &BufferRecord) -> Result<(), ProtocolError> { record.helper_source_buffer_id, "buffer_record.helper_source_buffer_id", )?; + if let Some(pipe) = &record.pipe { + validate_buffer_pipe_record(pipe, "buffer_record.pipe")?; + } record .validate() .map_err(ProtocolError::InvalidMessageOwned) } +fn validate_buffer_pipe_record( + record: &BufferPipeRecord, + field: &'static str, +) -> Result<(), ProtocolError> { + record + .validate(field) + .map_err(ProtocolError::InvalidMessageOwned) +} + fn validate_node_record(record: &NodeRecord) -> Result<(), ProtocolError> { validate_required_node_id(record.id, "node_record.id")?; validate_required_session_id(record.session_id, "node_record.session_id")?; @@ -1283,6 +1388,10 @@ fn validate_server_event(event: &ServerEvent) -> Result<(), ProtocolError> { validate_required_session_id(event.session_id, "session_renamed_event.session_id") } ServerEvent::BufferCreated(event) => validate_buffer_record(&event.buffer), + ServerEvent::BufferPipeChanged(event) => { + validate_optional_session_id(event.session_id, "buffer_pipe_changed_event.session_id")?; + validate_buffer_record(&event.buffer) + } ServerEvent::BufferDetached(event) => { validate_required_buffer_id(event.buffer_id, "buffer_detached_event.buffer_id") } @@ -2754,6 +2863,25 @@ fn encode_server_event<'a>( }, ) } + ServerEvent::BufferPipeChanged(e) => { + let buffer = encode_buffer_record(builder, &e.buffer); + let event = fb::BufferPipeChangedEvent::create( + builder, + &fb::BufferPipeChangedEventArgs { + session_id: e.session_id.map(|id| id.into()).unwrap_or(0), + buffer: Some(buffer), + }, + ); + fb::Envelope::create( + builder, + &fb::EnvelopeArgs { + request_id: 0, + kind: fb::MessageKind::BufferPipeChangedEvent, + buffer_pipe_changed_event: Some(event), + ..Default::default() + }, + ) + } ServerEvent::BufferDetached(e) => { let event = fb::BufferDetachedEvent::create( builder, @@ -2925,6 +3053,10 @@ fn encode_buffer_record<'a>( .helper_scope .map(encode_buffer_history_scope) .unwrap_or(fb::BufferHistoryScopeWire::Full); + let pipe = record + .pipe + .as_ref() + .map(|pipe| encode_buffer_pipe_record(builder, pipe)); fb::BufferRecord::create( builder, @@ -2933,6 +3065,7 @@ fn encode_buffer_record<'a>( title: Some(title), command: Some(command), cwd, + pipe, kind, state, pid: record.pid.unwrap_or(0), @@ -2957,6 +3090,39 @@ fn encode_buffer_record<'a>( ) } +fn encode_buffer_pipe_record<'a>( + builder: &mut FlatBufferBuilder<'a>, + record: &BufferPipeRecord, +) -> flatbuffers::WIPOffset> { + let command_vec: Vec<_> = record + .command + .iter() + .map(|segment| builder.create_string(segment)) + .collect(); + let command = builder.create_vector(&command_vec); + let state = match record.state { + BufferPipeState::Running => fb::BufferPipeStateWire::Running, + BufferPipeState::Stopped => fb::BufferPipeStateWire::Stopped, + }; + let stop_reason = record + .stop_reason + .map(encode_buffer_pipe_stop_reason) + .unwrap_or(fb::BufferPipeStopReasonWire::Requested); + fb::BufferPipeRecord::create( + builder, + &fb::BufferPipeRecordArgs { + command: Some(command), + state, + pid: record.pid.unwrap_or(0), + has_pid: record.pid.is_some(), + exit_code: record.exit_code.unwrap_or(0), + has_exit_code: record.exit_code.is_some(), + stop_reason, + has_stop_reason: record.stop_reason.is_some(), + }, + ) +} + fn encode_node_record<'a>( builder: &mut FlatBufferBuilder<'a>, record: &NodeRecord, @@ -3312,8 +3478,29 @@ pub fn decode_client_message(bytes: &[u8]) -> Result BufferRequest::StartPipe { + request_id, + buffer_id: decode_required_buffer_id( + req.buffer_id(), + "buffer_request.buffer_id", + )?, + command: required(req.command(), "buffer_request.command")? + .iter() + .map(|segment| segment.to_owned()) + .collect(), + cwd: req.cwd().map(|cwd| cwd.to_owned()), + env: decode_string_map(req.env_keys(), req.env_values(), "buffer_request.env")?, + }, + fb::BufferOp::StopPipe => BufferRequest::StopPipe { + request_id, + buffer_id: decode_required_buffer_id( + req.buffer_id(), + "buffer_request.buffer_id", + )?, + }, _ => return Err(ProtocolError::InvalidMessage("unknown buffer op")), }; + validate_buffer_request(&buffer_request)?; Ok(ClientMessage::Buffer(buffer_request)) } fb::MessageKind::NodeRequest => { @@ -3981,6 +4168,23 @@ pub fn decode_server_envelope(bytes: &[u8]) -> Result { + let event = required( + envelope.buffer_pipe_changed_event(), + "buffer_pipe_changed_event", + )?; + let buffer = required(event.buffer(), "buffer_pipe_changed_event.buffer")?; + Ok(ServerEnvelope::Event(ServerEvent::BufferPipeChanged( + BufferPipeChangedEvent { + session_id: if event.session_id() == 0 { + None + } else { + Some(SessionId(event.session_id())) + }, + buffer: decode_buffer_record(buffer)?, + }, + ))) + } fb::MessageKind::BufferDetachedEvent => { let event = required(envelope.buffer_detached_event(), "buffer_detached_event")?; Ok(ServerEnvelope::Event(ServerEvent::BufferDetached( @@ -4120,12 +4324,17 @@ fn decode_buffer_record(record: fb::BufferRecord) -> Result Result Result { + let command_fb = required(record.command(), "buffer_pipe_record.command")?; + let command: Vec = command_fb + .iter() + .map(|segment| segment.to_owned()) + .collect(); + let state = match record.state() { + fb::BufferPipeStateWire::Running => BufferPipeState::Running, + fb::BufferPipeStateWire::Stopped => BufferPipeState::Stopped, + _ => return Err(ProtocolError::InvalidMessage("unknown buffer pipe state")), + }; + let decoded = BufferPipeRecord { + command, + state, + pid: record.has_pid().then(|| record.pid()), + exit_code: record.has_exit_code().then(|| record.exit_code()), + stop_reason: if record.has_stop_reason() { + Some(decode_buffer_pipe_stop_reason(record.stop_reason())?) + } else { + None + }, + }; + validate_buffer_pipe_record(&decoded, "buffer_pipe_record")?; + Ok(decoded) +} + fn decode_node_record(record: fb::NodeRecord) -> Result { let (kind, buffer_view, split, tabs) = match record.kind() { fb::NodeRecordKindWire::BufferView => { @@ -4348,6 +4585,7 @@ fn decode_error_code(code: fb::ErrorCodeWire) -> ErrorCode { #[cfg(test)] mod tests { + use std::collections::BTreeMap; use std::num::NonZeroU64; use flatbuffers::FlatBufferBuilder; @@ -4378,6 +4616,102 @@ mod tests { )); } + #[test] + fn validate_buffer_pipe_record_rejects_stopped_without_stop_reason() { + let record = BufferPipeRecord { + command: vec!["tee".to_owned()], + state: BufferPipeState::Stopped, + pid: None, + exit_code: Some(0), + stop_reason: None, + }; + + let error = validate_buffer_pipe_record(&record, "buffer_record.pipe") + .expect_err("stopped pipe requires stop reason"); + + assert!(matches!( + error, + ProtocolError::InvalidMessageOwned(message) + if message == "buffer_record.pipe.stop_reason required for stopped pipe" + )); + } + + #[test] + fn buffer_record_validate_rejects_running_pipe_with_exit_code() { + let record = BufferRecord { + id: BufferId(1), + title: "shell".to_owned(), + command: vec!["sh".to_owned()], + cwd: None, + pipe: Some(BufferPipeRecord { + command: vec!["tee".to_owned()], + state: BufferPipeState::Running, + pid: Some(123), + exit_code: Some(0), + stop_reason: None, + }), + kind: BufferRecordKind::Pty, + state: BufferRecordState::Running, + pid: Some(456), + attachment_node_id: None, + read_only: false, + helper_source_buffer_id: None, + helper_scope: None, + pty_size: PtySize::new(80, 24), + activity: ActivityState::Idle, + last_snapshot_seq: 1, + exit_code: None, + env: BTreeMap::new(), + }; + + let error = record + .validate() + .expect_err("running pipe cannot set exit code"); + + assert_eq!( + error, + "buffer_record.pipe.exit_code requires a stopped pipe" + ); + } + + #[test] + fn buffer_record_validate_rejects_empty_pipe_program() { + let record = BufferRecord { + id: BufferId(1), + title: "shell".to_owned(), + command: vec!["sh".to_owned()], + cwd: None, + pipe: Some(BufferPipeRecord { + command: vec![String::new()], + state: BufferPipeState::Running, + pid: Some(123), + exit_code: None, + stop_reason: None, + }), + kind: BufferRecordKind::Pty, + state: BufferRecordState::Running, + pid: Some(456), + attachment_node_id: None, + read_only: false, + helper_source_buffer_id: None, + helper_scope: None, + pty_size: PtySize::new(80, 24), + activity: ActivityState::Idle, + last_snapshot_seq: 1, + exit_code: None, + env: BTreeMap::new(), + }; + + let error = record + .validate() + .expect_err("pipe command program cannot be empty"); + + assert_eq!( + error, + "buffer_record.pipe.command first segment must not be empty" + ); + } + #[test] fn decode_node_record_rejects_split_without_children() { let mut builder = FlatBufferBuilder::new(); @@ -4730,6 +5064,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4766,6 +5101,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4803,6 +5139,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4830,6 +5167,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4856,6 +5194,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4902,6 +5241,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4943,6 +5283,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -4980,6 +5321,7 @@ mod tests { title: "helper".to_owned(), command: Vec::new(), cwd: None, + pipe: None, kind: BufferRecordKind::Helper, state: BufferRecordState::Created, pid: None, @@ -5014,6 +5356,7 @@ mod tests { title: "shell".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -5038,6 +5381,7 @@ mod tests { title: "shell".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -5077,6 +5421,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Helper, state: BufferRecordState::Running, pid: None, @@ -5101,6 +5446,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Helper, state: BufferRecordState::Running, pid: None, @@ -5197,6 +5543,7 @@ mod tests { title: "helper".to_owned(), command: vec!["echo".to_owned()], cwd: None, + pipe: None, kind: BufferRecordKind::Pty, state: BufferRecordState::Running, pid: None, @@ -5280,6 +5627,42 @@ mod tests { )); } + #[test] + fn encode_buffer_request_rejects_empty_start_pipe_command() { + let error = encode_client_message(&ClientMessage::Buffer(BufferRequest::StartPipe { + request_id: RequestId(1), + buffer_id: BufferId(7), + command: Vec::new(), + cwd: None, + env: BTreeMap::new(), + })) + .expect_err("empty start pipe command should be rejected"); + + assert!(matches!( + error, + ProtocolError::InvalidMessageOwned(message) + if message == "buffer_request.command must not be empty" + )); + } + + #[test] + fn encode_buffer_request_rejects_empty_start_pipe_program() { + let error = encode_client_message(&ClientMessage::Buffer(BufferRequest::StartPipe { + request_id: RequestId(1), + buffer_id: BufferId(7), + command: vec![String::new()], + cwd: None, + env: BTreeMap::new(), + })) + .expect_err("empty start pipe program should be rejected"); + + assert!(matches!( + error, + ProtocolError::InvalidMessageOwned(message) + if message == "buffer_request.command first segment must not be empty" + )); + } + #[test] fn encode_session_request_rejects_zero_required_session_id() { let error = encode_client_message(&ClientMessage::Session(SessionRequest::Get { @@ -5501,6 +5884,75 @@ mod tests { )); } + #[test] + fn decode_buffer_request_rejects_empty_start_pipe_command() { + let mut builder = FlatBufferBuilder::new(); + let empty_command = builder.create_vector::>(&[]); + let request = fb::BufferRequest::create( + &mut builder, + &fb::BufferRequestArgs { + op: fb::BufferOp::StartPipe, + buffer_id: 7, + command: Some(empty_command), + ..Default::default() + }, + ); + let envelope = fb::Envelope::create( + &mut builder, + &fb::EnvelopeArgs { + request_id: 8, + kind: fb::MessageKind::BufferRequest, + buffer_request: Some(request), + ..Default::default() + }, + ); + builder.finish(envelope, Some("EMBR")); + + let error = decode_client_message(builder.finished_data()) + .expect_err("empty start pipe command should be rejected"); + + assert!(matches!( + error, + ProtocolError::InvalidMessageOwned(message) + if message == "buffer_request.command must not be empty" + )); + } + + #[test] + fn decode_buffer_request_rejects_empty_start_pipe_program() { + let mut builder = FlatBufferBuilder::new(); + let empty_program = builder.create_string(""); + let command = builder.create_vector(&[empty_program]); + let request = fb::BufferRequest::create( + &mut builder, + &fb::BufferRequestArgs { + op: fb::BufferOp::StartPipe, + buffer_id: 7, + command: Some(command), + ..Default::default() + }, + ); + let envelope = fb::Envelope::create( + &mut builder, + &fb::EnvelopeArgs { + request_id: 8, + kind: fb::MessageKind::BufferRequest, + buffer_request: Some(request), + ..Default::default() + }, + ); + builder.finish(envelope, Some("EMBR")); + + let error = decode_client_message(builder.finished_data()) + .expect_err("empty start pipe program should be rejected"); + + assert!(matches!( + error, + ProtocolError::InvalidMessageOwned(message) + if message == "buffer_request.command first segment must not be empty" + )); + } + #[test] fn decode_session_request_rejects_zero_required_session_id() { let mut builder = FlatBufferBuilder::new(); diff --git a/crates/embers-protocol/src/lib.rs b/crates/embers-protocol/src/lib.rs index 65858f4..cd050b5 100644 --- a/crates/embers-protocol/src/lib.rs +++ b/crates/embers-protocol/src/lib.rs @@ -25,13 +25,14 @@ pub use framing::{ }; pub use types::{ BufferCreatedEvent, BufferDetachedEvent, BufferHistoryPlacement, BufferHistoryScope, - BufferLocation, BufferLocationAttachment, BufferLocationResponse, BufferRecord, - BufferRecordKind, BufferRecordState, BufferRequest, BufferResponse, BufferViewRecord, - BufferWithLocationResponse, BuffersResponse, ClientChangedEvent, ClientMessage, ClientRecord, - ClientRequest, ClientResponse, ClientsResponse, ErrorResponse, FloatingChangedEvent, - FloatingListResponse, FloatingRecord, FloatingRequest, FloatingResponse, FocusChangedEvent, - InputRequest, NodeBreakDestination, NodeChangedEvent, NodeJoinPlacement, NodeRecord, - NodeRecordKind, NodeRequest, OkResponse, PingRequest, PingResponse, RenderInvalidatedEvent, + BufferLocation, BufferLocationAttachment, BufferLocationResponse, BufferPipeChangedEvent, + BufferPipeRecord, BufferPipeState, BufferPipeStopReason, BufferRecord, BufferRecordKind, + BufferRecordState, BufferRequest, BufferResponse, BufferViewRecord, BufferWithLocationResponse, + BuffersResponse, ClientChangedEvent, ClientMessage, ClientRecord, ClientRequest, + ClientResponse, ClientsResponse, ErrorResponse, FloatingChangedEvent, FloatingListResponse, + FloatingRecord, FloatingRequest, FloatingResponse, FocusChangedEvent, InputRequest, + NodeBreakDestination, NodeChangedEvent, NodeJoinPlacement, NodeRecord, NodeRecordKind, + NodeRequest, OkResponse, PingRequest, PingResponse, RenderInvalidatedEvent, ScrollbackSliceResponse, ServerEnvelope, ServerEvent, ServerResponse, SessionClosedEvent, SessionCreatedEvent, SessionRecord, SessionRenamedEvent, SessionRequest, SessionSnapshot, SessionSnapshotResponse, SessionsResponse, SnapshotResponse, SplitRecord, SubscribeRequest, diff --git a/crates/embers-protocol/src/types.rs b/crates/embers-protocol/src/types.rs index f888d71..efeb61a 100644 --- a/crates/embers-protocol/src/types.rs +++ b/crates/embers-protocol/src/types.rs @@ -144,6 +144,17 @@ pub enum BufferRequest { placement: BufferHistoryPlacement, client_id: Option, }, + StartPipe { + request_id: RequestId, + buffer_id: BufferId, + command: Vec, + cwd: Option, + env: BTreeMap, + }, + StopPipe { + request_id: RequestId, + buffer_id: BufferId, + }, } impl BufferRequest { @@ -160,7 +171,9 @@ impl BufferRequest { | Self::ScrollbackSlice { request_id, .. } | Self::GetLocation { request_id, .. } | Self::Reveal { request_id, .. } - | Self::OpenHistory { request_id, .. } => *request_id, + | Self::OpenHistory { request_id, .. } + | Self::StartPipe { request_id, .. } + | Self::StopPipe { request_id, .. } => *request_id, } } } @@ -177,6 +190,57 @@ pub enum BufferHistoryPlacement { Floating, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BufferPipeState { + Running, + Stopped, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BufferPipeStopReason { + Requested, + PipeExited, + WriteFailed, + BufferExited, + RuntimeInterrupted, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct BufferPipeRecord { + pub command: Vec, + pub state: BufferPipeState, + pub pid: Option, + pub exit_code: Option, + pub stop_reason: Option, +} + +impl BufferPipeRecord { + pub fn validate(&self, field: &str) -> std::result::Result<(), String> { + if self.command.is_empty() { + return Err(format!("{field}.command must not be empty")); + } + if self.command[0].is_empty() { + return Err(format!("{field}.command first segment must not be empty")); + } + match self.state { + BufferPipeState::Running => { + if self.stop_reason.is_some() { + return Err(format!("{field}.stop_reason requires a stopped pipe")); + } + if self.exit_code.is_some() { + return Err(format!("{field}.exit_code requires a stopped pipe")); + } + } + BufferPipeState::Stopped => { + if self.stop_reason.is_none() { + return Err(format!("{field}.stop_reason required for stopped pipe")); + } + } + } + Ok(()) + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BufferLocationAttachment { Detached, @@ -592,6 +656,7 @@ pub struct BufferRecord { pub title: String, pub command: Vec, pub cwd: Option, + pub pipe: Option, pub kind: BufferRecordKind, pub state: BufferRecordState, pub pid: Option, @@ -608,6 +673,9 @@ pub struct BufferRecord { impl BufferRecord { pub fn validate(&self) -> std::result::Result<(), String> { + if let Some(pipe) = &self.pipe { + pipe.validate("buffer_record.pipe")?; + } match self.kind { BufferRecordKind::Pty => { if self.helper_source_buffer_id.is_some() { @@ -620,6 +688,9 @@ impl BufferRecord { } } BufferRecordKind::Helper => { + if self.pipe.is_some() { + return Err("buffer_record.kind=helper cannot set pipe".to_owned()); + } if self.helper_source_buffer_id.is_some() ^ self.helper_scope.is_some() { return Err( "buffer_record.kind=helper must set helper_source_buffer_id and helper_scope together" @@ -960,6 +1031,12 @@ pub struct BufferCreatedEvent { pub buffer: BufferRecord, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct BufferPipeChangedEvent { + pub session_id: Option, + pub buffer: BufferRecord, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub struct BufferDetachedEvent { pub buffer_id: BufferId, @@ -1006,6 +1083,7 @@ pub enum ServerEvent { SessionClosed(SessionClosedEvent), SessionRenamed(SessionRenamedEvent), BufferCreated(BufferCreatedEvent), + BufferPipeChanged(BufferPipeChangedEvent), BufferDetached(BufferDetachedEvent), NodeChanged(NodeChangedEvent), FloatingChanged(FloatingChangedEvent), @@ -1021,6 +1099,7 @@ impl ServerEvent { Self::SessionClosed(event) => Some(event.session_id), Self::SessionRenamed(event) => Some(event.session_id), Self::BufferCreated(_) => None, + Self::BufferPipeChanged(event) => event.session_id, Self::BufferDetached(_) => None, Self::NodeChanged(event) => Some(event.session_id), Self::FloatingChanged(event) => Some(event.session_id), diff --git a/crates/embers-protocol/tests/family_round_trip.rs b/crates/embers-protocol/tests/family_round_trip.rs index 6de6fa3..c34b43d 100644 --- a/crates/embers-protocol/tests/family_round_trip.rs +++ b/crates/embers-protocol/tests/family_round_trip.rs @@ -124,6 +124,21 @@ fn client_message_families_round_trip() { placement: BufferHistoryPlacement::Tab, client_id: Some(NonZeroU64::new(8).expect("non-zero client id")), }), + ClientMessage::Buffer(BufferRequest::StartPipe { + request_id: RequestId(158), + buffer_id: BufferId(20), + command: vec![ + "tee".to_owned(), + "-a".to_owned(), + "/tmp/embers.log".to_owned(), + ], + cwd: Some("/tmp".to_owned()), + env: std::collections::BTreeMap::from([("PIPE_MODE".to_owned(), "append".to_owned())]), + }), + ClientMessage::Buffer(BufferRequest::StopPipe { + request_id: RequestId(159), + buffer_id: BufferId(20), + }), ClientMessage::Node(NodeRequest::GetTree { request_id: RequestId(16), session_id: SessionId(10), @@ -327,6 +342,20 @@ fn server_envelope_families_round_trip() { let snapshot_without_zoom = sample_snapshot_without_zoom(); let session = snapshot.session.clone(); let buffers = snapshot.buffers.clone(); + let buffer_with_pipe = BufferRecord { + pipe: Some(BufferPipeRecord { + command: vec![ + "tee".to_owned(), + "-a".to_owned(), + "/tmp/embers.log".to_owned(), + ], + state: BufferPipeState::Running, + pid: Some(5001), + exit_code: None, + stop_reason: None, + }), + ..buffers[0].clone() + }; let detached_buffer = BufferRecord { attachment_node_id: None, ..buffers[2].clone() @@ -367,7 +396,7 @@ fn server_envelope_families_round_trip() { })), ServerEnvelope::Response(ServerResponse::Buffer(BufferResponse { request_id: RequestId(36), - buffer: buffers[0].clone(), + buffer: buffer_with_pipe.clone(), })), ServerEnvelope::Response(ServerResponse::BufferWithLocation( BufferWithLocationResponse::new( @@ -460,6 +489,10 @@ fn server_envelope_families_round_trip() { ServerEnvelope::Event(ServerEvent::BufferCreated(BufferCreatedEvent { buffer: buffers[0].clone(), })), + ServerEnvelope::Event(ServerEvent::BufferPipeChanged(BufferPipeChangedEvent { + session_id: Some(SessionId(10)), + buffer: buffer_with_pipe, + })), ServerEnvelope::Event(ServerEvent::BufferDetached(BufferDetachedEvent { buffer_id: BufferId(11), })), @@ -652,6 +685,7 @@ fn sample_buffer_record( title: format!("buffer-{id}"), command: vec!["bash".to_owned(), "-lc".to_owned(), "echo mux".to_owned()], cwd: Some("/tmp".to_owned()), + pipe: None, kind, state, pid: Some(4242), diff --git a/crates/embers-server/src/buffer_runtime.rs b/crates/embers-server/src/buffer_runtime.rs index e901a29..2eb2dd2 100644 --- a/crates/embers-server/src/buffer_runtime.rs +++ b/crates/embers-server/src/buffer_runtime.rs @@ -11,8 +11,9 @@ use std::os::unix::net::{UnixListener, UnixStream}; #[cfg(windows)] use std::os::windows::ffi::OsStrExt; use std::path::{Path, PathBuf}; -use std::process::{Command as ProcessCommand, Stdio}; +use std::process::{ChildStdin, Command as ProcessCommand, Stdio}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::mpsc::{SyncSender, TrySendError, sync_channel}; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; @@ -32,12 +33,14 @@ const CONNECT_RETRY_DELAY: Duration = Duration::from_millis(25); const CONNECT_RETRY_ATTEMPTS: usize = 1200; const STATUS_POLL_INTERVAL: Duration = Duration::from_millis(50); const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; +const KEEPER_PIPE_WRITE_QUEUE_CAPACITY: usize = 64; #[derive(Clone, Debug)] pub struct BufferRuntimeUpdate { pub sequence: u64, pub activity: ActivityState, pub title: Option>, + pub pipe: Option>, } #[derive(Clone, Debug)] @@ -48,6 +51,24 @@ pub struct BufferRuntimeStatus { pub title: Option, pub running: bool, pub exit_code: Option, + pub pipe: Option, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum BufferRuntimePipeStopReason { + Requested, + PipeExited, + WriteFailed, + BufferExited, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct BufferRuntimePipeStatus { + pub command: Vec, + pub running: bool, + pub pid: Option, + pub exit_code: Option, + pub stop_reason: Option, } #[derive(Clone)] @@ -91,11 +112,29 @@ struct KeeperConnection { #[derive(Serialize, Deserialize)] enum KeeperRequest { Status, - Write { bytes: Vec }, - Resize { size: PtySize }, - Snapshot { cwd: Option }, - VisibleSnapshot { cwd: Option }, - ScrollbackSlice { start_line: u64, line_count: u32 }, + Write { + bytes: Vec, + }, + Resize { + size: PtySize, + }, + Snapshot { + cwd: Option, + }, + VisibleSnapshot { + cwd: Option, + }, + ScrollbackSlice { + start_line: u64, + line_count: u32, + }, + StartPipe { + command: Vec, + cwd: Option, + env: BTreeMap, + }, + StopPipe, + StopPipeAfterExit, Kill, } @@ -105,6 +144,7 @@ enum KeeperResponse { Snapshot(KeeperSnapshot), VisibleSnapshot(TerminalSnapshot), ScrollbackSlice(KeeperScrollbackSlice), + PipeStatus(BufferRuntimePipeStatus), Ok, Error { message: String }, } @@ -117,6 +157,7 @@ struct KeeperStatus { title: Option, running: bool, exit_code: Option, + pipe: Option, } #[derive(Clone, Serialize, Deserialize)] @@ -140,6 +181,7 @@ struct KeeperRuntime { master: Mutex>, writer: Mutex>, killer: Mutex>, + pipe: Mutex>, sequence: AtomicU64, activity: Mutex, exit_code: Mutex>>, @@ -152,6 +194,23 @@ struct KeeperSurface { size: PtySize, } +struct KeeperPipe { + command: Vec, + child: ProcessCommandChild, + writer_tx: Option>>, + writer_state: Arc>, + writer_thread: Option>, + stop_reason: Option, + exit_code: Option, +} + +type ProcessCommandChild = std::process::Child; + +#[derive(Default)] +struct KeeperPipeWriterState { + failed: bool, +} + impl std::fmt::Debug for BufferRuntimeHandle { fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { formatter @@ -356,6 +415,51 @@ impl BufferRuntimeHandle { .map_err(|error| MuxError::internal(error.to_string()))? } + pub async fn start_pipe( + &self, + command: Vec, + cwd: Option, + env: BTreeMap, + ) -> Result { + validate_pipe_command(&command)?; + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.start_pipe(command, cwd, env) + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + + pub async fn stop_pipe(&self) -> Result { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.stop_pipe() + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + + pub(crate) async fn stop_pipe_after_exit(&self) -> Result { + let inner = self.inner.clone(); + tokio::task::spawn_blocking(move || { + let mut connection = inner + .connection + .lock() + .map_err(|_| MuxError::internal("buffer runtime connection lock poisoned"))?; + connection.stop_pipe_after_exit() + }) + .await + .map_err(|error| MuxError::internal(error.to_string()))? + } + pub async fn join_threads(&self) -> Result<()> { let inner = self.inner.clone(); tokio::task::spawn_blocking(move || inner.join_threads_blocking()) @@ -410,6 +514,7 @@ impl KeeperConnection { title: status.title, running: status.running, exit_code: status.exit_code, + pipe: status.pipe, }), other => Err(MuxError::protocol(format!( "unexpected runtime keeper status response: {other_kind}", @@ -484,6 +589,42 @@ impl KeeperConnection { ))), } } + + fn start_pipe( + &mut self, + command: Vec, + cwd: Option, + env: BTreeMap, + ) -> Result { + validate_pipe_command(&command)?; + match self.request(KeeperRequest::StartPipe { command, cwd, env })? { + KeeperResponse::PipeStatus(status) => Ok(status), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper start pipe response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn stop_pipe(&mut self) -> Result { + match self.request(KeeperRequest::StopPipe)? { + KeeperResponse::PipeStatus(status) => Ok(status), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper stop pipe response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } + + fn stop_pipe_after_exit(&mut self) -> Result { + match self.request(KeeperRequest::StopPipeAfterExit)? { + KeeperResponse::PipeStatus(status) => Ok(status), + other => Err(MuxError::protocol(format!( + "unexpected runtime keeper stop pipe after exit response: {other_kind}", + other_kind = keeper_response_kind(&other) + ))), + } + } } impl KeeperSurface { @@ -525,6 +666,161 @@ impl KeeperSurface { } } +impl KeeperPipe { + fn spawn( + command: Vec, + cwd: Option, + env: BTreeMap, + ) -> Result { + let Some(program) = command.first() else { + return Err(MuxError::invalid_input( + "buffer pipe command must not be empty", + )); + }; + let mut child = ProcessCommand::new(program); + child.args(&command[1..]); + if let Some(cwd) = cwd { + child.current_dir(cwd); + } + child.envs(env); + child.stdin(Stdio::piped()); + child.stdout(Stdio::null()); + child.stderr(Stdio::null()); + let mut child = child.spawn()?; + let stdin = child + .stdin + .take() + .ok_or_else(|| MuxError::internal("buffer pipe stdin was not piped"))?; + let (writer_tx, writer_rx) = sync_channel(KEEPER_PIPE_WRITE_QUEUE_CAPACITY); + let writer_state = Arc::new(Mutex::new(KeeperPipeWriterState::default())); + let writer_thread_state = writer_state.clone(); + let writer_pid = child.id(); + let writer_thread = thread::Builder::new() + .name(format!("keeper-pipe-writer-{writer_pid}")) + .spawn(move || keeper_pipe_write_loop(stdin, writer_rx, writer_thread_state)) + .map_err(|error| MuxError::internal(error.to_string()))?; + Ok(Self { + command, + child, + writer_tx: Some(writer_tx), + writer_state, + writer_thread: Some(writer_thread), + stop_reason: None, + exit_code: None, + }) + } + + fn status(&mut self) -> BufferRuntimePipeStatus { + self.refresh(); + BufferRuntimePipeStatus { + command: self.command.clone(), + running: self.exit_code.is_none() && self.stop_reason.is_none(), + pid: (self.exit_code.is_none() && self.stop_reason.is_none()).then(|| self.child.id()), + exit_code: self.exit_code, + stop_reason: self.stop_reason, + } + } + + fn refresh(&mut self) { + if self.exit_code.is_some() || self.stop_reason.is_some() { + return; + } + if self.writer_failed() { + let _ = self.terminate(BufferRuntimePipeStopReason::WriteFailed); + return; + } + if let Ok(Some(status)) = self.child.try_wait() { + self.exit_code = exit_status_code(status.into()); + self.stop_reason + .get_or_insert(BufferRuntimePipeStopReason::PipeExited); + self.close_writer(); + self.join_writer(); + } + } + + fn write(&mut self, bytes: &[u8]) { + if self.exit_code.is_some() || self.stop_reason.is_some() { + return; + } + if let Ok(Some(status)) = self.child.try_wait() { + self.exit_code = exit_status_code(status.into()); + self.stop_reason + .get_or_insert(BufferRuntimePipeStopReason::PipeExited); + self.close_writer(); + self.join_writer(); + return; + } + let Some(writer_tx) = self.writer_tx.as_ref() else { + return; + }; + match writer_tx.try_send(bytes.to_vec()) { + Ok(()) => {} + Err(TrySendError::Full(_)) | Err(TrySendError::Disconnected(_)) => { + self.record_write_failure(); + } + } + } + + fn stop(&mut self, reason: BufferRuntimePipeStopReason) -> Result { + self.refresh(); + if self.exit_code.is_some() || self.stop_reason.is_some() { + return Err(MuxError::conflict("buffer pipe is not running")); + } + self.terminate(reason)?; + Ok(self.status()) + } + + fn writer_failed(&self) -> bool { + self.writer_state + .lock() + .map(|state| state.failed) + .unwrap_or(true) + } + + fn record_write_failure(&mut self) { + if let Ok(mut state) = self.writer_state.lock() { + state.failed = true; + } + self.close_writer(); + } + + fn close_writer(&mut self) { + let _ = self.writer_tx.take(); + } + + fn join_writer(&mut self) { + if let Some(writer_thread) = self.writer_thread.take() { + let _ = writer_thread.join(); + } + } + + fn terminate(&mut self, reason: BufferRuntimePipeStopReason) -> Result<()> { + self.stop_reason = Some(reason); + self.close_writer(); + if let Err(error) = self.child.kill() + && error.kind() != std::io::ErrorKind::InvalidInput + { + return Err(error.into()); + } + if let Ok(status) = self.child.wait() { + self.exit_code = exit_status_code(status.into()); + } + self.join_writer(); + Ok(()) + } +} + +impl Drop for KeeperPipe { + fn drop(&mut self) { + if self.exit_code.is_none() && self.stop_reason.is_none() { + self.close_writer(); + let _ = self.child.kill(); + let _ = self.child.wait(); + } + self.join_writer(); + } +} + /// Maximum retries for PTY allocation in runtime keeper const KEEPER_PTY_MAX_RETRIES: usize = 3; @@ -608,6 +904,7 @@ pub fn run_runtime_keeper(cli: RuntimeKeeperCli) -> Result<()> { master: Mutex::new(pair.master), writer: Mutex::new(writer), killer: Mutex::new(killer), + pipe: Mutex::new(None), sequence: AtomicU64::new(0), activity: Mutex::new(ActivityState::Idle), exit_code: Mutex::new(None), @@ -699,6 +996,15 @@ fn handle_keeper_request( KeeperResponse::ScrollbackSlice(runtime.scrollback_slice(start_line, line_count)?), false, )), + KeeperRequest::StartPipe { command, cwd, env } => Ok(( + KeeperResponse::PipeStatus(runtime.start_pipe(command, cwd, env)?), + false, + )), + KeeperRequest::StopPipe => Ok((KeeperResponse::PipeStatus(runtime.stop_pipe()?), false)), + KeeperRequest::StopPipeAfterExit => Ok(( + KeeperResponse::PipeStatus(runtime.stop_pipe_after_exit()?), + false, + )), KeeperRequest::Kill => { runtime.kill()?; Ok((KeeperResponse::Ok, false)) @@ -734,6 +1040,12 @@ impl KeeperRuntime { .map_err(|_| MuxError::internal("runtime keeper activity lock poisoned"))?; let sequence = self.sequence.load(Ordering::Relaxed); let title = surface.backend.metadata().title.clone(); + let pipe = self + .pipe + .lock() + .map_err(|_| MuxError::internal("runtime keeper pipe lock poisoned"))? + .as_mut() + .map(KeeperPipe::status); Ok(KeeperStatus { pid: self.pid, sequence, @@ -741,6 +1053,7 @@ impl KeeperRuntime { title, running: exit_code.is_none(), exit_code: exit_code.flatten(), + pipe, }) } @@ -811,6 +1124,75 @@ impl KeeperRuntime { .kill() .map_err(|error| MuxError::pty(error.to_string())) } + + fn start_pipe( + &self, + command: Vec, + cwd: Option, + env: BTreeMap, + ) -> Result { + validate_pipe_command(&command)?; + self.ensure_running()?; + let mut pipe = self + .pipe + .lock() + .map_err(|_| MuxError::internal("runtime keeper pipe lock poisoned"))?; + if let Some(existing) = pipe.as_mut() + && existing.status().running + { + return Err(MuxError::conflict("buffer pipe is already running")); + } + *pipe = Some(KeeperPipe::spawn(command, cwd, env)?); + Ok(pipe + .as_mut() + .expect("pipe slot populated after spawn") + .status()) + } + + fn stop_pipe(&self) -> Result { + self.stop_pipe_with_reason(BufferRuntimePipeStopReason::Requested, true) + } + + fn stop_pipe_after_exit(&self) -> Result { + self.stop_pipe_with_reason(BufferRuntimePipeStopReason::BufferExited, false) + } + + fn stop_pipe_with_reason( + &self, + reason: BufferRuntimePipeStopReason, + require_runtime_running: bool, + ) -> Result { + if require_runtime_running { + self.ensure_running()?; + } + let mut pipe = self + .pipe + .lock() + .map_err(|_| MuxError::internal("runtime keeper pipe lock poisoned"))?; + let Some(pipe) = pipe.as_mut() else { + return Err(MuxError::conflict("buffer pipe is not running")); + }; + if !pipe.status().running { + return Err(MuxError::conflict("buffer pipe is not running")); + } + pipe.stop(reason) + } +} + +fn keeper_pipe_write_loop( + mut stdin: ChildStdin, + writer_rx: std::sync::mpsc::Receiver>, + writer_state: Arc>, +) { + for bytes in writer_rx { + if stdin.write_all(&bytes).and_then(|()| stdin.flush()).is_ok() { + continue; + } + if let Ok(mut state) = writer_state.lock() { + state.failed = true; + } + break; + } } fn keeper_read_loop(runtime: Arc, mut reader: Box) { @@ -823,7 +1205,13 @@ fn keeper_read_loop(runtime: Arc, mut reader: Box surface, Err(_) => break, }; - let activity = surface.route_output(&buffer[..read]); + let bytes = &buffer[..read]; + let activity = surface.route_output(bytes); + if let Ok(mut pipe) = runtime.pipe.lock() + && let Some(pipe) = pipe.as_mut() + { + pipe.write(bytes); + } runtime.sequence.fetch_add(1, Ordering::Relaxed); if let Ok(mut state) = runtime.activity.lock() { *state = activity; @@ -840,6 +1228,45 @@ fn keeper_wait_loop(runtime: Arc, mut child: Box Result<()> { + if command.is_empty() { + return Err(MuxError::invalid_input( + "buffer pipe command must not be empty", + )); + } + if command[0].is_empty() { + return Err(MuxError::invalid_input( + "buffer pipe command first segment must not be empty", + )); + } + Ok(()) +} + +fn notify_pipe_removed( + callbacks: &BufferRuntimeCallbacks, + buffer_id: BufferId, + last_sequence: &mut u64, + activity: ActivityState, + last_pipe: &mut Option, + saw_exit: bool, +) { + if last_pipe.is_none() && saw_exit { + return; + } + *last_sequence = last_sequence.saturating_add(1); + (callbacks.on_output)( + buffer_id, + BufferRuntimeUpdate { + sequence: *last_sequence, + activity, + title: None, + pipe: Some(None), + }, + ); + *last_pipe = None; } fn spawn_status_poller( @@ -853,6 +1280,7 @@ fn spawn_status_poller( let mut last_sequence = initial.sequence; let mut last_title = initial.title.clone(); let mut last_activity = initial.activity; + let mut last_pipe = initial.pipe.clone(); let mut saw_exit = !initial.running; while !inner.stop.load(Ordering::Relaxed) { @@ -865,6 +1293,14 @@ fn spawn_status_poller( Ok(status) => status, Err(error) => { error!(%error, %inner.buffer_id, "status poll failed"); + notify_pipe_removed( + &callbacks, + inner.buffer_id, + &mut last_sequence, + last_activity, + &mut last_pipe, + saw_exit, + ); (callbacks.on_exit)(inner.buffer_id, None); break; } @@ -874,19 +1310,23 @@ fn spawn_status_poller( if status.sequence != last_sequence || status.title != last_title || status.activity != last_activity + || status.pipe != last_pipe { let title = (status.title != last_title).then(|| status.title.clone()); + let pipe = (status.pipe != last_pipe).then(|| status.pipe.clone()); (callbacks.on_output)( inner.buffer_id, BufferRuntimeUpdate { sequence: status.sequence, activity: status.activity, title, + pipe, }, ); last_sequence = status.sequence; last_title = status.title.clone(); last_activity = status.activity; + last_pipe = status.pipe.clone(); } if !saw_exit && !status.running { @@ -1099,6 +1539,7 @@ fn keeper_response_kind(response: &KeeperResponse) -> &'static str { KeeperResponse::Snapshot(_) => "snapshot", KeeperResponse::VisibleSnapshot(_) => "visible_snapshot", KeeperResponse::ScrollbackSlice(_) => "scrollback_slice", + KeeperResponse::PipeStatus(_) => "pipe_status", KeeperResponse::Ok => "ok", KeeperResponse::Error { .. } => "error", } @@ -1139,6 +1580,7 @@ impl Drop for SocketCleanup { #[cfg(test)] mod tests { + use std::collections::BTreeMap; use std::io::Write; use std::os::unix::net::UnixStream; use std::sync::Arc; @@ -1149,8 +1591,10 @@ mod tests { use embers_core::{ActivityState, BufferId, MuxError}; use super::{ - BufferRuntimeCallbacks, BufferRuntimeInner, BufferRuntimeStatus, KeeperConnection, - MAX_FRAME_SIZE, RuntimeThreads, read_message, spawn_status_poller, + BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimeInner, BufferRuntimePipeStatus, + BufferRuntimePipeStopReason, BufferRuntimeStatus, KEEPER_PIPE_WRITE_QUEUE_CAPACITY, + KeeperConnection, KeeperPipe, MAX_FRAME_SIZE, RuntimeThreads, read_message, + spawn_status_poller, }; #[test] @@ -1220,6 +1664,33 @@ mod tests { assert!(error.to_string().contains("out of range")); } + #[tokio::test] + async fn start_pipe_rejects_empty_program_before_runtime_request() { + let (stream, peer) = UnixStream::pair().expect("create socket pair"); + drop(peer); + let handle = BufferRuntimeHandle { + inner: Arc::new(BufferRuntimeInner { + buffer_id: BufferId(1), + pid: None, + socket_path: "/tmp/test-buffer.sock".into(), + connection: std::sync::Mutex::new(KeeperConnection { stream }), + stop: std::sync::atomic::AtomicBool::new(false), + threads: std::sync::Mutex::new(RuntimeThreads::default()), + }), + }; + + let error = handle + .start_pipe(vec![String::new()], None, BTreeMap::new()) + .await + .expect_err("empty pipe program should reject before keeper request"); + + assert!(matches!(error, MuxError::InvalidInput(_))); + assert_eq!( + error.to_string(), + "invalid input: buffer pipe command first segment must not be empty" + ); + } + #[test] fn status_poller_exits_on_status_error() { let (stream, peer) = UnixStream::pair().expect("create socket pair"); @@ -1237,10 +1708,10 @@ mod tests { let poller = spawn_status_poller( inner, BufferRuntimeCallbacks { - on_output: Arc::new(move |buffer_id, _| { + on_output: Arc::new(move |buffer_id, update| { output_tx - .send(buffer_id) - .expect("send unexpected output notification"); + .send((buffer_id, update)) + .expect("send output notification"); }), on_exit: Arc::new(move |buffer_id, exit_code| { exit_tx @@ -1255,6 +1726,7 @@ mod tests { title: None, running: true, exit_code: None, + pipe: None, }, ) .expect("spawn poller"); @@ -1267,6 +1739,104 @@ mod tests { .expect("poller should report exit"), (BufferId(1), None) ); - assert!(output_rx.try_recv().is_err()); + let (buffer_id, update) = output_rx + .recv_timeout(Duration::from_secs(1)) + .expect("poller should report final pipe removal"); + assert_eq!(buffer_id, BufferId(1)); + assert_eq!(update.pipe, Some(None)); + } + + #[test] + fn status_poller_reports_pipe_removal_before_status_error_exit() { + let (stream, peer) = UnixStream::pair().expect("create socket pair"); + drop(peer); + let inner = Arc::new(BufferRuntimeInner { + buffer_id: BufferId(2), + pid: None, + socket_path: "/tmp/test-buffer.sock".into(), + connection: std::sync::Mutex::new(KeeperConnection { stream }), + stop: std::sync::atomic::AtomicBool::new(false), + threads: std::sync::Mutex::new(RuntimeThreads::default()), + }); + let (exit_tx, exit_rx) = mpsc::channel(); + let (output_tx, output_rx) = mpsc::channel(); + let poller = spawn_status_poller( + inner, + BufferRuntimeCallbacks { + on_output: Arc::new(move |buffer_id, update| { + output_tx + .send((buffer_id, update)) + .expect("send output notification"); + }), + on_exit: Arc::new(move |buffer_id, exit_code| { + exit_tx + .send((buffer_id, exit_code)) + .expect("send exit notification"); + }), + }, + BufferRuntimeStatus { + pid: None, + sequence: 7, + activity: ActivityState::Idle, + title: None, + running: true, + exit_code: None, + pipe: Some(BufferRuntimePipeStatus { + command: vec!["tee".to_owned()], + running: true, + pid: Some(123), + exit_code: None, + stop_reason: None, + }), + }, + ) + .expect("spawn poller"); + + poller.join().expect("poller exits cleanly"); + + let (buffer_id, update) = output_rx + .recv_timeout(Duration::from_secs(1)) + .expect("poller should report pipe removal"); + assert_eq!(buffer_id, BufferId(2)); + assert_eq!(update.sequence, 8); + assert_eq!(update.pipe, Some(None)); + assert_eq!( + exit_rx + .recv_timeout(Duration::from_secs(1)) + .expect("poller should report exit"), + (BufferId(2), None) + ); + } + + #[test] + fn keeper_pipe_write_uses_bounded_queue() { + let mut pipe = KeeperPipe::spawn( + vec![ + "/bin/sh".to_owned(), + "-lc".to_owned(), + "sleep 30".to_owned(), + ], + None, + std::collections::BTreeMap::new(), + ) + .expect("spawn pipe"); + let payload = vec![b'x'; 128 * 1024]; + + let started = Instant::now(); + for _ in 0..(KEEPER_PIPE_WRITE_QUEUE_CAPACITY + 2) { + pipe.write(&payload); + } + + assert!( + started.elapsed() < Duration::from_secs(1), + "pipe writes should not block when the consumer is slow" + ); + + let status = pipe.status(); + assert!(!status.running); + assert_eq!( + status.stop_reason, + Some(BufferRuntimePipeStopReason::WriteFailed) + ); } } diff --git a/crates/embers-server/src/lib.rs b/crates/embers-server/src/lib.rs index b467d7c..8b3c5c3 100644 --- a/crates/embers-server/src/lib.rs +++ b/crates/embers-server/src/lib.rs @@ -9,13 +9,15 @@ mod server; mod terminal_backend; pub use buffer_runtime::{ - BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimeStatus, BufferRuntimeUpdate, - RuntimeKeeperCli, run_runtime_keeper, + BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimePipeStatus, + BufferRuntimePipeStopReason, BufferRuntimeStatus, BufferRuntimeUpdate, RuntimeKeeperCli, + run_runtime_keeper, }; pub use config::{SOCKET_ENV_VAR, ServerConfig}; pub use model::{ - Buffer, BufferAttachment, BufferState, BufferViewNode, BufferViewState, ExitedBuffer, - FloatingWindow, InterruptedBuffer, Node, RunningBuffer, Session, SplitNode, TabEntry, TabsNode, + Buffer, BufferAttachment, BufferPipe, BufferPipeState, BufferPipeStopReason, BufferState, + BufferViewNode, BufferViewState, ExitedBuffer, FloatingWindow, InterruptedBuffer, Node, + RunningBuffer, Session, SplitNode, TabEntry, TabsNode, }; pub use server::{Server, ServerHandle}; pub use state::ServerState; diff --git a/crates/embers-server/src/model.rs b/crates/embers-server/src/model.rs index 56dc91b..3e41f55 100644 --- a/crates/embers-server/src/model.rs +++ b/crates/embers-server/src/model.rs @@ -39,6 +39,32 @@ pub enum BufferKind { Helper(HelperBuffer), } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct BufferPipe { + pub command: Vec, + pub state: BufferPipeState, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum BufferPipeState { + Running { + pid: Option, + }, + Stopped { + exit_code: Option, + reason: BufferPipeStopReason, + }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BufferPipeStopReason { + Requested, + PipeExited, + WriteFailed, + BufferExited, + RuntimeInterrupted, +} + #[derive(Clone)] pub struct Buffer { pub id: BufferId, @@ -53,6 +79,7 @@ pub struct Buffer { pub activity: ActivityState, pub last_snapshot_seq: u64, pub kind: BufferKind, + pub pipe: Option, pub created_at: Timestamp, } @@ -77,6 +104,7 @@ impl Buffer { activity: ActivityState::Idle, last_snapshot_seq: 0, kind: BufferKind::Pty, + pipe: None, created_at: Timestamp::now(), } } @@ -105,6 +133,7 @@ impl fmt::Debug for Buffer { .field("activity", &self.activity) .field("last_snapshot_seq", &self.last_snapshot_seq) .field("kind", &self.kind) + .field("pipe", &self.pipe) .field("created_at", &self.created_at) .finish() } @@ -123,6 +152,7 @@ impl PartialEq for Buffer { && self.activity == other.activity && self.last_snapshot_seq == other.last_snapshot_seq && self.kind == other.kind + && self.pipe == other.pipe && self.created_at == other.created_at } } diff --git a/crates/embers-server/src/protocol.rs b/crates/embers-server/src/protocol.rs index b2916a0..20e30bb 100644 --- a/crates/embers-server/src/protocol.rs +++ b/crates/embers-server/src/protocol.rs @@ -1,13 +1,15 @@ use embers_core::{MuxError, Result}; use embers_protocol::{ - BufferHistoryScope, BufferLocation, BufferRecord, BufferRecordKind, BufferRecordState, - BufferViewRecord, FloatingRecord, NodeRecord, NodeRecordKind, SessionRecord, SessionSnapshot, - SplitRecord, TabRecord, TabsRecord, + BufferHistoryScope, BufferLocation, BufferPipeRecord, BufferPipeState, + BufferPipeStopReason as ProtocolBufferPipeStopReason, BufferRecord, BufferRecordKind, + BufferRecordState, BufferViewRecord, FloatingRecord, NodeRecord, NodeRecordKind, SessionRecord, + SessionSnapshot, SplitRecord, TabRecord, TabsRecord, }; use crate::model::{ - Buffer, BufferAttachment, BufferKind, BufferState, FloatingWindow, HelperBufferScope, Node, - Session, + Buffer, BufferAttachment, BufferKind, BufferPipeState as ModelBufferPipeState, + BufferPipeStopReason as ModelBufferPipeStopReason, BufferState, FloatingWindow, + HelperBufferScope, Node, Session, }; use crate::state::ServerState; @@ -55,6 +57,39 @@ pub fn buffer_record(buffer: &Buffer) -> BufferRecord { .cwd .as_ref() .map(|path| path.to_string_lossy().into_owned()), + pipe: buffer.pipe.as_ref().map(|pipe| BufferPipeRecord { + command: pipe.command.clone(), + state: match pipe.state { + ModelBufferPipeState::Running { .. } => BufferPipeState::Running, + ModelBufferPipeState::Stopped { .. } => BufferPipeState::Stopped, + }, + pid: match pipe.state { + ModelBufferPipeState::Running { pid } => pid, + ModelBufferPipeState::Stopped { .. } => None, + }, + exit_code: match pipe.state { + ModelBufferPipeState::Running { .. } => None, + ModelBufferPipeState::Stopped { exit_code, .. } => exit_code, + }, + stop_reason: match pipe.state { + ModelBufferPipeState::Running { .. } => None, + ModelBufferPipeState::Stopped { reason, .. } => Some(match reason { + ModelBufferPipeStopReason::Requested => ProtocolBufferPipeStopReason::Requested, + ModelBufferPipeStopReason::PipeExited => { + ProtocolBufferPipeStopReason::PipeExited + } + ModelBufferPipeStopReason::WriteFailed => { + ProtocolBufferPipeStopReason::WriteFailed + } + ModelBufferPipeStopReason::BufferExited => { + ProtocolBufferPipeStopReason::BufferExited + } + ModelBufferPipeStopReason::RuntimeInterrupted => { + ProtocolBufferPipeStopReason::RuntimeInterrupted + } + }), + }, + }), kind, state, pid, diff --git a/crates/embers-server/src/server.rs b/crates/embers-server/src/server.rs index 0f930f7..05d012d 100644 --- a/crates/embers-server/src/server.rs +++ b/crates/embers-server/src/server.rs @@ -14,9 +14,9 @@ use embers_core::{ }; use embers_protocol::{ BufferCreatedEvent, BufferDetachedEvent, BufferHistoryPlacement, BufferHistoryScope, - BufferLocation, BufferLocationAttachment, BufferLocationResponse, BufferRequest, - BufferResponse, BufferWithLocationResponse, BuffersResponse, ClientChangedEvent, ClientMessage, - ClientRecord, ClientRequest, ClientResponse, ClientsResponse, ErrorResponse, + BufferLocation, BufferLocationAttachment, BufferLocationResponse, BufferPipeChangedEvent, + BufferRequest, BufferResponse, BufferWithLocationResponse, BuffersResponse, ClientChangedEvent, + ClientMessage, ClientRecord, ClientRequest, ClientResponse, ClientsResponse, ErrorResponse, FloatingChangedEvent, FloatingRequest, FloatingResponse, FocusChangedEvent, FrameType, InputRequest, NodeChangedEvent, OkResponse, PingResponse, ProtocolError, RawFrame, RenderInvalidatedEvent, ScrollbackSliceResponse, ServerEnvelope, ServerEvent, ServerResponse, @@ -31,14 +31,17 @@ use tokio::sync::{Mutex, Notify, mpsc, oneshot, watch}; use tokio::task::JoinHandle; use tracing::{Instrument, debug, error, info}; -use crate::model::{BufferKind, HelperBufferScope, Node}; +use crate::model::{ + BufferKind, BufferPipe, BufferPipeState, BufferPipeStopReason, HelperBufferScope, Node, +}; use crate::persist::{load_workspace, save_workspace}; use crate::protocol::{ buffer_location, buffer_record, floating_record, session_record, session_snapshot, }; use crate::{ - BufferAttachment, BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimeStatus, - BufferRuntimeUpdate, BufferState, ServerConfig, ServerState, TabEntry, + BufferAttachment, BufferRuntimeCallbacks, BufferRuntimeHandle, BufferRuntimePipeStatus, + BufferRuntimePipeStopReason, BufferRuntimeStatus, BufferRuntimeUpdate, BufferState, + ServerConfig, ServerState, TabEntry, }; #[derive(Debug)] @@ -470,7 +473,29 @@ impl Runtime { .attach_buffer_runtime(buffer.id, socket_path.clone()) .await { - Ok((runtime, status)) => { + Ok((runtime, mut status)) => { + if status.pipe.as_ref().is_some_and(|pipe| pipe.running) { + match runtime.stop_pipe().await { + Ok(pipe) => status.pipe = Some(pipe), + Err(error) => { + debug!( + %buffer.id, + socket_path = %socket_path.display(), + %error, + "failed to stop restored buffer pipe" + ); + let _ = runtime.kill().await; + status.pipe = None; + let mut state = self.state.lock().await; + let _ = state.set_buffer_runtime_socket_path(buffer.id, None); + let _ = state.mark_buffer_interrupted( + buffer.id, + buffer_pid_hint(&buffer.state), + ); + continue; + } + } + } let mut state = self.state.lock().await; let _ = state.set_buffer_runtime_socket_path(buffer.id, Some(socket_path.clone())); @@ -1197,6 +1222,29 @@ impl Runtime { } Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), }, + BufferRequest::StartPipe { + request_id, + buffer_id, + command, + cwd, + env, + } => match self.start_buffer_pipe(buffer_id, command, cwd, env).await { + Ok((buffer, events)) => ( + ServerResponse::Buffer(BufferResponse { request_id, buffer }), + events, + ), + Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), + }, + BufferRequest::StopPipe { + request_id, + buffer_id, + } => match self.stop_buffer_pipe(buffer_id).await { + Ok((buffer, events)) => ( + ServerResponse::Buffer(BufferResponse { request_id, buffer }), + events, + ), + Err(error) => (mux_error_response(Some(request_id), error), Vec::new()), + }, } } @@ -2008,6 +2056,100 @@ impl Runtime { Ok(()) } + async fn ensure_buffer_accepts_pipe(&self, buffer_id: BufferId) -> Result<()> { + let state = self.state.lock().await; + let buffer = state.buffer(buffer_id)?; + if matches!(&buffer.kind, BufferKind::Helper(_)) { + return Err(MuxError::conflict(format!( + "buffer {buffer_id} is read-only" + ))); + } + if matches!(buffer.state, BufferState::Exited(_)) { + return Err(MuxError::conflict(format!( + "buffer {buffer_id} has already exited" + ))); + } + Ok(()) + } + + async fn start_buffer_pipe( + &self, + buffer_id: BufferId, + command: Vec, + cwd: Option, + env: BTreeMap, + ) -> Result<(embers_protocol::BufferRecord, Vec)> { + if command.is_empty() { + return Err(MuxError::invalid_input( + "buffer pipe command must not be empty", + )); + } + self.ensure_buffer_accepts_pipe(buffer_id).await?; + let runtime = self.buffer_runtime(buffer_id).await?; + let pipe = runtime + .start_pipe(command, cwd.map(Into::into), env) + .await?; + let (buffer, session_id) = { + let mut state = self.state.lock().await; + let attachment = state.buffer(buffer_id)?.attachment.clone(); + { + let Some(record) = state.buffers.get_mut(&buffer_id) else { + return Err(MuxError::not_found(format!( + "buffer {buffer_id} was not found" + ))); + }; + record.pipe = Some(model_pipe_from_runtime(&pipe)); + } + let session_id = match attachment { + BufferAttachment::Attached(node_id) => Some(state.node(node_id)?.session_id()), + BufferAttachment::Detached => None, + }; + let buffer = buffer_record(state.buffer(buffer_id)?); + (buffer, session_id) + }; + Ok(( + buffer.clone(), + vec![ServerEvent::BufferPipeChanged(BufferPipeChangedEvent { + session_id, + buffer, + })], + )) + } + + async fn stop_buffer_pipe( + &self, + buffer_id: BufferId, + ) -> Result<(embers_protocol::BufferRecord, Vec)> { + self.ensure_buffer_accepts_pipe(buffer_id).await?; + let runtime = self.buffer_runtime(buffer_id).await?; + let pipe = runtime.stop_pipe().await?; + let (buffer, session_id) = { + let mut state = self.state.lock().await; + let attachment = state.buffer(buffer_id)?.attachment.clone(); + { + let Some(record) = state.buffers.get_mut(&buffer_id) else { + return Err(MuxError::not_found(format!( + "buffer {buffer_id} was not found" + ))); + }; + record.pipe = Some(model_pipe_from_runtime(&pipe)); + } + let session_id = match attachment { + BufferAttachment::Attached(node_id) => Some(state.node(node_id)?.session_id()), + BufferAttachment::Detached => None, + }; + let buffer = buffer_record(state.buffer(buffer_id)?); + (buffer, session_id) + }; + Ok(( + buffer.clone(), + vec![ServerEvent::BufferPipeChanged(BufferPipeChangedEvent { + session_id, + buffer, + })], + )) + } + async fn resolve_reveal_client_id( &self, connection_id: u64, @@ -2531,34 +2673,62 @@ impl Runtime { } async fn record_buffer_update(&self, buffer_id: BufferId, update: BufferRuntimeUpdate) { - let updated = { + let (render_invalidated, pipe_event) = { let mut state = self.state.lock().await; - let Some(buffer) = state.buffers.get_mut(&buffer_id) else { + let Some(existing) = state.buffers.get(&buffer_id) else { return; }; - if update.sequence <= buffer.last_snapshot_seq { - false - } else { - buffer.last_snapshot_seq = update.sequence; - buffer.activity = max_activity(buffer.activity, update.activity); - if let Some(title) = update.title { - match title { - Some(title) => buffer.title = title, - None => buffer.title.clear(), + let previous_pipe = existing.pipe.clone(); + let mut render_invalidated = false; + { + let buffer = state + .buffers + .get_mut(&buffer_id) + .expect("buffer still exists while update is applied"); + let sequence_advanced = update.sequence > buffer.last_snapshot_seq; + let sequence_current = update.sequence >= buffer.last_snapshot_seq; + if sequence_advanced { + buffer.last_snapshot_seq = update.sequence; + let next_activity = max_activity(buffer.activity, update.activity); + if next_activity != buffer.activity { + buffer.activity = next_activity; + } + if let Some(title) = update.title { + let next_title = title.unwrap_or_default(); + if next_title != buffer.title { + buffer.title = next_title; + } + } + render_invalidated = true; + } + if sequence_current && let Some(pipe) = update.pipe { + let next_pipe = pipe.as_ref().map(model_pipe_from_runtime); + if next_pipe != buffer.pipe { + buffer.pipe = next_pipe; + render_invalidated = true; } } - true } + let pipe_event = match state.buffer(buffer_id) { + Ok(buffer) if buffer.pipe != previous_pipe => { + buffer_pipe_changed_event(&state, buffer_id).ok() + } + _ => None, + }; + (render_invalidated, pipe_event) }; - if updated { - self.broadcast( - vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { - buffer_id, - })], - &[], - ) - .await; + let mut events = Vec::new(); + if render_invalidated { + events.push(ServerEvent::RenderInvalidated(RenderInvalidatedEvent { + buffer_id, + })); + } + if let Some(pipe_event) = pipe_event { + events.push(ServerEvent::BufferPipeChanged(pipe_event)); + } + if !events.is_empty() { + self.broadcast(events, &[]).await; } } @@ -2568,7 +2738,12 @@ impl Runtime { let runtime = self.buffer_runtimes.lock().await.remove(&buffer_id); drop(runtime); } - let updated = { + let runtime = if should_interrupt { + None + } else { + self.buffer_runtimes.lock().await.get(&buffer_id).cloned() + }; + let (updated, pipe_changed) = { let mut state = self.state.lock().await; let result = if should_interrupt { let pid = state @@ -2580,22 +2755,59 @@ impl Runtime { state.mark_buffer_exited(buffer_id, exit_code) }; match result { - Ok(()) => true, + Ok(()) => { + let pipe_reason = if should_interrupt { + BufferPipeStopReason::RuntimeInterrupted + } else { + BufferPipeStopReason::BufferExited + }; + let pipe_changed = state + .buffers + .get_mut(&buffer_id) + .is_some_and(|buffer| stop_buffer_pipe(buffer, pipe_reason, exit_code)); + (true, pipe_changed) + } Err(error) => { debug!(%buffer_id, %error, "buffer exited after state cleanup"); - false + (false, false) } } }; + if updated && pipe_changed && !should_interrupt { + match runtime { + Some(runtime) => match runtime.stop_pipe_after_exit().await { + Ok(pipe) => { + let mut state = self.state.lock().await; + if let Some(buffer) = state.buffers.get_mut(&buffer_id) { + buffer.pipe = Some(model_pipe_from_runtime(&pipe)); + } + } + Err(error) => { + debug!(%buffer_id, %error, "failed to stop buffer pipe after buffer exit"); + } + }, + None => { + debug!(%buffer_id, "buffer exited with a running pipe but no runtime handle"); + } + } + } + + let pipe_event = if updated && pipe_changed { + let state = self.state.lock().await; + buffer_pipe_changed_event(&state, buffer_id).ok() + } else { + None + }; + if updated { - self.broadcast( - vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { - buffer_id, - })], - &[], - ) - .await; + let mut events = vec![ServerEvent::RenderInvalidated(RenderInvalidatedEvent { + buffer_id, + })]; + if let Some(pipe_event) = pipe_event { + events.push(ServerEvent::BufferPipeChanged(pipe_event)); + } + self.broadcast(events, &[]).await; } } @@ -2611,6 +2823,7 @@ impl Runtime { sequence: status.sequence, activity: status.activity, title: Some(status.title.clone()), + pipe: Some(status.pipe.clone()), }, ) .await; @@ -3128,6 +3341,58 @@ fn protocol_error_to_mux(error: ProtocolError) -> MuxError { MuxError::protocol(error.to_string()) } +fn model_pipe_from_runtime(status: &BufferRuntimePipeStatus) -> BufferPipe { + BufferPipe { + command: status.command.clone(), + state: if status.running { + BufferPipeState::Running { pid: status.pid } + } else { + BufferPipeState::Stopped { + exit_code: status.exit_code, + reason: match status + .stop_reason + .unwrap_or(BufferRuntimePipeStopReason::PipeExited) + { + BufferRuntimePipeStopReason::Requested => BufferPipeStopReason::Requested, + BufferRuntimePipeStopReason::PipeExited => BufferPipeStopReason::PipeExited, + BufferRuntimePipeStopReason::WriteFailed => BufferPipeStopReason::WriteFailed, + BufferRuntimePipeStopReason::BufferExited => BufferPipeStopReason::BufferExited, + }, + } + }, + } +} + +fn stop_buffer_pipe( + buffer: &mut crate::Buffer, + reason: BufferPipeStopReason, + exit_code: Option, +) -> bool { + let Some(pipe) = buffer.pipe.as_mut() else { + return false; + }; + if matches!(pipe.state, BufferPipeState::Stopped { .. }) { + return false; + } + pipe.state = BufferPipeState::Stopped { exit_code, reason }; + true +} + +fn buffer_pipe_changed_event( + state: &ServerState, + buffer_id: BufferId, +) -> Result { + let buffer = state.buffer(buffer_id)?; + let session_id = match buffer.attachment { + BufferAttachment::Attached(node_id) => Some(state.node(node_id)?.session_id()), + BufferAttachment::Detached => None, + }; + Ok(BufferPipeChangedEvent { + session_id, + buffer: buffer_record(buffer), + }) +} + fn apply_runtime_status( state: &mut ServerState, buffer_id: BufferId, @@ -3135,6 +3400,7 @@ fn apply_runtime_status( ) { if let Some(buffer) = state.buffers.get_mut(&buffer_id) { buffer.last_snapshot_seq = status.sequence; + buffer.pipe = status.pipe.as_ref().map(model_pipe_from_runtime); } if let Some(title) = &status.title { let _ = state.set_buffer_title(buffer_id, title.clone()); @@ -3190,7 +3456,10 @@ mod tests { use super::{Runtime, ShutdownSignal, Subscription, wait_for_shutdown}; use crate::model::HelperBufferScope; - use crate::{BufferRuntimeUpdate, BufferState, ServerState}; + use crate::{ + BufferRuntimePipeStatus, BufferRuntimePipeStopReason, BufferRuntimeUpdate, BufferState, + ServerState, + }; use tokio::time::{Duration, timeout}; @@ -3383,6 +3652,7 @@ mod tests { sequence: 5, activity: ActivityState::Bell, title: Some(Some("stale-title".to_owned())), + pipe: None, }, ) .await; @@ -3406,6 +3676,7 @@ mod tests { sequence: 6, activity: ActivityState::Bell, title: Some(Some("fresh-title".to_owned())), + pipe: None, }, ) .await; @@ -3427,6 +3698,131 @@ mod tests { )); } + #[tokio::test] + async fn record_buffer_update_ignores_stale_pipe_updates() { + let runtime = Runtime::new( + ServerState::new(), + PathBuf::from("server.sock"), + PathBuf::from("workspace"), + PathBuf::from("runtime"), + BTreeMap::new(), + ); + let buffer_id = { + let mut state = runtime.state.lock().await; + let buffer_id = state.create_buffer("current-title", vec!["/bin/sh".to_owned()], None); + let buffer = state + .buffers + .get_mut(&buffer_id) + .expect("buffer is created"); + buffer.last_snapshot_seq = 5; + buffer_id + }; + let (sender, mut receiver) = mpsc::unbounded_channel(); + runtime.subscriptions.lock().await.insert( + 1, + Subscription { + connection_id: 1, + session_id: None, + sender, + }, + ); + + runtime + .record_buffer_update( + buffer_id, + BufferRuntimeUpdate { + sequence: 4, + activity: ActivityState::Idle, + title: None, + pipe: Some(Some(BufferRuntimePipeStatus { + command: vec!["tee".to_owned()], + running: false, + pid: None, + exit_code: Some(0), + stop_reason: Some(BufferRuntimePipeStopReason::PipeExited), + })), + }, + ) + .await; + + let buffer = runtime + .state + .lock() + .await + .buffer(buffer_id) + .expect("buffer exists") + .clone(); + assert!(buffer.pipe.is_none()); + assert!(receiver.try_recv().is_err()); + } + + #[tokio::test] + async fn record_buffer_update_applies_pipe_only_updates() { + let runtime = Runtime::new( + ServerState::new(), + PathBuf::from("server.sock"), + PathBuf::from("workspace"), + PathBuf::from("runtime"), + BTreeMap::new(), + ); + let buffer_id = { + let mut state = runtime.state.lock().await; + let buffer_id = state.create_buffer("current-title", vec!["/bin/sh".to_owned()], None); + let buffer = state + .buffers + .get_mut(&buffer_id) + .expect("buffer is created"); + buffer.last_snapshot_seq = 5; + buffer_id + }; + let (sender, mut receiver) = mpsc::unbounded_channel(); + runtime.subscriptions.lock().await.insert( + 1, + Subscription { + connection_id: 1, + session_id: None, + sender, + }, + ); + + runtime + .record_buffer_update( + buffer_id, + BufferRuntimeUpdate { + sequence: 5, + activity: ActivityState::Idle, + title: None, + pipe: Some(Some(BufferRuntimePipeStatus { + command: vec!["tee".to_owned()], + running: false, + pid: None, + exit_code: Some(0), + stop_reason: Some(BufferRuntimePipeStopReason::PipeExited), + })), + }, + ) + .await; + + let buffer = runtime + .state + .lock() + .await + .buffer(buffer_id) + .expect("buffer exists") + .clone(); + assert!(buffer.pipe.is_some()); + assert!(matches!( + receiver.try_recv(), + Ok(ServerEnvelope::Event(ServerEvent::RenderInvalidated(event))) + if event.buffer_id == buffer_id + )); + assert!(matches!( + receiver.try_recv(), + Ok(ServerEnvelope::Event(ServerEvent::BufferPipeChanged(event))) + if event.buffer.id == buffer_id + )); + } + #[tokio::test] async fn record_buffer_update_clears_title() { let runtime = Runtime::new( @@ -3454,6 +3850,7 @@ mod tests { sequence: 6, activity: ActivityState::Idle, title: Some(None), + pipe: None, }, ) .await; diff --git a/crates/embers-server/tests/client_sessions.rs b/crates/embers-server/tests/client_sessions.rs index 07374ee..333748a 100644 --- a/crates/embers-server/tests/client_sessions.rs +++ b/crates/embers-server/tests/client_sessions.rs @@ -1,9 +1,11 @@ use std::num::NonZeroU64; +use std::time::Instant; -use embers_core::{RequestId, init_test_tracing}; +use embers_core::{BufferId, RequestId, init_test_tracing}; use embers_protocol::{ - ClientMessage, ClientRequest, ClientResponse, ClientsResponse, ProtocolClient, ServerEnvelope, - ServerEvent, ServerResponse, SessionRequest, SessionSnapshotResponse, SubscribeRequest, + BufferRequest, BufferResponse, ClientMessage, ClientRequest, ClientResponse, ClientsResponse, + InputRequest, ProtocolClient, ServerEnvelope, ServerEvent, ServerResponse, SessionRequest, + SessionSnapshotResponse, SnapshotResponse, SubscribeRequest, }; use embers_server::{Server, ServerConfig}; use tempfile::tempdir; @@ -45,6 +47,58 @@ async fn request_clients(client: &mut ProtocolClient, request: ClientRequest) -> } } +async fn request_buffer(client: &mut ProtocolClient, request: BufferRequest) -> BufferResponse { + match client + .request(&ClientMessage::Buffer(request)) + .await + .expect("buffer request succeeds") + { + ServerResponse::Buffer(response) => response, + other => panic!("expected buffer response, got {other:?}"), + } +} + +async fn capture_buffer( + client: &mut ProtocolClient, + request_id: RequestId, + buffer_id: BufferId, +) -> SnapshotResponse { + match client + .request(&ClientMessage::Buffer(BufferRequest::Capture { + request_id, + buffer_id, + })) + .await + .expect("capture request succeeds") + { + ServerResponse::Snapshot(response) => response, + other => panic!("expected snapshot response, got {other:?}"), + } +} + +async fn wait_for_snapshot_line( + client: &mut ProtocolClient, + request_id: RequestId, + buffer_id: BufferId, + expected: &str, +) -> SnapshotResponse { + let deadline = Instant::now() + Duration::from_secs(2); + loop { + let snapshot = capture_buffer(client, request_id, buffer_id).await; + if snapshot.lines.iter().any(|line| line.contains(expected)) { + return snapshot; + } + if Instant::now() >= deadline { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + panic!( + "capture for buffer {buffer_id} did not contain expected line '{expected}' before timeout" + ); +} + async fn recv_event(client: &mut ProtocolClient) -> ServerEvent { let envelope = timeout(Duration::from_secs(2), client.recv()) .await @@ -210,3 +264,83 @@ async fn closing_session_clears_client_binding_and_retires_session_subscriptions handle.shutdown().await.expect("shutdown server"); } + +#[tokio::test] +async fn concurrent_input_from_multiple_clients_reaches_shared_buffer() { + init_test_tracing(); + + let tempdir = tempdir().expect("tempdir"); + let socket_path = tempdir.path().join("mux.sock"); + let handle = Server::new(ServerConfig::new(socket_path.clone())) + .start() + .await + .expect("start server"); + + let mut client_a = ProtocolClient::connect(&socket_path) + .await + .expect("connect client A"); + let mut client_b = ProtocolClient::connect(&socket_path) + .await + .expect("connect client B"); + + let buffer = request_buffer( + &mut client_a, + BufferRequest::Create { + request_id: RequestId(21), + title: Some("shared".to_owned()), + command: vec![ + "/bin/sh".to_owned(), + "-lc".to_owned(), + "printf 'ready\\n'; while IFS= read -r line; do printf 'seen:%s\\n' \"$line\"; done" + .to_owned(), + ], + cwd: None, + env: Default::default(), + }, + ) + .await + .buffer; + let buffer_id = buffer.id; + + let _ = wait_for_snapshot_line(&mut client_a, RequestId(22), buffer_id, "ready").await; + + let send_a_message = ClientMessage::Input(InputRequest::Send { + request_id: RequestId(23), + buffer_id, + bytes: b"from-a\n".to_vec(), + }); + let send_b_message = ClientMessage::Input(InputRequest::Send { + request_id: RequestId(24), + buffer_id, + bytes: b"from-b\n".to_vec(), + }); + let send_a = client_a.request(&send_a_message); + let send_b = client_b.request(&send_b_message); + let (response_a, response_b) = tokio::join!(send_a, send_b); + assert!(matches!( + response_a.expect("client A input succeeds"), + ServerResponse::Ok(_) + )); + assert!(matches!( + response_b.expect("client B input succeeds"), + ServerResponse::Ok(_) + )); + + let capture = + wait_for_snapshot_line(&mut client_a, RequestId(25), buffer_id, "seen:from-a").await; + let capture_text = capture.lines.join("\n"); + if capture_text.contains("seen:from-b") { + handle.shutdown().await.expect("shutdown server"); + return; + } + + let capture = + wait_for_snapshot_line(&mut client_a, RequestId(26), buffer_id, "seen:from-b").await; + let capture_text = capture.lines.join("\n"); + assert!( + capture_text.contains("seen:from-a"), + "capture should retain both client inputs, got {capture_text:?}" + ); + + handle.shutdown().await.expect("shutdown server"); +} diff --git a/crates/embers-server/tests/persistence.rs b/crates/embers-server/tests/persistence.rs index dc1d7ff..0b5e5a5 100644 --- a/crates/embers-server/tests/persistence.rs +++ b/crates/embers-server/tests/persistence.rs @@ -73,6 +73,33 @@ async fn wait_for_snapshot_line( ); } +async fn wait_for_running_buffer( + client: &mut ProtocolClient, + request_id: RequestId, + buffer_id: BufferId, +) -> embers_protocol::BufferRecord { + let deadline = Instant::now() + Duration::from_secs(2); + loop { + let response = request_buffer( + client, + BufferRequest::Get { + request_id, + buffer_id, + }, + ) + .await; + if response.buffer.state == BufferRecordState::Running { + return response.buffer; + } + if Instant::now() >= deadline { + break; + } + sleep(Duration::from_millis(25)).await; + } + + panic!("buffer {buffer_id} did not reach Running before timeout"); +} + #[tokio::test] async fn clean_restart_restores_workspace_and_keeps_live_buffers_running() { init_test_tracing(); @@ -248,3 +275,85 @@ async fn clean_restart_restores_workspace_and_keeps_live_buffers_running() { handle.shutdown().await.expect("shutdown restarted server"); } + +#[tokio::test] +async fn clean_restart_stops_ephemeral_buffer_pipes() { + init_test_tracing(); + + let tempdir = tempdir().expect("tempdir"); + let socket_path = tempdir.path().join("mux.sock"); + let config = ServerConfig::new(socket_path.clone()); + + let handle = Server::new(config.clone()) + .start() + .await + .expect("start server"); + let mut client = ProtocolClient::connect(&socket_path) + .await + .expect("connect client"); + + let buffer = request_buffer( + &mut client, + BufferRequest::Create { + request_id: RequestId(101), + title: Some("piped".to_owned()), + command: vec!["/bin/sh".to_owned()], + cwd: None, + env: Default::default(), + }, + ) + .await + .buffer; + let buffer_id = buffer.id; + let _ = wait_for_running_buffer(&mut client, RequestId(102), buffer_id).await; + + let started = request_buffer( + &mut client, + BufferRequest::StartPipe { + request_id: RequestId(103), + buffer_id, + command: vec!["/bin/cat".to_owned()], + cwd: None, + env: Default::default(), + }, + ) + .await + .buffer; + assert!( + started + .pipe + .as_ref() + .is_some_and(|pipe| pipe.state == embers_protocol::BufferPipeState::Running) + ); + + handle.shutdown().await.expect("shutdown server"); + + let handle = Server::new(config).start().await.expect("restart server"); + let mut client = ProtocolClient::connect(&socket_path) + .await + .expect("reconnect client"); + + let restored = request_buffer( + &mut client, + BufferRequest::Get { + request_id: RequestId(104), + buffer_id, + }, + ) + .await + .buffer; + assert_eq!(restored.state, BufferRecordState::Running); + let restored_pipe = restored + .pipe + .expect("restored buffer keeps stopped pipe metadata"); + assert_eq!( + restored_pipe.state, + embers_protocol::BufferPipeState::Stopped + ); + assert_eq!( + restored_pipe.stop_reason, + Some(embers_protocol::BufferPipeStopReason::Requested) + ); + + handle.shutdown().await.expect("shutdown restarted server"); +}