use rand::Rng; use salsa::Database; use salsa::ParallelDatabase; use salsa::Snapshot; use salsa::SweepStrategy; // Number of operations a reader performs const N_MUTATOR_OPS: usize = 100; const N_READER_OPS: usize = 100; #[derive(Clone, Copy, Debug, PartialEq, Eq)] struct Canceled; type Cancelable = Result; #[salsa::query_group(Stress)] trait StressDatabase: salsa::Database { #[salsa::input] fn a(&self, key: usize) -> usize; fn b(&self, key: usize) -> Cancelable; fn c(&self, key: usize) -> Cancelable; } fn b(db: &impl StressDatabase, key: usize) -> Cancelable { if db.salsa_runtime().is_current_revision_canceled() { return Err(Canceled); } Ok(db.a(key)) } fn c(db: &impl StressDatabase, key: usize) -> Cancelable { db.b(key) } #[salsa::database(Stress)] #[derive(Default)] struct StressDatabaseImpl { runtime: salsa::Runtime, } impl salsa::Database for StressDatabaseImpl { fn salsa_runtime(&self) -> &salsa::Runtime { &self.runtime } } impl salsa::ParallelDatabase for StressDatabaseImpl { fn snapshot(&self) -> Snapshot { Snapshot::new(StressDatabaseImpl { runtime: self.runtime.snapshot(self), }) } } #[derive(Clone, Copy, Debug)] enum Query { A, B, C, } enum MutatorOp { WriteOp(WriteOp), LaunchReader { ops: Vec, check_cancellation: bool, }, } #[derive(Debug)] enum WriteOp { SetA(usize, usize), } #[derive(Debug)] enum ReadOp { Get(Query, usize), Gc(Query, SweepStrategy), GcAll(SweepStrategy), } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> Query { *rng.choose(&[Query::A, Query::B, Query::C]).unwrap() } } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> MutatorOp { if rng.gen_bool(0.5) { MutatorOp::WriteOp(rng.gen()) } else { MutatorOp::LaunchReader { ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(), check_cancellation: rng.gen(), } } } } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> WriteOp { let key = rng.gen::() % 10; let value = rng.gen::() % 10; return WriteOp::SetA(key, value); } } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> ReadOp { if rng.gen_bool(0.5) { let query = rng.gen::(); let key = rng.gen::() % 10; return ReadOp::Get(query, key); } let mut strategy = SweepStrategy::default(); if rng.gen_bool(0.5) { strategy = strategy.discard_values(); } if rng.gen_bool(0.5) { ReadOp::Gc(rng.gen::(), strategy) } else { ReadOp::GcAll(strategy) } } } fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec, check_cancellation: bool) { for op in ops { if check_cancellation { if db.salsa_runtime().is_current_revision_canceled() { return; } } op.execute(db); } } impl WriteOp { fn execute(self, db: &mut StressDatabaseImpl) { match self { WriteOp::SetA(key, value) => { db.query_mut(AQuery).set(key, value); } } } } impl ReadOp { fn execute(self, db: &StressDatabaseImpl) { match self { ReadOp::Get(query, key) => match query { Query::A => { db.a(key); } Query::B => { let _ = db.b(key); } Query::C => { let _ = db.c(key); } }, ReadOp::Gc(query, strategy) => match query { Query::A => { db.query(AQuery).sweep(strategy); } Query::B => { db.query(BQuery).sweep(strategy); } Query::C => { db.query(CQuery).sweep(strategy); } }, ReadOp::GcAll(strategy) => { db.sweep_all(strategy); } } } } #[test] fn stress_test() { let mut db = StressDatabaseImpl::default(); for i in 0..10 { db.query_mut(AQuery).set(i, i); } let mut rng = rand::thread_rng(); // generate the ops that the mutator thread will perform let write_ops: Vec = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect(); // execute the "main thread", which sometimes snapshots off other threads let mut all_threads = vec![]; for op in write_ops { match op { MutatorOp::WriteOp(w) => w.execute(&mut db), MutatorOp::LaunchReader { ops, check_cancellation, } => all_threads.push(std::thread::spawn({ let db = db.snapshot(); move || db_reader_thread(&db, ops, check_cancellation) })), } } for thread in all_threads { thread.join().unwrap(); } }