mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 17:52:19 +00:00
Merge pull request #114 from nikomatsakis/clear-query-stack-on-panic
Clear query stack on panic
This commit is contained in:
commit
827828d6b5
3 changed files with 148 additions and 56 deletions
|
@ -4,7 +4,6 @@ use log::debug;
|
||||||
use parking_lot::{Mutex, RwLock};
|
use parking_lot::{Mutex, RwLock};
|
||||||
use rustc_hash::{FxHashMap, FxHasher};
|
use rustc_hash::{FxHashMap, FxHasher};
|
||||||
use smallvec::SmallVec;
|
use smallvec::SmallVec;
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::hash::BuildHasherDefault;
|
use std::hash::BuildHasherDefault;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
@ -12,6 +11,9 @@ use std::sync::Arc;
|
||||||
|
|
||||||
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
|
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, BuildHasherDefault<FxHasher>>;
|
||||||
|
|
||||||
|
mod local_state;
|
||||||
|
use local_state::LocalState;
|
||||||
|
|
||||||
/// The salsa runtime stores the storage for all queries as well as
|
/// The salsa runtime stores the storage for all queries as well as
|
||||||
/// tracking the query stack and dependencies between cycles.
|
/// tracking the query stack and dependencies between cycles.
|
||||||
///
|
///
|
||||||
|
@ -29,7 +31,7 @@ pub struct Runtime<DB: Database> {
|
||||||
revision_guard: Option<RevisionGuard<DB>>,
|
revision_guard: Option<RevisionGuard<DB>>,
|
||||||
|
|
||||||
/// Local state that is specific to this runtime (thread).
|
/// Local state that is specific to this runtime (thread).
|
||||||
local_state: RefCell<LocalState<DB>>,
|
local_state: LocalState<DB>,
|
||||||
|
|
||||||
/// Shared state that is accessible via all runtimes.
|
/// Shared state that is accessible via all runtimes.
|
||||||
shared_state: Arc<SharedState<DB>>,
|
shared_state: Arc<SharedState<DB>>,
|
||||||
|
@ -99,7 +101,7 @@ where
|
||||||
"invoked `snapshot` with a non-matching database"
|
"invoked `snapshot` with a non-matching database"
|
||||||
);
|
);
|
||||||
|
|
||||||
if self.local_state.borrow().query_in_progress() {
|
if self.local_state.query_in_progress() {
|
||||||
panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
|
panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,11 +153,7 @@ where
|
||||||
/// Returns the descriptor for the query that this thread is
|
/// Returns the descriptor for the query that this thread is
|
||||||
/// actively executing (if any).
|
/// actively executing (if any).
|
||||||
pub fn active_query(&self) -> Option<DB::QueryDescriptor> {
|
pub fn active_query(&self) -> Option<DB::QueryDescriptor> {
|
||||||
self.local_state
|
self.local_state.active_query()
|
||||||
.borrow()
|
|
||||||
.query_stack
|
|
||||||
.last()
|
|
||||||
.map(|active_query| active_query.descriptor.clone())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read current value of the revision counter.
|
/// Read current value of the revision counter.
|
||||||
|
@ -303,7 +301,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn permits_increment(&self) -> bool {
|
pub(crate) fn permits_increment(&self) -> bool {
|
||||||
self.revision_guard.is_none() && !self.local_state.borrow().query_in_progress()
|
self.revision_guard.is_none() && !self.local_state.query_in_progress()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn execute_query_implementation<V>(
|
pub(crate) fn execute_query_implementation<V>(
|
||||||
|
@ -322,13 +320,7 @@ where
|
||||||
});
|
});
|
||||||
|
|
||||||
// Push the active query onto the stack.
|
// Push the active query onto the stack.
|
||||||
let push_len = {
|
let active_query = self.local_state.push_query(descriptor);
|
||||||
let mut local_state = self.local_state.borrow_mut();
|
|
||||||
local_state
|
|
||||||
.query_stack
|
|
||||||
.push(ActiveQuery::new(descriptor.clone()));
|
|
||||||
local_state.query_stack.len()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Execute user's code, accumulating inputs etc.
|
// Execute user's code, accumulating inputs etc.
|
||||||
let value = execute();
|
let value = execute();
|
||||||
|
@ -338,14 +330,7 @@ where
|
||||||
subqueries,
|
subqueries,
|
||||||
changed_at,
|
changed_at,
|
||||||
..
|
..
|
||||||
} = {
|
} = active_query.complete();
|
||||||
let mut local_state = self.local_state.borrow_mut();
|
|
||||||
|
|
||||||
// Sanity check: pushes and pops should be balanced.
|
|
||||||
assert_eq!(local_state.query_stack.len(), push_len);
|
|
||||||
|
|
||||||
local_state.query_stack.pop().unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
ComputedQueryResult {
|
ComputedQueryResult {
|
||||||
value,
|
value,
|
||||||
|
@ -367,15 +352,12 @@ where
|
||||||
descriptor: &DB::QueryDescriptor,
|
descriptor: &DB::QueryDescriptor,
|
||||||
changed_at: ChangedAt,
|
changed_at: ChangedAt,
|
||||||
) {
|
) {
|
||||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
self.local_state.report_query_read(descriptor, changed_at);
|
||||||
top_query.add_read(descriptor, changed_at);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn report_untracked_read(&self) {
|
pub(crate) fn report_untracked_read(&self) {
|
||||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
self.local_state
|
||||||
top_query.add_untracked_read(self.current_revision());
|
.report_untracked_read(self.current_revision());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An "anonymous" read is a read that doesn't come from executing
|
/// An "anonymous" read is a read that doesn't come from executing
|
||||||
|
@ -387,18 +369,14 @@ where
|
||||||
///
|
///
|
||||||
/// This is used when queries check if they have been canceled.
|
/// This is used when queries check if they have been canceled.
|
||||||
fn report_anon_read(&self, revision: Revision) {
|
fn report_anon_read(&self, revision: Revision) {
|
||||||
if let Some(top_query) = self.local_state.borrow_mut().query_stack.last_mut() {
|
self.local_state.report_anon_read(revision)
|
||||||
top_query.add_anon_read(revision);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obviously, this should be user configurable at some point.
|
/// Obviously, this should be user configurable at some point.
|
||||||
pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
|
pub(crate) fn report_unexpected_cycle(&self, descriptor: DB::QueryDescriptor) -> ! {
|
||||||
debug!("report_unexpected_cycle(descriptor={:?})", descriptor);
|
debug!("report_unexpected_cycle(descriptor={:?})", descriptor);
|
||||||
|
|
||||||
let local_state = self.local_state.borrow();
|
let query_stack = self.local_state.borrow_query_stack();
|
||||||
let LocalState { query_stack, .. } = &*local_state;
|
|
||||||
|
|
||||||
let start_index = (0..query_stack.len())
|
let start_index = (0..query_stack.len())
|
||||||
.rev()
|
.rev()
|
||||||
.filter(|&i| query_stack[i].descriptor == descriptor)
|
.filter(|&i| query_stack[i].descriptor == descriptor)
|
||||||
|
@ -501,26 +479,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// State that will be specific to a single execution threads (when we
|
|
||||||
/// support multiple threads)
|
|
||||||
struct LocalState<DB: Database> {
|
|
||||||
query_stack: Vec<ActiveQuery<DB>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<DB: Database> Default for LocalState<DB> {
|
|
||||||
fn default() -> Self {
|
|
||||||
LocalState {
|
|
||||||
query_stack: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<DB: Database> LocalState<DB> {
|
|
||||||
fn query_in_progress(&self) -> bool {
|
|
||||||
!self.query_stack.is_empty()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ActiveQuery<DB: Database> {
|
struct ActiveQuery<DB: Database> {
|
||||||
/// What query is executing
|
/// What query is executing
|
||||||
descriptor: DB::QueryDescriptor,
|
descriptor: DB::QueryDescriptor,
|
||||||
|
|
120
src/runtime/local_state.rs
Normal file
120
src/runtime/local_state.rs
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
use crate::runtime::ActiveQuery;
|
||||||
|
use crate::runtime::ChangedAt;
|
||||||
|
use crate::runtime::Revision;
|
||||||
|
use crate::Database;
|
||||||
|
use std::cell::Ref;
|
||||||
|
use std::cell::RefCell;
|
||||||
|
|
||||||
|
/// State that is specific to a single execution thread.
|
||||||
|
///
|
||||||
|
/// Internally, this type uses ref-cells.
|
||||||
|
///
|
||||||
|
/// **Note also that all mutations to the database handle (and hence
|
||||||
|
/// to the local-state) must be undone during unwinding.**
|
||||||
|
pub(super) struct LocalState<DB: Database> {
|
||||||
|
/// Vector of active queries.
|
||||||
|
///
|
||||||
|
/// Unwinding note: pushes onto this vector must be popped -- even
|
||||||
|
/// during unwinding.
|
||||||
|
query_stack: RefCell<Vec<ActiveQuery<DB>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<DB: Database> Default for LocalState<DB> {
|
||||||
|
fn default() -> Self {
|
||||||
|
LocalState {
|
||||||
|
query_stack: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<DB: Database> LocalState<DB> {
|
||||||
|
pub(super) fn push_query(&self, descriptor: &DB::QueryDescriptor) -> ActiveQueryGuard<'_, DB> {
|
||||||
|
let mut query_stack = self.query_stack.borrow_mut();
|
||||||
|
query_stack.push(ActiveQuery::new(descriptor.clone()));
|
||||||
|
ActiveQueryGuard {
|
||||||
|
local_state: self,
|
||||||
|
push_len: query_stack.len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the active query stack.
|
||||||
|
///
|
||||||
|
/// **Warning:** Because this reference holds the ref-cell lock,
|
||||||
|
/// you should not use any mutating methods of `LocalState` while
|
||||||
|
/// reading from it.
|
||||||
|
pub(super) fn borrow_query_stack(&self) -> Ref<'_, Vec<ActiveQuery<DB>>> {
|
||||||
|
self.query_stack.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn query_in_progress(&self) -> bool {
|
||||||
|
!self.query_stack.borrow().is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn active_query(&self) -> Option<DB::QueryDescriptor> {
|
||||||
|
self.query_stack
|
||||||
|
.borrow()
|
||||||
|
.last()
|
||||||
|
.map(|active_query| active_query.descriptor.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn report_query_read(
|
||||||
|
&self,
|
||||||
|
descriptor: &DB::QueryDescriptor,
|
||||||
|
changed_at: ChangedAt,
|
||||||
|
) {
|
||||||
|
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||||
|
top_query.add_read(descriptor, changed_at);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn report_untracked_read(&self, current_revision: Revision) {
|
||||||
|
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||||
|
top_query.add_untracked_read(current_revision);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn report_anon_read(&self, revision: Revision) {
|
||||||
|
if let Some(top_query) = self.query_stack.borrow_mut().last_mut() {
|
||||||
|
top_query.add_anon_read(revision);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// When a query is pushed onto the `active_query` stack, this guard
|
||||||
|
/// is returned to represent its slot. The guard can be used to pop
|
||||||
|
/// the query from the stack -- in the case of unwinding, the guard's
|
||||||
|
/// destructor will also remove the query.
|
||||||
|
pub(super) struct ActiveQueryGuard<'me, DB: Database> {
|
||||||
|
local_state: &'me LocalState<DB>,
|
||||||
|
push_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'me, DB> ActiveQueryGuard<'me, DB>
|
||||||
|
where
|
||||||
|
DB: Database,
|
||||||
|
{
|
||||||
|
fn pop_helper(&self) -> ActiveQuery<DB> {
|
||||||
|
let mut query_stack = self.local_state.query_stack.borrow_mut();
|
||||||
|
|
||||||
|
// Sanity check: pushes and pops should be balanced.
|
||||||
|
assert_eq!(query_stack.len(), self.push_len);
|
||||||
|
|
||||||
|
query_stack.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invoked when the query has successfully completed execution.
|
||||||
|
pub(super) fn complete(self) -> ActiveQuery<DB> {
|
||||||
|
let query = self.pop_helper();
|
||||||
|
std::mem::forget(self);
|
||||||
|
query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'me, DB> Drop for ActiveQueryGuard<'me, DB>
|
||||||
|
where
|
||||||
|
DB: Database,
|
||||||
|
{
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.pop_helper();
|
||||||
|
}
|
||||||
|
}
|
|
@ -64,3 +64,17 @@ fn storages_are_unwind_safe() {
|
||||||
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
|
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
|
||||||
check_unwind_safe::<&DatabaseStruct>();
|
check_unwind_safe::<&DatabaseStruct>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn panics_clear_query_stack() {
|
||||||
|
let db = DatabaseStruct::default();
|
||||||
|
|
||||||
|
// Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
|
||||||
|
// will default to 0 and we should catch the panic.
|
||||||
|
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// The database has been poisoned and any attempt to increment the
|
||||||
|
// revision should panic.
|
||||||
|
assert_eq!(db.salsa_runtime().active_query(), None);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue