This commit is contained in:
Niko Matsakis 2024-07-18 06:32:18 -04:00
parent 82872192b4
commit efbf3249ef
2 changed files with 69 additions and 1 deletions

View file

@ -6,7 +6,7 @@ use crate::{
db_lifetime,
hygiene::Hygiene,
options::{AllowedOptions, Options},
tracked_fn::{check_db_argument, TrackedFn},
tracked_fn::{check_db_argument, FnArgs, TrackedFn},
};
pub(crate) fn tracked_impl(
@ -64,6 +64,7 @@ impl Macro {
};
let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index);
let args: FnArgs = salsa_tracked_attr.parse_args()?;
let InnerTrait = self.hygiene.ident("InnerTrait");
let inner_fn_name = self.hygiene.ident("inner_fn_name");
@ -82,6 +83,8 @@ impl Macro {
inner_fn.vis = syn::Visibility::Inherited;
inner_fn.sig.ident = inner_fn_name.clone();
// Construct the body of the method
let block = parse_quote!({
salsa::plumbing::setup_method_body! {
salsa_tracked_attr: #salsa_tracked_attr,
@ -105,7 +108,12 @@ impl Macro {
}
});
// Update the method that will actually appear in the impl to have the new body
// and its true return type
let db_lt = db_lt.cloned();
self.update_return_type(&mut fn_item.sig, &args, &db_lt)?;
fn_item.block = block;
Ok(())
}
@ -259,4 +267,23 @@ impl Macro {
Ok((db_ident, db_ty))
}
fn update_return_type(
&self,
sig: &mut syn::Signature,
args: &FnArgs,
db_lt: &Option<syn::Lifetime>,
) -> syn::Result<()> {
if let Some(return_ref) = &args.return_ref {
if let syn::ReturnType::Type(_, t) = &mut sig.output {
**t = parse_quote!(& #db_lt #t)
} else {
return Err(syn::Error::new_spanned(
return_ref,
"return_ref attribute requires explicit return type",
));
};
}
Ok(())
}
}

View file

@ -0,0 +1,41 @@
use salsa::Database as _;
#[salsa::input]
struct Input {
number: usize,
}
#[salsa::tracked]
impl Input {
#[salsa::tracked(return_ref)]
fn test(self, db: &dyn salsa::Database) -> Vec<String> {
(0..self.number(db))
.map(|i| format!("test {}", i))
.collect()
}
}
#[salsa::db]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
#[salsa::db]
impl salsa::Database for Database {}
#[test]
fn invoke() {
Database::default().attach(|db| {
let input = Input::new(db, 3);
let x: &Vec<String> = input.test(db);
expect_test::expect![[r#"
[
"test 0",
"test 1",
"test 2",
]
"#]]
.assert_debug_eq(x);
})
}