mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 09:46:06 +00:00
parent
f56c341730
commit
f709e64bd5
2 changed files with 63 additions and 1 deletions
|
@ -55,3 +55,57 @@ fn in_par_get_set_cancellation() {
|
||||||
assert_eq!(thread1.join().unwrap(), (std::usize::MAX, 111));
|
assert_eq!(thread1.join().unwrap(), (std::usize::MAX, 111));
|
||||||
assert_eq!(thread2.join().unwrap(), 1000);
|
assert_eq!(thread2.join().unwrap(), 1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Here, we check that `sum`'s cancellation is propagated
|
||||||
|
/// to `sum2` properly.
|
||||||
|
#[test]
|
||||||
|
fn in_par_get_set_transitive_cancellation() {
|
||||||
|
let db = ParDatabaseImpl::default();
|
||||||
|
|
||||||
|
db.query(Input).set('a', 100);
|
||||||
|
db.query(Input).set('b', 010);
|
||||||
|
db.query(Input).set('c', 001);
|
||||||
|
db.query(Input).set('d', 0);
|
||||||
|
|
||||||
|
let thread1 = std::thread::spawn({
|
||||||
|
let db = db.fork();
|
||||||
|
move || {
|
||||||
|
let v1 = db.knobs().sum_signal_on_entry.with_value(1, || {
|
||||||
|
db.knobs()
|
||||||
|
.sum_wait_for_cancellation
|
||||||
|
.with_value(true, || db.sum2("abc"))
|
||||||
|
});
|
||||||
|
|
||||||
|
// check that we observed cancellation
|
||||||
|
assert_eq!(v1, std::usize::MAX);
|
||||||
|
|
||||||
|
// at this point, we have observed cancellation, so let's
|
||||||
|
// wait until the `set` is known to have occurred.
|
||||||
|
db.wait_for(2);
|
||||||
|
|
||||||
|
// Now when we read we should get the correct sums. Note
|
||||||
|
// in particular that we re-compute the sum of `"abc"`
|
||||||
|
// even though none of our inputs have changed.
|
||||||
|
let v2 = db.sum2("abc");
|
||||||
|
(v1, v2)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let thread2 = std::thread::spawn({
|
||||||
|
let db = db.fork();
|
||||||
|
move || {
|
||||||
|
// Wait until we have entered `sum` in the other thread.
|
||||||
|
db.wait_for(1);
|
||||||
|
|
||||||
|
db.query(Input).set('d', 1000);
|
||||||
|
|
||||||
|
// Signal that we have *set* `d`
|
||||||
|
db.signal(2);
|
||||||
|
|
||||||
|
db.sum("d")
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(thread1.join().unwrap(), (std::usize::MAX, 111));
|
||||||
|
assert_eq!(thread2.join().unwrap(), 1000);
|
||||||
|
}
|
||||||
|
|
|
@ -13,7 +13,10 @@ salsa::query_group! {
|
||||||
|
|
||||||
fn sum(key: &'static str) -> usize {
|
fn sum(key: &'static str) -> usize {
|
||||||
type Sum;
|
type Sum;
|
||||||
use fn sum;
|
}
|
||||||
|
|
||||||
|
fn sum2(key: &'static str) -> usize {
|
||||||
|
type Sum2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -97,6 +100,10 @@ fn sum(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sum2(db: &impl ParDatabase, key: &'static str) -> usize {
|
||||||
|
sum(db, key)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct ParDatabaseImpl {
|
pub struct ParDatabaseImpl {
|
||||||
runtime: salsa::Runtime<ParDatabaseImpl>,
|
runtime: salsa::Runtime<ParDatabaseImpl>,
|
||||||
|
@ -137,6 +144,7 @@ salsa::database_storage! {
|
||||||
impl ParDatabase {
|
impl ParDatabase {
|
||||||
fn input() for Input;
|
fn input() for Input;
|
||||||
fn sum() for Sum;
|
fn sum() for Sum;
|
||||||
|
fn sum2() for Sum2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue