mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-23 13:10:19 +00:00
Merge pull request #105 from nikomatsakis/issue-66-transitive-cancelation
Issue 66 transitive cancelation
This commit is contained in:
commit
aaa50e01d6
6 changed files with 287 additions and 40 deletions
|
@ -281,6 +281,13 @@ where
|
|||
) -> Result<StampedValue<Q::Value>, CycleDetected> {
|
||||
let runtime = db.salsa_runtime();
|
||||
|
||||
debug!(
|
||||
"{:?}({:?}): read_upgrade(revision_now={:?})",
|
||||
Q::default(),
|
||||
key,
|
||||
revision_now,
|
||||
);
|
||||
|
||||
// Check with an upgradable read to see if there is a value
|
||||
// already. (This permits other readers but prevents anyone
|
||||
// else from running `read_upgrade` at the same time.)
|
||||
|
@ -355,6 +362,13 @@ where
|
|||
if let Some(old_memo) = &old_memo {
|
||||
if let Some(old_value) = &old_memo.value {
|
||||
if MP::memoized_value_eq(&old_value, &result.value) {
|
||||
debug!(
|
||||
"read_upgrade({:?}({:?})): value is equal, back-dating to {:?}",
|
||||
Q::default(),
|
||||
key,
|
||||
old_memo.changed_at,
|
||||
);
|
||||
|
||||
assert!(old_memo.changed_at <= result.changed_at.revision);
|
||||
result.changed_at.revision = old_memo.changed_at;
|
||||
}
|
||||
|
@ -372,6 +386,14 @@ where
|
|||
None
|
||||
};
|
||||
|
||||
debug!(
|
||||
"read_upgrade({:?}({:?})): result.changed_at={:?}, result.subqueries = {:#?}",
|
||||
Q::default(),
|
||||
key,
|
||||
result.changed_at,
|
||||
result.subqueries,
|
||||
);
|
||||
|
||||
let inputs = match result.subqueries {
|
||||
None => MemoInputs::Untracked,
|
||||
|
||||
|
@ -666,7 +688,7 @@ where
|
|||
let revision_now = runtime.current_revision();
|
||||
|
||||
debug!(
|
||||
"{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})",
|
||||
"maybe_changed_since({:?}({:?})) called with revision={:?}, revision_now={:?}",
|
||||
Q::default(),
|
||||
key,
|
||||
revision,
|
||||
|
@ -682,13 +704,26 @@ where
|
|||
// If somebody depends on us, but we have no map
|
||||
// entry, that must mean that it was found to be out
|
||||
// of date and removed.
|
||||
None => return true,
|
||||
None => {
|
||||
debug!(
|
||||
"maybe_changed_since({:?}({:?}): no value",
|
||||
Q::default(),
|
||||
key,
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
// This value is being actively recomputed. Wait for
|
||||
// that thread to finish (assuming it's not dependent
|
||||
// on us...) and check its associated revision.
|
||||
Some(QueryState::InProgress { id, waiting }) => {
|
||||
let other_id = *id;
|
||||
debug!(
|
||||
"maybe_changed_since({:?}({:?}): blocking on thread `{:?}`",
|
||||
Q::default(),
|
||||
key,
|
||||
other_id,
|
||||
);
|
||||
match self.register_with_in_progress_thread(runtime, descriptor, other_id, waiting)
|
||||
{
|
||||
Ok(rx) => {
|
||||
|
@ -709,6 +744,13 @@ where
|
|||
};
|
||||
|
||||
if memo.verified_at == revision_now {
|
||||
debug!(
|
||||
"maybe_changed_since({:?}({:?}): {:?} since up-to-date memo that changed at {:?}",
|
||||
Q::default(),
|
||||
key,
|
||||
memo.changed_at > revision,
|
||||
memo.changed_at,
|
||||
);
|
||||
return memo.changed_at > revision;
|
||||
}
|
||||
|
||||
|
@ -718,6 +760,11 @@ where
|
|||
// inputs, so if there is a new
|
||||
// revision, we must assume it is
|
||||
// dirty
|
||||
debug!(
|
||||
"maybe_changed_since({:?}({:?}): true since untracked inputs",
|
||||
Q::default(),
|
||||
key,
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -734,7 +781,16 @@ where
|
|||
if memo.value.is_some() {
|
||||
std::mem::drop(map);
|
||||
return match self.read_upgrade(db, key, descriptor, revision_now) {
|
||||
Ok(v) => v.changed_at.changed_since(revision),
|
||||
Ok(v) => {
|
||||
debug!(
|
||||
"maybe_changed_since({:?}({:?}): {:?} since (recomputed) value changed at {:?}",
|
||||
Q::default(),
|
||||
key,
|
||||
v.changed_at.changed_since(revision),
|
||||
v.changed_at,
|
||||
);
|
||||
v.changed_at.changed_since(revision)
|
||||
}
|
||||
Err(CycleDetected) => true,
|
||||
};
|
||||
}
|
||||
|
@ -927,6 +983,12 @@ where
|
|||
assert!(self.verified_at != revision_now);
|
||||
let verified_at = self.verified_at;
|
||||
|
||||
debug!(
|
||||
"validate_memoized_value({:?}): verified_at={:#?}",
|
||||
Q::default(),
|
||||
self.inputs,
|
||||
);
|
||||
|
||||
let is_constant = match &mut self.inputs {
|
||||
// We can't validate values that had untracked inputs; just have to
|
||||
// re-execute.
|
||||
|
|
127
src/runtime.rs
127
src/runtime.rs
|
@ -159,6 +159,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Read current value of the revision counter.
|
||||
#[inline]
|
||||
fn pending_revision(&self) -> Revision {
|
||||
Revision {
|
||||
generation: self.shared_state.pending_revision.load(Ordering::SeqCst) as u64,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the current revision is canceled. If this method ever
|
||||
/// returns true, the currently executing query is also marked as
|
||||
/// having an *untracked read* -- this means that, in the next
|
||||
|
@ -169,14 +177,48 @@ where
|
|||
/// whatever it likes.
|
||||
#[inline]
|
||||
pub fn is_current_revision_canceled(&self) -> bool {
|
||||
let pending_revision_increments = self
|
||||
.shared_state
|
||||
.pending_revision_increments
|
||||
.load(Ordering::SeqCst);
|
||||
if pending_revision_increments > 0 {
|
||||
let current_revision = self.current_revision();
|
||||
let pending_revision = self.pending_revision();
|
||||
debug!(
|
||||
"is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
|
||||
current_revision, pending_revision
|
||||
);
|
||||
if pending_revision > current_revision {
|
||||
self.report_untracked_read();
|
||||
true
|
||||
} else {
|
||||
// Subtle: If the current revision is not canceled, we
|
||||
// still report an **anonymous** read, which will bump up
|
||||
// the revision number to be at least the last
|
||||
// non-canceled revision. This is needed to ensure
|
||||
// deterministic reads and avoid salsa-rs/salsa#66. The
|
||||
// specific scenario we are trying to avoid is tested by
|
||||
// `no_back_dating_in_cancellation`; it works like
|
||||
// this. Imagine we have 3 queries, where Query3 invokes
|
||||
// Query2 which invokes Query1. Then:
|
||||
//
|
||||
// - In Revision R1:
|
||||
// - Query1: Observes cancelation and returns sentinel S.
|
||||
// - Recorded inputs: Untracked, because we observed cancelation.
|
||||
// - Query2: Reads Query1 and propagates sentinel S.
|
||||
// - Recorded inputs: Query1, changed-at=R1
|
||||
// - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
|
||||
// - Recorded inputs: Query2, changed-at=R1
|
||||
// - In Revision R2:
|
||||
// - Query1: Observes no cancelation. All of its inputs last changed in R0,
|
||||
// so it returns a valid value with "changed at" of R0.
|
||||
// - Recorded inputs: ..., changed-at=R0
|
||||
// - Query2: Recomputes its value and returns correct result.
|
||||
// - Recorded inputs: Query1, changed-at=R0 <-- key problem!
|
||||
// - Query3: sees that Query2's result last changed in R0, so it thinks it
|
||||
// can re-use its value from R1 (which is the sentinel value).
|
||||
//
|
||||
// The anonymous read here prevents that scenario: Query1
|
||||
// winds up with a changed-at setting of R2, which is the
|
||||
// "pending revision", and hence Query2 and Query3
|
||||
// are recomputed.
|
||||
assert_eq!(pending_revision, current_revision);
|
||||
self.report_anon_read(pending_revision);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
@ -190,6 +232,10 @@ where
|
|||
/// method will also increment `pending_revision_increments`, thus
|
||||
/// signalling to queries that their results are "canceled" and
|
||||
/// they should abort as expeditiously as possible.
|
||||
///
|
||||
/// Note that, given our writer model, we can assume that only one
|
||||
/// thread is attempting to increment the global revision at a
|
||||
/// time.
|
||||
pub(crate) fn with_incremented_revision<R>(&self, op: impl FnOnce(Revision) -> R) -> R {
|
||||
log::debug!("increment_revision()");
|
||||
|
||||
|
@ -197,36 +243,24 @@ where
|
|||
panic!("increment_revision invoked during a query computation");
|
||||
}
|
||||
|
||||
// Signal that we have a pending increment so that workers can
|
||||
// start to cancel work.
|
||||
let old_pending_revision_increments = self
|
||||
// Set the `pending_revision` field so that people
|
||||
// know current revision is canceled.
|
||||
let current_revision = self
|
||||
.shared_state
|
||||
.pending_revision_increments
|
||||
.pending_revision
|
||||
.fetch_add(1, Ordering::SeqCst);
|
||||
assert!(
|
||||
old_pending_revision_increments != usize::max_value(),
|
||||
"pending increment overflow"
|
||||
);
|
||||
assert!(current_revision != usize::max_value(), "revision overflow");
|
||||
|
||||
// To modify the revision, we need the lock.
|
||||
let _lock = self.shared_state.query_lock.write();
|
||||
|
||||
// *Before* updating the revision number, decrement the
|
||||
// `pending_revision_increments` counter. This way, if anybody
|
||||
// should happen to invoke `is_current_revision_canceled`
|
||||
// before we update the number, and they read 0, they don't
|
||||
// get an incorrect result -- once they acquire the query
|
||||
// lock, we'll be in the new revision.
|
||||
self.shared_state
|
||||
.pending_revision_increments
|
||||
.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
let old_revision = self.shared_state.revision.fetch_add(1, Ordering::SeqCst);
|
||||
assert!(old_revision != usize::max_value(), "revision overflow");
|
||||
assert_eq!(current_revision, old_revision);
|
||||
|
||||
let new_revision = Revision {
|
||||
generation: 1 + old_revision as u64,
|
||||
generation: (current_revision + 1) as u64,
|
||||
};
|
||||
|
||||
debug!("increment_revision: incremented to {:?}", new_revision);
|
||||
|
||||
op(new_revision)
|
||||
|
@ -308,6 +342,20 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// An "anonymous" read is a read that doesn't come from executing
|
||||
/// a query, but from some other internal operation. It just
|
||||
/// modifies the "changed at" to be at least the given revision.
|
||||
/// (It also does not disqualify a query from being considered
|
||||
/// constant, since it is used for queries that don't give back
|
||||
/// actual *data*.)
|
||||
///
|
||||
/// This is used when queries check if they have been canceled.
|
||||
fn report_anon_read(&self, revision: Revision) {
|
||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
||||
top_query.add_anon_read(revision);
|
||||
}
|
||||
}
|
||||
|
||||
/// Obviously, this should be user configurable at some point.
|
||||
pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
|
||||
debug!("report_unexpected_cycle(descriptor={:?})", descriptor);
|
||||
|
@ -374,12 +422,10 @@ struct SharedState<DB: Database> {
|
|||
/// (Ideally, this should be `AtomicU64`, but that is currently unstable.)
|
||||
revision: AtomicUsize,
|
||||
|
||||
/// Counts the number of pending increments to the revision
|
||||
/// counter. If this is non-zero, it means that the current
|
||||
/// revision is out of date, and hence queries are free to
|
||||
/// "short-circuit" their results if they learn that. See
|
||||
/// `is_current_revision_canceled` for more information.
|
||||
pending_revision_increments: AtomicUsize,
|
||||
/// This is typically equal to `revision` -- set to `revision+1`
|
||||
/// when a new revision is pending (which implies that the current
|
||||
/// revision is canceled).
|
||||
pending_revision: AtomicUsize,
|
||||
|
||||
/// The dependency graph tracks which runtimes are blocked on one
|
||||
/// another, waiting for queries to terminate.
|
||||
|
@ -393,8 +439,8 @@ impl<DB: Database> Default for SharedState<DB> {
|
|||
storage: Default::default(),
|
||||
query_lock: Default::default(),
|
||||
revision: Default::default(),
|
||||
pending_revision: Default::default(),
|
||||
dependency_graph: Default::default(),
|
||||
pending_revision_increments: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -414,7 +460,7 @@ where
|
|||
fmt.debug_struct("SharedState")
|
||||
.field("query_lock", &query_lock)
|
||||
.field("revision", &self.revision)
|
||||
.field("pending_revision_increments", &self.pending_revision_increments)
|
||||
.field("pending_revision", &self.pending_revision)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
@ -501,6 +547,10 @@ impl<DB: Database> ActiveQuery<DB> {
|
|||
self.changed_at.is_constant = false;
|
||||
self.changed_at.revision = changed_at;
|
||||
}
|
||||
|
||||
fn add_anon_read(&mut self, changed_at: Revision) {
|
||||
self.changed_at.revision = self.changed_at.revision.max(changed_at);
|
||||
}
|
||||
}
|
||||
|
||||
/// A unique identifier for a particular runtime. Each time you create
|
||||
|
@ -523,6 +573,17 @@ pub struct Revision {
|
|||
|
||||
impl Revision {
|
||||
pub(crate) const ZERO: Self = Revision { generation: 0 };
|
||||
|
||||
fn next(self) -> Revision {
|
||||
Revision {
|
||||
generation: self.generation + 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_usize(self) -> usize {
|
||||
assert!(self.generation < (std::usize::MAX as u64));
|
||||
self.generation as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Revision {
|
||||
|
|
|
@ -66,10 +66,8 @@ fn revalidate() {
|
|||
// Second generation: volatile will change (to 1) but memoized1
|
||||
// will not (still 0, as 1/2 = 0)
|
||||
query.salsa_runtime().next_revision();
|
||||
|
||||
query.memoized2();
|
||||
query.assert_log(&["Memoized1 invoked", "Volatile invoked"]);
|
||||
|
||||
query.memoized2();
|
||||
query.assert_log(&[]);
|
||||
|
||||
|
|
|
@ -83,3 +83,75 @@ fn in_par_get_set_cancellation_transitive() {
|
|||
assert_eq!(thread1.join().unwrap(), std::usize::MAX);
|
||||
assert_eq!(thread2.join().unwrap(), 111);
|
||||
}
|
||||
|
||||
/// https://github.com/salsa-rs/salsa/issues/66
|
||||
#[test]
|
||||
fn no_back_dating_in_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.query_mut(Input).set('a', 1);
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// Here we compute a long-chain of queries,
|
||||
// but the last one gets cancelled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(true, || db.sum3("a"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
db.wait_for(1);
|
||||
|
||||
// Set unrelated input to bump revision
|
||||
db.query_mut(Input).set('b', 2);
|
||||
|
||||
// Here we should recompuet the whole chain again, clearing the cancellation
|
||||
// state. If we get `usize::max()` here, it is a bug!
|
||||
assert_eq!(db.sum3("a"), 1);
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), std::usize::MAX);
|
||||
|
||||
db.query_mut(Input).set('a', 3);
|
||||
db.query_mut(Input).set('a', 4);
|
||||
assert_eq!(db.sum3("ab"), 6);
|
||||
}
|
||||
|
||||
/// Here, we compute `sum3_drop_sum` and -- in the process -- observe
|
||||
/// a cancellation. As a result, we have to recompute `sum` when we
|
||||
/// reinvoke `sum3_drop_sum` and we have to re-execute
|
||||
/// `sum2_drop_sum`. But the result of `sum2_drop_sum` doesn't
|
||||
/// change, so we don't have to re-execute `sum3_drop_sum`.
|
||||
#[test]
|
||||
fn transitive_cancellation() {
|
||||
let mut db = ParDatabaseImpl::default();
|
||||
|
||||
db.query_mut(Input).set('a', 1);
|
||||
let thread1 = std::thread::spawn({
|
||||
let db = db.snapshot();
|
||||
move || {
|
||||
// Here we compute a long-chain of queries,
|
||||
// but the last one gets cancelled.
|
||||
db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||
db.knobs()
|
||||
.sum_wait_for_cancellation
|
||||
.with_value(true, || db.sum3_drop_sum("a"))
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
db.wait_for(1);
|
||||
|
||||
db.query_mut(Input).set('b', 2);
|
||||
|
||||
// Check that when we call `sum3_drop_sum` we don't wind up having
|
||||
// to actually re-execute it, because the result of `sum2` winds
|
||||
// up not changing.
|
||||
db.knobs().sum3_drop_sum_should_panic.with_value(true, || {
|
||||
assert_eq!(db.sum3_drop_sum("a"), 22);
|
||||
});
|
||||
|
||||
assert_eq!(thread1.join().unwrap(), 22);
|
||||
}
|
||||
|
|
|
@ -25,9 +25,14 @@ fn in_par_get_set_race() {
|
|||
});
|
||||
|
||||
// If the 1st thread runs first, you get 111, otherwise you get
|
||||
// 1011.
|
||||
// 1011; if they run concurrently and the 1st thread observes the
|
||||
// cancelation, you get back usize::max.
|
||||
let value1 = thread1.join().unwrap();
|
||||
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
|
||||
assert!(
|
||||
value1 == 111 || value1 == 1011 || value1 == std::usize::MAX,
|
||||
"illegal result {}",
|
||||
value1
|
||||
);
|
||||
|
||||
assert_eq!(thread2.join().unwrap(), 1000);
|
||||
}
|
||||
|
|
|
@ -16,10 +16,26 @@ salsa::query_group! {
|
|||
type Sum;
|
||||
}
|
||||
|
||||
/// Invokes `sum`
|
||||
fn sum2(key: &'static str) -> usize {
|
||||
type Sum2;
|
||||
}
|
||||
|
||||
/// Invokes `sum` but doesn't really care about the result.
|
||||
fn sum2_drop_sum(key: &'static str) -> usize {
|
||||
type Sum2Drop;
|
||||
}
|
||||
|
||||
/// Invokes `sum2`
|
||||
fn sum3(key: &'static str) -> usize {
|
||||
type Sum3;
|
||||
}
|
||||
|
||||
/// Invokes `sum2_drop_sum`
|
||||
fn sum3_drop_sum(key: &'static str) -> usize {
|
||||
type Sum3Drop;
|
||||
}
|
||||
|
||||
fn snapshot_me() -> () {
|
||||
type SnapshotMe;
|
||||
}
|
||||
|
@ -82,6 +98,9 @@ pub(crate) struct KnobsStruct {
|
|||
|
||||
/// Invocations of `sum` will signal this stage prior to exiting.
|
||||
pub(crate) sum_signal_on_exit: Cell<usize>,
|
||||
|
||||
/// Invocations of `sum3_drop_sum` will panic unconditionally
|
||||
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
|
||||
}
|
||||
|
||||
fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||
|
@ -105,6 +124,17 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
|||
std::thread::yield_now();
|
||||
}
|
||||
log::debug!("cancellation observed");
|
||||
}
|
||||
|
||||
// Check for cancelation and return MAX if so. Note that we check
|
||||
// for cancelation *deterministically* -- but if
|
||||
// `sum_wait_for_cancellation` is set, we will block
|
||||
// beforehand. Deterministic execution is a requirement for valid
|
||||
// salsa user code. It's also important to some tests that `sum`
|
||||
// *attempts* to invoke `is_current_revision_canceled` even if we
|
||||
// know it will not be canceled, because that helps us keep the
|
||||
// accounting up to date.
|
||||
if db.salsa_runtime().is_current_revision_canceled() {
|
||||
return std::usize::MAX; // when we are cancelled, we return usize::MAX.
|
||||
}
|
||||
|
||||
|
@ -119,6 +149,22 @@ fn sum2(db: &impl ParDatabase, key: &'static str) -> usize {
|
|||
db.sum(key)
|
||||
}
|
||||
|
||||
fn sum2_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||
let _ = db.sum(key);
|
||||
22
|
||||
}
|
||||
|
||||
fn sum3(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||
db.sum2(key)
|
||||
}
|
||||
|
||||
fn sum3_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||
if db.knobs().sum3_drop_sum_should_panic.get() {
|
||||
panic!("sum3_drop_sum executed")
|
||||
}
|
||||
db.sum2_drop_sum(key)
|
||||
}
|
||||
|
||||
fn snapshot_me(db: &impl ParDatabase) {
|
||||
// this should panic
|
||||
db.snapshot();
|
||||
|
@ -176,6 +222,9 @@ salsa::database_storage! {
|
|||
fn input() for Input;
|
||||
fn sum() for Sum;
|
||||
fn sum2() for Sum2;
|
||||
fn sum2_drop_sum() for Sum2Drop;
|
||||
fn sum3() for Sum3;
|
||||
fn sum3_drop_sum() for Sum3Drop;
|
||||
fn snapshot_me() for SnapshotMe;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue