diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index b632989b..76e883f4 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -55,3 +55,57 @@ fn in_par_get_set_cancellation() { assert_eq!(thread1.join().unwrap(), (std::usize::MAX, 111)); assert_eq!(thread2.join().unwrap(), 1000); } + +/// Here, we check that `sum`'s cancellation is propagated +/// to `sum2` properly. +#[test] +fn in_par_get_set_transitive_cancellation() { + let db = ParDatabaseImpl::default(); + + db.query(Input).set('a', 100); + db.query(Input).set('b', 010); + db.query(Input).set('c', 001); + db.query(Input).set('d', 0); + + let thread1 = std::thread::spawn({ + let db = db.fork(); + move || { + let v1 = db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(true, || db.sum2("abc")) + }); + + // check that we observed cancellation + assert_eq!(v1, std::usize::MAX); + + // at this point, we have observed cancellation, so let's + // wait until the `set` is known to have occurred. + db.wait_for(2); + + // Now when we read we should get the correct sums. Note + // in particular that we re-compute the sum of `"abc"` + // even though none of our inputs have changed. + let v2 = db.sum2("abc"); + (v1, v2) + } + }); + + let thread2 = std::thread::spawn({ + let db = db.fork(); + move || { + // Wait until we have entered `sum` in the other thread. + db.wait_for(1); + + db.query(Input).set('d', 1000); + + // Signal that we have *set* `d` + db.signal(2); + + db.sum("d") + } + }); + + assert_eq!(thread1.join().unwrap(), (std::usize::MAX, 111)); + assert_eq!(thread2.join().unwrap(), 1000); +} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 9aa7fe7a..be89f720 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -13,7 +13,10 @@ salsa::query_group! { fn sum(key: &'static str) -> usize { type Sum; - use fn sum; + } + + fn sum2(key: &'static str) -> usize { + type Sum2; } } } @@ -97,6 +100,10 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize { sum } +fn sum2(db: &impl ParDatabase, key: &'static str) -> usize { + sum(db, key) +} + #[derive(Default)] pub struct ParDatabaseImpl { runtime: salsa::Runtime, @@ -137,6 +144,7 @@ salsa::database_storage! { impl ParDatabase { fn input() for Input; fn sum() for Sum; + fn sum2() for Sum2; } } }