mirror of
https://github.com/salsa-rs/salsa.git
synced 2024-11-24 04:09:36 +00:00
fix: Replace SelfTy
with actual type in tracked methods
This commit is contained in:
parent
af2ec49d80
commit
ad1f84d80f
3 changed files with 184 additions and 4 deletions
|
@ -1,8 +1,10 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::ToTokens;
|
||||
use syn::parse::Nothing;
|
||||
use syn::{parse::Nothing, visit_mut::VisitMut};
|
||||
|
||||
use crate::{hygiene::Hygiene, tracked_fn::FnArgs};
|
||||
use crate::{hygiene::Hygiene, tracked_fn::FnArgs, xform::ChangeSelfPath};
|
||||
|
||||
pub(crate) fn tracked_impl(
|
||||
args: proc_macro::TokenStream,
|
||||
|
@ -32,8 +34,19 @@ struct MethodArguments<'syn> {
|
|||
impl Macro {
|
||||
fn try_generate(&self, mut impl_item: syn::ItemImpl) -> syn::Result<TokenStream> {
|
||||
let mut member_items = std::mem::take(&mut impl_item.items);
|
||||
let member_idents: HashSet<_> = member_items
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
syn::ImplItem::Const(it) => Some(it.ident.clone()),
|
||||
syn::ImplItem::Fn(it) => Some(it.sig.ident.clone()),
|
||||
syn::ImplItem::Type(it) => Some(it.ident.clone()),
|
||||
syn::ImplItem::Macro(_) => None,
|
||||
syn::ImplItem::Verbatim(_) => None,
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
for member_item in &mut member_items {
|
||||
self.modify_member(&impl_item, member_item)?;
|
||||
self.modify_member(&impl_item, member_item, &member_idents)?;
|
||||
}
|
||||
impl_item.items = member_items;
|
||||
Ok(crate::debug::dump_tokens(
|
||||
|
@ -47,6 +60,7 @@ impl Macro {
|
|||
&self,
|
||||
impl_item: &syn::ItemImpl,
|
||||
member_item: &mut syn::ImplItem,
|
||||
member_idents: &HashSet<syn::Ident>,
|
||||
) -> syn::Result<()> {
|
||||
let syn::ImplItem::Fn(fn_item) = member_item else {
|
||||
return Ok(());
|
||||
|
@ -59,6 +73,13 @@ impl Macro {
|
|||
return Ok(());
|
||||
};
|
||||
|
||||
let trait_ = match &impl_item.trait_ {
|
||||
Some((None, path, _)) => Some((path, member_idents)),
|
||||
_ => None,
|
||||
};
|
||||
let mut change = ChangeSelfPath::new(self_ty, trait_);
|
||||
change.visit_impl_item_fn_mut(fn_item);
|
||||
|
||||
let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index);
|
||||
let args: FnArgs = match &salsa_tracked_attr.meta {
|
||||
syn::Meta::Path(..) => Default::default(),
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use syn::visit_mut::VisitMut;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use quote::ToTokens;
|
||||
use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut};
|
||||
|
||||
pub(crate) struct ChangeLt<'a> {
|
||||
from: Option<&'a str>,
|
||||
|
@ -12,6 +15,7 @@ impl ChangeLt<'_> {
|
|||
to: db_lt.ident.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn in_type(mut self, ty: &syn::Type) -> syn::Type {
|
||||
let mut ty = ty.clone();
|
||||
self.visit_type_mut(&mut ty);
|
||||
|
@ -26,3 +30,114 @@ impl syn::visit_mut::VisitMut for ChangeLt<'_> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ChangeSelfPath<'a> {
|
||||
self_ty: &'a syn::Type,
|
||||
trait_: Option<(&'a syn::Path, &'a HashSet<syn::Ident>)>,
|
||||
}
|
||||
|
||||
impl ChangeSelfPath<'_> {
|
||||
pub fn new<'a>(
|
||||
self_ty: &'a syn::Type,
|
||||
trait_: Option<(&'a syn::Path, &'a HashSet<syn::Ident>)>,
|
||||
) -> ChangeSelfPath<'a> {
|
||||
ChangeSelfPath { self_ty, trait_ }
|
||||
}
|
||||
}
|
||||
|
||||
impl syn::visit_mut::VisitMut for ChangeSelfPath<'_> {
|
||||
fn visit_type_mut(&mut self, i: &mut syn::Type) {
|
||||
if let syn::Type::Path(syn::TypePath { qself: None, path }) = i {
|
||||
if path.segments.len() == 1 && path.segments.first().is_some_and(|s| s.ident == "Self")
|
||||
{
|
||||
let span = path.segments.first().unwrap().span();
|
||||
*i = respan(self.self_ty, span);
|
||||
}
|
||||
}
|
||||
syn::visit_mut::visit_type_mut(self, i);
|
||||
}
|
||||
|
||||
fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
|
||||
// `<Self as ..>` cases are handled in `visit_type_mut`
|
||||
if i.qself.is_some() {
|
||||
syn::visit_mut::visit_type_path_mut(self, i);
|
||||
return;
|
||||
}
|
||||
|
||||
// A single path `Self` case is handled in `visit_type_mut`
|
||||
if i.path.segments.first().is_some_and(|s| s.ident == "Self") && i.path.segments.len() > 1 {
|
||||
let span = i.path.segments.first().unwrap().span();
|
||||
let ty = Box::new(respan::<syn::Type>(self.self_ty, span));
|
||||
let lt_token = syn::Token![<](span);
|
||||
let gt_token = syn::Token![>](span);
|
||||
match self.trait_ {
|
||||
// If the next segment's ident is a trait member, replace `Self::` with
|
||||
// `<ActualTy as Trait>::`
|
||||
Some((trait_, member_idents))
|
||||
if member_idents.contains(&i.path.segments.iter().nth(1).unwrap().ident) =>
|
||||
{
|
||||
let qself = syn::QSelf {
|
||||
lt_token,
|
||||
ty,
|
||||
position: trait_.segments.len(),
|
||||
as_token: Some(syn::Token![as](span)),
|
||||
gt_token,
|
||||
};
|
||||
i.qself = Some(qself);
|
||||
i.path.segments = Punctuated::from_iter(
|
||||
trait_
|
||||
.segments
|
||||
.iter()
|
||||
.chain(i.path.segments.iter().skip(1))
|
||||
.cloned(),
|
||||
);
|
||||
}
|
||||
// Replace `Self::` with `<ActualTy>::` otherwise
|
||||
_ => {
|
||||
let qself = syn::QSelf {
|
||||
lt_token,
|
||||
ty,
|
||||
position: 0,
|
||||
as_token: None,
|
||||
gt_token,
|
||||
};
|
||||
i.qself = Some(qself);
|
||||
i.path.segments =
|
||||
Punctuated::from_iter(i.path.segments.iter().skip(1).cloned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
syn::visit_mut::visit_type_path_mut(self, i);
|
||||
}
|
||||
}
|
||||
|
||||
fn respan<T>(t: &T, span: proc_macro2::Span) -> T
|
||||
where
|
||||
T: ToTokens + Spanned + syn::parse::Parse,
|
||||
{
|
||||
let tokens = t.to_token_stream();
|
||||
let respanned = respan_tokenstream(tokens, span);
|
||||
syn::parse2(respanned).unwrap()
|
||||
}
|
||||
|
||||
fn respan_tokenstream(
|
||||
stream: proc_macro2::TokenStream,
|
||||
span: proc_macro2::Span,
|
||||
) -> proc_macro2::TokenStream {
|
||||
stream
|
||||
.into_iter()
|
||||
.map(|token| respan_token(token, span))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn respan_token(
|
||||
mut token: proc_macro2::TokenTree,
|
||||
span: proc_macro2::Span,
|
||||
) -> proc_macro2::TokenTree {
|
||||
if let proc_macro2::TokenTree::Group(g) = &mut token {
|
||||
*g = proc_macro2::Group::new(g.delimiter(), respan_tokenstream(g.stream(), span));
|
||||
}
|
||||
token.set_span(span);
|
||||
token
|
||||
}
|
||||
|
|
44
tests/tracked_method_with_self_ty.rs
Normal file
44
tests/tracked_method_with_self_ty.rs
Normal file
|
@ -0,0 +1,44 @@
|
|||
//! Test that a `tracked` fn with `Self` in its signature or body on a `salsa::input`
|
||||
//! compiles and executes successfully.
|
||||
#![allow(warnings)]
|
||||
|
||||
trait TrackedTrait {
|
||||
type Type;
|
||||
|
||||
fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type;
|
||||
|
||||
fn untracked_trait_fn();
|
||||
}
|
||||
|
||||
#[salsa::input]
|
||||
struct MyInput {
|
||||
field: u32,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
impl MyInput {
|
||||
#[salsa::tracked]
|
||||
fn tracked_fn(self, db: &dyn salsa::Database, other: Self) -> u32 {
|
||||
self.field(db) + other.field(db)
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
impl TrackedTrait for MyInput {
|
||||
type Type = u32;
|
||||
|
||||
#[salsa::tracked]
|
||||
fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type {
|
||||
Self::untracked_trait_fn();
|
||||
Self::tracked_fn(self, db, self) + ty
|
||||
}
|
||||
|
||||
fn untracked_trait_fn() {}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let mut db = salsa::DatabaseImpl::new();
|
||||
let object = MyInput::new(&mut db, 10);
|
||||
assert_eq!(object.tracked_trait_fn(&db, 1), 21);
|
||||
}
|
Loading…
Reference in a new issue