Merge pull request #105 from nikomatsakis/issue-66-transitive-cancelation

Issue 66 transitive cancelation
This commit is contained in:
Niko Matsakis 2019-01-04 13:57:09 -05:00 committed by GitHub
commit aaa50e01d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 287 additions and 40 deletions

View file

@ -281,6 +281,13 @@ where
) -> Result<StampedValue<Q::Value>, CycleDetected> {
let runtime = db.salsa_runtime();
debug!(
"{:?}({:?}): read_upgrade(revision_now={:?})",
Q::default(),
key,
revision_now,
);
// Check with an upgradable read to see if there is a value
// already. (This permits other readers but prevents anyone
// else from running `read_upgrade` at the same time.)
@ -355,6 +362,13 @@ where
if let Some(old_memo) = &old_memo {
if let Some(old_value) = &old_memo.value {
if MP::memoized_value_eq(&old_value, &result.value) {
debug!(
"read_upgrade({:?}({:?})): value is equal, back-dating to {:?}",
Q::default(),
key,
old_memo.changed_at,
);
assert!(old_memo.changed_at <= result.changed_at.revision);
result.changed_at.revision = old_memo.changed_at;
}
@ -372,6 +386,14 @@ where
None
};
debug!(
"read_upgrade({:?}({:?})): result.changed_at={:?}, result.subqueries = {:#?}",
Q::default(),
key,
result.changed_at,
result.subqueries,
);
let inputs = match result.subqueries {
None => MemoInputs::Untracked,
@ -666,7 +688,7 @@ where
let revision_now = runtime.current_revision();
debug!(
"{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})",
"maybe_changed_since({:?}({:?})) called with revision={:?}, revision_now={:?}",
Q::default(),
key,
revision,
@ -682,13 +704,26 @@ where
// If somebody depends on us, but we have no map
// entry, that must mean that it was found to be out
// of date and removed.
None => return true,
None => {
debug!(
"maybe_changed_since({:?}({:?}): no value",
Q::default(),
key,
);
return true;
}
// This value is being actively recomputed. Wait for
// that thread to finish (assuming it's not dependent
// on us...) and check its associated revision.
Some(QueryState::InProgress { id, waiting }) => {
let other_id = *id;
debug!(
"maybe_changed_since({:?}({:?}): blocking on thread `{:?}`",
Q::default(),
key,
other_id,
);
match self.register_with_in_progress_thread(runtime, descriptor, other_id, waiting)
{
Ok(rx) => {
@ -709,6 +744,13 @@ where
};
if memo.verified_at == revision_now {
debug!(
"maybe_changed_since({:?}({:?}): {:?} since up-to-date memo that changed at {:?}",
Q::default(),
key,
memo.changed_at > revision,
memo.changed_at,
);
return memo.changed_at > revision;
}
@ -718,6 +760,11 @@ where
// inputs, so if there is a new
// revision, we must assume it is
// dirty
debug!(
"maybe_changed_since({:?}({:?}): true since untracked inputs",
Q::default(),
key,
);
return true;
}
@ -734,7 +781,16 @@ where
if memo.value.is_some() {
std::mem::drop(map);
return match self.read_upgrade(db, key, descriptor, revision_now) {
Ok(v) => v.changed_at.changed_since(revision),
Ok(v) => {
debug!(
"maybe_changed_since({:?}({:?}): {:?} since (recomputed) value changed at {:?}",
Q::default(),
key,
v.changed_at.changed_since(revision),
v.changed_at,
);
v.changed_at.changed_since(revision)
}
Err(CycleDetected) => true,
};
}
@ -927,6 +983,12 @@ where
assert!(self.verified_at != revision_now);
let verified_at = self.verified_at;
debug!(
"validate_memoized_value({:?}): verified_at={:#?}",
Q::default(),
self.inputs,
);
let is_constant = match &mut self.inputs {
// We can't validate values that had untracked inputs; just have to
// re-execute.

View file

@ -159,6 +159,14 @@ where
}
}
/// Read current value of the revision counter.
#[inline]
fn pending_revision(&self) -> Revision {
Revision {
generation: self.shared_state.pending_revision.load(Ordering::SeqCst) as u64,
}
}
/// Check if the current revision is canceled. If this method ever
/// returns true, the currently executing query is also marked as
/// having an *untracked read* -- this means that, in the next
@ -169,14 +177,48 @@ where
/// whatever it likes.
#[inline]
pub fn is_current_revision_canceled(&self) -> bool {
let pending_revision_increments = self
.shared_state
.pending_revision_increments
.load(Ordering::SeqCst);
if pending_revision_increments > 0 {
let current_revision = self.current_revision();
let pending_revision = self.pending_revision();
debug!(
"is_current_revision_canceled: current_revision={:?}, pending_revision={:?}",
current_revision, pending_revision
);
if pending_revision > current_revision {
self.report_untracked_read();
true
} else {
// Subtle: If the current revision is not canceled, we
// still report an **anonymous** read, which will bump up
// the revision number to be at least the last
// non-canceled revision. This is needed to ensure
// deterministic reads and avoid salsa-rs/salsa#66. The
// specific scenario we are trying to avoid is tested by
// `no_back_dating_in_cancellation`; it works like
// this. Imagine we have 3 queries, where Query3 invokes
// Query2 which invokes Query1. Then:
//
// - In Revision R1:
// - Query1: Observes cancelation and returns sentinel S.
// - Recorded inputs: Untracked, because we observed cancelation.
// - Query2: Reads Query1 and propagates sentinel S.
// - Recorded inputs: Query1, changed-at=R1
// - Query3: Reads Query2 and propagates sentinel S. (Inputs = Query2, ChangedAt R1)
// - Recorded inputs: Query2, changed-at=R1
// - In Revision R2:
// - Query1: Observes no cancelation. All of its inputs last changed in R0,
// so it returns a valid value with "changed at" of R0.
// - Recorded inputs: ..., changed-at=R0
// - Query2: Recomputes its value and returns correct result.
// - Recorded inputs: Query1, changed-at=R0 <-- key problem!
// - Query3: sees that Query2's result last changed in R0, so it thinks it
// can re-use its value from R1 (which is the sentinel value).
//
// The anonymous read here prevents that scenario: Query1
// winds up with a changed-at setting of R2, which is the
// "pending revision", and hence Query2 and Query3
// are recomputed.
assert_eq!(pending_revision, current_revision);
self.report_anon_read(pending_revision);
false
}
}
@ -190,6 +232,10 @@ where
/// method will also increment `pending_revision_increments`, thus
/// signalling to queries that their results are "canceled" and
/// they should abort as expeditiously as possible.
///
/// Note that, given our writer model, we can assume that only one
/// thread is attempting to increment the global revision at a
/// time.
pub(crate) fn with_incremented_revision<R>(&self, op: impl FnOnce(Revision) -> R) -> R {
log::debug!("increment_revision()");
@ -197,36 +243,24 @@ where
panic!("increment_revision invoked during a query computation");
}
// Signal that we have a pending increment so that workers can
// start to cancel work.
let old_pending_revision_increments = self
// Set the `pending_revision` field so that people
// know current revision is canceled.
let current_revision = self
.shared_state
.pending_revision_increments
.pending_revision
.fetch_add(1, Ordering::SeqCst);
assert!(
old_pending_revision_increments != usize::max_value(),
"pending increment overflow"
);
assert!(current_revision != usize::max_value(), "revision overflow");
// To modify the revision, we need the lock.
let _lock = self.shared_state.query_lock.write();
// *Before* updating the revision number, decrement the
// `pending_revision_increments` counter. This way, if anybody
// should happen to invoke `is_current_revision_canceled`
// before we update the number, and they read 0, they don't
// get an incorrect result -- once they acquire the query
// lock, we'll be in the new revision.
self.shared_state
.pending_revision_increments
.fetch_sub(1, Ordering::SeqCst);
let old_revision = self.shared_state.revision.fetch_add(1, Ordering::SeqCst);
assert!(old_revision != usize::max_value(), "revision overflow");
assert_eq!(current_revision, old_revision);
let new_revision = Revision {
generation: 1 + old_revision as u64,
generation: (current_revision + 1) as u64,
};
debug!("increment_revision: incremented to {:?}", new_revision);
op(new_revision)
@ -308,6 +342,20 @@ where
}
}
/// An "anonymous" read is a read that doesn't come from executing
/// a query, but from some other internal operation. It just
/// modifies the "changed at" to be at least the given revision.
/// (It also does not disqualify a query from being considered
/// constant, since it is used for queries that don't give back
/// actual *data*.)
///
/// This is used when queries check if they have been canceled.
fn report_anon_read(&self, revision: Revision) {
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
top_query.add_anon_read(revision);
}
}
/// Obviously, this should be user configurable at some point.
pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
debug!("report_unexpected_cycle(descriptor={:?})", descriptor);
@ -374,12 +422,10 @@ struct SharedState<DB: Database> {
/// (Ideally, this should be `AtomicU64`, but that is currently unstable.)
revision: AtomicUsize,
/// Counts the number of pending increments to the revision
/// counter. If this is non-zero, it means that the current
/// revision is out of date, and hence queries are free to
/// "short-circuit" their results if they learn that. See
/// `is_current_revision_canceled` for more information.
pending_revision_increments: AtomicUsize,
/// This is typically equal to `revision` -- set to `revision+1`
/// when a new revision is pending (which implies that the current
/// revision is canceled).
pending_revision: AtomicUsize,
/// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate.
@ -393,8 +439,8 @@ impl<DB: Database> Default for SharedState<DB> {
storage: Default::default(),
query_lock: Default::default(),
revision: Default::default(),
pending_revision: Default::default(),
dependency_graph: Default::default(),
pending_revision_increments: Default::default(),
}
}
}
@ -414,7 +460,7 @@ where
fmt.debug_struct("SharedState")
.field("query_lock", &query_lock)
.field("revision", &self.revision)
.field("pending_revision_increments", &self.pending_revision_increments)
.field("pending_revision", &self.pending_revision)
.finish()
}
}
@ -501,6 +547,10 @@ impl<DB: Database> ActiveQuery<DB> {
self.changed_at.is_constant = false;
self.changed_at.revision = changed_at;
}
fn add_anon_read(&mut self, changed_at: Revision) {
self.changed_at.revision = self.changed_at.revision.max(changed_at);
}
}
/// A unique identifier for a particular runtime. Each time you create
@ -523,6 +573,17 @@ pub struct Revision {
impl Revision {
pub(crate) const ZERO: Self = Revision { generation: 0 };
fn next(self) -> Revision {
Revision {
generation: self.generation + 1,
}
}
fn as_usize(self) -> usize {
assert!(self.generation < (std::usize::MAX as u64));
self.generation as usize
}
}
impl std::fmt::Debug for Revision {

View file

@ -66,10 +66,8 @@ fn revalidate() {
// Second generation: volatile will change (to 1) but memoized1
// will not (still 0, as 1/2 = 0)
query.salsa_runtime().next_revision();
query.memoized2();
query.assert_log(&["Memoized1 invoked", "Volatile invoked"]);
query.memoized2();
query.assert_log(&[]);

View file

@ -83,3 +83,75 @@ fn in_par_get_set_cancellation_transitive() {
assert_eq!(thread1.join().unwrap(), std::usize::MAX);
assert_eq!(thread2.join().unwrap(), 111);
}
/// https://github.com/salsa-rs/salsa/issues/66
#[test]
fn no_back_dating_in_cancellation() {
let mut db = ParDatabaseImpl::default();
db.query_mut(Input).set('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// Here we compute a long-chain of queries,
// but the last one gets cancelled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(true, || db.sum3("a"))
})
}
});
db.wait_for(1);
// Set unrelated input to bump revision
db.query_mut(Input).set('b', 2);
// Here we should recompuet the whole chain again, clearing the cancellation
// state. If we get `usize::max()` here, it is a bug!
assert_eq!(db.sum3("a"), 1);
assert_eq!(thread1.join().unwrap(), std::usize::MAX);
db.query_mut(Input).set('a', 3);
db.query_mut(Input).set('a', 4);
assert_eq!(db.sum3("ab"), 6);
}
/// Here, we compute `sum3_drop_sum` and -- in the process -- observe
/// a cancellation. As a result, we have to recompute `sum` when we
/// reinvoke `sum3_drop_sum` and we have to re-execute
/// `sum2_drop_sum`. But the result of `sum2_drop_sum` doesn't
/// change, so we don't have to re-execute `sum3_drop_sum`.
#[test]
fn transitive_cancellation() {
let mut db = ParDatabaseImpl::default();
db.query_mut(Input).set('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// Here we compute a long-chain of queries,
// but the last one gets cancelled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(true, || db.sum3_drop_sum("a"))
})
}
});
db.wait_for(1);
db.query_mut(Input).set('b', 2);
// Check that when we call `sum3_drop_sum` we don't wind up having
// to actually re-execute it, because the result of `sum2` winds
// up not changing.
db.knobs().sum3_drop_sum_should_panic.with_value(true, || {
assert_eq!(db.sum3_drop_sum("a"), 22);
});
assert_eq!(thread1.join().unwrap(), 22);
}

