265: Implement "opinionated cancellation" r=nikomatsakis a=jonas-schievink

This implements the design described in RFC https://github.com/salsa-rs/salsa/pull/262.

Currently, the `in_par_get_set_cancellation` test is failing. I'm not completely sure what it is trying to test (the comment on the test does not match its behavior), so I couldn't fix it.

Co-authored-by: Jonas Schievink <jonasschievink@gmail.com>
This commit is contained in:
bors[bot] 2021-05-27 12:30:18 +00:00 committed by GitHub
commit e17e77c4e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 244 additions and 346 deletions

View file

@ -447,14 +447,18 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
/// deadlock.
///
/// Before blocking, the thread that is attempting to `set` will
/// also set a cancellation flag. In the threads operating on
/// snapshots, you can use the [`is_current_revision_canceled`]
/// method to check for this flag and bring those operations to a
/// close, thus allowing the `set` to succeed. Ignoring this flag
/// may lead to "starvation", meaning that the thread attempting
/// to `set` has to wait a long, long time. =)
/// also set a cancellation flag. This will cause any query
/// invocations in other threads to unwind with a `Cancelled`
/// sentinel value and eventually let the `set` succeed once all
/// threads have unwound past the salsa invocation.
///
/// [`is_current_revision_canceled`]: struct.Runtime.html#method.is_current_revision_canceled
/// If your query implementations are performing expensive
/// operations without invoking another query, you can also use
/// the `Runtime::unwind_if_cancelled` method to check for an
/// ongoing cancellation and bring those operations to a close,
/// thus allowing the `set` to succeed. Otherwise, long-running
/// computations may lead to "starvation", meaning that the
/// thread attempting to `set` has to wait a long, long time. =)
#trait_vis fn in_db_mut(self, db: &mut #dyn_db) -> salsa::QueryTableMut<'_, Self>
{
salsa::plumbing::get_query_table_mut::<#qt>(db)

View file

@ -162,6 +162,8 @@ where
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
db.unwind_if_cancelled();
let slot = self.slot(key);
let StampedValue {
value,

View file

@ -10,6 +10,7 @@ use crate::revision::Revision;
use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
use crate::runtime::StampedValue;
use crate::Cancelled;
use crate::{
CycleError, Database, DatabaseKeyIndex, DiscardIf, DiscardWhat, Event, EventKind, QueryDb,
SweepStrategy,
@ -353,7 +354,12 @@ where
},
});
let result = future.wait().unwrap_or_else(|| db.on_propagated_panic());
let result = future.wait().unwrap_or_else(|| {
// If the other thread panics, we treat this as cancellation: there is no
// need to panic ourselves, since the original panic will already invoke
// the panic hook and bubble up to the thread boundary (or be caught).
Cancelled::throw()
});
ProbeState::UpToDate(if result.cycle.is_empty() {
Ok(result.value)
} else {
@ -541,6 +547,8 @@ where
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
db.unwind_if_cancelled();
debug!(
"maybe_changed_since({:?}) called with revision={:?}, revision_now={:?}",
self, revision, revision_now,
@ -574,7 +582,7 @@ where
// Release our lock on `self.state`, so other thread can complete.
std::mem::drop(state);
let result = future.wait().unwrap_or_else(|| db.on_propagated_panic());
let result = future.wait().unwrap_or_else(|| Cancelled::throw());
return !result.cycle.is_empty() || result.value.changed_at > revision;
}

View file

@ -99,6 +99,8 @@ where
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
db.unwind_if_cancelled();
let slot = self
.slot(key)
.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));

View file

