From e24ace24eb339d0d8c1597ac0cf9f7e4973374fe Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Fri, 12 Apr 2024 05:12:38 -0400 Subject: [PATCH] return `&TrackedStructValue` from `new_struct` This is a step towards the goal of keep a pointer in the structs themselves. --- .../salsa-2022-macros/src/tracked_struct.rs | 22 ++++++++- components/salsa-2022/src/tracked_struct.rs | 49 ++++++++++++++----- components/salsa-2022/src/update.rs | 14 ------ 3 files changed, 58 insertions(+), 27 deletions(-) diff --git a/components/salsa-2022-macros/src/tracked_struct.rs b/components/salsa-2022-macros/src/tracked_struct.rs index 5ec21af9..48022c9a 100644 --- a/components/salsa-2022-macros/src/tracked_struct.rs +++ b/components/salsa-2022-macros/src/tracked_struct.rs @@ -57,6 +57,7 @@ impl TrackedStruct { let ingredients_for_impl = self.tracked_struct_ingredients(&config_struct); let salsa_struct_in_db_impl = self.salsa_struct_in_db_impl(); let tracked_struct_in_db_impl = self.tracked_struct_in_db_impl(); + let update_impl = self.update_impl(); let as_id_impl = self.as_id_impl(); let as_debug_with_db_impl = self.as_debug_with_db_impl(); Ok(quote! { @@ -67,6 +68,7 @@ impl TrackedStruct { #ingredients_for_impl #salsa_struct_in_db_impl #tracked_struct_in_db_impl + #update_impl #as_id_impl #as_debug_with_db_impl }) @@ -210,11 +212,11 @@ impl TrackedStruct { { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); - let __id = __ingredients.0.new_struct( + let __data = __ingredients.0.new_struct( __runtime, (#(#field_names,)*), ); - __id + __data.id() } #(#field_getters)* @@ -332,6 +334,22 @@ impl TrackedStruct { } } + /// Implementation of `Update`. + fn update_impl(&self) -> syn::ItemImpl { + let ident = self.id_ident(); + parse_quote! { + unsafe impl salsa::update::Update for #ident { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + if unsafe { *old_pointer } != new_value { + unsafe { *old_pointer = new_value }; + true + } else { + false + } + } + } + } + } /// The index of the tracked struct ingredient in the ingredient tuple. fn tracked_struct_ingredient_index(&self) -> Literal { Literal::usize_unsuffixed(0) diff --git a/components/salsa-2022/src/tracked_struct.rs b/components/salsa-2022/src/tracked_struct.rs index 012dd11d..32214716 100644 --- a/components/salsa-2022/src/tracked_struct.rs +++ b/components/salsa-2022/src/tracked_struct.rs @@ -10,6 +10,7 @@ use crate::{ ingredient_list::IngredientList, interned::{InternedId, InternedIngredient}, key::{DatabaseKeyIndex, DependencyIndex}, + plumbing::transmute_lifetime, runtime::{local_state::QueryOrigin, Runtime}, salsa_struct::SalsaStructInDb, Database, Durability, Event, IngredientIndex, Revision, @@ -125,10 +126,13 @@ struct TrackedStructKey { // ANCHOR: TrackedStructValue #[derive(Debug)] -struct TrackedStructValue +pub struct TrackedStructValue where C: Configuration, { + /// The id of this struct in the ingredient. + id: C::Id, + /// The durability minimum durability of all inputs consumed /// by the creator query prior to creating this tracked struct. /// If any of those inputs changes, then the creator query may @@ -155,6 +159,16 @@ where } // ANCHOR_END: TrackedStructValue +impl TrackedStructValue +where + C: Configuration, +{ + /// The id of this struct in the ingredient. + pub fn id(&self) -> C::Id { + self.id + } +} + #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] pub struct Disambiguator(pub u32); @@ -194,7 +208,7 @@ where } } - pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> C::Id { + pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> &TrackedStructValue { let data_hash = crate::hash::hash(&C::id_fields(&fields)); let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity( @@ -211,22 +225,29 @@ where let (id, new_id) = self.interned.intern_full(runtime, entity_key); runtime.add_output(self.database_key_index(id).into()); + let pointer: *const TrackedStructValue; let current_revision = runtime.current_revision(); if new_id { - let old_value = self.entity_data.insert( + let data = Box::new(TrackedStructValue { id, - Box::new(TrackedStructValue { - created_at: current_revision, - durability: current_deps.durability, - fields, - revisions: C::new_revisions(current_deps.changed_at), - }), - ); + created_at: current_revision, + durability: current_deps.durability, + fields, + revisions: C::new_revisions(current_deps.changed_at), + }); + + // Keep a pointer into the box for later + pointer = &*data; + + let old_value = self.entity_data.insert(id, data); assert!(old_value.is_none()); } else { let mut data = self.entity_data.get_mut(&id).unwrap(); let data = &mut *data; + // Keep a pointer into the box for later + pointer = &**data; + // SAFETY: We assert that the pointer to `data.revisions` // is a pointer into the database referencing a value // from a previous revision. As such, it continues to meet @@ -247,7 +268,13 @@ where data.durability = current_deps.durability; } - id + // Unsafety clause: + // + // * The box is owned by self and, although the box has been moved, + // the pointer is to the contents of the box, which have a stable + // address. + // * Values are only removed or altered when we have `&mut self`. + unsafe { transmute_lifetime(self, &*pointer) } } /// Deletes the given entities. This is used after a query `Q` executes and we can compare diff --git a/components/salsa-2022/src/update.rs b/components/salsa-2022/src/update.rs index ad5b6d3e..4a3d369a 100644 --- a/components/salsa-2022/src/update.rs +++ b/components/salsa-2022/src/update.rs @@ -126,20 +126,6 @@ pub unsafe trait Update { unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool; } -unsafe impl Update for &T { - unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { - let old_value: *const T = unsafe { *old_pointer }; - if old_value != (new_value as *const T) { - unsafe { - *old_pointer = new_value; - } - true - } else { - false - } - } -} - unsafe impl Update for Vec where T: Update,