diff --git a/src/active_query.rs b/src/active_query.rs index 9eaf82160..f9d170b3d 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -8,7 +8,7 @@ use crate::accumulator::{ }; use crate::hash::FxIndexSet; use crate::key::DatabaseKeyIndex; -use crate::runtime::Stamp; +use crate::runtime::{CancellationCount, Stamp}; use crate::sync::atomic::AtomicBool; use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap}; use crate::zalsa_local::{ @@ -70,6 +70,9 @@ pub(crate) struct ActiveQuery { /// If this query is a cycle head, iteration count of that cycle. iteration_count: IterationCount, + + /// Cancellation count captured when this query run started. + pub(crate) cancellation_count: CancellationCount, } impl ActiveQuery { @@ -179,7 +182,11 @@ impl ActiveQuery { } impl ActiveQuery { - fn new(database_key_index: DatabaseKeyIndex, iteration_count: IterationCount) -> Self { + fn new( + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + cancellation_count: CancellationCount, + ) -> Self { ActiveQuery { database_key_index, durability: Durability::MAX, @@ -190,6 +197,7 @@ impl ActiveQuery { tracked_struct_ids: Default::default(), cycle_heads: Default::default(), iteration_count, + cancellation_count, #[cfg(feature = "accumulator")] accumulated: Default::default(), #[cfg(feature = "accumulator")] @@ -208,6 +216,7 @@ impl ActiveQuery { ref mut tracked_struct_ids, ref mut cycle_heads, iteration_count, + cancellation_count, #[cfg(feature = "accumulator")] ref mut accumulated, #[cfg(feature = "accumulator")] @@ -241,6 +250,7 @@ impl ActiveQuery { #[cfg(feature = "accumulator")] accumulated_inputs, verified_final: AtomicBool::new(verified_final), + cancellation_count, extra, }; @@ -261,6 +271,7 @@ impl ActiveQuery { tracked_struct_ids, cycle_heads, iteration_count, + cancellation_count: _, #[cfg(feature = "accumulator")] accumulated, #[cfg(feature = "accumulator")] @@ -279,6 +290,7 @@ impl ActiveQuery { &mut self, new_database_key_index: DatabaseKeyIndex, new_iteration_count: IterationCount, + new_cancellation_count: CancellationCount, ) { let Self { database_key_index, @@ -290,6 +302,7 @@ impl ActiveQuery { tracked_struct_ids, cycle_heads, iteration_count, + cancellation_count, #[cfg(feature = "accumulator")] accumulated, #[cfg(feature = "accumulator")] @@ -300,6 +313,7 @@ impl ActiveQuery { *changed_at = Revision::start(); *untracked_read = false; *iteration_count = new_iteration_count; + *cancellation_count = new_cancellation_count; debug_assert!( input_outputs.is_empty(), "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" @@ -369,12 +383,16 @@ impl QueryStack { &mut self, database_key_index: DatabaseKeyIndex, iteration_count: IterationCount, + cancellation_count: CancellationCount, ) { if self.len < self.stack.len() { - self.stack[self.len].reset_for(database_key_index, iteration_count); + self.stack[self.len].reset_for(database_key_index, iteration_count, cancellation_count); } else { - self.stack - .push(ActiveQuery::new(database_key_index, iteration_count)); + self.stack.push(ActiveQuery::new( + database_key_index, + iteration_count, + cancellation_count, + )); } self.len += 1; } diff --git a/src/cycle.rs b/src/cycle.rs index d2ab35d0e..b10ce4db1 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -48,6 +48,7 @@ use std::iter::FusedIterator; use thin_vec::{ThinVec, thin_vec}; use crate::key::DatabaseKeyIndex; +use crate::runtime::CancellationCount; use crate::sync::OnceLock; use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use crate::{Id, Revision}; @@ -499,11 +500,13 @@ pub enum ProvisionalStatus<'db> { Provisional { iteration: IterationCount, verified_at: Revision, + cancellation_count: CancellationCount, cycle_heads: &'db CycleHeads, }, Final { iteration: IterationCount, verified_at: Revision, + cancellation_count: CancellationCount, }, } diff --git a/src/function.rs b/src/function.rs index 2c04f1aaa..0a59dba69 100644 --- a/src/function.rs +++ b/src/function.rs @@ -388,11 +388,13 @@ where ProvisionalStatus::Final { iteration, verified_at: memo.verified_at.load(), + cancellation_count: memo.revisions.cancellation_count, } } else { ProvisionalStatus::Provisional { iteration, verified_at: memo.verified_at.load(), + cancellation_count: memo.revisions.cancellation_count, cycle_heads: memo.cycle_heads(), } }) diff --git a/src/function/execute.rs b/src/function/execute.rs index 1cd7fa79f..e6684f0ba 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -125,6 +125,7 @@ where let database_key_index = claim_guard.database_key_index(); let zalsa = claim_guard.zalsa(); + let cancellation_count = zalsa.runtime().cancellation_count(); let id = database_key_index.key_index(); @@ -136,7 +137,9 @@ where let mut iteration_count = IterationCount::initial(); if let Some(old_memo) = opt_old_memo { - if old_memo.verified_at.load() == zalsa.current_revision() { + if old_memo.verified_at.load() == zalsa.current_revision() + && old_memo.revisions.cancellation_count == cancellation_count + { // The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query. // However, the locking doesn't handle the case where a thread fetches the result of a panicking // cycle head query **after** all locks were released. That's what we do here. @@ -162,8 +165,13 @@ where } } - let _poison_guard = - PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); + let _poison_guard = PoisonProvisionalIfPanicking::new( + self, + zalsa, + id, + memo_ingredient_index, + cancellation_count, + ); let (new_value, completed_query) = loop { let active_query = claim_guard @@ -386,6 +394,7 @@ where // * ensure the final returned memo depends on all inputs from all iterations. if old_memo.may_be_provisional() && old_memo.verified_at.load() == zalsa.current_revision() + && old_memo.revisions.cancellation_count == active_query.cancellation_count() { active_query.seed_iteration(&old_memo.revisions); } @@ -423,6 +432,7 @@ struct PoisonProvisionalIfPanicking<'a, C: Configuration> { zalsa: &'a Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, + cancellation_count: crate::runtime::CancellationCount, } impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { @@ -431,12 +441,14 @@ impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { zalsa: &'a Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, + cancellation_count: crate::runtime::CancellationCount, ) -> Self { Self { ingredient, zalsa, id, memo_ingredient_index, + cancellation_count, } } } @@ -447,6 +459,7 @@ impl Drop for PoisonProvisionalIfPanicking<'_, C> { let revisions = QueryRevisions::fixpoint_initial( self.ingredient.database_key_index(self.id), IterationCount::initial(), + self.cancellation_count, ); let memo = Memo::new(None, self.zalsa.current_revision(), revisions); diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ba70acf46..0363ce90f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -183,6 +183,10 @@ where }) }, CycleRecoveryStrategy::Fixpoint | CycleRecoveryStrategy::FallbackImmediate => { + let cancellation_count = zalsa_local + .active_query_cancellation_count() + .unwrap_or_else(|| zalsa_local.cancellation_count()); + // check if there's a provisional value for this query // Note we don't `validate_may_be_provisional` the memo here as we want to reuse an // existing provisional memo if it exists @@ -193,6 +197,7 @@ where // on the value OR a concurrent `Vec` for cycle heads. if memo.verified_at.load() == zalsa.current_revision() && memo.value.is_some() + && memo.revisions.cancellation_count == cancellation_count && memo.revisions.cycle_heads().contains(&database_key_index) { memo.revisions @@ -219,6 +224,7 @@ where .and_then(|old_memo| { if old_memo.verified_at.load() == zalsa.current_revision() && old_memo.value.is_some() + && old_memo.revisions.cancellation_count == cancellation_count { Some(old_memo.revisions.iteration()) } else { @@ -226,7 +232,11 @@ where } }) .unwrap_or(IterationCount::initial()); - let revisions = QueryRevisions::fixpoint_initial(database_key_index, iteration); + let revisions = QueryRevisions::fixpoint_initial( + database_key_index, + iteration, + cancellation_count, + ); let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); self.insert_memo( diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index e6ea963eb..ab02393eb 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -7,6 +7,7 @@ use crate::function::{Configuration, IngredientImpl, Reentrancy}; use std::sync::atomic::Ordering; use crate::key::DatabaseKeyIndex; +use crate::runtime::CancellationCount; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{QueryEdge, QueryEdgeKind, QueryOriginRef, QueryRevisions, ZalsaLocal}; use crate::{Id, Revision}; @@ -308,18 +309,21 @@ where ); let verified_at = memo.verified_at.load(); + let cancellation_count = memo.revisions.cancellation_count; validate_provisional( zalsa, database_key_index, &memo.revisions, verified_at, + cancellation_count, cycle_heads, ) || validate_same_iteration( zalsa, zalsa_local, database_key_index, verified_at, + cancellation_count, cycle_heads, ) } @@ -519,6 +523,7 @@ fn validate_provisional( database_key_index: DatabaseKeyIndex, memo_revisions: &QueryRevisions, memo_verified_at: Revision, + memo_cancellation_count: CancellationCount, cycle_heads: &CycleHeads, ) -> bool { crate::tracing::trace!("{database_key_index:?}: validate_provisional({database_key_index:?})",); @@ -537,10 +542,12 @@ fn validate_provisional( ProvisionalStatus::Final { iteration, verified_at, + cancellation_count, .. } => { // Only consider the cycle head if it is from the same revision as the memo - if verified_at != memo_verified_at { + if verified_at != memo_verified_at || cancellation_count != memo_cancellation_count + { return false; } @@ -571,6 +578,7 @@ fn validate_same_iteration( zalsa_local: &ZalsaLocal, memo_database_key_index: DatabaseKeyIndex, memo_verified_at: Revision, + memo_cancellation_count: CancellationCount, cycle_heads: &CycleHeads, ) -> bool { crate::tracing::trace!("validate_same_iteration({memo_database_key_index:?})",); @@ -582,6 +590,10 @@ fn validate_same_iteration( return false; } + if zalsa_local.active_query_cancellation_count() != Some(memo_cancellation_count) { + return false; + } + // Always return `false` for cycle initial values "unless" they are running in the same thread. if cycle_heads .iter_not_eq(memo_database_key_index) @@ -591,10 +603,10 @@ fn validate_same_iteration( // SAFETY: We do not access the query stack reentrantly. let on_stack = unsafe { zalsa_local.with_query_stack_unchecked(|stack| { - stack - .iter() - .rev() - .any(|query| query.database_key_index == memo_database_key_index) + stack.iter().rev().any(|query| { + query.database_key_index == memo_database_key_index + && query.cancellation_count == memo_cancellation_count + }) }) }; @@ -609,8 +621,11 @@ fn validate_same_iteration( head_iteration_count, memo_iteration_count: current_iteration_count, verified_at: head_verified_at, + cancellation_count: head_cancellation_count, } => { - if head_verified_at != memo_verified_at { + if head_verified_at != memo_verified_at + || head_cancellation_count != memo_cancellation_count + { return false; } diff --git a/src/function/memo.rs b/src/function/memo.rs index e4b243319..a26443653 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -10,6 +10,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::ingredient::WaitForResult; use crate::key::DatabaseKeyIndex; use crate::revision::AtomicRevision; +use crate::runtime::CancellationCount; use crate::sync::atomic::Ordering; use crate::table::memo::MemoTableWithTypesMut; use crate::zalsa::{MemoIngredientIndex, Zalsa}; @@ -358,6 +359,7 @@ pub(super) enum TryClaimHeadsResult { head_iteration_count: IterationCount, memo_iteration_count: IterationCount, verified_at: Revision, + cancellation_count: CancellationCount, }, /// The cycle head is not finalized, but it can be claimed. @@ -404,22 +406,26 @@ impl Iterator for TryClaimCycleHeadsIter<'_> { let provisional_status = ingredient .provisional_status(self.zalsa, head_key_index) .expect("cycle head memo to exist"); - let (current_iteration_count, verified_at) = match provisional_status { - ProvisionalStatus::Provisional { - iteration, - verified_at, - cycle_heads: _, - } => (iteration, verified_at), - ProvisionalStatus::Final { - iteration, - verified_at, - } => (iteration, verified_at), - }; + let (current_iteration_count, verified_at, cancellation_count) = + match provisional_status { + ProvisionalStatus::Provisional { + iteration, + verified_at, + cancellation_count, + cycle_heads: _, + } => (iteration, verified_at, cancellation_count), + ProvisionalStatus::Final { + iteration, + verified_at, + cancellation_count, + } => (iteration, verified_at, cancellation_count), + }; Some(TryClaimHeadsResult::Cycle { memo_iteration_count: current_iteration_count, head_iteration_count: head.iteration_count.load(), verified_at, + cancellation_count, }) } WaitForResult::Running(running) => { @@ -446,7 +452,7 @@ mod _memory_usage { // Memo's are stored a lot, make sure their size doesn't randomly increase. const _: [(); std::mem::size_of::>()] = - [(); std::mem::size_of::<[usize; 6]>()]; + [(); std::mem::size_of::<[usize; 7]>()]; struct DummyStruct; diff --git a/src/function/specify.rs b/src/function/specify.rs index 99539d375..47f2634ef 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -72,6 +72,9 @@ where #[cfg(feature = "accumulator")] accumulated_inputs: Default::default(), verified_final: AtomicBool::new(true), + cancellation_count: zalsa_local + .active_query_cancellation_count() + .expect("specify should be called from an active query"), extra: QueryRevisionsExtra::default(), }, stale_tracked_structs: Vec::new(), diff --git a/src/runtime.rs b/src/runtime.rs index 4ec8fee03..a33e4d74f 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,9 +8,29 @@ use crate::sync::thread::{self, ThreadId}; use crate::table::Table; use crate::zalsa::Zalsa; use crate::{Cancelled, Event, EventKind, Revision}; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; mod dependency_graph; +/// Counts cancellation requests issued while a database revision is active. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct CancellationCount(usize); + +/// Shared cancellation counter for every local handle of a database. +#[derive(Debug, Default)] +pub(crate) struct AtomicCancellationCount(AtomicUsize); + +impl AtomicCancellationCount { + pub(crate) fn load(&self) -> CancellationCount { + CancellationCount(self.0.load(Ordering::Acquire)) + } + + pub(crate) fn increment(&self) { + self.0.fetch_add(1, Ordering::AcqRel); + } +} + #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Runtime { /// Set to true when the current revision has been cancelled. @@ -19,6 +39,10 @@ pub struct Runtime { #[cfg_attr(feature = "persistence", serde(skip))] revision_cancelled: AtomicBool, + /// Increments whenever a query run is cancelled in the current revision. + #[cfg_attr(feature = "persistence", serde(skip))] + cancellation_count: Arc, + /// Stores the "last change" revision for values of each duration. /// This vector is always of length at least 1 (for Durability 0) /// but its total length depends on the number of durations. The @@ -193,6 +217,7 @@ impl Default for Runtime { Runtime { revisions: [Revision::start(); Durability::LEN], revision_cancelled: Default::default(), + cancellation_count: Default::default(), dependency_graph: Default::default(), table: Default::default(), } @@ -204,6 +229,7 @@ impl std::fmt::Debug for Runtime { fmt.debug_struct("Runtime") .field("revisions", &self.revisions) .field("revision_cancelled", &self.revision_cancelled) + .field("cancellation_count", &self.cancellation_count) .field("dependency_graph", &self.dependency_graph) .finish() } @@ -239,8 +265,17 @@ impl Runtime { self.revision_cancelled.load(Ordering::Acquire) } + pub(crate) fn cancellation_count(&self) -> CancellationCount { + self.cancellation_count.load() + } + + pub(crate) fn cancellation_counter(&self) -> Arc { + Arc::clone(&self.cancellation_count) + } + pub(crate) fn set_cancellation_flag(&self) { crate::tracing::trace!("set_cancellation_flag"); + self.cancellation_count.increment(); self.revision_cancelled.store(true, Ordering::Release); } diff --git a/src/storage.rs b/src/storage.rs index c2a7029ee..32d003122 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -22,6 +22,12 @@ pub struct StorageHandle { phantom: PhantomData Db>, } +impl StorageHandle { + fn new_zalsa_local(&self) -> ZalsaLocal { + ZalsaLocal::new(self.zalsa_impl.runtime().cancellation_counter()) + } +} + impl Clone for StorageHandle { fn clone(&self) -> Self { *self.coordinate.clones.lock() += 1; @@ -60,9 +66,10 @@ impl StorageHandle { } pub fn into_storage(self) -> Storage { + let zalsa_local = self.new_zalsa_local(); Storage { handle: self, - zalsa_local: ZalsaLocal::new(), + zalsa_local, } } } @@ -116,9 +123,11 @@ impl Storage { /// /// The `event_callback` function is invoked by the salsa runtime at various points during execution. pub fn new(event_callback: Option>) -> Self { + let handle = StorageHandle::new(event_callback); + let zalsa_local = handle.new_zalsa_local(); Self { - handle: StorageHandle::new(event_callback), - zalsa_local: ZalsaLocal::new(), + handle, + zalsa_local, } } @@ -224,9 +233,11 @@ impl StorageBuilder { /// Construct the [`Storage`] using the provided builder options. pub fn build(self) -> Storage { + let handle = StorageHandle::with_jars(self.event_callback, self.jars); + let zalsa_local = handle.new_zalsa_local(); Storage { - handle: StorageHandle::with_jars(self.event_callback, self.jars), - zalsa_local: ZalsaLocal::new(), + handle, + zalsa_local, } } } @@ -252,7 +263,7 @@ impl Clone for Storage { fn clone(&self) -> Self { Self { handle: self.handle.clone(), - zalsa_local: ZalsaLocal::new(), + zalsa_local: self.handle.new_zalsa_local(), } } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 466df5881..122064f04 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -18,7 +18,7 @@ use crate::active_query::{CompletedQuery, QueryStack}; use crate::cycle::{AtomicIterationCount, CycleHeads, IterationCount, empty_cycle_heads}; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; -use crate::runtime::Stamp; +use crate::runtime::{AtomicCancellationCount, CancellationCount, Stamp}; use crate::sync::atomic::AtomicBool; use crate::table::{PageIndex, Slot, Table}; use crate::tracked_struct::{Disambiguator, Identity, IdentityHash}; @@ -47,7 +47,10 @@ pub struct ZalsaLocal { /// A cancellation token that can be used to cancel a query computation for a specific local `Database`. #[derive(Default, Clone, Debug)] -pub struct CancellationToken(Arc); +pub struct CancellationToken { + state: Arc, + cancellation_count: Arc, +} impl CancellationToken { const CANCELLED_MASK: u8 = 0b01; @@ -55,39 +58,52 @@ impl CancellationToken { /// Inform the database to cancel the current query computation. pub fn cancel(&self) { - self.0.fetch_or(Self::CANCELLED_MASK, Ordering::Relaxed); + self.cancellation_count.increment(); + self.state.fetch_or(Self::CANCELLED_MASK, Ordering::Relaxed); } /// Check if the query computation has been requested to be cancelled. pub fn is_cancelled(&self) -> bool { - self.0.load(Ordering::Relaxed) & Self::CANCELLED_MASK != 0 + self.state.load(Ordering::Relaxed) & Self::CANCELLED_MASK != 0 } #[inline] fn set_cancellation_disabled(&self, disabled: bool) -> bool { let previous_disabled_bit = if disabled { - self.0.fetch_or(Self::DISABLED_MASK, Ordering::Relaxed) + self.state.fetch_or(Self::DISABLED_MASK, Ordering::Relaxed) } else { - self.0.fetch_and(!Self::DISABLED_MASK, Ordering::Relaxed) + self.state + .fetch_and(!Self::DISABLED_MASK, Ordering::Relaxed) }; previous_disabled_bit & Self::DISABLED_MASK != 0 } fn should_trigger_local_cancellation(&self) -> bool { - self.0.load(Ordering::Relaxed) == Self::CANCELLED_MASK + self.state.load(Ordering::Relaxed) == Self::CANCELLED_MASK } fn reset(&self) { - self.0.store(0, Ordering::Relaxed); + self.state.store(0, Ordering::Relaxed); + } + + fn new(cancellation_count: Arc) -> Self { + Self { + state: Default::default(), + cancellation_count, + } + } + + fn cancellation_count(&self) -> CancellationCount { + self.cancellation_count.load() } } impl ZalsaLocal { - pub(crate) fn new() -> Self { + pub(crate) fn new(cancellation_count: Arc) -> Self { ZalsaLocal { query_stack: RefCell::new(QueryStack::default()), most_recent_pages: UnsafeCell::new(FxHashMap::default()), - cancelled: CancellationToken::default(), + cancelled: CancellationToken::new(cancellation_count), } } @@ -184,7 +200,11 @@ impl ZalsaLocal { // SAFETY: We do not access the query stack reentrantly. unsafe { self.with_query_stack_unchecked_mut(|stack| { - stack.push_new_query(database_key_index, iteration_count); + stack.push_new_query( + database_key_index, + iteration_count, + self.cancellation_count(), + ); ActiveQueryGuard { local_state: self, @@ -253,6 +273,21 @@ impl ZalsaLocal { } } + pub(crate) fn cancellation_count(&self) -> CancellationCount { + self.cancelled.cancellation_count() + } + + pub(crate) fn active_query_cancellation_count(&self) -> Option { + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked(|stack| { + stack + .last() + .map(|active_query| active_query.cancellation_count) + }) + } + } + /// Add an output to the current query's list of dependencies /// /// Returns `Err` if not in a query. @@ -496,6 +531,10 @@ pub(crate) struct QueryRevisions { #[cfg_attr(feature = "persistence", serde(with = "persistence::verified_final"))] pub(super) verified_final: AtomicBool, + /// Cancellation count captured when this query run started. + #[cfg_attr(feature = "persistence", serde(skip))] + pub(super) cancellation_count: CancellationCount, + /// Lazily allocated state. pub(super) extra: QueryRevisionsExtra, } @@ -507,6 +546,7 @@ impl QueryRevisions { changed_at: _, durability: _, verified_final: _, + cancellation_count: _, origin, extra, #[cfg(feature = "accumulator")] @@ -681,7 +721,7 @@ impl fmt::Debug for QueryRevisionsExtraInner { #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] -const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; 4]>()]; +const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; 5]>()]; #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] @@ -689,7 +729,11 @@ const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 7 } else { 3 }]>()]; impl QueryRevisions { - pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, iteration: IterationCount) -> Self { + pub(crate) fn fixpoint_initial( + query: DatabaseKeyIndex, + iteration: IterationCount, + cancellation_count: CancellationCount, + ) -> Self { Self { changed_at: Revision::start(), durability: Durability::MAX, @@ -697,6 +741,7 @@ impl QueryRevisions { #[cfg(feature = "accumulator")] accumulated_inputs: Default::default(), verified_final: AtomicBool::new(false), + cancellation_count, extra: QueryRevisionsExtra::new( #[cfg(feature = "accumulator")] AccumulatedMap::default(), @@ -1250,6 +1295,17 @@ pub(crate) struct ActiveQueryGuard<'me> { } impl ActiveQueryGuard<'_> { + pub(crate) fn cancellation_count(&self) -> CancellationCount { + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.local_state.with_query_stack_unchecked(|stack| { + #[cfg(debug_assertions)] + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); + stack.last().unwrap().cancellation_count + }) + } + } + /// Initialize the tracked struct ids with the values from the prior execution. pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &[(Identity, Id)]) { // SAFETY: We do not access the query stack reentrantly. @@ -1361,6 +1417,7 @@ pub(crate) mod persistence { changed_at, durability, ref verified_final, + cancellation_count: _, ref extra, #[cfg(feature = "accumulator")] accumulated_inputs: _, // TODO: Support serializing accumulators diff --git a/tests/parallel/cancellation_in_fixpoint.rs b/tests/parallel/cancellation_in_fixpoint.rs new file mode 100644 index 000000000..a61daab0d --- /dev/null +++ b/tests/parallel/cancellation_in_fixpoint.rs @@ -0,0 +1,48 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; +use crate::sync::thread; + +#[salsa::tracked(cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + let value = query_b(db); + + db.signal(1); + db.wait_for(2); + cancellation_point(db); + + value +} + +#[salsa::tracked(cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + query_a(db) +} + +#[salsa::tracked] +fn cancellation_point(_db: &dyn KnobsDatabase) {} + +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { + 0 +} + +#[test] +fn cancellation_rejects_provisional_fixpoint_state() { + let mut db = Knobs::default(); + let db_worker = db.clone(); + + db.signal_on_did_cancel(2); + + let worker = thread::spawn(move || Cancelled::catch(|| query_a(&db_worker))); + + db.wait_for(1); + db.trigger_cancellation(); + + let cancelled = worker.join().unwrap(); + assert!(matches!(cancelled, Err(Cancelled::PendingWrite))); + + assert_eq!(query_a(&db), 0); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 399eaa7da..54e5fbb5f 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,7 @@ mod setup; mod signal; +mod cancellation_in_fixpoint; mod cancellation_token_cycle_nested; mod cancellation_token_multi_blocked; mod cancellation_token_recomputes;