diff --git a/src/input.rs b/src/input.rs index 6d1d4455..bdce2d19 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,6 +1,7 @@ use crate::runtime::QueryDescriptorSet; use crate::runtime::Revision; use crate::CycleDetected; +use crate::MutQueryStorageOps; use crate::Query; use crate::QueryContext; use crate::QueryDescriptor; @@ -116,6 +117,27 @@ where } } +impl MutQueryStorageOps for InputStorage +where + Q: Query, + QC: QueryContext, + Q::Value: Default, +{ + fn set(&self, query: &QC, key: &Q::Key, value: Q::Value) { + let key = key.clone(); + + let mut map_write = self.map.write(); + + // Do this *after* we acquire the lock, so that we are not + // racing with somebody else to modify this same cell. + // (Otherwise, someone else might write a *newer* revision + // into the same cell while we block on the lock.) + let changed_at = query.salsa_runtime().increment_revision(); + + map_write.insert(key, StampedValue { value, changed_at }); + } +} + #[derive(Clone)] struct StampedValue { value: V, diff --git a/src/lib.rs b/src/lib.rs index 78abf5ee..edaafa4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,11 +147,20 @@ where }) } + /// Equivalent to `of(DefaultKey::default_key())` + pub fn read(&self) -> Q::Value + where + Q::Key: DefaultKey, + { + self.of(DefaultKey::default_key()) + } + + /// Assign a value to an "input queries". Must be used outside of + /// an active query computation. pub fn set(&self, key: Q::Key, value: Q::Value) where Q::Storage: MutQueryStorageOps, { - self.query.salsa_runtime().next_revision(); self.storage.set(self.query, &key, value); } @@ -160,6 +169,20 @@ where } } +/// A variant of the `Default` trait used for query keys that are +/// either singletons (e.g., `()`) or have some overwhelming default. +/// In this case, you can write `query.my_query().read()` as a +/// convenient shorthand. +pub trait DefaultKey { + fn default_key() -> Self; +} + +impl DefaultKey for () { + fn default_key() -> Self { + () + } +} + /// A macro that helps in defining the "context trait" of a given /// module. This is a trait that defines everything that a block of /// queries need to execute, as well as defining the queries @@ -360,7 +383,7 @@ macro_rules! query_definition { ( @storage_ty[$QC:ident, $Self:ident, input] ) => { - $crate::volatile::InputStorage<$QC, $Self> + $crate::input::InputStorage<$QC, $Self> }; // Various legal start states: diff --git a/tests/incremental/implementation.rs b/tests/incremental/implementation.rs index 41ea5be0..10cbc0bd 100644 --- a/tests/incremental/implementation.rs +++ b/tests/incremental/implementation.rs @@ -1,5 +1,6 @@ use crate::counter::Counter; use crate::log::Log; +use crate::memoized_inputs; use crate::memoized_volatile; crate trait TestContext: salsa::QueryContext { @@ -46,6 +47,12 @@ salsa::query_context_storage! { fn memoized1() for memoized_volatile::Memoized1; fn volatile() for memoized_volatile::Volatile; } + + impl memoized_inputs::MemoizedInputsContext { + fn max() for memoized_inputs::Max; + fn input1() for memoized_inputs::Input1; + fn input2() for memoized_inputs::Input2; + } } } diff --git a/tests/incremental/main.rs b/tests/incremental/main.rs index 27092b9f..25b9a967 100644 --- a/tests/incremental/main.rs +++ b/tests/incremental/main.rs @@ -4,6 +4,7 @@ mod counter; mod implementation; mod log; +mod memoized_inputs; mod memoized_volatile; fn main() {} diff --git a/tests/incremental/memoized_inputs.rs b/tests/incremental/memoized_inputs.rs new file mode 100644 index 00000000..dce500ad --- /dev/null +++ b/tests/incremental/memoized_inputs.rs @@ -0,0 +1,72 @@ +use crate::implementation::{TestContext, TestContextImpl}; + +crate trait MemoizedInputsContext: TestContext { + salsa::query_prototype! { + fn max() for Max; + fn input1() for Input1; + fn input2() for Input2; + } +} + +salsa::query_definition! { + crate Max(query: &impl MemoizedInputsContext, (): ()) -> usize { + query.log().add("Max invoked"); + std::cmp::max( + query.input1().read(), + query.input2().read(), + ) + } +} + +salsa::query_definition! { + #[storage(input)] + crate Input1(_query: &impl MemoizedInputsContext, _value: ()) -> usize { + panic!("silly") + } +} + +salsa::query_definition! { + #[storage(input)] + crate Input2(_query: &impl MemoizedInputsContext, _value: ()) -> usize { + panic!("silly") + } +} + +#[test] +fn revalidate() { + let query = TestContextImpl::default(); + + let v = query.max().of(()); + assert_eq!(v, 0); + query.assert_log(&["Max invoked"]); + + let v = query.max().of(()); + assert_eq!(v, 0); + query.assert_log(&[]); + + query.input1().set((), 44); + query.assert_log(&[]); + + let v = query.max().of(()); + assert_eq!(v, 44); + query.assert_log(&["Max invoked"]); + + let v = query.max().of(()); + assert_eq!(v, 44); + query.assert_log(&[]); + + query.input1().set((), 44); + query.assert_log(&[]); + query.input2().set((), 66); + query.assert_log(&[]); + query.input1().set((), 64); + query.assert_log(&[]); + + let v = query.max().of(()); + assert_eq!(v, 66); + query.assert_log(&["Max invoked"]); + + let v = query.max().of(()); + assert_eq!(v, 66); + query.assert_log(&[]); +} diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index cddd4a17..14c113ca 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -54,8 +54,6 @@ fn volatile_x2() { /// - On the first run of R2, we recompute everything (since Memoized1 result *did* change). #[test] fn revalidate() { - env_logger::init(); - let query = TestContextImpl::default(); query.memoized2().of(());