From 458266e1cd5830a4a164ef988015161f9d25c3e2 Mon Sep 17 00:00:00 2001 From: Jonas Schievink Date: Tue, 25 May 2021 15:34:57 +0200 Subject: [PATCH] Move `unwind_if_cancelled` to `Database` --- src/derived.rs | 2 +- src/derived/slot.rs | 2 +- src/input.rs | 2 +- src/interned.rs | 2 +- src/lib.rs | 41 ++++++++++++++++++++++++++++++++++++++ src/runtime.rs | 43 ++-------------------------------------- tests/parallel/frozen.rs | 7 ++----- tests/parallel/setup.rs | 2 +- tests/parallel/stress.rs | 4 ++-- 9 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/derived.rs b/src/derived.rs index 14b7aa69..64437b5a 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -162,7 +162,7 @@ where db: &>::DynDb, key: &Q::Key, ) -> Result> { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); let slot = self.slot(key); let StampedValue { diff --git a/src/derived/slot.rs b/src/derived/slot.rs index b080e927..f0f153f0 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -547,7 +547,7 @@ where let runtime = db.salsa_runtime(); let revision_now = runtime.current_revision(); - runtime.unwind_if_cancelled(); + db.unwind_if_cancelled(); debug!( "maybe_changed_since({:?}) called with revision={:?}, revision_now={:?}", diff --git a/src/input.rs b/src/input.rs index 51995709..67c2a21a 100644 --- a/src/input.rs +++ b/src/input.rs @@ -99,7 +99,7 @@ where db: &>::DynDb, key: &Q::Key, ) -> Result> { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); let slot = self .slot(key) diff --git a/src/interned.rs b/src/interned.rs index 073ec72f..c8eccabd 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -322,7 +322,7 @@ where db: &>::DynDb, key: &Q::Key, ) -> Result> { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); let slot = self.intern_index(db, key); let changed_at = slot.interned_at; diff --git a/src/lib.rs b/src/lib.rs index bd5cde47..f3aeca7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,8 @@ pub trait Database: plumbing::DatabaseOps { /// consume are marked as used. You then invoke this method to /// remove other values that were not needed for your main query /// results. + /// + /// This method should not be overridden by `Database` implementors. fn sweep_all(&self, strategy: SweepStrategy) { // Note that we do not acquire the query lock (or any locks) // here. Each table is capable of sweeping itself atomically @@ -72,12 +74,46 @@ pub trait Database: plumbing::DatabaseOps { #![allow(unused_variables)] } + /// Starts unwinding the stack if the current revision is cancelled. + /// + /// This method can be called by query implementations that perform + /// potentially expensive computations, in order to speed up propagation of + /// cancellation. + /// + /// Cancellation will automatically be triggered by salsa on any query + /// invocation. + /// + /// This method should not be overridden by `Database` implementors. + #[inline] + fn unwind_if_cancelled(&self) { + let runtime = self.salsa_runtime(); + self.salsa_event(Event { + runtime_id: runtime.id(), + kind: EventKind::WillCheckCancellation, + }); + + let current_revision = runtime.current_revision(); + let pending_revision = runtime.pending_revision(); + log::debug!( + "unwind_if_cancelled: current_revision={:?}, pending_revision={:?}", + current_revision, + pending_revision + ); + if pending_revision > current_revision { + runtime.unwind_cancelled(); + } + } + /// Gives access to the underlying salsa runtime. + /// + /// This method should not be overridden by `Database` implementors. fn salsa_runtime(&self) -> &Runtime { self.ops_salsa_runtime() } /// Gives access to the underlying salsa runtime. + /// + /// This method should not be overridden by `Database` implementors. fn salsa_runtime_mut(&mut self) -> &mut Runtime { self.ops_salsa_runtime_mut() } @@ -140,6 +176,10 @@ pub enum EventKind { /// The database-key for the affected value. Implements `Debug`. database_key: DatabaseKeyIndex, }, + + /// Indicates that `unwind_if_cancelled` was called and salsa will check if + /// the current revision has been cancelled. + WillCheckCancellation, } impl fmt::Debug for EventKind { @@ -161,6 +201,7 @@ impl fmt::Debug for EventKind { .debug_struct("WillExecute") .field("database_key", database_key) .finish(), + EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(), } } } diff --git a/src/runtime.rs b/src/runtime.rs index 23b14622..f9d6e4ab 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,7 +8,6 @@ use parking_lot::{Mutex, RwLock}; use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; use std::hash::{BuildHasherDefault, Hash}; -use std::panic::RefUnwindSafe; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -39,8 +38,6 @@ pub struct Runtime { /// Shared state that is accessible via all runtimes. shared_state: Arc, - - on_cancellation_check: Option>, } impl Default for Runtime { @@ -50,7 +47,6 @@ impl Default for Runtime { revision_guard: None, shared_state: Default::default(), local_state: Default::default(), - on_cancellation_check: None, } } } @@ -89,7 +85,6 @@ impl Runtime { revision_guard: Some(revision_guard), shared_state: self.shared_state.clone(), local_state: Default::default(), - on_cancellation_check: None, } } @@ -157,50 +152,16 @@ impl Runtime { /// Read current value of the revision counter. #[inline] - fn pending_revision(&self) -> Revision { + pub(crate) fn pending_revision(&self) -> Revision { self.shared_state.pending_revision.load() } - /// Starts unwinding the stack if the current revision is cancelled. - /// - /// This method can be called by query implementations that perform - /// potentially expensive computations, in order to speed up propagation of - /// cancellation. - /// - /// Cancellation will automatically be triggered by salsa on any query - /// invocation. - #[inline] - pub fn unwind_if_cancelled(&self) { - if let Some(callback) = &self.on_cancellation_check { - callback(); - } - - let current_revision = self.current_revision(); - let pending_revision = self.pending_revision(); - debug!( - "unwind_if_cancelled: current_revision={:?}, pending_revision={:?}", - current_revision, pending_revision - ); - if pending_revision > current_revision { - self.unwind_cancelled(); - } - } - #[cold] - fn unwind_cancelled(&self) { + pub(crate) fn unwind_cancelled(&self) { self.report_untracked_read(); Cancelled::throw(); } - /// Registers a callback to be invoked every time [`Runtime::unwind_if_cancelled`] is called - /// (either automatically by salsa, or manually by user code). - pub fn set_cancellation_check_callback(&mut self, callback: F) - where - F: Fn() + Send + RefUnwindSafe + 'static, - { - self.on_cancellation_check = Some(Box::new(callback)); - } - /// Acquires the **global query write lock** (ensuring that no queries are /// executing) and then increments the current revision counter; invokes /// `op` with the global query write lock still held. diff --git a/tests/parallel/frozen.rs b/tests/parallel/frozen.rs index 677ca835..2e341a90 100644 --- a/tests/parallel/frozen.rs +++ b/tests/parallel/frozen.rs @@ -23,17 +23,14 @@ fn in_par_get_set_cancellation() { move || { // Check that cancellation flag is not yet set, because // `set` cannot have been called yet. - catch_unwind(AssertUnwindSafe(|| { - db.salsa_runtime().unwind_if_cancelled() - })) - .unwrap(); + catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap(); // Signal other thread to proceed. signal.signal(1); // Wait for other thread to signal cancellation catch_unwind(AssertUnwindSafe(|| loop { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); std::thread::yield_now(); })) .unwrap_err(); diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 78ef1ff0..fc39b77a 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -124,7 +124,7 @@ fn sum(db: &dyn ParDatabase, key: &'static str) -> usize { CancellationFlag::Panic => { log::debug!("waiting for cancellation"); loop { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); std::thread::yield_now(); } } diff --git a/tests/parallel/stress.rs b/tests/parallel/stress.rs index fea9e3a6..49d74f79 100644 --- a/tests/parallel/stress.rs +++ b/tests/parallel/stress.rs @@ -21,7 +21,7 @@ trait StressDatabase: salsa::Database { } fn b(db: &dyn StressDatabase, key: usize) -> usize { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); db.a(key) } @@ -121,7 +121,7 @@ impl rand::distributions::Distribution for rand::distributions::Standard fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec, check_cancellation: bool) { for op in ops { if check_cancellation { - db.salsa_runtime().unwind_if_cancelled(); + db.unwind_if_cancelled(); } op.execute(db); }