Expose the ability to remove the value from an input query, taking ownership of it

Co-authored-by: Tim Robinson <tim.g.robinson@gmail.com>
This commit is contained in:
Niko Matsakis 2022-06-03 05:49:23 -04:00
parent 887f24c06d
commit 1e3c2f22aa
11 changed files with 101 additions and 32 deletions

View file

@ -253,6 +253,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
if let QueryStorage::Input = query.storage {
let set_fn_name = format_ident!("set_{}", fn_name);
let set_with_durability_fn_name = format_ident!("set_{}_with_durability", fn_name);
let remove_fn_name = format_ident!("remove_{}", fn_name);
let set_fn_docs = format!(
"
@ -283,13 +284,30 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
fn_name = fn_name
);
let remove_fn_docs = format!(
"
Remove the value from the `{fn_name}` input.
See `{fn_name}` for details. Panics if a value has
not previously been set using `set_{fn_name}` or
`set_{fn_name}_with_durability`.
*Note:* Setting values will trigger cancellation
of any ongoing queries; this method blocks until
those queries have been cancelled.
",
fn_name = fn_name
);
query_fn_declarations.extend(quote! {
# [doc = #set_fn_docs]
fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value);
# [doc = #set_constant_fn_docs]
fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability);
# [doc = #remove_fn_docs]
fn #remove_fn_name(&mut self, #(#key_names: #keys,)*) -> #value;
});
query_fn_definitions.extend(quote! {
@ -306,6 +324,13 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
}
__shim(self, #(#key_names,)* value__ ,durability__)
}
fn #remove_fn_name(&mut self, #(#key_names: #keys,)*) -> #value {
fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)*) -> #value {
salsa::plumbing::get_query_table_mut::<#qt>(db).remove((#(#key_names),*))
}
__shim(self, #(#key_names,)*)
}
});
}

View file

@ -2,7 +2,7 @@ use arc_swap::Guard;
use crate::{
plumbing::{DatabaseOps, QueryFunction},
runtime::{local_state::QueryInputs, StampedValue},
runtime::{StampedValue},
Database, QueryDb,
};

View file

@ -15,7 +15,6 @@ use indexmap::map::Entry;
use log::debug;
use parking_lot::RwLock;
use std::convert::TryFrom;
use std::sync::Arc;
/// Input queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
@ -25,7 +24,7 @@ where
Q: Query,
{
group_index: u16,
slots: RwLock<FxIndexMap<Q::Key, Arc<Slot<Q>>>>,
slots: RwLock<FxIndexMap<Q::Key, Slot<Q>>>,
}
struct Slot<Q>
@ -45,15 +44,6 @@ where
{
}
impl<Q> InputStorage<Q>
where
Q: Query,
{
fn slot(&self, key: &Q::Key) -> Option<Arc<Slot<Q>>> {
self.slots.read().get(key).cloned()
}
}
impl<Q> QueryStorageOps<Q> for InputStorage<Q>
where
Q: Query,
@ -89,21 +79,17 @@ where
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision());
let slot = self
.slots
.read()
.get_index(input.key_index as usize)
.unwrap()
.1
.clone();
let slots = self.slots.read();
let (_, slot) = slots.get_index(input.key_index as usize).unwrap();
slot.maybe_changed_after(db, revision)
}
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
db.unwind_if_cancelled();
let slot = self
.slot(key)
let slots = self.slots.read();
let slot = slots
.get(key)
.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
let StampedValue {
@ -123,7 +109,8 @@ where
}
fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
match self.slot(key) {
let slots = self.slots.read();
match slots.get(key) {
Some(slot) => slot.stamped_value.read().durability,
None => panic!("no value set for {:?}({:?})", Q::default(), key),
}
@ -229,16 +216,29 @@ where
query_index: Q::QUERY_INDEX,
key_index,
};
entry.insert(Arc::new(Slot {
entry.insert(Slot {
key: key.clone(),
database_key_index,
stamped_value: RwLock::new(stamped_value),
}));
});
None
}
}
});
}
fn remove(&self, runtime: &mut Runtime, key: &<Q as Query>::Key) -> <Q as Query>::Value {
let mut value = None;
runtime.with_incremented_revision(&mut |_| {
let mut slots = self.slots.write();
let slot = slots.remove(key)?;
let slot_stamped_value = slot.stamped_value.into_inner();
value = Some(slot_stamped_value.value);
Some(slot_stamped_value.durability)
});
value.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key))
}
}
/// Check that `Slot<Q, MP>: Send + Sync` as long as