@ -322,6 +322,8 @@ where
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> Result<Q::Value, CycleError<DatabaseKeyIndex>> {
db.unwind_if_cancelled();
let slot = self.intern_index(db, key);
let changed_at = slot.interned_at;
let index = slot.index;

View file

@ -34,6 +34,7 @@ use crate::plumbing::QueryStorageOps;
pub use crate::revision::Revision;
use std::fmt::{self, Debug};
use std::hash::Hash;
use std::panic::{self, UnwindSafe};
use std::sync::Arc;
pub use crate::durability::Durability;
@ -54,6 +55,8 @@ pub trait Database: plumbing::DatabaseOps {
/// consume are marked as used. You then invoke this method to
/// remove other values that were not needed for your main query
/// results.
///
/// This method should not be overridden by `Database` implementors.
fn sweep_all(&self, strategy: SweepStrategy) {
// Note that we do not acquire the query lock (or any locks)
// here. Each table is capable of sweeping itself atomically
@ -71,18 +74,48 @@ pub trait Database: plumbing::DatabaseOps {
#![allow(unused_variables)]
}
/// This function is invoked when a dependent query is being computed by the
/// other thread, and that thread panics.
fn on_propagated_panic(&self) -> ! {
panic!("concurrent salsa query panicked")
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
/// potentially expensive computations, in order to speed up propagation of
/// cancellation.
///
/// Cancellation will automatically be triggered by salsa on any query
/// invocation.
///
/// This method should not be overridden by `Database` implementors. A
/// `salsa_event` is emitted when this method is called, so that should be
/// used instead.
#[inline]
fn unwind_if_cancelled(&self) {
let runtime = self.salsa_runtime();
self.salsa_event(Event {
runtime_id: runtime.id(),
kind: EventKind::WillCheckCancellation,
});
let current_revision = runtime.current_revision();
let pending_revision = runtime.pending_revision();
log::debug!(
"unwind_if_cancelled: current_revision={:?}, pending_revision={:?}",
current_revision,
pending_revision
);
if pending_revision > current_revision {
runtime.unwind_cancelled();
}
}
/// Gives access to the underlying salsa runtime.
///
/// This method should not be overridden by `Database` implementors.
fn salsa_runtime(&self) -> &Runtime {
self.ops_salsa_runtime()
}
/// Gives access to the underlying salsa runtime.
///
/// This method should not be overridden by `Database` implementors.
fn salsa_runtime_mut(&mut self) -> &mut Runtime {
self.ops_salsa_runtime_mut()
}
@ -145,6 +178,10 @@ pub enum EventKind {
/// The database-key for the affected value. Implements `Debug`.
database_key: DatabaseKeyIndex,
},
/// Indicates that `unwind_if_cancelled` was called and salsa will check if
/// the current revision has been cancelled.
WillCheckCancellation,
}
impl fmt::Debug for EventKind {
@ -166,6 +203,7 @@ impl fmt::Debug for EventKind {
.debug_struct("WillExecute")
.field("database_key", database_key)
.finish(),
EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
}
}
}
@ -277,11 +315,9 @@ pub trait ParallelDatabase: Database + Send {
/// series of queries in parallel and arranging the results. Using
/// this method for that purpose ensures that those queries will
/// see a consistent view of the database (it is also advisable
/// for those queries to use the [`is_current_revision_canceled`]
/// for those queries to use the [`Runtime::unwind_if_cancelled`]
/// method to check for cancellation).
///
/// [`is_current_revision_canceled`]: struct.Runtime.html#method.is_current_revision_canceled
///
/// # Panics
///
/// It is not permitted to create a snapshot from inside of a
@ -644,6 +680,41 @@ where
}
}
/// A panic payload indicating that a salsa revision was cancelled.
#[derive(Debug)]
#[non_exhaustive]
pub struct Cancelled;
impl Cancelled {
fn throw() -> ! {
// We use resume and not panic here to avoid running the panic
// hook (that is, to avoid collecting and printing backtrace).
std::panic::resume_unwind(Box::new(Self));
}
/// Runs `f`, and catches any salsa cancellation.
pub fn catch<F, T>(f: F) -> Result<T, Cancelled>
where
F: FnOnce() -> T + UnwindSafe,
{
match panic::catch_unwind(f) {
Ok(t) => Ok(t),
Err(payload) => match payload.downcast() {
Ok(cancelled) => Err(*cancelled),
Err(payload) => panic::resume_unwind(payload),
},
}
}
}
impl std::fmt::Display for Cancelled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("cancelled")
}
}
impl std::error::Error for Cancelled {}
// Re-export the procedural macros.
#[allow(unused_imports)]
#[macro_use]

View file

