diff --git a/components/salsa-macros/src/tracked_impl.rs b/components/salsa-macros/src/tracked_impl.rs index 073f43f..4160ceb 100644 --- a/components/salsa-macros/src/tracked_impl.rs +++ b/components/salsa-macros/src/tracked_impl.rs @@ -1,8 +1,10 @@ +use std::collections::HashSet; + use proc_macro2::TokenStream; use quote::ToTokens; -use syn::parse::Nothing; +use syn::{parse::Nothing, visit_mut::VisitMut}; -use crate::{hygiene::Hygiene, tracked_fn::FnArgs}; +use crate::{hygiene::Hygiene, tracked_fn::FnArgs, xform::ChangeSelfPath}; pub(crate) fn tracked_impl( args: proc_macro::TokenStream, @@ -32,8 +34,19 @@ struct MethodArguments<'syn> { impl Macro { fn try_generate(&self, mut impl_item: syn::ItemImpl) -> syn::Result { let mut member_items = std::mem::take(&mut impl_item.items); + let member_idents: HashSet<_> = member_items + .iter() + .filter_map(|item| match item { + syn::ImplItem::Const(it) => Some(it.ident.clone()), + syn::ImplItem::Fn(it) => Some(it.sig.ident.clone()), + syn::ImplItem::Type(it) => Some(it.ident.clone()), + syn::ImplItem::Macro(_) => None, + syn::ImplItem::Verbatim(_) => None, + _ => None, + }) + .collect(); for member_item in &mut member_items { - self.modify_member(&impl_item, member_item)?; + self.modify_member(&impl_item, member_item, &member_idents)?; } impl_item.items = member_items; Ok(crate::debug::dump_tokens( @@ -47,6 +60,7 @@ impl Macro { &self, impl_item: &syn::ItemImpl, member_item: &mut syn::ImplItem, + member_idents: &HashSet, ) -> syn::Result<()> { let syn::ImplItem::Fn(fn_item) = member_item else { return Ok(()); @@ -59,6 +73,13 @@ impl Macro { return Ok(()); }; + let trait_ = match &impl_item.trait_ { + Some((None, path, _)) => Some((path, member_idents)), + _ => None, + }; + let mut change = ChangeSelfPath::new(self_ty, trait_); + change.visit_impl_item_fn_mut(fn_item); + let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index); let args: FnArgs = match &salsa_tracked_attr.meta { syn::Meta::Path(..) => Default::default(), diff --git a/components/salsa-macros/src/xform.rs b/components/salsa-macros/src/xform.rs index 4c719cb..3ab6bc6 100644 --- a/components/salsa-macros/src/xform.rs +++ b/components/salsa-macros/src/xform.rs @@ -1,4 +1,7 @@ -use syn::visit_mut::VisitMut; +use std::collections::HashSet; + +use quote::ToTokens; +use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut}; pub(crate) struct ChangeLt<'a> { from: Option<&'a str>, @@ -12,6 +15,7 @@ impl ChangeLt<'_> { to: db_lt.ident.to_string(), } } + pub fn in_type(mut self, ty: &syn::Type) -> syn::Type { let mut ty = ty.clone(); self.visit_type_mut(&mut ty); @@ -26,3 +30,114 @@ impl syn::visit_mut::VisitMut for ChangeLt<'_> { } } } + +pub(crate) struct ChangeSelfPath<'a> { + self_ty: &'a syn::Type, + trait_: Option<(&'a syn::Path, &'a HashSet)>, +} + +impl ChangeSelfPath<'_> { + pub fn new<'a>( + self_ty: &'a syn::Type, + trait_: Option<(&'a syn::Path, &'a HashSet)>, + ) -> ChangeSelfPath<'a> { + ChangeSelfPath { self_ty, trait_ } + } +} + +impl syn::visit_mut::VisitMut for ChangeSelfPath<'_> { + fn visit_type_mut(&mut self, i: &mut syn::Type) { + if let syn::Type::Path(syn::TypePath { qself: None, path }) = i { + if path.segments.len() == 1 && path.segments.first().is_some_and(|s| s.ident == "Self") + { + let span = path.segments.first().unwrap().span(); + *i = respan(self.self_ty, span); + } + } + syn::visit_mut::visit_type_mut(self, i); + } + + fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) { + // `` cases are handled in `visit_type_mut` + if i.qself.is_some() { + syn::visit_mut::visit_type_path_mut(self, i); + return; + } + + // A single path `Self` case is handled in `visit_type_mut` + if i.path.segments.first().is_some_and(|s| s.ident == "Self") && i.path.segments.len() > 1 { + let span = i.path.segments.first().unwrap().span(); + let ty = Box::new(respan::(self.self_ty, span)); + let lt_token = syn::Token![<](span); + let gt_token = syn::Token![>](span); + match self.trait_ { + // If the next segment's ident is a trait member, replace `Self::` with + // `::` + Some((trait_, member_idents)) + if member_idents.contains(&i.path.segments.iter().nth(1).unwrap().ident) => + { + let qself = syn::QSelf { + lt_token, + ty, + position: trait_.segments.len(), + as_token: Some(syn::Token![as](span)), + gt_token, + }; + i.qself = Some(qself); + i.path.segments = Punctuated::from_iter( + trait_ + .segments + .iter() + .chain(i.path.segments.iter().skip(1)) + .cloned(), + ); + } + // Replace `Self::` with `::` otherwise + _ => { + let qself = syn::QSelf { + lt_token, + ty, + position: 0, + as_token: None, + gt_token, + }; + i.qself = Some(qself); + i.path.segments = + Punctuated::from_iter(i.path.segments.iter().skip(1).cloned()); + } + } + } + + syn::visit_mut::visit_type_path_mut(self, i); + } +} + +fn respan(t: &T, span: proc_macro2::Span) -> T +where + T: ToTokens + Spanned + syn::parse::Parse, +{ + let tokens = t.to_token_stream(); + let respanned = respan_tokenstream(tokens, span); + syn::parse2(respanned).unwrap() +} + +fn respan_tokenstream( + stream: proc_macro2::TokenStream, + span: proc_macro2::Span, +) -> proc_macro2::TokenStream { + stream + .into_iter() + .map(|token| respan_token(token, span)) + .collect() +} + +fn respan_token( + mut token: proc_macro2::TokenTree, + span: proc_macro2::Span, +) -> proc_macro2::TokenTree { + if let proc_macro2::TokenTree::Group(g) = &mut token { + *g = proc_macro2::Group::new(g.delimiter(), respan_tokenstream(g.stream(), span)); + } + token.set_span(span); + token +} diff --git a/tests/tracked_method_with_self_ty.rs b/tests/tracked_method_with_self_ty.rs new file mode 100644 index 0000000..8f8b067 --- /dev/null +++ b/tests/tracked_method_with_self_ty.rs @@ -0,0 +1,44 @@ +//! Test that a `tracked` fn with `Self` in its signature or body on a `salsa::input` +//! compiles and executes successfully. +#![allow(warnings)] + +trait TrackedTrait { + type Type; + + fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type; + + fn untracked_trait_fn(); +} + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +impl MyInput { + #[salsa::tracked] + fn tracked_fn(self, db: &dyn salsa::Database, other: Self) -> u32 { + self.field(db) + other.field(db) + } +} + +#[salsa::tracked] +impl TrackedTrait for MyInput { + type Type = u32; + + #[salsa::tracked] + fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type { + Self::untracked_trait_fn(); + Self::tracked_fn(self, db, self) + ty + } + + fn untracked_trait_fn() {} +} + +#[test] +fn execute() { + let mut db = salsa::DatabaseImpl::new(); + let object = MyInput::new(&mut db, 10); + assert_eq!(object.tracked_trait_fn(&db, 1), 21); +}