Merge pull request #114 from nikomatsakis/clear-query-stack-on-panic

Clear query stack on panic
This commit is contained in:
Niko Matsakis 2019-01-18 07:10:31 -05:00 committed by GitHub
commit 827828d6b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 148 additions and 56 deletions

View file

@ -4,7 +4,6 @@ use log::debug;
use parking_lot::{Mutex, RwLock};
use rustc_hash::{FxHashMap, FxHasher};
use smallvec::SmallVec;
use std::cell::RefCell;
use std::fmt::Write;
use std::hash::BuildHasherDefault;
use std::sync::atomic::{AtomicUsize, Ordering};
@ -12,6 +11,9 @@ use std::sync::Arc;
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
mod local_state;
use local_state::LocalState;
/// The salsa runtime stores the storage for all queries as well as
/// tracking the query stack and dependencies between cycles.
///
@ -29,7 +31,7 @@ pub struct Runtime<DB: Database> {
revision_guard: Option<RevisionGuard<DB>>,
/// Local state that is specific to this runtime (thread).
local_state: RefCell<LocalState<DB>>,
local_state: LocalState<DB>,
/// Shared state that is accessible via all runtimes.
shared_state: Arc<SharedState<DB>>,
@ -99,7 +101,7 @@ where
"invoked `snapshot` with a non-matching database"
);
if self.local_state.borrow().query_in_progress() {
if self.local_state.query_in_progress() {
panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
}
@ -151,11 +153,7 @@ where
/// Returns the descriptor for the query that this thread is
/// actively executing (if any).
pub fn active_query(&self) -> Option<DB::QueryDescriptor> {
self.local_state
.borrow()
.query_stack
.last()
.map(|active_query| active_query.descriptor.clone())
self.local_state.active_query()
}
/// Read current value of the revision counter.
@ -303,7 +301,7 @@ where
}
pub(crate) fn permits_increment(&self) -> bool {
self.revision_guard.is_none() && !self.local_state.borrow().query_in_progress()
self.revision_guard.is_none() && !self.local_state.query_in_progress()
}
pub(crate) fn execute_query_implementation<V>(
@ -322,13 +320,7 @@ where
});
// Push the active query onto the stack.
let push_len = {
let mut local_state = self.local_state.borrow_mut();
local_state
.query_stack
.push(ActiveQuery::new(descriptor.clone()));
local_state.query_stack.len()
};
let active_query = self.local_state.push_query(descriptor);
// Execute user's code, accumulating inputs etc.
let value = execute();
@ -338,14 +330,7 @@ where
subqueries,
changed_at,
..
} = {
let mut local_state = self.local_state.borrow_mut();
// Sanity check: pushes and pops should be balanced.
assert_eq!(local_state.query_stack.len(), push_len);
local_state.query_stack.pop().unwrap()
};
} = active_query.complete();
ComputedQueryResult {
value,
@ -367,15 +352,12 @@ where
descriptor: &DB::QueryDescriptor,
changed_at: ChangedAt,
) {
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
top_query.add_read(descriptor, changed_at);
}
self.local_state.report_query_read(descriptor, changed_at);
}
pub(crate) fn report_untracked_read(&self) {
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
top_query.add_untracked_read(self.current_revision());
}
self.local_state
.report_untracked_read(self.current_revision());
}
/// An "anonymous" read is a read that doesn't come from executing
@ -387,18 +369,14 @@ where
///
/// This is used when queries check if they have been canceled.
fn report_anon_read(&self, revision: Revision) {
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
top_query.add_anon_read(revision);
}
self.local_state.report_anon_read(revision)
}
/// Obviously, this should be user configurable at some point.
pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
debug!("report_unexpected_cycle(descriptor={:?})", descriptor);
let local_state = self.local_state.borrow();
let LocalState { query_stack, .. } = &*local_state;
let query_stack = self.local_state.borrow_query_stack();
let start_index = (0..query_stack.len())
.rev()
.filter(|&i| query_stack[i].descriptor == descriptor)
@ -501,26 +479,6 @@ where
}
}
/// State that will be specific to a single execution threads (when we
/// support multiple threads)
struct LocalState<DB: Database> {
query_stack: Vec<ActiveQuery<DB>>,
}
impl<DB: Database> Default for LocalState<DB> {
fn default() -> Self {
LocalState {
query_stack: Default::default(),
}
}
}
impl<DB: Database> LocalState<DB> {
fn query_in_progress(&self) -> bool {
!self.query_stack.is_empty()
}
}
struct ActiveQuery<DB: Database> {
/// What query is executing
descriptor: DB::QueryDescriptor,

120
src/runtime/local_state.rs Normal file
View file

@ -0,0 +1,120 @@
use crate::runtime::ActiveQuery;
use crate::runtime::ChangedAt;
use crate::runtime::Revision;
use crate::Database;
use std::cell::Ref;
use std::cell::RefCell;
/// State that is specific to a single execution thread.
///
/// Internally, this type uses ref-cells.
///
/// **Note also that all mutations to the database handle (and hence
/// to the local-state) must be undone during unwinding.**
pub(super) struct LocalState<DB: Database> {
/// Vector of active queries.
///
/// Unwinding note: pushes onto this vector must be popped -- even
/// during unwinding.
query_stack: RefCell<Vec<ActiveQuery<DB>>>,
}
impl<DB: Database> Default for LocalState<DB> {
fn default() -> Self {
LocalState {
query_stack: Default::default(),
}
}
}
impl<DB: Database> LocalState<DB> {
pub(super) fn push_query(&self, descriptor: &DB::QueryDescriptor) -> ActiveQueryGuard<'_, DB> {
let mut query_stack = self.query_stack.borrow_mut();
query_stack.push(ActiveQuery::new(descriptor.clone()));
ActiveQueryGuard {
local_state: self,
push_len: query_stack.len(),
}
}
/// Returns a reference to the active query stack.
///
/// **Warning:** Because this reference holds the ref-cell lock,
/// you should not use any mutating methods of `LocalState` while
/// reading from it.
pub(super) fn borrow_query_stack(&self) -> Ref<'_, Vec<ActiveQuery<DB>>> {
self.query_stack.borrow()
}
pub(super) fn query_in_progress(&self) -> bool {
!self.query_stack.borrow().is_empty()
}
pub(super) fn active_query(&self) -> Option<DB::QueryDescriptor> {
self.query_stack
.borrow()
.last()
.map(|active_query| active_query.descriptor.clone())
}
pub(super) fn report_query_read(
&self,
descriptor: &DB::QueryDescriptor,
changed_at: ChangedAt,
) {
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
top_query.add_read(descriptor, changed_at);
}
}
pub(super) fn report_untracked_read(&self, current_revision: Revision) {
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
top_query.add_untracked_read(current_revision);
}
}
pub(super) fn report_anon_read(&self, revision: Revision) {
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
top_query.add_anon_read(revision);
}
}
}
/// When a query is pushed onto the `active_query` stack, this guard
/// is returned to represent its slot. The guard can be used to pop
/// the query from the stack -- in the case of unwinding, the guard's
/// destructor will also remove the query.
pub(super) struct ActiveQueryGuard<'me, DB: Database> {
local_state: &'me LocalState<DB>,
push_len: usize,
}
impl<'me, DB> ActiveQueryGuard<'me, DB>
where
DB: Database,
{
fn pop_helper(&self) -> ActiveQuery<DB> {
let mut query_stack = self.local_state.query_stack.borrow_mut();
// Sanity check: pushes and pops should be balanced.
assert_eq!(query_stack.len(), self.push_len);
query_stack.pop().unwrap()
}
/// Invoked when the query has successfully completed execution.
pub(super) fn complete(self) -> ActiveQuery<DB> {
let query = self.pop_helper();
std::mem::forget(self);
query
}
}
impl<'me, DB> Drop for ActiveQueryGuard<'me, DB>
where
DB: Database,
{
fn drop(&mut self) {
self.pop_helper();
}
}

View file

@ -64,3 +64,17 @@ fn storages_are_unwind_safe() {
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
check_unwind_safe::<&DatabaseStruct>();
}
#[test]
fn panics_clear_query_stack() {
let db = DatabaseStruct::default();
// Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
// will default to 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_err());
// The database has been poisoned and any attempt to increment the
// revision should panic.
assert_eq!(db.salsa_runtime().active_query(), None);
}