@ -1,6 +1,6 @@
use crate::durability::Durability;
use crate::plumbing::CycleDetected;
use crate::revision::{AtomicRevision, Revision};
use crate::{durability::Durability, Cancelled};
use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind};
use log::debug;
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
@ -152,93 +152,14 @@ impl Runtime {
/// Read current value of the revision counter.
#[inline]
fn pending_revision(&self) -> Revision {
pub(crate) fn pending_revision(&self) -> Revision {
self.shared_state.pending_revision.load()
}
/// Check if the current revision is canceled. If this method ever
/// returns true, the currently executing query is also marked as
/// having an *untracked read* -- this means that, in the next
/// revision, we will always recompute its value "as if" some
/// input had changed. This means that, if your revision is
/// canceled (which indicates that current query results will be
/// ignored) your query is free to shortcircuit and return
/// whatever it likes.
///
/// This method is useful for implementing cancellation of queries.
/// You can do it in one of two ways, via `Result`s or via unwinding.
///
/// The `Result` approach looks like this:
///
/// * Some queries invoke `is_current_revision_canceled` and
/// return a special value, like `Err(Canceled)`, if it returns
/// `true`.
/// * Other queries propagate the special value using `?` operator.
/// * API around top-level queries checks if the result is `Ok` or
/// `Err(Canceled)`.
///
/// The `panic` approach works in a similar way:
///
/// * Some queries invoke `is_current_revision_canceled` and
/// panic with a special value, like `Canceled`, if it returns
/// true.
/// * The implementation of `Database` trait overrides
/// `on_propagated_panic` to throw this special value as well.
/// This way, panic gets propagated naturally through dependant
/// queries, even across the threads.
/// * API around top-level queries converts a `panic` into `Result` by
/// catching the panic (using either `std::panic::catch_unwind` or
/// threads) and downcasting the payload to `Canceled` (re-raising
/// panic if downcast fails).
///
/// Note that salsa is explicitly designed to be panic-safe, so cancellation
/// via unwinding is 100% valid approach to cancellation.
#[inline]
pub fn is_current_revision_canceled(&self) -> bool {
let current_revision = self.current_revision();
let pending_revision = self.pending_revision();
debug!(
"is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
current_revision, pending_revision
);
if pending_revision > current_revision {
self.report_untracked_read();
true
} else {
// Subtle: If the current revision is not canceled, we
// still report an **anonymous** read, which will bump up
// the revision number to be at least the last
// non-canceled revision. This is needed to ensure
// deterministic reads and avoid salsa-rs/salsa#66. The
// specific scenario we are trying to avoid is tested by
// `no_back_dating_in_cancellation`; it works like
// this. Imagine we have 3 queries, where Query3 invokes
// Query2 which invokes Query1. Then:
//
// - In Revision R1:
// - Query1: Observes cancelation and returns sentinel S.
// - Recorded inputs: Untracked, because we observed cancelation.
// - Query2: Reads Query1 and propagates sentinel S.
// - Recorded inputs: Query1, changed-at=R1
// - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
// - Recorded inputs: Query2, changed-at=R1
// - In Revision R2:
// - Query1: Observes no cancelation. All of its inputs last changed in R0,
// so it returns a valid value with "changed at" of R0.
// - Recorded inputs: ..., changed-at=R0
// - Query2: Recomputes its value and returns correct result.
// - Recorded inputs: Query1, changed-at=R0 <-- key problem!
// - Query3: sees that Query2's result last changed in R0, so it thinks it
// can re-use its value from R1 (which is the sentinel value).
//
// The anonymous read here prevents that scenario: Query1
// winds up with a changed-at setting of R2, which is the
// "pending revision", and hence Query2 and Query3
// are recomputed.
assert_eq!(pending_revision, current_revision);
self.report_anon_read(pending_revision);
false
}
#[cold]
pub(crate) fn unwind_cancelled(&self) {
self.report_untracked_read();
Cancelled::throw();
}
/// Acquires the **global query write lock** (ensuring that no queries are
@ -247,7 +168,7 @@ impl Runtime {
///
/// While we wait to acquire the global query write lock, this method will
/// also increment `pending_revision_increments`, thus signalling to queries
/// that their results are "canceled" and they should abort as expeditiously
/// that their results are "cancelled" and they should abort as expeditiously
/// as possible.
///
/// The `op` closure should actually perform the writes needed. It is given
@ -274,7 +195,7 @@ impl Runtime {
}
// Set the `pending_revision` field so that people
// know current revision is canceled.
// know current revision is cancelled.
let current_revision = self.shared_state.pending_revision.fetch_then_increment();
// To modify the revision, we need the lock.
@ -381,18 +302,6 @@ impl Runtime {
self.local_state.report_synthetic_read(durability);
}
/// An "anonymous" read is a read that doesn't come from executing
/// a query, but from some other internal operation. It just
/// modifies the "changed at" to be at least the given revision.
/// (It also does not disqualify a query from being considered
/// constant, since it is used for queries that don't give back
/// actual *data*.)
///
/// This is used when queries check if they have been canceled.
fn report_anon_read(&self, revision: Revision) {
self.local_state.report_anon_read(revision)
}
/// Obviously, this should be user configurable at some point.
pub(crate) fn report_unexpected_cycle(
&self,
@ -516,7 +425,7 @@ struct SharedState {
/// This is typically equal to `revision` -- set to `revision+1`
/// when a new revision is pending (which implies that the current
/// revision is canceled).
/// revision is cancelled).
pending_revision: AtomicRevision,
/// Stores the "last change" revision for values of each duration.
@ -639,10 +548,6 @@ impl ActiveQuery {
fn add_synthetic_read(&mut self, durability: Durability) {
self.durability = self.durability.min(durability);
}
fn add_anon_read(&mut self, changed_at: Revision) {
self.changed_at = self.changed_at.max(changed_at);
}
}
/// A unique identifier for a particular runtime. Each time you create

View file

@ -86,12 +86,6 @@ impl LocalState {
top_query.add_synthetic_read(durability);
}
}
pub(super) fn report_anon_read(&self, revision: Revision) {
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
top_query.add_anon_read(revision);
}
}
}
impl std::panic::RefUnwindSafe for LocalState {}

View file

@ -1,161 +1,104 @@
use crate::setup::{CancelationFlag, Canceled, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::ParallelDatabase;
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::{Cancelled, ParallelDatabase};
macro_rules! assert_canceled {
($flag:expr, $thread:expr) => {
if $flag == CancelationFlag::Panic {
match $thread.join() {
Ok(value) => panic!("expected cancelation, got {:?}", value),
Err(payload) => match payload.downcast::<Canceled>() {
Ok(_) => {}
Err(payload) => ::std::panic::resume_unwind(payload),
},
}
} else {
assert_eq!($thread.join().unwrap(), usize::max_value());
macro_rules! assert_cancelled {
($thread:expr) => {
match $thread.join() {
Ok(value) => panic!("expected cancellation, got {:?}", value),
Err(payload) => match payload.downcast::<Cancelled>() {
Ok(_) => {}
Err(payload) => ::std::panic::resume_unwind(payload),
},
}
};
}
/// We have to falvors of cancellation: based on unwindig and based on anon
/// reads. This checks both,
fn check_cancelation(f: impl Fn(CancelationFlag)) {
f(CancelationFlag::Panic);
f(CancelationFlag::SpecialValue);
}
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation_immediate() {
check_cancelation(|flag| {
let mut db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 010);
db.set_input('c', 001);
db.set_input('d', 0);
db.set_input('a', 100);
db.set_input('b', 010);
db.set_input('c', 001);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(flag, || db.sum("abc"))
})
}
});
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
assert_eq!(db.sum("d"), 1000);
assert_canceled!(flag, thread1);
assert_eq!(thread2.join().unwrap(), 111);
})
assert_eq!(db.sum("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Here, we check that `sum`'s cancellation is propagated
/// to `sum2` properly.
#[test]
fn in_par_get_set_cancellation_transitive() {
check_cancelation(|flag| {
let mut db = ParDatabaseImpl::default();
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 010);
db.set_input('c', 001);
db.set_input('d', 0);
db.set_input('a', 100);
db.set_input('b', 010);
db.set_input('c', 001);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(flag, || db.sum2("abc"))
})
}
});
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum2("abc")
});
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum2("abc")
});
assert_eq!(db.sum2("d"), 1000);
assert_canceled!(flag, thread1);
assert_eq!(thread2.join().unwrap(), 111);
})
assert_eq!(db.sum2("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// https://github.com/salsa-rs/salsa/issues/66
#[test]
fn no_back_dating_in_cancellation() {
check_cancelation(|flag| {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// Here we compute a long-chain of queries,
// but the last one gets cancelled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(flag, || db.sum3("a"))
})
}
});
db.wait_for(1);
// Set unrelated input to bump revision
db.set_input('b', 2);
// Here we should recompuet the whole chain again, clearing the cancellation
// state. If we get `usize::max()` here, it is a bug!
assert_eq!(db.sum3("a"), 1);
assert_canceled!(flag, thread1);
db.set_input('a', 3);
db.set_input('a', 4);
assert_eq!(db.sum3("ab"), 6);
})
}
/// Here, we compute `sum3_drop_sum` and -- in the process -- observe
/// a cancellation. As a result, we have to recompute `sum` when we
/// reinvoke `sum3_drop_sum` and we have to re-execute
/// `sum2_drop_sum`. But the result of `sum2_drop_sum` doesn't
/// change, so we don't have to re-execute `sum3_drop_sum`.
/// This only works with SpecialValue cancellation strategy.
#[test]
fn transitive_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
@ -167,21 +110,23 @@ fn transitive_cancellation() {
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancelationFlag::SpecialValue, || db.sum3_drop_sum("a"))
.with_value(CancellationFlag::Panic, || db.sum3("a"))
})
}
});
db.wait_for(1);
// Set unrelated input to bump revision
db.set_input('b', 2);
// Check that when we call `sum3_drop_sum` we don't wind up having
// to actually re-execute it, because the result of `sum2` winds
// up not changing.
db.knobs().sum3_drop_sum_should_panic.with_value(true, || {
assert_eq!(db.sum3_drop_sum("a"), 22);
});
// Here we should recompuet the whole chain again, clearing the cancellation
// state. If we get `usize::max()` here, it is a bug!
assert_eq!(db.sum3("a"), 1);
assert_eq!(thread1.join().unwrap(), 22);
assert_cancelled!(thread1);
db.set_input('a', 3);
db.set_input('a', 4);
assert_eq!(db.sum3("ab"), 6);
}

