352: Add options to tracked funcitons for lru capacity  r=nikomatsakis a=XFFXFF

fixes #344 

Now we can write something like the following to set the lru capacity of tracked functions  
```rust
#[salsa::tracked(lru=32)]
fn my_tracked_fn(db: &dyn crate::Db, ...) { }
```

some details:  
* lru should not be combined with specify. We will report an error if people do #[salsa::tracked(lru = 32, specify)]
* set 0 as default capacity to disable LRU (Because I think doing this would make the code simpler when implementing `create_ingredients` of tracked functions).
* old salsa support to change lru capacity at runtime, [as noted here](https://salsa-rs.github.io/salsa/rfcs/RFC0004-LRU.html?highlight=change#reference-guide), but we do not support this now

Co-authored-by: XFFXFF <1247714429@qq.com>
This commit is contained in:
bors[bot] 2022-08-18 10:37:38 +00:00 committed by GitHub
commit eca8bad6e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 148 additions and 3 deletions

View file

@ -32,6 +32,8 @@ impl crate::options::AllowedOptions for Accumulator {
const DB: bool = false;
const RECOVERY_FN: bool = false;
const LRU: bool = false;
}
fn accumulator_contents(

View file

@ -41,6 +41,8 @@ impl crate::options::AllowedOptions for Jar {
const DB: bool = true;
const RECOVERY_FN: bool = false;
const LRU: bool = false;
}
pub(crate) fn jar_struct_and_friends(

View file

@ -46,6 +46,11 @@ pub(crate) struct Options<A: AllowedOptions> {
/// If this is `Some`, the value is the `<ident>`.
pub data: Option<syn::Ident>,
/// The `lru = <usize>` option is used to set the lru capacity for a tracked function.
///
/// If this is `Some`, the value is the `<usize>`.
pub lru: Option<usize>,
/// Remember the `A` parameter, which plays no role after parsing.
phantom: PhantomData<A>,
}
@ -61,6 +66,7 @@ impl<A: AllowedOptions> Default for Options<A> {
recovery_fn: Default::default(),
data: Default::default(),
phantom: Default::default(),
lru: Default::default(),
}
}
}
@ -74,6 +80,7 @@ pub(crate) trait AllowedOptions {
const DATA: bool;
const DB: bool;
const RECOVERY_FN: bool;
const LRU: bool;
}
type Equals = syn::Token![=];
@ -195,6 +202,20 @@ impl<A: AllowedOptions> syn::parse::Parse for Options<A> {
"`data` option not allowed here",
));
}
} else if ident == "lru" {
if A::LRU {
let _eq = Equals::parse(input)?;
let lit = syn::LitInt::parse(input)?;
let value = lit.base10_parse::<usize>()?;
if let Some(old) = std::mem::replace(&mut options.lru, Some(value)) {
return Err(syn::Error::new(old.span(), "option `lru` provided twice"));
}
} else {
return Err(syn::Error::new(
ident.span(),
"`lru` option not allowed here",
));
}
} else {
return Err(syn::Error::new(
ident.span(),

View file

@ -50,6 +50,8 @@ impl crate::options::AllowedOptions for SalsaStruct {
const DB: bool = false;
const RECOVERY_FN: bool = false;
const LRU: bool = false;
}
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];

View file

@ -31,6 +31,13 @@ fn tracked_fn(args: Args, item_fn: syn::ItemFn) -> syn::Result<TokenStream> {
"tracked functon takes too many argments to have its value set with `specify`",
));
}
if args.lru.is_some() {
return Err(syn::Error::new(
s.span(),
"`specify` and `lru` cannot be used together",
));
}
}
let struct_item = configuration_struct(&item_fn);
@ -72,6 +79,8 @@ impl crate::options::AllowedOptions for TrackedFn {
const DB: bool = false;
const RECOVERY_FN: bool = true;
const LRU: bool = true;
}
/// Returns the key type for this tracked function.
@ -228,6 +237,9 @@ fn ingredients_for_impl(
}
};
// set 0 as default to disable LRU
let lru = args.lru.unwrap_or(0);
parse_quote! {
impl salsa::storage::IngredientsFor for #config_ty {
type Ingredients = Self;
@ -254,8 +266,10 @@ fn ingredients_for_impl(
<_ as salsa::storage::HasIngredientsFor<Self::Ingredients>>::ingredient_mut(jar);
&mut ingredients.function
});
salsa::function::FunctionIngredient::new(index)
},
let ingredient = salsa::function::FunctionIngredient::new(index);
ingredient.set_capacity(#lru);
ingredient
}
}
}
}

View file

@ -0,0 +1,105 @@
//! Test that a `tracked` fn with lru options
//! compiles and executes successfully.
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput, get_hot_potato, get_volatile);
trait Db: salsa::DbWithJar<Jar> {}
#[derive(Debug, PartialEq, Eq)]
struct HotPotato(u32);
thread_local! {
static N_POTATOES: AtomicUsize = AtomicUsize::new(0)
}
impl HotPotato {
fn new(id: u32) -> HotPotato {
N_POTATOES.with(|n| n.fetch_add(1, Ordering::SeqCst));
HotPotato(id)
}
}
impl Drop for HotPotato {
fn drop(&mut self) {
N_POTATOES.with(|n| n.fetch_sub(1, Ordering::SeqCst));
}
}
#[salsa::input(jar = Jar)]
struct MyInput {
field: u32,
}
#[salsa::tracked(jar = Jar, lru = 32)]
fn get_hot_potato(db: &dyn Db, input: MyInput) -> Arc<HotPotato> {
Arc::new(HotPotato::new(input.field(db)))
}
#[salsa::tracked(jar = Jar, lru = 32)]
fn get_volatile(db: &dyn Db, _input: MyInput) -> usize {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
db.salsa_runtime().report_untracked_read();
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {
fn salsa_runtime(&self) -> &salsa::Runtime {
self.storage.runtime()
}
}
impl Db for Database {}
fn load_n_potatoes() -> usize {
N_POTATOES.with(|n| n.load(Ordering::SeqCst))
}
#[test]
fn lru_works() {
let mut db = Database::default();
assert_eq!(load_n_potatoes(), 0);
for i in 0..128u32 {
let input = MyInput::new(&mut db, i);
let p = get_hot_potato(&db, input);
assert_eq!(p.0, i)
}
// Create a new input to change the revision, and trigger the GC
MyInput::new(&mut db, 0);
assert_eq!(load_n_potatoes(), 32);
}
#[test]
fn lru_doesnt_break_volatile_queries() {
let mut db = Database::default();
// Create all inputs first, so that there are no revision changes among calls to `get_volatile`
let inputs: Vec<MyInput> = (0..128usize)
.map(|i| MyInput::new(&mut db, i as u32))
.collect();
// Here, we check that we execute each volatile query at most once, despite
// LRU. That does mean that we have more values in DB than the LRU capacity,
// but it's much better than inconsistent results from volatile queries!
for _ in 0..3 {
for (i, input) in inputs.iter().enumerate() {
let x = get_volatile(&db, *input);
assert_eq!(x, i);
}
}
}

View file

@ -3,7 +3,6 @@
use salsa_2022_tests::{HasLogger, Logger};
use expect_test::expect;
use test_log::test;
#[salsa::jar(db = Db)]