From efbf3249ef9aeed6c0668081bf229baee478e10b Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Thu, 18 Jul 2024 06:32:18 -0400 Subject: [PATCH] wip --- components/salsa-macros/src/tracked_impl.rs | 29 ++++++++++++++- tests/tracked_method_inherent_return_ref.rs | 41 +++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/tracked_method_inherent_return_ref.rs diff --git a/components/salsa-macros/src/tracked_impl.rs b/components/salsa-macros/src/tracked_impl.rs index 2e40a161..5bbbc504 100644 --- a/components/salsa-macros/src/tracked_impl.rs +++ b/components/salsa-macros/src/tracked_impl.rs @@ -6,7 +6,7 @@ use crate::{ db_lifetime, hygiene::Hygiene, options::{AllowedOptions, Options}, - tracked_fn::{check_db_argument, TrackedFn}, + tracked_fn::{check_db_argument, FnArgs, TrackedFn}, }; pub(crate) fn tracked_impl( @@ -64,6 +64,7 @@ impl Macro { }; let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index); + let args: FnArgs = salsa_tracked_attr.parse_args()?; let InnerTrait = self.hygiene.ident("InnerTrait"); let inner_fn_name = self.hygiene.ident("inner_fn_name"); @@ -82,6 +83,8 @@ impl Macro { inner_fn.vis = syn::Visibility::Inherited; inner_fn.sig.ident = inner_fn_name.clone(); + // Construct the body of the method + let block = parse_quote!({ salsa::plumbing::setup_method_body! { salsa_tracked_attr: #salsa_tracked_attr, @@ -105,7 +108,12 @@ impl Macro { } }); + // Update the method that will actually appear in the impl to have the new body + // and its true return type + let db_lt = db_lt.cloned(); + self.update_return_type(&mut fn_item.sig, &args, &db_lt)?; fn_item.block = block; + Ok(()) } @@ -259,4 +267,23 @@ impl Macro { Ok((db_ident, db_ty)) } + + fn update_return_type( + &self, + sig: &mut syn::Signature, + args: &FnArgs, + db_lt: &Option, + ) -> syn::Result<()> { + if let Some(return_ref) = &args.return_ref { + if let syn::ReturnType::Type(_, t) = &mut sig.output { + **t = parse_quote!(& #db_lt #t) + } else { + return Err(syn::Error::new_spanned( + return_ref, + "return_ref attribute requires explicit return type", + )); + }; + } + Ok(()) + } } diff --git a/tests/tracked_method_inherent_return_ref.rs b/tests/tracked_method_inherent_return_ref.rs new file mode 100644 index 00000000..e6fa65a8 --- /dev/null +++ b/tests/tracked_method_inherent_return_ref.rs @@ -0,0 +1,41 @@ +use salsa::Database as _; + +#[salsa::input] +struct Input { + number: usize, +} + +#[salsa::tracked] +impl Input { + #[salsa::tracked(return_ref)] + fn test(self, db: &dyn salsa::Database) -> Vec { + (0..self.number(db)) + .map(|i| format!("test {}", i)) + .collect() + } +} + +#[salsa::db] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +#[salsa::db] +impl salsa::Database for Database {} + +#[test] +fn invoke() { + Database::default().attach(|db| { + let input = Input::new(db, 3); + let x: &Vec = input.test(db); + expect_test::expect![[r#" + [ + "test 0", + "test 1", + "test 2", + ] + "#]] + .assert_debug_eq(x); + }) +}