WIP checkpoint

This commit is contained in:
Niko Matsakis 2024-07-12 07:51:28 -04:00
parent 15106ff8ea
commit 73b8134345
12 changed files with 132 additions and 36 deletions

View file

@ -1,4 +1,4 @@
use std::{fmt, sync::Arc}; use std::{any::Any, fmt, sync::Arc};
use crossbeam::atomic::AtomicCell; use crossbeam::atomic::AtomicCell;
@ -30,7 +30,9 @@ mod specify;
mod store; mod store;
mod sync; mod sync;
pub trait Configuration: 'static { pub mod interned;
pub trait Configuration: Any {
const DEBUG_NAME: &'static str; const DEBUG_NAME: &'static str;
/// The database that this function is associated with. /// The database that this function is associated with.
@ -45,7 +47,7 @@ pub trait Configuration: 'static {
type Input<'db>: Send + Sync; type Input<'db>: Send + Sync;
/// The value computed by the function. /// The value computed by the function.
type Value<'db>: fmt::Debug + Send + Sync; type Output<'db>: fmt::Debug + Send + Sync;
/// Determines whether this function can recover from being a participant in a cycle /// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how). /// (and, if so, how).
@ -57,19 +59,23 @@ pub trait Configuration: 'static {
/// even though it was recomputed). /// even though it was recomputed).
/// ///
/// This invokes user's code in form of the `Eq` impl. /// This invokes user's code in form of the `Eq` impl.
fn should_backdate_value(old_value: &Self::Value<'_>, new_value: &Self::Value<'_>) -> bool; fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool;
/// Convert from the id used internally to the value that execute is expecting.
/// This is a no-op if the input to the function is a salsa struct.
fn id_to_input<'db>(db: &'db Self::DbView, key: Id) -> Self::Input<'db>;
/// Invoked when we need to compute the value for the given key, either because we've never /// Invoked when we need to compute the value for the given key, either because we've never
/// computed it before or because the old one relied on inputs that have changed. /// computed it before or because the old one relied on inputs that have changed.
/// ///
/// This invokes the function the user wrote. /// This invokes the function the user wrote.
fn execute<'db>(db: &'db Self::DbView, key: Id) -> Self::Value<'db>; fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;
/// If the cycle strategy is `Recover`, then invoked when `key` is a participant /// If the cycle strategy is `Recover`, then invoked when `key` is a participant
/// in a cycle to find out what value it should have. /// in a cycle to find out what value it should have.
/// ///
/// This invokes the recovery function given by the user. /// This invokes the recovery function given by the user.
fn recover_from_cycle<'db>(db: &'db Self::DbView, cycle: &Cycle, key: Id) -> Self::Value<'db>; fn recover_from_cycle<'db>(db: &'db Self::DbView, cycle: &Cycle, key: Id) -> Self::Output<'db>;
} }
/// Function ingredients are the "workhorse" of salsa. /// Function ingredients are the "workhorse" of salsa.
@ -162,9 +168,9 @@ where
/// only cleared with `&mut self`. /// only cleared with `&mut self`.
unsafe fn extend_memo_lifetime<'this, 'memo>( unsafe fn extend_memo_lifetime<'this, 'memo>(
&'this self, &'this self,
memo: &'memo memo::Memo<C::Value<'this>>, memo: &'memo memo::Memo<C::Output<'this>>,
) -> Option<&'this C::Value<'this>> { ) -> Option<&'this C::Output<'this>> {
let memo_value: Option<&'memo C::Value<'this>> = memo.value.as_ref(); let memo_value: Option<&'memo C::Output<'this>> = memo.value.as_ref();
std::mem::transmute(memo_value) std::mem::transmute(memo_value)
} }
@ -172,8 +178,8 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
key: Id, key: Id,
memo: memo::Memo<C::Value<'db>>, memo: memo::Memo<C::Output<'db>>,
) -> Option<&C::Value<'db>> { ) -> Option<&C::Output<'db>> {
self.register(db); self.register(db);
let memo = Arc::new(memo); let memo = Arc::new(memo);
let value = unsafe { let value = unsafe {

View file

@ -11,9 +11,9 @@ where
/// on an old memo when a new memo has been produced to check whether there have been changed. /// on an old memo when a new memo has been produced to check whether there have been changed.
pub(super) fn backdate_if_appropriate( pub(super) fn backdate_if_appropriate(
&self, &self,
old_memo: &Memo<C::Value<'_>>, old_memo: &Memo<C::Output<'_>>,
revisions: &mut QueryRevisions, revisions: &mut QueryRevisions,
value: &C::Value<'_>, value: &C::Output<'_>,
) { ) {
if let Some(old_value) = &old_memo.value { if let Some(old_value) = &old_memo.value {
// Careful: if the value became less durable than it // Careful: if the value became less durable than it

View file

@ -26,7 +26,7 @@ where
/// once the next revision starts. See the comment on the field /// once the next revision starts. See the comment on the field
/// `deleted_entries` of [`FunctionIngredient`][] for more details. /// `deleted_entries` of [`FunctionIngredient`][] for more details.
pub(super) struct DeletedEntries<C: Configuration> { pub(super) struct DeletedEntries<C: Configuration> {
seg_queue: SegQueue<ArcSwap<memo::Memo<C::Value<'static>>>>, seg_queue: SegQueue<ArcSwap<memo::Memo<C::Output<'static>>>>,
} }
impl<C: Configuration> Default for DeletedEntries<C> { impl<C: Configuration> Default for DeletedEntries<C> {
@ -38,7 +38,7 @@ impl<C: Configuration> Default for DeletedEntries<C> {
} }
impl<C: Configuration> DeletedEntries<C> { impl<C: Configuration> DeletedEntries<C> {
pub(super) fn push<'db>(&'db self, memo: ArcSwap<memo::Memo<C::Value<'db>>>) { pub(super) fn push<'db>(&'db self, memo: ArcSwap<memo::Memo<C::Output<'db>>>) {
let memo = unsafe { std::mem::transmute(memo) }; let memo = unsafe { std::mem::transmute(memo) };
self.seg_queue.push(memo); self.seg_queue.push(memo);
} }

View file

@ -15,7 +15,7 @@ where
&self, &self,
db: &C::DbView, db: &C::DbView,
key: DatabaseKeyIndex, key: DatabaseKeyIndex,
old_memo: &Memo<C::Value<'_>>, old_memo: &Memo<C::Output<'_>>,
revisions: &QueryRevisions, revisions: &QueryRevisions,
) { ) {
// Iterate over the outputs of the `old_memo` and put them into a hashset // Iterate over the outputs of the `old_memo` and put them into a hashset

View file

@ -25,8 +25,8 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
active_query: ActiveQueryGuard<'_>, active_query: ActiveQueryGuard<'_>,
opt_old_memo: Option<Arc<Memo<C::Value<'_>>>>, opt_old_memo: Option<Arc<Memo<C::Output<'_>>>>,
) -> StampedValue<&C::Value<'db>> { ) -> StampedValue<&C::Output<'db>> {
let runtime = db.runtime(); let runtime = db.runtime();
let revision_now = runtime.current_revision(); let revision_now = runtime.current_revision();
let database_key_index = active_query.database_key_index; let database_key_index = active_query.database_key_index;
@ -44,7 +44,7 @@ where
// stale, or value is absent. Let's execute! // stale, or value is absent. Let's execute!
let database_key_index = active_query.database_key_index; let database_key_index = active_query.database_key_index;
let key = database_key_index.key_index; let key = database_key_index.key_index;
let value = match Cycle::catch(|| C::execute(db, key)) { let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, key))) {
Ok(v) => v, Ok(v) => v,
Err(cycle) => { Err(cycle) => {
log::debug!( log::debug!(

View file

@ -8,7 +8,7 @@ impl<C> IngredientImpl<C>
where where
C: Configuration, C: Configuration,
{ {
pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Value<'db> { pub fn fetch<'db>(&'db self, db: &'db C::DbView, key: Id) -> &C::Output<'db> {
let runtime = db.runtime(); let runtime = db.runtime();
runtime.unwind_if_revision_cancelled(db); runtime.unwind_if_revision_cancelled(db);
@ -37,7 +37,7 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
key: Id, key: Id,
) -> StampedValue<&'db C::Value<'db>> { ) -> StampedValue<&'db C::Output<'db>> {
loop { loop {
if let Some(value) = self.fetch_hot(db, key).or_else(|| self.fetch_cold(db, key)) { if let Some(value) = self.fetch_hot(db, key).or_else(|| self.fetch_cold(db, key)) {
return value; return value;
@ -50,7 +50,7 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
key: Id, key: Id,
) -> Option<StampedValue<&'db C::Value<'db>>> { ) -> Option<StampedValue<&'db C::Output<'db>>> {
let memo_guard = self.memo_map.get(key); let memo_guard = self.memo_map.get(key);
if let Some(memo) = &memo_guard { if let Some(memo) = &memo_guard {
if memo.value.is_some() { if memo.value.is_some() {
@ -71,7 +71,7 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
key: Id, key: Id,
) -> Option<StampedValue<&'db C::Value<'db>>> { ) -> Option<StampedValue<&'db C::Output<'db>>> {
let runtime = db.runtime(); let runtime = db.runtime();
let database_key_index = self.database_key_index(key); let database_key_index = self.database_key_index(key);

88
src/function/interned.rs Normal file
View file

@ -0,0 +1,88 @@
//! Helper code for tracked functions that take arbitrary arguments.
//! These arguments must be interned to create a salsa id before the
//! salsa machinery can execute.
use std::{any::Any, fmt, hash::Hash, marker::PhantomData};
use crate::{
function, interned, plumbing::CycleRecoveryStrategy, salsa_struct::SalsaStructInDb, Cycle, Id,
};
pub trait Configuration: Any + Copy {
const DEBUG_NAME: &'static str;
type DbView: ?Sized + crate::Database;
type SalsaStruct<'db>: SalsaStructInDb<Self::DbView>;
type Input<'db>: Send + Sync + Clone + Hash + Eq;
type Output<'db>: fmt::Debug + Send + Sync;
const CYCLE_STRATEGY: CycleRecoveryStrategy;
fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool;
fn id_to_input<'db>(db: &'db Self::DbView, key: Id) -> Self::Input<'db>;
fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;
fn recover_from_cycle<'db>(db: &'db Self::DbView, cycle: &Cycle, key: Id) -> Self::Output<'db>;
}
pub struct InterningConfiguration<C: Configuration> {
phantom: PhantomData<C>,
}
#[derive(Copy, Clone)]
pub struct InternedData<'db, C: Configuration>(
std::ptr::NonNull<interned::ValueStruct<C>>,
std::marker::PhantomData<&'db interned::ValueStruct<C>>,
);
impl<C: Configuration> SalsaStructInDb<C::DbView> for InternedData<'_, C> {
fn register_dependent_fn(_db: &C::DbView, _index: crate::storage::IngredientIndex) {}
}
impl<C: Configuration> interned::Configuration for C {
const DEBUG_NAME: &'static str = C::DEBUG_NAME;
type Data<'db> = C::Input<'db>;
type Struct<'db> = InternedData<'db, C>;
unsafe fn struct_from_raw<'db>(
ptr: std::ptr::NonNull<interned::ValueStruct<Self>>,
) -> Self::Struct<'db> {
InternedData(ptr, std::marker::PhantomData)
}
fn deref_struct(s: Self::Struct<'_>) -> &interned::ValueStruct<Self> {
unsafe { s.0.as_ref() }
}
}
impl<C: Configuration> function::Configuration for C {
const DEBUG_NAME: &'static str = C::DEBUG_NAME;
type DbView = C::DbView;
type SalsaStruct<'db> = InternedData<'db, C>;
type Input<'db> = C::Input<'db>;
type Output<'db> = C::Output<'db>;
const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = C::CYCLE_STRATEGY;
fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool {
C::should_backdate_value(old_value, new_value)
}
fn id_to_input<'db>(db: &'db Self::DbView, key: crate::Id) -> Self::Input<'db> {
todo!()
}
fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db> {
todo!()
}
fn recover_from_cycle<'db>(
db: &'db Self::DbView,
cycle: &crate::Cycle,
key: crate::Id,
) -> Self::Output<'db> {
todo!()
}
}

View file

@ -101,7 +101,7 @@ where
db: &C::DbView, db: &C::DbView,
runtime: &Runtime, runtime: &Runtime,
database_key_index: DatabaseKeyIndex, database_key_index: DatabaseKeyIndex,
memo: &Memo<C::Value<'_>>, memo: &Memo<C::Output<'_>>,
) -> bool { ) -> bool {
let verified_at = memo.verified_at.load(); let verified_at = memo.verified_at.load();
let revision_now = runtime.current_revision(); let revision_now = runtime.current_revision();
@ -135,7 +135,7 @@ where
pub(super) fn deep_verify_memo( pub(super) fn deep_verify_memo(
&self, &self,
db: &C::DbView, db: &C::DbView,
old_memo: &Memo<C::Value<'_>>, old_memo: &Memo<C::Output<'_>>,
active_query: &ActiveQueryGuard<'_>, active_query: &ActiveQueryGuard<'_>,
) -> bool { ) -> bool {
let runtime = db.runtime(); let runtime = db.runtime();

View file

@ -18,7 +18,7 @@ pub(super) struct MemoMap<C: Configuration> {
} }
#[allow(type_alias_bounds)] #[allow(type_alias_bounds)]
type ArcMemo<'lt, C: Configuration> = ArcSwap<Memo<<C as Configuration>::Value<'lt>>>; type ArcMemo<'lt, C: Configuration> = ArcSwap<Memo<<C as Configuration>::Output<'lt>>>;
impl<C: Configuration> Default for MemoMap<C> { impl<C: Configuration> Default for MemoMap<C> {
fn default() -> Self { fn default() -> Self {
@ -47,8 +47,8 @@ impl<C: Configuration> MemoMap<C> {
pub(super) fn insert<'db>( pub(super) fn insert<'db>(
&'db self, &'db self,
key: Id, key: Id,
memo: Arc<Memo<C::Value<'db>>>, memo: Arc<Memo<C::Output<'db>>>,
) -> Option<ArcSwap<Memo<C::Value<'db>>>> { ) -> Option<ArcSwap<Memo<C::Output<'db>>>> {
unsafe { unsafe {
let value = ArcSwap::from(memo); let value = ArcSwap::from(memo);
let old_value = self.map.insert(key, self.to_static(value))?; let old_value = self.map.insert(key, self.to_static(value))?;
@ -58,18 +58,18 @@ impl<C: Configuration> MemoMap<C> {
/// Removes any existing memo for the given key. /// Removes any existing memo for the given key.
#[must_use] #[must_use]
pub(super) fn remove(&self, key: Id) -> Option<ArcSwap<Memo<C::Value<'_>>>> { pub(super) fn remove(&self, key: Id) -> Option<ArcSwap<Memo<C::Output<'_>>>> {
unsafe { self.map.remove(&key).map(|o| self.to_self(o.1)) } unsafe { self.map.remove(&key).map(|o| self.to_self(o.1)) }
} }
/// Loads the current memo for `key_index`. This does not hold any sort of /// Loads the current memo for `key_index`. This does not hold any sort of
/// lock on the `memo_map` once it returns, so this memo could immediately /// lock on the `memo_map` once it returns, so this memo could immediately
/// become outdated if other threads store into the `memo_map`. /// become outdated if other threads store into the `memo_map`.
pub(super) fn get<'db>(&self, key: Id) -> Option<Guard<Arc<Memo<C::Value<'db>>>>> { pub(super) fn get<'db>(&self, key: Id) -> Option<Guard<Arc<Memo<C::Output<'db>>>>> {
self.map.get(&key).map(|v| unsafe { self.map.get(&key).map(|v| unsafe {
std::mem::transmute::< std::mem::transmute::<
Guard<Arc<Memo<C::Value<'static>>>>, Guard<Arc<Memo<C::Output<'static>>>>,
Guard<Arc<Memo<C::Value<'db>>>>, Guard<Arc<Memo<C::Output<'db>>>>,
>(v.load()) >(v.load())
}) })
} }
@ -95,7 +95,7 @@ impl<C: Configuration> MemoMap<C> {
QueryOrigin::Derived(_) => { QueryOrigin::Derived(_) => {
let memo_evicted = Arc::new(Memo::new( let memo_evicted = Arc::new(Memo::new(
None::<C::Value<'_>>, None::<C::Output<'_>>,
memo.verified_at.load(), memo.verified_at.load(),
memo.revisions.clone(), memo.revisions.clone(),
)); ));

View file

@ -20,7 +20,7 @@ where
&'db self, &'db self,
db: &'db C::DbView, db: &'db C::DbView,
key: Id, key: Id,
value: C::Value<'db>, value: C::Output<'db>,
origin: impl Fn(DatabaseKeyIndex) -> QueryOrigin, origin: impl Fn(DatabaseKeyIndex) -> QueryOrigin,
) where ) where
C::Input<'db>: TrackedStructInDb<C::DbView>, C::Input<'db>: TrackedStructInDb<C::DbView>,
@ -93,7 +93,7 @@ where
/// Specify the value for `key` *and* record that we did so. /// Specify the value for `key` *and* record that we did so.
/// Used for explicit calls to `specify`, but not needed for pre-declared tracked struct fields. /// Used for explicit calls to `specify`, but not needed for pre-declared tracked struct fields.
pub fn specify_and_record<'db>(&'db self, db: &'db C::DbView, key: Id, value: C::Value<'db>) pub fn specify_and_record<'db>(&'db self, db: &'db C::DbView, key: Id, value: C::Output<'db>)
where where
C::Input<'db>: TrackedStructInDb<C::DbView>, C::Input<'db>: TrackedStructInDb<C::DbView>,
{ {

View file

@ -18,7 +18,7 @@ where
&'db mut self, &'db mut self,
runtime: &mut Runtime, runtime: &mut Runtime,
key: Id, key: Id,
value: C::Value<'db>, value: C::Output<'db>,
durability: Durability, durability: Durability,
) { ) {
let revision = runtime.current_revision(); let revision = runtime.current_revision();

View file

@ -61,8 +61,10 @@ pub mod plumbing {
pub use crate::ingredient::Jar; pub use crate::ingredient::Jar;
pub use crate::salsa_struct::SalsaStructInDb; pub use crate::salsa_struct::SalsaStructInDb;
pub use crate::storage::views; pub use crate::storage::views;
pub use crate::storage::HasStorage;
pub use crate::storage::IngredientCache; pub use crate::storage::IngredientCache;
pub use crate::storage::IngredientIndex; pub use crate::storage::IngredientIndex;
pub use crate::storage::Storage;
pub mod input { pub mod input {
pub use crate::input::Configuration; pub use crate::input::Configuration;