diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index 001af7ba..b57b4e1e 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -27,7 +27,10 @@ use crate::options::{AllowedOptions, Options}; use proc_macro2::{Ident, Span, TokenStream}; -use syn::spanned::Spanned; +use syn::{ + punctuated::Punctuated, spanned::Spanned, token::Comma, GenericParam, ImplGenerics, + TypeGenerics, WhereClause, +}; pub(crate) struct SalsaStruct { args: Options, @@ -144,6 +147,31 @@ impl SalsaStruct { &self.struct_item.ident } + /// Name of the struct the user gave plus: + /// + /// * its list of generic parameters + /// * the generics "split for impl". + pub(crate) fn id_ident_and_generics( + &self, + ) -> ( + &syn::Ident, + &Punctuated, + ImplGenerics<'_>, + TypeGenerics<'_>, + Option<&WhereClause>, + ) { + let ident = &self.struct_item.ident; + let (impl_generics, type_generics, where_clause) = + self.struct_item.generics.split_for_impl(); + ( + ident, + &self.struct_item.generics.params, + impl_generics, + type_generics, + where_clause, + ) + } + /// Type of the jar for this struct pub(crate) fn jar_ty(&self) -> syn::Type { self.args.jar_ty() @@ -173,7 +201,9 @@ impl SalsaStruct { } } - /// Generate `struct Foo(Id)` + /// Create a struct that wraps the id. + /// This is the struct the user will refernece, but only if there + /// are no lifetimes. pub(crate) fn id_struct(&self) -> syn::ItemStruct { let ident = self.id_ident(); let visibility = &self.struct_item.vis; @@ -195,6 +225,49 @@ impl SalsaStruct { } } + /// Create the struct that the user will reference. + /// If + pub(crate) fn id_or_ptr_struct( + &self, + config_ident: &syn::Ident, + ) -> syn::Result { + if self.struct_item.generics.params.is_empty() { + Ok(self.id_struct()) + } else { + let ident = self.id_ident(); + let visibility = &self.struct_item.vis; + + let generics = &self.struct_item.generics; + if generics.params.len() != 1 || generics.lifetimes().count() != 1 { + return Err(syn::Error::new_spanned( + &self.struct_item.generics, + "must have exactly one lifetime parameter", + )); + } + + let lifetime = generics.lifetimes().next().unwrap(); + + // Extract the attributes the user gave, but screen out derive, since we are adding our own, + // and the customize attribute that we use for our own purposes. + let attrs: Vec<_> = self + .struct_item + .attrs + .iter() + .filter(|attr| !attr.path.is_ident("derive")) + .filter(|attr| !attr.path.is_ident("customize")) + .collect(); + + Ok(parse_quote_spanned! { ident.span() => + #(#attrs)* + #[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)] + #visibility struct #ident #generics ( + *const salsa::tracked_struct::TrackedStructValue < #config_ident >, + std::marker::PhantomData < & #lifetime salsa::tracked_struct::TrackedStructValue < #config_ident > > + ); + }) + } + } + /// Generates the `struct FooData` struct (or enum). /// This type inherits all the attributes written by the user. /// @@ -233,8 +306,12 @@ impl SalsaStruct { /// Generate `impl salsa::AsId for Foo` pub(crate) fn as_id_impl(&self) -> syn::ItemImpl { let ident = self.id_ident(); + let (impl_generics, type_generics, where_clause) = + self.struct_item.generics.split_for_impl(); parse_quote_spanned! { ident.span() => - impl salsa::AsId for #ident { + impl #impl_generics salsa::AsId for #ident #type_generics + #where_clause + { fn as_id(self) -> salsa::Id { self.0 } @@ -254,6 +331,8 @@ impl SalsaStruct { } let ident = self.id_ident(); + let (impl_generics, type_generics, where_clause) = + self.struct_item.generics.split_for_impl(); let db_type = self.db_dyn_ty(); let ident_string = ident.to_string(); @@ -281,7 +360,9 @@ impl SalsaStruct { // `use ::salsa::debug::helper::Fallback` is needed for the fallback to `Debug` impl Some(parse_quote_spanned! {ident.span()=> - impl ::salsa::DebugWithDb<#db_type> for #ident { + impl #impl_generics ::salsa::DebugWithDb<#db_type> for #ident #type_generics + #where_clause + { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>, _db: &#db_type) -> ::std::fmt::Result { #[allow(unused_imports)] use ::salsa::debug::helper::Fallback; diff --git a/components/salsa-2022-macros/src/tracked_struct.rs b/components/salsa-2022-macros/src/tracked_struct.rs index 48022c9a..0a4bb9b9 100644 --- a/components/salsa-2022-macros/src/tracked_struct.rs +++ b/components/salsa-2022-macros/src/tracked_struct.rs @@ -50,8 +50,8 @@ impl TrackedStruct { fn generate_tracked(&self) -> syn::Result { self.validate_tracked()?; - let id_struct = self.id_struct(); let config_struct = self.config_struct(); + let the_struct = self.id_or_ptr_struct(&config_struct.ident)?; let config_impl = self.config_impl(&config_struct); let inherent_impl = self.tracked_inherent_impl(); let ingredients_for_impl = self.tracked_struct_ingredients(&config_struct); @@ -63,7 +63,7 @@ impl TrackedStruct { Ok(quote! { #config_struct #config_impl - #id_struct + #the_struct #inherent_impl #ingredients_for_impl #salsa_struct_in_db_impl @@ -168,7 +168,8 @@ impl TrackedStruct { /// Generate an inherent impl with methods on the tracked type. fn tracked_inherent_impl(&self) -> syn::ItemImpl { - let ident = self.id_ident(); + let (ident, _, impl_generics, type_generics, where_clause) = self.id_ident_and_generics(); + let jar_ty = self.jar_ty(); let db_dyn_ty = self.db_dyn_ty(); let tracked_field_ingredients: Literal = self.tracked_field_ingredients_index(); @@ -207,7 +208,8 @@ impl TrackedStruct { parse_quote! { #[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)] - impl #ident { + impl #impl_generics #ident #type_generics + #where_clause { pub fn #constructor_name(__db: &#db_dyn_ty, #(#field_names: #field_tys,)*) -> Self { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); @@ -230,7 +232,7 @@ impl TrackedStruct { /// function ingredient for each of the value fields. fn tracked_struct_ingredients(&self, config_struct: &syn::ItemStruct) -> syn::ItemImpl { use crate::literal; - let ident = self.id_ident(); + let (ident, _, impl_generics, type_generics, where_clause) = self.id_ident_and_generics(); let jar_ty = self.jar_ty(); let config_struct_name = &config_struct.ident; let field_indices: Vec = self.all_field_indices(); @@ -241,7 +243,8 @@ impl TrackedStruct { let debug_name_fields: Vec<_> = self.all_field_names().into_iter().map(literal).collect(); parse_quote! { - impl salsa::storage::IngredientsFor for #ident { + impl #impl_generics salsa::storage::IngredientsFor for #ident #type_generics + #where_clause { type Jar = #jar_ty; type Ingredients = ( salsa::tracked_struct::TrackedStructIngredient<#config_struct_name>, @@ -298,15 +301,17 @@ impl TrackedStruct { /// Implementation of `SalsaStructInDb`. fn salsa_struct_in_db_impl(&self) -> syn::ItemImpl { - let ident = self.id_ident(); + let (ident, parameters, _, type_generics, where_clause) = self.id_ident_and_generics(); + let db = syn::Ident::new("DB", ident.span()); let jar_ty = self.jar_ty(); let tracked_struct_ingredient = self.tracked_struct_ingredient_index(); parse_quote! { - impl salsa::salsa_struct::SalsaStructInDb for #ident + impl<#db, #parameters> salsa::salsa_struct::SalsaStructInDb<#db> for #ident #type_generics where - DB: ?Sized + salsa::DbWithJar<#jar_ty>, + #db: ?Sized + salsa::DbWithJar<#jar_ty>, + #where_clause { - fn register_dependent_fn(db: &DB, index: salsa::routes::IngredientIndex) { + fn register_dependent_fn(db: & #db, index: salsa::routes::IngredientIndex) { let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar); ingredients.#tracked_struct_ingredient.register_dependent_fn(index) @@ -317,15 +322,17 @@ impl TrackedStruct { /// Implementation of `TrackedStructInDb`. fn tracked_struct_in_db_impl(&self) -> syn::ItemImpl { - let ident = self.id_ident(); + let (ident, parameters, _, type_generics, where_clause) = self.id_ident_and_generics(); + let db = syn::Ident::new("DB", ident.span()); let jar_ty = self.jar_ty(); let tracked_struct_ingredient = self.tracked_struct_ingredient_index(); parse_quote! { - impl salsa::tracked_struct::TrackedStructInDb for #ident + impl<#db, #parameters> salsa::tracked_struct::TrackedStructInDb<#db> for #ident #type_generics where - DB: ?Sized + salsa::DbWithJar<#jar_ty>, + #db: ?Sized + salsa::DbWithJar<#jar_ty>, + #where_clause { - fn database_key_index(self, db: &DB) -> salsa::DatabaseKeyIndex { + fn database_key_index(self, db: &#db) -> salsa::DatabaseKeyIndex { let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar); ingredients.#tracked_struct_ingredient.database_key_index(self)