diff --git a/components/salsa-2022-macros/src/accumulator.rs b/components/salsa-2022-macros/src/accumulator.rs index cfc32c9e..26276398 100644 --- a/components/salsa-2022-macros/src/accumulator.rs +++ b/components/salsa-2022-macros/src/accumulator.rs @@ -32,6 +32,8 @@ impl crate::options::AllowedOptions for Accumulator { const DB: bool = false; const RECOVERY_FN: bool = false; + + const LRU: bool = false; } fn accumulator_contents( diff --git a/components/salsa-2022-macros/src/jar.rs b/components/salsa-2022-macros/src/jar.rs index 7a6498d4..5c582abd 100644 --- a/components/salsa-2022-macros/src/jar.rs +++ b/components/salsa-2022-macros/src/jar.rs @@ -41,6 +41,8 @@ impl crate::options::AllowedOptions for Jar { const DB: bool = true; const RECOVERY_FN: bool = false; + + const LRU: bool = false; } pub(crate) fn jar_struct_and_friends( diff --git a/components/salsa-2022-macros/src/options.rs b/components/salsa-2022-macros/src/options.rs index eb8bd937..185bde14 100644 --- a/components/salsa-2022-macros/src/options.rs +++ b/components/salsa-2022-macros/src/options.rs @@ -46,6 +46,11 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub data: Option, + /// The `lru = ` option is used to set the lru capacity for a tracked function. + /// + /// If this is `Some`, the value is the ``. + pub lru: Option, + /// Remember the `A` parameter, which plays no role after parsing. phantom: PhantomData, } @@ -61,6 +66,7 @@ impl Default for Options { recovery_fn: Default::default(), data: Default::default(), phantom: Default::default(), + lru: Default::default(), } } } @@ -74,6 +80,7 @@ pub(crate) trait AllowedOptions { const DATA: bool; const DB: bool; const RECOVERY_FN: bool; + const LRU: bool; } type Equals = syn::Token![=]; @@ -195,6 +202,20 @@ impl syn::parse::Parse for Options { "`data` option not allowed here", )); } + } else if ident == "lru" { + if A::LRU { + let _eq = Equals::parse(input)?; + let lit = syn::LitInt::parse(input)?; + let value = lit.base10_parse::()?; + if let Some(old) = std::mem::replace(&mut options.lru, Some(value)) { + return Err(syn::Error::new(old.span(), "option `lru` provided twice")); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`lru` option not allowed here", + )); + } } else { return Err(syn::Error::new( ident.span(), diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index beb7f0d2..36b01df9 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -50,6 +50,8 @@ impl crate::options::AllowedOptions for SalsaStruct { const DB: bool = false; const RECOVERY_FN: bool = false; + + const LRU: bool = false; } const BANNED_FIELD_NAMES: &[&str] = &["from", "new"]; diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index 8c0884e1..4ce59b3d 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -31,6 +31,13 @@ fn tracked_fn(args: Args, item_fn: syn::ItemFn) -> syn::Result { "tracked functon takes too many argments to have its value set with `specify`", )); } + + if args.lru.is_some() { + return Err(syn::Error::new( + s.span(), + "`specify` and `lru` cannot be used together", + )); + } } let struct_item = configuration_struct(&item_fn); @@ -72,6 +79,8 @@ impl crate::options::AllowedOptions for TrackedFn { const DB: bool = false; const RECOVERY_FN: bool = true; + + const LRU: bool = true; } /// Returns the key type for this tracked function. @@ -228,6 +237,9 @@ fn ingredients_for_impl( } }; + // set 0 as default to disable LRU + let lru = args.lru.unwrap_or(0); + parse_quote! { impl salsa::storage::IngredientsFor for #config_ty { type Ingredients = Self; @@ -254,8 +266,10 @@ fn ingredients_for_impl( <_ as salsa::storage::HasIngredientsFor>::ingredient_mut(jar); &mut ingredients.function }); - salsa::function::FunctionIngredient::new(index) - }, + let ingredient = salsa::function::FunctionIngredient::new(index); + ingredient.set_capacity(#lru); + ingredient + } } } } diff --git a/salsa-2022-tests/tests/lru.rs b/salsa-2022-tests/tests/lru.rs new file mode 100644 index 00000000..21d2ecb4 --- /dev/null +++ b/salsa-2022-tests/tests/lru.rs @@ -0,0 +1,105 @@ +//! Test that a `tracked` fn with lru options +//! compiles and executes successfully. + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar(MyInput, get_hot_potato, get_volatile); + +trait Db: salsa::DbWithJar {} + +#[derive(Debug, PartialEq, Eq)] +struct HotPotato(u32); + +thread_local! { + static N_POTATOES: AtomicUsize = AtomicUsize::new(0) +} + +impl HotPotato { + fn new(id: u32) -> HotPotato { + N_POTATOES.with(|n| n.fetch_add(1, Ordering::SeqCst)); + HotPotato(id) + } +} + +impl Drop for HotPotato { + fn drop(&mut self) { + N_POTATOES.with(|n| n.fetch_sub(1, Ordering::SeqCst)); + } +} + +#[salsa::input(jar = Jar)] +struct MyInput { + field: u32, +} + +#[salsa::tracked(jar = Jar, lru = 32)] +fn get_hot_potato(db: &dyn Db, input: MyInput) -> Arc { + Arc::new(HotPotato::new(input.field(db))) +} + +#[salsa::tracked(jar = Jar, lru = 32)] +fn get_volatile(db: &dyn Db, _input: MyInput) -> usize { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + db.salsa_runtime().report_untracked_read(); + COUNTER.fetch_add(1, Ordering::SeqCst) +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + self.storage.runtime() + } +} + +impl Db for Database {} + +fn load_n_potatoes() -> usize { + N_POTATOES.with(|n| n.load(Ordering::SeqCst)) +} + +#[test] +fn lru_works() { + let mut db = Database::default(); + assert_eq!(load_n_potatoes(), 0); + + for i in 0..128u32 { + let input = MyInput::new(&mut db, i); + let p = get_hot_potato(&db, input); + assert_eq!(p.0, i) + } + + // Create a new input to change the revision, and trigger the GC + MyInput::new(&mut db, 0); + assert_eq!(load_n_potatoes(), 32); +} + +#[test] +fn lru_doesnt_break_volatile_queries() { + let mut db = Database::default(); + + // Create all inputs first, so that there are no revision changes among calls to `get_volatile` + let inputs: Vec = (0..128usize) + .map(|i| MyInput::new(&mut db, i as u32)) + .collect(); + + // Here, we check that we execute each volatile query at most once, despite + // LRU. That does mean that we have more values in DB than the LRU capacity, + // but it's much better than inconsistent results from volatile queries! + for _ in 0..3 { + for (i, input) in inputs.iter().enumerate() { + let x = get_volatile(&db, *input); + assert_eq!(x, i); + } + } +} diff --git a/salsa-2022-tests/tests/mutate_in_place.rs b/salsa-2022-tests/tests/mutate_in_place.rs index 4d6c1a5c..83166c96 100644 --- a/salsa-2022-tests/tests/mutate_in_place.rs +++ b/salsa-2022-tests/tests/mutate_in_place.rs @@ -3,7 +3,6 @@ use salsa_2022_tests::{HasLogger, Logger}; -use expect_test::expect; use test_log::test; #[salsa::jar(db = Db)]