mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-22 21:05:11 +00:00
Expose the ability to remove the value from an input query, taking ownership of it
Co-authored-by: Tim Robinson <tim.g.robinson@gmail.com>
This commit is contained in:
parent
887f24c06d
commit
1e3c2f22aa
11 changed files with 101 additions and 32 deletions
|
@ -253,6 +253,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
if let QueryStorage::Input = query.storage {
|
||||
let set_fn_name = format_ident!("set_{}", fn_name);
|
||||
let set_with_durability_fn_name = format_ident!("set_{}_with_durability", fn_name);
|
||||
let remove_fn_name = format_ident!("remove_{}", fn_name);
|
||||
|
||||
let set_fn_docs = format!(
|
||||
"
|
||||
|
@ -283,13 +284,30 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
fn_name = fn_name
|
||||
);
|
||||
|
||||
let remove_fn_docs = format!(
|
||||
"
|
||||
Remove the value from the `{fn_name}` input.
|
||||
|
||||
See `{fn_name}` for details. Panics if a value has
|
||||
not previously been set using `set_{fn_name}` or
|
||||
`set_{fn_name}_with_durability`.
|
||||
|
||||
*Note:* Setting values will trigger cancellation
|
||||
of any ongoing queries; this method blocks until
|
||||
those queries have been cancelled.
|
||||
",
|
||||
fn_name = fn_name
|
||||
);
|
||||
|
||||
query_fn_declarations.extend(quote! {
|
||||
# [doc = #set_fn_docs]
|
||||
fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value);
|
||||
|
||||
|
||||
# [doc = #set_constant_fn_docs]
|
||||
fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability);
|
||||
|
||||
# [doc = #remove_fn_docs]
|
||||
fn #remove_fn_name(&mut self, #(#key_names: #keys,)*) -> #value;
|
||||
});
|
||||
|
||||
query_fn_definitions.extend(quote! {
|
||||
|
@ -306,6 +324,13 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
}
|
||||
__shim(self, #(#key_names,)* value__ ,durability__)
|
||||
}
|
||||
|
||||
fn #remove_fn_name(&mut self, #(#key_names: #keys,)*) -> #value {
|
||||
fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)*) -> #value {
|
||||
salsa::plumbing::get_query_table_mut::<#qt>(db).remove((#(#key_names),*))
|
||||
}
|
||||
__shim(self, #(#key_names,)*)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use arc_swap::Guard;
|
|||
|
||||
use crate::{
|
||||
plumbing::{DatabaseOps, QueryFunction},
|
||||
runtime::{local_state::QueryInputs, StampedValue},
|
||||
runtime::{StampedValue},
|
||||
Database, QueryDb,
|
||||
};
|
||||
|
||||
|
|
46
src/input.rs
46
src/input.rs
|
@ -15,7 +15,6 @@ use indexmap::map::Entry;
|
|||
use log::debug;
|
||||
use parking_lot::RwLock;
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Input queries store the result plus a list of the other queries
|
||||
/// that they invoked. This means we can avoid recomputing them when
|
||||
|
@ -25,7 +24,7 @@ where
|
|||
Q: Query,
|
||||
{
|
||||
group_index: u16,
|
||||
slots: RwLock<FxIndexMap<Q::Key, Arc<Slot<Q>>>>,
|
||||
slots: RwLock<FxIndexMap<Q::Key, Slot<Q>>>,
|
||||
}
|
||||
|
||||
struct Slot<Q>
|
||||
|
@ -45,15 +44,6 @@ where
|
|||
{
|
||||
}
|
||||
|
||||
impl<Q> InputStorage<Q>
|
||||
where
|
||||
Q: Query,
|
||||
{
|
||||
fn slot(&self, key: &Q::Key) -> Option<Arc<Slot<Q>>> {
|
||||
self.slots.read().get(key).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Q> QueryStorageOps<Q> for InputStorage<Q>
|
||||
where
|
||||
Q: Query,
|
||||
|
@ -89,21 +79,17 @@ where
|
|||
assert_eq!(input.group_index, self.group_index);
|
||||
assert_eq!(input.query_index, Q::QUERY_INDEX);
|
||||
debug_assert!(revision < db.salsa_runtime().current_revision());
|
||||
let slot = self
|
||||
.slots
|
||||
.read()
|
||||
.get_index(input.key_index as usize)
|
||||
.unwrap()
|
||||
.1
|
||||
.clone();
|
||||
let slots = self.slots.read();
|
||||
let (_, slot) = slots.get_index(input.key_index as usize).unwrap();
|
||||
slot.maybe_changed_after(db, revision)
|
||||
}
|
||||
|
||||
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
|
||||
db.unwind_if_cancelled();
|
||||
|
||||
let slot = self
|
||||
.slot(key)
|
||||
let slots = self.slots.read();
|
||||
let slot = slots
|
||||
.get(key)
|
||||
.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
|
||||
|
||||
let StampedValue {
|
||||
|
@ -123,7 +109,8 @@ where
|
|||
}
|
||||
|
||||
fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
|
||||
match self.slot(key) {
|
||||
let slots = self.slots.read();
|
||||
match slots.get(key) {
|
||||
Some(slot) => slot.stamped_value.read().durability,
|
||||
None => panic!("no value set for {:?}({:?})", Q::default(), key),
|
||||
}
|
||||
|
@ -229,16 +216,29 @@ where
|
|||
query_index: Q::QUERY_INDEX,
|
||||
key_index,
|
||||
};
|
||||
entry.insert(Arc::new(Slot {
|
||||
entry.insert(Slot {
|
||||
key: key.clone(),
|
||||
database_key_index,
|
||||
stamped_value: RwLock::new(stamped_value),
|
||||
}));
|
||||
});
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn remove(&self, runtime: &mut Runtime, key: &<Q as Query>::Key) -> <Q as Query>::Value {
|
||||
let mut value = None;
|
||||
runtime.with_incremented_revision(&mut |_| {
|
||||
let mut slots = self.slots.write();
|
||||
let slot = slots.remove(key)?;
|
||||
let slot_stamped_value = slot.stamped_value.into_inner();
|
||||
value = Some(slot_stamped_value.value);
|
||||
Some(slot_stamped_value.durability)
|
||||
});
|
||||
|
||||
value.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that `Slot<Q, MP>: Send + Sync` as long as
|
||||
|
|
18
src/lib.rs
18
src/lib.rs
|
@ -569,6 +569,24 @@ where
|
|||
self.storage.set(self.runtime, &key, value, durability);
|
||||
}
|
||||
|
||||
/// Removes a value from an "input query". Must be used outside of
|
||||
/// an active query computation.
|
||||
///
|
||||
/// If you are using `snapshot`, see the notes on blocking
|
||||
/// and cancellation on [the `query_mut` method].
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the value was not previously set by `set` or
|
||||
/// `set_with_durability`.
|
||||
///
|
||||
/// [the `query_mut` method]: trait.Database.html#method.query_mut
|
||||
pub fn remove(&mut self, key: Q::Key) -> Q::Value
|
||||
where
|
||||
Q::Storage: plumbing::InputQueryStorageOps<Q>,
|
||||
{
|
||||
self.storage.remove(self.runtime, &key)
|
||||
}
|
||||
|
||||
/// Sets the size of LRU cache of values for this query table.
|
||||
///
|
||||
/// That is, at most `cap` values will be preset in the table at the same
|
||||
|
|
|
@ -218,6 +218,8 @@ where
|
|||
Q: Query,
|
||||
{
|
||||
fn set(&self, runtime: &mut Runtime, key: &Q::Key, new_value: Q::Value, durability: Durability);
|
||||
|
||||
fn remove(&self, runtime: &mut Runtime, key: &Q::Key) -> Q::Value;
|
||||
}
|
||||
|
||||
/// An optional trait that is implemented for "user mutable" storage:
|
||||
|
|
|
@ -47,7 +47,8 @@ fn revalidate() {
|
|||
db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
|
||||
|
||||
// Here validation of Memoized1 succeeds so Memoized2 never runs.
|
||||
db.set_dep_input1(45);
|
||||
let value = db.remove_dep_input1() + 1;
|
||||
db.set_dep_input1(value);
|
||||
let v = db.dep_memoized2();
|
||||
assert_eq!(v, 44);
|
||||
db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);
|
||||
|
|
|
@ -47,12 +47,22 @@ fn revalidate() {
|
|||
db.set_input1(64);
|
||||
db.assert_log(&[]);
|
||||
|
||||
let value = db.remove_input1() + 1;
|
||||
db.set_input1(value);
|
||||
db.assert_log(&[]);
|
||||
let value = db.remove_input2() + 1;
|
||||
db.set_input2(value);
|
||||
db.assert_log(&[]);
|
||||
let value = db.remove_input1() + 1;
|
||||
db.set_input1(value);
|
||||
db.assert_log(&[]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 66);
|
||||
assert_eq!(v, 67);
|
||||
db.assert_log(&["Max invoked"]);
|
||||
|
||||
let v = db.max();
|
||||
assert_eq!(v, 66);
|
||||
assert_eq!(v, 67);
|
||||
db.assert_log(&[]);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,8 @@ fn in_par_get_set_cancellation() {
|
|||
signal.wait_for(1);
|
||||
|
||||
// This will block until thread1 drops the revision lock.
|
||||
db.set_input('a', 2);
|
||||
let value = db.remove_input('a') + 1;
|
||||
db.set_input('a', value);
|
||||
|
||||
db.input('a')
|
||||
}
|
||||
|
|
|
@ -19,7 +19,8 @@ fn in_par_get_set_race() {
|
|||
});
|
||||
|
||||
let thread2 = std::thread::spawn(move || {
|
||||
db.set_input('a', 1000);
|
||||
let value = db.remove_input('a') * 10;
|
||||
db.set_input('a', value);
|
||||
db.sum("a")
|
||||
});
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ enum MutatorOp {
|
|||
|
||||
#[derive(Debug)]
|
||||
enum WriteOp {
|
||||
AddA(usize, isize),
|
||||
SetA(usize, usize),
|
||||
}
|
||||
|
||||
|
@ -92,7 +93,11 @@ impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standar
|
|||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
|
||||
let key = rng.gen::<usize>() % 10;
|
||||
let value = rng.gen::<usize>() % 10;
|
||||
WriteOp::SetA(key, value)
|
||||
if rng.gen_bool(0.5) {
|
||||
WriteOp::AddA(key, value as isize - 5)
|
||||
} else {
|
||||
WriteOp::SetA(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,6 +121,11 @@ fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellatio
|
|||
impl WriteOp {
|
||||
fn execute(self, db: &mut StressDatabaseImpl) {
|
||||
match self {
|
||||
WriteOp::AddA(key, value_delta) => {
|
||||
let value = db.remove_a(key);
|
||||
let value = (value as isize + value_delta) as usize;
|
||||
db.set_a(key, value);
|
||||
}
|
||||
WriteOp::SetA(key, value) => {
|
||||
db.set_a(key, value);
|
||||
}
|
||||
|
|
|
@ -33,7 +33,8 @@ fn transparent_queries_work() {
|
|||
assert_eq!(db.get(1), 10);
|
||||
assert_eq!(db.get(1), 10);
|
||||
|
||||
db.set_input(1, 92);
|
||||
let value = db.remove_input(1) + 82;
|
||||
db.set_input(1, value);
|
||||
assert_eq!(db.get(1), 92);
|
||||
assert_eq!(db.get(1), 92);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue