mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-23 05:07:27 +00:00
Merge pull request #62 from kleimkuhler/issue-24-ensure-panic-safety
Panic safely in a single threaded context
This commit is contained in:
commit
3bc5b78284
2 changed files with 102 additions and 5 deletions
|
@ -263,6 +263,8 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
let panic_guard = PanicGuard::new(&self.map, key);
|
||||
|
||||
// If we have an old-value, it *may* now be stale, since there
|
||||
// has been a new revision since the last time we checked. So,
|
||||
// first things first, let's walk over each of our previous
|
||||
|
@ -277,7 +279,14 @@ where
|
|||
let changed_at = memo.changed_at;
|
||||
|
||||
let new_value = StampedValue { value, changed_at };
|
||||
self.overwrite_placeholder(runtime, descriptor, key, old_memo.unwrap(), &new_value);
|
||||
self.overwrite_placeholder(
|
||||
runtime,
|
||||
descriptor,
|
||||
key,
|
||||
old_memo.unwrap(),
|
||||
&new_value,
|
||||
panic_guard,
|
||||
);
|
||||
return Ok(new_value);
|
||||
}
|
||||
}
|
||||
|
@ -333,6 +342,7 @@ where
|
|||
verified_at: revision_now,
|
||||
},
|
||||
&stamped_value,
|
||||
panic_guard,
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -463,7 +473,11 @@ where
|
|||
key: &Q::Key,
|
||||
memo: Memo<DB, Q>,
|
||||
new_value: &StampedValue<Q::Value>,
|
||||
panic_guard: PanicGuard<'_, DB, Q>,
|
||||
) {
|
||||
// No panic occurred, do not run the panic-guard destructor:
|
||||
std::mem::forget(panic_guard);
|
||||
|
||||
// Overwrite the value, releasing the lock afterwards:
|
||||
let waiting = {
|
||||
let mut write = self.map.write();
|
||||
|
@ -501,6 +515,37 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct PanicGuard<'db, DB, Q>
|
||||
where
|
||||
DB: Database + 'db,
|
||||
Q: QueryFunction<DB>,
|
||||
{
|
||||
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
||||
key: &'db Q::Key,
|
||||
}
|
||||
|
||||
impl<'db, DB, Q> PanicGuard<'db, DB, Q>
|
||||
where
|
||||
DB: Database + 'db,
|
||||
Q: QueryFunction<DB>,
|
||||
{
|
||||
fn new(map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>, key: &'db Q::Key) -> Self {
|
||||
Self { map, key }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'db, DB, Q> Drop for PanicGuard<'db, DB, Q>
|
||||
where
|
||||
DB: Database + 'db,
|
||||
Q: QueryFunction<DB>,
|
||||
{
|
||||
// FIXME(#24) -- handle parallel case
|
||||
fn drop(&mut self) {
|
||||
let mut map = self.map.write();
|
||||
let _ = map.remove(self.key);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB, Q, MP> QueryStorageOps<DB, Q> for DerivedStorage<DB, Q, MP>
|
||||
where
|
||||
Q: QueryFunction<DB>,
|
||||
|
@ -622,8 +667,7 @@ where
|
|||
key,
|
||||
old_input
|
||||
)
|
||||
})
|
||||
.next()
|
||||
}).next()
|
||||
.is_some();
|
||||
|
||||
// Either way, we have to update our entry.
|
||||
|
@ -740,8 +784,7 @@ where
|
|||
Q::default(),
|
||||
old_input
|
||||
)
|
||||
})
|
||||
.next();
|
||||
}).next();
|
||||
|
||||
changed_input.is_none()
|
||||
}
|
||||
|
|
54
tests/panic_safely.rs
Normal file
54
tests/panic_safely.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use salsa::Database;
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
|
||||
salsa::query_group! {
|
||||
trait PanicSafelyDatabase: salsa::Database {
|
||||
fn one() -> usize {
|
||||
type One;
|
||||
storage input;
|
||||
}
|
||||
|
||||
fn panic_safely() -> () {
|
||||
type PanicSafely;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn panic_safely(db: &impl PanicSafelyDatabase) -> () {
|
||||
assert_eq!(db.one(), 1);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct DatabaseStruct {
|
||||
runtime: salsa::Runtime<DatabaseStruct>,
|
||||
}
|
||||
|
||||
impl salsa::Database for DatabaseStruct {
|
||||
fn salsa_runtime(&self) -> &salsa::Runtime<DatabaseStruct> {
|
||||
&self.runtime
|
||||
}
|
||||
}
|
||||
|
||||
salsa::database_storage! {
|
||||
struct DatabaseStorage for DatabaseStruct {
|
||||
impl PanicSafelyDatabase {
|
||||
fn one() for One;
|
||||
fn panic_safely() for PanicSafely;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_panic_safely() {
|
||||
let db = DatabaseStruct::default();
|
||||
|
||||
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
|
||||
// default to 0 and we should catch the panic.
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Set `db.one` to 1 and assert ok
|
||||
db.query(One).set((), 1);
|
||||
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||
assert!(result.is_ok())
|
||||
}
|
Loading…
Reference in a new issue