diff --git a/src/derived.rs b/src/derived.rs index 0475b5e0..e5f7b81e 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -281,6 +281,13 @@ where ) -> Result, CycleDetected> { let runtime = db.salsa_runtime(); + debug!( + "{:?}({:?}): read_upgrade(revision_now={:?})", + Q::default(), + key, + revision_now, + ); + // Check with an upgradable read to see if there is a value // already. (This permits other readers but prevents anyone // else from running `read_upgrade` at the same time.) @@ -355,6 +362,13 @@ where if let Some(old_memo) = &old_memo { if let Some(old_value) = &old_memo.value { if MP::memoized_value_eq(&old_value, &result.value) { + debug!( + "read_upgrade({:?}({:?})): value is equal, back-dating to {:?}", + Q::default(), + key, + old_memo.changed_at, + ); + assert!(old_memo.changed_at <= result.changed_at.revision); result.changed_at.revision = old_memo.changed_at; } @@ -372,6 +386,14 @@ where None }; + debug!( + "read_upgrade({:?}({:?})): result.changed_at={:?}, result.subqueries = {:#?}", + Q::default(), + key, + result.changed_at, + result.subqueries, + ); + let inputs = match result.subqueries { None => MemoInputs::Untracked, @@ -666,7 +688,7 @@ where let revision_now = runtime.current_revision(); debug!( - "{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})", + "maybe_changed_since({:?}({:?})) called with revision={:?}, revision_now={:?}", Q::default(), key, revision, @@ -682,13 +704,26 @@ where // If somebody depends on us, but we have no map // entry, that must mean that it was found to be out // of date and removed. - None => return true, + None => { + debug!( + "maybe_changed_since({:?}({:?}): no value", + Q::default(), + key, + ); + return true; + } // This value is being actively recomputed. Wait for // that thread to finish (assuming it's not dependent // on us...) and check its associated revision. Some(QueryState::InProgress { id, waiting }) => { let other_id = *id; + debug!( + "maybe_changed_since({:?}({:?}): blocking on thread `{:?}`", + Q::default(), + key, + other_id, + ); match self.register_with_in_progress_thread(runtime, descriptor, other_id, waiting) { Ok(rx) => { @@ -709,6 +744,13 @@ where }; if memo.verified_at == revision_now { + debug!( + "maybe_changed_since({:?}({:?}): {:?} since up-to-date memo that changed at {:?}", + Q::default(), + key, + memo.changed_at > revision, + memo.changed_at, + ); return memo.changed_at > revision; } @@ -718,6 +760,11 @@ where // inputs, so if there is a new // revision, we must assume it is // dirty + debug!( + "maybe_changed_since({:?}({:?}): true since untracked inputs", + Q::default(), + key, + ); return true; } @@ -734,7 +781,16 @@ where if memo.value.is_some() { std::mem::drop(map); return match self.read_upgrade(db, key, descriptor, revision_now) { - Ok(v) => v.changed_at.changed_since(revision), + Ok(v) => { + debug!( + "maybe_changed_since({:?}({:?}): {:?} since (recomputed) value changed at {:?}", + Q::default(), + key, + v.changed_at.changed_since(revision), + v.changed_at, + ); + v.changed_at.changed_since(revision) + } Err(CycleDetected) => true, }; } @@ -927,6 +983,12 @@ where assert!(self.verified_at != revision_now); let verified_at = self.verified_at; + debug!( + "validate_memoized_value({:?}): verified_at={:#?}", + Q::default(), + self.inputs, + ); + let is_constant = match &mut self.inputs { // We can't validate values that had untracked inputs; just have to // re-execute. diff --git a/src/runtime.rs b/src/runtime.rs index f1b41f36..4f0e6d08 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -159,6 +159,14 @@ where } } + /// Read current value of the revision counter. + #[inline] + fn pending_revision(&self) -> Revision { + Revision { + generation: self.shared_state.pending_revision.load(Ordering::SeqCst) as u64, + } + } + /// Check if the current revision is canceled. If this method ever /// returns true, the currently executing query is also marked as /// having an *untracked read* -- this means that, in the next @@ -169,14 +177,48 @@ where /// whatever it likes. #[inline] pub fn is_current_revision_canceled(&self) -> bool { - let pending_revision_increments = self - .shared_state - .pending_revision_increments - .load(Ordering::SeqCst); - if pending_revision_increments > 0 { + let current_revision = self.current_revision(); + let pending_revision = self.pending_revision(); + debug!( + "is_current_revision_canceled: current_revision={:?}, pending_revision={:?}", + current_revision, pending_revision + ); + if pending_revision > current_revision { self.report_untracked_read(); true } else { + // Subtle: If the current revision is not canceled, we + // still report an **anonymous** read, which will bump up + // the revision number to be at least the last + // non-canceled revision. This is needed to ensure + // deterministic reads and avoid salsa-rs/salsa#66. The + // specific scenario we are trying to avoid is tested by + // `no_back_dating_in_cancellation`; it works like + // this. Imagine we have 3 queries, where Query3 invokes + // Query2 which invokes Query1. Then: + // + // - In Revision R1: + // - Query1: Observes cancelation and returns sentinel S. + // - Recorded inputs: Untracked, because we observed cancelation. + // - Query2: Reads Query1 and propagates sentinel S. + // - Recorded inputs: Query1, changed-at=R1 + // - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1) + // - Recorded inputs: Query2, changed-at=R1 + // - In Revision R2: + // - Query1: Observes no cancelation. All of its inputs last changed in R0, + // so it returns a valid value with "changed at" of R0. + // - Recorded inputs: ..., changed-at=R0 + // - Query2: Recomputes its value and returns correct result. + // - Recorded inputs: Query1, changed-at=R0 <-- key problem! + // - Query3: sees that Query2's result last changed in R0, so it thinks it + // can re-use its value from R1 (which is the sentinel value). + // + // The anonymous read here prevents that scenario: Query1 + // winds up with a changed-at setting of R2, which is the + // "pending revision", and hence Query2 and Query3 + // are recomputed. + assert_eq!(pending_revision, current_revision); + self.report_anon_read(pending_revision); false } } @@ -190,6 +232,10 @@ where /// method will also increment `pending_revision_increments`, thus /// signalling to queries that their results are "canceled" and /// they should abort as expeditiously as possible. + /// + /// Note that, given our writer model, we can assume that only one + /// thread is attempting to increment the global revision at a + /// time. pub(crate) fn with_incremented_revision(&self, op: impl FnOnce(Revision) -> R) -> R { log::debug!("increment_revision()"); @@ -197,36 +243,24 @@ where panic!("increment_revision invoked during a query computation"); } - // Signal that we have a pending increment so that workers can - // start to cancel work. - let old_pending_revision_increments = self + // Set the `pending_revision` field so that people + // know current revision is canceled. + let current_revision = self .shared_state - .pending_revision_increments + .pending_revision .fetch_add(1, Ordering::SeqCst); - assert!( - old_pending_revision_increments != usize::max_value(), - "pending increment overflow" - ); + assert!(current_revision != usize::max_value(), "revision overflow"); // To modify the revision, we need the lock. let _lock = self.shared_state.query_lock.write(); - // *Before* updating the revision number, decrement the - // `pending_revision_increments` counter. This way, if anybody - // should happen to invoke `is_current_revision_canceled` - // before we update the number, and they read 0, they don't - // get an incorrect result -- once they acquire the query - // lock, we'll be in the new revision. - self.shared_state - .pending_revision_increments - .fetch_sub(1, Ordering::SeqCst); - let old_revision = self.shared_state.revision.fetch_add(1, Ordering::SeqCst); - assert!(old_revision != usize::max_value(), "revision overflow"); + assert_eq!(current_revision, old_revision); let new_revision = Revision { - generation: 1 + old_revision as u64, + generation: (current_revision + 1) as u64, }; + debug!("increment_revision: incremented to {:?}", new_revision); op(new_revision) @@ -308,6 +342,20 @@ where } } + /// An "anonymous" read is a read that doesn't come from executing + /// a query, but from some other internal operation. It just + /// modifies the "changed at" to be at least the given revision. + /// (It also does not disqualify a query from being considered + /// constant, since it is used for queries that don't give back + /// actual *data*.) + /// + /// This is used when queries check if they have been canceled. + fn report_anon_read(&self, revision: Revision) { + if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { + top_query.add_anon_read(revision); + } + } + /// Obviously, this should be user configurable at some point. pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! { debug!("report_unexpected_cycle(descriptor={:?})", descriptor); @@ -374,12 +422,10 @@ struct SharedState { /// (Ideally, this should be `AtomicU64`, but that is currently unstable.) revision: AtomicUsize, - /// Counts the number of pending increments to the revision - /// counter. If this is non-zero, it means that the current - /// revision is out of date, and hence queries are free to - /// "short-circuit" their results if they learn that. See - /// `is_current_revision_canceled` for more information. - pending_revision_increments: AtomicUsize, + /// This is typically equal to `revision` -- set to `revision+1` + /// when a new revision is pending (which implies that the current + /// revision is canceled). + pending_revision: AtomicUsize, /// The dependency graph tracks which runtimes are blocked on one /// another, waiting for queries to terminate. @@ -393,8 +439,8 @@ impl Default for SharedState { storage: Default::default(), query_lock: Default::default(), revision: Default::default(), + pending_revision: Default::default(), dependency_graph: Default::default(), - pending_revision_increments: Default::default(), } } } @@ -414,7 +460,7 @@ where fmt.debug_struct("SharedState") .field("query_lock", &query_lock) .field("revision", &self.revision) - .field("pending_revision_increments", &self.pending_revision_increments) + .field("pending_revision", &self.pending_revision) .finish() } } @@ -501,6 +547,10 @@ impl ActiveQuery { self.changed_at.is_constant = false; self.changed_at.revision = changed_at; } + + fn add_anon_read(&mut self, changed_at: Revision) { + self.changed_at.revision = self.changed_at.revision.max(changed_at); + } } /// A unique identifier for a particular runtime. Each time you create @@ -523,6 +573,17 @@ pub struct Revision { impl Revision { pub(crate) const ZERO: Self = Revision { generation: 0 }; + + fn next(self) -> Revision { + Revision { + generation: self.generation + 1, + } + } + + fn as_usize(self) -> usize { + assert!(self.generation < (std::usize::MAX as u64)); + self.generation as usize + } } impl std::fmt::Debug for Revision { diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index 80710a4c..0385dfc4 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -66,10 +66,8 @@ fn revalidate() { // Second generation: volatile will change (to 1) but memoized1 // will not (still 0, as 1/2 = 0) query.salsa_runtime().next_revision(); - query.memoized2(); query.assert_log(&["Memoized1 invoked", "Volatile invoked"]); - query.memoized2(); query.assert_log(&[]); diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index 27cde460..73afbad1 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -83,3 +83,75 @@ fn in_par_get_set_cancellation_transitive() { assert_eq!(thread1.join().unwrap(), std::usize::MAX); assert_eq!(thread2.join().unwrap(), 111); } + +/// https://github.com/salsa-rs/salsa/issues/66 +#[test] +fn no_back_dating_in_cancellation() { + let mut db = ParDatabaseImpl::default(); + + db.query_mut(Input).set('a', 1); + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + // Here we compute a long-chain of queries, + // but the last one gets cancelled. + db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(true, || db.sum3("a")) + }) + } + }); + + db.wait_for(1); + + // Set unrelated input to bump revision + db.query_mut(Input).set('b', 2); + + // Here we should recompuet the whole chain again, clearing the cancellation + // state. If we get `usize::max()` here, it is a bug! + assert_eq!(db.sum3("a"), 1); + + assert_eq!(thread1.join().unwrap(), std::usize::MAX); + + db.query_mut(Input).set('a', 3); + db.query_mut(Input).set('a', 4); + assert_eq!(db.sum3("ab"), 6); +} + +/// Here, we compute `sum3_drop_sum` and -- in the process -- observe +/// a cancellation. As a result, we have to recompute `sum` when we +/// reinvoke `sum3_drop_sum` and we have to re-execute +/// `sum2_drop_sum`. But the result of `sum2_drop_sum` doesn't +/// change, so we don't have to re-execute `sum3_drop_sum`. +#[test] +fn transitive_cancellation() { + let mut db = ParDatabaseImpl::default(); + + db.query_mut(Input).set('a', 1); + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + // Here we compute a long-chain of queries, + // but the last one gets cancelled. + db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(true, || db.sum3_drop_sum("a")) + }) + } + }); + + db.wait_for(1); + + db.query_mut(Input).set('b', 2); + + // Check that when we call `sum3_drop_sum` we don't wind up having + // to actually re-execute it, because the result of `sum2` winds + // up not changing. + db.knobs().sum3_drop_sum_should_panic.with_value(true, || { + assert_eq!(db.sum3_drop_sum("a"), 22); + }); + + assert_eq!(thread1.join().unwrap(), 22); +} diff --git a/tests/parallel/race.rs b/tests/parallel/race.rs index b862096c..234aea8e 100644 --- a/tests/parallel/race.rs +++ b/tests/parallel/race.rs @@ -25,9 +25,14 @@ fn in_par_get_set_race() { }); // If the 1st thread runs first, you get 111, otherwise you get - // 1011. + // 1011; if they run concurrently and the 1st thread observes the + // cancelation, you get back usize::max. let value1 = thread1.join().unwrap(); - assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1); + assert!( + value1 == 111 || value1 == 1011 || value1 == std::usize::MAX, + "illegal result {}", + value1 + ); assert_eq!(thread2.join().unwrap(), 1000); } diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 89929864..100c4e05 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -16,10 +16,26 @@ salsa::query_group! { type Sum; } + /// Invokes `sum` fn sum2(key: &'static str) -> usize { type Sum2; } + /// Invokes `sum` but doesn't really care about the result. + fn sum2_drop_sum(key: &'static str) -> usize { + type Sum2Drop; + } + + /// Invokes `sum2` + fn sum3(key: &'static str) -> usize { + type Sum3; + } + + /// Invokes `sum2_drop_sum` + fn sum3_drop_sum(key: &'static str) -> usize { + type Sum3Drop; + } + fn snapshot_me() -> () { type SnapshotMe; } @@ -82,6 +98,9 @@ pub(crate) struct KnobsStruct { /// Invocations of `sum` will signal this stage prior to exiting. pub(crate) sum_signal_on_exit: Cell, + + /// Invocations of `sum3_drop_sum` will panic unconditionally + pub(crate) sum3_drop_sum_should_panic: Cell, } fn sum(db: &impl ParDatabase, key: &'static str) -> usize { @@ -105,6 +124,17 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize { std::thread::yield_now(); } log::debug!("cancellation observed"); + } + + // Check for cancelation and return MAX if so. Note that we check + // for cancelation *deterministically* -- but if + // `sum_wait_for_cancellation` is set, we will block + // beforehand. Deterministic execution is a requirement for valid + // salsa user code. It's also important to some tests that `sum` + // *attempts* to invoke `is_current_revision_canceled` even if we + // know it will not be canceled, because that helps us keep the + // accounting up to date. + if db.salsa_runtime().is_current_revision_canceled() { return std::usize::MAX; // when we are cancelled, we return usize::MAX. } @@ -119,6 +149,22 @@ fn sum2(db: &impl ParDatabase, key: &'static str) -> usize { db.sum(key) } +fn sum2_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize { + let _ = db.sum(key); + 22 +} + +fn sum3(db: &impl ParDatabase, key: &'static str) -> usize { + db.sum2(key) +} + +fn sum3_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize { + if db.knobs().sum3_drop_sum_should_panic.get() { + panic!("sum3_drop_sum executed") + } + db.sum2_drop_sum(key) +} + fn snapshot_me(db: &impl ParDatabase) { // this should panic db.snapshot(); @@ -176,6 +222,9 @@ salsa::database_storage! { fn input() for Input; fn sum() for Sum; fn sum2() for Sum2; + fn sum2_drop_sum() for Sum2Drop; + fn sum3() for Sum3; + fn sum3_drop_sum() for Sum3Drop; fn snapshot_me() for SnapshotMe; } }