add a test for tracked functions

This commit is contained in:
Niko Matsakis 2022-08-05 00:39:00 -04:00
parent b4053ad76b
commit 627eddd428
5 changed files with 83 additions and 2 deletions

View file

@ -19,7 +19,6 @@ parking_lot = "0.12.1"
rustc-hash = "1.0"
smallvec = "1.0.0"
oorandom = "11"
salsa-macros = { version = "0.17.0-pre.2", path = "components/salsa-macros" }
[dev-dependencies]
@ -37,4 +36,5 @@ members = [
"components/salsa-entity-mock",
"components/salsa-entity-macros",
"calc-example/calc",
"salsa-2022-tests"
]

View file

@ -66,9 +66,10 @@ fn key_ty(item_fn: &syn::ItemFn) -> syn::Type {
fn configuration_struct(item_fn: &syn::ItemFn) -> syn::ItemStruct {
let fn_name = item_fn.sig.ident.clone();
let key_tuple_ty = key_ty(item_fn);
let visibility = &item_fn.vis;
parse_quote! {
#[allow(non_camel_case_types)]
pub struct #fn_name {
#visibility struct #fn_name {
intern_map: salsa::interned::InternedIngredient<salsa::Id, #key_tuple_ty>,
function: salsa::function::FunctionIngredient<Self>,
}
@ -187,6 +188,7 @@ fn wrapper_fns(
Ok((getter_fn, setter_impl))
}
/// Creates the `get` associated function.
fn getter_fn(
args: &Args,
item_fn: &syn::ItemFn,
@ -224,6 +226,10 @@ fn getter_fn(
Ok(getter_fn)
}
/// Creates a `get` associated function that returns `&Value`
/// (to be used when `return_ref` is specified).
///
/// (Helper for `getter_fn`)
fn ref_getter_fn(
args: &Args,
item_fn: &syn::ItemFn,
@ -247,6 +253,8 @@ fn ref_getter_fn(
Ok(ref_getter_fn)
}
/// Creates a `set` associated function that can be used to set (given an `&mut db`)
/// the value for this function for some inputs.
fn setter_fn(
args: &Args,
item_fn: &syn::ItemFn,
@ -288,6 +296,9 @@ fn setter_fn(
})
}
/// Given a function def tagged with `#[return_ref]`, modifies `ref_getter_fn`
/// so that it returns an `&Value` instead of `Value`. May introduce a name for the
/// database lifetime if required.
fn make_fn_return_ref(mut ref_getter_fn: syn::ItemFn) -> syn::Result<syn::ItemFn> {
// The 0th input should be a `&dyn Foo`. We need to ensure
// it has a named lifetime parameter.
@ -313,6 +324,9 @@ fn make_fn_return_ref(mut ref_getter_fn: syn::ItemFn) -> syn::Result<syn::ItemFn
Ok(ref_getter_fn)
}
/// Given an item function, identifies the name given to the `&dyn Db` reference and returns it,
/// along with the type of the database. If the database lifetime did not have a name,
/// then modifies the item function so that it is called `'__db` and returns that.
fn db_lifetime_and_ty(func: &mut syn::ItemFn) -> syn::Result<(syn::Lifetime, &syn::Type)> {
match &mut func.sig.inputs[0] {
syn::FnArg::Receiver(r) => {
@ -355,6 +369,9 @@ fn db_lifetime_and_ty(func: &mut syn::ItemFn) -> syn::Result<(syn::Lifetime, &sy
}
}
/// Generates the `accumulated` function, which invokes `accumulated`
/// on the function ingredient to extract the values pushed (transitively)
/// into an accumulator.
fn accumulated_fn(
args: &Args,
item_fn: &syn::ItemFn,
@ -393,6 +410,10 @@ fn accumulated_fn(
Ok(accumulated_fn)
}
/// Examines the function arguments and returns a tuple of:
///
/// * the name of the database argument
/// * the name(s) of the key arguments
fn fn_args(item_fn: &syn::ItemFn) -> syn::Result<(proc_macro2::Ident, Vec<proc_macro2::Ident>)> {
// Check that we have no receiver and that all argments have names
if item_fn.sig.inputs.len() == 0 {

View file

@ -0,0 +1,16 @@
[package]
name = "salsa-2022-tests"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
salsa = { path = "../components/salsa-entity-mock", package = "salsa-entity-mock" }
[dev-dependencies]
expect-test = "1.4.0"
[[bin]]
name = "salsa-2022-tests"
path = "main.rs"

6
salsa-2022-tests/main.rs Normal file
View file

@ -0,0 +1,6 @@
//! This crate has the beginning of various unit tests on salsa 2022
//! code.
mod tracked_fn_on_input;
fn main() {}

View file

@ -0,0 +1,38 @@
//! Test that a `tracked` fn on a `salsa::input`
//! compiles and executes successfully.
#[salsa::jar(db = Db)]
struct Jar(MyInput, tracked_fn);
trait Db: salsa::DbWithJar<Jar> {}
#[salsa::input(jar = Jar)]
struct MyInput {
field: u32,
}
#[salsa::tracked(jar = Jar)]
fn tracked_fn(db: &dyn Db, input: MyInput) -> u32 {
input.field(db) * 2
}
#[test]
fn execute() {
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {
fn salsa_runtime(&self) -> &salsa::Runtime {
self.storage.runtime()
}
}
impl Db for Database {}
let mut db = Database::default();
let input = MyInput::new(&mut db, 22);
assert_eq!(tracked_fn(&db, input), 44);
}