mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-23 05:07:27 +00:00
Merge #265
265: Implement "opinionated cancellation" r=nikomatsakis a=jonas-schievink This implements the design described in RFC https://github.com/salsa-rs/salsa/pull/262. Currently, the `in_par_get_set_cancellation` test is failing. I'm not completely sure what it is trying to test (the comment on the test does not match its behavior), so I couldn't fix it. Co-authored-by: Jonas Schievink <jonasschievink@gmail.com>
This commit is contained in:
commit
e17e77c4e9
13 changed files with 244 additions and 346 deletions
|
@ -447,14 +447,18 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
/// deadlock.
|
||||
///
|
||||
/// Before blocking, the thread that is attempting to `set` will
|
||||
/// also set a cancellation flag. In the threads operating on
|
||||
/// snapshots, you can use the [`is_current_revision_canceled`]
|
||||
/// method to check for this flag and bring those operations to a
|
||||
/// close, thus allowing the `set` to succeed. Ignoring this flag
|
||||
/// may lead to "starvation", meaning that the thread attempting
|
||||
/// to `set` has to wait a long, long time. =)
|
||||
/// also set a cancellation flag. This will cause any query
|
||||
/// invocations in other threads to unwind with a `Cancelled`
|
||||
/// sentinel value and eventually let the `set` succeed once all
|
||||
/// threads have unwound past the salsa invocation.
|
||||
///
|
||||
/// [`is_current_revision_canceled`]: struct.Runtime.html#method.is_current_revision_canceled
|
||||
/// If your query implementations are performing expensive
|
||||
/// operations without invoking another query, you can also use
|
||||
/// the `Runtime::unwind_if_cancelled` method to check for an
|
||||
/// ongoing cancellation and bring those operations to a close,
|
||||
/// thus allowing the `set` to succeed. Otherwise, long-running
|
||||
/// computations may lead to "starvation", meaning that the
|
||||
/// thread attempting to `set` has to wait a long, long time. =)
|
||||
#trait_vis fn in_db_mut(self, db: &mut #dyn_db) -> salsa::QueryTableMut<'_, Self>
|
||||
{
|
||||
salsa::plumbing::get_query_table_mut::<#qt>(db)
|
||||
|
|
|
@ -162,6 +162,8 @@ where
|
|||
db: &<Q as QueryDb<'_>>::DynDb,
|
||||
key: &Q::Key,
|
||||
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
|
||||
db.unwind_if_cancelled();
|
||||
|
||||
let slot = self.slot(key);
|
||||
let StampedValue {
|
||||
value,
|
||||
|
|
|
@ -10,6 +10,7 @@ use crate::revision::Revision;
|
|||
use crate::runtime::Runtime;
|
||||
use crate::runtime::RuntimeId;
|
||||
use crate::runtime::StampedValue;
|
||||
use crate::Cancelled;
|
||||
use crate::{
|
||||
CycleError, Database, DatabaseKeyIndex, DiscardIf, DiscardWhat, Event, EventKind, QueryDb,
|
||||
SweepStrategy,
|
||||
|
@ -353,7 +354,12 @@ where
|
|||
},
|
||||
});
|
||||
|
||||
let result = future.wait().unwrap_or_else(|| db.on_propagated_panic());
|
||||
let result = future.wait().unwrap_or_else(|| {
|
||||
// If the other thread panics, we treat this as cancellation: there is no
|
||||
// need to panic ourselves, since the original panic will already invoke
|
||||
// the panic hook and bubble up to the thread boundary (or be caught).
|
||||
Cancelled::throw()
|
||||
});
|
||||
ProbeState::UpToDate(if result.cycle.is_empty() {
|
||||
Ok(result.value)
|
||||
} else {
|
||||
|
@ -541,6 +547,8 @@ where
|
|||
let runtime = db.salsa_runtime();
|
||||
let revision_now = runtime.current_revision();
|
||||
|
||||
db.unwind_if_cancelled();
|
||||
|
||||
debug!(
|
||||
"maybe_changed_since({:?}) called with revision={:?}, revision_now={:?}",
|
||||
self, revision, revision_now,
|
||||
|
@ -574,7 +582,7 @@ where
|
|||
// Release our lock on `self.state`, so other thread can complete.
|
||||
std::mem::drop(state);
|
||||
|
||||
let result = future.wait().unwrap_or_else(|| db.on_propagated_panic());
|
||||
let result = future.wait().unwrap_or_else(|| Cancelled::throw());
|
||||
return !result.cycle.is_empty() || result.value.changed_at > revision;
|
||||
}
|
||||
|
||||
|
|
|
@ -99,6 +99,8 @@ where
|
|||
db: &<Q as QueryDb<'_>>::DynDb,
|
||||
key: &Q::Key,
|
||||
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
|
||||
db.unwind_if_cancelled();
|
||||
|
||||
let slot = self
|
||||
.slot(key)
|
||||
.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
|
||||
|
|
|
@ -322,6 +322,8 @@ where
|
|||
db: &<Q as QueryDb<'_>>::DynDb,
|
||||
key: &Q::Key,
|
||||
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
|
||||
db.unwind_if_cancelled();
|
||||
|
||||
let slot = self.intern_index(db, key);
|
||||
let changed_at = slot.interned_at;
|
||||
let index = slot.index;
|
||||
|
|
85
src/lib.rs
85
src/lib.rs
|
@ -34,6 +34,7 @@ use crate::plumbing::QueryStorageOps;
|
|||
pub use crate::revision::Revision;
|
||||
use std::fmt::{self, Debug};
|
||||
use std::hash::Hash;
|
||||
use std::panic::{self, UnwindSafe};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use crate::durability::Durability;
|
||||
|
@ -54,6 +55,8 @@ pub trait Database: plumbing::DatabaseOps {
|
|||
/// consume are marked as used. You then invoke this method to
|
||||
/// remove other values that were not needed for your main query
|
||||
/// results.
|
||||
///
|
||||
/// This method should not be overridden by `Database` implementors.
|
||||
fn sweep_all(&self, strategy: SweepStrategy) {
|
||||
// Note that we do not acquire the query lock (or any locks)
|
||||
// here. Each table is capable of sweeping itself atomically
|
||||
|
@ -71,18 +74,48 @@ pub trait Database: plumbing::DatabaseOps {
|
|||
#![allow(unused_variables)]
|
||||
}
|
||||
|
||||
/// This function is invoked when a dependent query is being computed by the
|
||||
/// other thread, and that thread panics.
|
||||
fn on_propagated_panic(&self) -> ! {
|
||||
panic!("concurrent salsa query panicked")
|
||||
/// Starts unwinding the stack if the current revision is cancelled.
|
||||
///
|
||||
/// This method can be called by query implementations that perform
|
||||
/// potentially expensive computations, in order to speed up propagation of
|
||||
/// cancellation.
|
||||
///
|
||||
/// Cancellation will automatically be triggered by salsa on any query
|
||||
/// invocation.
|
||||
///
|
||||
/// This method should not be overridden by `Database` implementors. A
|
||||
/// `salsa_event` is emitted when this method is called, so that should be
|
||||
/// used instead.
|
||||
#[inline]
|
||||
fn unwind_if_cancelled(&self) {
|
||||
let runtime = self.salsa_runtime();
|
||||
self.salsa_event(Event {
|
||||
runtime_id: runtime.id(),
|
||||
kind: EventKind::WillCheckCancellation,
|
||||
});
|
||||
|
||||
let current_revision = runtime.current_revision();
|
||||
let pending_revision = runtime.pending_revision();
|
||||
log::debug!(
|
||||
"unwind_if_cancelled: current_revision={:?}, pending_revision={:?}",
|
||||
current_revision,
|
||||
pending_revision
|
||||
);
|
||||
if pending_revision > current_revision {
|
||||
runtime.unwind_cancelled();
|
||||
}
|
||||
}
|
||||
|
||||
/// Gives access to the underlying salsa runtime.
|
||||
///
|
||||
/// This method should not be overridden by `Database` implementors.
|
||||
fn salsa_runtime(&self) -> &Runtime {
|
||||
self.ops_salsa_runtime()
|
||||
}
|
||||
|
||||
/// Gives access to the underlying salsa runtime.
|
||||
///
|
||||
/// This method should not be overridden by `Database` implementors.
|
||||
fn salsa_runtime_mut(&mut self) -> &mut Runtime {
|
||||
self.ops_salsa_runtime_mut()
|
||||
}
|
||||
|
@ -145,6 +178,10 @@ pub enum EventKind {
|
|||
/// The database-key for the affected value. Implements `Debug`.
|
||||
database_key: DatabaseKeyIndex,
|
||||
},
|
||||
|
||||
/// Indicates that `unwind_if_cancelled` was called and salsa will check if
|
||||
/// the current revision has been cancelled.
|
||||
WillCheckCancellation,
|
||||
}
|
||||
|
||||
impl fmt::Debug for EventKind {
|
||||
|
@ -166,6 +203,7 @@ impl fmt::Debug for EventKind {
|
|||
.debug_struct("WillExecute")
|
||||
.field("database_key", database_key)
|
||||
.finish(),
|
||||
EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -277,11 +315,9 @@ pub trait ParallelDatabase: Database + Send {
|
|||
/// series of queries in parallel and arranging the results. Using
|
||||
/// this method for that purpose ensures that those queries will
|
||||
/// see a consistent view of the database (it is also advisable
|
||||
/// for those queries to use the [`is_current_revision_canceled`]
|
||||
/// for those queries to use the [`Runtime::unwind_if_cancelled`]
|
||||
/// method to check for cancellation).
|
||||
///
|
||||
/// [`is_current_revision_canceled`]: struct.Runtime.html#method.is_current_revision_canceled
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// It is not permitted to create a snapshot from inside of a
|
||||
|
@ -644,6 +680,41 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// A panic payload indicating that a salsa revision was cancelled.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Cancelled;
|
||||
|
||||
impl Cancelled {
|
||||
fn throw() -> ! {
|
||||
// We use resume and not panic here to avoid running the panic
|
||||
// hook (that is, to avoid collecting and printing backtrace).
|
||||
std::panic::resume_unwind(Box::new(Self));
|
||||
}
|
||||
|
||||
/// Runs `f`, and catches any salsa cancellation.
|
||||
pub fn catch<F, T>(f: F) -> Result<T, Cancelled>
|
||||
where
|
||||
F: FnOnce() -> T + UnwindSafe,
|
||||
{
|
||||
match panic::catch_unwind(f) {
|
||||
Ok(t) => Ok(t),
|
||||
Err(payload) => match payload.downcast() {
|
||||
Ok(cancelled) => Err(*cancelled),
|
||||
Err(payload) => panic::resume_unwind(payload),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Cancelled {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Cancelled {}
|
||||
|
||||
// Re-export the procedural macros.
|
||||
#[allow(unused_imports)]
|
||||
#[macro_use]
|
||||
|
|
113
src/runtime.rs
113
src/runtime.rs
|
@ -1,6 +1,6 @@
|
|||
use crate::durability::Durability;
|
||||
use crate::plumbing::CycleDetected;
|
||||
use crate::revision::{AtomicRevision, Revision};
|
||||
use crate::{durability::Durability, Cancelled};
|
||||
use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind};
|
||||
use log::debug;
|
||||
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
|
||||
|
@ -152,93 +152,14 @@ impl Runtime {
|
|||
|
||||
/// Read current value of the revision counter.
|
||||
#[inline]
|
||||
fn pending_revision(&self) -> Revision {
|
||||
pub(crate) fn pending_revision(&self) -> Revision {
|
||||
self.shared_state.pending_revision.load()
|
||||
}
|
||||
|
||||
/// Check if the current revision is canceled. If this method ever
|
||||
/// returns true, the currently executing query is also marked as
|
||||
/// having an *untracked read* -- this means that, in the next
|
||||
/// revision, we will always recompute its value "as if" some
|
||||
/// input had changed. This means that, if your revision is
|
||||
/// canceled (which indicates that current query results will be
|
||||
/// ignored) your query is free to shortcircuit and return
|
||||
/// whatever it likes.
|
||||
///
|
||||
/// This method is useful for implementing cancellation of queries.
|
||||
/// You can do it in one of two ways, via `Result`s or via unwinding.
|
||||
///
|
||||
/// The `Result` approach looks like this:
|
||||
///
|
||||
/// * Some queries invoke `is_current_revision_canceled` and
|
||||
/// return a special value, like `Err(Canceled)`, if it returns
|
||||
/// `true`.
|
||||
/// * Other queries propagate the special value using `?` operator.
|
||||
/// * API around top-level queries checks if the result is `Ok` or
|
||||
/// `Err(Canceled)`.
|
||||
///
|
||||
/// The `panic` approach works in a similar way:
|
||||
///
|
||||
/// * Some queries invoke `is_current_revision_canceled` and
|
||||
/// panic with a special value, like `Canceled`, if it returns
|
||||
/// true.
|
||||
/// * The implementation of `Database` trait overrides
|
||||
/// `on_propagated_panic` to throw this special value as well.
|
||||
/// This way, panic gets propagated naturally through dependant
|
||||
/// queries, even across the threads.
|
||||
/// * API around top-level queries converts a `panic` into `Result` by
|
||||
/// catching the panic (using either `std::panic::catch_unwind` or
|
||||
/// threads) and downcasting the payload to `Canceled` (re-raising
|
||||
/// panic if downcast fails).
|
||||
///
|
||||
/// Note that salsa is explicitly designed to be panic-safe, so cancellation
|
||||
/// via unwinding is 100% valid approach to cancellation.
|
||||
#[inline]
|
||||
pub fn is_current_revision_canceled(&self) -> bool {
|
||||
let current_revision = self.current_revision();
|
||||
let pending_revision = self.pending_revision();
|
||||
debug!(
|
||||
"is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
|
||||
current_revision, pending_revision
|
||||
);
|
||||
if pending_revision > current_revision {
|
||||
self.report_untracked_read();
|
||||
true
|
||||
} else {
|
||||
// Subtle: If the current revision is not canceled, we
|
||||
// still report an **anonymous** read, which will bump up
|
||||
// the revision number to be at least the last
|
||||
// non-canceled revision. This is needed to ensure
|
||||
// deterministic reads and avoid salsa-rs/salsa#66. The
|
||||
// specific scenario we are trying to avoid is tested by
|
||||
// `no_back_dating_in_cancellation`; it works like
|
||||
// this. Imagine we have 3 queries, where Query3 invokes
|
||||
// Query2 which invokes Query1. Then:
|
||||
//
|
||||
// - In Revision R1:
|
||||
// - Query1: Observes cancelation and returns sentinel S.
|
||||
// - Recorded inputs: Untracked, because we observed cancelation.
|
||||
// - Query2: Reads Query1 and propagates sentinel S.
|
||||
// - Recorded inputs: Query1, changed-at=R1
|
||||
// - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
|
||||
// - Recorded inputs: Query2, changed-at=R1
|
||||
// - In Revision R2:
|
||||
// - Query1: Observes no cancelation. All of its inputs last changed in R0,
|
||||
// so it returns a valid value with "changed at" of R0.
|
||||
// - Recorded inputs: ..., changed-at=R0
|
||||
// - Query2: Recomputes its value and returns correct result.
|
||||
// - Recorded inputs: Query1, changed-at=R0 <-- key problem!
|
||||
// - Query3: sees that Query2's result last changed in R0, so it thinks it
|
||||
// can re-use its value from R1 (which is the sentinel value).
|
||||
//
|
||||
// The anonymous read here prevents that scenario: Query1
|
||||
// winds up with a changed-at setting of R2, which is the
|
||||
// "pending revision", and hence Query2 and Query3
|
||||
// are recomputed.
|
||||
assert_eq!(pending_revision, current_revision);
|
||||
self.report_anon_read(pending_revision);
|
||||
false
|
||||
}
|
||||
#[cold]
|
||||
pub(crate) fn unwind_cancelled(&self) {
|
||||
self.report_untracked_read();
|
||||
Cancelled::throw();
|
||||
}
|
||||
|
||||
/// Acquires the **global query write lock** (ensuring that no queries are
|
||||
|
@ -247,7 +168,7 @@ impl Runtime {
|
|||
///
|
||||
/// While we wait to acquire the global query write lock, this method will
|
||||
/// also increment `pending_revision_increments`, thus signalling to queries
|
||||
/// that their results are "canceled" and they should abort as expeditiously
|
||||
/// that their results are "cancelled" and they should abort as expeditiously
|
||||
/// as possible.
|
||||
///
|
||||
/// The `op` closure should actually perform the writes needed. It is given
|
||||
|
@ -274,7 +195,7 @@ impl Runtime {
|
|||
}
|
||||
|
||||
// Set the `pending_revision` field so that people
|
||||
// know current revision is canceled.
|
||||
// know current revision is cancelled.
|
||||
let current_revision = self.shared_state.pending_revision.fetch_then_increment();
|
||||
|
||||
// To modify the revision, we need the lock.
|
||||
|
@ -381,18 +302,6 @@ impl Runtime {
|
|||
self.local_state.report_synthetic_read(durability);
|
||||
}
|
||||
|
||||
/// An "anonymous" read is a read that doesn't come from executing
|
||||
/// a query, but from some other internal operation. It just
|
||||
/// modifies the "changed at" to be at least the given revision.
|
||||
/// (It also does not disqualify a query from being considered
|
||||
/// constant, since it is used for queries that don't give back
|
||||
/// actual *data*.)
|
||||
///
|
||||
/// This is used when queries check if they have been canceled.
|
||||
fn report_anon_read(&self, revision: Revision) {
|
||||
self.local_state.report_anon_read(revision)
|
||||
}
|
||||
|
||||
/// Obviously, this should be user configurable at some point.
|
||||
pub(crate) fn report_unexpected_cycle(
|
||||
&self,
|
||||
|
@ -516,7 +425,7 @@ struct SharedState {
|
|||
|
||||
/// This is typically equal to `revision` -- set to `revision+1`
|
||||
/// when a new revision is pending (which implies that the current
|
||||
/// revision is canceled).
|
||||
/// revision is cancelled).
|
||||
pending_revision: AtomicRevision,
|
||||
|
||||
/// Stores the "last change" revision for values of each duration.
|
||||
|
@ -639,10 +548,6 @@ impl ActiveQuery {
|
|||
fn add_synthetic_read(&mut self, durability: Durability) {
|
||||
self.durability = self.durability.min(durability);
|
||||
}
|
||||
|
||||
fn add_anon_read(&mut self, changed_at: Revision) {
|
||||
self.changed_at = self.changed_at.max(changed_at);
|
||||
}
|
||||
}
|
||||
|
||||
/// A unique identifier for a particular runtime. Each time you create
|
||||
|
|
|
@ -86,12 +86,6 @@ impl LocalState {
|
|||
top_query.add_synthetic_read(durability);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn report_anon_read(&self, revision: Revision) {
|
||||
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||
top_query.add_anon_read(revision);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::panic::RefUnwindSafe for LocalState {}
|
||||
|
|
|
@ -1,161 +1,104 @@
|
|||
use crate::setup::{CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||
use salsa::ParallelDatabase;
|
||||
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
|
||||
use salsa::{Cancelled, ParallelDatabase};
|
||||
|
||||
macro_rules! assert_canceled {
|
||||
($flag:expr, $thread:expr) => {
|
||||
if $flag == CancelationFlag::Panic {
|
||||
match $thread.join() {
|
||||
Ok(value) => panic!("expected cancelation, got {:?}", value),
|
||||
Err(payload) => match payload.downcast::<Canceled>() {
|
||||
Ok(_) => {}
|
||||
Err(payload) => ::std::panic::resume_unwind(payload),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
assert_eq!($thread.join().unwrap(), usize::max_value());
|
||||
macro_rules! assert_cancelled {
|
||||
($thread:expr) => {
|
||||
match $thread.join() {
|
||||
Ok(value) => panic!("expected cancellation, got {:?}", value),
|
||||
Err(payload) => match payload.downcast::<Cancelled>() {
|
||||
Ok(_) => {}
|
||||
Err(payload) => ::std::panic::resume_unwind(payload),
|
||||
},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// We have to falvors of cancellation: based on unwindig and based on anon
|
||||
/// reads. This checks both,
|
||||
fn check_cancelation(f: impl Fn(CancelationFlag)) {
|
||||
f(CancelationFlag::Panic);
|
||||
f(CancelationFlag::SpecialValue);
|
||||
}
|
||||
|
||||
/// Add test where a call to `sum` is cancelled by a simultaneous
|
||||
/// write. Check that we recompute the result in next revision, even
|
||||
/// though none of the inputs have changed.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation_immediate() {
|
||||
check_cancelation(|flag| {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 010);
|
||||
db.set_input('c', 001);
|
||||
db.set_input('d', 0);
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 010);
|
||||
db.set_input('c', 001);
|
||||
db.set_input('d', 0);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(flag, || db.sum("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("abc")
|
||||
});
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum("abc")
|
||||
});
|
||||
|
||||
assert_eq!(db.sum("d"), 1000);
|
||||
assert_canceled!(flag, thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
})
|
||||
assert_eq!(db.sum("d"), 1000);
|
||||
assert_cancelled!(thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// Here, we check that `sum`'s cancellation is propagated
|
||||
/// to `sum2` properly.
|
||||
#[test]
|
||||
fn in_par_get_set_cancellation_transitive() {
|
||||
check_cancelation(|flag| {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 010);
|
||||
db.set_input('c', 001);
|
||||
db.set_input('d', 0);
|
||||
db.set_input('a', 100);
|
||||
db.set_input('b', 010);
|
||||
db.set_input('c', 001);
|
||||
db.set_input('d', 0);
|
||||
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(flag, || db.sum2("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// This will not return until it sees cancellation is
|
||||
// signaled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
// Wait until we have entered `sum` in the other thread.
|
||||
db.wait_for(1);
|
||||
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
// Try to set the input. This will signal cancellation.
|
||||
db.set_input('d', 1000);
|
||||
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum2("abc")
|
||||
});
|
||||
// This should re-compute the value (even though no input has changed).
|
||||
let thread2 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db.sum2("abc")
|
||||
});
|
||||
|
||||
assert_eq!(db.sum2("d"), 1000);
|
||||
assert_canceled!(flag, thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
})
|
||||
assert_eq!(db.sum2("d"), 1000);
|
||||
assert_cancelled!(thread1);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// https://github.com/salsa-rs/salsa/issues/66
|
||||
#[test]
|
||||
fn no_back_dating_in_cancellation() {
|
||||
check_cancelation(|flag| {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// Here we compute a long-chain of queries,
|
||||
// but the last one gets cancelled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(flag, || db.sum3("a"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
db.wait_for(1);
|
||||
|
||||
// Set unrelated input to bump revision
|
||||
db.set_input('b', 2);
|
||||
|
||||
// Here we should recompuet the whole chain again, clearing the cancellation
|
||||
// state. If we get `usize::max()` here, it is a bug!
|
||||
assert_eq!(db.sum3("a"), 1);
|
||||
|
||||
assert_canceled!(flag, thread1);
|
||||
|
||||
db.set_input('a', 3);
|
||||
db.set_input('a', 4);
|
||||
assert_eq!(db.sum3("ab"), 6);
|
||||
})
|
||||
}
|
||||
|
||||
/// Here, we compute `sum3_drop_sum` and -- in the process -- observe
|
||||
/// a cancellation. As a result, we have to recompute `sum` when we
|
||||
/// reinvoke `sum3_drop_sum` and we have to re-execute
|
||||
/// `sum2_drop_sum`. But the result of `sum2_drop_sum` doesn't
|
||||
/// change, so we don't have to re-execute `sum3_drop_sum`.
|
||||
/// This only works with SpecialValue cancellation strategy.
|
||||
#[test]
|
||||
fn transitive_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.set_input('a', 1);
|
||||
|
@ -167,21 +110,23 @@ fn transitive_cancellation() {
|
|||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(CancelationFlag::SpecialValue, || db.sum3_drop_sum("a"))
|
||||
.with_value(CancellationFlag::Panic, || db.sum3("a"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
db.wait_for(1);
|
||||
|
||||
// Set unrelated input to bump revision
|
||||
db.set_input('b', 2);
|
||||
|
||||
// Check that when we call `sum3_drop_sum` we don't wind up having
|
||||
// to actually re-execute it, because the result of `sum2` winds
|
||||
// up not changing.
|
||||
db.knobs().sum3_drop_sum_should_panic.with_value(true, || {
|
||||
assert_eq!(db.sum3_drop_sum("a"), 22);
|
||||
});
|
||||
// Here we should recompuet the whole chain again, clearing the cancellation
|
||||
// state. If we get `usize::max()` here, it is a bug!
|
||||
assert_eq!(db.sum3("a"), 1);
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 22);
|
||||
assert_cancelled!(thread1);
|
||||
|
||||
db.set_input('a', 3);
|
||||
db.set_input('a', 4);
|
||||
assert_eq!(db.sum3("ab"), 6);
|
||||
}
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use crate::signal::Signal;
|
||||
use salsa::{Database, ParallelDatabase};
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
panic::{catch_unwind, AssertUnwindSafe},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
/// Add test where a call to `sum` is cancelled by a simultaneous
|
||||
/// write. Check that we recompute the result in next revision, even
|
||||
|
@ -20,25 +23,17 @@ fn in_par_get_set_cancellation() {
|
|||
move || {
|
||||
// Check that cancellation flag is not yet set, because
|
||||
// `set` cannot have been called yet.
|
||||
assert!(!db.salsa_runtime().is_current_revision_canceled());
|
||||
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
|
||||
|
||||
// Signal other thread to proceed.
|
||||
signal.signal(1);
|
||||
|
||||
// Wait for other thread to signal cancellation
|
||||
while !db.salsa_runtime().is_current_revision_canceled() {
|
||||
catch_unwind(AssertUnwindSafe(|| loop {
|
||||
db.unwind_if_cancelled();
|
||||
std::thread::yield_now();
|
||||
}
|
||||
|
||||
// Since we have not yet released revision lock, we should
|
||||
// see 1 here.
|
||||
let v = db.input('a');
|
||||
|
||||
// Since this is a snapshotted database, we are in a consistent
|
||||
// revision, so this must yield the same value.
|
||||
let w = db.input('a');
|
||||
|
||||
(v, w)
|
||||
}))
|
||||
.unwrap_err();
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -56,9 +51,7 @@ fn in_par_get_set_cancellation() {
|
|||
}
|
||||
});
|
||||
|
||||
let (a, b) = thread1.join().unwrap();
|
||||
assert_eq!(a, 1);
|
||||
assert_eq!(b, 1);
|
||||
thread1.join().unwrap();
|
||||
|
||||
let c = thread2.join().unwrap();
|
||||
assert_eq!(c, 2);
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use std::panic::AssertUnwindSafe;
|
||||
|
||||
use crate::setup::{ParDatabase, ParDatabaseImpl};
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::{Cancelled, ParallelDatabase};
|
||||
|
||||
/// Test where a read and a set are racing with one another.
|
||||
/// Should be atomic.
|
||||
|
@ -14,8 +16,10 @@ fn in_par_get_set_race() {
|
|||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
let v = db.sum("abc");
|
||||
v
|
||||
Cancelled::catch(AssertUnwindSafe(|| {
|
||||
let v = db.sum("abc");
|
||||
v
|
||||
}))
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -26,13 +30,13 @@ fn in_par_get_set_race() {
|
|||
|
||||
// If the 1st thread runs first, you get 111, otherwise you get
|
||||
// 1011; if they run concurrently and the 1st thread observes the
|
||||
// cancelation, you get back usize::max.
|
||||
let value1 = thread1.join().unwrap();
|
||||
assert!(
|
||||
value1 == 111 || value1 == 1011 || value1 == std::usize::MAX,
|
||||
"illegal result {}",
|
||||
value1
|
||||
);
|
||||
// cancellation, it'll unwind.
|
||||
let result1 = thread1.join().unwrap();
|
||||
if let Ok(value1) = result1 {
|
||||
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
|
||||
}
|
||||
|
||||
// thread2 can not observe a cancellation because it performs a
|
||||
// database write before running any other queries.
|
||||
assert_eq!(thread2.join().unwrap(), 1000);
|
||||
}
|
||||
|
|
|
@ -2,8 +2,11 @@ use crate::signal::Signal;
|
|||
use salsa::Database;
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::Snapshot;
|
||||
use std::cell::Cell;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
cell::Cell,
|
||||
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
|
||||
};
|
||||
|
||||
#[salsa::query_group(Par)]
|
||||
pub(crate) trait ParDatabase: Knobs {
|
||||
|
@ -25,16 +28,6 @@ pub(crate) trait ParDatabase: Knobs {
|
|||
fn sum3_drop_sum(&self, key: &'static str) -> usize;
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub(crate) struct Canceled;
|
||||
|
||||
impl Canceled {
|
||||
fn throw() -> ! {
|
||||
// Don't print backtrace
|
||||
std::panic::resume_unwind(Box::new(Canceled));
|
||||
}
|
||||
}
|
||||
|
||||
/// Various "knobs" and utilities used by tests to force
|
||||
/// a certain behavior.
|
||||
pub(crate) trait Knobs {
|
||||
|
@ -53,24 +46,26 @@ impl<T> WithValue<T> for Cell<T> {
|
|||
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
|
||||
let old_value = self.replace(value);
|
||||
|
||||
let result = closure();
|
||||
let result = catch_unwind(AssertUnwindSafe(|| closure()));
|
||||
|
||||
self.set(old_value);
|
||||
|
||||
result
|
||||
match result {
|
||||
Ok(r) => r,
|
||||
Err(payload) => resume_unwind(payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum CancelationFlag {
|
||||
pub(crate) enum CancellationFlag {
|
||||
Down,
|
||||
Panic,
|
||||
SpecialValue,
|
||||
}
|
||||
|
||||
impl Default for CancelationFlag {
|
||||
fn default() -> CancelationFlag {
|
||||
CancelationFlag::Down
|
||||
impl Default for CancellationFlag {
|
||||
fn default() -> CancellationFlag {
|
||||
CancellationFlag::Down
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -97,7 +92,7 @@ pub(crate) struct KnobsStruct {
|
|||
|
||||
/// If true, invocations of `sum` will wait for cancellation before
|
||||
/// they exit.
|
||||
pub(crate) sum_wait_for_cancellation: Cell<CancelationFlag>,
|
||||
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
|
||||
|
||||
/// Invocations of `sum` will wait for this stage prior to exiting.
|
||||
pub(crate) sum_wait_for_on_exit: Cell<usize>,
|
||||
|
@ -125,31 +120,16 @@ fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
|
|||
}
|
||||
|
||||
match db.knobs().sum_wait_for_cancellation.get() {
|
||||
CancelationFlag::Down => (),
|
||||
flag => {
|
||||
CancellationFlag::Down => (),
|
||||
CancellationFlag::Panic => {
|
||||
log::debug!("waiting for cancellation");
|
||||
while !db.salsa_runtime().is_current_revision_canceled() {
|
||||
loop {
|
||||
db.unwind_if_cancelled();
|
||||
std::thread::yield_now();
|
||||
}
|
||||
log::debug!("observed cancelation");
|
||||
if flag == CancelationFlag::Panic {
|
||||
Canceled::throw();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for cancelation and return MAX if so. Note that we check
|
||||
// for cancelation *deterministically* -- but if
|
||||
// `sum_wait_for_cancellation` is set, we will block
|
||||
// beforehand. Deterministic execution is a requirement for valid
|
||||
// salsa user code. It's also important to some tests that `sum`
|
||||
// *attempts* to invoke `is_current_revision_canceled` even if we
|
||||
// know it will not be canceled, because that helps us keep the
|
||||
// accounting up to date.
|
||||
if db.salsa_runtime().is_current_revision_canceled() {
|
||||
return std::usize::MAX; // when we are cancelled, we return usize::MAX.
|
||||
}
|
||||
|
||||
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
|
||||
|
||||
db.signal(db.knobs().sum_signal_on_exit.get());
|
||||
|
@ -194,10 +174,6 @@ impl Database for ParDatabaseImpl {
|
|||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_propagated_panic(&self) -> ! {
|
||||
Canceled::throw()
|
||||
}
|
||||
}
|
||||
|
||||
impl ParallelDatabase for ParDatabaseImpl {
|
||||
|
|
|
@ -1,37 +1,31 @@
|
|||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
|
||||
use salsa::Database;
|
||||
use salsa::ParallelDatabase;
|
||||
use salsa::Snapshot;
|
||||
use salsa::SweepStrategy;
|
||||
use salsa::{Cancelled, Database};
|
||||
|
||||
// Number of operations a reader performs
|
||||
const N_MUTATOR_OPS: usize = 100;
|
||||
const N_READER_OPS: usize = 100;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
struct Canceled;
|
||||
type Cancelable<T> = Result<T, Canceled>;
|
||||
|
||||
#[salsa::query_group(Stress)]
|
||||
trait StressDatabase: salsa::Database {
|
||||
#[salsa::input]
|
||||
fn a(&self, key: usize) -> usize;
|
||||
|
||||
fn b(&self, key: usize) -> Cancelable<usize>;
|
||||
fn b(&self, key: usize) -> usize;
|
||||
|
||||
fn c(&self, key: usize) -> Cancelable<usize>;
|
||||
fn c(&self, key: usize) -> usize;
|
||||
}
|
||||
|
||||
fn b(db: &dyn StressDatabase, key: usize) -> Cancelable<usize> {
|
||||
if db.salsa_runtime().is_current_revision_canceled() {
|
||||
return Err(Canceled);
|
||||
}
|
||||
Ok(db.a(key))
|
||||
fn b(db: &dyn StressDatabase, key: usize) -> usize {
|
||||
db.unwind_if_cancelled();
|
||||
db.a(key)
|
||||
}
|
||||
|
||||
fn c(db: &dyn StressDatabase, key: usize) -> Cancelable<usize> {
|
||||
fn c(db: &dyn StressDatabase, key: usize) -> usize {
|
||||
db.b(key)
|
||||
}
|
||||
|
||||
|
@ -127,9 +121,7 @@ impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard
|
|||
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
|
||||
for op in ops {
|
||||
if check_cancellation {
|
||||
if db.salsa_runtime().is_current_revision_canceled() {
|
||||
return;
|
||||
}
|
||||
db.unwind_if_cancelled();
|
||||
}
|
||||
op.execute(db);
|
||||
}
|
||||
|
@ -199,12 +191,12 @@ fn stress_test() {
|
|||
check_cancellation,
|
||||
} => all_threads.push(std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || db_reader_thread(&db, ops, check_cancellation)
|
||||
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
for thread in all_threads {
|
||||
thread.join().unwrap();
|
||||
thread.join().unwrap().ok();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue