Skip to content
172 changes: 163 additions & 9 deletions crates/request-sharing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,36 @@ where
}

fn spawn_gc(cache: Cache<Request, Fut>, label: String) {
let weak = Arc::downgrade(&cache);
tokio::task::spawn(async move {
loop {
Self::collect_garbage(&cache, &label);
tokio::time::sleep(Duration::from_millis(500)).await;
if let Some(cache) = weak.upgrade() {
Self::collect_garbage(&cache, &label);
} else {
return;
}
Comment thread
metalurgical marked this conversation as resolved.
}
});
}
}

impl<A, B: Future> Drop for RequestSharing<A, B> {
fn drop(&mut self) {
Metrics::get()
.request_sharing_cached_items
.with_label_values(&[&self.request_label])
.set(0);
if Arc::strong_count(&self.in_flight) == 1 {
Metrics::get()
.request_sharing_cached_items
.with_label_values(&[&self.request_label])
.set(0);
}
}
}

/// Returns a shallow copy (without any pending requests)
/// Returns a shallow copy sharing the same in-flight request cache.
impl<Request, Fut: Future> Clone for RequestSharing<Request, Fut> {
fn clone(&self) -> Self {
Self {
in_flight: Default::default(),
in_flight: self.in_flight.clone(),
request_label: self.request_label.clone(),
}
}
Expand Down Expand Up @@ -182,11 +189,11 @@ impl Metrics {

#[cfg(test)]
mod tests {
use super::*;
use {super::*, tokio::runtime::Handle};

#[tokio::test]
async fn shares_request() {
// Manually create [`RequestSharing`] so we can have fine grained control
// Manually create [`RequestSharing`] so we can have fine-grain control
// over the garbage collection.
let cache: Cache<u64, BoxFuture<u64>> = Default::default();
let label = "test".to_string();
Expand Down Expand Up @@ -217,4 +224,151 @@ mod tests {
// GC deleted all now unused futures.
assert!(sharing.in_flight.lock().unwrap().is_empty());
}

#[tokio::test]
async fn in_flight_futures_cache_is_shared_from_origin() {
let cache: Cache<u64, BoxFuture<u64>> = Default::default();
let label = "future sharing".to_string();
let original = RequestSharing {
in_flight: cache,
request_label: label.clone(),
};

// Create the origin future
let origin_future = original.shared_or_else(0, |_| futures::future::ready(1u64).boxed());
assert!(!origin_future.is_shared);

// The clone should use the original request future, instead of new assignment
let cloned = original.clone();
let shared_future = cloned.shared_or_else(0, |_| {
async { panic!("future cache is not shared") }.boxed()
});

// Check origin is reused in shared
assert!(shared_future.is_shared);

// Check same value is reached
assert_eq!(origin_future.await, 1);
assert_eq!(shared_future.await, 1);
}

#[tokio::test]
async fn in_flight_futures_cache_is_shared_from_clone() {
let cache: Cache<u64, BoxFuture<u64>> = Default::default();
let label = "future sharing".to_string();
let original = RequestSharing {
in_flight: cache,
request_label: label.clone(),
};
let cloned = original.clone();

// Create the future on clone
let cloned_future = cloned.shared_or_else(0, |_| futures::future::ready(1u64).boxed());
assert!(!cloned_future.is_shared);

// Origin should use the cloned request future, instead of new assignment
let origin_future = original.shared_or_else(0, |_| {
async { panic!("future cache is not shared") }.boxed()
});
assert!(origin_future.is_shared);

// Check same value is yielded
assert_eq!(cloned_future.await, 1);
assert_eq!(origin_future.await, 1);
}

#[tokio::test]
async fn gc_cleans_entries_on_clones() {
let cache: Cache<u64, BoxFuture<u64>> = Default::default();
let label = "gc shared".to_string();
let original = RequestSharing {
in_flight: cache,
request_label: label.clone(),
};
let cloned = original.clone();

// Create future via the clone and immediately await it.
let _pending1 = original
.shared_or_else(0, |_| futures::future::ready(0u64).boxed())
.await;

// Create a second future and don't await it, later assertion requires it to be
// unpolled to survive GC.
let _pending2 = original.shared_or_else(1, |_| futures::future::ready(0u64).boxed());

// Run GC
RequestSharing::collect_garbage(&original.in_flight, &label);

//Check GC
assert_eq!(cloned.in_flight.lock().unwrap().len(), 1);
assert!(!cloned.in_flight.lock().unwrap().contains_key(&0u64));
assert!(cloned.in_flight.lock().unwrap().contains_key(&1u64));
}

#[tokio::test]
async fn drop_does_not_corrupt_existing_entries() {
let cache: Cache<u64, BoxFuture<u64>> = Default::default();
let label = "drop".to_string();
let original = RequestSharing {
in_flight: cache,
request_label: label.clone(),
};
let pending = original.shared_or_else(0, |_| futures::future::ready(1u64).boxed());
{
let cloned = original.clone();
let cloned_future = cloned.shared_or_else(0, |_| {
async { panic!("future cache is not shared") }.boxed()
});
// Check cloned_future is shared
assert!(cloned_future.is_shared);
} // drop occurs here

// Check future in cache and that future still yields value
assert_eq!(original.in_flight.lock().unwrap().len(), 1);
assert_eq!(pending.await, 1);
}

#[tokio::test]
async fn gc_task_exits_when_all_handles_dropped() {
let initial_task_count = Handle::current().metrics().num_alive_tasks();

{
let _sharing = RequestSharing::<u64, BoxFuture<u64>>::labelled("gc finish".to_string());
// Yield to let the spawned GC task register.
tokio::task::yield_now().await;
assert_eq!(
Handle::current().metrics().num_alive_tasks(),
initial_task_count + 1
);
} // drop occurs here

tokio::time::sleep(Duration::from_millis(600)).await;

let final_task_count = Handle::current().metrics().num_alive_tasks();
assert_eq!(initial_task_count, final_task_count);
}

#[tokio::test]
async fn gauge_on_clone_drop_not_zeroed() {
let cache: Cache<u64, BoxFuture<u64>> = Default::default();
let label = "gauge".to_string();
let original = RequestSharing {
in_flight: cache,
request_label: label.clone(),
};

let _pending = original.shared_or_else(0, |_| futures::future::ready(1u64).boxed());
assert_eq!(Arc::strong_count(&original.in_flight), 1);

{
let _cloned = original.clone();
assert_eq!(Arc::strong_count(&original.in_flight), 2);
} // drop occurs here

assert_eq!(Arc::strong_count(&original.in_flight), 1);

// Since _pending remains, will still have an entry
assert_eq!(original.in_flight.lock().unwrap().len(), 1);
drop(original); // will now zero, exact value only testable in integration test
}
}
Loading