View file

@ -569,6 +569,24 @@ where
self.storage.set(self.runtime, &key, value, durability);
}
/// Removes a value from an "input query". Must be used outside of
/// an active query computation.
///
/// If you are using `snapshot`, see the notes on blocking
/// and cancellation on [the `query_mut` method].
///
/// # Panics
/// Panics if the value was not previously set by `set` or
/// `set_with_durability`.
///
/// [the `query_mut` method]: trait.Database.html#method.query_mut
pub fn remove(&mut self, key: Q::Key) -> Q::Value
where
Q::Storage: plumbing::InputQueryStorageOps<Q>,
{
self.storage.remove(self.runtime, &key)
}
/// Sets the size of LRU cache of values for this query table.
///
/// That is, at most `cap` values will be preset in the table at the same

View file

@ -218,6 +218,8 @@ where
Q: Query,
{
fn set(&self, runtime: &mut Runtime, key: &Q::Key, new_value: Q::Value, durability: Durability);
fn remove(&self, runtime: &mut Runtime, key: &Q::Key) -> Q::Value;
}
/// An optional trait that is implemented for "user mutable" storage:

View file

@ -47,7 +47,8 @@ fn revalidate() {
db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
// Here validation of Memoized1 succeeds so Memoized2 never runs.
db.set_dep_input1(45);
let value = db.remove_dep_input1() + 1;
db.set_dep_input1(value);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);

View file

@ -47,12 +47,22 @@ fn revalidate() {
db.set_input1(64);
db.assert_log(&[]);
let value = db.remove_input1() + 1;
db.set_input1(value);
db.assert_log(&[]);
let value = db.remove_input2() + 1;
db.set_input2(value);
db.assert_log(&[]);
let value = db.remove_input1() + 1;
db.set_input1(value);
db.assert_log(&[]);
let v = db.max();
assert_eq!(v, 66);
assert_eq!(v, 67);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 66);
assert_eq!(v, 67);
db.assert_log(&[]);
}

View file

@ -44,7 +44,8 @@ fn in_par_get_set_cancellation() {
signal.wait_for(1);
// This will block until thread1 drops the revision lock.
db.set_input('a', 2);
let value = db.remove_input('a') + 1;
db.set_input('a', value);
db.input('a')
}

View file

@ -19,7 +19,8 @@ fn in_par_get_set_race() {
});
let thread2 = std::thread::spawn(move || {
db.set_input('a', 1000);
let value = db.remove_input('a') * 10;
db.set_input('a', value);
db.sum("a")
});

View file

@ -61,6 +61,7 @@ enum MutatorOp {
#[derive(Debug)]
enum WriteOp {
AddA(usize, isize),
SetA(usize, usize),
}
@ -92,9 +93,13 @@ impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standar
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
let key = rng.gen::<usize>() % 10;
let value = rng.gen::<usize>() % 10;
if rng.gen_bool(0.5) {
WriteOp::AddA(key, value as isize - 5)
} else {
WriteOp::SetA(key, value)
}
}
}
impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
@ -116,6 +121,11 @@ fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellatio
impl WriteOp {
fn execute(self, db: &mut StressDatabaseImpl) {
match self {
WriteOp::AddA(key, value_delta) => {
let value = db.remove_a(key);
let value = (value as isize + value_delta) as usize;
db.set_a(key, value);
}
WriteOp::SetA(key, value) => {
db.set_a(key, value);
}

View file

@ -33,7 +33,8 @@ fn transparent_queries_work() {
assert_eq!(db.get(1), 10);
assert_eq!(db.get(1), 10);
db.set_input(1, 92);
let value = db.remove_input(1) + 82;
db.set_input(1, value);
assert_eq!(db.get(1), 92);
assert_eq!(db.get(1), 92);
}