diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 70568b619b..87b6d82628 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -1025,17 +1025,8 @@ impl Bucket { pub async fn object_exists>(&self, path: S) -> Result { let command = Command::HeadObject; let request = RequestImpl::new(self, path.as_ref(), command).await?; - let response_data = match request.response_data(false).await { - Ok(response_data) => response_data, - Err(S3Error::HttpFailWithBody(status_code, error)) => { - if status_code == 404 { - return Ok(false); - } - return Err(S3Error::HttpFailWithBody(status_code, error)); - } - Err(e) => return Err(e), - }; - Ok(response_data.status_code() != 404) + let status_code = request.response_status().await?; + Ok(status_code != 404) } #[maybe_async::maybe_async] @@ -3113,11 +3104,75 @@ mod test { use crate::{Bucket, PostPolicy}; use http::header::{CACHE_CONTROL, HeaderMap, HeaderName, HeaderValue}; use std::env; + #[cfg(all(not(feature = "sync"), feature = "with-tokio"))] + use std::io::{Read, Write}; + #[cfg(all(not(feature = "sync"), feature = "with-tokio"))] + use std::net::TcpListener; + #[cfg(all(not(feature = "sync"), feature = "with-tokio"))] + use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }; + #[cfg(all(not(feature = "sync"), feature = "with-tokio"))] + use std::thread; fn init() { let _ = env_logger::builder().is_test(true).try_init(); } + #[cfg(all(not(feature = "sync"), feature = "with-tokio"))] + #[tokio::test] + async fn test_object_exists_404_does_not_retry() { + init(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let endpoint = format!("http://{}", listener.local_addr().unwrap()); + let requests = Arc::new(AtomicUsize::new(0)); + let request_count = Arc::clone(&requests); + + let server = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + request_count.fetch_add(1, Ordering::SeqCst); + + let mut buffer = [0; 2048]; + let _ = stream.read(&mut buffer).unwrap(); + stream + .write_all( + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + ) + .unwrap(); + }); + + crate::set_retries(1); + + let credentials = Credentials::new( + Some("test_access_key"), + Some("test_secret_key"), + None, + None, + None, + ) + .unwrap(); + let bucket = Bucket::new( + "test-bucket", + Region::Custom { + region: "us-east-1".to_owned(), + endpoint, + }, + credentials, + ) + .unwrap() + .with_path_style(); + + let exists = bucket.object_exists("/missing.txt").await.unwrap(); + + crate::set_retries(1); + server.join().unwrap(); + + assert!(!exists); + assert_eq!(requests.load(Ordering::SeqCst), 1); + } + fn test_aws_credentials() -> Credentials { Credentials::new( Some(&env::var("EU_AWS_ACCESS_KEY_ID").unwrap()), diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index c0a345d93c..bededa3efc 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -84,6 +84,47 @@ impl<'a> Request for SurfRequest<'a> { Ok(response) } + async fn response_status(&self) -> Result { + crate::retry! { + async { + let headers = self.headers().await?; + + let request = match self.command.http_verb() { + HttpMethod::Get => surf::Request::builder(Method::Get, self.url()?), + HttpMethod::Delete => surf::Request::builder(Method::Delete, self.url()?), + HttpMethod::Put => surf::Request::builder(Method::Put, self.url()?), + HttpMethod::Post => surf::Request::builder(Method::Post, self.url()?), + HttpMethod::Head => surf::Request::builder(Method::Head, self.url()?), + }; + + let mut request = request.body(self.request_body()?); + + for (name, value) in headers.iter() { + request = request.header( + HeaderName::from_bytes(AsRef::<[u8]>::as_ref(&name).to_vec()) + .expect("Could not parse heaeder name"), + HeaderValue::from_bytes(AsRef::<[u8]>::as_ref(&value).to_vec()) + .expect("Could not parse header value"), + ); + } + + let response = request + .send() + .await + .map_err(|e| S3Error::Surf(e.to_string()))?; + let status = u16::from(response.status()); + + if status == 404 { + Ok(status) + } else if cfg!(feature = "fail-on-err") && !response.status().is_success() { + Err(S3Error::HttpFail) + } else { + Ok(status) + } + }.await + } + } + async fn response_data(&self, etag: bool) -> Result { let mut response = crate::retry! {self.response().await}?; let status_code = response.status(); diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index ac065fd898..78e8bcf424 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -79,6 +79,43 @@ impl<'a> Request for AttoRequest<'a> { Ok(response) } + fn response_status(&self) -> Result { + crate::retry! { + { + let headers = self.headers()?; + let mut session = attohttpc::Session::new(); + + for (name, value) in headers.iter() { + session.header(HeaderName::from_bytes(name.as_ref())?, value.to_str()?); + } + + if let Some(timeout) = self.bucket.request_timeout { + session.timeout(timeout) + } + + let request = match self.command.http_verb() { + HttpMethod::Get => session.get(self.url()?), + HttpMethod::Delete => session.delete(self.url()?), + HttpMethod::Put => session.put(self.url()?), + HttpMethod::Post => session.post(self.url()?), + HttpMethod::Head => session.head(self.url()?), + }; + + let response = request.bytes(&self.request_body()?).send()?; + let status = response.status().as_u16(); + + if status == 404 { + Ok(status) + } else if cfg!(feature = "fail-on-err") && !response.status().is_success() { + let text = response.text()?; + Err(S3Error::HttpFailWithBody(status, text)) + } else { + Ok(status) + } + } + } + } + fn response_data(&self, etag: bool) -> Result { let response = crate::retry! {self.response()}?; let status_code = response.status().as_u16(); diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index 5686cc8dda..29769b1e48 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -210,6 +210,10 @@ pub trait Request { #[cfg(any(feature = "with-async-std", feature = "with-tokio"))] async fn response_data_to_stream(&self) -> Result; async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>; + async fn response_status(&self) -> Result { + let (_, status_code) = self.response_header().await?; + Ok(status_code) + } fn datetime(&self) -> OffsetDateTime; fn bucket(&self) -> Bucket; fn command(&self) -> Command<'_>; diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 498f19135b..34924aa481 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -117,6 +117,56 @@ impl<'a> Request for ReqwestRequest<'a> { Ok(response) } + async fn response_status(&self) -> Result { + retry! { + async { + let headers = self + .headers() + .await? + .iter() + .map(|(k, v)| { + ( + reqwest::header::HeaderName::from_str(k.as_str()), + reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()), + ) + }) + .filter(|(k, v)| k.is_ok() && v.is_ok()) + .map(|(k, v)| (k.unwrap(), v.unwrap())) + .collect(); + + let client = self.bucket.http_client(); + + let method = match self.command.http_verb() { + HttpMethod::Delete => reqwest::Method::DELETE, + HttpMethod::Get => reqwest::Method::GET, + HttpMethod::Post => reqwest::Method::POST, + HttpMethod::Put => reqwest::Method::PUT, + HttpMethod::Head => reqwest::Method::HEAD, + }; + + let request = client + .request(method, self.url()?.as_str()) + .headers(headers) + .body(self.request_body()?); + + let request = request.build()?; + let response = client.execute(request).await?; + let status = response.status().as_u16(); + + if status == 404 { + return Ok(status); + } + + if cfg!(feature = "fail-on-err") && !response.status().is_success() { + let text = response.text().await?; + return Err(S3Error::HttpFailWithBody(status, text)); + } + + Ok(status) + }.await + } + } + async fn response_data(&self, etag: bool) -> Result { let response = retry! {self.response().await }?; let status_code = response.status().as_u16();