Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 105 additions & 12 deletions src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
use crate::StatusCode;
use async_lock::Semaphore;
use async_std::{
net::TcpListener,
net::{TcpListener, TcpStream},
sync::Arc,
task::{sleep, spawn},
};
use async_trait::async_trait;
use derivative::Derivative;
use futures::stream::StreamExt;
use futures::{select, stream::StreamExt, FutureExt};
use std::{
fmt::{self, Display, Formatter},
io::{self, ErrorKind},
Expand All @@ -26,6 +26,8 @@ use tide::{
Server,
};



/// TCP listener which accepts only a limited number of connections at a time.
///
/// This listener is based on `tide::listener::TcpListener` and should match the semantics of that
Expand Down Expand Up @@ -107,17 +109,26 @@ where
spawn(async move {
let local_addr = stream.local_addr().ok();
let peer_addr = stream.peer_addr().ok();
let detect_stream = stream.clone();

let fut = async_h1::accept(stream, |mut req| async {
// Handle the request if we can get a permit.
if let Some(_guard) = permit.try_acquire() {
req.set_local_addr(local_addr);
req.set_peer_addr(peer_addr);
app.respond(req).await
} else {
// Otherwise, we are rate limited. Respond immediately with an
// error.
Ok(http::Response::new(StatusCode::TOO_MANY_REQUESTS))
let fut = async_h1::accept(stream, |mut req| {
let detect_stream = detect_stream.clone();
let permit = permit.clone();
let app = app.clone();
async move {
if let Some(_guard) = permit.try_acquire() {
req.set_local_addr(local_addr);
req.set_peer_addr(peer_addr);
select! {
result = app.respond(req).fuse() => result,
_ = wait_for_disconnect(detect_stream).fuse() => {
tracing::debug!("handler cancelled due to client disconnect");
Err(tide::Error::from(io::Error::from(ErrorKind::ConnectionAborted)))
}
}
} else {
Ok(http::Response::new(StatusCode::TOO_MANY_REQUESTS))
}
}
});

Expand All @@ -139,6 +150,24 @@ where
}
}

/// Resolves when the client disconnects by polling the stream with `peek`.
async fn wait_for_disconnect(stream: TcpStream) {
let mut buf = [0u8; 1];
loop {
match stream.peek(&mut buf).await {
Ok(0) => {
tracing::debug!("client disconnected (EOF on peek)");
return;
}
Err(e) => {
tracing::debug!(%e, "client disconnected (error on peek)");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does an error in peek necessarily mean a client disconnect? If it's possible that future reads from the socket would succeed even after this failure (ie a transient failure), then we should not disconnect here

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, but I can match the exact error types

return;
}
Ok(_) => sleep(Duration::from_millis(100)).await,
Comment thread
imabdulbasit marked this conversation as resolved.
}
}
}

impl<State> ToListener<State> for RateLimitListener<State>
where
State: Clone + Send + Sync + 'static,
Expand Down Expand Up @@ -179,6 +208,10 @@ mod test {
};
use futures::future::{try_join_all, FutureExt};
use portpicker::pick_unused_port;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use toml::toml;
use vbs::version::{StaticVersion, StaticVersionType};

Expand Down Expand Up @@ -234,4 +267,64 @@ mod test {
assert_eq!(StatusCode::OK, res.status());
}
}

/// When a client disconnects , the handler future should be dropped
/// rather than running to completion.
#[async_std::test]
async fn test_handler_dropped_on_client_disconnect() {

let handler_completed = Arc::new(AtomicBool::new(false));

let mut app = App::<_, ServerError>::with_state(());
let api_toml = toml! {
[route.slow]
PATH = ["/slow"]
METHOD = "GET"
};
{
let flag = handler_completed.clone();
let mut api = app
.module::<ServerError, StaticVer01>("mod", api_toml)
.unwrap();
api.get("slow", move |_req, _state| {
let flag = flag.clone();
async move {
sleep(Duration::from_secs(5)).await;
flag.store(true, Ordering::SeqCst);
Ok(())
}
.boxed()
})
.unwrap();
}

let port = pick_unused_port().unwrap();
spawn(app.serve(
RateLimitListener::with_port(port, 10),
StaticVer01::instance(),
));


sleep(Duration::from_secs(1)).await;


let req_task = spawn(async move {
reqwest::Client::new()
.get(format!("http://localhost:{port}/mod/slow"))
.send()
.await
});


sleep(Duration::from_millis(200)).await;


req_task.cancel().await;

sleep(Duration::from_secs(6)).await;

assert!(
!handler_completed.load(Ordering::SeqCst),
);
Comment thread
imabdulbasit marked this conversation as resolved.
Outdated
}
}
Loading