Merge pull request #62 from kleimkuhler/issue-24-ensure-panic-safety

Panic safely in a single threaded context
This commit is contained in:
Niko Matsakis 2018-10-24 14:14:03 -04:00 committed by GitHub
commit 3bc5b78284
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 5 deletions

View file

@ -263,6 +263,8 @@ where
}
};
let panic_guard = PanicGuard::new(&self.map, key);
// If we have an old-value, it *may* now be stale, since there
// has been a new revision since the last time we checked. So,
// first things first, let's walk over each of our previous
@ -277,7 +279,14 @@ where
let changed_at = memo.changed_at;
let new_value = StampedValue { value, changed_at };
self.overwrite_placeholder(runtime, descriptor, key, old_memo.unwrap(), &new_value);
self.overwrite_placeholder(
runtime,
descriptor,
key,
old_memo.unwrap(),
&new_value,
panic_guard,
);
return Ok(new_value);
}
}
@ -333,6 +342,7 @@ where
verified_at: revision_now,
},
&stamped_value,
panic_guard,
);
}
@ -463,7 +473,11 @@ where
key: &Q::Key,
memo: Memo<DB, Q>,
new_value: &StampedValue<Q::Value>,
panic_guard: PanicGuard<'_, DB, Q>,
) {
// No panic occurred, do not run the panic-guard destructor:
std::mem::forget(panic_guard);
// Overwrite the value, releasing the lock afterwards:
let waiting = {
let mut write = self.map.write();
@ -501,6 +515,37 @@ where
}
}
struct PanicGuard<'db, DB, Q>
where
DB: Database + 'db,
Q: QueryFunction<DB>,
{
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
key: &'db Q::Key,
}
impl<'db, DB, Q> PanicGuard<'db, DB, Q>
where
DB: Database + 'db,
Q: QueryFunction<DB>,
{
fn new(map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>, key: &'db Q::Key) -> Self {
Self { map, key }
}
}
impl<'db, DB, Q> Drop for PanicGuard<'db, DB, Q>
where
DB: Database + 'db,
Q: QueryFunction<DB>,
{
// FIXME(#24) -- handle parallel case
fn drop(&mut self) {
let mut map = self.map.write();
let _ = map.remove(self.key);
}
}
impl<DB, Q, MP> QueryStorageOps<DB, Q> for DerivedStorage<DB, Q, MP>
where
Q: QueryFunction<DB>,
@ -622,8 +667,7 @@ where
key,
old_input
)
})
.next()
}).next()
.is_some();
// Either way, we have to update our entry.
@ -740,8 +784,7 @@ where
Q::default(),
old_input
)
})
.next();
}).next();
changed_input.is_none()
}

54
tests/panic_safely.rs Normal file
View file

@ -0,0 +1,54 @@
use salsa::Database;
use std::panic::{self, AssertUnwindSafe};
salsa::query_group! {
trait PanicSafelyDatabase: salsa::Database {
fn one() -> usize {
type One;
storage input;
}
fn panic_safely() -> () {
type PanicSafely;
}
}
}
fn panic_safely(db: &impl PanicSafelyDatabase) -> () {
assert_eq!(db.one(), 1);
}
#[derive(Default)]
struct DatabaseStruct {
runtime: salsa::Runtime<DatabaseStruct>,
}
impl salsa::Database for DatabaseStruct {
fn salsa_runtime(&self) -> &salsa::Runtime<DatabaseStruct> {
&self.runtime
}
}
salsa::database_storage! {
struct DatabaseStorage for DatabaseStruct {
impl PanicSafelyDatabase {
fn one() for One;
fn panic_safely() for PanicSafely;
}
}
}
#[test]
fn should_panic_safely() {
let db = DatabaseStruct::default();
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
// default to 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_err());
// Set `db.one` to 1 and assert ok
db.query(One).set((), 1);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_ok())
}