From b66eb8131191a12ba9d2a0b14760d8f0b60bdffd Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Mon, 6 Jul 2020 00:55:01 +0000 Subject: [PATCH] experiment: extract some Memo code to be independent from Q This should enable more sharing and less monomorphization. There is probably room for more radical restructing in this vein. --- .../salsa-macros/src/database_storage.rs | 4 + src/derived/slot.rs | 148 +++++++++--------- src/plumbing.rs | 3 + 3 files changed, 84 insertions(+), 71 deletions(-) diff --git a/components/salsa-macros/src/database_storage.rs b/components/salsa-macros/src/database_storage.rs index 91d22718..e2909cb3 100644 --- a/components/salsa-macros/src/database_storage.rs +++ b/components/salsa-macros/src/database_storage.rs @@ -134,6 +134,10 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { } output.extend(quote! { impl salsa::plumbing::DatabaseOps for #database_name { + fn ops_database(&self) -> &dyn salsa::Database { + self + } + fn ops_salsa_runtime(&self) -> &salsa::Runtime { self.#db_storage_field.salsa_runtime() } diff --git a/src/derived/slot.rs b/src/derived/slot.rs index 8b7a4edd..ac750a76 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -65,6 +65,11 @@ where /// The result of the query, if we decide to memoize it. value: Option, + /// Revision information + revisions: MemoRevisions, +} + +struct MemoRevisions { /// Last revision when this memo was verified (if there are /// untracked inputs, this will also be when the memo was /// created). @@ -246,16 +251,16 @@ where // used to be, that is a "breaking change" that our // consumers must be aware of. Becoming *more* durable // is not. See the test `constant_to_non_constant`. - if result.durability >= old_memo.durability + if result.durability >= old_memo.revisions.durability && MP::memoized_value_eq(&old_value, &result.value) { debug!( "read_upgrade({:?}): value is equal, back-dating to {:?}", - self, old_memo.changed_at, + self, old_memo.revisions.changed_at, ); - assert!(old_memo.changed_at <= result.changed_at); - result.changed_at = old_memo.changed_at; + assert!(old_memo.revisions.changed_at <= result.changed_at); + result.changed_at = old_memo.revisions.changed_at; } } } @@ -295,10 +300,12 @@ where panic_guard.memo = Some(Memo { value, - changed_at: result.changed_at, - verified_at: revision_now, - inputs, - durability: result.durability, + revisions: MemoRevisions { + changed_at: result.changed_at, + verified_at: revision_now, + inputs, + durability: result.durability, + }, }); panic_guard.proceed(&new_value, result.cycle); @@ -387,14 +394,14 @@ where QueryState::Memoized(memo) => { debug!( "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", - self, memo.verified_at, memo.changed_at, + self, memo.revisions.verified_at, memo.revisions.changed_at, ); if let Some(value) = &memo.value { - if memo.verified_at == revision_now { + if memo.revisions.verified_at == revision_now { let value = StampedValue { - durability: memo.durability, - changed_at: memo.changed_at, + durability: memo.revisions.durability, + changed_at: memo.revisions.changed_at, value: value.clone(), }; @@ -417,8 +424,8 @@ where QueryState::NotComputed => Durability::LOW, QueryState::InProgress { .. } => panic!("query in progress"), QueryState::Memoized(memo) => { - if memo.check_durability(db) { - memo.durability + if memo.revisions.check_durability(db.salsa_runtime()) { + memo.revisions.durability } else { Durability::LOW } @@ -443,7 +450,7 @@ where // lead to inconsistencies. Note that we can't check // `has_untracked_input` when we add the value to the cache, // because inputs can become untracked in the next revision. - if memo.has_untracked_input() { + if memo.revisions.has_untracked_input() { return; } memo.value = None; @@ -467,7 +474,7 @@ where QueryState::Memoized(memo) => { debug!( "sweep({:?}): last verified at {:?}, current revision {:?}", - self, memo.verified_at, revision_now + self, memo.revisions.verified_at, revision_now ); // Check if this memo read something "untracked" @@ -478,7 +485,7 @@ where // revision, we might wind up re-executing the // query later in the revision and getting a // distinct result. - let has_untracked_input = memo.has_untracked_input(); + let has_untracked_input = memo.revisions.has_untracked_input(); // Since we don't acquire a query lock in this // method, it *is* possible for the revision to @@ -487,19 +494,19 @@ where // written into this table that reflect the new // revision, since we are holding the write lock // when we read `revision_now`. - assert!(memo.verified_at <= revision_now); + assert!(memo.revisions.verified_at <= revision_now); match strategy.discard_if { DiscardIf::Never => unreachable!(), // If we are only discarding outdated things, // and this is not outdated, keep it. - DiscardIf::Outdated if memo.verified_at == revision_now => (), + DiscardIf::Outdated if memo.revisions.verified_at == revision_now => (), // As explained on the `has_untracked_input` variable // definition, if this is a volatile entry, we // can't discard it unless it is outdated. DiscardIf::Always - if has_untracked_input && memo.verified_at == revision_now => {} + if has_untracked_input && memo.revisions.verified_at == revision_now => {} // Otherwise, we can discard -- discard whatever the user requested. DiscardIf::Outdated | DiscardIf::Always => match strategy.discard_what { @@ -518,8 +525,8 @@ where pub(super) fn invalidate(&self) -> Option { if let QueryState::Memoized(memo) = &mut *self.state.write() { - memo.inputs = MemoInputs::Untracked; - Some(memo.durability) + memo.revisions.inputs = MemoInputs::Untracked; + Some(memo.revisions.durability) } else { None } @@ -574,14 +581,14 @@ where QueryState::Memoized(memo) => memo, }; - if memo.verified_at == revision_now { + if memo.revisions.verified_at == revision_now { debug!( "maybe_changed_since({:?}: {:?} since up-to-date memo that changed at {:?}", self, - memo.changed_at > revision, - memo.changed_at, + memo.revisions.changed_at > revision, + memo.revisions.changed_at, ); - return memo.changed_at > revision; + return memo.revisions.changed_at > revision; } let maybe_changed; @@ -589,11 +596,11 @@ where // If we only depended on constants, and no constant has been // modified since then, we cannot have changed; no need to // trace our inputs. - if memo.check_durability(db) { + if memo.revisions.check_durability(runtime) { std::mem::drop(state); maybe_changed = false; } else { - match &memo.inputs { + match &memo.revisions.inputs { MemoInputs::Untracked => { // we don't know the full set of // inputs, so if there is a new @@ -660,7 +667,7 @@ where let mut state = self.state.write(); match &mut *state { QueryState::Memoized(memo) => { - if memo.verified_at == revision_now { + if memo.revisions.verified_at == revision_now { // Since we started verifying inputs, somebody // else has come along and updated this value // (they may even have recomputed @@ -683,7 +690,7 @@ where // We found this entry is valid. Update the // `verified_at` to reflect the current // revision. - memo.verified_at = revision_now; + memo.revisions.verified_at = revision_now; } } @@ -877,46 +884,46 @@ impl Memo where Q: QueryFunction, { - /// True if this memo is known not to have changed based on its durability. - fn check_durability(&self, db: &Q::DynDb) -> bool { - let last_changed = db.salsa_runtime().last_changed_revision(self.durability); - debug!( - "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", - last_changed, - self.verified_at, - last_changed <= self.verified_at, - ); - last_changed <= self.verified_at - } - fn validate_memoized_value( &mut self, db: &Q::DynDb, revision_now: Revision, ) -> Option> { // If we don't have a memoized value, nothing to validate. - if self.value.is_none() { - return None; - } + let value = match &self.value { + None => return None, + Some(v) => v, + }; + let dyn_db = db.ops_database(); + if self.revisions.validate_memoized_value(dyn_db, revision_now) { + Some(StampedValue { + durability: self.revisions.durability, + changed_at: self.revisions.changed_at, + value: value.clone(), + }) + } else { + None + } + } +} + +impl MemoRevisions { + fn validate_memoized_value(&mut self, db: &dyn Database, revision_now: Revision) -> bool { assert!(self.verified_at != revision_now); let verified_at = self.verified_at; - debug!( - "validate_memoized_value({:?}): verified_at={:#?}", - Q::default(), - self.inputs, - ); + debug!("validate_memoized_value: verified_at={:#?}", self.inputs,); - if self.check_durability(db) { - return Some(self.mark_value_as_verified(revision_now)); + if self.check_durability(db.salsa_runtime()) { + return self.mark_value_as_verified(revision_now); } match &self.inputs { // We can't validate values that had untracked inputs; just have to // re-execute. MemoInputs::Untracked { .. } => { - return None; + return false; } MemoInputs::NoInputs => {} @@ -937,32 +944,31 @@ where .next(); if let Some(input) = changed_input { - debug!( - "{:?}::validate_memoized_value: `{:?}` may have changed", - Q::default(), - input - ); + debug!("validate_memoized_value: `{:?}` may have changed", input); - return None; + return false; } } }; - Some(self.mark_value_as_verified(revision_now)) + self.mark_value_as_verified(revision_now) } - fn mark_value_as_verified(&mut self, revision_now: Revision) -> StampedValue { - let value = match &self.value { - Some(v) => v.clone(), - None => panic!("invoked `verify_value` without a value!"), - }; - self.verified_at = revision_now; + /// True if this memo is known not to have changed based on its durability. + fn check_durability(&self, runtime: &Runtime) -> bool { + let last_changed = runtime.last_changed_revision(self.durability); + debug!( + "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", + last_changed, + self.verified_at, + last_changed <= self.verified_at, + ); + last_changed <= self.verified_at + } - StampedValue { - durability: self.durability, - changed_at: self.changed_at, - value, - } + fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool { + self.verified_at = revision_now; + true } fn has_untracked_input(&self) -> bool { diff --git a/src/plumbing.rs b/src/plumbing.rs index 254d9e27..036274b4 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -37,6 +37,9 @@ pub trait DatabaseStorageTypes: Database { /// Internal operations that the runtime uses to operate on the database. pub trait DatabaseOps { + /// Upcast this type to a `dyn Database`. + fn ops_database(&self) -> &dyn Database; + /// Gives access to the underlying salsa runtime. fn ops_salsa_runtime(&self) -> &Runtime;