diff --git a/tests/parallel/cancellation.rs b/tests/parallel/cancellation.rs index 2ca97894..a60e663e 100644 --- a/tests/parallel/cancellation.rs +++ b/tests/parallel/cancellation.rs @@ -16,8 +16,9 @@ fn in_par_get_set_cancellation() { let thread1 = std::thread::spawn({ let db = db.fork(); move || { - let v1 = db.sum_signal_on_entry().with_value(1, || { - db.sum_await_cancellation() + let v1 = db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_await_cancellation .with_value(true, || db.sum("abc")) }); @@ -26,7 +27,7 @@ fn in_par_get_set_cancellation() { // at this point, we have observed cancellation, so let's // wait until the `set` is known to have occurred. - db.signal().await(2); + db.await(2); // Now when we read we should get the correct sums. Note // in particular that we re-compute the sum of `"abc"` @@ -40,12 +41,12 @@ fn in_par_get_set_cancellation() { let db = db.fork(); move || { // Wait until we have entered `sum` in the other thread. - db.signal().await(1); + db.await(1); db.query(Input).set('d', 1000); // Signal that we have *set* `d` - db.signal().signal(2); + db.signal(2); db.sum("d") } diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 67f6413e..cf4927e6 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -21,14 +21,11 @@ salsa::query_group! { /// Various "knobs" and utilities used by tests to force /// a certain behavior. pub(crate) trait Knobs { - fn signal(&self) -> &Signal; + fn knobs(&self) -> &KnobsStruct; - /// Invocations of `sum` will signal `stage` this stage on entry. - fn sum_signal_on_entry(&self) -> &Cell; + fn signal(&self, stage: usize); - /// If set to true, invocations of `sum` will await cancellation - /// before they exit. - fn sum_await_cancellation(&self) -> &Cell; + fn await(&self, stage: usize); } pub(crate) trait WithValue { @@ -47,11 +44,30 @@ impl WithValue for Cell { } } +/// Various "knobs" that can be used to customize how the queries +/// behave on one specific thread. Note that this state is +/// intentionally thread-local (apart from `signal`). #[derive(Clone, Default)] -struct KnobsStruct { - signal: Arc, - sum_signal_on_entry: Cell, - sum_await_cancellation: Cell, +pub(crate) struct KnobsStruct { + /// A kind of flexible barrier used to coordinate execution across + /// threads to ensure we reach various weird states. + pub(crate) signal: Arc, + + /// Invocations of `sum` will signal this stage on entry. + pub(crate) sum_signal_on_entry: Cell, + + /// Invocations of `sum` will await this stage on entry. + pub(crate) sum_await_on_entry: Cell, + + /// If true, invocations of `sum` will await cancellation before + /// they exit. + pub(crate) sum_await_cancellation: Cell, + + /// Invocations of `sum` will await this stage prior to exiting. + pub(crate) sum_await_on_exit: Cell, + + /// Invocations of `sum` will signal this stage prior to exiting. + pub(crate) sum_signal_on_exit: Cell, } #[derive(Default)] @@ -63,10 +79,18 @@ pub(crate) struct Signal { impl Signal { pub(crate) fn signal(&self, stage: usize) { log::debug!("signal({})", stage); - let mut v = self.value.lock(); - if stage > *v { - *v = stage; - self.cond_var.notify_all(); + + // This check avoids acquiring the lock for things that will + // clearly be a no-op. Not *necessary* but helps to ensure we + // are more likely to encounter weird race conditions; + // otherwise calls to `sum` will tend to be unnecessarily + // synchronous. + if stage > 0 { + let mut v = self.value.lock(); + if stage > *v { + *v = stage; + self.cond_var.notify_all(); + } } } @@ -74,9 +98,13 @@ impl Signal { /// with the current stage. pub(crate) fn await(&self, stage: usize) { log::debug!("await({})", stage); - let mut v = self.value.lock(); - while *v < stage { - self.cond_var.wait(&mut v); + + // As above, avoid lock if clearly a no-op. + if stage > 0 { + let mut v = self.value.lock(); + while *v < stage { + self.cond_var.wait(&mut v); + } } } } @@ -84,14 +112,15 @@ impl Signal { fn sum(db: &impl ParDatabase, key: &'static str) -> usize { let mut sum = 0; - let stage = db.sum_signal_on_entry().get(); - db.signal().signal(stage); + db.signal(db.knobs().sum_signal_on_entry.get()); + + db.await(db.knobs().sum_await_on_entry.get()); for ch in key.chars() { sum += db.input(ch); } - if db.sum_await_cancellation().get() { + if db.knobs().sum_await_cancellation.get() { log::debug!("awaiting cancellation"); while !db.salsa_runtime().is_current_revision_canceled() { std::thread::yield_now(); @@ -100,6 +129,10 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize { return std::usize::MAX; // when we are cancelled, we return usize::MAX. } + db.await(db.knobs().sum_await_on_exit.get()); + + db.signal(db.knobs().sum_signal_on_exit.get()); + sum } @@ -125,16 +158,16 @@ impl ParallelDatabase for ParDatabaseImpl { } impl Knobs for ParDatabaseImpl { - fn signal(&self) -> &Signal { - &self.knobs.signal + fn knobs(&self) -> &KnobsStruct { + &self.knobs } - fn sum_signal_on_entry(&self) -> &Cell { - &self.knobs.sum_signal_on_entry + fn signal(&self, stage: usize) { + self.knobs.signal.signal(stage); } - fn sum_await_cancellation(&self) -> &Cell { - &self.knobs.sum_await_cancellation + fn await(&self, stage: usize) { + self.knobs.signal.await(stage); } }