Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/active_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::accumulator::{
};
use crate::hash::FxIndexSet;
use crate::key::DatabaseKeyIndex;
use crate::runtime::Stamp;
use crate::runtime::{CancellationCount, Stamp};
use crate::sync::atomic::AtomicBool;
use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap};
use crate::zalsa_local::{
Expand Down Expand Up @@ -70,6 +70,9 @@ pub(crate) struct ActiveQuery {

/// If this query is a cycle head, iteration count of that cycle.
iteration_count: IterationCount,

/// Cancellation count captured when this query run started.
pub(crate) cancellation_count: CancellationCount,
}

impl ActiveQuery {
Expand Down Expand Up @@ -179,7 +182,11 @@ impl ActiveQuery {
}

impl ActiveQuery {
fn new(database_key_index: DatabaseKeyIndex, iteration_count: IterationCount) -> Self {
fn new(
database_key_index: DatabaseKeyIndex,
iteration_count: IterationCount,
cancellation_count: CancellationCount,
) -> Self {
ActiveQuery {
database_key_index,
durability: Durability::MAX,
Expand All @@ -190,6 +197,7 @@ impl ActiveQuery {
tracked_struct_ids: Default::default(),
cycle_heads: Default::default(),
iteration_count,
cancellation_count,
#[cfg(feature = "accumulator")]
accumulated: Default::default(),
#[cfg(feature = "accumulator")]
Expand All @@ -208,6 +216,7 @@ impl ActiveQuery {
ref mut tracked_struct_ids,
ref mut cycle_heads,
iteration_count,
cancellation_count,
#[cfg(feature = "accumulator")]
ref mut accumulated,
#[cfg(feature = "accumulator")]
Expand Down Expand Up @@ -241,6 +250,7 @@ impl ActiveQuery {
#[cfg(feature = "accumulator")]
accumulated_inputs,
verified_final: AtomicBool::new(verified_final),
cancellation_count,
extra,
};

Expand All @@ -261,6 +271,7 @@ impl ActiveQuery {
tracked_struct_ids,
cycle_heads,
iteration_count,
cancellation_count: _,
#[cfg(feature = "accumulator")]
accumulated,
#[cfg(feature = "accumulator")]
Expand All @@ -279,6 +290,7 @@ impl ActiveQuery {
&mut self,
new_database_key_index: DatabaseKeyIndex,
new_iteration_count: IterationCount,
new_cancellation_count: CancellationCount,
) {
let Self {
database_key_index,
Expand All @@ -290,6 +302,7 @@ impl ActiveQuery {
tracked_struct_ids,
cycle_heads,
iteration_count,
cancellation_count,
#[cfg(feature = "accumulator")]
accumulated,
#[cfg(feature = "accumulator")]
Expand All @@ -300,6 +313,7 @@ impl ActiveQuery {
*changed_at = Revision::start();
*untracked_read = false;
*iteration_count = new_iteration_count;
*cancellation_count = new_cancellation_count;
debug_assert!(
input_outputs.is_empty(),
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
Expand Down Expand Up @@ -369,12 +383,16 @@ impl QueryStack {
&mut self,
database_key_index: DatabaseKeyIndex,
iteration_count: IterationCount,
cancellation_count: CancellationCount,
) {
if self.len < self.stack.len() {
self.stack[self.len].reset_for(database_key_index, iteration_count);
self.stack[self.len].reset_for(database_key_index, iteration_count, cancellation_count);
} else {
self.stack
.push(ActiveQuery::new(database_key_index, iteration_count));
self.stack.push(ActiveQuery::new(
database_key_index,
iteration_count,
cancellation_count,
));
}
self.len += 1;
}
Expand Down
3 changes: 3 additions & 0 deletions src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use std::iter::FusedIterator;
use thin_vec::{ThinVec, thin_vec};

use crate::key::DatabaseKeyIndex;
use crate::runtime::CancellationCount;
use crate::sync::OnceLock;
use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use crate::{Id, Revision};
Expand Down Expand Up @@ -499,11 +500,13 @@ pub enum ProvisionalStatus<'db> {
Provisional {
iteration: IterationCount,
verified_at: Revision,
cancellation_count: CancellationCount,
cycle_heads: &'db CycleHeads,
},
Final {
iteration: IterationCount,
verified_at: Revision,
cancellation_count: CancellationCount,
},
}

Expand Down
2 changes: 2 additions & 0 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,13 @@ where
ProvisionalStatus::Final {
iteration,
verified_at: memo.verified_at.load(),
cancellation_count: memo.revisions.cancellation_count,
}
} else {
ProvisionalStatus::Provisional {
iteration,
verified_at: memo.verified_at.load(),
cancellation_count: memo.revisions.cancellation_count,
cycle_heads: memo.cycle_heads(),
}
})
Expand Down
19 changes: 16 additions & 3 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ where

let database_key_index = claim_guard.database_key_index();
let zalsa = claim_guard.zalsa();
let cancellation_count = zalsa.runtime().cancellation_count();

let id = database_key_index.key_index();

Expand All @@ -136,7 +137,9 @@ where
let mut iteration_count = IterationCount::initial();

if let Some(old_memo) = opt_old_memo {
if old_memo.verified_at.load() == zalsa.current_revision() {
if old_memo.verified_at.load() == zalsa.current_revision()
&& old_memo.revisions.cancellation_count == cancellation_count
{
// The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query.
// However, the locking doesn't handle the case where a thread fetches the result of a panicking
// cycle head query **after** all locks were released. That's what we do here.
Expand All @@ -162,8 +165,13 @@ where
}
}

let _poison_guard =
PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index);
let _poison_guard = PoisonProvisionalIfPanicking::new(
self,
zalsa,
id,
memo_ingredient_index,
cancellation_count,
);

let (new_value, completed_query) = loop {
let active_query = claim_guard
Expand Down Expand Up @@ -386,6 +394,7 @@ where
// * ensure the final returned memo depends on all inputs from all iterations.
if old_memo.may_be_provisional()
&& old_memo.verified_at.load() == zalsa.current_revision()
&& old_memo.revisions.cancellation_count == active_query.cancellation_count()
{
active_query.seed_iteration(&old_memo.revisions);
}
Expand Down Expand Up @@ -423,6 +432,7 @@ struct PoisonProvisionalIfPanicking<'a, C: Configuration> {
zalsa: &'a Zalsa,
id: Id,
memo_ingredient_index: MemoIngredientIndex,
cancellation_count: crate::runtime::CancellationCount,
}

impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> {
Expand All @@ -431,12 +441,14 @@ impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> {
zalsa: &'a Zalsa,
id: Id,
memo_ingredient_index: MemoIngredientIndex,
cancellation_count: crate::runtime::CancellationCount,
) -> Self {
Self {
ingredient,
zalsa,
id,
memo_ingredient_index,
cancellation_count,
}
}
}
Expand All @@ -447,6 +459,7 @@ impl<C: Configuration> Drop for PoisonProvisionalIfPanicking<'_, C> {
let revisions = QueryRevisions::fixpoint_initial(
self.ingredient.database_key_index(self.id),
IterationCount::initial(),
self.cancellation_count,
);

let memo = Memo::new(None, self.zalsa.current_revision(), revisions);
Expand Down
12 changes: 11 additions & 1 deletion src/function/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ where
})
},
CycleRecoveryStrategy::Fixpoint | CycleRecoveryStrategy::FallbackImmediate => {
let cancellation_count = zalsa_local
.active_query_cancellation_count()
.unwrap_or_else(|| zalsa_local.cancellation_count());

// check if there's a provisional value for this query
// Note we don't `validate_may_be_provisional` the memo here as we want to reuse an
// existing provisional memo if it exists
Expand All @@ -193,6 +197,7 @@ where
// on the value OR a concurrent `Vec` for cycle heads.
if memo.verified_at.load() == zalsa.current_revision()
&& memo.value.is_some()
&& memo.revisions.cancellation_count == cancellation_count
&& memo.revisions.cycle_heads().contains(&database_key_index)
{
memo.revisions
Expand All @@ -219,14 +224,19 @@ where
.and_then(|old_memo| {
if old_memo.verified_at.load() == zalsa.current_revision()
&& old_memo.value.is_some()
&& old_memo.revisions.cancellation_count == cancellation_count
{
Some(old_memo.revisions.iteration())
} else {
None
}
})
.unwrap_or(IterationCount::initial());
let revisions = QueryRevisions::fixpoint_initial(database_key_index, iteration);
let revisions = QueryRevisions::fixpoint_initial(
database_key_index,
iteration,
cancellation_count,
);

let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id));
self.insert_memo(
Expand Down
27 changes: 21 additions & 6 deletions src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::function::{Configuration, IngredientImpl, Reentrancy};
use std::sync::atomic::Ordering;

use crate::key::DatabaseKeyIndex;
use crate::runtime::CancellationCount;
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
use crate::zalsa_local::{QueryEdge, QueryEdgeKind, QueryOriginRef, QueryRevisions, ZalsaLocal};
use crate::{Id, Revision};
Expand Down Expand Up @@ -308,18 +309,21 @@ where
);

let verified_at = memo.verified_at.load();
let cancellation_count = memo.revisions.cancellation_count;

validate_provisional(
zalsa,
database_key_index,
&memo.revisions,
verified_at,
cancellation_count,
cycle_heads,
) || validate_same_iteration(
zalsa,
zalsa_local,
database_key_index,
verified_at,
cancellation_count,
cycle_heads,
)
}
Expand Down Expand Up @@ -519,6 +523,7 @@ fn validate_provisional(
database_key_index: DatabaseKeyIndex,
memo_revisions: &QueryRevisions,
memo_verified_at: Revision,
memo_cancellation_count: CancellationCount,
cycle_heads: &CycleHeads,
) -> bool {
crate::tracing::trace!("{database_key_index:?}: validate_provisional({database_key_index:?})",);
Expand All @@ -537,10 +542,12 @@ fn validate_provisional(
ProvisionalStatus::Final {
iteration,
verified_at,
cancellation_count,
..
} => {
// Only consider the cycle head if it is from the same revision as the memo
if verified_at != memo_verified_at {
if verified_at != memo_verified_at || cancellation_count != memo_cancellation_count
{
return false;
}

Expand Down Expand Up @@ -571,6 +578,7 @@ fn validate_same_iteration(
zalsa_local: &ZalsaLocal,
memo_database_key_index: DatabaseKeyIndex,
memo_verified_at: Revision,
memo_cancellation_count: CancellationCount,
cycle_heads: &CycleHeads,
) -> bool {
crate::tracing::trace!("validate_same_iteration({memo_database_key_index:?})",);
Expand All @@ -582,6 +590,10 @@ fn validate_same_iteration(
return false;
}

if zalsa_local.active_query_cancellation_count() != Some(memo_cancellation_count) {
return false;
}

// Always return `false` for cycle initial values "unless" they are running in the same thread.
if cycle_heads
.iter_not_eq(memo_database_key_index)
Expand All @@ -591,10 +603,10 @@ fn validate_same_iteration(
// SAFETY: We do not access the query stack reentrantly.
let on_stack = unsafe {
zalsa_local.with_query_stack_unchecked(|stack| {
stack
.iter()
.rev()
.any(|query| query.database_key_index == memo_database_key_index)
stack.iter().rev().any(|query| {
query.database_key_index == memo_database_key_index
&& query.cancellation_count == memo_cancellation_count
})
})
};

Expand All @@ -609,8 +621,11 @@ fn validate_same_iteration(
head_iteration_count,
memo_iteration_count: current_iteration_count,
verified_at: head_verified_at,
cancellation_count: head_cancellation_count,
} => {
if head_verified_at != memo_verified_at {
if head_verified_at != memo_verified_at
|| head_cancellation_count != memo_cancellation_count
{
return false;
}

Expand Down
Loading
Loading