diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index e1165516..58d88875 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -265,24 +265,26 @@ macro_rules! setup_tracked_fn { } } } - let result = $zalsa::macro_if! { - if $needs_interner { - { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); - $Configuration::fn_ingredient($db).fetch($db, key) + $zalsa::attach($db, || { + let result = $zalsa::macro_if! { + if $needs_interner { + { + let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); + $Configuration::fn_ingredient($db).fetch($db, key) + } + } else { + $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) } - } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) - } - }; + }; - $zalsa::macro_if! { - if $return_ref { - result - } else { - <$output_ty as std::clone::Clone>::clone(result) + $zalsa::macro_if! { + if $return_ref { + result + } else { + <$output_ty as std::clone::Clone>::clone(result) + } } - } + }) } }; } diff --git a/src/accumulator.rs b/src/accumulator.rs index dc114bb2..47133790 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,7 +10,7 @@ use crate::{ hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, key::DependencyIndex, - local_state::{self, LocalState, QueryOrigin}, + local_state::{LocalState, QueryOrigin}, storage::IngredientIndex, Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, }; @@ -80,32 +80,30 @@ impl IngredientImpl { } pub fn push(&self, db: &dyn crate::Database, value: A) { - local_state::attach(db, |state| { - let current_revision = db.zalsa().current_revision(); - let (active_query, _) = match state.active_query() { - Some(pair) => pair, - None => { - panic!("cannot accumulate values outside of an active query") - } - }; - - let mut accumulated_values = - self.map.entry(active_query).or_insert(AccumulatedValues { - values: vec![], - produced_at: current_revision, - }); - - // When we call `push' in a query, we will add the accumulator to the output of the query. - // If we find here that this accumulator is not the output of the query, - // we can say that the accumulated values we stored for this query is out of date. - if !state.is_output_of_active_query(self.dependency_index()) { - accumulated_values.values.truncate(0); - accumulated_values.produced_at = current_revision; + let state = db.zalsa_local(); + let current_revision = db.zalsa().current_revision(); + let (active_query, _) = match state.active_query() { + Some(pair) => pair, + None => { + panic!("cannot accumulate values outside of an active query") } + }; - state.add_output(self.dependency_index()); - accumulated_values.values.push(value); - }) + let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues { + values: vec![], + produced_at: current_revision, + }); + + // When we call `push' in a query, we will add the accumulator to the output of the query. + // If we find here that this accumulator is not the output of the query, + // we can say that the accumulated values we stored for this query is out of date. + if !state.is_output_of_active_query(self.dependency_index()) { + accumulated_values.values.truncate(0); + accumulated_values.produced_at = current_revision; + } + + state.add_output(self.dependency_index()); + accumulated_values.values.push(value); } pub(crate) fn produced_by( diff --git a/src/attach.rs b/src/attach.rs new file mode 100644 index 00000000..3dcf9d11 --- /dev/null +++ b/src/attach.rs @@ -0,0 +1,100 @@ +use std::{cell::Cell, ptr::NonNull}; + +use crate::Database; + +thread_local! { + /// The thread-local state salsa requires for a given thread + static ATTACHED: Attached = const { Attached::new() } +} + +/// State that is specific to a single execution thread. +/// +/// Internally, this type uses ref-cells. +/// +/// **Note also that all mutations to the database handle (and hence +/// to the local-state) must be undone during unwinding.** +struct Attached { + /// Pointer to the currently attached database. + database: Cell>>, +} + +impl Attached { + const fn new() -> Self { + Self { + database: Cell::new(None), + } + } + + fn attach(&self, db: &Db, op: impl FnOnce() -> R) -> R + where + Db: ?Sized + Database, + { + struct DbGuard<'s> { + state: Option<&'s Attached>, + } + + impl<'s> DbGuard<'s> { + fn new(attached: &'s Attached, db: &dyn Database) -> Self { + if let Some(current_db) = attached.database.get() { + let new_db = NonNull::from(db); + + // Already attached? Assert that the database has not changed. + // NOTE: It's important to use `addr_eq` here because `NonNull::eq` + // not only compares the address but also the type's metadata. + if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) { + panic!( + "Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}", + ); + } + + Self { state: None } + } else { + // Otherwise, set the database. + attached.database.set(Some(NonNull::from(db))); + Self { + state: Some(attached), + } + } + } + } + + impl Drop for DbGuard<'_> { + fn drop(&mut self) { + // Reset database to null if we did anything in `DbGuard::new`. + if let Some(attached) = self.state { + attached.database.set(None); + } + } + } + + let _guard = DbGuard::new(self, db.as_dyn_database()); + op() + } + + /// Access the "attached" database. Returns `None` if no database is attached. + /// Databases are attached with `attach_database`. + fn with(&self, op: impl FnOnce(&dyn Database) -> R) -> Option { + if let Some(db) = self.database.get() { + // SAFETY: We always attach the database in for the entire duration of a function, + // so it cannot become "unattached" while this function is running. + Some(op(unsafe { db.as_ref() })) + } else { + None + } + } +} + +/// Attach the database to the current thread and execute `op`. +/// Panics if a different database has already been attached. +pub fn attach(db: &Db, op: impl FnOnce() -> R) -> R +where + Db: ?Sized + Database, +{ + ATTACHED.with(|a| a.attach(db, op)) +} + +/// Access the "attached" database. Returns `None` if no database is attached. +/// Databases are attached with `attach_database`. +pub fn with_attached_database(op: impl FnOnce(&dyn Database) -> R) -> Option { + ATTACHED.with(|a| a.with(op)) +} diff --git a/src/cycle.rs b/src/cycle.rs index 4a8a56f4..44558b4a 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,4 +1,4 @@ -use crate::{key::DatabaseKeyIndex, local_state, Database}; +use crate::{key::DatabaseKeyIndex, Database}; use std::{panic::AssertUnwindSafe, sync::Arc}; /// Captures the participants of a cycle that occurred when executing a query. @@ -74,7 +74,7 @@ impl Cycle { impl std::fmt::Debug for Cycle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - local_state::with_attached_database(|db| { + crate::attach::with_attached_database(|db| { f.debug_struct("UnexpectedCycle") .field("all_participants", &self.all_participants(db)) .field("unexpected_participants", &self.unexpected_participants(db)) diff --git a/src/database.rs b/src/database.rs index e79e94ea..3643c6dd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,7 +3,8 @@ use std::{any::Any, panic::RefUnwindSafe, sync::Arc}; use parking_lot::{Condvar, Mutex}; use crate::{ - self as salsa, local_state, + self as salsa, + local_state::{self, LocalState}, storage::{Zalsa, ZalsaImpl}, Durability, Event, EventKind, Revision, }; @@ -16,7 +17,7 @@ use crate::{ /// This trait can only safely be implemented by Salsa's [`DatabaseImpl`][] type. /// FIXME: Document better the unsafety conditions we guarantee. #[salsa_macros::db] -pub unsafe trait Database: AsDynDatabase + Any { +pub unsafe trait Database: Send + AsDynDatabase + Any { /// This function is invoked by the salsa runtime at various points during execution. /// You can customize what happens by implementing the [`UserData`][] trait. /// By default, the event is logged at level debug using tracing facade. @@ -45,9 +46,8 @@ pub unsafe trait Database: AsDynDatabase + Any { /// revision. fn report_untracked_read(&self) { let db = self.as_dyn_database(); - local_state::attach(db, |state| { - state.report_untracked_read(db.zalsa().current_revision()) - }) + let zalsa_local = db.zalsa_local(); + zalsa_local.report_untracked_read(db.zalsa().current_revision()) } /// Execute `op` with the database in thread-local storage for debug print-outs. @@ -55,7 +55,7 @@ pub unsafe trait Database: AsDynDatabase + Any { where Self: Sized, { - local_state::attach(self, |_state| op(self)) + crate::attach::attach(self, || op(self)) } /// Plumbing method: Access the internal salsa methods. @@ -68,6 +68,10 @@ pub unsafe trait Database: AsDynDatabase + Any { /// This can lead to deadlock! #[doc(hidden)] fn zalsa_mut(&mut self) -> &mut dyn Zalsa; + + /// Access the thread-local state associated with this database + #[doc(hidden)] + fn zalsa_local(&self) -> &LocalState; } /// Upcast to a `dyn Database`. @@ -113,6 +117,9 @@ pub struct DatabaseImpl { /// Coordination data for cancellation of other handles when `zalsa_mut` is called. /// This could be stored in ZalsaImpl but it makes things marginally cleaner to keep it separate. coordinate: Arc, + + /// Per-thread state + zalsa_local: local_state::LocalState, } impl Default for DatabaseImpl { @@ -141,6 +148,7 @@ impl DatabaseImpl { clones: Mutex::new(1), cvar: Default::default(), }), + zalsa_local: LocalState::new(), } } @@ -201,6 +209,10 @@ unsafe impl Database for DatabaseImpl { zalsa_mut } + fn zalsa_local(&self) -> &LocalState { + &self.zalsa_local + } + // Report a salsa event. fn salsa_event(&self, event: &dyn Fn() -> Event) { U::salsa_event(self, event) @@ -214,6 +226,7 @@ impl Clone for DatabaseImpl { Self { zalsa_impl: self.zalsa_impl.clone(), coordinate: Arc::clone(&self.coordinate), + zalsa_local: LocalState::new(), } } } @@ -229,7 +242,7 @@ impl Drop for DatabaseImpl { } } -pub trait UserData: Any + Sized { +pub trait UserData: Any + Sized + Send + Sync { /// Callback invoked by the [`Database`][] at key points during salsa execution. /// By overriding this method, you can inject logging or other custom behavior. /// diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 98d11eaa..62c85930 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,4 @@ -use crate::{accumulator, hash::FxHashSet, local_state, Database, DatabaseKeyIndex, Id}; +use crate::{accumulator, hash::FxHashSet, Database, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -12,44 +12,41 @@ where where A: accumulator::Accumulator, { - local_state::attach(db, |local_state| { - let zalsa = db.zalsa(); - let current_revision = zalsa.current_revision(); + let zalsa = db.zalsa(); + let zalsa_local = db.zalsa_local(); + let current_revision = zalsa.current_revision(); - let Some(accumulator) = >::from_db(db) else { - return vec![]; - }; - let mut output = vec![]; + let Some(accumulator) = >::from_db(db) else { + return vec![]; + }; + let mut output = vec![]; - // First ensure the result is up to date - self.fetch(db, key); + // First ensure the result is up to date + self.fetch(db, key); - let db_key = self.database_key_index(key); - let mut visited: FxHashSet = FxHashSet::default(); - let mut stack: Vec = vec![db_key]; + let db_key = self.database_key_index(key); + let mut visited: FxHashSet = FxHashSet::default(); + let mut stack: Vec = vec![db_key]; - while let Some(k) = stack.pop() { - if visited.insert(k) { - accumulator.produced_by(current_revision, local_state, k, &mut output); + while let Some(k) = stack.pop() { + if visited.insert(k) { + accumulator.produced_by(current_revision, zalsa_local, k, &mut output); - let origin = zalsa - .lookup_ingredient(k.ingredient_index) - .origin(k.key_index); - let inputs = origin.iter().flat_map(|origin| origin.inputs()); - // Careful: we want to push in execution order, so reverse order to - // ensure the first child that was executed will be the first child popped - // from the stack. - stack.extend( - inputs - .flat_map(|input| { - TryInto::::try_into(input).into_iter() - }) - .rev(), - ); - } + let origin = zalsa + .lookup_ingredient(k.ingredient_index) + .origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + // Careful: we want to push in execution order, so reverse order to + // ensure the first child that was executed will be the first child popped + // from the stack. + stack.extend( + inputs + .flat_map(|input| TryInto::::try_into(input).into_iter()) + .rev(), + ); } + } - output - }) + output } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f204145f..4e3018b9 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,9 +1,7 @@ use arc_swap::Guard; use crate::{ - local_state::{self, LocalState}, - runtime::StampedValue, - AsDynDatabase as _, Database as _, Id, + local_state::LocalState, runtime::StampedValue, AsDynDatabase as _, Database as _, Id, }; use super::{Configuration, IngredientImpl}; @@ -13,27 +11,26 @@ where C: Configuration, { pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Output<'db> { - local_state::attach(db.as_dyn_database(), |local_state| { - local_state.unwind_if_revision_cancelled(db.as_dyn_database()); + let zalsa_local = db.zalsa_local(); + zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); - let StampedValue { - value, - durability, - changed_at, - } = self.compute_value(db, local_state, key); + let StampedValue { + value, + durability, + changed_at, + } = self.compute_value(db, zalsa_local, key); - if let Some(evicted) = self.lru.record_use(key) { - self.evict(evicted); - } + if let Some(evicted) = self.lru.record_use(key) { + self.evict(evicted); + } - local_state.report_tracked_read( - self.database_key_index(key).into(), - durability, - changed_at, - ); + zalsa_local.report_tracked_read( + self.database_key_index(key).into(), + durability, + changed_at, + ); - value - }) + value } #[inline] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 15a677d5..0d3fc3d4 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,7 +2,7 @@ use arc_swap::Guard; use crate::{ key::DatabaseKeyIndex, - local_state::{self, ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, + local_state::{ActiveQueryGuard, EdgeKind, LocalState, QueryOrigin}, runtime::StampedValue, storage::Zalsa, AsDynDatabase as _, Database, Id, Revision, @@ -20,36 +20,32 @@ where key: Id, revision: Revision, ) -> bool { - local_state::attach(db.as_dyn_database(), |local_state| { - let zalsa = db.zalsa(); - local_state.unwind_if_revision_cancelled(db.as_dyn_database()); + let zalsa_local = db.zalsa_local(); + let zalsa = db.zalsa(); + zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); - loop { - let database_key_index = self.database_key_index(key); + loop { + let database_key_index = self.database_key_index(key); - tracing::debug!( - "{database_key_index:?}: maybe_changed_after(revision = {revision:?})" - ); + tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})"); - // Check if we have a verified version: this is the hot path. - let memo_guard = self.memo_map.get(key); - if let Some(memo) = &memo_guard { - if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { - return memo.revisions.changed_at > revision; - } - drop(memo_guard); // release the arc-swap guard before cold path - if let Some(mcs) = self.maybe_changed_after_cold(db, local_state, key, revision) - { - return mcs; - } else { - // We failed to claim, have to retry. - } - } else { - // No memo? Assume has changed. - return true; + // Check if we have a verified version: this is the hot path. + let memo_guard = self.memo_map.get(key); + if let Some(memo) = &memo_guard { + if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { + return memo.revisions.changed_at > revision; } + drop(memo_guard); // release the arc-swap guard before cold path + if let Some(mcs) = self.maybe_changed_after_cold(db, zalsa_local, key, revision) { + return mcs; + } else { + // We failed to claim, have to retry. + } + } else { + // No memo? Assume has changed. + return true; } - }) + } } fn maybe_changed_after_cold<'db>( diff --git a/src/function/specify.rs b/src/function/specify.rs index d8d5dea8..98945dc5 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,7 +1,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ - local_state::{self, QueryOrigin, QueryRevisions}, + local_state::{QueryOrigin, QueryRevisions}, tracked_struct::TrackedStructInDb, AsDynDatabase as _, Database, DatabaseKeyIndex, Id, }; @@ -18,76 +18,74 @@ where where C::Input<'db>: TrackedStructInDb, { - local_state::attach(db.as_dyn_database(), |state| { - let (active_query_key, current_deps) = match state.active_query() { - Some(v) => v, - None => panic!("can only use `specify` inside a tracked function"), - }; + let zalsa_local = db.zalsa_local(); - // `specify` only works if the key is a tracked struct created in the current query. - // - // The reason is this. We want to ensure that the same result is reached regardless of - // the "path" that the user takes through the execution graph. - // If you permit values to be specified from other queries, you can have a situation like this: - // * Q0 creates the tracked struct T0 - // * Q1 specifies the value for F(T0) - // * Q2 invokes F(T0) - // * Q3 invokes Q1 and then Q2 - // * Q4 invokes Q2 and then Q1 - // - // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. - let database_key_index = >::database_key_index(db.as_dyn_database(), key); - let dependency_index = database_key_index.into(); - if !state.is_output_of_active_query(dependency_index) { - panic!( - "can only use `specify` on salsa structs created during the current tracked fn" - ); - } + let (active_query_key, current_deps) = match zalsa_local.active_query() { + Some(v) => v, + None => panic!("can only use `specify` inside a tracked function"), + }; - // Subtle: we treat the "input" to a set query as if it were - // volatile. - // - // The idea is this. You have the current query C that - // created the entity E, and it is setting the value F(E) of the function F. - // When some other query R reads the field F(E), in order to have obtained - // the entity E, it has to have executed the query C. - // - // This will have forced C to either: - // - // - not create E this time, in which case R shouldn't have it (some kind of leak has occurred) - // - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately - // - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics). - // - // So, ruling out the case of a leak having occurred, that means that the reader R will either see: - // - // - a result that is verified in the current revision, because it was set, which will use the set value - // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) + // `specify` only works if the key is a tracked struct created in the current query. + // + // The reason is this. We want to ensure that the same result is reached regardless of + // the "path" that the user takes through the execution graph. + // If you permit values to be specified from other queries, you can have a situation like this: + // * Q0 creates the tracked struct T0 + // * Q1 specifies the value for F(T0) + // * Q2 invokes F(T0) + // * Q3 invokes Q1 and then Q2 + // * Q4 invokes Q2 and then Q1 + // + // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. + let database_key_index = >::database_key_index(db.as_dyn_database(), key); + let dependency_index = database_key_index.into(); + if !zalsa_local.is_output_of_active_query(dependency_index) { + panic!("can only use `specify` on salsa structs created during the current tracked fn"); + } - let revision = db.zalsa().current_revision(); - let mut revisions = QueryRevisions { - changed_at: current_deps.changed_at, - durability: current_deps.durability, - origin: QueryOrigin::Assigned(active_query_key), - }; + // Subtle: we treat the "input" to a set query as if it were + // volatile. + // + // The idea is this. You have the current query C that + // created the entity E, and it is setting the value F(E) of the function F. + // When some other query R reads the field F(E), in order to have obtained + // the entity E, it has to have executed the query C. + // + // This will have forced C to either: + // + // - not create E this time, in which case R shouldn't have it (some kind of leak has occurred) + // - assign a value to F(E), in which case `verified_at` will be the current revision and `changed_at` will be updated appropriately + // - NOT assign a value to F(E), in which case we need to re-execute the function (which typically panics). + // + // So, ruling out the case of a leak having occurred, that means that the reader R will either see: + // + // - a result that is verified in the current revision, because it was set, which will use the set value + // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) - if let Some(old_memo) = self.memo_map.get(key) { - self.backdate_if_appropriate(&old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, &old_memo, &revisions); - } + let revision = db.zalsa().current_revision(); + let mut revisions = QueryRevisions { + changed_at: current_deps.changed_at, + durability: current_deps.durability, + origin: QueryOrigin::Assigned(active_query_key), + }; - let memo = Memo { - value: Some(value), - verified_at: AtomicCell::new(revision), - revisions, - }; + if let Some(old_memo) = self.memo_map.get(key) { + self.backdate_if_appropriate(&old_memo, &mut revisions, &value); + self.diff_outputs(db, database_key_index, &old_memo, &revisions); + } - tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key); - self.insert_memo(db, key, memo); + let memo = Memo { + value: Some(value), + verified_at: AtomicCell::new(revision), + revisions, + }; - // Record that the current query *specified* a value for this cell. - let database_key_index = self.database_key_index(key); - state.add_output(database_key_index.into()); - }) + tracing::debug!("specify: about to add memo {:#?} for key {:?}", memo, key); + self.insert_memo(db, key, memo); + + // Record that the current query *specified* a value for this cell. + let database_key_index = self.database_key_index(key); + zalsa_local.add_output(database_key_index.into()); } /// Invoked when the query `executor` has been validated as having green inputs diff --git a/src/input.rs b/src/input.rs index 6fa683e6..62fe321c 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,7 +17,7 @@ use crate::{ id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::{self, QueryOrigin}, + local_state::QueryOrigin, plumbing::{Jar, Stamp}, storage::IngredientIndex, Database, Durability, Id, Revision, @@ -152,21 +152,20 @@ impl IngredientImpl { id: C::Struct, field_index: usize, ) -> &'db C::Fields { - local_state::attach(db, |state| { - let field_ingredient_index = self.ingredient_index.successor(field_index); - let id = id.as_id(); - let value = self.struct_map.get(id); - let stamp = &value.stamps[field_index]; - state.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(id), - }, - stamp.durability, - stamp.changed_at, - ); - &value.fields - }) + let zalsa_local = db.zalsa_local(); + let field_ingredient_index = self.ingredient_index.successor(field_index); + let id = id.as_id(); + let value = self.struct_map.get(id); + let stamp = &value.stamps[field_index]; + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(id), + }, + stamp.durability, + stamp.changed_at, + ); + &value.fields } /// Peek at the field values without recording any read dependency. diff --git a/src/interned.rs b/src/interned.rs index d3da16a0..55d9fc3b 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,7 +9,7 @@ use crate::durability::Durability; use crate::id::AsId; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; -use crate::local_state::{self, QueryOrigin}; +use crate::local_state::QueryOrigin; use crate::plumbing::Jar; use crate::storage::IngredientIndex; use crate::{Database, DatabaseKeyIndex, Id}; @@ -136,46 +136,45 @@ where db: &'db dyn crate::Database, data: C::Data<'db>, ) -> C::Struct<'db> { - local_state::attach(db, |state| { - state.report_tracked_read( - DependencyIndex::for_table(self.ingredient_index), - Durability::MAX, - self.reset_at, - ); + let zalsa_local = db.zalsa_local(); + zalsa_local.report_tracked_read( + DependencyIndex::for_table(self.ingredient_index), + Durability::MAX, + self.reset_at, + ); - // Optimisation to only get read lock on the map if the data has already - // been interned. - let internal_data = unsafe { self.to_internal_data(data) }; - if let Some(guard) = self.key_map.get(&internal_data) { - let id = *guard; - drop(guard); - return self.interned_value(id); + // Optimisation to only get read lock on the map if the data has already + // been interned. + let internal_data = unsafe { self.to_internal_data(data) }; + if let Some(guard) = self.key_map.get(&internal_data) { + let id = *guard; + drop(guard); + return self.interned_value(id); + } + + match self.key_map.entry(internal_data.clone()) { + // Data has been interned by a racing call, use that ID instead + dashmap::mapref::entry::Entry::Occupied(entry) => { + let id = *entry.get(); + drop(entry); + self.interned_value(id) } - match self.key_map.entry(internal_data.clone()) { - // Data has been interned by a racing call, use that ID instead - dashmap::mapref::entry::Entry::Occupied(entry) => { - let id = *entry.get(); - drop(entry); - self.interned_value(id) - } - - // We won any races so should intern the data - dashmap::mapref::entry::Entry::Vacant(entry) => { - let next_id = self.counter.fetch_add(1); - let next_id = crate::id::Id::from_u32(next_id); - let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value { - id: next_id, - fields: internal_data, - })); - let value_raw = value.as_raw(); - drop(value); - entry.insert(next_id); - // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. - unsafe { C::struct_from_raw(value_raw) } - } + // We won any races so should intern the data + dashmap::mapref::entry::Entry::Vacant(entry) => { + let next_id = self.counter.fetch_add(1); + let next_id = crate::id::Id::from_u32(next_id); + let value = self.value_map.entry(next_id).or_insert(Alloc::new(Value { + id: next_id, + fields: internal_data, + })); + let value_raw = value.as_raw(); + drop(value); + entry.insert(next_id); + // SAFETY: Items are only removed from the `value_map` with an `&mut self` reference. + unsafe { C::struct_from_raw(value_raw) } } - }) + } } pub fn interned_value(&self, id: Id) -> C::Struct<'_> { diff --git a/src/key.rs b/src/key.rs index b2b70292..4cb5dd4c 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,4 @@ -use crate::{cycle::CycleRecoveryStrategy, local_state, storage::IngredientIndex, Database, Id}; +use crate::{cycle::CycleRecoveryStrategy, storage::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -60,7 +60,7 @@ impl DependencyIndex { impl std::fmt::Debug for DependencyIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - local_state::with_attached_database(|db| { + crate::attach::with_attached_database(|db| { let ingredient = db.zalsa().lookup_ingredient(self.ingredient_index); ingredient.fmt_index(self.key_index, f) }) diff --git a/src/lib.rs b/src/lib.rs index 8aa9685e..42e5ad00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ mod accumulator; mod active_query; mod alloc; mod array; +mod attach; mod cancelled; mod cycle; mod database; @@ -41,7 +42,7 @@ pub use self::key::DatabaseKeyIndex; pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::update::Update; -pub use crate::local_state::with_attached_database; +pub use crate::attach::with_attached_database; pub use salsa_macros::accumulator; pub use salsa_macros::db; pub use salsa_macros::input; @@ -63,6 +64,8 @@ pub mod prelude { pub mod plumbing { pub use crate::accumulator::Accumulator; pub use crate::array::Array; + pub use crate::attach::attach; + pub use crate::attach::with_attached_database; pub use crate::cycle::Cycle; pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; @@ -76,7 +79,6 @@ pub mod plumbing { pub use crate::ingredient::Ingredient; pub use crate::ingredient::Jar; pub use crate::key::DatabaseKeyIndex; - pub use crate::local_state::with_attached_database; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/local_state.rs b/src/local_state.rs index b9a16c4c..606415a2 100644 --- a/src/local_state.rs +++ b/src/local_state.rs @@ -13,49 +13,16 @@ use crate::Database; use crate::Event; use crate::EventKind; use crate::Revision; -use std::cell::Cell; use std::cell::RefCell; -use std::ptr::NonNull; use std::sync::Arc; -thread_local! { - /// The thread-local state salsa requires for a given thread - static LOCAL_STATE: LocalState = const { LocalState::new() } -} - -/// Attach the database to the current thread and execute `op`. -/// Panics if a different database has already been attached. -pub(crate) fn attach(db: &DB, op: impl FnOnce(&LocalState) -> R) -> R -where - DB: ?Sized + Database, -{ - LOCAL_STATE.with(|state| state.attach(db.as_dyn_database(), || op(state))) -} - -/// Access the "attached" database. Returns `None` if no database is attached. -/// Databases are attached with `attach_database`. -pub fn with_attached_database(op: impl FnOnce(&dyn Database) -> R) -> Option { - LOCAL_STATE.with(|state| { - if let Some(db) = state.database.get() { - // SAFETY: We always attach the database in for the entire duration of a function, - // so it cannot become "unattached" while this function is running. - Some(op(unsafe { db.as_ref() })) - } else { - None - } - }) -} - /// State that is specific to a single execution thread. /// /// Internally, this type uses ref-cells. /// /// **Note also that all mutations to the database handle (and hence /// to the local-state) must be undone during unwinding.** -pub(crate) struct LocalState { - /// Pointer to the currently attached database. - database: Cell>>, - +pub struct LocalState { /// Vector of active queries. /// /// This is normally `Some`, but it is set to `None` @@ -67,56 +34,12 @@ pub(crate) struct LocalState { } impl LocalState { - const fn new() -> Self { + pub(crate) fn new() -> Self { LocalState { - database: Cell::new(None), query_stack: RefCell::new(Some(vec![])), } } - fn attach(&self, db: &dyn Database, op: impl FnOnce() -> R) -> R { - struct DbGuard<'s> { - state: Option<&'s LocalState>, - } - - impl<'s> DbGuard<'s> { - fn new(state: &'s LocalState, db: &dyn Database) -> Self { - if let Some(current_db) = state.database.get() { - let new_db = NonNull::from(db); - - // Already attached? Assert that the database has not changed. - // NOTE: It's important to use `addr_eq` here because `NonNull::eq` not only compares the address but also the type's metadata. - if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) { - panic!( - "Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}", - ); - } - - Self { state: None } - } else { - // Otherwise, set the database. - state.database.set(Some(NonNull::from(db))); - Self { state: Some(state) } - } - } - } - - impl Drop for DbGuard<'_> { - fn drop(&mut self) { - // Reset database to null if we did anything in `DbGuard::new`. - if let Some(state) = self.state { - state.database.set(None); - - // All stack frames should have been popped from the local stack. - assert!(state.query_stack.borrow().as_ref().unwrap().is_empty()); - } - } - } - - let _guard = DbGuard::new(self, db); - op() - } - #[inline] pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { let mut query_stack = self.query_stack.borrow_mut(); diff --git a/src/storage.rs b/src/storage.rs index 77fc6210..d37ada2c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,12 +5,12 @@ use parking_lot::Mutex; use rustc_hash::FxHashMap; use crate::cycle::CycleRecoveryStrategy; -use crate::database::{DatabaseImpl, UserData}; +use crate::database::UserData; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; -use crate::views::{Views, ViewsOf}; -use crate::{Database, Durability, Revision}; +use crate::views::Views; +use crate::{Database, DatabaseImpl, Durability, Revision}; pub fn views(db: &Db) -> &Views { db.zalsa().views() @@ -191,7 +191,7 @@ impl IngredientIndex { pub(crate) struct ZalsaImpl { user_data: U, - views_of: ViewsOf>, + views_of: Views, nonce: Nonce, @@ -227,7 +227,7 @@ impl Default for ZalsaImpl { impl ZalsaImpl { pub(crate) fn with(user_data: U) -> Self { Self { - views_of: Default::default(), + views_of: Views::new::>(), nonce: NONCE.nonce(), jar_map: Default::default(), ingredients_vec: Default::default(), diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 45136529..b8460c34 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -11,7 +11,7 @@ use crate::{ ingredient::{fmt_index, Ingredient, Jar}, ingredient_list::IngredientList, key::{DatabaseKeyIndex, DependencyIndex}, - local_state::{self, QueryOrigin}, + local_state::QueryOrigin, salsa_struct::SalsaStructInDb, storage::IngredientIndex, Database, Durability, Event, Id, Revision, @@ -290,86 +290,85 @@ where db: &'db dyn Database, fields: C::Fields<'db>, ) -> C::Struct<'db> { - local_state::attach(db, |local_state| { - let zalsa = db.zalsa(); + let zalsa = db.zalsa(); + let zalsa_local = db.zalsa_local(); - let data_hash = crate::hash::hash(&C::id_fields(&fields)); + let data_hash = crate::hash::hash(&C::id_fields(&fields)); - let (query_key, current_deps, disambiguator) = - local_state.disambiguate(self.ingredient_index, Revision::start(), data_hash); + let (query_key, current_deps, disambiguator) = + zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash); - let entity_key = KeyStruct { - query_key, - disambiguator, - data_hash, - }; + let entity_key = KeyStruct { + query_key, + disambiguator, + data_hash, + }; - let (id, new_id) = self.intern(entity_key); - local_state.add_output(self.database_key_index(id).into()); + let (id, new_id) = self.intern(entity_key); + zalsa_local.add_output(self.database_key_index(id).into()); - let current_revision = zalsa.current_revision(); - if new_id { - // This is a new tracked struct, so create an entry in the struct map. + let current_revision = zalsa.current_revision(); + if new_id { + // This is a new tracked struct, so create an entry in the struct map. - self.struct_map.insert( - current_revision, - Value { - id, - key: entity_key, - struct_ingredient_index: self.ingredient_index, - created_at: current_revision, - durability: current_deps.durability, - fields: unsafe { self.to_static(fields) }, - revisions: C::new_revisions(current_deps.changed_at), - }, - ) - } else { - // The struct already exists in the intern map. - // Note that we assume there is at most one executing copy of - // the current query at a time, which implies that the - // struct must exist in `self.struct_map` already - // (if the same query could execute twice in parallel, - // then it would potentially create the same struct twice in parallel, - // which means the interned key could exist but `struct_map` not yet have - // been updated). + self.struct_map.insert( + current_revision, + Value { + id, + key: entity_key, + struct_ingredient_index: self.ingredient_index, + created_at: current_revision, + durability: current_deps.durability, + fields: unsafe { self.to_static(fields) }, + revisions: C::new_revisions(current_deps.changed_at), + }, + ) + } else { + // The struct already exists in the intern map. + // Note that we assume there is at most one executing copy of + // the current query at a time, which implies that the + // struct must exist in `self.struct_map` already + // (if the same query could execute twice in parallel, + // then it would potentially create the same struct twice in parallel, + // which means the interned key could exist but `struct_map` not yet have + // been updated). - match self.struct_map.update(current_revision, id) { - Update::Current(r) => { - // All inputs up to this point were previously - // observed to be green and this struct was already - // verified. Therefore, the durability ought not to have - // changed (nor the field values, but the user could've - // done something stupid, so we can't *assert* this is true). - assert!(C::deref_struct(r).durability == current_deps.durability); + match self.struct_map.update(current_revision, id) { + Update::Current(r) => { + // All inputs up to this point were previously + // observed to be green and this struct was already + // verified. Therefore, the durability ought not to have + // changed (nor the field values, but the user could've + // done something stupid, so we can't *assert* this is true). + assert!(C::deref_struct(r).durability == current_deps.durability); - r + r + } + Update::Outdated(mut data_ref) => { + let data = &mut *data_ref; + + // SAFETY: We assert that the pointer to `data.revisions` + // is a pointer into the database referencing a value + // from a previous revision. As such, it continues to meet + // its validity invariant and any owned content also continues + // to meet its safety invariant. + unsafe { + C::update_fields( + current_revision, + &mut data.revisions, + self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), + fields, + ); } - Update::Outdated(mut data_ref) => { - let data = &mut *data_ref; - - // SAFETY: We assert that the pointer to `data.revisions` - // is a pointer into the database referencing a value - // from a previous revision. As such, it continues to meet - // its validity invariant and any owned content also continues - // to meet its safety invariant. - unsafe { - C::update_fields( - current_revision, - &mut data.revisions, - self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)), - fields, - ); - } - if current_deps.durability < data.durability { - data.revisions = C::new_revisions(current_revision); - } - data.durability = current_deps.durability; - data.created_at = current_revision; - data_ref.freeze() + if current_deps.durability < data.durability { + data.revisions = C::new_revisions(current_revision); } + data.durability = current_deps.durability; + data.created_at = current_revision; + data_ref.freeze() } } - }) + } } /// Given the id of a tracked struct created in this revision, @@ -520,21 +519,20 @@ where db: &dyn crate::Database, field_index: usize, ) -> &'db C::Fields<'db> { - local_state::attach(db, |local_state| { - let field_ingredient_index = self.struct_ingredient_index.successor(field_index); - let changed_at = self.revisions[field_index]; + let zalsa_local = db.zalsa_local(); + let field_ingredient_index = self.struct_ingredient_index.successor(field_index); + let changed_at = self.revisions[field_index]; - local_state.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(self.id.as_id()), - }, - self.durability, - changed_at, - ); + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: field_ingredient_index, + key_index: Some(self.id.as_id()), + }, + self.durability, + changed_at, + ); - unsafe { self.to_self_ref(&self.fields) } - }) + unsafe { self.to_self_ref(&self.fields) } } unsafe fn to_self_ref<'db>(&'db self, fields: &'db C::Fields<'static>) -> &'db C::Fields<'db> { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 7a6b7b42..2061d1f0 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,6 +1,5 @@ use crate::{ - id::AsId, ingredient::Ingredient, key::DependencyIndex, local_state, storage::IngredientIndex, - Database, Id, + id::AsId, ingredient::Ingredient, key::DependencyIndex, storage::IngredientIndex, Database, Id, }; use super::{struct_map::StructMapView, Configuration}; @@ -47,23 +46,22 @@ where /// Note that this function returns the entire tuple of value fields. /// The caller is responible for selecting the appropriate element. pub fn field<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - local_state::attach(db, |local_state| { - let current_revision = db.zalsa().current_revision(); - let data = self.struct_map.get(current_revision, id); - let data = C::deref_struct(data); - let changed_at = data.revisions[self.field_index]; + let zalsa_local = db.zalsa_local(); + let current_revision = db.zalsa().current_revision(); + let data = self.struct_map.get(current_revision, id); + let data = C::deref_struct(data); + let changed_at = data.revisions[self.field_index]; - local_state.report_tracked_read( - DependencyIndex { - ingredient_index: self.ingredient_index, - key_index: Some(id.as_id()), - }, - data.durability, - changed_at, - ); + zalsa_local.report_tracked_read( + DependencyIndex { + ingredient_index: self.ingredient_index, + key_index: Some(id.as_id()), + }, + data.durability, + changed_at, + ); - unsafe { self.to_self_ref(&data.fields) } - }) + unsafe { self.to_self_ref(&data.fields) } } } diff --git a/src/views.rs b/src/views.rs index 75369e75..19798737 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,7 +1,5 @@ use std::{ any::{Any, TypeId}, - marker::PhantomData, - ops::Deref, sync::Arc, }; @@ -9,11 +7,6 @@ use orx_concurrent_vec::ConcurrentVec; use crate::Database; -pub struct ViewsOf { - upcasts: Views, - phantom: PhantomData, -} - #[derive(Clone)] pub struct Views { source_type_id: TypeId, @@ -29,25 +22,8 @@ struct ViewCaster { #[allow(dead_code)] enum Dummy {} -impl Default for ViewsOf { - fn default() -> Self { - Self { - upcasts: Views::new::(), - phantom: Default::default(), - } - } -} - -impl Deref for ViewsOf { - type Target = Views; - - fn deref(&self) -> &Self::Target { - &self.upcasts - } -} - impl Views { - fn new() -> Self { + pub(crate) fn new() -> Self { let source_type_id = TypeId::of::(); Self { source_type_id, @@ -127,12 +103,3 @@ fn data_ptr(t: &T) -> &() { let u: *const () = t as *const (); unsafe { &*u } } - -impl Clone for ViewsOf { - fn clone(&self) -> Self { - Self { - upcasts: self.upcasts.clone(), - phantom: self.phantom, - } - } -} diff --git a/tests/accumulate.rs b/tests/accumulate.rs index e20cc05f..ea16666d 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -55,7 +55,7 @@ fn push_b_logs(db: &dyn LogDatabase, input: MyInput) { #[test] fn accumulate_once() { - let mut db = salsa::DatabaseImpl::with(Logger::default()); + let db = salsa::DatabaseImpl::with(Logger::default()); // Just call accumulate on a base input to see what happens. let input = MyInput::new(&db, 2, 3); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index cb741f37..752d9e4c 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -25,7 +25,7 @@ pub trait LogDatabase: HasLogger + salsa::Database { /// Asserts what the (formatted) logs should look like, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. - fn assert_logs(&mut self, expected: expect_test::Expect) { + fn assert_logs(&self, expected: expect_test::Expect) { let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap()); expected.assert_eq(&format!("{:#?}", logs)); } @@ -33,7 +33,7 @@ pub trait LogDatabase: HasLogger + salsa::Database { /// Asserts the length of the logs, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. - fn assert_logs_len(&mut self, expected: usize) { + fn assert_logs_len(&self, expected: usize) { let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap()); assert_eq!(logs.len(), expected); }