make fn input/value a GAT

This commit is contained in:
Niko Matsakis 2024-04-27 10:33:54 -04:00
parent b050bd874a
commit 44a8a2f41c
16 changed files with 161 additions and 81 deletions

View file

@ -25,8 +25,8 @@ impl Configuration {
impl salsa::function::Configuration for #self_ty {
type Jar = #jar_ty;
type SalsaStruct = #salsa_struct_ty;
type Input = #input_ty;
type Value = #value_ty;
type Input<'db> = #input_ty;
type Value<'db> = #value_ty;
const CYCLE_STRATEGY: salsa::cycle::CycleRecoveryStrategy = #cycle_strategy;
#backdate_fn
#execute_fn
@ -59,13 +59,13 @@ impl quote::ToTokens for CycleRecoveryStrategy {
pub(crate) fn should_backdate_value_fn(should_backdate: bool) -> syn::ImplItemMethod {
if should_backdate {
parse_quote! {
fn should_backdate_value(v1: &Self::Value, v2: &Self::Value) -> bool {
fn should_backdate_value(v1: &Self::Value<'_>, v2: &Self::Value<'_>) -> bool {
salsa::function::should_backdate_value(v1, v2)
}
}
} else {
parse_quote! {
fn should_backdate_value(_v1: &Self::Value, _v2: &Self::Value) -> bool {
fn should_backdate_value(_v1: &Self::Value<'_>, _v2: &Self::Value<'_>) -> bool {
false
}
}
@ -76,11 +76,11 @@ pub(crate) fn should_backdate_value_fn(should_backdate: bool) -> syn::ImplItemMe
/// the cycle recovery is panic.
pub(crate) fn panic_cycle_recovery_fn() -> syn::ImplItemMethod {
parse_quote! {
fn recover_from_cycle(
_db: &salsa::function::DynDb<Self>,
fn recover_from_cycle<'db>(
_db: &'db salsa::function::DynDb<'db, Self>,
_cycle: &salsa::Cycle,
_key: salsa::Id,
) -> Self::Value {
) -> Self::Value<'db> {
panic!()
}
}

View file

@ -11,17 +11,23 @@ use proc_macro::TokenStream;
macro_rules! parse_quote {
($($inp:tt)*) => {
syn::parse2(quote!{$($inp)*}).unwrap_or_else(|err| {
panic!("failed to parse at {}:{}:{}: {}", file!(), line!(), column!(), err)
})
{
let tt = quote!{$($inp)*};
syn::parse2(tt.clone()).unwrap_or_else(|err| {
panic!("failed to parse `{}` at {}:{}:{}: {}", tt, file!(), line!(), column!(), err)
})
}
}
}
macro_rules! parse_quote_spanned {
($($inp:tt)*) => {
syn::parse2(quote_spanned!{$($inp)*}).unwrap_or_else(|err| {
panic!("failed to parse at {}:{}:{}: {}", file!(), line!(), column!(), err)
})
{
let tt = quote_spanned!{$($inp)*};
syn::parse2(tt.clone()).unwrap_or_else(|err| {
panic!("failed to parse `{}` at {}:{}:{}: {}", tt, file!(), line!(), column!(), err)
})
}
}
}

View file

@ -391,7 +391,7 @@ fn fn_configuration(args: &FnArgs, item_fn: &syn::ItemFn) -> Configuration {
let cycle_strategy = CycleRecoveryStrategy::Fallback;
let cycle_fullback = parse_quote! {
fn recover_from_cycle(__db: &salsa::function::DynDb<Self>, __cycle: &salsa::Cycle, __id: salsa::Id) -> Self::Value {
fn recover_from_cycle<'db>(__db: &'db salsa::function::DynDb<'db, Self>, __cycle: &salsa::Cycle, __id: salsa::Id) -> Self::Value<'db> {
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients =
<_ as salsa::storage::HasIngredientsFor<#fn_ty>>::ingredient(__jar);
@ -422,7 +422,7 @@ fn fn_configuration(args: &FnArgs, item_fn: &syn::ItemFn) -> Configuration {
// keys and then (b) invokes the function itself (which we embed within).
let indices = (0..item_fn.sig.inputs.len() - 1).map(Literal::usize_unsuffixed);
let execute_fn = parse_quote! {
fn execute(__db: &salsa::function::DynDb<Self>, __id: salsa::Id) -> Self::Value {
fn execute<'db>(__db: &'db salsa::function::DynDb<'db, Self>, __id: salsa::Id) -> Self::Value<'db> {
#inner_fn
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);

View file

@ -1,7 +1,6 @@
use std::{fmt, sync::Arc};
use arc_swap::ArcSwap;
use crossbeam::{atomic::AtomicCell, queue::SegQueue};
use crossbeam::atomic::AtomicCell;
use crate::{
cycle::CycleRecoveryStrategy,
@ -13,6 +12,8 @@ use crate::{
Cycle, DbWithJar, Event, EventKind, Id, Revision,
};
use self::delete::DeletedEntries;
use super::{ingredient::Ingredient, routes::IngredientIndex, AsId};
mod accumulated;
@ -45,7 +46,7 @@ pub struct FunctionIngredient<C: Configuration> {
index: IngredientIndex,
/// Tracks the keys for which we have memoized values.
memo_map: memo::MemoMap<Id, C::Value>,
memo_map: memo::MemoMap<C>,
/// Tracks the keys that are currently being processed; used to coordinate between
/// worker threads.
@ -66,7 +67,7 @@ pub struct FunctionIngredient<C: Configuration> {
/// current revision: you would be right, but we are being defensive, because
/// we don't know that we can trust the database to give us the same runtime
/// everytime and so forth.
deleted_entries: SegQueue<ArcSwap<memo::Memo<C::Value>>>,
deleted_entries: DeletedEntries<C>,
/// Set to true once we invoke `register_dependent_fn` for `C::SalsaStruct`.
/// Prevents us from registering more than once.
@ -84,10 +85,10 @@ pub trait Configuration {
type SalsaStruct: for<'db> SalsaStructInDb<DynDb<'db, Self>>;
/// The input to the function
type Input;
type Input<'db>;
/// The value computed by the function.
type Value: fmt::Debug;
type Value<'db>: fmt::Debug;
/// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how).
@ -99,19 +100,19 @@ pub trait Configuration {
/// even though it was recomputed).
///
/// This invokes user's code in form of the `Eq` impl.
fn should_backdate_value(old_value: &Self::Value, new_value: &Self::Value) -> bool;
fn should_backdate_value(old_value: &Self::Value<'_>, new_value: &Self::Value<'_>) -> bool;
/// Invoked when we need to compute the value for the given key, either because we've never
/// computed it before or because the old one relied on inputs that have changed.
///
/// This invokes the function the user wrote.
fn execute(db: &DynDb<Self>, key: Id) -> Self::Value;
fn execute<'db>(db: &'db DynDb<Self>, key: Id) -> Self::Value<'db>;
/// If the cycle strategy is `Recover`, then invoked when `key` is a participant
/// in a cycle to find out what value it should have.
///
/// This invokes the recovery function given by the user.
fn recover_from_cycle(db: &DynDb<Self>, cycle: &Cycle, key: Id) -> Self::Value;
fn recover_from_cycle<'db>(db: &'db DynDb<Self>, cycle: &Cycle, key: Id) -> Self::Value<'db>;
}
/// True if `old_value == new_value`. Invoked by the generated
@ -163,18 +164,18 @@ where
/// only cleared with `&mut self`.
unsafe fn extend_memo_lifetime<'this, 'memo>(
&'this self,
memo: &'memo memo::Memo<C::Value>,
) -> Option<&'this C::Value> {
let memo_value: Option<&'memo C::Value> = memo.value.as_ref();
memo: &'memo memo::Memo<C::Value<'this>>,
) -> Option<&'this C::Value<'this>> {
let memo_value: Option<&'memo C::Value<'this>> = memo.value.as_ref();
std::mem::transmute(memo_value)
}
fn insert_memo(
&self,
db: &DynDb<'_, C>,
fn insert_memo<'db>(
&'db self,
db: &'db DynDb<'db, C>,
key: Id,
memo: memo::Memo<C::Value>,
) -> Option<&C::Value> {
memo: memo::Memo<C::Value<'db>>,
) -> Option<&C::Value<'db>> {
self.register(db);
let memo = Arc::new(memo);
let value = unsafe {

View file

@ -14,7 +14,7 @@ where
{
/// Returns all the values accumulated into `accumulator` by this query and its
/// transitive inputs.
pub fn accumulated<'db, A>(&self, db: &DynDb<'db, C>, key: Id) -> Vec<A::Data>
pub fn accumulated<'db, A>(&'db self, db: &'db DynDb<'db, C>, key: Id) -> Vec<A::Data>
where
DynDb<'db, C>: HasJar<A::Jar>,
A: Accumulator,

View file

@ -11,9 +11,9 @@ where
/// on an old memo when a new memo has been produced to check whether there have been changed.
pub(super) fn backdate_if_appropriate(
&self,
old_memo: &Memo<C::Value>,
old_memo: &Memo<C::Value<'_>>,
revisions: &mut QueryRevisions,
value: &C::Value,
value: &C::Value<'_>,
) {
if let Some(old_value) = &old_memo.value {
// Careful: if the value became less durable than it

View file

@ -1,6 +1,9 @@
use arc_swap::ArcSwap;
use crossbeam::queue::SegQueue;
use crate::{runtime::local_state::QueryOrigin, Id};
use super::{Configuration, FunctionIngredient};
use super::{memo, Configuration, FunctionIngredient};
impl<C> FunctionIngredient<C>
where
@ -18,3 +21,25 @@ where
}
}
}
/// Stores the list of memos that have been deleted so they can be freed
/// once the next revision starts. See the comment on the field
/// `deleted_entries` of [`FunctionIngredient`][] for more details.
pub(super) struct DeletedEntries<C: Configuration> {
seg_queue: SegQueue<ArcSwap<memo::Memo<C::Value<'static>>>>,
}
impl<C: Configuration> Default for DeletedEntries<C> {
fn default() -> Self {
Self {
seg_queue: Default::default(),
}
}
}
impl<C: Configuration> DeletedEntries<C> {
pub(super) fn push<'db>(&'db self, memo: ArcSwap<memo::Memo<C::Value<'db>>>) {
let memo = unsafe { std::mem::transmute(memo) };
self.seg_queue.push(memo);
}
}

View file

@ -15,7 +15,7 @@ where
&self,
db: &DynDb<'_, C>,
key: DatabaseKeyIndex,
old_memo: &Memo<C::Value>,
old_memo: &Memo<C::Value<'_>>,
revisions: &QueryRevisions,
) {
// Iterate over the outputs of the `old_memo` and put them into a hashset

View file

@ -22,12 +22,12 @@ where
/// * `db`, the database.
/// * `active_query`, the active stack frame for the query to execute.
/// * `opt_old_memo`, the older memo, if any existed. Used for backdated.
pub(super) fn execute(
&self,
db: &DynDb<C>,
pub(super) fn execute<'db>(
&'db self,
db: &'db DynDb<'db, C>,
active_query: ActiveQueryGuard<'_>,
opt_old_memo: Option<Arc<Memo<C::Value>>>,
) -> StampedValue<&C::Value> {
opt_old_memo: Option<Arc<Memo<C::Value<'_>>>>,
) -> StampedValue<&C::Value<'db>> {
let runtime = db.runtime();
let revision_now = runtime.current_revision();
let database_key_index = active_query.database_key_index;

View file

@ -8,7 +8,7 @@ impl<C> FunctionIngredient<C>
where
C: Configuration,
{
pub fn fetch(&self, db: &DynDb<C>, key: Id) -> &C::Value {
pub fn fetch<'db>(&'db self, db: &'db DynDb<'db, C>, key: Id) -> &C::Value<'db> {
let runtime = db.runtime();
runtime.unwind_if_revision_cancelled(db);
@ -33,7 +33,11 @@ where
}
#[inline]
fn compute_value(&self, db: &DynDb<C>, key: Id) -> StampedValue<&C::Value> {
fn compute_value<'db>(
&'db self,
db: &'db DynDb<'db, C>,
key: Id,
) -> StampedValue<&'db C::Value<'db>> {
loop {
if let Some(value) = self.fetch_hot(db, key).or_else(|| self.fetch_cold(db, key)) {
return value;
@ -42,7 +46,11 @@ where
}
#[inline]
fn fetch_hot(&self, db: &DynDb<C>, key: Id) -> Option<StampedValue<&C::Value>> {
fn fetch_hot<'db>(
&'db self,
db: &'db DynDb<'db, C>,
key: Id,
) -> Option<StampedValue<&'db C::Value<'db>>> {
let memo_guard = self.memo_map.get(key);
if let Some(memo) = &memo_guard {
if memo.value.is_some() {
@ -59,7 +67,11 @@ where
None
}
fn fetch_cold(&self, db: &DynDb<C>, key: Id) -> Option<StampedValue<&C::Value>> {
fn fetch_cold<'db>(
&'db self,
db: &'db DynDb<'db, C>,
key: Id,
) -> Option<StampedValue<&'db C::Value<'db>>> {
let runtime = db.runtime();
let database_key_index = self.database_key_index(key);

View file

@ -18,7 +18,12 @@ impl<C> FunctionIngredient<C>
where
C: Configuration,
{
pub(super) fn maybe_changed_after(&self, db: &DynDb<C>, key: Id, revision: Revision) -> bool {
pub(super) fn maybe_changed_after<'db>(
&'db self,
db: &'db DynDb<'db, C>,
key: Id,
revision: Revision,
) -> bool {
let runtime = db.runtime();
runtime.unwind_if_revision_cancelled(db);
@ -50,9 +55,9 @@ where
}
}
fn maybe_changed_after_cold(
&self,
db: &DynDb<C>,
fn maybe_changed_after_cold<'db>(
&'db self,
db: &'db DynDb<'db, C>,
key_index: Id,
revision: Revision,
) -> Option<bool> {
@ -104,7 +109,7 @@ where
db: &DynDb<C>,
runtime: &Runtime,
database_key_index: DatabaseKeyIndex,
memo: &Memo<C::Value>,
memo: &Memo<C::Value<'_>>,
) -> bool {
let verified_at = memo.verified_at.load();
let revision_now = runtime.current_revision();
@ -142,7 +147,7 @@ where
pub(super) fn deep_verify_memo(
&self,
db: &DynDb<C>,
old_memo: &Memo<C::Value>,
old_memo: &Memo<C::Value<'_>>,
active_query: &ActiveQueryGuard<'_>,
) -> bool {
let runtime = db.runtime();

View file

@ -4,18 +4,23 @@ use arc_swap::{ArcSwap, Guard};
use crossbeam_utils::atomic::AtomicCell;
use crate::{
hash::FxDashMap, key::DatabaseKeyIndex, runtime::local_state::QueryRevisions, AsId, Event,
EventKind, Revision, Runtime,
hash::FxDashMap, key::DatabaseKeyIndex, runtime::local_state::QueryRevisions, Event, EventKind,
Id, Revision, Runtime,
};
use super::Configuration;
/// The memo map maps from a key of type `K` to the memoized value for that `K`.
/// The memoized value is a `Memo<V>` which contains, in addition to the value `V`,
/// dependency information.
pub(super) struct MemoMap<K: AsId, V> {
map: FxDashMap<K, ArcSwap<Memo<V>>>,
pub(super) struct MemoMap<C: Configuration> {
map: FxDashMap<Id, ArcMemo<'static, C>>,
}
impl<K: AsId, V> Default for MemoMap<K, V> {
#[allow(type_alias_bounds)]
type ArcMemo<'lt, C: Configuration> = ArcSwap<Memo<<C as Configuration>::Value<'lt>>>;
impl<C: Configuration> Default for MemoMap<C> {
fn default() -> Self {
Self {
map: Default::default(),
@ -23,30 +28,56 @@ impl<K: AsId, V> Default for MemoMap<K, V> {
}
}
impl<K: AsId, V> MemoMap<K, V> {
impl<C: Configuration> MemoMap<C> {
/// Memos have to be stored internally using `'static` as the database lifetime.
/// This (unsafe) function call converts from something tied to self to static.
/// Values transmuted this way have to be transmuted back to being tied to self
/// when they are returned to the user.
unsafe fn to_static<'db>(&'db self, value: ArcMemo<'db, C>) -> ArcMemo<'static, C> {
unsafe { std::mem::transmute(value) }
}
/// Convert from an internal memo (which uses statis) to one tied to self
/// so it can be publicly released.
unsafe fn to_self<'db>(&'db self, value: ArcMemo<'static, C>) -> ArcMemo<'db, C> {
unsafe { std::mem::transmute(value) }
}
/// Inserts the memo for the given key; (atomically) overwrites any previously existing memo.-
#[must_use]
pub(super) fn insert(&self, key: K, memo: Arc<Memo<V>>) -> Option<ArcSwap<Memo<V>>> {
self.map.insert(key, ArcSwap::from(memo))
pub(super) fn insert<'db>(
&'db self,
key: Id,
memo: Arc<Memo<C::Value<'db>>>,
) -> Option<ArcSwap<Memo<C::Value<'db>>>> {
unsafe {
let value = ArcSwap::from(memo);
let old_value = self.map.insert(key, self.to_static(value))?;
Some(self.to_self(old_value))
}
}
/// Removes any existing memo for the given key.
#[must_use]
pub(super) fn remove(&self, key: K) -> Option<ArcSwap<Memo<V>>> {
self.map.remove(&key).map(|o| o.1)
pub(super) fn remove<'db>(&'db self, key: Id) -> Option<ArcSwap<Memo<C::Value<'db>>>> {
unsafe { self.map.remove(&key).map(|o| self.to_self(o.1)) }
}
/// Loads the current memo for `key_index`. This does not hold any sort of
/// lock on the `memo_map` once it returns, so this memo could immediately
/// become outdated if other threads store into the `memo_map`.
pub(super) fn get(&self, key: K) -> Option<Guard<Arc<Memo<V>>>> {
self.map.get(&key).map(|v| v.load())
pub(super) fn get<'db>(&self, key: Id) -> Option<Guard<Arc<Memo<C::Value<'db>>>>> {
self.map.get(&key).map(|v| unsafe {
std::mem::transmute::<
Guard<Arc<Memo<C::Value<'static>>>>,
Guard<Arc<Memo<C::Value<'db>>>>,
>(v.load())
})
}
/// Evicts the existing memo for the given key, replacing it
/// with an equivalent memo that has no value. If the memo is untracked, BaseInput,
/// or has values assigned as output of another query, this has no effect.
pub(super) fn evict(&self, key: K) {
pub(super) fn evict(&self, key: Id) {
use crate::runtime::local_state::QueryOrigin;
use dashmap::mapref::entry::Entry::*;
@ -64,7 +95,7 @@ impl<K: AsId, V> MemoMap<K, V> {
QueryOrigin::Derived(_) => {
let memo_evicted = Arc::new(Memo::new(
None::<V>,
None::<C::Value<'_>>,
memo.verified_at.load(),
memo.revisions.clone(),
));

View file

@ -18,13 +18,13 @@ where
/// This is a way to imperatively set the value of a function.
/// It only works if the key is a tracked struct created in the current query.
fn specify<'db>(
&self,
&'db self,
db: &'db DynDb<'db, C>,
key: Id,
value: C::Value,
value: C::Value<'db>,
origin: impl Fn(DatabaseKeyIndex) -> QueryOrigin,
) where
C::Input: TrackedStructInDb<DynDb<'db, C>>,
C::Input<'db>: TrackedStructInDb<DynDb<'db, C>>,
{
let runtime = db.runtime();
@ -45,7 +45,7 @@ where
// * Q4 invokes Q2 and then Q1
//
// Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good.
let database_key_index = <C::Input>::database_key_index(db, key);
let database_key_index = <C::Input<'db>>::database_key_index(db, key);
let dependency_index = database_key_index.into();
if !runtime.is_output_of_active_query(dependency_index) {
panic!("can only use `specfiy` on entities created during current query");
@ -94,9 +94,9 @@ where
/// Specify the value for `key` *and* record that we did so.
/// Used for explicit calls to `specify`, but not needed for pre-declared tracked struct fields.
pub fn specify_and_record<'db>(&self, db: &'db DynDb<'db, C>, key: Id, value: C::Value)
pub fn specify_and_record<'db>(&'db self, db: &'db DynDb<'db, C>, key: Id, value: C::Value<'db>)
where
C::Input: TrackedStructInDb<DynDb<'db, C>>,
C::Input<'db>: TrackedStructInDb<DynDb<'db, C>>,
{
self.specify(db, key, value, |database_key_index| {
QueryOrigin::Assigned(database_key_index)

View file

@ -14,11 +14,11 @@ impl<C> FunctionIngredient<C>
where
C: Configuration,
{
pub fn store(
&mut self,
pub fn store<'db>(
&'db mut self,
runtime: &mut Runtime,
key: Id,
value: C::Value,
value: C::Value<'db>,
durability: Durability,
) {
let revision = runtime.current_revision();

View file

@ -8,9 +8,9 @@ error[E0277]: the trait bound `MyInput: TrackedStructInDb<dyn Db>` is not satisf
note: required by a bound in `function::specify::<impl FunctionIngredient<C>>::specify_and_record`
--> $WORKSPACE/components/salsa-2022/src/function/specify.rs
|
| pub fn specify_and_record<'db>(&self, db: &'db DynDb<'db, C>, key: Id, value: C::Value)
| pub fn specify_and_record<'db>(&'db self, db: &'db DynDb<'db, C>, key: Id, value: C::Value<'db>)
| ------------------ required by a bound in this associated function
| where
| C::Input: TrackedStructInDb<DynDb<'db, C>>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `function::specify::<impl FunctionIngredient<C>>::specify_and_record`
| C::Input<'db>: TrackedStructInDb<DynDb<'db, C>>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `function::specify::<impl FunctionIngredient<C>>::specify_and_record`
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -8,9 +8,9 @@ error[E0277]: the trait bound `MyInterned: TrackedStructInDb<dyn Db>` is not sat
note: required by a bound in `function::specify::<impl FunctionIngredient<C>>::specify_and_record`
--> $WORKSPACE/components/salsa-2022/src/function/specify.rs
|
| pub fn specify_and_record<'db>(&self, db: &'db DynDb<'db, C>, key: Id, value: C::Value)
| pub fn specify_and_record<'db>(&'db self, db: &'db DynDb<'db, C>, key: Id, value: C::Value<'db>)
| ------------------ required by a bound in this associated function
| where
| C::Input: TrackedStructInDb<DynDb<'db, C>>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `function::specify::<impl FunctionIngredient<C>>::specify_and_record`
| C::Input<'db>: TrackedStructInDb<DynDb<'db, C>>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `function::specify::<impl FunctionIngredient<C>>::specify_and_record`
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)