View file

@ -25,9 +25,14 @@ fn in_par_get_set_race() {
});
// If the 1st thread runs first, you get 111, otherwise you get
// 1011.
// 1011; if they run concurrently and the 1st thread observes the
// cancelation, you get back usize::max.
let value1 = thread1.join().unwrap();
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
assert!(
value1 == 111 || value1 == 1011 || value1 == std::usize::MAX,
"illegal result {}",
value1
);
assert_eq!(thread2.join().unwrap(), 1000);
}

View file

@ -16,10 +16,26 @@ salsa::query_group! {
type Sum;
}
/// Invokes `sum`
fn sum2(key: &'static str) -> usize {
type Sum2;
}
/// Invokes `sum` but doesn't really care about the result.
fn sum2_drop_sum(key: &'static str) -> usize {
type Sum2Drop;
}
/// Invokes `sum2`
fn sum3(key: &'static str) -> usize {
type Sum3;
}
/// Invokes `sum2_drop_sum`
fn sum3_drop_sum(key: &'static str) -> usize {
type Sum3Drop;
}
fn snapshot_me() -> () {
type SnapshotMe;
}
@ -82,6 +98,9 @@ pub(crate) struct KnobsStruct {
/// Invocations of `sum` will signal this stage prior to exiting.
pub(crate) sum_signal_on_exit: Cell<usize>,
/// Invocations of `sum3_drop_sum` will panic unconditionally
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
}
fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
@ -105,6 +124,17 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
std::thread::yield_now();
}
log::debug!("cancellation observed");
}
// Check for cancelation and return MAX if so. Note that we check
// for cancelation *deterministically* -- but if
// `sum_wait_for_cancellation` is set, we will block
// beforehand. Deterministic execution is a requirement for valid
// salsa user code. It's also important to some tests that `sum`
// *attempts* to invoke `is_current_revision_canceled` even if we
// know it will not be canceled, because that helps us keep the
// accounting up to date.
if db.salsa_runtime().is_current_revision_canceled() {
return std::usize::MAX; // when we are cancelled, we return usize::MAX.
}
@ -119,6 +149,22 @@ fn sum2(db: &impl ParDatabase, key: &'static str) -> usize {
db.sum(key)
}
fn sum2_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize {
let _ = db.sum(key);
22
}
fn sum3(db: &impl ParDatabase, key: &'static str) -> usize {
db.sum2(key)
}
fn sum3_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize {
if db.knobs().sum3_drop_sum_should_panic.get() {
panic!("sum3_drop_sum executed")
}
db.sum2_drop_sum(key)
}
fn snapshot_me(db: &impl ParDatabase) {
// this should panic
db.snapshot();
@ -176,6 +222,9 @@ salsa::database_storage! {
fn input() for Input;
fn sum() for Sum;
fn sum2() for Sum2;
fn sum2_drop_sum() for Sum2Drop;
fn sum3() for Sum3;
fn sum3_drop_sum() for Sum3Drop;
fn snapshot_me() for SnapshotMe;
}
}