View file

@ -1,7 +1,10 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use crate::signal::Signal;
use salsa::{Database, ParallelDatabase};
use std::sync::Arc;
use std::{
panic::{catch_unwind, AssertUnwindSafe},
sync::Arc,
};
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
@ -20,25 +23,17 @@ fn in_par_get_set_cancellation() {
move || {
// Check that cancellation flag is not yet set, because
// `set` cannot have been called yet.
assert!(!db.salsa_runtime().is_current_revision_canceled());
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
// Signal other thread to proceed.
signal.signal(1);
// Wait for other thread to signal cancellation
while !db.salsa_runtime().is_current_revision_canceled() {
catch_unwind(AssertUnwindSafe(|| loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}
// Since we have not yet released revision lock, we should
// see 1 here.
let v = db.input('a');
// Since this is a snapshotted database, we are in a consistent
// revision, so this must yield the same value.
let w = db.input('a');
(v, w)
}))
.unwrap_err();
}
});
@ -56,9 +51,7 @@ fn in_par_get_set_cancellation() {
}
});
let (a, b) = thread1.join().unwrap();
assert_eq!(a, 1);
assert_eq!(b, 1);
thread1.join().unwrap();
let c = thread2.join().unwrap();
assert_eq!(c, 2);

View file

@ -1,5 +1,7 @@
use std::panic::AssertUnwindSafe;
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::ParallelDatabase;
use salsa::{Cancelled, ParallelDatabase};
/// Test where a read and a set are racing with one another.
/// Should be atomic.
@ -14,8 +16,10 @@ fn in_par_get_set_race() {
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db.sum("abc");
v
Cancelled::catch(AssertUnwindSafe(|| {
let v = db.sum("abc");
v
}))
}
});
@ -26,13 +30,13 @@ fn in_par_get_set_race() {
// If the 1st thread runs first, you get 111, otherwise you get
// 1011; if they run concurrently and the 1st thread observes the
// cancelation, you get back usize::max.
let value1 = thread1.join().unwrap();
assert!(
value1 == 111 || value1 == 1011 || value1 == std::usize::MAX,
"illegal result {}",
value1
);
// cancellation, it'll unwind.
let result1 = thread1.join().unwrap();
if let Ok(value1) = result1 {
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
}
// thread2 can not observe a cancellation because it performs a
// database write before running any other queries.
assert_eq!(thread2.join().unwrap(), 1000);
}

