use rand::seq::SliceRandom; use rand::Rng; use salsa::ParallelDatabase; use salsa::Snapshot; use salsa::SweepStrategy; use salsa::{Canceled, Database}; // Number of operations a reader performs const N_MUTATOR_OPS: usize = 100; const N_READER_OPS: usize = 100; #[salsa::query_group(Stress)] trait StressDatabase: salsa::Database { #[salsa::input] fn a(&self, key: usize) -> usize; fn b(&self, key: usize) -> usize; fn c(&self, key: usize) -> usize; } fn b(db: &dyn StressDatabase, key: usize) -> usize { db.salsa_runtime().unwind_if_canceled(); db.a(key) } fn c(db: &dyn StressDatabase, key: usize) -> usize { db.b(key) } #[salsa::database(Stress)] #[derive(Default)] struct StressDatabaseImpl { storage: salsa::Storage, } impl salsa::Database for StressDatabaseImpl {} impl salsa::ParallelDatabase for StressDatabaseImpl { fn snapshot(&self) -> Snapshot { Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot(), }) } } #[derive(Clone, Copy, Debug)] enum Query { A, B, C, } enum MutatorOp { WriteOp(WriteOp), LaunchReader { ops: Vec, check_cancelation: bool, }, } #[derive(Debug)] enum WriteOp { SetA(usize, usize), } #[derive(Debug)] enum ReadOp { Get(Query, usize), Gc(Query, SweepStrategy), GcAll(SweepStrategy), } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> Query { *[Query::A, Query::B, Query::C].choose(rng).unwrap() } } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> MutatorOp { if rng.gen_bool(0.5) { MutatorOp::WriteOp(rng.gen()) } else { MutatorOp::LaunchReader { ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(), check_cancelation: rng.gen(), } } } } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> WriteOp { let key = rng.gen::() % 10; let value = rng.gen::() % 10; return WriteOp::SetA(key, value); } } impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> ReadOp { if rng.gen_bool(0.5) { let query = rng.gen::(); let key = rng.gen::() % 10; return ReadOp::Get(query, key); } let mut strategy = SweepStrategy::discard_outdated(); if rng.gen_bool(0.5) { strategy = strategy.discard_values(); } if rng.gen_bool(0.5) { ReadOp::Gc(rng.gen::(), strategy) } else { ReadOp::GcAll(strategy) } } } fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec, check_cancelation: bool) { for op in ops { if check_cancelation { db.salsa_runtime().unwind_if_canceled(); } op.execute(db); } } impl WriteOp { fn execute(self, db: &mut StressDatabaseImpl) { match self { WriteOp::SetA(key, value) => { db.set_a(key, value); } } } } impl ReadOp { fn execute(self, db: &StressDatabaseImpl) { match self { ReadOp::Get(query, key) => match query { Query::A => { db.a(key); } Query::B => { let _ = db.b(key); } Query::C => { let _ = db.c(key); } }, ReadOp::Gc(query, strategy) => match query { Query::A => { AQuery.in_db(db).sweep(strategy); } Query::B => { BQuery.in_db(db).sweep(strategy); } Query::C => { CQuery.in_db(db).sweep(strategy); } }, ReadOp::GcAll(strategy) => { db.sweep_all(strategy); } } } } #[test] fn stress_test() { let mut db = StressDatabaseImpl::default(); for i in 0..10 { db.set_a(i, i); } let mut rng = rand::thread_rng(); // generate the ops that the mutator thread will perform let write_ops: Vec = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect(); // execute the "main thread", which sometimes snapshots off other threads let mut all_threads = vec![]; for op in write_ops { match op { MutatorOp::WriteOp(w) => w.execute(&mut db), MutatorOp::LaunchReader { ops, check_cancelation, } => all_threads.push(std::thread::spawn({ let db = db.snapshot(); move || Canceled::catch(|| db_reader_thread(&db, ops, check_cancelation)) })), } } for thread in all_threads { thread.join().unwrap().ok(); } }