diff --git a/src/runtime.rs b/src/runtime.rs index ab909d6b..d63ba8c9 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -4,7 +4,6 @@ use log::debug; use parking_lot::{Mutex, RwLock}; use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; -use std::cell::RefCell; use std::fmt::Write; use std::hash::BuildHasherDefault; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -12,6 +11,9 @@ use std::sync::Arc; pub(crate) type FxIndexSet = indexmap::IndexSet>; +mod local_state; +use local_state::LocalState; + /// The salsa runtime stores the storage for all queries as well as /// tracking the query stack and dependencies between cycles. /// @@ -29,7 +31,7 @@ pub struct Runtime { revision_guard: Option>, /// Local state that is specific to this runtime (thread). - local_state: RefCell>, + local_state: LocalState, /// Shared state that is accessible via all runtimes. shared_state: Arc>, @@ -99,7 +101,7 @@ where "invoked `snapshot` with a non-matching database" ); - if self.local_state.borrow().query_in_progress() { + if self.local_state.query_in_progress() { panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)"); } @@ -151,11 +153,7 @@ where /// Returns the descriptor for the query that this thread is /// actively executing (if any). pub fn active_query(&self) -> Option { - self.local_state - .borrow() - .query_stack - .last() - .map(|active_query| active_query.descriptor.clone()) + self.local_state.active_query() } /// Read current value of the revision counter. @@ -303,7 +301,7 @@ where } pub(crate) fn permits_increment(&self) -> bool { - self.revision_guard.is_none() && !self.local_state.borrow().query_in_progress() + self.revision_guard.is_none() && !self.local_state.query_in_progress() } pub(crate) fn execute_query_implementation( @@ -322,13 +320,7 @@ where }); // Push the active query onto the stack. - let push_len = { - let mut local_state = self.local_state.borrow_mut(); - local_state - .query_stack - .push(ActiveQuery::new(descriptor.clone())); - local_state.query_stack.len() - }; + let active_query = self.local_state.push_query(descriptor); // Execute user's code, accumulating inputs etc. let value = execute(); @@ -338,14 +330,7 @@ where subqueries, changed_at, .. - } = { - let mut local_state = self.local_state.borrow_mut(); - - // Sanity check: pushes and pops should be balanced. - assert_eq!(local_state.query_stack.len(), push_len); - - local_state.query_stack.pop().unwrap() - }; + } = active_query.complete(); ComputedQueryResult { value, @@ -367,15 +352,12 @@ where descriptor: &DB::QueryDescriptor, changed_at: ChangedAt, ) { - if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { - top_query.add_read(descriptor, changed_at); - } + self.local_state.report_query_read(descriptor, changed_at); } pub(crate) fn report_untracked_read(&self) { - if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { - top_query.add_untracked_read(self.current_revision()); - } + self.local_state + .report_untracked_read(self.current_revision()); } /// An "anonymous" read is a read that doesn't come from executing @@ -387,18 +369,14 @@ where /// /// This is used when queries check if they have been canceled. fn report_anon_read(&self, revision: Revision) { - if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() { - top_query.add_anon_read(revision); - } + self.local_state.report_anon_read(revision) } /// Obviously, this should be user configurable at some point. pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! { debug!("report_unexpected_cycle(descriptor={:?})", descriptor); - let local_state = self.local_state.borrow(); - let LocalState { query_stack, .. } = &*local_state; - + let query_stack = self.local_state.borrow_query_stack(); let start_index = (0..query_stack.len()) .rev() .filter(|&i| query_stack[i].descriptor == descriptor) @@ -501,26 +479,6 @@ where } } -/// State that will be specific to a single execution threads (when we -/// support multiple threads) -struct LocalState { - query_stack: Vec>, -} - -impl Default for LocalState { - fn default() -> Self { - LocalState { - query_stack: Default::default(), - } - } -} - -impl LocalState { - fn query_in_progress(&self) -> bool { - !self.query_stack.is_empty() - } -} - struct ActiveQuery { /// What query is executing descriptor: DB::QueryDescriptor, diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs new file mode 100644 index 00000000..496d3afc --- /dev/null +++ b/src/runtime/local_state.rs @@ -0,0 +1,120 @@ +use crate::runtime::ActiveQuery; +use crate::runtime::ChangedAt; +use crate::runtime::Revision; +use crate::Database; +use std::cell::Ref; +use std::cell::RefCell; + +/// State that is specific to a single execution thread. +/// +/// Internally, this type uses ref-cells. +/// +/// **Note also that all mutations to the database handle (and hence +/// to the local-state) must be undone during unwinding.** +pub(super) struct LocalState { + /// Vector of active queries. + /// + /// Unwinding note: pushes onto this vector must be popped -- even + /// during unwinding. + query_stack: RefCell>>, +} + +impl Default for LocalState { + fn default() -> Self { + LocalState { + query_stack: Default::default(), + } + } +} + +impl LocalState { + pub(super) fn push_query(&self, descriptor: &DB::QueryDescriptor) -> ActiveQueryGuard<'_, DB> { + let mut query_stack = self.query_stack.borrow_mut(); + query_stack.push(ActiveQuery::new(descriptor.clone())); + ActiveQueryGuard { + local_state: self, + push_len: query_stack.len(), + } + } + + /// Returns a reference to the active query stack. + /// + /// **Warning:** Because this reference holds the ref-cell lock, + /// you should not use any mutating methods of `LocalState` while + /// reading from it. + pub(super) fn borrow_query_stack(&self) -> Ref<'_, Vec>> { + self.query_stack.borrow() + } + + pub(super) fn query_in_progress(&self) -> bool { + !self.query_stack.borrow().is_empty() + } + + pub(super) fn active_query(&self) -> Option { + self.query_stack + .borrow() + .last() + .map(|active_query| active_query.descriptor.clone()) + } + + pub(super) fn report_query_read( + &self, + descriptor: &DB::QueryDescriptor, + changed_at: ChangedAt, + ) { + if let Some(top_query) = self.query_stack.borrow_mut().last_mut() { + top_query.add_read(descriptor, changed_at); + } + } + + pub(super) fn report_untracked_read(&self, current_revision: Revision) { + if let Some(top_query) = self.query_stack.borrow_mut().last_mut() { + top_query.add_untracked_read(current_revision); + } + } + + pub(super) fn report_anon_read(&self, revision: Revision) { + if let Some(top_query) = self.query_stack.borrow_mut().last_mut() { + top_query.add_anon_read(revision); + } + } +} + +/// When a query is pushed onto the `active_query` stack, this guard +/// is returned to represent its slot. The guard can be used to pop +/// the query from the stack -- in the case of unwinding, the guard's +/// destructor will also remove the query. +pub(super) struct ActiveQueryGuard<'me, DB: Database> { + local_state: &'me LocalState, + push_len: usize, +} + +impl<'me, DB> ActiveQueryGuard<'me, DB> +where + DB: Database, +{ + fn pop_helper(&self) -> ActiveQuery { + let mut query_stack = self.local_state.query_stack.borrow_mut(); + + // Sanity check: pushes and pops should be balanced. + assert_eq!(query_stack.len(), self.push_len); + + query_stack.pop().unwrap() + } + + /// Invoked when the query has successfully completed execution. + pub(super) fn complete(self) -> ActiveQuery { + let query = self.pop_helper(); + std::mem::forget(self); + query + } +} + +impl<'me, DB> Drop for ActiveQueryGuard<'me, DB> +where + DB: Database, +{ + fn drop(&mut self) { + self.pop_helper(); + } +} diff --git a/tests/panic_safely.rs b/tests/panic_safely.rs index 01d00d47..8f256674 100644 --- a/tests/panic_safely.rs +++ b/tests/panic_safely.rs @@ -64,3 +64,17 @@ fn storages_are_unwind_safe() { fn check_unwind_safe() {} check_unwind_safe::<&DatabaseStruct>(); } + +#[test] +fn panics_clear_query_stack() { + let db = DatabaseStruct::default(); + + // Invoke `db.panic_if_not_one() without having set `db.input`. `db.input` + // will default to 0 and we should catch the panic. + let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); + assert!(result.is_err()); + + // The database has been poisoned and any attempt to increment the + // revision should panic. + assert_eq!(db.salsa_runtime().active_query(), None); +}