Add more tests for tracked methods

This commit is contained in:
Jack Rickard 2022-09-03 15:44:56 +01:00
parent 2df88d2c33
commit bac4c668cf
No known key found for this signature in database
GPG key ID: 88084D7D08A72C8A
4 changed files with 71 additions and 13 deletions

View file

@ -94,6 +94,15 @@ pub(crate) fn tracked_impl(
)) ))
} }
}; };
let self_type_name = &self_type.path.segments.last().unwrap().ident;
let name_prefix = match &item_impl.trait_ {
Some((_, trait_name, _)) => format!(
"{}_{}",
self_type_name,
trait_name.segments.last().unwrap().ident
),
None => format!("{}", self_type_name),
};
let extra_impls = item_impl let extra_impls = item_impl
.items .items
.iter_mut() .iter_mut()
@ -116,10 +125,18 @@ pub(crate) fn tracked_impl(
} else { } else {
Ok(FnArgs::default()) Ok(FnArgs::default())
}; };
Some(match inner_args { let inner_args = match inner_args {
Ok(inner_args) => tracked_method(&args, inner_args, item_method, self_type), Ok(inner_args) => inner_args,
Err(e) => Err(e), Err(err) => return Some(Err(err)),
}) };
let name = format!("{}_{}", name_prefix, item_method.sig.ident);
Some(tracked_method(
&args,
inner_args,
item_method,
self_type,
&name,
))
}) })
// Collate all the errors so we can display them all at once // Collate all the errors so we can display them all at once
.fold(Ok(Vec::new()), |mut acc, res| { .fold(Ok(Vec::new()), |mut acc, res| {
@ -166,6 +183,7 @@ fn tracked_method(
mut args: FnArgs, mut args: FnArgs,
item_method: &mut syn::ImplItemMethod, item_method: &mut syn::ImplItemMethod,
self_type: &syn::TypePath, self_type: &syn::TypePath,
name: &str,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
args.jar_ty = args.jar_ty.or_else(|| outer_args.jar_ty.clone()); args.jar_ty = args.jar_ty.or_else(|| outer_args.jar_ty.clone());
@ -182,14 +200,7 @@ fn tracked_method(
sig: item_method.sig.clone(), sig: item_method.sig.clone(),
block: Box::new(rename_self_in_block(item_method.block.clone())?), block: Box::new(rename_self_in_block(item_method.block.clone())?),
}; };
item_fn.sig.ident = syn::Ident::new( item_fn.sig.ident = syn::Ident::new(name, item_fn.sig.ident.span());
&format!(
"{}_{}",
self_type.path.segments.last().unwrap().ident,
item_fn.sig.ident
),
item_fn.sig.ident.span(),
);
// Flip the first and second arguments as the rest of the code expects the // Flip the first and second arguments as the rest of the code expects the
// database to come first and the struct to come second. We also need to // database to come first and the struct to come second. We also need to
// change the self argument to a normal typed argument called __salsa_self. // change the self argument to a normal typed argument called __salsa_self.

View file

@ -0,0 +1,18 @@
#[salsa::jar(db = Db)]
struct Jar(MyInput, tracked_method_on_untracked_impl);
trait Db: salsa::DbWithJar<Jar> {}
#[salsa::input]
struct MyInput {
field: u32,
}
impl MyInput {
#[salsa::tracked]
fn tracked_method_on_untracked_impl(self, db: &dyn Db) -> u32 {
input.field(db)
}
}
fn main() {}

View file

@ -0,0 +1,11 @@
error: #[salsa::tracked] must also be applied to the impl block for tracked methods
--> tests/compile-fail/tracked_method_on_untracked_impl.rs:13:41
|
13 | fn tracked_method_on_untracked_impl(self, db: &dyn Db) -> u32 {
| ^^^^
error[E0412]: cannot find type `tracked_method_on_untracked_impl` in this scope
--> tests/compile-fail/tracked_method_on_untracked_impl.rs:2:21
|
2 | struct Jar(MyInput, tracked_method_on_untracked_impl);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ not found in this scope

View file

@ -3,10 +3,19 @@
#![allow(warnings)] #![allow(warnings)]
#[salsa::jar(db = Db)] #[salsa::jar(db = Db)]
struct Jar(MyInput, MyInput_tracked_fn, MyInput_tracked_fn_ref); struct Jar(
MyInput,
MyInput_tracked_fn,
MyInput_tracked_fn_ref,
MyInput_TrackedTrait_tracked_trait_fn,
);
trait Db: salsa::DbWithJar<Jar> {} trait Db: salsa::DbWithJar<Jar> {}
trait TrackedTrait {
fn tracked_trait_fn(self, db: &dyn Db) -> u32;
}
#[salsa::input] #[salsa::input]
struct MyInput { struct MyInput {
field: u32, field: u32,
@ -25,6 +34,14 @@ impl MyInput {
} }
} }
#[salsa::tracked]
impl TrackedTrait for MyInput {
#[salsa::tracked]
fn tracked_trait_fn(self, db: &dyn Db) -> u32 {
self.field(db) * 4
}
}
#[test] #[test]
fn execute() { fn execute() {
#[salsa::db(Jar)] #[salsa::db(Jar)]
@ -41,4 +58,5 @@ fn execute() {
let object = MyInput::new(&mut db, 22); let object = MyInput::new(&mut db, 22);
assert_eq!(object.tracked_fn(&db), 44); assert_eq!(object.tracked_fn(&db), 44);
assert_eq!(*object.tracked_fn_ref(&db), 66); assert_eq!(*object.tracked_fn_ref(&db), 66);
assert_eq!(object.tracked_trait_fn(&db), 88);
} }