create a struct_map that encapsulates access

The internal API is now based around providing
references to the `TrackedStructValue`.

Documenting the invariants led to one interesting
case, which is that we sometimes verify a tracked
struct as not having changed (and even create
`&`-ref to it!) but then re-execute the function
around it.

We now guarantee that, in this case, the data
does not change, even if it has leaked values.
This is required to ensure soundness.
Add a test case about it.
This commit is contained in:
Niko Matsakis 2024-04-13 07:34:25 -04:00
parent 20cb307301
commit ea1d452143
4 changed files with 449 additions and 116 deletions

View file

@ -1,23 +1,21 @@
use std::{fmt, hash::Hash, sync::Arc};
use crossbeam::queue::SegQueue;
use std::{fmt, hash::Hash};
use crate::{
cycle::CycleRecoveryStrategy,
hash::FxDashMap,
id::AsId,
ingredient::{fmt_index, Ingredient, IngredientRequiresReset},
ingredient_list::IngredientList,
interned::{InternedId, InternedIngredient},
key::{DatabaseKeyIndex, DependencyIndex},
plumbing::transmute_lifetime,
runtime::{local_state::QueryOrigin, Runtime},
salsa_struct::SalsaStructInDb,
Database, Durability, Event, IngredientIndex, Revision,
};
use self::struct_map::{StructMap, Update};
pub use self::tracked_field::TrackedFieldIngredient;
mod struct_map;
mod tracked_field;
// ANCHOR: Configuration
@ -97,7 +95,7 @@ where
{
interned: InternedIngredient<C::Id, TrackedStructKey>,
entity_data: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
struct_map: struct_map::StructMap<C>,
/// A list of each tracked function `f` whose key is this
/// tracked struct.
@ -107,13 +105,6 @@ where
/// so they can remove any data tied to that instance.
dependent_fns: IngredientList,
/// When specific entities are deleted, their data is added
/// to this vector rather than being immediately freed. This is because we may` have
/// references to that data floating about that are tied to the lifetime of some
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
/// guaranteeing that there are no more references to it.
deleted_entries: SegQueue<Box<TrackedStructValue<C>>>,
debug_name: &'static str,
}
@ -159,16 +150,6 @@ where
}
// ANCHOR_END: TrackedStructValue
impl<C> TrackedStructValue<C>
where
C: Configuration,
{
/// The id of this struct in the ingredient.
pub fn id(&self) -> C::Id {
self.id
}
}
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
pub struct Disambiguator(pub u32);
@ -179,9 +160,8 @@ where
pub fn new(index: IngredientIndex, debug_name: &'static str) -> Self {
Self {
interned: InternedIngredient::new(index, debug_name),
entity_data: Default::default(),
struct_map: StructMap::new(),
dependent_fns: IngredientList::new(),
deleted_entries: SegQueue::default(),
debug_name,
}
}
@ -195,7 +175,7 @@ where
TrackedFieldIngredient {
ingredient_index: field_ingredient_index,
field_index,
entity_data: self.entity_data.clone(),
struct_map: self.struct_map.view(),
struct_debug_name: self.debug_name,
field_debug_name,
}
@ -208,7 +188,11 @@ where
}
}
pub fn new_struct(&self, runtime: &Runtime, fields: C::Fields) -> &TrackedStructValue<C> {
pub fn new_struct<'db>(
&'db self,
runtime: &'db Runtime,
fields: C::Fields,
) -> &'db TrackedStructValue<C> {
let data_hash = crate::hash::hash(&C::id_fields(&fields));
let (query_key, current_deps, disambiguator) = runtime.disambiguate_entity(
@ -225,56 +209,56 @@ where
let (id, new_id) = self.interned.intern_full(runtime, entity_key);
runtime.add_output(self.database_key_index(id).into());
let pointer: *const TrackedStructValue<C>;
let current_revision = runtime.current_revision();
if new_id {
let data = Box::new(TrackedStructValue {
id,
created_at: current_revision,
durability: current_deps.durability,
fields,
revisions: C::new_revisions(current_deps.changed_at),
});
// Keep a pointer into the box for later
pointer = &*data;
let old_value = self.entity_data.insert(id, data);
assert!(old_value.is_none());
} else {
let mut data = self.entity_data.get_mut(&id).unwrap();
let data = &mut *data;
// Keep a pointer into the box for later
pointer = &**data;
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
std::ptr::addr_of_mut!(data.fields),
self.struct_map.insert(
runtime,
TrackedStructValue {
id,
created_at: current_revision,
durability: current_deps.durability,
fields,
);
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
}
data.created_at = current_revision;
data.durability = current_deps.durability;
}
revisions: C::new_revisions(current_deps.changed_at),
},
)
} else {
match self.struct_map.update(runtime, id) {
Update::Current(r) => {
// All inputs up to this point were previously
// observed to be green and this struct was already
// verified. Therefore, the durability ought not to have
// changed (nor the field values, but the user could've
// done something stupid, so we can't *assert* this is true).
assert!(r.durability == current_deps.durability);
// Unsafety clause:
//
// * The box is owned by self and, although the box has been moved,
// the pointer is to the contents of the box, which have a stable
// address.
// * Values are only removed or altered when we have `&mut self`.
unsafe { transmute_lifetime(self, &*pointer) }
r
}
Update::Outdated(mut data_ref) => {
let data = &mut *data_ref;
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
std::ptr::addr_of_mut!(data.fields),
fields,
);
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
}
data.created_at = current_revision;
data.durability = current_deps.durability;
data_ref.freeze()
}
}
}
}
/// Deletes the given entities. This is used after a query `Q` executes and we can compare
@ -296,9 +280,7 @@ where
});
self.interned.delete_index(id);
if let Some((_, data)) = self.entity_data.remove(&id) {
self.deleted_entries.push(data);
}
self.struct_map.delete(id);
for dependent_fn in self.dependent_fns.iter() {
db.salsa_struct_deleted(dependent_fn, id.as_id());
@ -334,16 +316,16 @@ where
None
}
fn mark_validated_output(
&self,
db: &DB,
fn mark_validated_output<'db>(
&'db self,
db: &'db DB,
_executor: DatabaseKeyIndex,
output_key: Option<crate::Id>,
) {
let runtime = db.runtime();
let output_key = output_key.unwrap();
let output_key: C::Id = <C::Id>::from_id(output_key);
let mut entity = self.entity_data.get_mut(&output_key).unwrap();
entity.created_at = db.runtime().current_revision();
self.struct_map.validate(runtime, output_key);
}
fn remove_stale_output(
@ -362,7 +344,7 @@ where
fn reset_for_new_revision(&mut self) {
self.interned.clear_deleted_indices();
std::mem::take(&mut self.deleted_entries);
self.struct_map.drop_deleted_entries();
}
fn salsa_struct_deleted(&self, _db: &DB, _id: crate::Id) {
@ -380,3 +362,13 @@ where
{
const RESET_ON_NEW_REVISION: bool = true;
}
impl<C> TrackedStructValue<C>
where
C: Configuration,
{
/// The id of this struct in the ingredient.
pub fn id(&self) -> C::Id {
self.id
}
}

View file

@ -0,0 +1,259 @@
use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
use crossbeam::queue::SegQueue;
use dashmap::mapref::one::RefMut;
use crate::{
hash::{FxDashMap, FxHasher},
plumbing::transmute_lifetime,
Runtime,
};
use super::{Configuration, TrackedStructValue};
pub(crate) struct StructMap<C>
where
C: Configuration,
{
map: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
/// When specific entities are deleted, their data is added
/// to this vector rather than being immediately freed. This is because we may` have
/// references to that data floating about that are tied to the lifetime of some
/// `&db` reference. This queue itself is not freed until we have an `&mut db` reference,
/// guaranteeing that there are no more references to it.
deleted_entries: SegQueue<Box<TrackedStructValue<C>>>,
}
pub(crate) struct StructMapView<C>
where
C: Configuration,
{
map: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
}
/// Return value for [`StructMap`][]'s `update` method.
pub(crate) enum Update<'db, C>
where
C: Configuration,
{
/// Indicates that the given struct has not yet been verified in this revision.
/// The [`UpdateRef`][] gives mutable access to the field contents so that
/// its fields can be compared and updated.
Outdated(UpdateRef<'db, C>),
/// Indicates that we have already verified that all the inputs accessed prior
/// to this struct creation were up-to-date, and therefore the field contents
/// ought not to have changed (barring user error). Returns a shared reference
/// because caller cannot safely modify fields at this point.
Current(&'db TrackedStructValue<C>),
}
impl<C> StructMap<C>
where
C: Configuration,
{
pub fn new() -> Self {
Self {
map: Arc::new(FxDashMap::default()),
deleted_entries: SegQueue::new(),
}
}
/// Get a secondary view onto this struct-map that can be used to fetch entries.
pub fn view(&self) -> StructMapView<C> {
StructMapView {
map: self.map.clone(),
}
}
/// Insert the given tracked struct value into the map.
///
/// # Panics
///
/// * If value with same `value.id` is already present in the map.
/// * If value not created in current revision.
pub fn insert<'db>(
&'db self,
runtime: &'db Runtime,
value: TrackedStructValue<C>,
) -> &TrackedStructValue<C> {
assert_eq!(value.created_at, runtime.current_revision());
let boxed_value = Box::new(value);
let pointer = std::ptr::addr_of!(*boxed_value);
let old_value = self.map.insert(boxed_value.id, boxed_value);
assert!(old_value.is_none()); // ...strictly speaking we probably need to abort here
// Unsafety clause:
//
// * The box is owned by self and, although the box has been moved,
// the pointer is to the contents of the box, which have a stable
// address.
// * Values are only removed or altered when we have `&mut self`.
unsafe { transmute_lifetime(self, &*pointer) }
}
pub fn validate<'db>(&'db self, runtime: &'db Runtime, id: C::Id) {
let mut data = self.map.get_mut(&id).unwrap();
// Never update a struct twice in the same revision.
let current_revision = runtime.current_revision();
assert!(data.created_at < current_revision);
data.created_at = current_revision;
}
/// Get mutable access to the data for `id` -- this holds a write lock for the duration
/// of the returned value.
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value is already updated in this revision.
pub fn update<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> Update<'db, C> {
let mut data = self.map.get_mut(&id).unwrap();
// Never update a struct twice in the same revision.
let current_revision = runtime.current_revision();
// Subtle: it's possible that this struct was already validated
// in this revision. What can happen (e.g., in the test
// `test_run_5_then_20` in `specify_tracked_fn_in_rev_1_but_not_2.rs`)
// is that
//
// * Revision 1:
// * Tracked function F creates tracked struct S
// * F reads input I
//
// In Revision 2, I is changed, and F is re-executed.
// We try to validate F's inputs/outputs, which is the list [output: S, input: I].
// As no inputs have changed by the time we reach S, we mark it as verified.
// But then input I is seen to hvae changed, and so we re-execute F.
// Note that we *know* that S will have the same value (barring program bugs).
//
// Further complicating things: it is possible that F calls F2
// and gives it (e.g.) S as one of its arguments. Validating F2 may cause F2 to
// re-execute which means that it may indeed have read from S's fields
// during the current revision and thus obtained an `&` reference to those fields
// that is still live.
//
// For this reason, we just return `None` in this case, ensuring that the calling
// code cannot violate that `&`-reference.
if data.created_at == current_revision {
drop(data);
return Update::Current(&Self::get_from_map(&self.map, runtime, id));
}
data.created_at = current_revision;
Update::Outdated(UpdateRef { guard: data })
}
/// Helper function, provides shared functionality for [`StructMapView`][]
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
fn get_from_map<'db>(
map: &'db FxDashMap<C::Id, Box<TrackedStructValue<C>>>,
runtime: &'db Runtime,
id: C::Id,
) -> &'db TrackedStructValue<C> {
let data = map.get(&id).unwrap();
let data: &TrackedStructValue<C> = &**data;
// Before we drop the lock, check that the value has
// been updated in this revision. This is what allows us to return a ``
let current_revision = runtime.current_revision();
let created_at = data.created_at;
assert!(
created_at == current_revision,
"access to tracked struct from previous revision"
);
// Unsafety clause:
//
// * Value will not be updated again in this revision,
// and revision will not change so long as runtime is shared
// * We only remove values from the map when we have `&mut self`
unsafe { transmute_lifetime(map, data) }
}
/// Remove the entry for `id` from the map.
///
/// NB. the data won't actually be freed until `drop_deleted_entries` is called.
pub fn delete(&self, id: C::Id) {
if let Some((_, data)) = self.map.remove(&id) {
self.deleted_entries.push(data);
}
}
/// Drop all entries deleted until now.
pub fn drop_deleted_entries(&mut self) {
std::mem::take(&mut self.deleted_entries);
}
}
impl<C> StructMapView<C>
where
C: Configuration,
{
/// Get a pointer to the data for the given `id`.
///
/// # Panics
///
/// * If the value is not present in the map.
/// * If the value has not been updated in this revision.
pub fn get<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db TrackedStructValue<C> {
StructMap::get_from_map(&self.map, runtime, id)
}
}
/// A mutable reference to the data for a single struct.
/// Can be "frozen" to yield an `&` that will remain valid
/// until the end of the revision.
pub(crate) struct UpdateRef<'db, C>
where
C: Configuration,
{
guard: RefMut<'db, C::Id, Box<TrackedStructValue<C>>, FxHasher>,
}
impl<'db, C> UpdateRef<'db, C>
where
C: Configuration,
{
/// Finalize this update, freezing the value for the rest of the revision.
pub fn freeze(self) -> &'db TrackedStructValue<C> {
// Unsafety clause:
//
// see `get` above
let data: &TrackedStructValue<C> = &*self.guard;
let dummy: &'db () = &();
unsafe { transmute_lifetime(dummy, data) }
}
}
impl<C> Deref for UpdateRef<'_, C>
where
C: Configuration,
{
type Target = TrackedStructValue<C>;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<C> DerefMut for UpdateRef<'_, C>
where
C: Configuration,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}

View file

@ -1,16 +1,11 @@
use std::sync::Arc;
use crate::{
hash::FxDashMap,
id::AsId,
ingredient::{Ingredient, IngredientRequiresReset},
key::DependencyIndex,
plumbing::transmute_lifetime,
tracked_struct::TrackedStructValue,
IngredientIndex, Runtime,
Database, IngredientIndex, Runtime,
};
use super::Configuration;
use super::{struct_map::StructMapView, Configuration};
/// Created for each tracked struct.
/// This ingredient only stores the "id" fields.
@ -27,7 +22,7 @@ where
/// Index of this ingredient in the database (used to construct database-ids, etc).
pub(super) ingredient_index: IngredientIndex,
pub(super) field_index: u32,
pub(super) entity_data: Arc<FxDashMap<C::Id, Box<TrackedStructValue<C>>>>,
pub(super) struct_map: StructMapView<C>,
pub(super) struct_debug_name: &'static str,
pub(super) field_debug_name: &'static str,
}
@ -40,16 +35,7 @@ where
/// Note that this function returns the entire tuple of value fields.
/// The caller is responible for selecting the appropriate element.
pub fn field<'db>(&'db self, runtime: &'db Runtime, id: C::Id) -> &'db C::Fields {
let Some(data) = self.entity_data.get(&id) else {
panic!("no data found for entity id {id:?}");
};
let current_revision = runtime.current_revision();
let created_at = data.created_at;
assert!(
created_at == current_revision,
"access to tracked struct from previous revision"
);
let data = self.struct_map.get(runtime, id);
let changed_at = C::revision(&data.revisions, self.field_index);
@ -62,15 +48,13 @@ where
changed_at,
);
// Unsafety clause:
//
// * Values are only removed or altered when we have `&mut self`
unsafe { transmute_lifetime(self, &data.fields) }
&data.fields
}
}
impl<DB: ?Sized, C> Ingredient<DB> for TrackedFieldIngredient<C>
where
DB: Database,
C: Configuration,
{
fn ingredient_index(&self) -> IngredientIndex {
@ -81,22 +65,17 @@ where
crate::cycle::CycleRecoveryStrategy::Panic
}
fn maybe_changed_after(
&self,
_db: &DB,
fn maybe_changed_after<'db>(
&'db self,
db: &'db DB,
input: crate::key::DependencyIndex,
revision: crate::Revision,
) -> bool {
let runtime = db.runtime();
let id = <C::Id>::from_id(input.key_index.unwrap());
match self.entity_data.get(&id) {
Some(data) => {
let field_changed_at = C::revision(&data.revisions, self.field_index);
field_changed_at > revision
}
None => {
panic!("no data found for field `{id:?}`");
}
}
let data = self.struct_map.get(runtime, id);
let field_changed_at = C::revision(&data.revisions, self.field_index);
field_changed_at > revision
}
fn origin(&self, _key_index: crate::Id) -> Option<crate::runtime::local_state::QueryOrigin> {

View file

@ -0,0 +1,103 @@
//! Test that a `tracked` fn on a `salsa::input`
//! compiles and executes successfully.
use std::cell::Cell;
use expect_test::expect;
use salsa::DebugWithDb;
use salsa_2022_tests::{HasLogger, Logger};
use test_log::test;
thread_local! {
static COUNTER: Cell<usize> = Cell::new(0);
}
#[salsa::jar(db = Db)]
struct Jar(MyInput, MyTracked, function);
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
logger: Logger,
}
impl salsa::Database for Database {
fn salsa_event(&self, event: salsa::Event) {
self.push_log(format!("{:?}", event.debug(self)));
}
}
impl Db for Database {}
impl HasLogger for Database {
fn logger(&self) -> &Logger {
&self.logger
}
}
#[salsa::input]
struct MyInput {
field1: u32,
field2: u32,
}
#[salsa::tracked]
struct MyTracked {
counter: usize,
}
#[salsa::tracked]
fn function(db: &dyn Db, input: MyInput) -> usize {
// Read input 1
let _field1 = input.field1(db);
// **BAD:** Leak in the value of the counter non-deterministically
let counter = COUNTER.with(|c| c.get());
// Create the tracked struct, which (from salsa's POV), only depends on field1;
// but which actually depends on the leaked value.
let tracked = MyTracked::new(db, counter);
// Read input 2. This will cause us to re-execute on revision 2.
let _field2 = input.field2(db);
tracked.counter(db)
}
#[test]
fn test_leaked_inputs_ignored() {
let mut db = Database::default();
let input = MyInput::new(&db, 10, 20);
let result_in_rev_1 = function(&db, input);
db.assert_logs(expect![[r#"
[
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }",
]"#]]);
assert_eq!(result_in_rev_1, 0);
// Modify field2 so that `function` is seen to have changed --
// but only *after* the tracked struct is created.
input.set_field2(&mut db).to(30);
// Also modify the thread-local counter
COUNTER.with(|c| c.set(100));
let result_in_rev_2 = function(&db, input);
db.assert_logs(expect![[r#"
[
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillCheckCancellation }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: function(0) } }",
]"#]]);
// Because salsa did not see any way for the tracked
// struct to have changed, its field values will not have
// been updated, even though in theory they would have
// the leaked value from the counter.
assert_eq!(result_in_rev_2, 0);
}