From 9607638d5d805a212c4c42b7fb9b8c31c0291a31 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Thu, 16 May 2024 05:18:56 -0400 Subject: [PATCH] permit interned structs with lifetimes --- components/salsa-2022-macros/src/interned.rs | 125 ++++++++++++++++-- components/salsa-2022/src/interned.rs | 14 +- .../tests/interned-struct-with-lifetime.rs | 57 ++++++++ 3 files changed, 178 insertions(+), 18 deletions(-) create mode 100644 salsa-2022-tests/tests/interned-struct-with-lifetime.rs diff --git a/components/salsa-2022-macros/src/interned.rs b/components/salsa-2022-macros/src/interned.rs index 181108dc..663fc089 100644 --- a/components/salsa-2022-macros/src/interned.rs +++ b/components/salsa-2022-macros/src/interned.rs @@ -1,4 +1,4 @@ -use crate::salsa_struct::SalsaStruct; +use crate::salsa_struct::{SalsaStruct, TheStructKind}; use proc_macro2::TokenStream; // #[salsa::interned(jar = Jar0, data = TyData0)] @@ -84,7 +84,7 @@ impl InternedStruct { fn validate_interned(&self) -> syn::Result<()> { self.disallow_id_fields("interned")?; - self.require_no_generics()?; + self.require_db_lifetime()?; Ok(()) } @@ -113,16 +113,34 @@ impl InternedStruct { let visibility = self.visibility(); let all_field_names = self.all_field_names(); let all_field_tys = self.all_field_tys(); - parse_quote_spanned! { data_ident.span() => - /// Internal struct used for interned item - #[derive(Eq, PartialEq, Hash, Clone)] - #visibility struct #data_ident #impl_generics - where - #where_clause - { - #( - #all_field_names: #all_field_tys, - )* + + match self.the_struct_kind() { + TheStructKind::Id => { + parse_quote_spanned! { data_ident.span() => + #[derive(Eq, PartialEq, Hash, Clone)] + #visibility struct #data_ident #impl_generics + where + #where_clause + { + #( + #all_field_names: #all_field_tys, + )* + } + } + } + TheStructKind::Pointer(db_lt) => { + parse_quote_spanned! { data_ident.span() => + #[derive(Eq, PartialEq, Hash, Clone)] + #visibility struct #data_ident #impl_generics + where + #where_clause + { + #( + #all_field_names: #all_field_tys, + )* + __phantom: std::marker::PhantomData<& #db_lt ()>, + } + } } } } @@ -146,6 +164,89 @@ impl InternedStruct { /// If this is an interned struct, then generate methods to access each field, /// as well as a `new` method. fn inherent_impl_for_named_fields(&self) -> syn::ItemImpl { + match self.the_struct_kind() { + TheStructKind::Id => self.inherent_impl_for_named_fields_id(), + TheStructKind::Pointer(db_lt) => self.inherent_impl_for_named_fields_lt(&db_lt), + } + } + + /// If this is an interned struct, then generate methods to access each field, + /// as well as a `new` method. + fn inherent_impl_for_named_fields_lt(&self, db_lt: &syn::Lifetime) -> syn::ItemImpl { + let vis: &syn::Visibility = self.visibility(); + let (the_ident, _, impl_generics, type_generics, where_clause) = + self.the_ident_and_generics(); + let db_dyn_ty = self.db_dyn_ty(); + let jar_ty = self.jar_ty(); + + let field_getters: Vec = self + .all_fields() + .map(|field| { + let field_name = field.name(); + let field_ty = field.ty(); + let field_vis = field.vis(); + let field_get_name = field.get_name(); + if field.is_clone_field() { + parse_quote_spanned! { field_get_name.span() => + #field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> #field_ty { + std::clone::Clone::clone(&unsafe { &*self.0 }.data().#field_name) + } + } + } else { + parse_quote_spanned! { field_get_name.span() => + #field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> & #db_lt #field_ty { + &unsafe { &*self.0 }.data().#field_name + } + } + } + }) + .collect(); + + let field_names = self.all_field_names(); + let field_tys = self.all_field_tys(); + let data_ident = self.data_ident(); + let constructor_name = self.constructor_name(); + let new_method: syn::ImplItemMethod = parse_quote_spanned! { constructor_name.span() => + #vis fn #constructor_name( + db: &#db_dyn_ty, + #(#field_names: #field_tys,)* + ) -> Self { + let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); + let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #the_ident #type_generics >>::ingredient(jar); + Self( + ingredients.intern(runtime, #data_ident { + #(#field_names,)* + __phantom: std::marker::PhantomData, + }), + std::marker::PhantomData, + ) + } + }; + + let salsa_id = quote!( + pub fn salsa_id(&self) -> salsa::Id { + unsafe { &*self.0 }.salsa_id() + } + ); + + parse_quote! { + #[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)] + impl #impl_generics #the_ident #type_generics + where + #where_clause + { + #(#field_getters)* + + #new_method + + #salsa_id + } + } + } + + /// If this is an interned struct, then generate methods to access each field, + /// as well as a `new` method. + fn inherent_impl_for_named_fields_id(&self) -> syn::ItemImpl { let vis: &syn::Visibility = self.visibility(); let (the_ident, _, impl_generics, type_generics, where_clause) = self.the_ident_and_generics(); diff --git a/components/salsa-2022/src/interned.rs b/components/salsa-2022/src/interned.rs index 71f9417b..6c276145 100644 --- a/components/salsa-2022/src/interned.rs +++ b/components/salsa-2022/src/interned.rs @@ -39,7 +39,7 @@ pub struct InternedIngredient { /// Maps from an interned id to its data. /// /// Deadlock requirement: We access `value_map` while holding lock on `key_map`, but not vice versa. - value_map: FxDashMap>>, + value_map: FxDashMap>>, /// counter for the next id. counter: AtomicCell, @@ -54,7 +54,7 @@ pub struct InternedIngredient { } /// Struct storing the interned fields. -pub struct InternedValue +pub struct ValueStruct where C: Configuration, { @@ -90,7 +90,7 @@ where &'db self, runtime: &'db Runtime, data: C::Data<'db>, - ) -> &'db InternedValue { + ) -> &'db ValueStruct { runtime.report_tracked_read( DependencyIndex::for_table(self.ingredient_index), Durability::MAX, @@ -121,7 +121,7 @@ where let value = self .value_map .entry(next_id) - .or_insert(Box::new(InternedValue { + .or_insert(Box::new(ValueStruct { id: next_id, fields: internal_data, })); @@ -134,7 +134,7 @@ where } } - pub fn interned_value<'db>(&'db self, id: Id) -> &'db InternedValue { + pub fn interned_value<'db>(&'db self, id: Id) -> &'db ValueStruct { let r = self.value_map.get(&id).unwrap(); // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. @@ -247,7 +247,7 @@ where } } -impl InternedValue +impl ValueStruct where C: Configuration, { @@ -256,6 +256,8 @@ where } pub fn data<'db>(&'db self) -> &'db C::Data<'db> { + // SAFETY: The lifetime of `self` is tied to the interning ingredient; + // we never remove data without an `&mut self` access to the interning ingredient. unsafe { self.to_self_ref(&self.fields) } } diff --git a/salsa-2022-tests/tests/interned-struct-with-lifetime.rs b/salsa-2022-tests/tests/interned-struct-with-lifetime.rs new file mode 100644 index 00000000..a8fa3e4c --- /dev/null +++ b/salsa-2022-tests/tests/interned-struct-with-lifetime.rs @@ -0,0 +1,57 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. +use salsa::DebugWithDb; +use salsa_2022_tests::{HasLogger, Logger}; + +use expect_test::expect; +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar(InternedString<'_>, InternedPair<'_>, intern_stuff); + +trait Db: salsa::DbWithJar + HasLogger {} + +#[salsa::interned] +struct InternedString<'db> { + data: String, +} + +#[salsa::interned] +struct InternedPair<'db> { + data: (InternedString<'db>, InternedString<'db>), +} + +#[salsa::tracked] +fn intern_stuff(db: &dyn Db) -> String { + let s1 = InternedString::new(db, format!("Hello, ")); + let s2 = InternedString::new(db, format!("World, ")); + let s3 = InternedPair::new(db, (s1, s2)); + format!("{:?}", s3.debug(db)) +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, + logger: Logger, +} + +impl salsa::Database for Database {} + +impl Db for Database {} + +impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } +} + +#[test] +fn execute() { + let mut db = Database::default(); + + expect![[r#" + "InternedPair { [salsa id]: 0, data: (InternedString { [salsa id]: 0, data: \"Hello, \" }, InternedString { [salsa id]: 1, data: \"World, \" }) }" + "#]].assert_debug_eq(&intern_stuff(&db)); + db.assert_logs(expect!["[]"]); +}