diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index 3b73bb17..73afbad1 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -118,3 +118,40 @@ fn no_back_dating_in_cancellation() { db.query_mut(Input).set('a', 4); assert_eq!(db.sum3("ab"), 6); } + +/// Here, we compute `sum3_drop_sum` and -- in the process -- observe +/// a cancellation. As a result, we have to recompute `sum` when we +/// reinvoke `sum3_drop_sum` and we have to re-execute +/// `sum2_drop_sum`. But the result of `sum2_drop_sum` doesn't +/// change, so we don't have to re-execute `sum3_drop_sum`. +#[test] +fn transitive_cancellation() { + let mut db = ParDatabaseImpl::default(); + + db.query_mut(Input).set('a', 1); + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + // Here we compute a long-chain of queries, + // but the last one gets cancelled. + db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(true, || db.sum3_drop_sum("a")) + }) + } + }); + + db.wait_for(1); + + db.query_mut(Input).set('b', 2); + + // Check that when we call `sum3_drop_sum` we don't wind up having + // to actually re-execute it, because the result of `sum2` winds + // up not changing. + db.knobs().sum3_drop_sum_should_panic.with_value(true, || { + assert_eq!(db.sum3_drop_sum("a"), 22); + }); + + assert_eq!(thread1.join().unwrap(), 22); +} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index f66986f1..100c4e05 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -16,14 +16,26 @@ salsa::query_group! { type Sum; } + /// Invokes `sum` fn sum2(key: &'static str) -> usize { type Sum2; } + /// Invokes `sum` but doesn't really care about the result. + fn sum2_drop_sum(key: &'static str) -> usize { + type Sum2Drop; + } + + /// Invokes `sum2` fn sum3(key: &'static str) -> usize { type Sum3; } + /// Invokes `sum2_drop_sum` + fn sum3_drop_sum(key: &'static str) -> usize { + type Sum3Drop; + } + fn snapshot_me() -> () { type SnapshotMe; } @@ -86,6 +98,9 @@ pub(crate) struct KnobsStruct { /// Invocations of `sum` will signal this stage prior to exiting. pub(crate) sum_signal_on_exit: Cell, + + /// Invocations of `sum3_drop_sum` will panic unconditionally + pub(crate) sum3_drop_sum_should_panic: Cell, } fn sum(db: &impl ParDatabase, key: &'static str) -> usize { @@ -134,10 +149,22 @@ fn sum2(db: &impl ParDatabase, key: &'static str) -> usize { db.sum(key) } +fn sum2_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize { + let _ = db.sum(key); + 22 +} + fn sum3(db: &impl ParDatabase, key: &'static str) -> usize { db.sum2(key) } +fn sum3_drop_sum(db: &impl ParDatabase, key: &'static str) -> usize { + if db.knobs().sum3_drop_sum_should_panic.get() { + panic!("sum3_drop_sum executed") + } + db.sum2_drop_sum(key) +} + fn snapshot_me(db: &impl ParDatabase) { // this should panic db.snapshot(); @@ -195,7 +222,9 @@ salsa::database_storage! { fn input() for Input; fn sum() for Sum; fn sum2() for Sum2; + fn sum2_drop_sum() for Sum2Drop; fn sum3() for Sum3; + fn sum3_drop_sum() for Sum3Drop; fn snapshot_me() for SnapshotMe; } }