Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/active_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ pub(crate) struct ActiveQuery {
/// Provisional cycle results that this query depends on.
cycle_heads: CycleHeads,

/// Provisional memos that this query depends on.
provisional_memos: FxIndexSet<DatabaseKeyIndex>,

/// If this query is a cycle head, iteration count of that cycle.
iteration_count: IterationCount,
}
Expand All @@ -80,6 +83,7 @@ impl ActiveQuery {
edges: &[QueryEdge],
untracked_read: bool,
active_tracked_ids: &[(Identity, Id)],
provisional_memos: &[DatabaseKeyIndex],
) {
assert!(self.input_outputs.is_empty());

Expand All @@ -102,6 +106,8 @@ impl ActiveQuery {
.mark_all_active(active_tracked_ids.iter().copied());
self.disambiguator_map
.seed(active_tracked_ids.iter().map(|(id, _)| id));
self.provisional_memos
.extend(provisional_memos.iter().copied());
}

pub(super) fn take_cycle_heads(&mut self) -> CycleHeads {
Expand All @@ -114,13 +120,19 @@ impl ActiveQuery {
durability: Durability,
changed_at: Revision,
cycle_heads: &CycleHeads,
provisional_memos: &[DatabaseKeyIndex],
#[cfg(feature = "accumulator")] has_accumulated: bool,
#[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues,
) {
self.durability = self.durability.min(durability);
self.changed_at = self.changed_at.max(changed_at);
self.input_outputs.insert(QueryEdge::input(input));
self.cycle_heads.extend(cycle_heads);
if !cycle_heads.is_empty() {
self.provisional_memos.insert(input);
}
self.provisional_memos
.extend(provisional_memos.iter().copied());
#[cfg(feature = "accumulator")]
{
self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated {
Expand Down Expand Up @@ -157,6 +169,10 @@ impl ActiveQuery {
self.input_outputs.insert(QueryEdge::output(key));
}

pub(super) fn add_provisional_memo(&mut self, key: DatabaseKeyIndex) {
self.provisional_memos.insert(key);
}

/// True if the given key was output by this query.
pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator {
self.disambiguator_map.disambiguate(key)
Expand Down Expand Up @@ -189,6 +205,7 @@ impl ActiveQuery {
disambiguator_map: Default::default(),
tracked_struct_ids: Default::default(),
cycle_heads: Default::default(),
provisional_memos: FxIndexSet::default(),
iteration_count,
#[cfg(feature = "accumulator")]
accumulated: Default::default(),
Expand All @@ -207,6 +224,7 @@ impl ActiveQuery {
ref mut disambiguator_map,
ref mut tracked_struct_ids,
ref mut cycle_heads,
ref mut provisional_memos,
iteration_count,
#[cfg(feature = "accumulator")]
ref mut accumulated,
Expand All @@ -232,6 +250,7 @@ impl ActiveQuery {
active_tracked_structs,
mem::take(cycle_heads),
iteration_count,
provisional_memos.drain(..).collect(),
);

let revisions = QueryRevisions {
Expand All @@ -250,7 +269,7 @@ impl ActiveQuery {
}
}

fn clear(&mut self) {
fn clear(&mut self) -> FxIndexSet<DatabaseKeyIndex> {
let Self {
database_key_index: _,
durability: _,
Expand All @@ -260,6 +279,7 @@ impl ActiveQuery {
disambiguator_map,
tracked_struct_ids,
cycle_heads,
provisional_memos,
iteration_count,
#[cfg(feature = "accumulator")]
accumulated,
Expand All @@ -270,9 +290,12 @@ impl ActiveQuery {
disambiguator_map.clear();
tracked_struct_ids.clear();
*cycle_heads = Default::default();
let provisional_memos = mem::take(provisional_memos);
*iteration_count = IterationCount::initial();
#[cfg(feature = "accumulator")]
accumulated.clear();

provisional_memos
}

fn reset_for(
Expand All @@ -289,6 +312,7 @@ impl ActiveQuery {
disambiguator_map,
tracked_struct_ids,
cycle_heads,
provisional_memos,
iteration_count,
#[cfg(feature = "accumulator")]
accumulated,
Expand Down Expand Up @@ -316,6 +340,10 @@ impl ActiveQuery {
cycle_heads.is_empty(),
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
);
debug_assert!(
provisional_memos.is_empty(),
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
);
#[cfg(feature = "accumulator")]
{
*accumulated_inputs = Default::default();
Expand Down Expand Up @@ -400,7 +428,11 @@ impl QueryStack {
self.stack[self.len].top_into_revisions()
}

pub(crate) fn pop(&mut self, key: DatabaseKeyIndex, #[cfg(debug_assertions)] push_len: usize) {
pub(crate) fn pop(
&mut self,
key: DatabaseKeyIndex,
#[cfg(debug_assertions)] push_len: usize,
) -> FxIndexSet<DatabaseKeyIndex> {
#[cfg(debug_assertions)]
assert_eq!(push_len, self.len(), "unbalanced push/pop");
debug_assert_ne!(self.len, 0, "too many pops");
Expand Down
44 changes: 42 additions & 2 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::table::Table;
use crate::table::memo::MemoTableTypes;
use crate::views::DatabaseDownCaster;
use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa};
use crate::zalsa_local::{QueryEdge, QueryOriginRef};
use crate::zalsa_local::{QueryEdge, QueryOriginRef, QueryRevisions, ZalsaLocal};
use crate::{Cycle, Id, Revision};

#[cfg(feature = "accumulator")]
Expand Down Expand Up @@ -265,10 +265,17 @@ where
fn insert_memo<'db>(
&'db self,
zalsa: &'db Zalsa,
zalsa_local: Option<&ZalsaLocal>,
id: Id,
mut memo: memo::Memo<'db, C>,
memo_ingredient_index: MemoIngredientIndex,
) -> &'db memo::Memo<'db, C> {
let may_be_provisional = memo.may_be_provisional();
if may_be_provisional {
memo.revisions
.add_provisional_memo(self.database_key_index(id));
}

if let Some(tracked_struct_ids) = memo.revisions.tracked_struct_ids_mut() {
tracked_struct_ids.shrink_to_fit();
}
Expand All @@ -289,7 +296,15 @@ where
unsafe { self.deleted_entries.push(old_value) };
}
// SAFETY: memo has been inserted into the table
unsafe { self.extend_memo_lifetime(memo.as_ref()) }
let memo = unsafe { self.extend_memo_lifetime(memo.as_ref()) };

if may_be_provisional {
if let Some(zalsa_local) = zalsa_local {
zalsa_local.record_provisional_memo(self.database_key_index(id));
}
}

memo
}

#[inline]
Expand Down Expand Up @@ -419,6 +434,31 @@ where
memo.revisions.verified_final.store(true, Ordering::Release);
}

fn invalidate_provisional_memo(&self, zalsa: &Zalsa, input: Id) {
let memo_ingredient_index = self.memo_ingredient_index(zalsa, input);
let Some(memo) = self.get_memo_from_table_for(zalsa, input, memo_ingredient_index) else {
return;
};

if !memo.may_be_provisional() || memo.value.is_none() {
return;
}

let database_key_index = self.database_key_index(input);
let revisions =
QueryRevisions::fixpoint_initial(database_key_index, IterationCount::initial());

let memo = memo::Memo::new(
None,
// Keep this distinct from the current revision so a retry treats
// the memo as stale instead of same-revision provisional state.
Revision::max(),
revisions,
);

self.insert_memo(zalsa, None, input, memo, memo_ingredient_index);
}

fn flatten_cycle_head_dependencies(
&self,
zalsa: &Zalsa,
Expand Down
Loading
Loading