diff --git a/src/input.rs b/src/input.rs index b957d98..a3d9532 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,3 +1,4 @@ +use crate::runtime::ChangedAt; use crate::runtime::QueryDescriptorSet; use crate::runtime::Revision; use crate::runtime::StampedValue; @@ -66,7 +67,7 @@ where Ok(StampedValue { value: ::default(), - changed_at: Revision::ZERO, + changed_at: ChangedAt::Revision(Revision::ZERO), }) } } @@ -109,10 +110,10 @@ where map_read .get(key) .map(|v| v.changed_at) - .unwrap_or(Revision::ZERO) + .unwrap_or(ChangedAt::Revision(Revision::ZERO)) }; - changed_at > revision + changed_at.changed_since(revision) } } @@ -131,7 +132,7 @@ where // racing with somebody else to modify this same cell. // (Otherwise, someone else might write a *newer* revision // into the same cell while we block on the lock.) - let changed_at = db.salsa_runtime().increment_revision(); + let changed_at = ChangedAt::Revision(db.salsa_runtime().increment_revision()); map_write.insert(key, StampedValue { value, changed_at }); } @@ -150,7 +151,7 @@ where // Unlike with `set`, here we use the **current revision** and // do not create a new one. - let changed_at = db.salsa_runtime().current_revision(); + let changed_at = ChangedAt::Revision(db.salsa_runtime().current_revision()); map_write.insert(key, StampedValue { value, changed_at }); } diff --git a/src/memoized.rs b/src/memoized.rs index fbcb15f..fb8339b 100644 --- a/src/memoized.rs +++ b/src/memoized.rs @@ -1,3 +1,4 @@ +use crate::runtime::ChangedAt; use crate::runtime::QueryDescriptorSet; use crate::runtime::Revision; use crate::runtime::StampedValue; @@ -91,7 +92,7 @@ where { /// Last time the value has actually changed. /// changed_at can be less than verified_at. - changed_at: Revision, + changed_at: ChangedAt, /// The result of the query, if we decide to memoize it. value: Option, @@ -184,24 +185,18 @@ where // first things first, let's walk over each of our previous // inputs and check whether they are out of date. if let Some(QueryState::Memoized(old_memo)) = &mut old_value { - if old_memo.value.is_some() { - if old_memo - .inputs - .iter() - .all(|old_input| !old_input.maybe_changed_since(db, old_memo.changed_at)) - { - debug!("{:?}({:?}): inputs still valid", Q::default(), key); - // If none of out inputs have changed since the last time we refreshed - // our value, then our value must still be good. We'll just patch - // the verified-at date and re-use it. - old_memo.verified_at = revision_now; - let value = old_memo.value.clone().unwrap(); - let changed_at = old_memo.changed_at; + if old_memo.validate_memoized_value(db) { + debug!("{:?}({:?}): inputs still valid", Q::default(), key); + // If none of out inputs have changed since the last time we refreshed + // our value, then our value must still be good. We'll just patch + // the verified-at date and re-use it. + old_memo.verified_at = revision_now; + let value = old_memo.value.clone().unwrap(); + let changed_at = old_memo.changed_at; - let mut map_write = self.map.write(); - self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); - return Ok(StampedValue { value, changed_at }); - } + let mut map_write = self.map.write(); + self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); + return Ok(StampedValue { value, changed_at }); } } @@ -318,14 +313,14 @@ where // If our memo is still up to date, then check if we've // changed since the revision. if memo.verified_at == revision_now { - return memo.changed_at > revision; + return memo.changed_at.changed_since(revision); } if memo.value.is_some() { // Otherwise, if we cache values, fall back to the full read to compute the result. drop(memo); drop(map_read); return match self.read(db, key, descriptor) { - Ok(v) => v.changed_at > revision, + Ok(v) => v.changed_at.changed_since(revision), Err(CycleDetected) => true, }; } @@ -370,7 +365,8 @@ where let mut map_write = self.map.write(); - let changed_at = db.salsa_runtime().current_revision(); + let current_revision = db.salsa_runtime().current_revision(); + let changed_at = ChangedAt::Revision(current_revision); map_write.insert( key, @@ -378,8 +374,28 @@ where value: Some(value), changed_at, inputs: QueryDescriptorSet::new(), - verified_at: changed_at, + verified_at: current_revision, }), ); } } + +impl Memo +where + Q: QueryFunction, + DB: Database, +{ + fn validate_memoized_value(&self, db: &DB) -> bool { + // If we don't have a memoized value, nothing to validate. + if !self.value.is_some() { + return false; + } + + match self.changed_at { + ChangedAt::Revision(revision) => self + .inputs + .iter() + .all(|old_input| !old_input.maybe_changed_since(db, revision)), + } + } +} diff --git a/src/runtime.rs b/src/runtime.rs index f92456e..4423872 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -136,7 +136,7 @@ where /// - `descriptor`: the query whose result was read /// - `changed_revision`: the last revision in which the result of that /// query had changed - crate fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: Revision) { + crate fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: ChangedAt) { if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { top_query.add_read(descriptor, changed_at); } @@ -178,7 +178,7 @@ struct ActiveQuery { descriptor: DB::QueryDescriptor, /// Records the maximum revision where any subquery changed - changed_at: Revision, + changed_at: ChangedAt, /// Each subquery subqueries: QueryDescriptorSet, @@ -188,12 +188,12 @@ impl ActiveQuery { fn new(descriptor: DB::QueryDescriptor) -> Self { ActiveQuery { descriptor, - changed_at: Revision::ZERO, + changed_at: ChangedAt::Revision(Revision::ZERO), subqueries: QueryDescriptorSet::new(), } } - fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: Revision) { + fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: ChangedAt) { self.subqueries.insert(subquery.clone()); self.changed_at = self.changed_at.max(changed_at); } @@ -214,6 +214,25 @@ impl std::fmt::Debug for Revision { } } +/// Records when a stamped value changed. +/// +/// Note: the order of variants is significant. We sometimes use `max` +/// for example to find the "most recent revision" when something +/// changed. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum ChangedAt { + Revision(Revision), +} + +impl ChangedAt { + /// True if this value has changed after `revision`. + pub fn changed_since(self, revision: Revision) -> bool { + match self { + ChangedAt::Revision(r) => r > revision, + } + } +} + /// An insertion-order-preserving set of queries. Used to track the /// inputs accessed during query execution. crate struct QueryDescriptorSet { @@ -249,5 +268,5 @@ impl QueryDescriptorSet { #[derive(Clone, Debug)] crate struct StampedValue { crate value: V, - crate changed_at: Revision, + crate changed_at: ChangedAt, } diff --git a/src/volatile.rs b/src/volatile.rs index 6fa6172..be70e38 100644 --- a/src/volatile.rs +++ b/src/volatile.rs @@ -1,3 +1,4 @@ +use crate::runtime::ChangedAt; use crate::runtime::Revision; use crate::runtime::StampedValue; use crate::CycleDetected; @@ -68,10 +69,9 @@ where let was_in_progress = self.in_progress.lock().remove(key); assert!(was_in_progress); - let revision_now = db.salsa_runtime().current_revision(); + let changed_at = ChangedAt::Revision(db.salsa_runtime().current_revision()); - db.salsa_runtime() - .report_query_read(descriptor, revision_now); + db.salsa_runtime().report_query_read(descriptor, changed_at); Ok(value) }