diff --git a/Cargo.lock b/Cargo.lock index 5162512..ab62f83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3618,7 +3618,7 @@ dependencies = [ [[package]] name = "tide-disco" -version = "0.9.6" +version = "0.9.7" dependencies = [ "anyhow", "ark-serialize", diff --git a/Cargo.toml b/Cargo.toml index 34d8920..fb848fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tide-disco" -version = "0.9.6" +version = "0.9.7" edition = "2021" authors = ["Espresso Systems "] description = "Discoverability for Tide" diff --git a/src/listener.rs b/src/listener.rs index da63e6b..ab97e10 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -7,13 +7,13 @@ use crate::StatusCode; use async_lock::Semaphore; use async_std::{ - net::TcpListener, + net::{TcpListener, TcpStream}, sync::Arc, task::{sleep, spawn}, }; use async_trait::async_trait; use derivative::Derivative; -use futures::stream::StreamExt; +use futures::{select, stream::StreamExt, FutureExt}; use std::{ fmt::{self, Display, Formatter}, io::{self, ErrorKind}, @@ -107,17 +107,26 @@ where spawn(async move { let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); + let detect_stream = stream.clone(); - let fut = async_h1::accept(stream, |mut req| async { - // Handle the request if we can get a permit. - if let Some(_guard) = permit.try_acquire() { - req.set_local_addr(local_addr); - req.set_peer_addr(peer_addr); - app.respond(req).await - } else { - // Otherwise, we are rate limited. Respond immediately with an - // error. - Ok(http::Response::new(StatusCode::TOO_MANY_REQUESTS)) + let fut = async_h1::accept(stream, |mut req| { + let detect_stream = detect_stream.clone(); + let permit = permit.clone(); + let app = app.clone(); + async move { + if let Some(_guard) = permit.try_acquire() { + req.set_local_addr(local_addr); + req.set_peer_addr(peer_addr); + select! { + result = app.respond(req).fuse() => result, + _ = wait_for_disconnect(detect_stream).fuse() => { + tracing::debug!("handler cancelled due to client disconnect"); + Err(tide::Error::from(io::Error::from(ErrorKind::ConnectionAborted))) + } + } + } else { + Ok(http::Response::new(StatusCode::TOO_MANY_REQUESTS)) + } } }); @@ -139,6 +148,24 @@ where } } +/// Resolves when the client disconnects by polling the stream with `peek`. +async fn wait_for_disconnect(stream: TcpStream) { + let mut buf = [0u8; 1]; + loop { + match stream.peek(&mut buf).await { + Ok(0) => { + tracing::debug!("client disconnected (EOF on peek)"); + return; + } + Err(e) => { + tracing::debug!(%e, "client disconnected (error on peek)"); + return; + } + Ok(_) => sleep(Duration::from_millis(100)).await, + } + } +} + impl ToListener for RateLimitListener where State: Clone + Send + Sync + 'static, @@ -179,6 +206,10 @@ mod test { }; use futures::future::{try_join_all, FutureExt}; use portpicker::pick_unused_port; + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; use toml::toml; use vbs::version::{StaticVersion, StaticVersionType}; @@ -234,4 +265,57 @@ mod test { assert_eq!(StatusCode::OK, res.status()); } } + + /// When a client disconnects , the handler future should be dropped + /// rather than running to completion. + #[async_std::test] + async fn test_handler_dropped_on_client_disconnect() { + let handler_completed = Arc::new(AtomicBool::new(false)); + + let mut app = App::<_, ServerError>::with_state(()); + let api_toml = toml! { + [route.slow] + PATH = ["/slow"] + METHOD = "GET" + }; + { + let flag = handler_completed.clone(); + let mut api = app + .module::("mod", api_toml) + .unwrap(); + api.get("slow", move |_req, _state| { + let flag = flag.clone(); + async move { + sleep(Duration::from_secs(5)).await; + flag.store(true, Ordering::SeqCst); + Ok(()) + } + .boxed() + }) + .unwrap(); + } + + let port = pick_unused_port().unwrap(); + spawn(app.serve( + RateLimitListener::with_port(port, 10), + StaticVer01::instance(), + )); + + sleep(Duration::from_secs(1)).await; + + let req_task = spawn(async move { + reqwest::Client::new() + .get(format!("http://localhost:{port}/mod/slow")) + .send() + .await + }); + + sleep(Duration::from_millis(200)).await; + + req_task.cancel().await; + + sleep(Duration::from_secs(6)).await; + + assert!(!handler_completed.load(Ordering::SeqCst),); + } }