View file

@ -2,8 +2,11 @@ use crate::signal::Signal;
use salsa::Database;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use std::cell::Cell;
use std::sync::Arc;
use std::{
cell::Cell,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
};
#[salsa::query_group(Par)]
pub(crate) trait ParDatabase: Knobs {
@ -25,16 +28,6 @@ pub(crate) trait ParDatabase: Knobs {
fn sum3_drop_sum(&self, key: &'static str) -> usize;
}
#[derive(PartialEq, Eq)]
pub(crate) struct Canceled;
impl Canceled {
fn throw() -> ! {
// Don't print backtrace
std::panic::resume_unwind(Box::new(Canceled));
}
}
/// Various "knobs" and utilities used by tests to force
/// a certain behavior.
pub(crate) trait Knobs {
@ -53,24 +46,26 @@ impl<T> WithValue<T> for Cell<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
let old_value = self.replace(value);
let result = closure();
let result = catch_unwind(AssertUnwindSafe(|| closure()));
self.set(old_value);
result
match result {
Ok(r) => r,
Err(payload) => resume_unwind(payload),
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum CancelationFlag {
pub(crate) enum CancellationFlag {
Down,
Panic,
SpecialValue,
}
impl Default for CancelationFlag {
fn default() -> CancelationFlag {
CancelationFlag::Down
impl Default for CancellationFlag {
fn default() -> CancellationFlag {
CancellationFlag::Down
}
}
@ -97,7 +92,7 @@ pub(crate) struct KnobsStruct {
/// If true, invocations of `sum` will wait for cancellation before
/// they exit.
pub(crate) sum_wait_for_cancellation: Cell<CancelationFlag>,
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
/// Invocations of `sum` will wait for this stage prior to exiting.
pub(crate) sum_wait_for_on_exit: Cell<usize>,
@ -125,31 +120,16 @@ fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
}
match db.knobs().sum_wait_for_cancellation.get() {
CancelationFlag::Down => (),
flag => {
CancellationFlag::Down => (),
CancellationFlag::Panic => {
log::debug!("waiting for cancellation");
while !db.salsa_runtime().is_current_revision_canceled() {
loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}
log::debug!("observed cancelation");
if flag == CancelationFlag::Panic {
Canceled::throw();
}
}
}
// Check for cancelation and return MAX if so. Note that we check
// for cancelation *deterministically* -- but if
// `sum_wait_for_cancellation` is set, we will block
// beforehand. Deterministic execution is a requirement for valid
// salsa user code. It's also important to some tests that `sum`
// *attempts* to invoke `is_current_revision_canceled` even if we
// know it will not be canceled, because that helps us keep the
// accounting up to date.
if db.salsa_runtime().is_current_revision_canceled() {
return std::usize::MAX; // when we are cancelled, we return usize::MAX.
}
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
db.signal(db.knobs().sum_signal_on_exit.get());
@ -194,10 +174,6 @@ impl Database for ParDatabaseImpl {
_ => {}
}
}
fn on_propagated_panic(&self) -> ! {
Canceled::throw()
}
}
impl ParallelDatabase for ParDatabaseImpl {

View file

@ -1,37 +1,31 @@
use rand::seq::SliceRandom;
use rand::Rng;
use salsa::Database;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use salsa::SweepStrategy;
use salsa::{Cancelled, Database};
// Number of operations a reader performs
const N_MUTATOR_OPS: usize = 100;
const N_READER_OPS: usize = 100;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct Canceled;
type Cancelable<T> = Result<T, Canceled>;
#[salsa::query_group(Stress)]
trait StressDatabase: salsa::Database {
#[salsa::input]
fn a(&self, key: usize) -> usize;
fn b(&self, key: usize) -> Cancelable<usize>;
fn b(&self, key: usize) -> usize;
fn c(&self, key: usize) -> Cancelable<usize>;
fn c(&self, key: usize) -> usize;
}
fn b(db: &dyn StressDatabase, key: usize) -> Cancelable<usize> {
if db.salsa_runtime().is_current_revision_canceled() {
return Err(Canceled);
}
Ok(db.a(key))
fn b(db: &dyn StressDatabase, key: usize) -> usize {
db.unwind_if_cancelled();
db.a(key)
}
fn c(db: &dyn StressDatabase, key: usize) -> Cancelable<usize> {
fn c(db: &dyn StressDatabase, key: usize) -> usize {
db.b(key)
}
@ -127,9 +121,7 @@ impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
for op in ops {
if check_cancellation {
if db.salsa_runtime().is_current_revision_canceled() {
return;
}
db.unwind_if_cancelled();
}
op.execute(db);
}
@ -199,12 +191,12 @@ fn stress_test() {
check_cancellation,
} => all_threads.push(std::thread::spawn({
let db = db.snapshot();
move || db_reader_thread(&db, ops, check_cancellation)
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
})),
}
}
for thread in all_threads {
thread.join().unwrap();
thread.join().unwrap().ok();
}
}