diff --git a/src/debug.rs b/src/debug.rs new file mode 100644 index 00000000..491ffdf1 --- /dev/null +++ b/src/debug.rs @@ -0,0 +1,28 @@ +//! Debugging APIs: these are meant for use when unit-testing or +//! debugging your application but aren't ordinarily needed. + +use crate::Database; +use crate::Query; +use crate::QueryStorageOps; +use crate::QueryTable; + +pub trait DebugQueryTable { + type Key; + + /// True if salsa thinks that the value for `key` is a + /// **constant**, meaning that it can never change, no matter what + /// values the inputs take on from this point. + fn is_constant(&self, key: Self::Key) -> bool; +} + +impl DebugQueryTable for QueryTable<'_, DB, Q> +where + DB: Database, + Q: Query, +{ + type Key = Q::Key; + + fn is_constant(&self, key: Q::Key) -> bool { + self.storage.is_constant(self.db, &key) + } +} diff --git a/src/derived.rs b/src/derived.rs index 49ca1e5d..e827f71b 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -396,6 +396,15 @@ where true } + + fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool { + let map_read = self.map.read(); + match map_read.get(key) { + None => false, + Some(QueryState::InProgress) => panic!("query in progress"), + Some(QueryState::Memoized(memo)) => memo.changed_at.is_constant(), + } + } } impl UncheckedMutQueryStorageOps for DerivedStorage @@ -443,6 +452,19 @@ where fn verify_inputs(&self, db: &DB) -> bool { match self.changed_at { + ChangedAt::Constant(_) => { + // If we know that the value is constant, it had + // better not change, but in that case, we ought not + // to have any inputs. Using `debug_assert` because + // this is on the fast path. + debug_assert!(match &self.inputs { + QueryDescriptorSet::Tracked(inputs) => inputs.is_empty(), + QueryDescriptorSet::Untracked => false, + }); + + true + } + ChangedAt::Revision(revision) => match &self.inputs { QueryDescriptorSet::Tracked(inputs) => inputs .iter() diff --git a/src/input.rs b/src/input.rs index 80678bfa..a90c92ea 100644 --- a/src/input.rs +++ b/src/input.rs @@ -4,7 +4,7 @@ use crate::runtime::Revision; use crate::runtime::StampedValue; use crate::CycleDetected; use crate::Database; -use crate::MutQueryStorageOps; +use crate::InputQueryStorageOps; use crate::Query; use crate::QueryDescriptor; use crate::QueryStorageOps; @@ -46,6 +46,8 @@ where } } +struct IsConstant(bool); + impl InputStorage where Q: Query, @@ -70,6 +72,67 @@ where changed_at: ChangedAt::Revision(Revision::ZERO), }) } + + fn set_common(&self, db: &DB, key: &Q::Key, value: Q::Value, is_constant: IsConstant) { + let mut map = self.map.write(); + + // If this value was previously stored, check if this is an + // *actual change* before we do anything. + if let Some(old_value) = map.get_mut(key) { + if old_value.value == value { + // If the value did not change, but it is now + // considered constant, we can just update + // `changed_at`. We don't have to trigger a new + // revision for this case: all the derived values are + // still intact, they just have conservative + // dependencies. The next revision, they may wind up + // with something more precise. + if is_constant.0 && !old_value.changed_at.is_constant() { + old_value.changed_at = + ChangedAt::Constant(db.salsa_runtime().current_revision()); + } + + return; + } + } + + let key = key.clone(); + + // The value is changing, so even if we are setting this to a + // constant, we still need a new revision. + let next_revision = db.salsa_runtime().increment_revision(); + + // Do this *after* we acquire the lock, so that we are not + // racing with somebody else to modify this same cell. + // (Otherwise, someone else might write a *newer* revision + // into the same cell while we block on the lock.) + let changed_at = if is_constant.0 { + ChangedAt::Constant(next_revision) + } else { + ChangedAt::Revision(next_revision) + }; + + let stamped_value = StampedValue { value, changed_at }; + + match map.entry(key) { + Entry::Occupied(mut entry) => { + assert!( + !entry.get().changed_at.is_constant(), + "modifying `{:?}({:?})`, which was previously marked as constant (old value `{:?}`, new value `{:?}`)", + Q::default(), + entry.key(), + entry.get().value, + stamped_value.value, + ); + + entry.insert(stamped_value); + } + + Entry::Vacant(entry) => { + entry.insert(stamped_value); + } + } + } } impl QueryStorageOps for InputStorage @@ -115,26 +178,28 @@ where changed_at.changed_since(revision) } + + fn is_constant(&self, _db: &DB, key: &Q::Key) -> bool { + let map_read = self.map.read(); + map_read + .get(key) + .map(|v| v.changed_at.is_constant()) + .unwrap_or(false) + } } -impl MutQueryStorageOps for InputStorage +impl InputQueryStorageOps for InputStorage where Q: Query, DB: Database, Q::Value: Default, { fn set(&self, db: &DB, key: &Q::Key, value: Q::Value) { - let key = key.clone(); + self.set_common(db, key, value, IsConstant(false)) + } - let mut map_write = self.map.write(); - - // Do this *after* we acquire the lock, so that we are not - // racing with somebody else to modify this same cell. - // (Otherwise, someone else might write a *newer* revision - // into the same cell while we block on the lock.) - let changed_at = ChangedAt::Revision(db.salsa_runtime().increment_revision()); - - map_write.insert(key, StampedValue { value, changed_at }); + fn set_constant(&self, db: &DB, key: &Q::Key, value: Q::Value) { + self.set_common(db, key, value, IsConstant(true)) } } diff --git a/src/lib.rs b/src/lib.rs index 56b6faca..80653de8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ use std::fmt::Display; use std::fmt::Write; use std::hash::Hash; +pub mod debug; pub mod derived; pub mod input; pub mod runtime; @@ -119,17 +120,22 @@ where key: &Q::Key, descriptor: &DB::QueryDescriptor, ) -> bool; + + /// Check if `key` is (currently) believed to be a constant. + fn is_constant(&self, db: &DB, key: &Q::Key) -> bool; } /// An optional trait that is implemented for "user mutable" storage: /// that is, storage whose value is not derived from other storage but /// is set independently. -pub trait MutQueryStorageOps: Default +pub trait InputQueryStorageOps: Default where DB: Database, Q: Query, { fn set(&self, db: &DB, key: &Q::Key, new_value: Q::Value); + + fn set_constant(&self, db: &DB, key: &Q::Key, new_value: Q::Value); } /// An optional trait that is implemented for "user mutable" storage: @@ -146,8 +152,8 @@ where #[derive(new)] pub struct QueryTable<'me, DB, Q> where - DB: Database, - Q: Query, + DB: Database + 'me, + Q: Query + 'me, { db: &'me DB, storage: &'me Q::Storage, @@ -170,15 +176,25 @@ where }) } - /// Assign a value to an "input queries". Must be used outside of + /// Assign a value to an "input query". Must be used outside of /// an active query computation. pub fn set(&self, key: Q::Key, value: Q::Value) where - Q::Storage: MutQueryStorageOps, + Q::Storage: InputQueryStorageOps, { self.storage.set(self.db, &key, value); } + /// Assign a value to an "input query", with the additional + /// promise that this value will **never change**. Must be used + /// outside of an active query computation. + pub fn set_constant(&self, key: Q::Key, value: Q::Value) + where + Q::Storage: InputQueryStorageOps, + { + self.storage.set_constant(self.db, &key, value); + } + /// Assigns a value to the query **bypassing the normal /// incremental checking** -- this value becomes the value for the /// query in the current revision. This can even be used on diff --git a/src/runtime.rs b/src/runtime.rs index e032453d..41c6cba7 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -134,7 +134,11 @@ where /// - `descriptor`: the query whose result was read /// - `changed_revision`: the last revision in which the result of that /// query had changed - pub(crate) fn report_query_read(&self, descriptor: &DB::QueryDescriptor, changed_at: ChangedAt) { + pub(crate) fn report_query_read( + &self, + descriptor: &DB::QueryDescriptor, + changed_at: ChangedAt, + ) { if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { top_query.add_read(descriptor, changed_at); } @@ -194,14 +198,22 @@ impl ActiveQuery { fn new(descriptor: DB::QueryDescriptor) -> Self { ActiveQuery { descriptor, - changed_at: ChangedAt::Revision(Revision::ZERO), + changed_at: ChangedAt::Constant(Revision::ZERO), subqueries: QueryDescriptorSet::default(), } } fn add_read(&mut self, subquery: &DB::QueryDescriptor, changed_at: ChangedAt) { - self.subqueries.insert(subquery.clone()); - self.changed_at = self.changed_at.max(changed_at); + match changed_at { + ChangedAt::Constant(_) => { + // When we read constant values, we don't need to + // track the source of the value. + } + ChangedAt::Revision(_) => { + self.subqueries.insert(subquery.clone()); + self.changed_at = self.changed_at.max(changed_at); + } + } } fn add_untracked_read(&mut self, changed_at: ChangedAt) { @@ -232,14 +244,28 @@ impl std::fmt::Debug for Revision { /// changed. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum ChangedAt { + /// Will never change again (and the revision in which we became a + /// constant). + Constant(Revision), + + /// Last changed in the given revision. May change in the future. Revision(Revision), } impl ChangedAt { - /// True if this value has changed after `revision`. + pub fn is_constant(self) -> bool { + match self { + ChangedAt::Constant(_) => true, + ChangedAt::Revision(_) => false, + } + } + + /// True if a value is stored with this `ChangedAt` value has + /// changed after `revision`. This is invoked by query storage + /// when their dependents are asking them if they have changed. pub fn changed_since(self, revision: Revision) -> bool { match self { - ChangedAt::Revision(r) => r > revision, + ChangedAt::Constant(r) | ChangedAt::Revision(r) => r > revision, } } } diff --git a/tests/incremental/constants.rs b/tests/incremental/constants.rs new file mode 100644 index 00000000..a7391bda --- /dev/null +++ b/tests/incremental/constants.rs @@ -0,0 +1,127 @@ +use crate::implementation::{TestContext, TestContextImpl}; +use salsa::debug::DebugQueryTable; +use salsa::Database; + +salsa::query_group! { + pub(crate) trait ConstantsDatabase: TestContext { + fn constants_input(key: char) -> usize { + type ConstantsInput; + storage input; + } + + fn constants_add(keys: (char, char)) -> usize { + type ConstantsAdd; + } + } +} + +fn constants_add(db: &impl ConstantsDatabase, (key1, key2): (char, char)) -> usize { + db.log().add(format!("add({}, {})", key1, key2)); + db.constants_input(key1) + db.constants_input(key2) +} + +#[test] +#[should_panic] +fn invalidate_constant() { + let db = &TestContextImpl::default(); + db.query(ConstantsInput).set_constant('a', 44); + db.query(ConstantsInput).set_constant('a', 66); +} + +#[test] +#[should_panic] +fn invalidate_constant_1() { + let db = &TestContextImpl::default(); + + // Not constant: + db.query(ConstantsInput).set('a', 44); + + // Becomes constant: + db.query(ConstantsInput).set_constant('a', 44); + + // Invalidates: + db.query(ConstantsInput).set_constant('a', 66); +} + +/// Test that use can still `set` an input that is constant, so long +/// as you don't change the value. +#[test] +fn set_after_constant_same_value() { + let db = &TestContextImpl::default(); + db.query(ConstantsInput).set_constant('a', 44); + db.query(ConstantsInput).set('a', 44); +} + +#[test] +fn not_constant() { + let db = &TestContextImpl::default(); + + db.query(ConstantsInput).set('a', 22); + db.query(ConstantsInput).set('b', 44); + assert_eq!(db.constants_add(('a', 'b')), 66); + assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); +} + +#[test] +fn is_constant() { + let db = &TestContextImpl::default(); + + db.query(ConstantsInput).set_constant('a', 22); + db.query(ConstantsInput).set_constant('b', 44); + assert_eq!(db.constants_add(('a', 'b')), 66); + assert!(db.query(ConstantsAdd).is_constant(('a', 'b'))); +} + +#[test] +fn mixed_constant() { + let db = &TestContextImpl::default(); + + db.query(ConstantsInput).set_constant('a', 22); + db.query(ConstantsInput).set('b', 44); + assert_eq!(db.constants_add(('a', 'b')), 66); + assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); +} + +#[test] +fn becomes_constant() { + let db = &TestContextImpl::default(); + + db.query(ConstantsInput).set('a', 22); + db.query(ConstantsInput).set('b', 44); + assert_eq!(db.constants_add(('a', 'b')), 66); + assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); + + db.query(ConstantsInput).set_constant('a', 23); + assert_eq!(db.constants_add(('a', 'b')), 67); + assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); + + db.query(ConstantsInput).set_constant('b', 45); + assert_eq!(db.constants_add(('a', 'b')), 68); + assert!(db.query(ConstantsAdd).is_constant(('a', 'b'))); +} + +#[test] +fn becomes_constant_no_change() { + let db = &TestContextImpl::default(); + + db.query(ConstantsInput).set('a', 22); + db.query(ConstantsInput).set('b', 44); + assert_eq!(db.constants_add(('a', 'b')), 66); + assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); + db.assert_log(&["add(a, b)"]); + + // 'a' is now constant, but the value did not change; this + // should not in and of itself trigger a new revision. + db.query(ConstantsInput).set_constant('a', 22); + assert_eq!(db.constants_add(('a', 'b')), 66); + assert!(!db.query(ConstantsAdd).is_constant(('a', 'b'))); + db.assert_log(&[]); // no new revision, no new log entries + + // 'b' is now constant, and its value DID change. This triggers a + // new revision, and at that point we figure out that we are + // constant. + db.query(ConstantsInput).set_constant('b', 45); + assert_eq!(db.constants_add(('a', 'b')), 67); + assert!(db.query(ConstantsAdd).is_constant(('a', 'b'))); + db.assert_log(&["add(a, b)"]); +} diff --git a/tests/incremental/implementation.rs b/tests/incremental/implementation.rs index 868e49c9..36809bef 100644 --- a/tests/incremental/implementation.rs +++ b/tests/incremental/implementation.rs @@ -1,3 +1,4 @@ +use crate::constants; use crate::counter::Counter; use crate::log::Log; use crate::memoized_dep_inputs; @@ -43,6 +44,11 @@ impl TestContextImpl { salsa::database_storage! { pub(crate) struct TestContextImplStorage for TestContextImpl { + impl constants::ConstantsDatabase { + fn constants_input() for constants::ConstantsInput; + fn constants_derived() for constants::ConstantsAdd; + } + impl memoized_dep_inputs::MemoizedDepInputsContext { fn dep_memoized2() for memoized_dep_inputs::Memoized2; fn dep_memoized1() for memoized_dep_inputs::Memoized1; diff --git a/tests/incremental/main.rs b/tests/incremental/main.rs index 33a623cd..bcd13c75 100644 --- a/tests/incremental/main.rs +++ b/tests/incremental/main.rs @@ -1,3 +1,4 @@ +mod constants; mod counter; mod implementation; mod log; diff --git a/tests/incremental/memoized_inputs.rs b/tests/incremental/memoized_inputs.rs index 0fdc9b79..23e690f3 100644 --- a/tests/incremental/memoized_inputs.rs +++ b/tests/incremental/memoized_inputs.rs @@ -60,3 +60,20 @@ fn revalidate() { assert_eq!(v, 66); db.assert_log(&[]); } + +/// Test that invoking `set` on an input with the same value does not +/// trigger a new revision. +#[test] +fn set_after_no_change() { + let db = &TestContextImpl::default(); + + db.query(Input1).set((), 44); + let v = db.max(()); + assert_eq!(v, 44); + db.assert_log(&["Max invoked"]); + + db.query(Input1).set((), 44); + let v = db.max(()); + assert_eq!(v, 44); + db.assert_log(&[]); +}