mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 09:46:06 +00:00
create a struct_map
that encapsulates access
The internal API is now based around providing references to the `TrackedStructValue`. Documenting the invariants led to one interesting case, which is that we sometimes verify a tracked struct as not having changed (and even create `&`-ref to it!) but then re-execute the function around it. We now guarantee that, in this case, the data does not change, even if it has leaked values. This is required to ensure soundness. Add a test case about it.
This commit is contained in:
parent
20cb307301
commit
ea1d452143
4 changed files with 449 additions and 116 deletions
|
@ -1,23 +1,21 @@
|
||||||
use std::{fmt, hash::Hash, sync::Arc};
|
use std::{fmt, hash::Hash};
|
||||||
|
|
||||||
use crossbeam::queue::SegQueue;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
cycle::CycleRecoveryStrategy,
|
cycle::CycleRecoveryStrategy,
|
||||||
hash::FxDashMap,
|
|
||||||
id::AsId,
|
id::AsId,
|
||||||
ingredient::{fmt_index, Ingredient, IngredientRequiresReset},
|
ingredient::{fmt_index, Ingredient, IngredientRequiresReset},
|
||||||
ingredient_list::IngredientList,
|
ingredient_list::IngredientList,
|
||||||
interned::{InternedId, InternedIngredient},
|
interned::{InternedId, InternedIngredient},
|
||||||
key::{DatabaseKeyIndex, DependencyIndex},
|
key::{DatabaseKeyIndex, DependencyIndex},
|
||||||
plumbing::transmute_lifetime,
|
|
||||||
runtime::{local_state::QueryOrigin, Runtime},
|
runtime::{local_state::QueryOrigin, Runtime},
|
||||||
salsa_struct::SalsaStructInDb,
|
salsa_struct::SalsaStructInDb,
|
||||||
Database, Durability, Event, IngredientIndex, Revision,
|
Database, Durability, Event, IngredientIndex, Revision,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use self::struct_map::{StructMap, Update};
|
||||||
pub use self::tracked_field::TrackedFieldIngredient;
|
pub use self::tracked_field::TrackedFieldIngredient;
|
||||||
|
|
||||||
|
mod struct_map;
|
||||||
mod tracked_field;
|
mod tracked_field;
|
||||||
|
|
||||||
// ANCHOR: Configuration
|
// ANCHOR: Configuration
|
||||||
|
@ -97,7 +95,7 @@ where
|
||||||
{
|
{
|
||||||
interned: InternedIngredient<C::Id, TrackedStructKey>,
|
interned: InternedIngredient<C::Id, TrackedStructKey>,
|
||||||
|
|
||||||
entity_data: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
struct_map: struct_map::StructMap<C>,
|
||||||
|
|
||||||
/// A list of each tracked function `f` whose key is this
|
/// A list of each tracked function `f` whose key is this
|
||||||
/// tracked struct.
|
/// tracked struct.
|
||||||
|
@ -107,13 +105,6 @@ where
|
||||||
/// so they can remove any data tied to that instance.
|
/// so they can remove any data tied to that instance.
|
||||||
dependent_fns: IngredientList,
|
dependent_fns: IngredientList,
|
||||||
|
|
||||||
/// When specific entities are deleted, their data is added
|
|
||||||
/// to this vector rather than being immediately freed. This is because we may` have
|
|
||||||
/// references to that data floating about that are tied to the lifetime of some
|
|
||||||
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
|
|
||||||
/// guaranteeing that there are no more references to it.
|
|
||||||
deleted_entries: SegQueue<Box<TrackedStructValue<C>>>,
|
|
||||||
|
|
||||||
debug_name: &'static str,
|
debug_name: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,16 +150,6 @@ where
|
||||||
}
|
}
|
||||||
// ANCHOR_END: TrackedStructValue
|
// ANCHOR_END: TrackedStructValue
|
||||||
|
|
||||||
impl<C> TrackedStructValue<C>
|
|
||||||
where
|
|
||||||
C: Configuration,
|
|
||||||
{
|
|
||||||
/// The id of this struct in the ingredient.
|
|
||||||
pub fn id(&self) -> C::Id {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
|
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
|
||||||
pub struct Disambiguator(pub u32);
|
pub struct Disambiguator(pub u32);
|
||||||
|
|
||||||
|
@ -179,9 +160,8 @@ where
|
||||||
pub fn new(index: IngredientIndex, debug_name: &'static str) -> Self {
|
pub fn new(index: IngredientIndex, debug_name: &'static str) -> Self {
|
||||||
Self {
|
Self {
|
||||||
interned: InternedIngredient::new(index, debug_name),
|
interned: InternedIngredient::new(index, debug_name),
|
||||||
entity_data: Default::default(),
|
struct_map: StructMap::new(),
|
||||||
dependent_fns: IngredientList::new(),
|
dependent_fns: IngredientList::new(),
|
||||||
deleted_entries: SegQueue::default(),
|
|
||||||
debug_name,
|
debug_name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,7 +175,7 @@ where
|
||||||
TrackedFieldIngredient {
|
TrackedFieldIngredient {
|
||||||
ingredient_index: field_ingredient_index,
|
ingredient_index: field_ingredient_index,
|
||||||
field_index,
|
field_index,
|
||||||
entity_data: self.entity_data.clone(),
|
struct_map: self.struct_map.view(),
|
||||||
struct_debug_name: self.debug_name,
|
struct_debug_name: self.debug_name,
|
||||||
field_debug_name,
|
field_debug_name,
|
||||||
}
|
}
|
||||||
|
@ -208,7 +188,11 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> &TrackedStructValue<C> {
|
pub fn new_struct<'db>(
|
||||||
|
&'db self,
|
||||||
|
runtime: &'db Runtime,
|
||||||
|
fields: C::Fields,
|
||||||
|
) -> &'db TrackedStructValue<C> {
|
||||||
let data_hash = crate::hash::hash(&C::id_fields(&fields));
|
let data_hash = crate::hash::hash(&C::id_fields(&fields));
|
||||||
|
|
||||||
let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity(
|
let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity(
|
||||||
|
@ -225,28 +209,32 @@ where
|
||||||
let (id, new_id) = self.interned.intern_full(runtime, entity_key);
|
let (id, new_id) = self.interned.intern_full(runtime, entity_key);
|
||||||
runtime.add_output(self.database_key_index(id).into());
|
runtime.add_output(self.database_key_index(id).into());
|
||||||
|
|
||||||
let pointer: *const TrackedStructValue<C>;
|
|
||||||
let current_revision = runtime.current_revision();
|
let current_revision = runtime.current_revision();
|
||||||
if new_id {
|
if new_id {
|
||||||
let data = Box::new(TrackedStructValue {
|
self.struct_map.insert(
|
||||||
|
runtime,
|
||||||
|
TrackedStructValue {
|
||||||
id,
|
id,
|
||||||
created_at: current_revision,
|
created_at: current_revision,
|
||||||
durability: current_deps.durability,
|
durability: current_deps.durability,
|
||||||
fields,
|
fields,
|
||||||
revisions: C::new_revisions(current_deps.changed_at),
|
revisions: C::new_revisions(current_deps.changed_at),
|
||||||
});
|
},
|
||||||
|
)
|
||||||
// Keep a pointer into the box for later
|
|
||||||
pointer = &*data;
|
|
||||||
|
|
||||||
let old_value = self.entity_data.insert(id, data);
|
|
||||||
assert!(old_value.is_none());
|
|
||||||
} else {
|
} else {
|
||||||
let mut data = self.entity_data.get_mut(&id).unwrap();
|
match self.struct_map.update(runtime, id) {
|
||||||
let data = &mut *data;
|
Update::Current(r) => {
|
||||||
|
// All inputs up to this point were previously
|
||||||
|
// observed to be green and this struct was already
|
||||||
|
// verified. Therefore, the durability ought not to have
|
||||||
|
// changed (nor the field values, but the user could've
|
||||||
|
// done something stupid, so we can't *assert* this is true).
|
||||||
|
assert!(r.durability == current_deps.durability);
|
||||||
|
|
||||||
// Keep a pointer into the box for later
|
r
|
||||||
pointer = &**data;
|
}
|
||||||
|
Update::Outdated(mut data_ref) => {
|
||||||
|
let data = &mut *data_ref;
|
||||||
|
|
||||||
// SAFETY: We assert that the pointer to `data.revisions`
|
// SAFETY: We assert that the pointer to `data.revisions`
|
||||||
// is a pointer into the database referencing a value
|
// is a pointer into the database referencing a value
|
||||||
|
@ -266,15 +254,11 @@ where
|
||||||
}
|
}
|
||||||
data.created_at = current_revision;
|
data.created_at = current_revision;
|
||||||
data.durability = current_deps.durability;
|
data.durability = current_deps.durability;
|
||||||
}
|
|
||||||
|
|
||||||
// Unsafety clause:
|
data_ref.freeze()
|
||||||
//
|
}
|
||||||
// * The box is owned by self and, although the box has been moved,
|
}
|
||||||
// the pointer is to the contents of the box, which have a stable
|
}
|
||||||
// address.
|
|
||||||
// * Values are only removed or altered when we have `&mut self`.
|
|
||||||
unsafe { transmute_lifetime(self, &*pointer) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Deletes the given entities. This is used after a query `Q` executes and we can compare
|
/// Deletes the given entities. This is used after a query `Q` executes and we can compare
|
||||||
|
@ -296,9 +280,7 @@ where
|
||||||
});
|
});
|
||||||
|
|
||||||
self.interned.delete_index(id);
|
self.interned.delete_index(id);
|
||||||
if let Some((_, data)) = self.entity_data.remove(&id) {
|
self.struct_map.delete(id);
|
||||||
self.deleted_entries.push(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
for dependent_fn in self.dependent_fns.iter() {
|
for dependent_fn in self.dependent_fns.iter() {
|
||||||
db.salsa_struct_deleted(dependent_fn, id.as_id());
|
db.salsa_struct_deleted(dependent_fn, id.as_id());
|
||||||
|
@ -334,16 +316,16 @@ where
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mark_validated_output(
|
fn mark_validated_output<'db>(
|
||||||
&self,
|
&'db self,
|
||||||
db: &DB,
|
db: &'db DB,
|
||||||
_executor: DatabaseKeyIndex,
|
_executor: DatabaseKeyIndex,
|
||||||
output_key: Option<crate::Id>,
|
output_key: Option<crate::Id>,
|
||||||
) {
|
) {
|
||||||
|
let runtime = db.runtime();
|
||||||
let output_key = output_key.unwrap();
|
let output_key = output_key.unwrap();
|
||||||
let output_key: C::Id = <C::Id>::from_id(output_key);
|
let output_key: C::Id = <C::Id>::from_id(output_key);
|
||||||
let mut entity = self.entity_data.get_mut(&output_key).unwrap();
|
self.struct_map.validate(runtime, output_key);
|
||||||
entity.created_at = db.runtime().current_revision();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_stale_output(
|
fn remove_stale_output(
|
||||||
|
@ -362,7 +344,7 @@ where
|
||||||
|
|
||||||
fn reset_for_new_revision(&mut self) {
|
fn reset_for_new_revision(&mut self) {
|
||||||
self.interned.clear_deleted_indices();
|
self.interned.clear_deleted_indices();
|
||||||
std::mem::take(&mut self.deleted_entries);
|
self.struct_map.drop_deleted_entries();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) {
|
fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) {
|
||||||
|
@ -380,3 +362,13 @@ where
|
||||||
{
|
{
|
||||||
const RESET_ON_NEW_REVISION: bool = true;
|
const RESET_ON_NEW_REVISION: bool = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<C> TrackedStructValue<C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
/// The id of this struct in the ingredient.
|
||||||
|
pub fn id(&self) -> C::Id {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
259
components/salsa-2022/src/tracked_struct/struct_map.rs
Normal file
259
components/salsa-2022/src/tracked_struct/struct_map.rs
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
use std::{
|
||||||
|
ops::{Deref, DerefMut},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crossbeam::queue::SegQueue;
|
||||||
|
use dashmap::mapref::one::RefMut;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
hash::{FxDashMap, FxHasher},
|
||||||
|
plumbing::transmute_lifetime,
|
||||||
|
Runtime,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{Configuration, TrackedStructValue};
|
||||||
|
|
||||||
|
pub(crate) struct StructMap<C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
map: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
||||||
|
|
||||||
|
/// When specific entities are deleted, their data is added
|
||||||
|
/// to this vector rather than being immediately freed. This is because we may` have
|
||||||
|
/// references to that data floating about that are tied to the lifetime of some
|
||||||
|
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
|
||||||
|
/// guaranteeing that there are no more references to it.
|
||||||
|
deleted_entries: SegQueue<Box<TrackedStructValue<C>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct StructMapView<C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
map: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return value for [`StructMap`][]'s `update` method.
|
||||||
|
pub(crate) enum Update<'db, C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
/// Indicates that the given struct has not yet been verified in this revision.
|
||||||
|
/// The [`UpdateRef`][] gives mutable access to the field contents so that
|
||||||
|
/// its fields can be compared and updated.
|
||||||
|
Outdated(UpdateRef<'db, C>),
|
||||||
|
|
||||||
|
/// Indicates that we have already verified that all the inputs accessed prior
|
||||||
|
/// to this struct creation were up-to-date, and therefore the field contents
|
||||||
|
/// ought not to have changed (barring user error). Returns a shared reference
|
||||||
|
/// because caller cannot safely modify fields at this point.
|
||||||
|
Current(&'db TrackedStructValue<C>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> StructMap<C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
map: Arc::new(FxDashMap::default()),
|
||||||
|
deleted_entries: SegQueue::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a secondary view onto this struct-map that can be used to fetch entries.
|
||||||
|
pub fn view(&self) -> StructMapView<C> {
|
||||||
|
StructMapView {
|
||||||
|
map: self.map.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert the given tracked struct value into the map.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// * If value with same `value.id` is already present in the map.
|
||||||
|
/// * If value not created in current revision.
|
||||||
|
pub fn insert<'db>(
|
||||||
|
&'db self,
|
||||||
|
runtime: &'db Runtime,
|
||||||
|
value: TrackedStructValue<C>,
|
||||||
|
) -> &TrackedStructValue<C> {
|
||||||
|
assert_eq!(value.created_at, runtime.current_revision());
|
||||||
|
|
||||||
|
let boxed_value = Box::new(value);
|
||||||
|
let pointer = std::ptr::addr_of!(*boxed_value);
|
||||||
|
|
||||||
|
let old_value = self.map.insert(boxed_value.id, boxed_value);
|
||||||
|
assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here
|
||||||
|
|
||||||
|
// Unsafety clause:
|
||||||
|
//
|
||||||
|
// * The box is owned by self and, although the box has been moved,
|
||||||
|
// the pointer is to the contents of the box, which have a stable
|
||||||
|
// address.
|
||||||
|
// * Values are only removed or altered when we have `&mut self`.
|
||||||
|
unsafe { transmute_lifetime(self, &*pointer) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: C::Id) {
|
||||||
|
let mut data = self.map.get_mut(&id).unwrap();
|
||||||
|
|
||||||
|
// Never update a struct twice in the same revision.
|
||||||
|
let current_revision = runtime.current_revision();
|
||||||
|
assert!(data.created_at < current_revision);
|
||||||
|
data.created_at = current_revision;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get mutable access to the data for `id` -- this holds a write lock for the duration
|
||||||
|
/// of the returned value.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// * If the value is not present in the map.
|
||||||
|
/// * If the value is already updated in this revision.
|
||||||
|
pub fn update<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> Update<'db, C> {
|
||||||
|
let mut data = self.map.get_mut(&id).unwrap();
|
||||||
|
|
||||||
|
// Never update a struct twice in the same revision.
|
||||||
|
let current_revision = runtime.current_revision();
|
||||||
|
|
||||||
|
// Subtle: it's possible that this struct was already validated
|
||||||
|
// in this revision. What can happen (e.g., in the test
|
||||||
|
// `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`)
|
||||||
|
// is that
|
||||||
|
//
|
||||||
|
// * Revision 1:
|
||||||
|
// * Tracked function F creates tracked struct S
|
||||||
|
// * F reads input I
|
||||||
|
//
|
||||||
|
// In Revision 2, I is changed, and F is re-executed.
|
||||||
|
// We try to validate F's inputs/outputs, which is the list [output: S, input: I].
|
||||||
|
// As no inputs have changed by the time we reach S, we mark it as verified.
|
||||||
|
// But then input I is seen to hvae changed, and so we re-execute F.
|
||||||
|
// Note that we *know* that S will have the same value (barring program bugs).
|
||||||
|
//
|
||||||
|
// Further complicating things: it is possible that F calls F2
|
||||||
|
// and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to
|
||||||
|
// re-execute which means that it may indeed have read from S's fields
|
||||||
|
// during the current revision and thus obtained an `&` reference to those fields
|
||||||
|
// that is still live.
|
||||||
|
//
|
||||||
|
// For this reason, we just return `None` in this case, ensuring that the calling
|
||||||
|
// code cannot violate that `&`-reference.
|
||||||
|
if data.created_at == current_revision {
|
||||||
|
drop(data);
|
||||||
|
return Update::Current(&Self::get_from_map(&self.map, runtime, id));
|
||||||
|
}
|
||||||
|
|
||||||
|
data.created_at = current_revision;
|
||||||
|
Update::Outdated(UpdateRef { guard: data })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function, provides shared functionality for [`StructMapView`][]
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// * If the value is not present in the map.
|
||||||
|
/// * If the value has not been updated in this revision.
|
||||||
|
fn get_from_map<'db>(
|
||||||
|
map: &'db FxDashMap<C::Id, Box<TrackedStructValue<C>>>,
|
||||||
|
runtime: &'db Runtime,
|
||||||
|
id: C::Id,
|
||||||
|
) -> &'db TrackedStructValue<C> {
|
||||||
|
let data = map.get(&id).unwrap();
|
||||||
|
let data: &TrackedStructValue<C> = &**data;
|
||||||
|
|
||||||
|
// Before we drop the lock, check that the value has
|
||||||
|
// been updated in this revision. This is what allows us to return a ``
|
||||||
|
let current_revision = runtime.current_revision();
|
||||||
|
let created_at = data.created_at;
|
||||||
|
assert!(
|
||||||
|
created_at == current_revision,
|
||||||
|
"access to tracked struct from previous revision"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Unsafety clause:
|
||||||
|
//
|
||||||
|
// * Value will not be updated again in this revision,
|
||||||
|
// and revision will not change so long as runtime is shared
|
||||||
|
// * We only remove values from the map when we have `&mut self`
|
||||||
|
unsafe { transmute_lifetime(map, data) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove the entry for `id` from the map.
|
||||||
|
///
|
||||||
|
/// NB. the data won't actually be freed until `drop_deleted_entries` is called.
|
||||||
|
pub fn delete(&self, id: C::Id) {
|
||||||
|
if let Some((_, data)) = self.map.remove(&id) {
|
||||||
|
self.deleted_entries.push(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop all entries deleted until now.
|
||||||
|
pub fn drop_deleted_entries(&mut self) {
|
||||||
|
std::mem::take(&mut self.deleted_entries);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> StructMapView<C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
/// Get a pointer to the data for the given `id`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// * If the value is not present in the map.
|
||||||
|
/// * If the value has not been updated in this revision.
|
||||||
|
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db TrackedStructValue<C> {
|
||||||
|
StructMap::get_from_map(&self.map, runtime, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable reference to the data for a single struct.
|
||||||
|
/// Can be "frozen" to yield an `&` that will remain valid
|
||||||
|
/// until the end of the revision.
|
||||||
|
pub(crate) struct UpdateRef<'db, C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
guard: RefMut<'db, C::Id, Box<TrackedStructValue<C>>, FxHasher>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'db, C> UpdateRef<'db, C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
/// Finalize this update, freezing the value for the rest of the revision.
|
||||||
|
pub fn freeze(self) -> &'db TrackedStructValue<C> {
|
||||||
|
// Unsafety clause:
|
||||||
|
//
|
||||||
|
// see `get` above
|
||||||
|
let data: &TrackedStructValue<C> = &*self.guard;
|
||||||
|
let dummy: &'db () = &();
|
||||||
|
unsafe { transmute_lifetime(dummy, data) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> Deref for UpdateRef<'_, C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
type Target = TrackedStructValue<C>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.guard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> DerefMut for UpdateRef<'_, C>
|
||||||
|
where
|
||||||
|
C: Configuration,
|
||||||
|
{
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.guard
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,16 +1,11 @@
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
hash::FxDashMap,
|
|
||||||
id::AsId,
|
id::AsId,
|
||||||
ingredient::{Ingredient, IngredientRequiresReset},
|
ingredient::{Ingredient, IngredientRequiresReset},
|
||||||
key::DependencyIndex,
|
key::DependencyIndex,
|
||||||
plumbing::transmute_lifetime,
|
Database, IngredientIndex, Runtime,
|
||||||
tracked_struct::TrackedStructValue,
|
|
||||||
IngredientIndex, Runtime,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::Configuration;
|
use super::{struct_map::StructMapView, Configuration};
|
||||||
|
|
||||||
/// Created for each tracked struct.
|
/// Created for each tracked struct.
|
||||||
/// This ingredient only stores the "id" fields.
|
/// This ingredient only stores the "id" fields.
|
||||||
|
@ -27,7 +22,7 @@ where
|
||||||
/// Index of this ingredient in the database (used to construct database-ids, etc).
|
/// Index of this ingredient in the database (used to construct database-ids, etc).
|
||||||
pub(super) ingredient_index: IngredientIndex,
|
pub(super) ingredient_index: IngredientIndex,
|
||||||
pub(super) field_index: u32,
|
pub(super) field_index: u32,
|
||||||
pub(super) entity_data: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
pub(super) struct_map: StructMapView<C>,
|
||||||
pub(super) struct_debug_name: &'static str,
|
pub(super) struct_debug_name: &'static str,
|
||||||
pub(super) field_debug_name: &'static str,
|
pub(super) field_debug_name: &'static str,
|
||||||
}
|
}
|
||||||
|
@ -40,16 +35,7 @@ where
|
||||||
/// Note that this function returns the entire tuple of value fields.
|
/// Note that this function returns the entire tuple of value fields.
|
||||||
/// The caller is responible for selecting the appropriate element.
|
/// The caller is responible for selecting the appropriate element.
|
||||||
pub fn field<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db C::Fields {
|
pub fn field<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db C::Fields {
|
||||||
let Some(data) = self.entity_data.get(&id) else {
|
let data = self.struct_map.get(runtime, id);
|
||||||
panic!("no data found for entity id {id:?}");
|
|
||||||
};
|
|
||||||
|
|
||||||
let current_revision = runtime.current_revision();
|
|
||||||
let created_at = data.created_at;
|
|
||||||
assert!(
|
|
||||||
created_at == current_revision,
|
|
||||||
"access to tracked struct from previous revision"
|
|
||||||
);
|
|
||||||
|
|
||||||
let changed_at = C::revision(&data.revisions, self.field_index);
|
let changed_at = C::revision(&data.revisions, self.field_index);
|
||||||
|
|
||||||
|
@ -62,15 +48,13 @@ where
|
||||||
changed_at,
|
changed_at,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Unsafety clause:
|
&data.fields
|
||||||
//
|
|
||||||
// * Values are only removed or altered when we have `&mut self`
|
|
||||||
unsafe { transmute_lifetime(self, &data.fields) }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<DB: ?Sized, C> Ingredient<DB> for TrackedFieldIngredient<C>
|
impl<DB: ?Sized, C> Ingredient<DB> for TrackedFieldIngredient<C>
|
||||||
where
|
where
|
||||||
|
DB: Database,
|
||||||
C: Configuration,
|
C: Configuration,
|
||||||
{
|
{
|
||||||
fn ingredient_index(&self) -> IngredientIndex {
|
fn ingredient_index(&self) -> IngredientIndex {
|
||||||
|
@ -81,23 +65,18 @@ where
|
||||||
crate::cycle::CycleRecoveryStrategy::Panic
|
crate::cycle::CycleRecoveryStrategy::Panic
|
||||||
}
|
}
|
||||||
|
|
||||||
fn maybe_changed_after(
|
fn maybe_changed_after<'db>(
|
||||||
&self,
|
&'db self,
|
||||||
_db: &DB,
|
db: &'db DB,
|
||||||
input: crate::key::DependencyIndex,
|
input: crate::key::DependencyIndex,
|
||||||
revision: crate::Revision,
|
revision: crate::Revision,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
|
let runtime = db.runtime();
|
||||||
let id = <C::Id>::from_id(input.key_index.unwrap());
|
let id = <C::Id>::from_id(input.key_index.unwrap());
|
||||||
match self.entity_data.get(&id) {
|
let data = self.struct_map.get(runtime, id);
|
||||||
Some(data) => {
|
|
||||||
let field_changed_at = C::revision(&data.revisions, self.field_index);
|
let field_changed_at = C::revision(&data.revisions, self.field_index);
|
||||||
field_changed_at > revision
|
field_changed_at > revision
|
||||||
}
|
}
|
||||||
None => {
|
|
||||||
panic!("no data found for field `{id:?}`");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn origin(&self, _key_index: crate::Id) -> Option<crate::runtime::local_state::QueryOrigin> {
|
fn origin(&self, _key_index: crate::Id) -> Option<crate::runtime::local_state::QueryOrigin> {
|
||||||
None
|
None
|
||||||
|
|
103
salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs
Normal file
103
salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
//! Test that a `tracked` fn on a `salsa::input`
|
||||||
|
//! compiles and executes successfully.
|
||||||
|
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
|
use expect_test::expect;
|
||||||
|
use salsa::DebugWithDb;
|
||||||
|
use salsa_2022_tests::{HasLogger, Logger};
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static COUNTER: Cell<usize> = Cell::new(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::jar(db = Db)]
|
||||||
|
struct Jar(MyInput, MyTracked, function);
|
||||||
|
|
||||||
|
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
|
||||||
|
|
||||||
|
#[salsa::db(Jar)]
|
||||||
|
#[derive(Default)]
|
||||||
|
struct Database {
|
||||||
|
storage: salsa::Storage<Self>,
|
||||||
|
logger: Logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl salsa::Database for Database {
|
||||||
|
fn salsa_event(&self, event: salsa::Event) {
|
||||||
|
self.push_log(format!("{:?}", event.debug(self)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Db for Database {}
|
||||||
|
|
||||||
|
impl HasLogger for Database {
|
||||||
|
fn logger(&self) -> &Logger {
|
||||||
|
&self.logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::input]
|
||||||
|
struct MyInput {
|
||||||
|
field1: u32,
|
||||||
|
field2: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
struct MyTracked {
|
||||||
|
counter: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn function(db: &dyn Db, input: MyInput) -> usize {
|
||||||
|
// Read input 1
|
||||||
|
let _field1 = input.field1(db);
|
||||||
|
|
||||||
|
// **BAD:** Leak in the value of the counter non-deterministically
|
||||||
|
let counter = COUNTER.with(|c| c.get());
|
||||||
|
|
||||||
|
// Create the tracked struct, which (from salsa's POV), only depends on field1;
|
||||||
|
// but which actually depends on the leaked value.
|
||||||
|
let tracked = MyTracked::new(db, counter);
|
||||||
|
|
||||||
|
// Read input 2. This will cause us to re-execute on revision 2.
|
||||||
|
let _field2 = input.field2(db);
|
||||||
|
|
||||||
|
tracked.counter(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_leaked_inputs_ignored() {
|
||||||
|
let mut db = Database::default();
|
||||||
|
|
||||||
|
let input = MyInput::new(&db, 10, 20);
|
||||||
|
let result_in_rev_1 = function(&db, input);
|
||||||
|
db.assert_logs(expect![[r#"
|
||||||
|
[
|
||||||
|
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }",
|
||||||
|
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }",
|
||||||
|
]"#]]);
|
||||||
|
|
||||||
|
assert_eq!(result_in_rev_1, 0);
|
||||||
|
|
||||||
|
// Modify field2 so that `function` is seen to have changed --
|
||||||
|
// but only *after* the tracked struct is created.
|
||||||
|
input.set_field2(&mut db).to(30);
|
||||||
|
|
||||||
|
// Also modify the thread-local counter
|
||||||
|
COUNTER.with(|c| c.set(100));
|
||||||
|
|
||||||
|
let result_in_rev_2 = function(&db, input);
|
||||||
|
db.assert_logs(expect![[r#"
|
||||||
|
[
|
||||||
|
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }",
|
||||||
|
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }",
|
||||||
|
]"#]]);
|
||||||
|
|
||||||
|
// Because salsa did not see any way for the tracked
|
||||||
|
// struct to have changed, its field values will not have
|
||||||
|
// been updated, even though in theory they would have
|
||||||
|
// the leaked value from the counter.
|
||||||
|
assert_eq!(result_in_rev_2, 0);
|
||||||
|
}
|
Loading…
Reference in a new issue