mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 09:46:06 +00:00
Merge pull request #114 from nikomatsakis/clear-query-stack-on-panic
Clear query stack on panic
This commit is contained in:
commit
827828d6b5
3 changed files with 148 additions and 56 deletions
|
@ -4,7 +4,6 @@ use log::debug;
|
|||
use parking_lot::{Mutex, RwLock};
|
||||
use rustc_hash::{FxHashMap, FxHasher};
|
||||
use smallvec::SmallVec;
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Write;
|
||||
use std::hash::BuildHasherDefault;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
@ -12,6 +11,9 @@ use std::sync::Arc;
|
|||
|
||||
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
|
||||
|
||||
mod local_state;
|
||||
use local_state::LocalState;
|
||||
|
||||
/// The salsa runtime stores the storage for all queries as well as
|
||||
/// tracking the query stack and dependencies between cycles.
|
||||
///
|
||||
|
@ -29,7 +31,7 @@ pub struct Runtime<DB: Database> {
|
|||
revision_guard: Option<RevisionGuard<DB>>,
|
||||
|
||||
/// Local state that is specific to this runtime (thread).
|
||||
local_state: RefCell<LocalState<DB>>,
|
||||
local_state: LocalState<DB>,
|
||||
|
||||
/// Shared state that is accessible via all runtimes.
|
||||
shared_state: Arc<SharedState<DB>>,
|
||||
|
@ -99,7 +101,7 @@ where
|
|||
"invoked `snapshot` with a non-matching database"
|
||||
);
|
||||
|
||||
if self.local_state.borrow().query_in_progress() {
|
||||
if self.local_state.query_in_progress() {
|
||||
panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
|
||||
}
|
||||
|
||||
|
@ -151,11 +153,7 @@ where
|
|||
/// Returns the descriptor for the query that this thread is
|
||||
/// actively executing (if any).
|
||||
pub fn active_query(&self) -> Option<DB::QueryDescriptor> {
|
||||
self.local_state
|
||||
.borrow()
|
||||
.query_stack
|
||||
.last()
|
||||
.map(|active_query| active_query.descriptor.clone())
|
||||
self.local_state.active_query()
|
||||
}
|
||||
|
||||
/// Read current value of the revision counter.
|
||||
|
@ -303,7 +301,7 @@ where
|
|||
}
|
||||
|
||||
pub(crate) fn permits_increment(&self) -> bool {
|
||||
self.revision_guard.is_none() && !self.local_state.borrow().query_in_progress()
|
||||
self.revision_guard.is_none() && !self.local_state.query_in_progress()
|
||||
}
|
||||
|
||||
pub(crate) fn execute_query_implementation<V>(
|
||||
|
@ -322,13 +320,7 @@ where
|
|||
});
|
||||
|
||||
// Push the active query onto the stack.
|
||||
let push_len = {
|
||||
let mut local_state = self.local_state.borrow_mut();
|
||||
local_state
|
||||
.query_stack
|
||||
.push(ActiveQuery::new(descriptor.clone()));
|
||||
local_state.query_stack.len()
|
||||
};
|
||||
let active_query = self.local_state.push_query(descriptor);
|
||||
|
||||
// Execute user's code, accumulating inputs etc.
|
||||
let value = execute();
|
||||
|
@ -338,14 +330,7 @@ where
|
|||
subqueries,
|
||||
changed_at,
|
||||
..
|
||||
} = {
|
||||
let mut local_state = self.local_state.borrow_mut();
|
||||
|
||||
// Sanity check: pushes and pops should be balanced.
|
||||
assert_eq!(local_state.query_stack.len(), push_len);
|
||||
|
||||
local_state.query_stack.pop().unwrap()
|
||||
};
|
||||
} = active_query.complete();
|
||||
|
||||
ComputedQueryResult {
|
||||
value,
|
||||
|
@ -367,15 +352,12 @@ where
|
|||
descriptor: &DB::QueryDescriptor,
|
||||
changed_at: ChangedAt,
|
||||
) {
|
||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
||||
top_query.add_read(descriptor, changed_at);
|
||||
}
|
||||
self.local_state.report_query_read(descriptor, changed_at);
|
||||
}
|
||||
|
||||
pub(crate) fn report_untracked_read(&self) {
|
||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
||||
top_query.add_untracked_read(self.current_revision());
|
||||
}
|
||||
self.local_state
|
||||
.report_untracked_read(self.current_revision());
|
||||
}
|
||||
|
||||
/// An "anonymous" read is a read that doesn't come from executing
|
||||
|
@ -387,18 +369,14 @@ where
|
|||
///
|
||||
/// This is used when queries check if they have been canceled.
|
||||
fn report_anon_read(&self, revision: Revision) {
|
||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
||||
top_query.add_anon_read(revision);
|
||||
}
|
||||
self.local_state.report_anon_read(revision)
|
||||
}
|
||||
|
||||
/// Obviously, this should be user configurable at some point.
|
||||
pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
|
||||
debug!("report_unexpected_cycle(descriptor={:?})", descriptor);
|
||||
|
||||
let local_state = self.local_state.borrow();
|
||||
let LocalState { query_stack, .. } = &*local_state;
|
||||
|
||||
let query_stack = self.local_state.borrow_query_stack();
|
||||
let start_index = (0..query_stack.len())
|
||||
.rev()
|
||||
.filter(|&i| query_stack[i].descriptor == descriptor)
|
||||
|
@ -501,26 +479,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// State that will be specific to a single execution threads (when we
|
||||
/// support multiple threads)
|
||||
struct LocalState<DB: Database> {
|
||||
query_stack: Vec<ActiveQuery<DB>>,
|
||||
}
|
||||
|
||||
impl<DB: Database> Default for LocalState<DB> {
|
||||
fn default() -> Self {
|
||||
LocalState {
|
||||
query_stack: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB: Database> LocalState<DB> {
|
||||
fn query_in_progress(&self) -> bool {
|
||||
!self.query_stack.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
struct ActiveQuery<DB: Database> {
|
||||
/// What query is executing
|
||||
descriptor: DB::QueryDescriptor,
|
||||
|
|
120
src/runtime/local_state.rs
Normal file
120
src/runtime/local_state.rs
Normal file
|
@ -0,0 +1,120 @@
|
|||
use crate::runtime::ActiveQuery;
|
||||
use crate::runtime::ChangedAt;
|
||||
use crate::runtime::Revision;
|
||||
use crate::Database;
|
||||
use std::cell::Ref;
|
||||
use std::cell::RefCell;
|
||||
|
||||
/// State that is specific to a single execution thread.
|
||||
///
|
||||
/// Internally, this type uses ref-cells.
|
||||
///
|
||||
/// **Note also that all mutations to the database handle (and hence
|
||||
/// to the local-state) must be undone during unwinding.**
|
||||
pub(super) struct LocalState<DB: Database> {
|
||||
/// Vector of active queries.
|
||||
///
|
||||
/// Unwinding note: pushes onto this vector must be popped -- even
|
||||
/// during unwinding.
|
||||
query_stack: RefCell<Vec<ActiveQuery<DB>>>,
|
||||
}
|
||||
|
||||
impl<DB: Database> Default for LocalState<DB> {
|
||||
fn default() -> Self {
|
||||
LocalState {
|
||||
query_stack: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB: Database> LocalState<DB> {
|
||||
pub(super) fn push_query(&self, descriptor: &DB::QueryDescriptor) -> ActiveQueryGuard<'_, DB> {
|
||||
let mut query_stack = self.query_stack.borrow_mut();
|
||||
query_stack.push(ActiveQuery::new(descriptor.clone()));
|
||||
ActiveQueryGuard {
|
||||
local_state: self,
|
||||
push_len: query_stack.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the active query stack.
|
||||
///
|
||||
/// **Warning:** Because this reference holds the ref-cell lock,
|
||||
/// you should not use any mutating methods of `LocalState` while
|
||||
/// reading from it.
|
||||
pub(super) fn borrow_query_stack(&self) -> Ref<'_, Vec<ActiveQuery<DB>>> {
|
||||
self.query_stack.borrow()
|
||||
}
|
||||
|
||||
pub(super) fn query_in_progress(&self) -> bool {
|
||||
!self.query_stack.borrow().is_empty()
|
||||
}
|
||||
|
||||
pub(super) fn active_query(&self) -> Option<DB::QueryDescriptor> {
|
||||
self.query_stack
|
||||
.borrow()
|
||||
.last()
|
||||
.map(|active_query| active_query.descriptor.clone())
|
||||
}
|
||||
|
||||
pub(super) fn report_query_read(
|
||||
&self,
|
||||
descriptor: &DB::QueryDescriptor,
|
||||
changed_at: ChangedAt,
|
||||
) {
|
||||
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||
top_query.add_read(descriptor, changed_at);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn report_untracked_read(&self, current_revision: Revision) {
|
||||
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||
top_query.add_untracked_read(current_revision);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn report_anon_read(&self, revision: Revision) {
|
||||
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||
top_query.add_anon_read(revision);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// When a query is pushed onto the `active_query` stack, this guard
|
||||
/// is returned to represent its slot. The guard can be used to pop
|
||||
/// the query from the stack -- in the case of unwinding, the guard's
|
||||
/// destructor will also remove the query.
|
||||
pub(super) struct ActiveQueryGuard<'me, DB: Database> {
|
||||
local_state: &'me LocalState<DB>,
|
||||
push_len: usize,
|
||||
}
|
||||
|
||||
impl<'me, DB> ActiveQueryGuard<'me, DB>
|
||||
where
|
||||
DB: Database,
|
||||
{
|
||||
fn pop_helper(&self) -> ActiveQuery<DB> {
|
||||
let mut query_stack = self.local_state.query_stack.borrow_mut();
|
||||
|
||||
// Sanity check: pushes and pops should be balanced.
|
||||
assert_eq!(query_stack.len(), self.push_len);
|
||||
|
||||
query_stack.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Invoked when the query has successfully completed execution.
|
||||
pub(super) fn complete(self) -> ActiveQuery<DB> {
|
||||
let query = self.pop_helper();
|
||||
std::mem::forget(self);
|
||||
query
|
||||
}
|
||||
}
|
||||
|
||||
impl<'me, DB> Drop for ActiveQueryGuard<'me, DB>
|
||||
where
|
||||
DB: Database,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
self.pop_helper();
|
||||
}
|
||||
}
|
|
@ -64,3 +64,17 @@ fn storages_are_unwind_safe() {
|
|||
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
|
||||
check_unwind_safe::<&DatabaseStruct>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn panics_clear_query_stack() {
|
||||
let db = DatabaseStruct::default();
|
||||
|
||||
// Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
|
||||
// will default to 0 and we should catch the panic.
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||
assert!(result.is_err());
|
||||
|
||||
// The database has been poisoned and any attempt to increment the
|
||||
// revision should panic.
|
||||
assert_eq!(db.salsa_runtime().active_query(), None);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue