From 1e3c2f22aa8b3a1515cfe2e8ca6574713fe3abee Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Fri, 3 Jun 2022 05:49:23 -0400 Subject: [PATCH] Expose the ability to remove the value from an input query, taking ownership of it Co-authored-by: Tim Robinson --- components/salsa-macros/src/query_group.rs | 27 ++++++++++++- src/derived/fetch.rs | 2 +- src/input.rs | 46 +++++++++++----------- src/lib.rs | 18 +++++++++ src/plumbing.rs | 2 + tests/incremental/memoized_dep_inputs.rs | 3 +- tests/incremental/memoized_inputs.rs | 14 ++++++- tests/parallel/frozen.rs | 3 +- tests/parallel/race.rs | 3 +- tests/parallel/stress.rs | 12 +++++- tests/transparent.rs | 3 +- 11 files changed, 101 insertions(+), 32 deletions(-) diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 45234b2f..6f3d1917 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -253,6 +253,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream if let QueryStorage::Input = query.storage { let set_fn_name = format_ident!("set_{}", fn_name); let set_with_durability_fn_name = format_ident!("set_{}_with_durability", fn_name); + let remove_fn_name = format_ident!("remove_{}", fn_name); let set_fn_docs = format!( " @@ -283,13 +284,30 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream fn_name = fn_name ); + let remove_fn_docs = format!( + " + Remove the value from the `{fn_name}` input. + + See `{fn_name}` for details. Panics if a value has + not previously been set using `set_{fn_name}` or + `set_{fn_name}_with_durability`. + + *Note:* Setting values will trigger cancellation + of any ongoing queries; this method blocks until + those queries have been cancelled. + ", + fn_name = fn_name + ); + query_fn_declarations.extend(quote! { # [doc = #set_fn_docs] fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value); - # [doc = #set_constant_fn_docs] fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability); + + # [doc = #remove_fn_docs] + fn #remove_fn_name(&mut self, #(#key_names: #keys,)*) -> #value; }); query_fn_definitions.extend(quote! { @@ -306,6 +324,13 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream } __shim(self, #(#key_names,)* value__ ,durability__) } + + fn #remove_fn_name(&mut self, #(#key_names: #keys,)*) -> #value { + fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)*) -> #value { + salsa::plumbing::get_query_table_mut::<#qt>(db).remove((#(#key_names),*)) + } + __shim(self, #(#key_names,)*) + } }); } diff --git a/src/derived/fetch.rs b/src/derived/fetch.rs index 63306b9c..16413a8f 100644 --- a/src/derived/fetch.rs +++ b/src/derived/fetch.rs @@ -2,7 +2,7 @@ use arc_swap::Guard; use crate::{ plumbing::{DatabaseOps, QueryFunction}, - runtime::{local_state::QueryInputs, StampedValue}, + runtime::{StampedValue}, Database, QueryDb, }; diff --git a/src/input.rs b/src/input.rs index 85534cfb..f8a7bc85 100644 --- a/src/input.rs +++ b/src/input.rs @@ -15,7 +15,6 @@ use indexmap::map::Entry; use log::debug; use parking_lot::RwLock; use std::convert::TryFrom; -use std::sync::Arc; /// Input queries store the result plus a list of the other queries /// that they invoked. This means we can avoid recomputing them when @@ -25,7 +24,7 @@ where Q: Query, { group_index: u16, - slots: RwLock>>>, + slots: RwLock>>, } struct Slot @@ -45,15 +44,6 @@ where { } -impl InputStorage -where - Q: Query, -{ - fn slot(&self, key: &Q::Key) -> Option>> { - self.slots.read().get(key).cloned() - } -} - impl QueryStorageOps for InputStorage where Q: Query, @@ -89,21 +79,17 @@ where assert_eq!(input.group_index, self.group_index); assert_eq!(input.query_index, Q::QUERY_INDEX); debug_assert!(revision < db.salsa_runtime().current_revision()); - let slot = self - .slots - .read() - .get_index(input.key_index as usize) - .unwrap() - .1 - .clone(); + let slots = self.slots.read(); + let (_, slot) = slots.get_index(input.key_index as usize).unwrap(); slot.maybe_changed_after(db, revision) } fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { db.unwind_if_cancelled(); - let slot = self - .slot(key) + let slots = self.slots.read(); + let slot = slots + .get(key) .unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key)); let StampedValue { @@ -123,7 +109,8 @@ where } fn durability(&self, _db: &>::DynDb, key: &Q::Key) -> Durability { - match self.slot(key) { + let slots = self.slots.read(); + match slots.get(key) { Some(slot) => slot.stamped_value.read().durability, None => panic!("no value set for {:?}({:?})", Q::default(), key), } @@ -229,16 +216,29 @@ where query_index: Q::QUERY_INDEX, key_index, }; - entry.insert(Arc::new(Slot { + entry.insert(Slot { key: key.clone(), database_key_index, stamped_value: RwLock::new(stamped_value), - })); + }); None } } }); } + + fn remove(&self, runtime: &mut Runtime, key: &::Key) -> ::Value { + let mut value = None; + runtime.with_incremented_revision(&mut |_| { + let mut slots = self.slots.write(); + let slot = slots.remove(key)?; + let slot_stamped_value = slot.stamped_value.into_inner(); + value = Some(slot_stamped_value.value); + Some(slot_stamped_value.durability) + }); + + value.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key)) + } } /// Check that `Slot: Send + Sync` as long as diff --git a/src/lib.rs b/src/lib.rs index 0496b1f7..4d8fcaf3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -569,6 +569,24 @@ where self.storage.set(self.runtime, &key, value, durability); } + /// Removes a value from an "input query". Must be used outside of + /// an active query computation. + /// + /// If you are using `snapshot`, see the notes on blocking + /// and cancellation on [the `query_mut` method]. + /// + /// # Panics + /// Panics if the value was not previously set by `set` or + /// `set_with_durability`. + /// + /// [the `query_mut` method]: trait.Database.html#method.query_mut + pub fn remove(&mut self, key: Q::Key) -> Q::Value + where + Q::Storage: plumbing::InputQueryStorageOps, + { + self.storage.remove(self.runtime, &key) + } + /// Sets the size of LRU cache of values for this query table. /// /// That is, at most `cap` values will be preset in the table at the same diff --git a/src/plumbing.rs b/src/plumbing.rs index 53af6c27..16d00304 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -218,6 +218,8 @@ where Q: Query, { fn set(&self, runtime: &mut Runtime, key: &Q::Key, new_value: Q::Value, durability: Durability); + + fn remove(&self, runtime: &mut Runtime, key: &Q::Key) -> Q::Value; } /// An optional trait that is implemented for "user mutable" storage: diff --git a/tests/incremental/memoized_dep_inputs.rs b/tests/incremental/memoized_dep_inputs.rs index 4ea33e0c..d76cbfb1 100644 --- a/tests/incremental/memoized_dep_inputs.rs +++ b/tests/incremental/memoized_dep_inputs.rs @@ -47,7 +47,8 @@ fn revalidate() { db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]); // Here validation of Memoized1 succeeds so Memoized2 never runs. - db.set_dep_input1(45); + let value = db.remove_dep_input1() + 1; + db.set_dep_input1(value); let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]); diff --git a/tests/incremental/memoized_inputs.rs b/tests/incremental/memoized_inputs.rs index 53d2ace8..ae5bf763 100644 --- a/tests/incremental/memoized_inputs.rs +++ b/tests/incremental/memoized_inputs.rs @@ -47,12 +47,22 @@ fn revalidate() { db.set_input1(64); db.assert_log(&[]); + let value = db.remove_input1() + 1; + db.set_input1(value); + db.assert_log(&[]); + let value = db.remove_input2() + 1; + db.set_input2(value); + db.assert_log(&[]); + let value = db.remove_input1() + 1; + db.set_input1(value); + db.assert_log(&[]); + let v = db.max(); - assert_eq!(v, 66); + assert_eq!(v, 67); db.assert_log(&["Max invoked"]); let v = db.max(); - assert_eq!(v, 66); + assert_eq!(v, 67); db.assert_log(&[]); } diff --git a/tests/parallel/frozen.rs b/tests/parallel/frozen.rs index 5359a882..41c0b8e9 100644 --- a/tests/parallel/frozen.rs +++ b/tests/parallel/frozen.rs @@ -44,7 +44,8 @@ fn in_par_get_set_cancellation() { signal.wait_for(1); // This will block until thread1 drops the revision lock. - db.set_input('a', 2); + let value = db.remove_input('a') + 1; + db.set_input('a', value); db.input('a') } diff --git a/tests/parallel/race.rs b/tests/parallel/race.rs index e875de99..0d89f70f 100644 --- a/tests/parallel/race.rs +++ b/tests/parallel/race.rs @@ -19,7 +19,8 @@ fn in_par_get_set_race() { }); let thread2 = std::thread::spawn(move || { - db.set_input('a', 1000); + let value = db.remove_input('a') * 10; + db.set_input('a', value); db.sum("a") }); diff --git a/tests/parallel/stress.rs b/tests/parallel/stress.rs index 16a1b790..ac8c3092 100644 --- a/tests/parallel/stress.rs +++ b/tests/parallel/stress.rs @@ -61,6 +61,7 @@ enum MutatorOp { #[derive(Debug)] enum WriteOp { + AddA(usize, isize), SetA(usize, usize), } @@ -92,7 +93,11 @@ impl rand::distributions::Distribution for rand::distributions::Standar fn sample(&self, rng: &mut R) -> WriteOp { let key = rng.gen::() % 10; let value = rng.gen::() % 10; - WriteOp::SetA(key, value) + if rng.gen_bool(0.5) { + WriteOp::AddA(key, value as isize - 5) + } else { + WriteOp::SetA(key, value) + } } } @@ -116,6 +121,11 @@ fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec, check_cancellatio impl WriteOp { fn execute(self, db: &mut StressDatabaseImpl) { match self { + WriteOp::AddA(key, value_delta) => { + let value = db.remove_a(key); + let value = (value as isize + value_delta) as usize; + db.set_a(key, value); + } WriteOp::SetA(key, value) => { db.set_a(key, value); } diff --git a/tests/transparent.rs b/tests/transparent.rs index 2e6dd426..fce5c9c3 100644 --- a/tests/transparent.rs +++ b/tests/transparent.rs @@ -33,7 +33,8 @@ fn transparent_queries_work() { assert_eq!(db.get(1), 10); assert_eq!(db.get(1), 10); - db.set_input(1, 92); + let value = db.remove_input(1) + 82; + db.set_input(1, value); assert_eq!(db.get(1), 92); assert_eq!(db.get(1), 92); }