Move unwind_if_cancelled to Database

This commit is contained in:
Jonas Schievink 2021-05-25 15:34:57 +02:00
parent 1fb660c33e
commit 458266e1cd
9 changed files with 52 additions and 53 deletions

View file

@ -162,7 +162,7 @@ where
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
let slot = self.slot(key);
let StampedValue {

View file

@ -547,7 +547,7 @@ where
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
runtime.unwind_if_cancelled();
db.unwind_if_cancelled();
debug!(
"maybe_changed_since({:?}) called with revision={:?}, revision_now={:?}",

View file

@ -99,7 +99,7 @@ where
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
let slot = self
.slot(key)

View file

@ -322,7 +322,7 @@ where
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
let slot = self.intern_index(db, key);
let changed_at = slot.interned_at;

View file

@ -55,6 +55,8 @@ pub trait Database: plumbing::DatabaseOps {
/// consume are marked as used. You then invoke this method to
/// remove other values that were not needed for your main query
/// results.
///
/// This method should not be overridden by `Database` implementors.
fn sweep_all(&self, strategy: SweepStrategy) {
// Note that we do not acquire the query lock (or any locks)
// here. Each table is capable of sweeping itself atomically
@ -72,12 +74,46 @@ pub trait Database: plumbing::DatabaseOps {
#![allow(unused_variables)]
}
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
/// potentially expensive computations, in order to speed up propagation of
/// cancellation.
///
/// Cancellation will automatically be triggered by salsa on any query
/// invocation.
///
/// This method should not be overridden by `Database` implementors.
#[inline]
fn unwind_if_cancelled(&self) {
let runtime = self.salsa_runtime();
self.salsa_event(Event {
runtime_id: runtime.id(),
kind: EventKind::WillCheckCancellation,
});
let current_revision = runtime.current_revision();
let pending_revision = runtime.pending_revision();
log::debug!(
"unwind_if_cancelled: current_revision={:?}, pending_revision={:?}",
current_revision,
pending_revision
);
if pending_revision > current_revision {
runtime.unwind_cancelled();
}
}
/// Gives access to the underlying salsa runtime.
///
/// This method should not be overridden by `Database` implementors.
fn salsa_runtime(&self) -> &Runtime {
self.ops_salsa_runtime()
}
/// Gives access to the underlying salsa runtime.
///
/// This method should not be overridden by `Database` implementors.
fn salsa_runtime_mut(&mut self) -> &mut Runtime {
self.ops_salsa_runtime_mut()
}
@ -140,6 +176,10 @@ pub enum EventKind {
/// The database-key for the affected value. Implements `Debug`.
database_key: DatabaseKeyIndex,
},
/// Indicates that `unwind_if_cancelled` was called and salsa will check if
/// the current revision has been cancelled.
WillCheckCancellation,
}
impl fmt::Debug for EventKind {
@ -161,6 +201,7 @@ impl fmt::Debug for EventKind {
.debug_struct("WillExecute")
.field("database_key", database_key)
.finish(),
EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
}
}
}

View file

@ -8,7 +8,6 @@ use parking_lot::{Mutex, RwLock};
use rustc_hash::{FxHashMap, FxHasher};
use smallvec::SmallVec;
use std::hash::{BuildHasherDefault, Hash};
use std::panic::RefUnwindSafe;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
@ -39,8 +38,6 @@ pub struct Runtime {
/// Shared state that is accessible via all runtimes.
shared_state: Arc<SharedState>,
on_cancellation_check: Option<Box<dyn Fn() + RefUnwindSafe + Send>>,
}
impl Default for Runtime {
@ -50,7 +47,6 @@ impl Default for Runtime {
revision_guard: None,
shared_state: Default::default(),
local_state: Default::default(),
on_cancellation_check: None,
}
}
}
@ -89,7 +85,6 @@ impl Runtime {
revision_guard: Some(revision_guard),
shared_state: self.shared_state.clone(),
local_state: Default::default(),
on_cancellation_check: None,
}
}
@ -157,50 +152,16 @@ impl Runtime {
/// Read current value of the revision counter.
#[inline]
fn pending_revision(&self) -> Revision {
pub(crate) fn pending_revision(&self) -> Revision {
self.shared_state.pending_revision.load()
}
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
/// potentially expensive computations, in order to speed up propagation of
/// cancellation.
///
/// Cancellation will automatically be triggered by salsa on any query
/// invocation.
#[inline]
pub fn unwind_if_cancelled(&self) {
if let Some(callback) = &self.on_cancellation_check {
callback();
}
let current_revision = self.current_revision();
let pending_revision = self.pending_revision();
debug!(
"unwind_if_cancelled: current_revision={:?}, pending_revision={:?}",
current_revision, pending_revision
);
if pending_revision > current_revision {
self.unwind_cancelled();
}
}
#[cold]
fn unwind_cancelled(&self) {
pub(crate) fn unwind_cancelled(&self) {
self.report_untracked_read();
Cancelled::throw();
}
/// Registers a callback to be invoked every time [`Runtime::unwind_if_cancelled`] is called
/// (either automatically by salsa, or manually by user code).
pub fn set_cancellation_check_callback<F>(&mut self, callback: F)
where
F: Fn() + Send + RefUnwindSafe + 'static,
{
self.on_cancellation_check = Some(Box::new(callback));
}
/// Acquires the **global query write lock** (ensuring that no queries are
/// executing) and then increments the current revision counter; invokes
/// `op` with the global query write lock still held.

View file

@ -23,17 +23,14 @@ fn in_par_get_set_cancellation() {
move || {
// Check that cancellation flag is not yet set, because
// `set` cannot have been called yet.
catch_unwind(AssertUnwindSafe(|| {
db.salsa_runtime().unwind_if_cancelled()
}))
.unwrap();
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
// Signal other thread to proceed.
signal.signal(1);
// Wait for other thread to signal cancellation
catch_unwind(AssertUnwindSafe(|| loop {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
std::thread::yield_now();
}))
.unwrap_err();

View file

@ -124,7 +124,7 @@ fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
CancellationFlag::Panic => {
log::debug!("waiting for cancellation");
loop {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
std::thread::yield_now();
}
}

View file

@ -21,7 +21,7 @@ trait StressDatabase: salsa::Database {
}
fn b(db: &dyn StressDatabase, key: usize) -> usize {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
db.a(key)
}
@ -121,7 +121,7 @@ impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
for op in ops {
if check_cancellation {
db.salsa_runtime().unwind_if_cancelled();
db.unwind_if_cancelled();
}
op.execute(db);
}