diff --git a/src/memoized.rs b/src/derived.rs similarity index 72% rename from src/memoized.rs rename to src/derived.rs index fbcb15fb..83090539 100644 --- a/src/memoized.rs +++ b/src/derived.rs @@ -1,3 +1,4 @@ +use crate::runtime::ChangedAt; use crate::runtime::QueryDescriptorSet; use crate::runtime::Revision; use crate::runtime::StampedValue; @@ -23,14 +24,21 @@ use std::marker::PhantomData; /// Memoized queries store the result plus a list of the other queries /// that they invoked. This means we can avoid recomputing them when /// none of those inputs have changed. -pub type MemoizedStorage = WeakMemoizedStorage; +pub type MemoizedStorage = DerivedStorage; /// "Dependency" queries just track their dependencies and not the /// actual value (which they produce on demand). This lessens the /// storage requirements. -pub type DependencyStorage = WeakMemoizedStorage; +pub type DependencyStorage = DerivedStorage; -pub struct WeakMemoizedStorage +/// "Dependency" queries just track their dependencies and not the +/// actual value (which they produce on demand). This lessens the +/// storage requirements. +pub type VolatileStorage = DerivedStorage; + +/// Handles storage where the value is 'derived' by executing a +/// function (in contrast to "inputs"). +pub struct DerivedStorage where Q: QueryFunction, DB: Database, @@ -46,6 +54,8 @@ where DB: Database, { fn should_memoize_value(key: &Q::Key) -> bool; + + fn should_track_inputs(key: &Q::Key) -> bool; } pub enum AlwaysMemoizeValue {} @@ -57,6 +67,10 @@ where fn should_memoize_value(_key: &Q::Key) -> bool { true } + + fn should_track_inputs(_key: &Q::Key) -> bool { + true + } } pub enum NeverMemoizeValue {} @@ -68,6 +82,30 @@ where fn should_memoize_value(_key: &Q::Key) -> bool { false } + + fn should_track_inputs(_key: &Q::Key) -> bool { + true + } +} + +pub enum VolatileValue {} +impl MemoizationPolicy for VolatileValue +where + Q: QueryFunction, + DB: Database, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + // Why memoize? Well, if the "volatile" value really is + // constantly changing, we still want to capture its value + // until the next revision is triggered and ensure it doesn't + // change -- otherwise the system gets into an inconsistent + // state where the same query reports back different values. + true + } + + fn should_track_inputs(_key: &Q::Key) -> bool { + false + } } /// Defines the "current state" of query's memoized results. @@ -91,11 +129,12 @@ where { /// Last time the value has actually changed. /// changed_at can be less than verified_at. - changed_at: Revision, + changed_at: ChangedAt, /// The result of the query, if we decide to memoize it. value: Option, + /// The inputs that went into our query, if we are tracking them. inputs: QueryDescriptorSet, /// Last time that we checked our inputs to see if they have @@ -106,21 +145,21 @@ where verified_at: Revision, } -impl Default for WeakMemoizedStorage +impl Default for DerivedStorage where Q: QueryFunction, DB: Database, MP: MemoizationPolicy, { fn default() -> Self { - WeakMemoizedStorage { + DerivedStorage { map: RwLock::new(FxHashMap::default()), policy: PhantomData, } } } -impl WeakMemoizedStorage +impl DerivedStorage where Q: QueryFunction, DB: Database, @@ -184,32 +223,32 @@ where // first things first, let's walk over each of our previous // inputs and check whether they are out of date. if let Some(QueryState::Memoized(old_memo)) = &mut old_value { - if old_memo.value.is_some() { - if old_memo - .inputs - .iter() - .all(|old_input| !old_input.maybe_changed_since(db, old_memo.changed_at)) - { - debug!("{:?}({:?}): inputs still valid", Q::default(), key); - // If none of out inputs have changed since the last time we refreshed - // our value, then our value must still be good. We'll just patch - // the verified-at date and re-use it. - old_memo.verified_at = revision_now; - let value = old_memo.value.clone().unwrap(); - let changed_at = old_memo.changed_at; + if let Some(value) = old_memo.verify_memoized_value(db) { + debug!("{:?}({:?}): inputs still valid", Q::default(), key); + // If none of out inputs have changed since the last time we refreshed + // our value, then our value must still be good. We'll just patch + // the verified-at date and re-use it. + old_memo.verified_at = revision_now; + let changed_at = old_memo.changed_at; - let mut map_write = self.map.write(); - self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); - return Ok(StampedValue { value, changed_at }); - } + let mut map_write = self.map.write(); + self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); + return Ok(StampedValue { value, changed_at }); } } // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let (mut stamped_value, inputs) = db - .salsa_runtime() - .execute_query_implementation::(db, descriptor, key); + let runtime = db.salsa_runtime(); + let (mut stamped_value, inputs) = runtime.execute_query_implementation(descriptor, || { + debug!("{:?}({:?}): executing query", Q::default(), key); + + if !self.should_track_inputs(key) { + runtime.report_untracked_read(); + } + + Q::execute(db, key.clone()) + }); // We assume that query is side-effect free -- that is, does // not mutate the "inputs" to the query system. Sanity check @@ -272,9 +311,13 @@ where fn should_memoize_value(&self, key: &Q::Key) -> bool { MP::should_memoize_value(key) } + + fn should_track_inputs(&self, key: &Q::Key) -> bool { + MP::should_track_inputs(key) + } } -impl QueryStorageOps for WeakMemoizedStorage +impl QueryStorageOps for DerivedStorage where Q: QueryFunction, DB: Database, @@ -318,14 +361,14 @@ where // If our memo is still up to date, then check if we've // changed since the revision. if memo.verified_at == revision_now { - return memo.changed_at > revision; + return memo.changed_at.changed_since(revision); } if memo.value.is_some() { // Otherwise, if we cache values, fall back to the full read to compute the result. drop(memo); drop(map_read); return match self.read(db, key, descriptor) { - Ok(v) => v.changed_at > revision, + Ok(v) => v.changed_at.changed_since(revision), Err(CycleDetected) => true, }; } @@ -342,11 +385,7 @@ where _ => unreachable!(), }; - if memo - .inputs - .iter() - .all(|old_input| !old_input.maybe_changed_since(db, memo.verified_at)) - { + if memo.verify_inputs(db) { memo.verified_at = revision_now; self.overwrite_placeholder(&mut self.map.write(), key, QueryState::Memoized(memo)); return false; @@ -359,7 +398,7 @@ where } } -impl UncheckedMutQueryStorageOps for WeakMemoizedStorage +impl UncheckedMutQueryStorageOps for DerivedStorage where Q: QueryFunction, DB: Database, @@ -370,16 +409,47 @@ where let mut map_write = self.map.write(); - let changed_at = db.salsa_runtime().current_revision(); + let current_revision = db.salsa_runtime().current_revision(); + let changed_at = ChangedAt::Revision(current_revision); map_write.insert( key, QueryState::Memoized(Memo { value: Some(value), changed_at, - inputs: QueryDescriptorSet::new(), - verified_at: changed_at, + inputs: QueryDescriptorSet::default(), + verified_at: current_revision, }), ); } } + +impl Memo +where + Q: QueryFunction, + DB: Database, +{ + fn verify_memoized_value(&self, db: &DB) -> Option { + // If we don't have a memoized value, nothing to validate. + if let Some(v) = &self.value { + // If inputs are still valid. + if self.verify_inputs(db) { + return Some(v.clone()); + } + } + + None + } + + fn verify_inputs(&self, db: &DB) -> bool { + match self.changed_at { + ChangedAt::Revision(revision) => match &self.inputs { + QueryDescriptorSet::Tracked(inputs) => inputs + .iter() + .all(|old_input| !old_input.maybe_changed_since(db, revision)), + + QueryDescriptorSet::Untracked => false, + }, + } + } +} diff --git a/src/input.rs b/src/input.rs index b957d98a..a3d95323 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,3 +1,4 @@ +use crate::runtime::ChangedAt; use crate::runtime::QueryDescriptorSet; use crate::runtime::Revision; use crate::runtime::StampedValue; @@ -66,7 +67,7 @@ where Ok(StampedValue { value: ::default(), - changed_at: Revision::ZERO, + changed_at: ChangedAt::Revision(Revision::ZERO), }) } } @@ -109,10 +110,10 @@ where map_read .get(key) .map(|v| v.changed_at) - .unwrap_or(Revision::ZERO) + .unwrap_or(ChangedAt::Revision(Revision::ZERO)) }; - changed_at > revision + changed_at.changed_since(revision) } } @@ -131,7 +132,7 @@ where // racing with somebody else to modify this same cell. // (Otherwise, someone else might write a *newer* revision // into the same cell while we block on the lock.) - let changed_at = db.salsa_runtime().increment_revision(); + let changed_at = ChangedAt::Revision(db.salsa_runtime().increment_revision()); map_write.insert(key, StampedValue { value, changed_at }); } @@ -150,7 +151,7 @@ where // Unlike with `set`, here we use the **current revision** and // do not create a new one. - let changed_at = db.salsa_runtime().current_revision(); + let changed_at = ChangedAt::Revision(db.salsa_runtime().current_revision()); map_write.insert(key, StampedValue { value, changed_at }); } diff --git a/src/lib.rs b/src/lib.rs index dd3e2b3f..a5317cbf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,10 +16,9 @@ use std::fmt::Display; use std::fmt::Write; use std::hash::Hash; +pub mod derived; pub mod input; -pub mod memoized; pub mod runtime; -pub mod volatile; pub use crate::runtime::Runtime; @@ -402,19 +401,19 @@ macro_rules! query_group { ( @storage_ty[$DB:ident, $Self:ident, memoized] ) => { - $crate::memoized::MemoizedStorage<$DB, $Self> + $crate::derived::MemoizedStorage<$DB, $Self> }; ( @storage_ty[$DB:ident, $Self:ident, volatile] ) => { - $crate::volatile::VolatileStorage<$DB, $Self> + $crate::derived::VolatileStorage<$DB, $Self> }; ( @storage_ty[$DB:ident, $Self:ident, dependencies] ) => { - $crate::memoized::DependencyStorage<$DB, $Self> + $crate::derived::DependencyStorage<$DB, $Self> }; ( diff --git a/src/runtime.rs b/src/runtime.rs index f92456ea..604cec31 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -88,16 +88,12 @@ where result } - crate fn execute_query_implementation( + crate fn execute_query_implementation( &self, - db: &DB, descriptor: &DB::QueryDescriptor, - key: &Q::Key, - ) -> (StampedValue, QueryDescriptorSet) - where - Q: QueryFunction, - { - debug!("{:?}({:?}): executing query", Q::default(), key); + execute: impl FnOnce() -> V, + ) -> (StampedValue, QueryDescriptorSet) { + debug!("{:?}: execute_query_implementation invoked", descriptor); // Push the active query onto the stack. let push_len = { @@ -109,7 +105,7 @@ where }; // Execute user's code, accumulating inputs etc. - let value = Q::execute(db, key.clone()); + let value = execute(); // Extract accumulated inputs. let ActiveQuery { @@ -136,12 +132,19 @@ where /// - `descriptor`: the query whose result was read /// - `changed_revision`: the last revision in which the result of that /// query had changed - crate fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: Revision) { + crate fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: ChangedAt) { if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { top_query.add_read(descriptor, changed_at); } } + crate fn report_untracked_read(&self) { + if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { + let changed_at = ChangedAt::Revision(self.current_revision()); + top_query.add_untracked_read(changed_at); + } + } + /// Obviously, this should be user configurable at some point. crate fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! { let local_state = self.local_state.borrow(); @@ -178,7 +181,7 @@ struct ActiveQuery { descriptor: DB::QueryDescriptor, /// Records the maximum revision where any subquery changed - changed_at: Revision, + changed_at: ChangedAt, /// Each subquery subqueries: QueryDescriptorSet, @@ -188,15 +191,20 @@ impl ActiveQuery { fn new(descriptor: DB::QueryDescriptor) -> Self { ActiveQuery { descriptor, - changed_at: Revision::ZERO, - subqueries: QueryDescriptorSet::new(), + changed_at: ChangedAt::Revision(Revision::ZERO), + subqueries: QueryDescriptorSet::default(), } } - fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: Revision) { + fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: ChangedAt) { self.subqueries.insert(subquery.clone()); self.changed_at = self.changed_at.max(changed_at); } + + fn add_untracked_read(&mut self, changed_at: ChangedAt) { + self.subqueries.insert_untracked(); + self.changed_at = self.changed_at.max(changed_at); + } } #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -214,40 +222,70 @@ impl std::fmt::Debug for Revision { } } +/// Records when a stamped value changed. +/// +/// Note: the order of variants is significant. We sometimes use `max` +/// for example to find the "most recent revision" when something +/// changed. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum ChangedAt { + Revision(Revision), +} + +impl ChangedAt { + /// True if this value has changed after `revision`. + pub fn changed_since(self, revision: Revision) -> bool { + match self { + ChangedAt::Revision(r) => r > revision, + } + } +} + /// An insertion-order-preserving set of queries. Used to track the /// inputs accessed during query execution. -crate struct QueryDescriptorSet { - set: FxIndexSet, +crate enum QueryDescriptorSet { + /// All reads were to tracked things: + Tracked(FxIndexSet), + + /// Some reads to an untracked thing: + Untracked, } impl std::fmt::Debug for QueryDescriptorSet { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Debug::fmt(&self.set, fmt) + match self { + QueryDescriptorSet::Tracked(set) => std::fmt::Debug::fmt(set, fmt), + QueryDescriptorSet::Untracked => write!(fmt, "Untracked"), + } + } +} + +impl Default for QueryDescriptorSet { + fn default() -> Self { + QueryDescriptorSet::Tracked(FxIndexSet::default()) } } impl QueryDescriptorSet { - crate fn new() -> Self { - QueryDescriptorSet { - set: FxIndexSet::default(), + /// Add `descriptor` to the set. Returns true if `descriptor` is + /// newly added and false if `descriptor` was already a member. + fn insert(&mut self, descriptor: DB::QueryDescriptor) { + match self { + QueryDescriptorSet::Tracked(set) => { + set.insert(descriptor); + } + + QueryDescriptorSet::Untracked => {} } } - /// Add `descriptor` to the set. Returns true if `descriptor` is - /// newly added and false if `descriptor` was already a member. - fn insert(&mut self, descriptor: DB::QueryDescriptor) -> bool { - self.set.insert(descriptor) - } - - /// Iterate over all queries in the set, in the order of their - /// first insertion. - pub fn iter(&self) -> impl Iterator { - self.set.iter() + fn insert_untracked(&mut self) { + *self = QueryDescriptorSet::Untracked; } } #[derive(Clone, Debug)] crate struct StampedValue { crate value: V, - crate changed_at: Revision, + crate changed_at: ChangedAt, } diff --git a/src/volatile.rs b/src/volatile.rs deleted file mode 100644 index 6fa61729..00000000 --- a/src/volatile.rs +++ /dev/null @@ -1,95 +0,0 @@ -use crate::runtime::Revision; -use crate::runtime::StampedValue; -use crate::CycleDetected; -use crate::Database; -use crate::QueryFunction; -use crate::QueryStorageOps; -use crate::QueryTable; -use log::debug; -use parking_lot::Mutex; -use rustc_hash::FxHashSet; -use std::any::Any; -use std::cell::RefCell; -use std::collections::hash_map::Entry; -use std::fmt::Debug; -use std::fmt::Display; -use std::fmt::Write; -use std::hash::Hash; - -/// Volatile Storage is just **always** considered dirty. Any time you -/// ask for the result of such a query, it is recomputed. -pub struct VolatileStorage -where - Q: QueryFunction, - DB: Database, -{ - /// We don't store the results of volatile queries, - /// but we track in-progress set to detect cycles. - in_progress: Mutex>, -} - -impl Default for VolatileStorage -where - Q: QueryFunction, - DB: Database, -{ - fn default() -> Self { - VolatileStorage { - in_progress: Mutex::new(FxHashSet::default()), - } - } -} - -impl QueryStorageOps for VolatileStorage -where - Q: QueryFunction, - DB: Database, -{ - fn try_fetch<'q>( - &self, - db: &'q DB, - key: &Q::Key, - descriptor: &DB::QueryDescriptor, - ) -> Result { - if !self.in_progress.lock().insert(key.clone()) { - return Err(CycleDetected); - } - - let ( - StampedValue { - value, - changed_at: _, - }, - _inputs, - ) = db - .salsa_runtime() - .execute_query_implementation::(db, descriptor, key); - - let was_in_progress = self.in_progress.lock().remove(key); - assert!(was_in_progress); - - let revision_now = db.salsa_runtime().current_revision(); - - db.salsa_runtime() - .report_query_read(descriptor, revision_now); - - Ok(value) - } - - fn maybe_changed_since( - &self, - _db: &'q DB, - revision: Revision, - key: &Q::Key, - _descriptor: &DB::QueryDescriptor, - ) -> bool { - debug!( - "{:?}({:?})::maybe_changed_since(revision={:?}) ==> true (volatile)", - Q::default(), - key, - revision, - ); - - true - } -} diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index 2864ddf3..36bd3593 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -38,10 +38,11 @@ fn volatile(db: &impl MemoizedVolatileContext, (): ()) -> usize { fn volatile_x2() { let query = TestContextImpl::default(); - // Invoking volatile twice will simply execute twice. + // Invoking volatile twice doesn't execute twice, because volatile + // queries are memoized by default. query.volatile(()); query.volatile(()); - query.assert_log(&["Volatile invoked", "Volatile invoked"]); + query.assert_log(&["Volatile invoked"]); } /// Test that: @@ -67,7 +68,7 @@ fn revalidate() { query.salsa_runtime().next_revision(); query.memoized2(()); - query.assert_log(&["Memoized1 invoked", "Volatile invoked"]); + query.assert_log(&["Volatile invoked", "Memoized1 invoked"]); query.memoized2(()); query.assert_log(&[]); @@ -78,7 +79,7 @@ fn revalidate() { query.salsa_runtime().next_revision(); query.memoized2(()); - query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]); + query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]); query.memoized2(()); query.assert_log(&[]); diff --git a/tests/storage_varieties/main.rs b/tests/storage_varieties/main.rs index c10de188..0e90f20e 100644 --- a/tests/storage_varieties/main.rs +++ b/tests/storage_varieties/main.rs @@ -1,4 +1,5 @@ #![feature(crate_visibility_modifier)] +#![feature(underscore_imports)] mod implementation; mod queries; diff --git a/tests/storage_varieties/queries.rs b/tests/storage_varieties/queries.rs index d0af5e2b..b66dc53e 100644 --- a/tests/storage_varieties/queries.rs +++ b/tests/storage_varieties/queries.rs @@ -17,7 +17,7 @@ salsa::query_group! { /// Because this query is memoized, we only increment the counter /// the first time it is invoked. fn memoized(db: &impl Database, (): ()) -> usize { - db.increment() + db.volatile(()) } /// Because this query is volatile, each time it is invoked, diff --git a/tests/storage_varieties/tests.rs b/tests/storage_varieties/tests.rs index 5ad5ffc5..bfd82143 100644 --- a/tests/storage_varieties/tests.rs +++ b/tests/storage_varieties/tests.rs @@ -2,32 +2,47 @@ use crate::implementation::DatabaseImpl; use crate::queries::Database; +use salsa::Database as _; #[test] fn memoized_twice() { - let query = DatabaseImpl::default(); - let v1 = query.memoized(()); - let v2 = query.memoized(()); + let db = DatabaseImpl::default(); + let v1 = db.memoized(()); + let v2 = db.memoized(()); assert_eq!(v1, v2); } #[test] fn volatile_twice() { - let query = DatabaseImpl::default(); - let v1 = query.volatile(()); - let v2 = query.volatile(()); - assert_eq!(v1 + 1, v2); + let db = DatabaseImpl::default(); + let v1 = db.volatile(()); + let v2 = db.volatile(()); // volatiles are cached, so 2nd read returns the same + assert_eq!(v1, v2); + + db.salsa_runtime().next_revision(); // clears volatile caches + + let v3 = db.volatile(()); // will re-increment the counter + let v4 = db.volatile(()); // second call will be cached + assert_eq!(v1 + 1, v3); + assert_eq!(v3, v4); } #[test] fn intermingled() { - let query = DatabaseImpl::default(); - let v1 = query.volatile(()); - let v2 = query.memoized(()); - let v3 = query.volatile(()); - let v4 = query.memoized(()); + let db = DatabaseImpl::default(); + let v1 = db.volatile(()); + let v2 = db.memoized(()); + let v3 = db.volatile(()); // cached + let v4 = db.memoized(()); // cached - assert_eq!(v1 + 1, v2); - assert_eq!(v2 + 1, v3); + assert_eq!(v1, v2); + assert_eq!(v1, v3); assert_eq!(v2, v4); + + db.salsa_runtime().next_revision(); // clears volatile caches + + let v5 = db.memoized(()); // re-executes volatile, caches new result + let v6 = db.memoized(()); // re-use cached result + assert_eq!(v4 + 1, v5); + assert_eq!(v5, v6); }