From a16cc21c73229a57e65106412634fb831c1df38c Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 21 May 2026 13:01:21 +0100 Subject: [PATCH] Prototype --- src/active_query.rs | 36 +++++- src/function.rs | 44 ++++++- src/function/execute.rs | 143 ++++++++++----------- src/function/fetch.rs | 8 ++ src/function/specify.rs | 2 +- src/ingredient.rs | 5 + src/zalsa_local.rs | 86 +++++++++++-- tests/parallel/cancellation_in_fixpoint.rs | 92 +++++++++++++ tests/parallel/main.rs | 1 + 9 files changed, 326 insertions(+), 91 deletions(-) create mode 100644 tests/parallel/cancellation_in_fixpoint.rs diff --git a/src/active_query.rs b/src/active_query.rs index 9eaf82160..90f973998 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -68,6 +68,9 @@ pub(crate) struct ActiveQuery { /// Provisional cycle results that this query depends on. cycle_heads: CycleHeads, + /// Provisional memos that this query depends on. + provisional_memos: FxIndexSet, + /// If this query is a cycle head, iteration count of that cycle. iteration_count: IterationCount, } @@ -80,6 +83,7 @@ impl ActiveQuery { edges: &[QueryEdge], untracked_read: bool, active_tracked_ids: &[(Identity, Id)], + provisional_memos: &[DatabaseKeyIndex], ) { assert!(self.input_outputs.is_empty()); @@ -102,6 +106,8 @@ impl ActiveQuery { .mark_all_active(active_tracked_ids.iter().copied()); self.disambiguator_map .seed(active_tracked_ids.iter().map(|(id, _)| id)); + self.provisional_memos + .extend(provisional_memos.iter().copied()); } pub(super) fn take_cycle_heads(&mut self) -> CycleHeads { @@ -114,6 +120,7 @@ impl ActiveQuery { durability: Durability, changed_at: Revision, cycle_heads: &CycleHeads, + provisional_memos: &[DatabaseKeyIndex], #[cfg(feature = "accumulator")] has_accumulated: bool, #[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues, ) { @@ -121,6 +128,11 @@ impl ActiveQuery { self.changed_at = self.changed_at.max(changed_at); self.input_outputs.insert(QueryEdge::input(input)); self.cycle_heads.extend(cycle_heads); + if !cycle_heads.is_empty() { + self.provisional_memos.insert(input); + } + self.provisional_memos + .extend(provisional_memos.iter().copied()); #[cfg(feature = "accumulator")] { self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated { @@ -157,6 +169,10 @@ impl ActiveQuery { self.input_outputs.insert(QueryEdge::output(key)); } + pub(super) fn add_provisional_memo(&mut self, key: DatabaseKeyIndex) { + self.provisional_memos.insert(key); + } + /// True if the given key was output by this query. pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator { self.disambiguator_map.disambiguate(key) @@ -189,6 +205,7 @@ impl ActiveQuery { disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), cycle_heads: Default::default(), + provisional_memos: FxIndexSet::default(), iteration_count, #[cfg(feature = "accumulator")] accumulated: Default::default(), @@ -207,6 +224,7 @@ impl ActiveQuery { ref mut disambiguator_map, ref mut tracked_struct_ids, ref mut cycle_heads, + ref mut provisional_memos, iteration_count, #[cfg(feature = "accumulator")] ref mut accumulated, @@ -232,6 +250,7 @@ impl ActiveQuery { active_tracked_structs, mem::take(cycle_heads), iteration_count, + provisional_memos.drain(..).collect(), ); let revisions = QueryRevisions { @@ -250,7 +269,7 @@ impl ActiveQuery { } } - fn clear(&mut self) { + fn clear(&mut self) -> FxIndexSet { let Self { database_key_index: _, durability: _, @@ -260,6 +279,7 @@ impl ActiveQuery { disambiguator_map, tracked_struct_ids, cycle_heads, + provisional_memos, iteration_count, #[cfg(feature = "accumulator")] accumulated, @@ -270,9 +290,12 @@ impl ActiveQuery { disambiguator_map.clear(); tracked_struct_ids.clear(); *cycle_heads = Default::default(); + let provisional_memos = mem::take(provisional_memos); *iteration_count = IterationCount::initial(); #[cfg(feature = "accumulator")] accumulated.clear(); + + provisional_memos } fn reset_for( @@ -289,6 +312,7 @@ impl ActiveQuery { disambiguator_map, tracked_struct_ids, cycle_heads, + provisional_memos, iteration_count, #[cfg(feature = "accumulator")] accumulated, @@ -316,6 +340,10 @@ impl ActiveQuery { cycle_heads.is_empty(), "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" ); + debug_assert!( + provisional_memos.is_empty(), + "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" + ); #[cfg(feature = "accumulator")] { *accumulated_inputs = Default::default(); @@ -400,7 +428,11 @@ impl QueryStack { self.stack[self.len].top_into_revisions() } - pub(crate) fn pop(&mut self, key: DatabaseKeyIndex, #[cfg(debug_assertions)] push_len: usize) { + pub(crate) fn pop( + &mut self, + key: DatabaseKeyIndex, + #[cfg(debug_assertions)] push_len: usize, + ) -> FxIndexSet { #[cfg(debug_assertions)] assert_eq!(push_len, self.len(), "unbalanced push/pop"); debug_assert_ne!(self.len, 0, "too many pops"); diff --git a/src/function.rs b/src/function.rs index 2c04f1aaa..cb09a6e95 100644 --- a/src/function.rs +++ b/src/function.rs @@ -20,7 +20,7 @@ use crate::table::Table; use crate::table::memo::MemoTableTypes; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa}; -use crate::zalsa_local::{QueryEdge, QueryOriginRef}; +use crate::zalsa_local::{QueryEdge, QueryOriginRef, QueryRevisions, ZalsaLocal}; use crate::{Cycle, Id, Revision}; #[cfg(feature = "accumulator")] @@ -265,10 +265,17 @@ where fn insert_memo<'db>( &'db self, zalsa: &'db Zalsa, + zalsa_local: Option<&ZalsaLocal>, id: Id, mut memo: memo::Memo<'db, C>, memo_ingredient_index: MemoIngredientIndex, ) -> &'db memo::Memo<'db, C> { + let may_be_provisional = memo.may_be_provisional(); + if may_be_provisional { + memo.revisions + .add_provisional_memo(self.database_key_index(id)); + } + if let Some(tracked_struct_ids) = memo.revisions.tracked_struct_ids_mut() { tracked_struct_ids.shrink_to_fit(); } @@ -289,7 +296,15 @@ where unsafe { self.deleted_entries.push(old_value) }; } // SAFETY: memo has been inserted into the table - unsafe { self.extend_memo_lifetime(memo.as_ref()) } + let memo = unsafe { self.extend_memo_lifetime(memo.as_ref()) }; + + if may_be_provisional { + if let Some(zalsa_local) = zalsa_local { + zalsa_local.record_provisional_memo(self.database_key_index(id)); + } + } + + memo } #[inline] @@ -419,6 +434,31 @@ where memo.revisions.verified_final.store(true, Ordering::Release); } + fn invalidate_provisional_memo(&self, zalsa: &Zalsa, input: Id) { + let memo_ingredient_index = self.memo_ingredient_index(zalsa, input); + let Some(memo) = self.get_memo_from_table_for(zalsa, input, memo_ingredient_index) else { + return; + }; + + if !memo.may_be_provisional() || memo.value.is_none() { + return; + } + + let database_key_index = self.database_key_index(input); + let revisions = + QueryRevisions::fixpoint_initial(database_key_index, IterationCount::initial()); + + let memo = memo::Memo::new( + None, + // Keep this distinct from the current revision so a retry treats + // the memo as stale instead of same-revision provisional state. + Revision::max(), + revisions, + ); + + self.insert_memo(zalsa, None, input, memo, memo_ingredient_index); + } + fn flatten_cycle_head_dependencies( &self, zalsa: &Zalsa, diff --git a/src/function/execute.rs b/src/function/execute.rs index 1cd7fa79f..913929d77 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -8,11 +8,12 @@ use crate::function::{ClaimGuard, Configuration, IngredientImpl}; use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::WaitForResult; use crate::plumbing::ZalsaLocal; -use crate::sync::thread; use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa}; -use crate::zalsa_local::{ActiveQueryGuard, QueryEdge, QueryEdgeKind, QueryRevisions}; -use crate::{Cancelled, Cycle, tracing}; +use crate::zalsa_local::{ + ActiveQueryGuard, QueryEdge, QueryEdgeKind, QueryRevisions, invalidate_provisional_memos, +}; +use crate::{Cycle, tracing}; use crate::{DatabaseKeyIndex, Event, EventKind, Id}; impl IngredientImpl @@ -59,27 +60,25 @@ where let (new_value, active_query) = Self::execute_query( db, zalsa, - claim_guard - .zalsa_local() - .push_query(database_key_index, IterationCount::initial()), + claim_guard.zalsa_local().push_query( + zalsa, + database_key_index, + IterationCount::initial(), + ), opt_old_memo, ); (new_value, active_query.pop()) } CycleRecoveryStrategy::FallbackImmediate | CycleRecoveryStrategy::Fixpoint => { let zalsa_local = claim_guard.zalsa_local(); - let was_disabled = zalsa_local.set_cancellation_disabled(true); + let _cancellation_guard = CancellationDisabledGuard::new(zalsa_local); - let res = self.execute_maybe_iterate( + self.execute_maybe_iterate( db, opt_old_memo, &mut claim_guard, memo_ingredient_index, - ); - - zalsa_local.set_cancellation_disabled(was_disabled); - - res + ) } }; @@ -102,6 +101,7 @@ where let memo = self.insert_memo( zalsa, + Some(claim_guard.zalsa_local()), id, Memo::new( Some(new_value), @@ -111,7 +111,12 @@ where memo_ingredient_index, ); - if claim_guard.drop() { None } else { Some(memo) } + let mut provisional_guard = + ProvisionalMemoInvalidationGuard::new(zalsa, memo.revisions.provisional_memos()); + let refetch = claim_guard.drop(); + provisional_guard.defuse(); + + if refetch { None } else { Some(memo) } } fn execute_maybe_iterate<'db>( @@ -136,22 +141,7 @@ where let mut iteration_count = IterationCount::initial(); if let Some(old_memo) = opt_old_memo { - if old_memo.verified_at.load() == zalsa.current_revision() { - // The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query. - // However, the locking doesn't handle the case where a thread fetches the result of a panicking - // cycle head query **after** all locks were released. That's what we do here. - // We could consider re-executing the entire cycle but: - // a) It's tricky to ensure that all queries participating in the cycle will re-execute - // (we can't rely on `iteration_count` being updated for nested cycles because the nested cycles may have completed successfully). - // b) It's guaranteed that this query will panic again anyway. - // That's why we simply propagate the panic here. It simplifies our lives and it also avoids duplicate panic messages. - if old_memo.value.is_none() { - tracing::warn!( - "Propagating panic for cycle head that panicked in an earlier execution in that revision" - ); - Cancelled::PropagatedPanic.throw(); - } - + if old_memo.value.is_some() && old_memo.verified_at.load() == zalsa.current_revision() { // Only use the last provisional memo if it was a cycle head in the last iteration. This is to // force at least two executions. if old_memo.cycle_heads().contains(&database_key_index) { @@ -162,13 +152,11 @@ where } } - let _poison_guard = - PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); - let (new_value, completed_query) = loop { - let active_query = claim_guard - .zalsa_local() - .push_query(database_key_index, iteration_count); + let active_query = + claim_guard + .zalsa_local() + .push_query(zalsa, database_key_index, iteration_count); // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next @@ -343,6 +331,7 @@ where let new_memo = self.insert_memo( zalsa, + Some(claim_guard.zalsa_local()), id, Memo::new( Some(new_value), @@ -384,7 +373,8 @@ where // * ensure that tracked struct created during the previous iteration // (and are owned by the query) are alive even if the query in this iteration no longer creates them. // * ensure the final returned memo depends on all inputs from all iterations. - if old_memo.may_be_provisional() + if old_memo.value.is_some() + && old_memo.may_be_provisional() && old_memo.verified_at.load() == zalsa.current_revision() { active_query.seed_iteration(&old_memo.revisions); @@ -402,56 +392,52 @@ where } } -/// Replaces any inserted memo with a fixpoint initial memo without a value if the current thread panics. -/// -/// A regular query doesn't insert any memo if it panics and the query -/// simply gets re-executed if any later called query depends on the panicked query (and will panic again unless the query isn't deterministic). -/// -/// Unfortunately, this isn't the case for cycle heads because Salsa first inserts the fixpoint initial memo and later inserts -/// provisional memos for every iteration. Detecting whether a query has previously panicked -/// in `fetch` (e.g., `validate_same_iteration`) and requires re-execution is probably possible but not very straightforward -/// and it's easy to get it wrong, which results in infinite loops where `Memo::provisional_retry` keeps retrying to get the latest `Memo` -/// but `fetch` doesn't re-execute the query for reasons. -/// -/// Specifically, a Memo can linger after a panic, which is then incorrectly returned -/// by `fetch_cold_cycle` because it passes the `shallow_verified_memo` check instead of inserting -/// a new fix point initial value if that happens. -/// -/// We could insert a fixpoint initial value here, but it seems unnecessary. -struct PoisonProvisionalIfPanicking<'a, C: Configuration> { - ingredient: &'a IngredientImpl, - zalsa: &'a Zalsa, - id: Id, - memo_ingredient_index: MemoIngredientIndex, +struct CancellationDisabledGuard<'db> { + zalsa_local: &'db ZalsaLocal, + was_disabled: bool, } -impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { - fn new( - ingredient: &'a IngredientImpl, - zalsa: &'a Zalsa, - id: Id, - memo_ingredient_index: MemoIngredientIndex, - ) -> Self { +impl<'db> CancellationDisabledGuard<'db> { + fn new(zalsa_local: &'db ZalsaLocal) -> Self { + let was_disabled = zalsa_local.set_cancellation_disabled(true); Self { - ingredient, - zalsa, - id, - memo_ingredient_index, + zalsa_local, + was_disabled, } } } -impl Drop for PoisonProvisionalIfPanicking<'_, C> { +impl Drop for CancellationDisabledGuard<'_> { fn drop(&mut self) { - if thread::panicking() { - let revisions = QueryRevisions::fixpoint_initial( - self.ingredient.database_key_index(self.id), - IterationCount::initial(), - ); + self.zalsa_local + .set_cancellation_disabled(self.was_disabled); + } +} - let memo = Memo::new(None, self.zalsa.current_revision(), revisions); - self.ingredient - .insert_memo(self.zalsa, self.id, memo, self.memo_ingredient_index); +struct ProvisionalMemoInvalidationGuard<'db> { + zalsa: &'db Zalsa, + memos: Vec, + defused: bool, +} + +impl<'db> ProvisionalMemoInvalidationGuard<'db> { + fn new(zalsa: &'db Zalsa, memos: &[DatabaseKeyIndex]) -> Self { + Self { + zalsa, + memos: memos.to_vec(), + defused: false, + } + } + + fn defuse(&mut self) { + self.defused = true; + } +} + +impl Drop for ProvisionalMemoInvalidationGuard<'_> { + fn drop(&mut self) { + if !self.defused { + invalidate_provisional_memos(self.zalsa, self.memos.drain(..)); } } } @@ -721,6 +707,7 @@ fn try_complete_cycle_head( } *completed_query.revisions.verified_final.get_mut() = true; + completed_query.revisions.clear_provisional_memos(); zalsa.event(&|| { Event::new(EventKind::DidFinalizeCycle { diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ba70acf46..6f1190c9c 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -33,11 +33,18 @@ where self.eviction.record_use(id); + let provisional_memos = if memo.may_be_provisional() { + memo.revisions.provisional_memos() + } else { + &[] + }; + zalsa_local.report_tracked_read( database_key_index, memo.revisions.durability, memo.revisions.changed_at, memo.cycle_heads(), + provisional_memos, #[cfg(feature = "accumulator")] memo.revisions.accumulated().is_some(), #[cfg(feature = "accumulator")] @@ -231,6 +238,7 @@ where let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); self.insert_memo( zalsa, + Some(zalsa_local), id, Memo::new(Some(initial_value), zalsa.current_revision(), revisions), memo_ingredient_index, diff --git a/src/function/specify.rs b/src/function/specify.rs index 99539d375..ba77b12c2 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -99,7 +99,7 @@ where memo.tracing_debug(), key ); - self.insert_memo(zalsa, key, memo, memo_ingredient_index); + self.insert_memo(zalsa, Some(zalsa_local), key, memo, memo_ingredient_index); // Record that the current query *specified* a value for this cell. let database_key_index = self.database_key_index(key); diff --git a/src/ingredient.rs b/src/ingredient.rs index c792dc9d9..d2e7206d3 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -190,6 +190,11 @@ pub trait Ingredient: Any + fmt::Debug + Send + Sync { ); } + /// Invalidates a provisional memo left behind by an unwinding fixpoint query. + fn invalidate_provisional_memo(&self, _zalsa: &Zalsa, _input: Id) { + unreachable!("invalidate_provisional_memo should only be called on functions"); + } + /// Flattens the dependencies of a query with cycle handling that participates in a cycle. /// /// This query recursively walks the dependency graph of `id` and flattens input dependencies diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 466df5881..db94be098 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -176,11 +176,12 @@ impl ZalsaLocal { } #[inline] - pub(crate) fn push_query( - &self, + pub(crate) fn push_query<'db>( + &'db self, + zalsa: &'db Zalsa, database_key_index: DatabaseKeyIndex, iteration_count: IterationCount, - ) -> ActiveQueryGuard<'_> { + ) -> ActiveQueryGuard<'db> { // SAFETY: We do not access the query stack reentrantly. unsafe { self.with_query_stack_unchecked_mut(|stack| { @@ -188,6 +189,7 @@ impl ZalsaLocal { ActiveQueryGuard { local_state: self, + zalsa, database_key_index, #[cfg(debug_assertions)] push_len: stack.len(), @@ -307,6 +309,7 @@ impl ZalsaLocal { durability: Durability, changed_at: Revision, cycle_heads: &CycleHeads, + provisional_memos: &[DatabaseKeyIndex], #[cfg(feature = "accumulator")] has_accumulated: bool, #[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues, ) { @@ -326,6 +329,7 @@ impl ZalsaLocal { durability, changed_at, cycle_heads, + provisional_memos, #[cfg(feature = "accumulator")] has_accumulated, #[cfg(feature = "accumulator")] @@ -457,6 +461,27 @@ impl ZalsaLocal { pub(crate) fn set_cancellation_disabled(&self, was_disabled: bool) -> bool { self.cancelled.set_cancellation_disabled(was_disabled) } + + pub(crate) fn record_provisional_memo(&self, key: DatabaseKeyIndex) { + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_provisional_memo(key); + } + }) + } + } +} + +pub(crate) fn invalidate_provisional_memos( + zalsa: &Zalsa, + memos: impl IntoIterator, +) { + for database_key_index in memos { + let ingredient = zalsa.lookup_ingredient(database_key_index.ingredient_index()); + ingredient.invalidate_provisional_memo(zalsa, database_key_index.key_index()); + } } // Okay to implement as `ZalsaLocal`` is !Sync @@ -546,6 +571,7 @@ impl QueryRevisionsExtra { mut tracked_struct_ids: ThinVec<(Identity, Id)>, cycle_heads: CycleHeads, iteration: IterationCount, + mut provisional_memos: ThinVec, ) -> Self { #[cfg(feature = "accumulator")] let acc = accumulated.is_empty(); @@ -555,16 +581,19 @@ impl QueryRevisionsExtra { && tracked_struct_ids.is_empty() && cycle_heads.is_empty() && iteration.is_initial() + && provisional_memos.is_empty() { None } else { tracked_struct_ids.shrink_to_fit(); + provisional_memos.shrink_to_fit(); Some(Box::new(QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] accumulated, cycle_heads, tracked_struct_ids, + provisional_memos, iteration: iteration.into(), cycle_converged: false, })) @@ -603,6 +632,11 @@ struct QueryRevisionsExtraInner { // be created with new IDs anyways. tracked_struct_ids: ThinVec<(Identity, Id)>, + /// Provisional memos that contributed to this result and should be + /// invalidated if the surrounding fixpoint query unwinds. + #[cfg_attr(feature = "persistence", serde(skip))] + provisional_memos: ThinVec, + /// This result was computed based on provisional values from /// these cycle heads. The "cycle head" is the query responsible /// for managing a fixpoint iteration. In a cycle like @@ -628,6 +662,7 @@ impl QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] accumulated, tracked_struct_ids, + provisional_memos, cycle_heads, iteration: _, cycle_converged: _, @@ -637,7 +672,9 @@ impl QueryRevisionsExtraInner { let b = accumulated.allocation_size(); #[cfg(not(feature = "accumulator"))] let b = 0; - b + cycle_heads.allocation_size() + std::mem::size_of_val(tracked_struct_ids.as_slice()) + b + cycle_heads.allocation_size() + + std::mem::size_of_val(tracked_struct_ids.as_slice()) + + std::mem::size_of_val(provisional_memos.as_slice()) } } @@ -686,7 +723,7 @@ const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::< #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] const _: [(); std::mem::size_of::()] = - [(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 7 } else { 3 }]>()]; + [(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 8 } else { 4 }]>()]; impl QueryRevisions { pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, iteration: IterationCount) -> Self { @@ -703,6 +740,7 @@ impl QueryRevisions { ThinVec::default(), CycleHeads::initial(query, iteration), iteration, + std::iter::once(query).collect(), ), } } @@ -725,6 +763,25 @@ impl QueryRevisions { } } + pub(crate) fn provisional_memos(&self) -> &[DatabaseKeyIndex] { + self.extra() + .map(|extra| &*extra.provisional_memos) + .unwrap_or_default() + } + + pub(crate) fn add_provisional_memo(&mut self, key: DatabaseKeyIndex) { + let extra = self.get_or_insert_extra(); + if !extra.provisional_memos.contains(&key) { + extra.provisional_memos.push(key); + } + } + + pub(crate) fn clear_provisional_memos(&mut self) { + if let Some(extra) = &mut self.extra.0 { + extra.provisional_memos.clear(); + } + } + /// Sets the `CycleHeads` for this query. pub(crate) fn set_cycle_heads(&mut self, cycle_heads: CycleHeads) { match &mut self.extra.0 { @@ -736,6 +793,7 @@ impl QueryRevisions { ThinVec::default(), cycle_heads, IterationCount::default(), + ThinVec::default(), ); } }; @@ -784,6 +842,7 @@ impl QueryRevisions { #[cfg(feature = "accumulator")] accumulated: AccumulatedMap::default(), tracked_struct_ids: ThinVec::default(), + provisional_memos: ThinVec::default(), cycle_heads: empty_cycle_heads().clone(), iteration: IterationCount::default().into(), cycle_converged: false, @@ -1244,6 +1303,7 @@ pub(crate) fn output_edges( /// destructor will also remove the query. pub(crate) struct ActiveQueryGuard<'me> { local_state: &'me ZalsaLocal, + zalsa: &'me Zalsa, #[cfg(debug_assertions)] push_len: usize, pub(crate) database_key_index: DatabaseKeyIndex, @@ -1274,6 +1334,7 @@ impl ActiveQueryGuard<'_> { ); let tracked_ids = previous.tracked_struct_ids(); + let provisional_memos = previous.provisional_memos(); // SAFETY: We do not access the query stack reentrantly. unsafe { @@ -1281,7 +1342,14 @@ impl ActiveQueryGuard<'_> { #[cfg(debug_assertions)] assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); - frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids); + frame.seed_iteration( + durability, + changed_at, + edges, + untracked_read, + tracked_ids, + provisional_memos, + ); }) } } @@ -1326,15 +1394,17 @@ impl ActiveQueryGuard<'_> { impl Drop for ActiveQueryGuard<'_> { fn drop(&mut self) { // SAFETY: We do not access the query stack reentrantly. - unsafe { + let provisional_memos = unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { stack.pop( self.database_key_index, #[cfg(debug_assertions)] self.push_len, - ); + ) }) }; + + invalidate_provisional_memos(self.zalsa, provisional_memos); } } diff --git a/tests/parallel/cancellation_in_fixpoint.rs b/tests/parallel/cancellation_in_fixpoint.rs new file mode 100644 index 000000000..88b409a53 --- /dev/null +++ b/tests/parallel/cancellation_in_fixpoint.rs @@ -0,0 +1,92 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +use std::panic::catch_unwind; + +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; +use crate::sync::thread; + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::tracked(cycle_fn = cycle_fn, cycle_initial = cycle_initial)] +fn cycle_a(db: &dyn KnobsDatabase, input: Input) -> u32 { + let value = cycle_b(db, input); + + db.signal(1); + db.wait_for(2); + cancellation_point(db, input); + + value +} + +#[salsa::tracked(cycle_fn = cycle_fn, cycle_initial = cycle_initial)] +fn cycle_b(db: &dyn KnobsDatabase, input: Input) -> u32 { + let value = cycle_a(db, input); + value.saturating_add(1).min(input.value(db)) +} + +#[salsa::tracked] +fn cancellation_point(db: &dyn KnobsDatabase, input: Input) { + input.value(db); +} + +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id, _input: Input) -> u32 { + 0 +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _cycle: &salsa::Cycle, + _last_provisional_value: &u32, + value: u32, + _input: Input, +) -> u32 { + value +} + +#[test] +fn pending_write_cancellation_invalidates_provisional_memos() { + let db = Knobs::default(); + let db_writer = db.clone(); + let db_t1 = db.clone(); + let db_waiter = db.clone(); + let input = Input::new(&db, 3); + + db.signal_on_did_cancel(2); + + drop(db); + + let t1 = thread::spawn(move || catch_unwind(|| cycle_a(&db_t1, input))); + + db_waiter.wait_for(1); + drop(db_waiter); + + let t2 = thread::spawn({ + let mut db_writer = db_writer; + move || { + db_writer.trigger_lru_eviction(); + db_writer + } + }); + + let result = t1.join().unwrap(); + let Err(payload) = result else { + panic!("expected the fixpoint query to be cancelled"); + }; + assert!( + payload + .downcast_ref::() + .is_some_and(|cancelled| matches!(cancelled, Cancelled::PendingWrite)), + "expected pending-write cancellation, got {payload:?}", + ); + + let db_after = t2.join().unwrap(); + + let result = catch_unwind(|| cycle_a(&db_after, input)); + assert_eq!(result.unwrap(), 3); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 399eaa7da..54e5fbb5f 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,7 @@ mod setup; mod signal; +mod cancellation_in_fixpoint; mod cancellation_token_cycle_nested; mod cancellation_token_multi_blocked; mod cancellation_token_recomputes;