From 57f38aa84b03a449c348f0aaf52e664ee5f4a4a8 Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Wed, 16 Oct 2024 01:43:09 +0900 Subject: [PATCH] fix: Inconsistent behaviour with lifetime elision on tracked fn --- components/salsa-macros/src/tracked_fn.rs | 33 ++++++++++++------ .../compile-fail/tracked_fn_incompatibles.rs | 34 +++++++++++++++++++ .../tracked_fn_incompatibles.stderr | 29 ++++++++++++++++ tests/tracked_fn_read_own_specify.rs | 2 +- 4 files changed, 86 insertions(+), 12 deletions(-) diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 6ebac55..57023ef 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -1,4 +1,5 @@ use proc_macro2::{Literal, Span, TokenStream}; +use quote::ToTokens; use syn::{spanned::Spanned, ItemFn}; use crate::{db_lifetime, fn_util, hygiene::Hygiene, options::Options}; @@ -154,7 +155,8 @@ impl Macro { )); } - let (db_ident, db_path) = check_db_argument(&item.sig.inputs[0])?; + let (db_ident, db_path) = + check_db_argument(&item.sig.inputs[0], item.sig.generics.lifetimes().next())?; Ok(ValidFn { db_ident, db_path }) } @@ -202,6 +204,7 @@ fn function_type(item_fn: &syn::ItemFn) -> FunctionType { pub fn check_db_argument<'arg>( fn_arg: &'arg syn::FnArg, + explicit_lt: Option<&'arg syn::LifetimeParam>, ) -> syn::Result<(&'arg syn::Ident, &'arg syn::Path)> { match fn_arg { syn::FnArg::Receiver(_) => { @@ -256,11 +259,23 @@ pub fn check_db_argument<'arg>( )); } - let extract_db_path = || -> Result<&'arg syn::Path, Span> { - let syn::Type::Reference(ref_type) = &*typed.ty else { - return Err(typed.ty.span()); - }; + let tykind_error_msg = + "must have type `&dyn Db`, where `Db` is some Salsa Database trait"; + let syn::Type::Reference(ref_type) = &*typed.ty else { + return Err(syn::Error::new(typed.ty.span(), tykind_error_msg)); + }; + + if let Some(lt) = explicit_lt { + if ref_type.lifetime.is_none() { + return Err(syn::Error::new_spanned( + ref_type.and_token, + format!("must have a `{}` lifetime", lt.lifetime.to_token_stream()), + )); + } + } + + let extract_db_path = || -> Result<&'arg syn::Path, Span> { if let Some(m) = &ref_type.mutability { return Err(m.span()); } @@ -298,12 +313,8 @@ pub fn check_db_argument<'arg>( Ok(path) }; - let db_path = extract_db_path().map_err(|span| { - syn::Error::new( - span, - "must have type `&dyn Db`, where `Db` is some Salsa Database trait", - ) - })?; + let db_path = + extract_db_path().map_err(|span| syn::Error::new(span, tykind_error_msg))?; Ok((db_ident, db_path)) } diff --git a/tests/compile-fail/tracked_fn_incompatibles.rs b/tests/compile-fail/tracked_fn_incompatibles.rs index 309e4fb..5f587bf 100644 --- a/tests/compile-fail/tracked_fn_incompatibles.rs +++ b/tests/compile-fail/tracked_fn_incompatibles.rs @@ -34,4 +34,38 @@ fn tracked_fn_with_too_many_arguments_for_specify( ) -> u32 { } +#[salsa::interned] +struct MyInterned<'db> { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn_with_lt_param_and_elided_lt_on_db_arg1<'db>( + db: &dyn Db, + interned: MyInterned<'db>, +) -> u32 { + interned.field(db) * 2 +} + +#[salsa::tracked] +fn tracked_fn_with_lt_param_and_elided_lt_on_db_arg2<'db_lifetime>( + db: &dyn Db, + interned: MyInterned<'db_lifetime>, +) -> u32 { + interned.field(db) * 2 +} + +#[salsa::tracked] +fn tracked_fn_with_lt_param_and_elided_lt_on_input<'db>( + db: &'db dyn Db, + interned: MyInterned, +) -> u32 { + interned.field(db) * 2 +} + +#[salsa::tracked] +fn tracked_fn_with_multiple_lts<'db1, 'db2>(db: &'db1 dyn Db, interned: MyInterned<'db2>) -> u32 { + interned.field(db) * 2 +} + fn main() {} diff --git a/tests/compile-fail/tracked_fn_incompatibles.stderr b/tests/compile-fail/tracked_fn_incompatibles.stderr index 5851bf7..882f920 100644 --- a/tests/compile-fail/tracked_fn_incompatibles.stderr +++ b/tests/compile-fail/tracked_fn_incompatibles.stderr @@ -28,6 +28,35 @@ error: only functions with a single salsa struct as their input can be specified 29 | #[salsa::tracked(specify)] | ^^^^^^^ +error: must have a `'db` lifetime + --> tests/compile-fail/tracked_fn_incompatibles.rs:44:9 + | +44 | db: &dyn Db, + | ^ + +error: must have a `'db_lifetime` lifetime + --> tests/compile-fail/tracked_fn_incompatibles.rs:52:9 + | +52 | db: &dyn Db, + | ^ + +error: only a single lifetime parameter is accepted + --> tests/compile-fail/tracked_fn_incompatibles.rs:67:39 + | +67 | fn tracked_fn_with_multiple_lts<'db1, 'db2>(db: &'db1 dyn Db, interned: MyInterned<'db2>) -> u32 { + | ^^^^ + +error[E0106]: missing lifetime specifier + --> tests/compile-fail/tracked_fn_incompatibles.rs:61:15 + | +61 | interned: MyInterned, + | ^^^^^^^^^^ expected named lifetime parameter + | +help: consider using the `'db` lifetime + | +61 | interned: MyInterned<'db>, + | +++++ + error[E0308]: mismatched types --> tests/compile-fail/tracked_fn_incompatibles.rs:24:46 | diff --git a/tests/tracked_fn_read_own_specify.rs b/tests/tracked_fn_read_own_specify.rs index 426d18a..c91bac6 100644 --- a/tests/tracked_fn_read_own_specify.rs +++ b/tests/tracked_fn_read_own_specify.rs @@ -22,7 +22,7 @@ fn tracked_fn(db: &dyn LogDatabase, input: MyInput) -> u32 { } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: MyTracked<'db>) -> u32 { +fn tracked_fn_extra<'db>(db: &'db dyn LogDatabase, input: MyTracked<'db>) -> u32 { db.push_log(format!("tracked_fn_extra({input:?})")); 0 }