diff --git a/crates/request-sharing/src/lib.rs b/crates/request-sharing/src/lib.rs index 35083a90b0..968fd30661 100644 --- a/crates/request-sharing/src/lib.rs +++ b/crates/request-sharing/src/lib.rs @@ -88,10 +88,19 @@ where } fn spawn_gc(cache: Cache, label: String) { + let weak = Arc::downgrade(&cache); tokio::task::spawn(async move { loop { - Self::collect_garbage(&cache, &label); tokio::time::sleep(Duration::from_millis(500)).await; + if let Some(cache) = weak.upgrade() { + Self::collect_garbage(&cache, &label); + } else { + Metrics::get() + .request_sharing_cached_items + .with_label_values(&[label]) + .set(0); + return; + } } }); } @@ -99,18 +108,20 @@ where impl Drop for RequestSharing { fn drop(&mut self) { - Metrics::get() - .request_sharing_cached_items - .with_label_values(&[&self.request_label]) - .set(0); + if Arc::strong_count(&self.in_flight) == 1 { + Metrics::get() + .request_sharing_cached_items + .with_label_values(&[&self.request_label]) + .set(0); + } } } -/// Returns a shallow copy (without any pending requests) +/// Returns a shallow copy sharing the same in-flight request cache. impl Clone for RequestSharing { fn clone(&self) -> Self { Self { - in_flight: Default::default(), + in_flight: self.in_flight.clone(), request_label: self.request_label.clone(), } } @@ -182,11 +193,11 @@ impl Metrics { #[cfg(test)] mod tests { - use super::*; + use {super::*, tokio::runtime::Handle}; #[tokio::test] async fn shares_request() { - // Manually create [`RequestSharing`] so we can have fine grained control + // Manually create [`RequestSharing`] so we can have fine-grain control // over the garbage collection. let cache: Cache> = Default::default(); let label = "test".to_string(); @@ -217,4 +228,150 @@ mod tests { // GC deleted all now unused futures. assert!(sharing.in_flight.lock().unwrap().is_empty()); } + + #[tokio::test] + async fn in_flight_futures_cache_is_shared_from_origin() { + let cache: Cache> = Default::default(); + let label = "future sharing".to_string(); + let original = RequestSharing { + in_flight: cache, + request_label: label.clone(), + }; + + // Create the origin future + let origin_future = original.shared_or_else(0, |_| futures::future::ready(1u64).boxed()); + assert!(!origin_future.is_shared); + + // The clone should use the original request future, instead of new assignment + let cloned = original.clone(); + let shared_future = cloned.shared_or_else(0, |_| { + async { panic!("future cache is not shared") }.boxed() + }); + + // Check origin is reused in shared + assert!(shared_future.is_shared); + + // Check same value is reached + assert_eq!(origin_future.await, 1); + assert_eq!(shared_future.await, 1); + } + + #[tokio::test] + async fn in_flight_futures_cache_is_shared_from_clone() { + let cache: Cache> = Default::default(); + let label = "future sharing".to_string(); + let original = RequestSharing { + in_flight: cache, + request_label: label.clone(), + }; + let cloned = original.clone(); + + // Create the future on clone + let cloned_future = cloned.shared_or_else(0, |_| futures::future::ready(1u64).boxed()); + assert!(!cloned_future.is_shared); + + // Origin should use the cloned request future, instead of new assignment + let origin_future = original.shared_or_else(0, |_| { + async { panic!("future cache is not shared") }.boxed() + }); + assert!(origin_future.is_shared); + + // Check same value is yielded + assert_eq!(cloned_future.await, 1); + assert_eq!(origin_future.await, 1); + } + + #[tokio::test] + async fn gc_cleans_entries_on_clones() { + let cache: Cache> = Default::default(); + let label = "gc shared".to_string(); + let original = RequestSharing { + in_flight: cache, + request_label: label.clone(), + }; + let cloned = original.clone(); + + // Create future via the clone and immediately await it. + let _pending1 = original + .shared_or_else(0, |_| futures::future::ready(0u64).boxed()) + .await; + + // Create a second future and don't await it, later assertion requires it to be + // unpolled to survive GC. + let _pending2 = original.shared_or_else(1, |_| futures::future::ready(0u64).boxed()); + + // Run GC + RequestSharing::collect_garbage(&original.in_flight, &label); + + //Check GC + assert_eq!(cloned.in_flight.lock().unwrap().len(), 1); + assert!(!cloned.in_flight.lock().unwrap().contains_key(&0u64)); + assert!(cloned.in_flight.lock().unwrap().contains_key(&1u64)); + } + + #[tokio::test] + async fn dropping_clone_does_not_corrupt_existing_entries() { + let cache: Cache> = Default::default(); + let label = "drop".to_string(); + let original = RequestSharing { + in_flight: cache, + request_label: label.clone(), + }; + let pending = original.shared_or_else(0, |_| futures::future::ready(1u64).boxed()); + { + let cloned = original.clone(); + let cloned_future = cloned.shared_or_else(0, |_| { + async { panic!("future cache is not shared") }.boxed() + }); + // Check cloned_future is shared + assert!(cloned_future.is_shared); + } // drop occurs here + + // Check future in cache and that future still yields value + assert_eq!(original.in_flight.lock().unwrap().len(), 1); + assert_eq!(pending.await, 1); + } + + #[tokio::test] + async fn gc_task_exits_when_all_handles_dropped() { + let initial_task_count = Handle::current().metrics().num_alive_tasks(); + + { + let _sharing = RequestSharing::>::labelled("gc finish".to_string()); + // Yield to let the spawned GC task register. + tokio::task::yield_now().await; + assert_eq!( + Handle::current().metrics().num_alive_tasks(), + initial_task_count + 1 + ); + } // drop occurs here + + tokio::time::sleep(Duration::from_millis(600)).await; + + let final_task_count = Handle::current().metrics().num_alive_tasks(); + assert_eq!(initial_task_count, final_task_count); + } + + #[tokio::test] + async fn dropping_clone_does_not_affect_shared_cache() { + let cache: Cache> = Default::default(); + let label = "gauge".to_string(); + let original = RequestSharing { + in_flight: cache, + request_label: label.clone(), + }; + + let _pending = original.shared_or_else(0, |_| futures::future::ready(1u64).boxed()); + assert_eq!(Arc::strong_count(&original.in_flight), 1); + + { + let _cloned = original.clone(); + assert_eq!(Arc::strong_count(&original.in_flight), 2); + } // drop occurs here + + assert_eq!(Arc::strong_count(&original.in_flight), 1); + + // Since _pending remains, will still have an entry + assert_eq!(original.in_flight.lock().unwrap().len(), 1); + } }