mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-22 21:05:11 +00:00
create a struct_map
that encapsulates access
The internal API is now based around providing references to the `TrackedStructValue`. Documenting the invariants led to one interesting case, which is that we sometimes verify a tracked struct as not having changed (and even create `&`-ref to it!) but then re-execute the function around it. We now guarantee that, in this case, the data does not change, even if it has leaked values. This is required to ensure soundness. Add a test case about it.
This commit is contained in:
parent
20cb307301
commit
ea1d452143
4 changed files with 449 additions and 116 deletions
|
@ -1,23 +1,21 @@
|
|||
use std::{fmt, hash::Hash, sync::Arc};
|
||||
|
||||
use crossbeam::queue::SegQueue;
|
||||
use std::{fmt, hash::Hash};
|
||||
|
||||
use crate::{
|
||||
cycle::CycleRecoveryStrategy,
|
||||
hash::FxDashMap,
|
||||
id::AsId,
|
||||
ingredient::{fmt_index, Ingredient, IngredientRequiresReset},
|
||||
ingredient_list::IngredientList,
|
||||
interned::{InternedId, InternedIngredient},
|
||||
key::{DatabaseKeyIndex, DependencyIndex},
|
||||
plumbing::transmute_lifetime,
|
||||
runtime::{local_state::QueryOrigin, Runtime},
|
||||
salsa_struct::SalsaStructInDb,
|
||||
Database, Durability, Event, IngredientIndex, Revision,
|
||||
};
|
||||
|
||||
use self::struct_map::{StructMap, Update};
|
||||
pub use self::tracked_field::TrackedFieldIngredient;
|
||||
|
||||
mod struct_map;
|
||||
mod tracked_field;
|
||||
|
||||
// ANCHOR: Configuration
|
||||
|
@ -97,7 +95,7 @@ where
|
|||
{
|
||||
interned: InternedIngredient<C::Id, TrackedStructKey>,
|
||||
|
||||
entity_data: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
||||
struct_map: struct_map::StructMap<C>,
|
||||
|
||||
/// A list of each tracked function `f` whose key is this
|
||||
/// tracked struct.
|
||||
|
@ -107,13 +105,6 @@ where
|
|||
/// so they can remove any data tied to that instance.
|
||||
dependent_fns: IngredientList,
|
||||
|
||||
/// When specific entities are deleted, their data is added
|
||||
/// to this vector rather than being immediately freed. This is because we may` have
|
||||
/// references to that data floating about that are tied to the lifetime of some
|
||||
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
|
||||
/// guaranteeing that there are no more references to it.
|
||||
deleted_entries: SegQueue<Box<TrackedStructValue<C>>>,
|
||||
|
||||
debug_name: &'static str,
|
||||
}
|
||||
|
||||
|
@ -159,16 +150,6 @@ where
|
|||
}
|
||||
// ANCHOR_END: TrackedStructValue
|
||||
|
||||
impl<C> TrackedStructValue<C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
/// The id of this struct in the ingredient.
|
||||
pub fn id(&self) -> C::Id {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
|
||||
pub struct Disambiguator(pub u32);
|
||||
|
||||
|
@ -179,9 +160,8 @@ where
|
|||
pub fn new(index: IngredientIndex, debug_name: &'static str) -> Self {
|
||||
Self {
|
||||
interned: InternedIngredient::new(index, debug_name),
|
||||
entity_data: Default::default(),
|
||||
struct_map: StructMap::new(),
|
||||
dependent_fns: IngredientList::new(),
|
||||
deleted_entries: SegQueue::default(),
|
||||
debug_name,
|
||||
}
|
||||
}
|
||||
|
@ -195,7 +175,7 @@ where
|
|||
TrackedFieldIngredient {
|
||||
ingredient_index: field_ingredient_index,
|
||||
field_index,
|
||||
entity_data: self.entity_data.clone(),
|
||||
struct_map: self.struct_map.view(),
|
||||
struct_debug_name: self.debug_name,
|
||||
field_debug_name,
|
||||
}
|
||||
|
@ -208,7 +188,11 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> &TrackedStructValue<C> {
|
||||
pub fn new_struct<'db>(
|
||||
&'db self,
|
||||
runtime: &'db Runtime,
|
||||
fields: C::Fields,
|
||||
) -> &'db TrackedStructValue<C> {
|
||||
let data_hash = crate::hash::hash(&C::id_fields(&fields));
|
||||
|
||||
let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity(
|
||||
|
@ -225,56 +209,56 @@ where
|
|||
let (id, new_id) = self.interned.intern_full(runtime, entity_key);
|
||||
runtime.add_output(self.database_key_index(id).into());
|
||||
|
||||
let pointer: *const TrackedStructValue<C>;
|
||||
let current_revision = runtime.current_revision();
|
||||
if new_id {
|
||||
let data = Box::new(TrackedStructValue {
|
||||
id,
|
||||
created_at: current_revision,
|
||||
durability: current_deps.durability,
|
||||
fields,
|
||||
revisions: C::new_revisions(current_deps.changed_at),
|
||||
});
|
||||
|
||||
// Keep a pointer into the box for later
|
||||
pointer = &*data;
|
||||
|
||||
let old_value = self.entity_data.insert(id, data);
|
||||
assert!(old_value.is_none());
|
||||
} else {
|
||||
let mut data = self.entity_data.get_mut(&id).unwrap();
|
||||
let data = &mut *data;
|
||||
|
||||
// Keep a pointer into the box for later
|
||||
pointer = &**data;
|
||||
|
||||
// SAFETY: We assert that the pointer to `data.revisions`
|
||||
// is a pointer into the database referencing a value
|
||||
// from a previous revision. As such, it continues to meet
|
||||
// its validity invariant and any owned content also continues
|
||||
// to meet its safety invariant.
|
||||
unsafe {
|
||||
C::update_fields(
|
||||
current_revision,
|
||||
&mut data.revisions,
|
||||
std::ptr::addr_of_mut!(data.fields),
|
||||
self.struct_map.insert(
|
||||
runtime,
|
||||
TrackedStructValue {
|
||||
id,
|
||||
created_at: current_revision,
|
||||
durability: current_deps.durability,
|
||||
fields,
|
||||
);
|
||||
}
|
||||
if current_deps.durability < data.durability {
|
||||
data.revisions = C::new_revisions(current_revision);
|
||||
}
|
||||
data.created_at = current_revision;
|
||||
data.durability = current_deps.durability;
|
||||
}
|
||||
revisions: C::new_revisions(current_deps.changed_at),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
match self.struct_map.update(runtime, id) {
|
||||
Update::Current(r) => {
|
||||
// All inputs up to this point were previously
|
||||
// observed to be green and this struct was already
|
||||
// verified. Therefore, the durability ought not to have
|
||||
// changed (nor the field values, but the user could've
|
||||
// done something stupid, so we can't *assert* this is true).
|
||||
assert!(r.durability == current_deps.durability);
|
||||
|
||||
// Unsafety clause:
|
||||
//
|
||||
// * The box is owned by self and, although the box has been moved,
|
||||
// the pointer is to the contents of the box, which have a stable
|
||||
// address.
|
||||
// * Values are only removed or altered when we have `&mut self`.
|
||||
unsafe { transmute_lifetime(self, &*pointer) }
|
||||
r
|
||||
}
|
||||
Update::Outdated(mut data_ref) => {
|
||||
let data = &mut *data_ref;
|
||||
|
||||
// SAFETY: We assert that the pointer to `data.revisions`
|
||||
// is a pointer into the database referencing a value
|
||||
// from a previous revision. As such, it continues to meet
|
||||
// its validity invariant and any owned content also continues
|
||||
// to meet its safety invariant.
|
||||
unsafe {
|
||||
C::update_fields(
|
||||
current_revision,
|
||||
&mut data.revisions,
|
||||
std::ptr::addr_of_mut!(data.fields),
|
||||
fields,
|
||||
);
|
||||
}
|
||||
if current_deps.durability < data.durability {
|
||||
data.revisions = C::new_revisions(current_revision);
|
||||
}
|
||||
data.created_at = current_revision;
|
||||
data.durability = current_deps.durability;
|
||||
|
||||
data_ref.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes the given entities. This is used after a query `Q` executes and we can compare
|
||||
|
@ -296,9 +280,7 @@ where
|
|||
});
|
||||
|
||||
self.interned.delete_index(id);
|
||||
if let Some((_, data)) = self.entity_data.remove(&id) {
|
||||
self.deleted_entries.push(data);
|
||||
}
|
||||
self.struct_map.delete(id);
|
||||
|
||||
for dependent_fn in self.dependent_fns.iter() {
|
||||
db.salsa_struct_deleted(dependent_fn, id.as_id());
|
||||
|
@ -334,16 +316,16 @@ where
|
|||
None
|
||||
}
|
||||
|
||||
fn mark_validated_output(
|
||||
&self,
|
||||
db: &DB,
|
||||
fn mark_validated_output<'db>(
|
||||
&'db self,
|
||||
db: &'db DB,
|
||||
_executor: DatabaseKeyIndex,
|
||||
output_key: Option<crate::Id>,
|
||||
) {
|
||||
let runtime = db.runtime();
|
||||
let output_key = output_key.unwrap();
|
||||
let output_key: C::Id = <C::Id>::from_id(output_key);
|
||||
let mut entity = self.entity_data.get_mut(&output_key).unwrap();
|
||||
entity.created_at = db.runtime().current_revision();
|
||||
self.struct_map.validate(runtime, output_key);
|
||||
}
|
||||
|
||||
fn remove_stale_output(
|
||||
|
@ -362,7 +344,7 @@ where
|
|||
|
||||
fn reset_for_new_revision(&mut self) {
|
||||
self.interned.clear_deleted_indices();
|
||||
std::mem::take(&mut self.deleted_entries);
|
||||
self.struct_map.drop_deleted_entries();
|
||||
}
|
||||
|
||||
fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) {
|
||||
|
@ -380,3 +362,13 @@ where
|
|||
{
|
||||
const RESET_ON_NEW_REVISION: bool = true;
|
||||
}
|
||||
|
||||
impl<C> TrackedStructValue<C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
/// The id of this struct in the ingredient.
|
||||
pub fn id(&self) -> C::Id {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
|
259
components/salsa-2022/src/tracked_struct/struct_map.rs
Normal file
259
components/salsa-2022/src/tracked_struct/struct_map.rs
Normal file
|
@ -0,0 +1,259 @@
|
|||
use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crossbeam::queue::SegQueue;
|
||||
use dashmap::mapref::one::RefMut;
|
||||
|
||||
use crate::{
|
||||
hash::{FxDashMap, FxHasher},
|
||||
plumbing::transmute_lifetime,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
use super::{Configuration, TrackedStructValue};
|
||||
|
||||
pub(crate) struct StructMap<C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
map: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
||||
|
||||
/// When specific entities are deleted, their data is added
|
||||
/// to this vector rather than being immediately freed. This is because we may` have
|
||||
/// references to that data floating about that are tied to the lifetime of some
|
||||
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
|
||||
/// guaranteeing that there are no more references to it.
|
||||
deleted_entries: SegQueue<Box<TrackedStructValue<C>>>,
|
||||
}
|
||||
|
||||
pub(crate) struct StructMapView<C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
map: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
||||
}
|
||||
|
||||
/// Return value for [`StructMap`][]'s `update` method.
|
||||
pub(crate) enum Update<'db, C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
/// Indicates that the given struct has not yet been verified in this revision.
|
||||
/// The [`UpdateRef`][] gives mutable access to the field contents so that
|
||||
/// its fields can be compared and updated.
|
||||
Outdated(UpdateRef<'db, C>),
|
||||
|
||||
/// Indicates that we have already verified that all the inputs accessed prior
|
||||
/// to this struct creation were up-to-date, and therefore the field contents
|
||||
/// ought not to have changed (barring user error). Returns a shared reference
|
||||
/// because caller cannot safely modify fields at this point.
|
||||
Current(&'db TrackedStructValue<C>),
|
||||
}
|
||||
|
||||
impl<C> StructMap<C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
map: Arc::new(FxDashMap::default()),
|
||||
deleted_entries: SegQueue::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a secondary view onto this struct-map that can be used to fetch entries.
|
||||
pub fn view(&self) -> StructMapView<C> {
|
||||
StructMapView {
|
||||
map: self.map.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert the given tracked struct value into the map.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If value with same `value.id` is already present in the map.
|
||||
/// * If value not created in current revision.
|
||||
pub fn insert<'db>(
|
||||
&'db self,
|
||||
runtime: &'db Runtime,
|
||||
value: TrackedStructValue<C>,
|
||||
) -> &TrackedStructValue<C> {
|
||||
assert_eq!(value.created_at, runtime.current_revision());
|
||||
|
||||
let boxed_value = Box::new(value);
|
||||
let pointer = std::ptr::addr_of!(*boxed_value);
|
||||
|
||||
let old_value = self.map.insert(boxed_value.id, boxed_value);
|
||||
assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here
|
||||
|
||||
// Unsafety clause:
|
||||
//
|
||||
// * The box is owned by self and, although the box has been moved,
|
||||
// the pointer is to the contents of the box, which have a stable
|
||||
// address.
|
||||
// * Values are only removed or altered when we have `&mut self`.
|
||||
unsafe { transmute_lifetime(self, &*pointer) }
|
||||
}
|
||||
|
||||
pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: C::Id) {
|
||||
let mut data = self.map.get_mut(&id).unwrap();
|
||||
|
||||
// Never update a struct twice in the same revision.
|
||||
let current_revision = runtime.current_revision();
|
||||
assert!(data.created_at < current_revision);
|
||||
data.created_at = current_revision;
|
||||
}
|
||||
|
||||
/// Get mutable access to the data for `id` -- this holds a write lock for the duration
|
||||
/// of the returned value.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the value is not present in the map.
|
||||
/// * If the value is already updated in this revision.
|
||||
pub fn update<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> Update<'db, C> {
|
||||
let mut data = self.map.get_mut(&id).unwrap();
|
||||
|
||||
// Never update a struct twice in the same revision.
|
||||
let current_revision = runtime.current_revision();
|
||||
|
||||
// Subtle: it's possible that this struct was already validated
|
||||
// in this revision. What can happen (e.g., in the test
|
||||
// `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`)
|
||||
// is that
|
||||
//
|
||||
// * Revision 1:
|
||||
// * Tracked function F creates tracked struct S
|
||||
// * F reads input I
|
||||
//
|
||||
// In Revision 2, I is changed, and F is re-executed.
|
||||
// We try to validate F's inputs/outputs, which is the list [output: S, input: I].
|
||||
// As no inputs have changed by the time we reach S, we mark it as verified.
|
||||
// But then input I is seen to hvae changed, and so we re-execute F.
|
||||
// Note that we *know* that S will have the same value (barring program bugs).
|
||||
//
|
||||
// Further complicating things: it is possible that F calls F2
|
||||
// and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to
|
||||
// re-execute which means that it may indeed have read from S's fields
|
||||
// during the current revision and thus obtained an `&` reference to those fields
|
||||
// that is still live.
|
||||
//
|
||||
// For this reason, we just return `None` in this case, ensuring that the calling
|
||||
// code cannot violate that `&`-reference.
|
||||
if data.created_at == current_revision {
|
||||
drop(data);
|
||||
return Update::Current(&Self::get_from_map(&self.map, runtime, id));
|
||||
}
|
||||
|
||||
data.created_at = current_revision;
|
||||
Update::Outdated(UpdateRef { guard: data })
|
||||
}
|
||||
|
||||
/// Helper function, provides shared functionality for [`StructMapView`][]
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the value is not present in the map.
|
||||
/// * If the value has not been updated in this revision.
|
||||
fn get_from_map<'db>(
|
||||
map: &'db FxDashMap<C::Id, Box<TrackedStructValue<C>>>,
|
||||
runtime: &'db Runtime,
|
||||
id: C::Id,
|
||||
) -> &'db TrackedStructValue<C> {
|
||||
let data = map.get(&id).unwrap();
|
||||
let data: &TrackedStructValue<C> = &**data;
|
||||
|
||||
// Before we drop the lock, check that the value has
|
||||
// been updated in this revision. This is what allows us to return a ``
|
||||
let current_revision = runtime.current_revision();
|
||||
let created_at = data.created_at;
|
||||
assert!(
|
||||
created_at == current_revision,
|
||||
"access to tracked struct from previous revision"
|
||||
);
|
||||
|
||||
// Unsafety clause:
|
||||
//
|
||||
// * Value will not be updated again in this revision,
|
||||
// and revision will not change so long as runtime is shared
|
||||
// * We only remove values from the map when we have `&mut self`
|
||||
unsafe { transmute_lifetime(map, data) }
|
||||
}
|
||||
|
||||
/// Remove the entry for `id` from the map.
|
||||
///
|
||||
/// NB. the data won't actually be freed until `drop_deleted_entries` is called.
|
||||
pub fn delete(&self, id: C::Id) {
|
||||
if let Some((_, data)) = self.map.remove(&id) {
|
||||
self.deleted_entries.push(data);
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop all entries deleted until now.
|
||||
pub fn drop_deleted_entries(&mut self) {
|
||||
std::mem::take(&mut self.deleted_entries);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> StructMapView<C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
/// Get a pointer to the data for the given `id`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the value is not present in the map.
|
||||
/// * If the value has not been updated in this revision.
|
||||
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db TrackedStructValue<C> {
|
||||
StructMap::get_from_map(&self.map, runtime, id)
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable reference to the data for a single struct.
|
||||
/// Can be "frozen" to yield an `&` that will remain valid
|
||||
/// until the end of the revision.
|
||||
pub(crate) struct UpdateRef<'db, C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
guard: RefMut<'db, C::Id, Box<TrackedStructValue<C>>, FxHasher>,
|
||||
}
|
||||
|
||||
impl<'db, C> UpdateRef<'db, C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
/// Finalize this update, freezing the value for the rest of the revision.
|
||||
pub fn freeze(self) -> &'db TrackedStructValue<C> {
|
||||
// Unsafety clause:
|
||||
//
|
||||
// see `get` above
|
||||
let data: &TrackedStructValue<C> = &*self.guard;
|
||||
let dummy: &'db () = &();
|
||||
unsafe { transmute_lifetime(dummy, data) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Deref for UpdateRef<'_, C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
type Target = TrackedStructValue<C>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.guard
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> DerefMut for UpdateRef<'_, C>
|
||||
where
|
||||
C: Configuration,
|
||||
{
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.guard
|
||||
}
|
||||
}
|
|
@ -1,16 +1,11 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
hash::FxDashMap,
|
||||
id::AsId,
|
||||
ingredient::{Ingredient, IngredientRequiresReset},
|
||||
key::DependencyIndex,
|
||||
plumbing::transmute_lifetime,
|
||||
tracked_struct::TrackedStructValue,
|
||||
IngredientIndex, Runtime,
|
||||
Database, IngredientIndex, Runtime,
|
||||
};
|
||||
|
||||
use super::Configuration;
|
||||
use super::{struct_map::StructMapView, Configuration};
|
||||
|
||||
/// Created for each tracked struct.
|
||||
/// This ingredient only stores the "id" fields.
|
||||
|
@ -27,7 +22,7 @@ where
|
|||
/// Index of this ingredient in the database (used to construct database-ids, etc).
|
||||
pub(super) ingredient_index: IngredientIndex,
|
||||
pub(super) field_index: u32,
|
||||
pub(super) entity_data: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
|
||||
pub(super) struct_map: StructMapView<C>,
|
||||
pub(super) struct_debug_name: &'static str,
|
||||
pub(super) field_debug_name: &'static str,
|
||||
}
|
||||
|
@ -40,16 +35,7 @@ where
|
|||
/// Note that this function returns the entire tuple of value fields.
|
||||
/// The caller is responible for selecting the appropriate element.
|
||||
pub fn field<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db C::Fields {
|
||||
let Some(data) = self.entity_data.get(&id) else {
|
||||
panic!("no data found for entity id {id:?}");
|
||||
};
|
||||
|
||||
let current_revision = runtime.current_revision();
|
||||
let created_at = data.created_at;
|
||||
assert!(
|
||||
created_at == current_revision,
|
||||
"access to tracked struct from previous revision"
|
||||
);
|
||||
let data = self.struct_map.get(runtime, id);
|
||||
|
||||
let changed_at = C::revision(&data.revisions, self.field_index);
|
||||
|
||||
|
@ -62,15 +48,13 @@ where
|
|||
changed_at,
|
||||
);
|
||||
|
||||
// Unsafety clause:
|
||||
//
|
||||
// * Values are only removed or altered when we have `&mut self`
|
||||
unsafe { transmute_lifetime(self, &data.fields) }
|
||||
&data.fields
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB: ?Sized, C> Ingredient<DB> for TrackedFieldIngredient<C>
|
||||
where
|
||||
DB: Database,
|
||||
C: Configuration,
|
||||
{
|
||||
fn ingredient_index(&self) -> IngredientIndex {
|
||||
|
@ -81,22 +65,17 @@ where
|
|||
crate::cycle::CycleRecoveryStrategy::Panic
|
||||
}
|
||||
|
||||
fn maybe_changed_after(
|
||||
&self,
|
||||
_db: &DB,
|
||||
fn maybe_changed_after<'db>(
|
||||
&'db self,
|
||||
db: &'db DB,
|
||||
input: crate::key::DependencyIndex,
|
||||
revision: crate::Revision,
|
||||
) -> bool {
|
||||
let runtime = db.runtime();
|
||||
let id = <C::Id>::from_id(input.key_index.unwrap());
|
||||
match self.entity_data.get(&id) {
|
||||
Some(data) => {
|
||||
let field_changed_at = C::revision(&data.revisions, self.field_index);
|
||||
field_changed_at > revision
|
||||
}
|
||||
None => {
|
||||
panic!("no data found for field `{id:?}`");
|
||||
}
|
||||
}
|
||||
let data = self.struct_map.get(runtime, id);
|
||||
let field_changed_at = C::revision(&data.revisions, self.field_index);
|
||||
field_changed_at > revision
|
||||
}
|
||||
|
||||
fn origin(&self, _key_index: crate::Id) -> Option<crate::runtime::local_state::QueryOrigin> {
|
||||
|
|
103
salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs
Normal file
103
salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs
Normal file
|
@ -0,0 +1,103 @@
|
|||
//! Test that a `tracked` fn on a `salsa::input`
|
||||
//! compiles and executes successfully.
|
||||
|
||||
use std::cell::Cell;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::DebugWithDb;
|
||||
use salsa_2022_tests::{HasLogger, Logger};
|
||||
use test_log::test;
|
||||
|
||||
thread_local! {
|
||||
static COUNTER: Cell<usize> = Cell::new(0);
|
||||
}
|
||||
|
||||
#[salsa::jar(db = Db)]
|
||||
struct Jar(MyInput, MyTracked, function);
|
||||
|
||||
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
|
||||
|
||||
#[salsa::db(Jar)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
storage: salsa::Storage<Self>,
|
||||
logger: Logger,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {
|
||||
fn salsa_event(&self, event: salsa::Event) {
|
||||
self.push_log(format!("{:?}", event.debug(self)));
|
||||
}
|
||||
}
|
||||
|
||||
impl Db for Database {}
|
||||
|
||||
impl HasLogger for Database {
|
||||
fn logger(&self) -> &Logger {
|
||||
&self.logger
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::input]
|
||||
struct MyInput {
|
||||
field1: u32,
|
||||
field2: u32,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
struct MyTracked {
|
||||
counter: usize,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn function(db: &dyn Db, input: MyInput) -> usize {
|
||||
// Read input 1
|
||||
let _field1 = input.field1(db);
|
||||
|
||||
// **BAD:** Leak in the value of the counter non-deterministically
|
||||
let counter = COUNTER.with(|c| c.get());
|
||||
|
||||
// Create the tracked struct, which (from salsa's POV), only depends on field1;
|
||||
// but which actually depends on the leaked value.
|
||||
let tracked = MyTracked::new(db, counter);
|
||||
|
||||
// Read input 2. This will cause us to re-execute on revision 2.
|
||||
let _field2 = input.field2(db);
|
||||
|
||||
tracked.counter(db)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaked_inputs_ignored() {
|
||||
let mut db = Database::default();
|
||||
|
||||
let input = MyInput::new(&db, 10, 20);
|
||||
let result_in_rev_1 = function(&db, input);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }",
|
||||
]"#]]);
|
||||
|
||||
assert_eq!(result_in_rev_1, 0);
|
||||
|
||||
// Modify field2 so that `function` is seen to have changed --
|
||||
// but only *after* the tracked struct is created.
|
||||
input.set_field2(&mut db).to(30);
|
||||
|
||||
// Also modify the thread-local counter
|
||||
COUNTER.with(|c| c.set(100));
|
||||
|
||||
let result_in_rev_2 = function(&db, input);
|
||||
db.assert_logs(expect![[r#"
|
||||
[
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }",
|
||||
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }",
|
||||
]"#]]);
|
||||
|
||||
// Because salsa did not see any way for the tracked
|
||||
// struct to have changed, its field values will not have
|
||||
// been updated, even though in theory they would have
|
||||
// the leaked value from the counter.
|
||||
assert_eq!(result_in_rev_2, 0);
|
||||
}
|
Loading…
Reference in a new issue