From 07d0ead9f40d42c9d6c737c9ad9276a5388baa0f Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Thu, 30 May 2024 01:55:13 -0400 Subject: [PATCH] return a NonNull instead of a `&'db` In old code, we converted to a `&'db` when creating a new tracked struct or interning, but this value in fact persisted beyond the end of `'db` (i.e., into the new revision). We now refactor so that we create the `Foo<'db>` from a `NonNull` instead of a `&'db T`, and then only create safe references when users access fields. This makes miri happy. --- components/salsa-2022-macros/src/interned.rs | 35 +++++++----- .../salsa-2022-macros/src/salsa_struct.rs | 8 +-- .../salsa-2022-macros/src/tracked_fn.rs | 10 ++++ .../salsa-2022-macros/src/tracked_struct.rs | 28 ++++++---- components/salsa-2022/src/alloc.rs | 26 +++++---- components/salsa-2022/src/interned.rs | 43 ++++++++++----- components/salsa-2022/src/tracked_struct.rs | 37 +++++++++---- .../src/tracked_struct/struct_map.rs | 54 +++++++++++-------- .../src/tracked_struct/tracked_field.rs | 2 + 9 files changed, 158 insertions(+), 85 deletions(-) diff --git a/components/salsa-2022-macros/src/interned.rs b/components/salsa-2022-macros/src/interned.rs index f9b749ce..39c2d4d0 100644 --- a/components/salsa-2022-macros/src/interned.rs +++ b/components/salsa-2022-macros/src/interned.rs @@ -160,6 +160,7 @@ impl InternedStruct { data_ident: &syn::Ident, config_ident: &syn::Ident, ) -> syn::ItemImpl { + let the_ident = self.the_ident(); let lt_db = &self.named_db_lifetime(); let (_, _, _, type_generics, _) = self.the_ident_and_generics(); parse_quote_spanned!( @@ -167,6 +168,16 @@ impl InternedStruct { impl salsa::interned::Configuration for #config_ident { type Data<#lt_db> = #data_ident #type_generics; + + type Struct<#lt_db> = #the_ident < #lt_db >; + + unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull>) -> Self::Struct<'db> { + #the_ident(ptr, std::marker::PhantomData) + } + + fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct { + unsafe { s.0.as_ref() } + } } ) } @@ -191,7 +202,7 @@ impl InternedStruct { let field_getters: Vec = self .all_fields() - .map(|field| { + .map(|field: &crate::salsa_struct::SalsaField| { let field_name = field.name(); let field_ty = field.ty(); let field_vis = field.vis(); @@ -199,13 +210,13 @@ impl InternedStruct { if field.is_clone_field() { parse_quote_spanned! { field_get_name.span() => #field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> #field_ty { - std::clone::Clone::clone(&unsafe { &*self.0 }.data().#field_name) + std::clone::Clone::clone(&unsafe { self.0.as_ref() }.data().#field_name) } } } else { parse_quote_spanned! { field_get_name.span() => #field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> & #db_lt #field_ty { - &unsafe { &*self.0 }.data().#field_name + &unsafe { self.0.as_ref() }.data().#field_name } } } @@ -218,18 +229,15 @@ impl InternedStruct { let constructor_name = self.constructor_name(); let new_method: syn::ImplItemFn = parse_quote_spanned! { constructor_name.span() => #vis fn #constructor_name( - db: &#db_dyn_ty, + db: &#db_lt #db_dyn_ty, #(#field_names: #field_tys,)* ) -> Self { let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #the_ident #type_generics >>::ingredient(jar); - Self( - ingredients.intern(runtime, #data_ident { - #(#field_names,)* - __phantom: std::marker::PhantomData, - }), - std::marker::PhantomData, - ) + ingredients.intern(runtime, #data_ident { + #(#field_names,)* + __phantom: std::marker::PhantomData, + }) } }; @@ -262,6 +270,7 @@ impl InternedStruct { self.the_ident_and_generics(); let db_dyn_ty = self.db_dyn_ty(); let jar_ty = self.jar_ty(); + let db_lt = self.named_db_lifetime(); let field_getters: Vec = self .all_fields() @@ -296,7 +305,7 @@ impl InternedStruct { let constructor_name = self.constructor_name(); let new_method: syn::ImplItemFn = parse_quote_spanned! { constructor_name.span() => #vis fn #constructor_name( - db: &#db_dyn_ty, + db: & #db_lt #db_lt #db_dyn_ty, #(#field_names: #field_tys,)* ) -> Self { let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); @@ -384,7 +393,7 @@ impl InternedStruct { fn lookup_id(id: salsa::Id, db: & #db_lt DB) -> Self { let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar); - Self(ingredients.interned_value(id), std::marker::PhantomData) + ingredients.interned_value(id) } } }) diff --git a/components/salsa-2022-macros/src/salsa_struct.rs b/components/salsa-2022-macros/src/salsa_struct.rs index 4baccb3e..3e2f1003 100644 --- a/components/salsa-2022-macros/src/salsa_struct.rs +++ b/components/salsa-2022-macros/src/salsa_struct.rs @@ -349,7 +349,7 @@ impl SalsaStruct { #(#attrs)* #[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #visibility struct #ident #generics ( - *const salsa::#module::ValueStruct < #config_ident >, + std::ptr::NonNull>, std::marker::PhantomData < & #lifetime salsa::#module::ValueStruct < #config_ident > > ); }) @@ -360,7 +360,9 @@ impl SalsaStruct { pub(crate) fn access_salsa_id_from_self(&self) -> syn::Expr { match self.the_struct_kind() { TheStructKind::Id => parse_quote!(self.0), - TheStructKind::Pointer(_) => parse_quote!(salsa::id::AsId::as_id(unsafe { &*self.0 })), + TheStructKind::Pointer(_) => { + parse_quote!(salsa::id::AsId::as_id(unsafe { self.0.as_ref() })) + } } } @@ -434,7 +436,7 @@ impl SalsaStruct { #where_clause { fn as_id(&self) -> salsa::Id { - salsa::id::AsId::as_id(unsafe { &*self.0 }) + salsa::id::AsId::as_id(unsafe { self.0.as_ref() }) } } diff --git a/components/salsa-2022-macros/src/tracked_fn.rs b/components/salsa-2022-macros/src/tracked_fn.rs index c1733b7f..907ee377 100644 --- a/components/salsa-2022-macros/src/tracked_fn.rs +++ b/components/salsa-2022-macros/src/tracked_fn.rs @@ -343,6 +343,16 @@ fn interned_configuration_impl( parse_quote!( impl salsa::interned::Configuration for #config_ty { type Data<#db_lt> = #intern_data_ty; + + type Struct<#db_lt> = & #db_lt salsa::interned::ValueStruct; + + unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull>) -> Self::Struct<'db> { + unsafe { ptr.as_ref() } + } + + fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct { + s + } } ) } diff --git a/components/salsa-2022-macros/src/tracked_struct.rs b/components/salsa-2022-macros/src/tracked_struct.rs index 86007b9c..aa049dae 100644 --- a/components/salsa-2022-macros/src/tracked_struct.rs +++ b/components/salsa-2022-macros/src/tracked_struct.rs @@ -92,6 +92,7 @@ impl TrackedStruct { let field_tys: Vec<_> = self.all_fields().map(SalsaField::ty).collect(); let id_field_indices = self.id_field_indices(); let arity = self.all_field_count(); + let the_ident = self.the_ident(); let lt_db = &self.named_db_lifetime(); // Create the function body that will update the revisions for each field. @@ -132,8 +133,19 @@ impl TrackedStruct { parse_quote! { impl salsa::tracked_struct::Configuration for #config_ident { type Fields<#lt_db> = ( #(#field_tys,)* ); + + type Struct<#lt_db> = #the_ident<#lt_db>; + type Revisions = [salsa::Revision; #arity]; + unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull>) -> Self::Struct<'db> { + #the_ident(ptr, std::marker::PhantomData) + } + + fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::tracked_struct::ValueStruct { + unsafe { s.0.as_ref() } + } + #[allow(clippy::unused_unit)] fn id_fields(fields: &Self::Fields<'_>) -> impl std::hash::Hash { ( #( &fields.#id_field_indices ),* ) @@ -205,7 +217,7 @@ impl TrackedStruct { #field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> & #lt_db #field_ty { let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); - let fields = unsafe { &*self.0 }.field(__runtime, #field_index); + let fields = unsafe { self.0.as_ref() }.field(__runtime, #field_index); &fields.#field_index } } @@ -214,7 +226,7 @@ impl TrackedStruct { #field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> #field_ty { let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); - let fields = unsafe { &*self.0 }.field(__runtime, #field_index); + let fields = unsafe { self.0.as_ref() }.field(__runtime, #field_index); fields.#field_index.clone() } } @@ -232,11 +244,6 @@ impl TrackedStruct { let salsa_id = self.access_salsa_id_from_self(); - let ctor = match the_kind { - TheStructKind::Id => quote!(salsa::id::FromId::from_as_id(#data)), - TheStructKind::Pointer(_) => quote!(Self(#data, std::marker::PhantomData)), - }; - let lt_db = self.maybe_elided_db_lifetime(); parse_quote! { #[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)] @@ -246,11 +253,10 @@ impl TrackedStruct { { let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< Self >>::ingredient(__jar); - let #data = __ingredients.0.new_struct( + __ingredients.0.new_struct( __runtime, (#(#field_names,)*), - ); - #ctor + ) } pub fn salsa_id(&self) -> salsa::Id { @@ -354,7 +360,7 @@ impl TrackedStruct { fn lookup_id(id: salsa::Id, db: & #db_lt DB) -> Self { let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db); let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar); - Self(ingredients.#tracked_struct_ingredient.lookup_struct(runtime, id), std::marker::PhantomData) + ingredients.#tracked_struct_ingredient.lookup_struct(runtime, id) } } }) diff --git a/components/salsa-2022/src/alloc.rs b/components/salsa-2022/src/alloc.rs index 633b52a7..557f429a 100644 --- a/components/salsa-2022/src/alloc.rs +++ b/components/salsa-2022/src/alloc.rs @@ -13,6 +13,18 @@ impl Alloc { data: unsafe { NonNull::new_unchecked(data) }, } } + + pub fn as_raw(&self) -> NonNull { + self.data + } + + pub unsafe fn as_ref(&self) -> &T { + unsafe { self.data.as_ref() } + } + + pub unsafe fn as_mut(&mut self) -> &mut T { + unsafe { self.data.as_mut() } + } } impl Drop for Alloc { @@ -23,20 +35,6 @@ impl Drop for Alloc { } } -impl std::ops::Deref for Alloc { - type Target = T; - - fn deref(&self) -> &Self::Target { - unsafe { self.data.as_ref() } - } -} - -impl std::ops::DerefMut for Alloc { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.data.as_mut() } - } -} - unsafe impl Send for Alloc where T: Send {} unsafe impl Sync for Alloc where T: Sync {} diff --git a/components/salsa-2022/src/interned.rs b/components/salsa-2022/src/interned.rs index 987d3180..7839b64f 100644 --- a/components/salsa-2022/src/interned.rs +++ b/components/salsa-2022/src/interned.rs @@ -2,6 +2,7 @@ use crossbeam::atomic::AtomicCell; use std::fmt; use std::hash::Hash; use std::marker::PhantomData; +use std::ptr::NonNull; use crate::alloc::Alloc; use crate::durability::Durability; @@ -18,8 +19,28 @@ use super::ingredient::Ingredient; use super::routes::IngredientIndex; use super::Revision; -pub trait Configuration { +pub trait Configuration: Sized { type Data<'db>: InternedData; + type Struct<'db>: Copy; + + /// Create an end-user struct from the underlying raw pointer. + /// + /// This call is an "end-step" to the tracked struct lookup/creation + /// process in a given revision: it occurs only when the struct is newly + /// created or, if a struct is being reused, after we have updated its + /// fields (or confirmed it is green and no updates are required). + /// + /// # Unsafety + /// + /// Requires that `ptr` represents a "confirmed" value in this revision, + /// which means that it will remain valid and immutable for the remainder of this + /// revision, represented by the lifetime `'db`. + unsafe fn struct_from_raw<'db>(ptr: NonNull>) -> Self::Struct<'db>; + + /// Deref the struct to yield the underlying value struct. + /// Since we are still part of the `'db` lifetime in which the struct was created, + /// this deref is safe, and the value-struct fields are immutable and verified. + fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db ValueStruct; } pub trait InternedData: Sized + Eq + Hash + Clone {} @@ -83,15 +104,11 @@ where } pub fn intern_id<'db>(&'db self, runtime: &'db Runtime, data: C::Data<'db>) -> crate::Id { - self.intern(runtime, data).as_id() + C::deref_struct(self.intern(runtime, data)).as_id() } /// Intern data to a unique reference. - pub fn intern<'db>( - &'db self, - runtime: &'db Runtime, - data: C::Data<'db>, - ) -> &'db ValueStruct { + pub fn intern<'db>(&'db self, runtime: &'db Runtime, data: C::Data<'db>) -> C::Struct<'db> { runtime.report_tracked_read( DependencyIndex::for_table(self.ingredient_index), Durability::MAX, @@ -126,27 +143,27 @@ where id: next_id, fields: internal_data, })); - // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. - let value_ref = unsafe { transmute_lifetime(self, &**value) }; + let value_raw = value.as_raw(); drop(value); entry.insert(next_id); - value_ref + // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. + unsafe { C::struct_from_raw(value_raw) } } } } - pub fn interned_value<'db>(&'db self, id: Id) -> &'db ValueStruct { + pub fn interned_value<'db>(&'db self, id: Id) -> C::Struct<'db> { let r = self.value_map.get(&id).unwrap(); // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. - unsafe { transmute_lifetime(self, &**r) } + unsafe { C::struct_from_raw(r.as_raw()) } } /// Lookup the data for an interned value based on its id. /// Rarely used since end-users generally carry a struct with a pointer directly /// to the interned item. pub fn data<'db>(&'db self, id: Id) -> &'db C::Data<'db> { - self.interned_value(id).data() + C::deref_struct(self.interned_value(id)).data() } /// Variant of `data` that takes a (unnecessary) database argument. diff --git a/components/salsa-2022/src/tracked_struct.rs b/components/salsa-2022/src/tracked_struct.rs index b5640e41..59197899 100644 --- a/components/salsa-2022/src/tracked_struct.rs +++ b/components/salsa-2022/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use std::{fmt, hash::Hash}; +use std::{fmt, hash::Hash, ptr::NonNull}; use crossbeam::atomic::AtomicCell; use dashmap::mapref::entry::Entry; @@ -25,7 +25,7 @@ mod tracked_field; /// Trait that defines the key properties of a tracked struct. /// Implemented by the `#[salsa::tracked]` macro when applied /// to a struct. -pub trait Configuration { +pub trait Configuration: Sized { /// A (possibly empty) tuple of the fields for this struct. type Fields<'db>; @@ -35,6 +35,27 @@ pub trait Configuration { /// values have changed (or if the field is marked as `#[no_eq]`). type Revisions; + type Struct<'db>: Copy; + + /// Create an end-user struct from the underlying raw pointer. + /// + /// This call is an "end-step" to the tracked struct lookup/creation + /// process in a given revision: it occurs only when the struct is newly + /// created or, if a struct is being reused, after we have updated its + /// fields (or confirmed it is green and no updates are required). + /// + /// # Unsafety + /// + /// Requires that `ptr` represents a "confirmed" value in this revision, + /// which means that it will remain valid and immutable for the remainder of this + /// revision, represented by the lifetime `'db`. + unsafe fn struct_from_raw<'db>(ptr: NonNull>) -> Self::Struct<'db>; + + /// Deref the struct to yield the underlying value struct. + /// Since we are still part of the `'db` lifetime in which the struct was created, + /// this deref is safe, and the value-struct fields are immutable and verified. + fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db ValueStruct; + fn id_fields(fields: &Self::Fields<'_>) -> impl Hash; /// Access the revision of a given value field. @@ -132,10 +153,6 @@ struct KeyStruct { disambiguator: Disambiguator, } -impl crate::interned::Configuration for KeyStruct { - type Data<'db> = KeyStruct; -} - // ANCHOR: ValueStruct #[derive(Debug)] pub struct ValueStruct @@ -262,7 +279,7 @@ where &'db self, runtime: &'db Runtime, fields: C::Fields<'db>, - ) -> &'db ValueStruct { + ) -> C::Struct<'db> { let data_hash = crate::hash::hash(&C::id_fields(&fields)); let (query_key, current_deps, disambiguator) = @@ -310,7 +327,7 @@ where // verified. Therefore, the durability ought not to have // changed (nor the field values, but the user could've // done something stupid, so we can't *assert* this is true). - assert!(r.durability == current_deps.durability); + assert!(C::deref_struct(r).durability == current_deps.durability); r } @@ -333,8 +350,8 @@ where if current_deps.durability < data.durability { data.revisions = C::new_revisions(current_revision); } - data.created_at = current_revision; data.durability = current_deps.durability; + data.created_at = current_revision; data_ref.freeze() } } @@ -347,7 +364,7 @@ where /// # Panics /// /// If the struct has not been created in this revision. - pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db ValueStruct { + pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { self.struct_map.get(runtime, id) } diff --git a/components/salsa-2022/src/tracked_struct/struct_map.rs b/components/salsa-2022/src/tracked_struct/struct_map.rs index 55b6ed20..57c90c74 100644 --- a/components/salsa-2022/src/tracked_struct/struct_map.rs +++ b/components/salsa-2022/src/tracked_struct/struct_map.rs @@ -50,7 +50,7 @@ where /// to this struct creation were up-to-date, and therefore the field contents /// ought not to have changed (barring user error). Returns a shared reference /// because caller cannot safely modify fields at this point. - Current(&'db ValueStruct), + Current(C::Struct<'db>), } impl StructMap @@ -77,13 +77,14 @@ where /// /// * If value with same `value.id` is already present in the map. /// * If value not created in current revision. - pub fn insert<'db>(&'db self, runtime: &'db Runtime, value: ValueStruct) -> &ValueStruct { + pub fn insert<'db>(&'db self, runtime: &'db Runtime, value: ValueStruct) -> C::Struct<'db> { assert_eq!(value.created_at, runtime.current_revision()); + let id = value.id; let boxed_value = Alloc::new(value); - let pointer = std::ptr::addr_of!(*boxed_value); + let pointer = boxed_value.as_raw(); - let old_value = self.map.insert(boxed_value.id, boxed_value); + let old_value = self.map.insert(id, boxed_value); assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here // Unsafety clause: @@ -92,12 +93,16 @@ where // the pointer is to the contents of the box, which have a stable // address. // * Values are only removed or altered when we have `&mut self`. - unsafe { transmute_lifetime(self, &*pointer) } + unsafe { C::struct_from_raw(pointer) } } pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: Id) { let mut data = self.map.get_mut(&id).unwrap(); + // UNSAFE: We never permit `&`-access in the current revision until data.created_at + // has been updated to the current revision (which we check below). + let data = unsafe { data.as_mut() }; + // Never update a struct twice in the same revision. let current_revision = runtime.current_revision(); assert!(data.created_at < current_revision); @@ -117,6 +122,10 @@ where // Never update a struct twice in the same revision. let current_revision = runtime.current_revision(); + // UNSAFE: We never permit `&`-access in the current revision until data.created_at + // has been updated to the current revision (which we check below). + let data_ref = unsafe { data.as_mut() }; + // Subtle: it's possible that this struct was already validated // in this revision. What can happen (e.g., in the test // `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`) @@ -140,12 +149,12 @@ where // // For this reason, we just return `None` in this case, ensuring that the calling // code cannot violate that `&`-reference. - if data.created_at == current_revision { + if data_ref.created_at == current_revision { drop(data); - return Update::Current(&Self::get_from_map(&self.map, runtime, id)); + return Update::Current(Self::get_from_map(&self.map, runtime, id)); } - data.created_at = current_revision; + data_ref.created_at = current_revision; Update::Outdated(UpdateRef { guard: data }) } @@ -155,7 +164,7 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db ValueStruct { + pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { Self::get_from_map(&self.map, runtime, id) } @@ -169,14 +178,17 @@ where map: &'db FxDashMap>>, runtime: &'db Runtime, id: Id, - ) -> &'db ValueStruct { + ) -> C::Struct<'db> { let data = map.get(&id).unwrap(); - let data: &ValueStruct = &**data; + + // UNSAFE: We permit `&`-access in the current revision once data.created_at + // has been updated to the current revision (which we check below). + let data_ref: &ValueStruct = unsafe { data.as_ref() }; // Before we drop the lock, check that the value has // been updated in this revision. This is what allows us to return a `` let current_revision = runtime.current_revision(); - let created_at = data.created_at; + let created_at = data_ref.created_at; assert!( created_at == current_revision, "access to tracked struct from previous revision" @@ -187,7 +199,7 @@ where // * Value will not be updated again in this revision, // and revision will not change so long as runtime is shared // * We only remove values from the map when we have `&mut self` - unsafe { transmute_lifetime(map, data) } + unsafe { C::struct_from_raw(data.as_raw()) } } /// Remove the entry for `id` from the map. @@ -195,7 +207,8 @@ where /// NB. the data won't actually be freed until `drop_deleted_entries` is called. pub fn delete(&self, id: Id) -> Option { if let Some((_, data)) = self.map.remove(&id) { - let key = data.key; + // UNSAFE: The `key` field is immutable once `ValueStruct` is created. + let key = unsafe { data.as_ref() }.key; self.deleted_entries.push(data); Some(key) } else { @@ -219,7 +232,7 @@ where /// /// * If the value is not present in the map. /// * If the value has not been updated in this revision. - pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db ValueStruct { + pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> { StructMap::get_from_map(&self.map, runtime, id) } } @@ -239,13 +252,12 @@ where C: Configuration, { /// Finalize this update, freezing the value for the rest of the revision. - pub fn freeze(self) -> &'db ValueStruct { + pub fn freeze(self) -> C::Struct<'db> { // Unsafety clause: // // see `get` above - let data: &ValueStruct = &*self.guard; - let dummy: &'db () = &(); - unsafe { transmute_lifetime(dummy, data) } + let data = self.guard.as_raw(); + unsafe { C::struct_from_raw(data) } } } @@ -256,7 +268,7 @@ where type Target = ValueStruct; fn deref(&self) -> &Self::Target { - &self.guard + unsafe { self.guard.as_ref() } } } @@ -265,6 +277,6 @@ where C: Configuration, { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.guard + unsafe { self.guard.as_mut() } } } diff --git a/components/salsa-2022/src/tracked_struct/tracked_field.rs b/components/salsa-2022/src/tracked_struct/tracked_field.rs index 03451fe4..5e7bc1b4 100644 --- a/components/salsa-2022/src/tracked_struct/tracked_field.rs +++ b/components/salsa-2022/src/tracked_struct/tracked_field.rs @@ -40,6 +40,7 @@ where /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db C::Fields<'db> { let data = self.struct_map.get(runtime, id); + let data = C::deref_struct(data); let changed_at = C::revision(&data.revisions, self.field_index); @@ -78,6 +79,7 @@ where let runtime = db.runtime(); let id = input.key_index.unwrap(); let data = self.struct_map.get(runtime, id); + let data = C::deref_struct(data); let field_changed_at = C::revision(&data.revisions, self.field_index); field_changed_at > revision }