return a NonNull instead of a &'db

In old code, we converted to a `&'db` when
creating a new tracked struct or interning,
but this value in fact persisted beyond the end
of `'db` (i.e., into the new revision).

We now refactor so that we create the `Foo<'db>`
from a `NonNull<T>` instead of a `&'db T`, and
then only create safe references when users
access fields.

This makes miri happy.
This commit is contained in:
Niko Matsakis 2024-05-30 01:55:13 -04:00
parent 8c51f37292
commit 07d0ead9f4
9 changed files with 158 additions and 85 deletions

View file

@ -160,6 +160,7 @@ impl InternedStruct {
data_ident: &syn::Ident,
config_ident: &syn::Ident,
) -> syn::ItemImpl {
let the_ident = self.the_ident();
let lt_db = &self.named_db_lifetime();
let (_, _, _, type_generics, _) = self.the_ident_and_generics();
parse_quote_spanned!(
@ -167,6 +168,16 @@ impl InternedStruct {
impl salsa::interned::Configuration for #config_ident {
type Data<#lt_db> = #data_ident #type_generics;
type Struct<#lt_db> = #the_ident < #lt_db >;
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::interned::ValueStruct<Self>>) -> Self::Struct<'db> {
#the_ident(ptr, std::marker::PhantomData)
}
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct<Self> {
unsafe { s.0.as_ref() }
}
}
)
}
@ -191,7 +202,7 @@ impl InternedStruct {
let field_getters: Vec<syn::ImplItemFn> = self
.all_fields()
.map(|field| {
.map(|field: &crate::salsa_struct::SalsaField| {
let field_name = field.name();
let field_ty = field.ty();
let field_vis = field.vis();
@ -199,13 +210,13 @@ impl InternedStruct {
if field.is_clone_field() {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> #field_ty {
std::clone::Clone::clone(&unsafe { &*self.0 }.data().#field_name)
std::clone::Clone::clone(&unsafe { self.0.as_ref() }.data().#field_name)
}
}
} else {
parse_quote_spanned! { field_get_name.span() =>
#field_vis fn #field_get_name(self, _db: & #db_lt #db_dyn_ty) -> & #db_lt #field_ty {
&unsafe { &*self.0 }.data().#field_name
&unsafe { self.0.as_ref() }.data().#field_name
}
}
}
@ -218,18 +229,15 @@ impl InternedStruct {
let constructor_name = self.constructor_name();
let new_method: syn::ImplItemFn = parse_quote_spanned! { constructor_name.span() =>
#vis fn #constructor_name(
db: &#db_dyn_ty,
db: &#db_lt #db_dyn_ty,
#(#field_names: #field_tys,)*
) -> Self {
let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #the_ident #type_generics >>::ingredient(jar);
Self(
ingredients.intern(runtime, #data_ident {
#(#field_names,)*
__phantom: std::marker::PhantomData,
}),
std::marker::PhantomData,
)
ingredients.intern(runtime, #data_ident {
#(#field_names,)*
__phantom: std::marker::PhantomData,
})
}
};
@ -262,6 +270,7 @@ impl InternedStruct {
self.the_ident_and_generics();
let db_dyn_ty = self.db_dyn_ty();
let jar_ty = self.jar_ty();
let db_lt = self.named_db_lifetime();
let field_getters: Vec<syn::ImplItemFn> = self
.all_fields()
@ -296,7 +305,7 @@ impl InternedStruct {
let constructor_name = self.constructor_name();
let new_method: syn::ImplItemFn = parse_quote_spanned! { constructor_name.span() =>
#vis fn #constructor_name(
db: &#db_dyn_ty,
db: & #db_lt #db_lt #db_dyn_ty,
#(#field_names: #field_tys,)*
) -> Self {
let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
@ -384,7 +393,7 @@ impl InternedStruct {
fn lookup_id(id: salsa::Id, db: & #db_lt DB) -> Self {
let (jar, _) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar);
Self(ingredients.interned_value(id), std::marker::PhantomData)
ingredients.interned_value(id)
}
}
})

View file

@ -349,7 +349,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
#(#attrs)*
#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#visibility struct #ident #generics (
*const salsa::#module::ValueStruct < #config_ident >,
std::ptr::NonNull<salsa::#module::ValueStruct < #config_ident >>,
std::marker::PhantomData < & #lifetime salsa::#module::ValueStruct < #config_ident > >
);
})
@ -360,7 +360,9 @@ impl<A: AllowedOptions> SalsaStruct<A> {
pub(crate) fn access_salsa_id_from_self(&self) -> syn::Expr {
match self.the_struct_kind() {
TheStructKind::Id => parse_quote!(self.0),
TheStructKind::Pointer(_) => parse_quote!(salsa::id::AsId::as_id(unsafe { &*self.0 })),
TheStructKind::Pointer(_) => {
parse_quote!(salsa::id::AsId::as_id(unsafe { self.0.as_ref() }))
}
}
}
@ -434,7 +436,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
#where_clause
{
fn as_id(&self) -> salsa::Id {
salsa::id::AsId::as_id(unsafe { &*self.0 })
salsa::id::AsId::as_id(unsafe { self.0.as_ref() })
}
}

View file

@ -343,6 +343,16 @@ fn interned_configuration_impl(
parse_quote!(
impl salsa::interned::Configuration for #config_ty {
type Data<#db_lt> = #intern_data_ty;
type Struct<#db_lt> = & #db_lt salsa::interned::ValueStruct<Self>;
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::interned::ValueStruct<Self>>) -> Self::Struct<'db> {
unsafe { ptr.as_ref() }
}
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::interned::ValueStruct<Self> {
s
}
}
)
}

