diff --git a/src/runtime.rs b/src/runtime.rs index 22f84c9c..a1ab94cb 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,7 +1,7 @@ use crate::{Database, Event, EventKind, SweepStrategy}; use lock_api::{RawRwLock, RawRwLockRecursive}; use log::debug; -use parking_lot::{Mutex, RwLock, RwLockReadGuard}; +use parking_lot::{Mutex, RwLock}; use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; use std::cell::RefCell; @@ -95,14 +95,11 @@ where counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst), }; - let mut local_state = LocalState::default(); - local_state.query_in_progress = true; - Runtime { id, revision_guard: Some(revision_guard), shared_state: self.shared_state.clone(), - local_state: RefCell::new(local_state), + local_state: Default::default(), } } @@ -189,7 +186,7 @@ where pub(crate) fn with_incremented_revision(&self, op: impl FnOnce(Revision) -> R) -> R { log::debug!("increment_revision()"); - if self.query_in_progress() { + if !self.permits_increment() { panic!("increment_revision invoked during a query computation"); } @@ -228,8 +225,8 @@ where op(new_revision) } - pub(crate) fn query_in_progress(&self) -> bool { - self.local_state.borrow().query_in_progress + pub(crate) fn permits_increment(&self) -> bool { + self.revision_guard.is_none() && self.local_state.borrow().query_stack.is_empty() } pub(crate) fn execute_query_implementation( @@ -398,38 +395,17 @@ impl Default for SharedState { /// State that will be specific to a single execution threads (when we /// support multiple threads) struct LocalState { - query_in_progress: bool, query_stack: Vec>, } impl Default for LocalState { fn default() -> Self { LocalState { - query_in_progress: false, query_stack: Default::default(), } } } -pub(crate) struct QueryGuard<'db, DB: Database + 'db> { - db: &'db Runtime, - lock: RwLockReadGuard<'db, ()>, -} - -impl<'db, DB: Database> QueryGuard<'db, DB> { - fn new(db: &'db Runtime, lock: RwLockReadGuard<'db, ()>) -> Self { - Self { db, lock } - } -} - -impl<'db, DB: Database> Drop for QueryGuard<'db, DB> { - fn drop(&mut self) { - let mut local_state = self.db.local_state.borrow_mut(); - assert!(local_state.query_in_progress); - local_state.query_in_progress = false; - } -} - struct ActiveQuery { /// What query is executing descriptor: DB::QueryDescriptor, diff --git a/tests/panic_safely.rs b/tests/panic_safely.rs index e354aa7f..1f06e9d9 100644 --- a/tests/panic_safely.rs +++ b/tests/panic_safely.rs @@ -1,4 +1,4 @@ -use salsa::Database; +use salsa::{Database, Frozen, ParallelDatabase}; use std::panic::{self, AssertUnwindSafe}; salsa::query_group! { @@ -29,6 +29,14 @@ impl salsa::Database for DatabaseStruct { } } +impl salsa::ParallelDatabase for DatabaseStruct { + fn fork(&self) -> Frozen { + Frozen::new(DatabaseStruct { + runtime: self.runtime.fork(self), + }) + } +} + salsa::database_storage! { struct DatabaseStorage for DatabaseStruct { impl PanicSafelyDatabase { @@ -44,7 +52,10 @@ fn should_panic_safely() { // Invoke `db.panic_safely() without having set `db.one`. `db.one` will // default to 0 and we should catch the panic. - let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); + let result = panic::catch_unwind(AssertUnwindSafe({ + let db = db.fork(); + move || db.panic_safely() + })); assert!(result.is_err()); // Set `db.one` to 1 and assert ok