allow (but don't test) lifetime parameters

This commit is contained in:
Niko Matsakis 2024-04-16 11:33:49 -04:00
parent ea1d452143
commit 79d24e0ad7
2 changed files with 106 additions and 18 deletions

View file

@ -27,7 +27,10 @@
use crate::options::{AllowedOptions, Options};
use proc_macro2::{Ident, Span, TokenStream};
use syn::spanned::Spanned;
use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, GenericParam, ImplGenerics,
TypeGenerics, WhereClause,
};
pub(crate) struct SalsaStruct<A: AllowedOptions> {
args: Options<A>,
@ -144,6 +147,31 @@ impl<A: AllowedOptions> SalsaStruct<A> {
&self.struct_item.ident
}
/// Name of the struct the user gave plus:
///
/// * its list of generic parameters
/// * the generics "split for impl".
pub(crate) fn id_ident_and_generics(
&self,
) -> (
&syn::Ident,
&Punctuated<GenericParam, Comma>,
ImplGenerics<'_>,
TypeGenerics<'_>,
Option<&WhereClause>,
) {
let ident = &self.struct_item.ident;
let (impl_generics, type_generics, where_clause) =
self.struct_item.generics.split_for_impl();
(
ident,
&self.struct_item.generics.params,
impl_generics,
type_generics,
where_clause,
)
}
/// Type of the jar for this struct
pub(crate) fn jar_ty(&self) -> syn::Type {
self.args.jar_ty()
@ -173,7 +201,9 @@ impl<A: AllowedOptions> SalsaStruct<A> {
}
}
/// Generate `struct Foo(Id)`
/// Create a struct that wraps the id.
/// This is the struct the user will refernece, but only if there
/// are no lifetimes.
pub(crate) fn id_struct(&self) -> syn::ItemStruct {
let ident = self.id_ident();
let visibility = &self.struct_item.vis;
@ -195,6 +225,49 @@ impl<A: AllowedOptions> SalsaStruct<A> {
}
}
/// Create the struct that the user will reference.
/// If
pub(crate) fn id_or_ptr_struct(
&self,
config_ident: &syn::Ident,
) -> syn::Result<syn::ItemStruct> {
if self.struct_item.generics.params.is_empty() {
Ok(self.id_struct())
} else {
let ident = self.id_ident();
let visibility = &self.struct_item.vis;
let generics = &self.struct_item.generics;
if generics.params.len() != 1 || generics.lifetimes().count() != 1 {
return Err(syn::Error::new_spanned(
&self.struct_item.generics,
"must have exactly one lifetime parameter",
));
}
let lifetime = generics.lifetimes().next().unwrap();
// Extract the attributes the user gave, but screen out derive, since we are adding our own,
// and the customize attribute that we use for our own purposes.
let attrs: Vec<_> = self
.struct_item
.attrs
.iter()
.filter(|attr| !attr.path.is_ident("derive"))
.filter(|attr| !attr.path.is_ident("customize"))
.collect();
Ok(parse_quote_spanned! { ident.span() =>
#(#attrs)*
#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
#visibility struct #ident #generics (
*const salsa::tracked_struct::TrackedStructValue < #config_ident >,
std::marker::PhantomData < & #lifetime salsa::tracked_struct::TrackedStructValue < #config_ident > >
);
})
}
}
/// Generates the `struct FooData` struct (or enum).
/// This type inherits all the attributes written by the user.
///
@ -233,8 +306,12 @@ impl<A: AllowedOptions> SalsaStruct<A> {
/// Generate `impl salsa::AsId for Foo`
pub(crate) fn as_id_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let (impl_generics, type_generics, where_clause) =
self.struct_item.generics.split_for_impl();
parse_quote_spanned! { ident.span() =>
impl salsa::AsId for #ident {
impl #impl_generics salsa::AsId for #ident #type_generics
#where_clause
{
fn as_id(self) -> salsa::Id {
self.0
}
@ -254,6 +331,8 @@ impl<A: AllowedOptions> SalsaStruct<A> {
}
let ident = self.id_ident();
let (impl_generics, type_generics, where_clause) =
self.struct_item.generics.split_for_impl();
let db_type = self.db_dyn_ty();
let ident_string = ident.to_string();
@ -281,7 +360,9 @@ impl<A: AllowedOptions> SalsaStruct<A> {
// `use ::salsa::debug::helper::Fallback` is needed for the fallback to `Debug` impl
Some(parse_quote_spanned! {ident.span()=>
impl ::salsa::DebugWithDb<#db_type> for #ident {
impl #impl_generics ::salsa::DebugWithDb<#db_type> for #ident #type_generics
#where_clause
{
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>, _db: &#db_type) -> ::std::fmt::Result {
#[allow(unused_imports)]
use ::salsa::debug::helper::Fallback;

View file

@ -50,8 +50,8 @@ impl TrackedStruct {
fn generate_tracked(&self) -> syn::Result<TokenStream> {
self.validate_tracked()?;
let id_struct = self.id_struct();
let config_struct = self.config_struct();
let the_struct = self.id_or_ptr_struct(&config_struct.ident)?;
let config_impl = self.config_impl(&config_struct);
let inherent_impl = self.tracked_inherent_impl();
let ingredients_for_impl = self.tracked_struct_ingredients(&config_struct);
@ -63,7 +63,7 @@ impl TrackedStruct {
Ok(quote! {
#config_struct
#config_impl
#id_struct
#the_struct
#inherent_impl
#ingredients_for_impl
#salsa_struct_in_db_impl
@ -168,7 +168,8 @@ impl TrackedStruct {
/// Generate an inherent impl with methods on the tracked type.
fn tracked_inherent_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let (ident, _, impl_generics, type_generics, where_clause) = self.id_ident_and_generics();
let jar_ty = self.jar_ty();
let db_dyn_ty = self.db_dyn_ty();
let tracked_field_ingredients: Literal = self.tracked_field_ingredients_index();
@ -207,7 +208,8 @@ impl TrackedStruct {
parse_quote! {
#[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)]
impl #ident {
impl #impl_generics #ident #type_generics
#where_clause {
pub fn #constructor_name(__db: &#db_dyn_ty, #(#field_names: #field_tys,)*) -> Self
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
@ -230,7 +232,7 @@ impl TrackedStruct {
/// function ingredient for each of the value fields.
fn tracked_struct_ingredients(&self, config_struct: &syn::ItemStruct) -> syn::ItemImpl {
use crate::literal;
let ident = self.id_ident();
let (ident, _, impl_generics, type_generics, where_clause) = self.id_ident_and_generics();
let jar_ty = self.jar_ty();
let config_struct_name = &config_struct.ident;
let field_indices: Vec<Literal> = self.all_field_indices();
@ -241,7 +243,8 @@ impl TrackedStruct {
let debug_name_fields: Vec<_> = self.all_field_names().into_iter().map(literal).collect();
parse_quote! {
impl salsa::storage::IngredientsFor for #ident {
impl #impl_generics salsa::storage::IngredientsFor for #ident #type_generics
#where_clause {
type Jar = #jar_ty;
type Ingredients = (
salsa::tracked_struct::TrackedStructIngredient<#config_struct_name>,
@ -298,15 +301,17 @@ impl TrackedStruct {
/// Implementation of `SalsaStructInDb`.
fn salsa_struct_in_db_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let (ident, parameters, _, type_generics, where_clause) = self.id_ident_and_generics();
let db = syn::Ident::new("DB", ident.span());
let jar_ty = self.jar_ty();
let tracked_struct_ingredient = self.tracked_struct_ingredient_index();
parse_quote! {
impl<DB> salsa::salsa_struct::SalsaStructInDb<DB> for #ident
impl<#db, #parameters> salsa::salsa_struct::SalsaStructInDb<#db> for #ident #type_generics
where
DB: ?Sized + salsa::DbWithJar<#jar_ty>,
#db: ?Sized + salsa::DbWithJar<#jar_ty>,
#where_clause
{
fn register_dependent_fn(db: &DB, index: salsa::routes::IngredientIndex) {
fn register_dependent_fn(db: & #db, index: salsa::routes::IngredientIndex) {
let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar);
ingredients.#tracked_struct_ingredient.register_dependent_fn(index)
@ -317,15 +322,17 @@ impl TrackedStruct {
/// Implementation of `TrackedStructInDb`.
fn tracked_struct_in_db_impl(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let (ident, parameters, _, type_generics, where_clause) = self.id_ident_and_generics();
let db = syn::Ident::new("DB", ident.span());
let jar_ty = self.jar_ty();
let tracked_struct_ingredient = self.tracked_struct_ingredient_index();
parse_quote! {
impl<DB> salsa::tracked_struct::TrackedStructInDb<DB> for #ident
impl<#db, #parameters> salsa::tracked_struct::TrackedStructInDb<#db> for #ident #type_generics
where
DB: ?Sized + salsa::DbWithJar<#jar_ty>,
#db: ?Sized + salsa::DbWithJar<#jar_ty>,
#where_clause
{
fn database_key_index(self, db: &DB) -> salsa::DatabaseKeyIndex {
fn database_key_index(self, db: &#db) -> salsa::DatabaseKeyIndex {
let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident>>::ingredient(jar);
ingredients.#tracked_struct_ingredient.database_key_index(self)