View file

@ -92,6 +92,7 @@ impl TrackedStruct {
let field_tys: Vec<_> = self.all_fields().map(SalsaField::ty).collect();
let id_field_indices = self.id_field_indices();
let arity = self.all_field_count();
let the_ident = self.the_ident();
let lt_db = &self.named_db_lifetime();
// Create the function body that will update the revisions for each field.
@ -132,8 +133,19 @@ impl TrackedStruct {
parse_quote! {
impl salsa::tracked_struct::Configuration for #config_ident {
type Fields<#lt_db> = ( #(#field_tys,)* );
type Struct<#lt_db> = #the_ident<#lt_db>;
type Revisions = [salsa::Revision; #arity];
unsafe fn struct_from_raw<'db>(ptr: std::ptr::NonNull<salsa::tracked_struct::ValueStruct<Self>>) -> Self::Struct<'db> {
#the_ident(ptr, std::marker::PhantomData)
}
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db salsa::tracked_struct::ValueStruct<Self> {
unsafe { s.0.as_ref() }
}
#[allow(clippy::unused_unit)]
fn id_fields(fields: &Self::Fields<'_>) -> impl std::hash::Hash {
( #( &fields.#id_field_indices ),* )
@ -205,7 +217,7 @@ impl TrackedStruct {
#field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> & #lt_db #field_ty
{
let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let fields = unsafe { &*self.0 }.field(__runtime, #field_index);
let fields = unsafe { self.0.as_ref() }.field(__runtime, #field_index);
&fields.#field_index
}
}
@ -214,7 +226,7 @@ impl TrackedStruct {
#field_vis fn #field_get_name(self, __db: & #lt_db #db_dyn_ty) -> #field_ty
{
let (_, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let fields = unsafe { &*self.0 }.field(__runtime, #field_index);
let fields = unsafe { self.0.as_ref() }.field(__runtime, #field_index);
fields.#field_index.clone()
}
}
@ -232,11 +244,6 @@ impl TrackedStruct {
let salsa_id = self.access_salsa_id_from_self();
let ctor = match the_kind {
TheStructKind::Id => quote!(salsa::id::FromId::from_as_id(#data)),
TheStructKind::Pointer(_) => quote!(Self(#data, std::marker::PhantomData)),
};
let lt_db = self.maybe_elided_db_lifetime();
parse_quote! {
#[allow(dead_code, clippy::pedantic, clippy::complexity, clippy::style)]
@ -246,11 +253,10 @@ impl TrackedStruct {
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< Self >>::ingredient(__jar);
let #data = __ingredients.0.new_struct(
__ingredients.0.new_struct(
__runtime,
(#(#field_names,)*),
);
#ctor
)
}
pub fn salsa_id(&self) -> salsa::Id {
@ -354,7 +360,7 @@ impl TrackedStruct {
fn lookup_id(id: salsa::Id, db: & #db_lt DB) -> Self {
let (jar, runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(db);
let ingredients = <#jar_ty as salsa::storage::HasIngredientsFor<#ident #type_generics>>::ingredient(jar);
Self(ingredients.#tracked_struct_ingredient.lookup_struct(runtime, id), std::marker::PhantomData)
ingredients.#tracked_struct_ingredient.lookup_struct(runtime, id)
}
}
})

View file

@ -13,6 +13,18 @@ impl<T> Alloc<T> {
data: unsafe { NonNull::new_unchecked(data) },
}
}
pub fn as_raw(&self) -> NonNull<T> {
self.data
}
pub unsafe fn as_ref(&self) -> &T {
unsafe { self.data.as_ref() }
}
pub unsafe fn as_mut(&mut self) -> &mut T {
unsafe { self.data.as_mut() }
}
}
impl<T> Drop for Alloc<T> {
@ -23,20 +35,6 @@ impl<T> Drop for Alloc<T> {
}
}
impl<T> std::ops::Deref for Alloc<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.data.as_ref() }
}
}
impl<T> std::ops::DerefMut for Alloc<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.data.as_mut() }
}
}
unsafe impl<T> Send for Alloc<T> where T: Send {}
unsafe impl<T> Sync for Alloc<T> where T: Sync {}

View file

@ -2,6 +2,7 @@ use crossbeam::atomic::AtomicCell;
use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ptr::NonNull;
use crate::alloc::Alloc;
use crate::durability::Durability;
@ -18,8 +19,28 @@ use super::ingredient::Ingredient;
use super::routes::IngredientIndex;
use super::Revision;
pub trait Configuration {
pub trait Configuration: Sized {
type Data<'db>: InternedData;
type Struct<'db>: Copy;
/// Create an end-user struct from the underlying raw pointer.
///
/// This call is an "end-step" to the tracked struct lookup/creation
/// process in a given revision: it occurs only when the struct is newly
/// created or, if a struct is being reused, after we have updated its
/// fields (or confirmed it is green and no updates are required).
///
/// # Unsafety
///
/// Requires that `ptr` represents a "confirmed" value in this revision,
/// which means that it will remain valid and immutable for the remainder of this
/// revision, represented by the lifetime `'db`.
unsafe fn struct_from_raw<'db>(ptr: NonNull<ValueStruct<Self>>) -> Self::Struct<'db>;
/// Deref the struct to yield the underlying value struct.
/// Since we are still part of the `'db` lifetime in which the struct was created,
/// this deref is safe, and the value-struct fields are immutable and verified.
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db ValueStruct<Self>;
}
pub trait InternedData: Sized + Eq + Hash + Clone {}
@ -83,15 +104,11 @@ where
}
pub fn intern_id<'db>(&'db self, runtime: &'db Runtime, data: C::Data<'db>) -> crate::Id {
self.intern(runtime, data).as_id()
C::deref_struct(self.intern(runtime, data)).as_id()
}
/// Intern data to a unique reference.
pub fn intern<'db>(
&'db self,
runtime: &'db Runtime,
data: C::Data<'db>,
) -> &'db ValueStruct<C> {
pub fn intern<'db>(&'db self, runtime: &'db Runtime, data: C::Data<'db>) -> C::Struct<'db> {
runtime.report_tracked_read(
DependencyIndex::for_table(self.ingredient_index),
Durability::MAX,
@ -126,27 +143,27 @@ where
id: next_id,
fields: internal_data,
}));
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
let value_ref = unsafe { transmute_lifetime(self, &**value) };
let value_raw = value.as_raw();
drop(value);
entry.insert(next_id);
value_ref
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
unsafe { C::struct_from_raw(value_raw) }
}
}
}
pub fn interned_value<'db>(&'db self, id: Id) -> &'db ValueStruct<C> {
pub fn interned_value<'db>(&'db self, id: Id) -> C::Struct<'db> {
let r = self.value_map.get(&id).unwrap();
// SAFETY: Items are only removed from the `value_map` with an `&mut self` reference.
unsafe { transmute_lifetime(self, &**r) }
unsafe { C::struct_from_raw(r.as_raw()) }
}
/// Lookup the data for an interned value based on its id.
/// Rarely used since end-users generally carry a struct with a pointer directly
/// to the interned item.
pub fn data<'db>(&'db self, id: Id) -> &'db C::Data<'db> {
self.interned_value(id).data()
C::deref_struct(self.interned_value(id)).data()
}
/// Variant of `data` that takes a (unnecessary) database argument.

View file

@ -1,4 +1,4 @@
use std::{fmt, hash::Hash};
use std::{fmt, hash::Hash, ptr::NonNull};
use crossbeam::atomic::AtomicCell;
use dashmap::mapref::entry::Entry;
@ -25,7 +25,7 @@ mod tracked_field;
/// Trait that defines the key properties of a tracked struct.
/// Implemented by the `#[salsa::tracked]` macro when applied
/// to a struct.
pub trait Configuration {
pub trait Configuration: Sized {
/// A (possibly empty) tuple of the fields for this struct.
type Fields<'db>;
@ -35,6 +35,27 @@ pub trait Configuration {
/// values have changed (or if the field is marked as `#[no_eq]`).
type Revisions;
type Struct<'db>: Copy;
/// Create an end-user struct from the underlying raw pointer.
///
/// This call is an "end-step" to the tracked struct lookup/creation
/// process in a given revision: it occurs only when the struct is newly
/// created or, if a struct is being reused, after we have updated its
/// fields (or confirmed it is green and no updates are required).
///
/// # Unsafety
///
/// Requires that `ptr` represents a "confirmed" value in this revision,
/// which means that it will remain valid and immutable for the remainder of this
/// revision, represented by the lifetime `'db`.
unsafe fn struct_from_raw<'db>(ptr: NonNull<ValueStruct<Self>>) -> Self::Struct<'db>;
/// Deref the struct to yield the underlying value struct.
/// Since we are still part of the `'db` lifetime in which the struct was created,
/// this deref is safe, and the value-struct fields are immutable and verified.
fn deref_struct<'db>(s: Self::Struct<'db>) -> &'db ValueStruct<Self>;
fn id_fields(fields: &Self::Fields<'_>) -> impl Hash;
/// Access the revision of a given value field.
@ -132,10 +153,6 @@ struct KeyStruct {
disambiguator: Disambiguator,
}
impl crate::interned::Configuration for KeyStruct {
type Data<'db> = KeyStruct;
}
// ANCHOR: ValueStruct
#[derive(Debug)]
pub struct ValueStruct<C>
@ -262,7 +279,7 @@ where
&'db self,
runtime: &'db Runtime,
fields: C::Fields<'db>,
) -> &'db ValueStruct<C> {
) -> C::Struct<'db> {
let data_hash = crate::hash::hash(&C::id_fields(&fields));
let (query_key, current_deps, disambiguator) =
@ -310,7 +327,7 @@ where
// verified. Therefore, the durability ought not to have
// changed (nor the field values, but the user could've
// done something stupid, so we can't *assert* this is true).
assert!(r.durability == current_deps.durability);
assert!(C::deref_struct(r).durability == current_deps.durability);
r
}
@ -333,8 +350,8 @@ where
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
}
data.created_at = current_revision;
data.durability = current_deps.durability;
data.created_at = current_revision;
data_ref.freeze()
}
}
@ -347,7 +364,7 @@ where
/// # Panics
///
/// If the struct has not been created in this revision.
pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db ValueStruct<C> {
pub fn lookup_struct<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> {
self.struct_map.get(runtime, id)
}

View file

@ -50,7 +50,7 @@ where
/// to this struct creation were up-to-date, and therefore the field contents
/// ought not to have changed (barring user error). Returns a shared reference
/// because caller cannot safely modify fields at this point.
Current(&'db ValueStruct<C>),
Current(C::Struct<'db>),
}
impl<C> StructMap<C>
@ -77,13 +77,14 @@ where
///
/// * If value with same `value.id` is already present in the map.
/// * If value not created in current revision.
pub fn insert<'db>(&'db self, runtime: &'db Runtime, value: ValueStruct<C>) -> &ValueStruct<C> {
pub fn insert<'db>(&'db self, runtime: &'db Runtime, value: ValueStruct<C>) -> C::Struct<'db> {
assert_eq!(value.created_at, runtime.current_revision());
let id = value.id;
let boxed_value = Alloc::new(value);
let pointer = std::ptr::addr_of!(*boxed_value);
let pointer = boxed_value.as_raw();
let old_value = self.map.insert(boxed_value.id, boxed_value);
let old_value = self.map.insert(id, boxed_value);
assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here
// Unsafety clause:
@ -92,12 +93,16 @@ where
// the pointer is to the contents of the box, which have a stable
// address.
// * Values are only removed or altered when we have `&mut self`.
unsafe { transmute_lifetime(self, &*pointer) }
unsafe { C::struct_from_raw(pointer) }
}
pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: Id) {
let mut data = self.map.get_mut(&id).unwrap();
// UNSAFE: We never permit `&`-access in the current revision until data.created_at
// has been updated to the current revision (which we check below).
let data = unsafe { data.as_mut() };
// Never update a struct twice in the same revision.
let current_revision = runtime.current_revision();
assert!(data.created_at < current_revision);
@ -117,6 +122,10 @@ where
// Never update a struct twice in the same revision.
let current_revision = runtime.current_revision();
// UNSAFE: We never permit `&`-access in the current revision until data.created_at
// has been updated to the current revision (which we check below).
let data_ref = unsafe { data.as_mut() };
// Subtle: it's possible that this struct was already validated
// in this revision. What can happen (e.g., in the test
// `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`)
@ -140,12 +149,12 @@ where
//
// For this reason, we just return `None` in this case, ensuring that the calling
// code cannot violate that `&`-reference.
if data.created_at == current_revision {
if data_ref.created_at == current_revision {
drop(data);
return Update::Current(&Self::get_from_map(&self.map, runtime, id));
return Update::Current(Self::get_from_map(&self.map, runtime, id));
}
data.created_at = current_revision;
data_ref.created_at = current_revision;
Update::Outdated(UpdateRef { guard: data })
}
@ -155,7 +164,7 @@ where
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db ValueStruct<C> {
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> {
Self::get_from_map(&self.map, runtime, id)
}
@ -169,14 +178,17 @@ where
map: &'db FxDashMap<Id, Alloc<ValueStruct<C>>>,
runtime: &'db Runtime,
id: Id,
) -> &'db ValueStruct<C> {
) -> C::Struct<'db> {
let data = map.get(&id).unwrap();
let data: &ValueStruct<C> = &**data;
// UNSAFE: We permit `&`-access in the current revision once data.created_at
// has been updated to the current revision (which we check below).
let data_ref: &ValueStruct<C> = unsafe { data.as_ref() };
// Before we drop the lock, check that the value has
// been updated in this revision. This is what allows us to return a ``
let current_revision = runtime.current_revision();
let created_at = data.created_at;
let created_at = data_ref.created_at;
assert!(
created_at == current_revision,
"access to tracked struct from previous revision"
@ -187,7 +199,7 @@ where
// * Value will not be updated again in this revision,
// and revision will not change so long as runtime is shared
// * We only remove values from the map when we have `&mut self`
unsafe { transmute_lifetime(map, data) }
unsafe { C::struct_from_raw(data.as_raw()) }
}
/// Remove the entry for `id` from the map.
@ -195,7 +207,8 @@ where
/// NB. the data won't actually be freed until `drop_deleted_entries` is called.
pub fn delete(&self, id: Id) -> Option<KeyStruct> {
if let Some((_, data)) = self.map.remove(&id) {
let key = data.key;
// UNSAFE: The `key` field is immutable once `ValueStruct` is created.
let key = unsafe { data.as_ref() }.key;
self.deleted_entries.push(data);
Some(key)
} else {
@ -219,7 +232,7 @@ where
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db ValueStruct<C> {
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: Id) -> C::Struct<'db> {
StructMap::get_from_map(&self.map, runtime, id)
}
}
@ -239,13 +252,12 @@ where
C: Configuration,
{
/// Finalize this update, freezing the value for the rest of the revision.
pub fn freeze(self) -> &'db ValueStruct<C> {
pub fn freeze(self) -> C::Struct<'db> {
// Unsafety clause:
//
// see `get` above
let data: &ValueStruct<C> = &*self.guard;
let dummy: &'db () = &();
unsafe { transmute_lifetime(dummy, data) }
let data = self.guard.as_raw();
unsafe { C::struct_from_raw(data) }
}
}
@ -256,7 +268,7 @@ where
type Target = ValueStruct<C>;
fn deref(&self) -> &Self::Target {
&self.guard
unsafe { self.guard.as_ref() }
}
}
@ -265,6 +277,6 @@ where
C: Configuration,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
unsafe { self.guard.as_mut() }
}
}

View file

@ -40,6 +40,7 @@ where
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(&'db self, runtime: &'db Runtime, id: Id) -> &'db C::Fields<'db> {
let data = self.struct_map.get(runtime, id);
let data = C::deref_struct(data);
let changed_at = C::revision(&data.revisions, self.field_index);
@ -78,6 +79,7 @@ where
let runtime = db.runtime();
let id = input.key_index.unwrap();
let data = self.struct_map.get(runtime, id);
let data = C::deref_struct(data);
let field_changed_at = C::revision(&data.revisions, self.field_index);
field_changed_at > revision
}