From 917ca42f04c04fed8afd900b54a940a5d257ac23 Mon Sep 17 00:00:00 2001 From: Kevin Leimkuhler Date: Mon, 22 Oct 2018 21:59:12 -0700 Subject: [PATCH] Panic safely in a single threaded context --- src/derived.rs | 53 ++++++++++++++++++++++++++++++++++++++---- tests/panic_safely.rs | 54 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 tests/panic_safely.rs diff --git a/src/derived.rs b/src/derived.rs index 58fb94a4..127f65b4 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -263,6 +263,8 @@ where } }; + let panic_guard = PanicGuard::new(&self.map, key); + // If we have an old-value, it *may* now be stale, since there // has been a new revision since the last time we checked. So, // first things first, let's walk over each of our previous @@ -277,7 +279,14 @@ where let changed_at = memo.changed_at; let new_value = StampedValue { value, changed_at }; - self.overwrite_placeholder(runtime, descriptor, key, old_memo.unwrap(), &new_value); + self.overwrite_placeholder( + runtime, + descriptor, + key, + old_memo.unwrap(), + &new_value, + panic_guard, + ); return Ok(new_value); } } @@ -333,6 +342,7 @@ where verified_at: revision_now, }, &stamped_value, + panic_guard, ); } @@ -463,7 +473,11 @@ where key: &Q::Key, memo: Memo, new_value: &StampedValue, + panic_guard: PanicGuard<'_, DB, Q>, ) { + // No panic occurred, do not run the panic-guard destructor: + std::mem::forget(panic_guard); + // Overwrite the value, releasing the lock afterwards: let waiting = { let mut write = self.map.write(); @@ -501,6 +515,37 @@ where } } +struct PanicGuard<'db, DB, Q> +where + DB: Database + 'db, + Q: QueryFunction, +{ + map: &'db RwLock>>, + key: &'db Q::Key, +} + +impl<'db, DB, Q> PanicGuard<'db, DB, Q> +where + DB: Database + 'db, + Q: QueryFunction, +{ + fn new(map: &'db RwLock>>, key: &'db Q::Key) -> Self { + Self { map, key } + } +} + +impl<'db, DB, Q> Drop for PanicGuard<'db, DB, Q> +where + DB: Database + 'db, + Q: QueryFunction, +{ + fn drop(&mut self) { + let map = self.map.upgradable_read(); + let mut map = RwLockUpgradableReadGuard::upgrade(map); + let _ = map.remove(self.key); + } +} + impl QueryStorageOps for DerivedStorage where Q: QueryFunction, @@ -622,8 +667,7 @@ where key, old_input ) - }) - .next() + }).next() .is_some(); // Either way, we have to update our entry. @@ -740,8 +784,7 @@ where Q::default(), old_input ) - }) - .next(); + }).next(); changed_input.is_none() } diff --git a/tests/panic_safely.rs b/tests/panic_safely.rs new file mode 100644 index 00000000..e354aa7f --- /dev/null +++ b/tests/panic_safely.rs @@ -0,0 +1,54 @@ +use salsa::Database; +use std::panic::{self, AssertUnwindSafe}; + +salsa::query_group! { + trait PanicSafelyDatabase: salsa::Database { + fn one() -> usize { + type One; + storage input; + } + + fn panic_safely() -> () { + type PanicSafely; + } + } +} + +fn panic_safely(db: &impl PanicSafelyDatabase) -> () { + assert_eq!(db.one(), 1); +} + +#[derive(Default)] +struct DatabaseStruct { + runtime: salsa::Runtime, +} + +impl salsa::Database for DatabaseStruct { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +salsa::database_storage! { + struct DatabaseStorage for DatabaseStruct { + impl PanicSafelyDatabase { + fn one() for One; + fn panic_safely() for PanicSafely; + } + } +} + +#[test] +fn should_panic_safely() { + let db = DatabaseStruct::default(); + + // Invoke `db.panic_safely() without having set `db.one`. `db.one` will + // default to 0 and we should catch the panic. + let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); + assert!(result.is_err()); + + // Set `db.one` to 1 and assert ok + db.query(One).set((), 1); + let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); + assert!(result.is_ok()) +}