diff --git a/components/salsa-2022/src/tracked_struct.rs b/components/salsa-2022/src/tracked_struct.rs index 32214716..0a548060 100644 --- a/components/salsa-2022/src/tracked_struct.rs +++ b/components/salsa-2022/src/tracked_struct.rs @@ -1,23 +1,21 @@ -use std::{fmt, hash::Hash, sync::Arc}; - -use crossbeam::queue::SegQueue; +use std::{fmt, hash::Hash}; use crate::{ cycle::CycleRecoveryStrategy, - hash::FxDashMap, id::AsId, ingredient::{fmt_index, Ingredient, IngredientRequiresReset}, ingredient_list::IngredientList, interned::{InternedId, InternedIngredient}, key::{DatabaseKeyIndex, DependencyIndex}, - plumbing::transmute_lifetime, runtime::{local_state::QueryOrigin, Runtime}, salsa_struct::SalsaStructInDb, Database, Durability, Event, IngredientIndex, Revision, }; +use self::struct_map::{StructMap, Update}; pub use self::tracked_field::TrackedFieldIngredient; +mod struct_map; mod tracked_field; // ANCHOR: Configuration @@ -97,7 +95,7 @@ where { interned: InternedIngredient, - entity_data: Arc>>>, + struct_map: struct_map::StructMap, /// A list of each tracked function `f` whose key is this /// tracked struct. @@ -107,13 +105,6 @@ where /// so they can remove any data tied to that instance. dependent_fns: IngredientList, - /// When specific entities are deleted, their data is added - /// to this vector rather than being immediately freed. This is because we may` have - /// references to that data floating about that are tied to the lifetime of some - /// `&db` reference. This queue itself is not freed until we have an `&mut db` reference, - /// guaranteeing that there are no more references to it. - deleted_entries: SegQueue>>, - debug_name: &'static str, } @@ -159,16 +150,6 @@ where } // ANCHOR_END: TrackedStructValue -impl TrackedStructValue -where - C: Configuration, -{ - /// The id of this struct in the ingredient. - pub fn id(&self) -> C::Id { - self.id - } -} - #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] pub struct Disambiguator(pub u32); @@ -179,9 +160,8 @@ where pub fn new(index: IngredientIndex, debug_name: &'static str) -> Self { Self { interned: InternedIngredient::new(index, debug_name), - entity_data: Default::default(), + struct_map: StructMap::new(), dependent_fns: IngredientList::new(), - deleted_entries: SegQueue::default(), debug_name, } } @@ -195,7 +175,7 @@ where TrackedFieldIngredient { ingredient_index: field_ingredient_index, field_index, - entity_data: self.entity_data.clone(), + struct_map: self.struct_map.view(), struct_debug_name: self.debug_name, field_debug_name, } @@ -208,7 +188,11 @@ where } } - pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> &TrackedStructValue { + pub fn new_struct<'db>( + &'db self, + runtime: &'db Runtime, + fields: C::Fields, + ) -> &'db TrackedStructValue { let data_hash = crate::hash::hash(&C::id_fields(&fields)); let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity( @@ -225,56 +209,56 @@ where let (id, new_id) = self.interned.intern_full(runtime, entity_key); runtime.add_output(self.database_key_index(id).into()); - let pointer: *const TrackedStructValue; let current_revision = runtime.current_revision(); if new_id { - let data = Box::new(TrackedStructValue { - id, - created_at: current_revision, - durability: current_deps.durability, - fields, - revisions: C::new_revisions(current_deps.changed_at), - }); - - // Keep a pointer into the box for later - pointer = &*data; - - let old_value = self.entity_data.insert(id, data); - assert!(old_value.is_none()); - } else { - let mut data = self.entity_data.get_mut(&id).unwrap(); - let data = &mut *data; - - // Keep a pointer into the box for later - pointer = &**data; - - // SAFETY: We assert that the pointer to `data.revisions` - // is a pointer into the database referencing a value - // from a previous revision. As such, it continues to meet - // its validity invariant and any owned content also continues - // to meet its safety invariant. - unsafe { - C::update_fields( - current_revision, - &mut data.revisions, - std::ptr::addr_of_mut!(data.fields), + self.struct_map.insert( + runtime, + TrackedStructValue { + id, + created_at: current_revision, + durability: current_deps.durability, fields, - ); - } - if current_deps.durability < data.durability { - data.revisions = C::new_revisions(current_revision); - } - data.created_at = current_revision; - data.durability = current_deps.durability; - } + revisions: C::new_revisions(current_deps.changed_at), + }, + ) + } else { + match self.struct_map.update(runtime, id) { + Update::Current(r) => { + // All inputs up to this point were previously + // observed to be green and this struct was already + // verified. Therefore, the durability ought not to have + // changed (nor the field values, but the user could've + // done something stupid, so we can't *assert* this is true). + assert!(r.durability == current_deps.durability); - // Unsafety clause: - // - // * The box is owned by self and, although the box has been moved, - // the pointer is to the contents of the box, which have a stable - // address. - // * Values are only removed or altered when we have `&mut self`. - unsafe { transmute_lifetime(self, &*pointer) } + r + } + Update::Outdated(mut data_ref) => { + let data = &mut *data_ref; + + // SAFETY: We assert that the pointer to `data.revisions` + // is a pointer into the database referencing a value + // from a previous revision. As such, it continues to meet + // its validity invariant and any owned content also continues + // to meet its safety invariant. + unsafe { + C::update_fields( + current_revision, + &mut data.revisions, + std::ptr::addr_of_mut!(data.fields), + fields, + ); + } + if current_deps.durability < data.durability { + data.revisions = C::new_revisions(current_revision); + } + data.created_at = current_revision; + data.durability = current_deps.durability; + + data_ref.freeze() + } + } + } } /// Deletes the given entities. This is used after a query `Q` executes and we can compare @@ -296,9 +280,7 @@ where }); self.interned.delete_index(id); - if let Some((_, data)) = self.entity_data.remove(&id) { - self.deleted_entries.push(data); - } + self.struct_map.delete(id); for dependent_fn in self.dependent_fns.iter() { db.salsa_struct_deleted(dependent_fn, id.as_id()); @@ -334,16 +316,16 @@ where None } - fn mark_validated_output( - &self, - db: &DB, + fn mark_validated_output<'db>( + &'db self, + db: &'db DB, _executor: DatabaseKeyIndex, output_key: Option, ) { + let runtime = db.runtime(); let output_key = output_key.unwrap(); let output_key: C::Id = ::from_id(output_key); - let mut entity = self.entity_data.get_mut(&output_key).unwrap(); - entity.created_at = db.runtime().current_revision(); + self.struct_map.validate(runtime, output_key); } fn remove_stale_output( @@ -362,7 +344,7 @@ where fn reset_for_new_revision(&mut self) { self.interned.clear_deleted_indices(); - std::mem::take(&mut self.deleted_entries); + self.struct_map.drop_deleted_entries(); } fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) { @@ -380,3 +362,13 @@ where { const RESET_ON_NEW_REVISION: bool = true; } + +impl TrackedStructValue +where + C: Configuration, +{ + /// The id of this struct in the ingredient. + pub fn id(&self) -> C::Id { + self.id + } +} diff --git a/components/salsa-2022/src/tracked_struct/struct_map.rs b/components/salsa-2022/src/tracked_struct/struct_map.rs new file mode 100644 index 00000000..9c26a460 --- /dev/null +++ b/components/salsa-2022/src/tracked_struct/struct_map.rs @@ -0,0 +1,259 @@ +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use crossbeam::queue::SegQueue; +use dashmap::mapref::one::RefMut; + +use crate::{ + hash::{FxDashMap, FxHasher}, + plumbing::transmute_lifetime, + Runtime, +}; + +use super::{Configuration, TrackedStructValue}; + +pub(crate) struct StructMap +where + C: Configuration, +{ + map: Arc>>>, + + /// When specific entities are deleted, their data is added + /// to this vector rather than being immediately freed. This is because we may` have + /// references to that data floating about that are tied to the lifetime of some + /// `&db` reference. This queue itself is not freed until we have an `&mut db` reference, + /// guaranteeing that there are no more references to it. + deleted_entries: SegQueue>>, +} + +pub(crate) struct StructMapView +where + C: Configuration, +{ + map: Arc>>>, +} + +/// Return value for [`StructMap`][]'s `update` method. +pub(crate) enum Update<'db, C> +where + C: Configuration, +{ + /// Indicates that the given struct has not yet been verified in this revision. + /// The [`UpdateRef`][] gives mutable access to the field contents so that + /// its fields can be compared and updated. + Outdated(UpdateRef<'db, C>), + + /// Indicates that we have already verified that all the inputs accessed prior + /// to this struct creation were up-to-date, and therefore the field contents + /// ought not to have changed (barring user error). Returns a shared reference + /// because caller cannot safely modify fields at this point. + Current(&'db TrackedStructValue), +} + +impl StructMap +where + C: Configuration, +{ + pub fn new() -> Self { + Self { + map: Arc::new(FxDashMap::default()), + deleted_entries: SegQueue::new(), + } + } + + /// Get a secondary view onto this struct-map that can be used to fetch entries. + pub fn view(&self) -> StructMapView { + StructMapView { + map: self.map.clone(), + } + } + + /// Insert the given tracked struct value into the map. + /// + /// # Panics + /// + /// * If value with same `value.id` is already present in the map. + /// * If value not created in current revision. + pub fn insert<'db>( + &'db self, + runtime: &'db Runtime, + value: TrackedStructValue, + ) -> &TrackedStructValue { + assert_eq!(value.created_at, runtime.current_revision()); + + let boxed_value = Box::new(value); + let pointer = std::ptr::addr_of!(*boxed_value); + + let old_value = self.map.insert(boxed_value.id, boxed_value); + assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here + + // Unsafety clause: + // + // * The box is owned by self and, although the box has been moved, + // the pointer is to the contents of the box, which have a stable + // address. + // * Values are only removed or altered when we have `&mut self`. + unsafe { transmute_lifetime(self, &*pointer) } + } + + pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: C::Id) { + let mut data = self.map.get_mut(&id).unwrap(); + + // Never update a struct twice in the same revision. + let current_revision = runtime.current_revision(); + assert!(data.created_at < current_revision); + data.created_at = current_revision; + } + + /// Get mutable access to the data for `id` -- this holds a write lock for the duration + /// of the returned value. + /// + /// # Panics + /// + /// * If the value is not present in the map. + /// * If the value is already updated in this revision. + pub fn update<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> Update<'db, C> { + let mut data = self.map.get_mut(&id).unwrap(); + + // Never update a struct twice in the same revision. + let current_revision = runtime.current_revision(); + + // Subtle: it's possible that this struct was already validated + // in this revision. What can happen (e.g., in the test + // `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`) + // is that + // + // * Revision 1: + // * Tracked function F creates tracked struct S + // * F reads input I + // + // In Revision 2, I is changed, and F is re-executed. + // We try to validate F's inputs/outputs, which is the list [output: S, input: I]. + // As no inputs have changed by the time we reach S, we mark it as verified. + // But then input I is seen to hvae changed, and so we re-execute F. + // Note that we *know* that S will have the same value (barring program bugs). + // + // Further complicating things: it is possible that F calls F2 + // and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to + // re-execute which means that it may indeed have read from S's fields + // during the current revision and thus obtained an `&` reference to those fields + // that is still live. + // + // For this reason, we just return `None` in this case, ensuring that the calling + // code cannot violate that `&`-reference. + if data.created_at == current_revision { + drop(data); + return Update::Current(&Self::get_from_map(&self.map, runtime, id)); + } + + data.created_at = current_revision; + Update::Outdated(UpdateRef { guard: data }) + } + + /// Helper function, provides shared functionality for [`StructMapView`][] + /// + /// # Panics + /// + /// * If the value is not present in the map. + /// * If the value has not been updated in this revision. + fn get_from_map<'db>( + map: &'db FxDashMap>>, + runtime: &'db Runtime, + id: C::Id, + ) -> &'db TrackedStructValue { + let data = map.get(&id).unwrap(); + let data: &TrackedStructValue = &**data; + + // Before we drop the lock, check that the value has + // been updated in this revision. This is what allows us to return a `` + let current_revision = runtime.current_revision(); + let created_at = data.created_at; + assert!( + created_at == current_revision, + "access to tracked struct from previous revision" + ); + + // Unsafety clause: + // + // * Value will not be updated again in this revision, + // and revision will not change so long as runtime is shared + // * We only remove values from the map when we have `&mut self` + unsafe { transmute_lifetime(map, data) } + } + + /// Remove the entry for `id` from the map. + /// + /// NB. the data won't actually be freed until `drop_deleted_entries` is called. + pub fn delete(&self, id: C::Id) { + if let Some((_, data)) = self.map.remove(&id) { + self.deleted_entries.push(data); + } + } + + /// Drop all entries deleted until now. + pub fn drop_deleted_entries(&mut self) { + std::mem::take(&mut self.deleted_entries); + } +} + +impl StructMapView +where + C: Configuration, +{ + /// Get a pointer to the data for the given `id`. + /// + /// # Panics + /// + /// * If the value is not present in the map. + /// * If the value has not been updated in this revision. + pub fn get<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db TrackedStructValue { + StructMap::get_from_map(&self.map, runtime, id) + } +} + +/// A mutable reference to the data for a single struct. +/// Can be "frozen" to yield an `&` that will remain valid +/// until the end of the revision. +pub(crate) struct UpdateRef<'db, C> +where + C: Configuration, +{ + guard: RefMut<'db, C::Id, Box>, FxHasher>, +} + +impl<'db, C> UpdateRef<'db, C> +where + C: Configuration, +{ + /// Finalize this update, freezing the value for the rest of the revision. + pub fn freeze(self) -> &'db TrackedStructValue { + // Unsafety clause: + // + // see `get` above + let data: &TrackedStructValue = &*self.guard; + let dummy: &'db () = &(); + unsafe { transmute_lifetime(dummy, data) } + } +} + +impl Deref for UpdateRef<'_, C> +where + C: Configuration, +{ + type Target = TrackedStructValue; + + fn deref(&self) -> &Self::Target { + &self.guard + } +} + +impl DerefMut for UpdateRef<'_, C> +where + C: Configuration, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.guard + } +} diff --git a/components/salsa-2022/src/tracked_struct/tracked_field.rs b/components/salsa-2022/src/tracked_struct/tracked_field.rs index 0568e819..2ef4d8b7 100644 --- a/components/salsa-2022/src/tracked_struct/tracked_field.rs +++ b/components/salsa-2022/src/tracked_struct/tracked_field.rs @@ -1,16 +1,11 @@ -use std::sync::Arc; - use crate::{ - hash::FxDashMap, id::AsId, ingredient::{Ingredient, IngredientRequiresReset}, key::DependencyIndex, - plumbing::transmute_lifetime, - tracked_struct::TrackedStructValue, - IngredientIndex, Runtime, + Database, IngredientIndex, Runtime, }; -use super::Configuration; +use super::{struct_map::StructMapView, Configuration}; /// Created for each tracked struct. /// This ingredient only stores the "id" fields. @@ -27,7 +22,7 @@ where /// Index of this ingredient in the database (used to construct database-ids, etc). pub(super) ingredient_index: IngredientIndex, pub(super) field_index: u32, - pub(super) entity_data: Arc>>>, + pub(super) struct_map: StructMapView, pub(super) struct_debug_name: &'static str, pub(super) field_debug_name: &'static str, } @@ -40,16 +35,7 @@ where /// Note that this function returns the entire tuple of value fields. /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db C::Fields { - let Some(data) = self.entity_data.get(&id) else { - panic!("no data found for entity id {id:?}"); - }; - - let current_revision = runtime.current_revision(); - let created_at = data.created_at; - assert!( - created_at == current_revision, - "access to tracked struct from previous revision" - ); + let data = self.struct_map.get(runtime, id); let changed_at = C::revision(&data.revisions, self.field_index); @@ -62,15 +48,13 @@ where changed_at, ); - // Unsafety clause: - // - // * Values are only removed or altered when we have `&mut self` - unsafe { transmute_lifetime(self, &data.fields) } + &data.fields } } impl Ingredient for TrackedFieldIngredient where + DB: Database, C: Configuration, { fn ingredient_index(&self) -> IngredientIndex { @@ -81,22 +65,17 @@ where crate::cycle::CycleRecoveryStrategy::Panic } - fn maybe_changed_after( - &self, - _db: &DB, + fn maybe_changed_after<'db>( + &'db self, + db: &'db DB, input: crate::key::DependencyIndex, revision: crate::Revision, ) -> bool { + let runtime = db.runtime(); let id = ::from_id(input.key_index.unwrap()); - match self.entity_data.get(&id) { - Some(data) => { - let field_changed_at = C::revision(&data.revisions, self.field_index); - field_changed_at > revision - } - None => { - panic!("no data found for field `{id:?}`"); - } - } + let data = self.struct_map.get(runtime, id); + let field_changed_at = C::revision(&data.revisions, self.field_index); + field_changed_at > revision } fn origin(&self, _key_index: crate::Id) -> Option { diff --git a/salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs b/salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs new file mode 100644 index 00000000..1add83d7 --- /dev/null +++ b/salsa-2022-tests/tests/preverify-struct-with-leaked-data.rs @@ -0,0 +1,103 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. + +use std::cell::Cell; + +use expect_test::expect; +use salsa::DebugWithDb; +use salsa_2022_tests::{HasLogger, Logger}; +use test_log::test; + +thread_local! { + static COUNTER: Cell = Cell::new(0); +} + +#[salsa::jar(db = Db)] +struct Jar(MyInput, MyTracked, function); + +trait Db: salsa::DbWithJar + HasLogger {} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, + logger: Logger, +} + +impl salsa::Database for Database { + fn salsa_event(&self, event: salsa::Event) { + self.push_log(format!("{:?}", event.debug(self))); + } +} + +impl Db for Database {} + +impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } +} + +#[salsa::input] +struct MyInput { + field1: u32, + field2: u32, +} + +#[salsa::tracked] +struct MyTracked { + counter: usize, +} + +#[salsa::tracked] +fn function(db: &dyn Db, input: MyInput) -> usize { + // Read input 1 + let _field1 = input.field1(db); + + // **BAD:** Leak in the value of the counter non-deterministically + let counter = COUNTER.with(|c| c.get()); + + // Create the tracked struct, which (from salsa's POV), only depends on field1; + // but which actually depends on the leaked value. + let tracked = MyTracked::new(db, counter); + + // Read input 2. This will cause us to re-execute on revision 2. + let _field2 = input.field2(db); + + tracked.counter(db) +} + +#[test] +fn test_leaked_inputs_ignored() { + let mut db = Database::default(); + + let input = MyInput::new(&db, 10, 20); + let result_in_rev_1 = function(&db, input); + db.assert_logs(expect![[r#" + [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }", + ]"#]]); + + assert_eq!(result_in_rev_1, 0); + + // Modify field2 so that `function` is seen to have changed -- + // but only *after* the tracked struct is created. + input.set_field2(&mut db).to(30); + + // Also modify the thread-local counter + COUNTER.with(|c| c.set(100)); + + let result_in_rev_2 = function(&db, input); + db.assert_logs(expect![[r#" + [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }", + ]"#]]); + + // Because salsa did not see any way for the tracked + // struct to have changed, its field values will not have + // been updated, even though in theory they would have + // the leaked value from the counter. + assert_eq!(result_in_rev_2, 0); +}