mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-08 21:35:47 +00:00
Panic safely in a single threaded context
This commit is contained in:
parent
bfc56d2591
commit
917ca42f04
2 changed files with 102 additions and 5 deletions
|
@ -263,6 +263,8 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let panic_guard = PanicGuard::new(&self.map, key);
|
||||||
|
|
||||||
// If we have an old-value, it *may* now be stale, since there
|
// If we have an old-value, it *may* now be stale, since there
|
||||||
// has been a new revision since the last time we checked. So,
|
// has been a new revision since the last time we checked. So,
|
||||||
// first things first, let's walk over each of our previous
|
// first things first, let's walk over each of our previous
|
||||||
|
@ -277,7 +279,14 @@ where
|
||||||
let changed_at = memo.changed_at;
|
let changed_at = memo.changed_at;
|
||||||
|
|
||||||
let new_value = StampedValue { value, changed_at };
|
let new_value = StampedValue { value, changed_at };
|
||||||
self.overwrite_placeholder(runtime, descriptor, key, old_memo.unwrap(), &new_value);
|
self.overwrite_placeholder(
|
||||||
|
runtime,
|
||||||
|
descriptor,
|
||||||
|
key,
|
||||||
|
old_memo.unwrap(),
|
||||||
|
&new_value,
|
||||||
|
panic_guard,
|
||||||
|
);
|
||||||
return Ok(new_value);
|
return Ok(new_value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -333,6 +342,7 @@ where
|
||||||
verified_at: revision_now,
|
verified_at: revision_now,
|
||||||
},
|
},
|
||||||
&stamped_value,
|
&stamped_value,
|
||||||
|
panic_guard,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -463,7 +473,11 @@ where
|
||||||
key: &Q::Key,
|
key: &Q::Key,
|
||||||
memo: Memo<DB, Q>,
|
memo: Memo<DB, Q>,
|
||||||
new_value: &StampedValue<Q::Value>,
|
new_value: &StampedValue<Q::Value>,
|
||||||
|
panic_guard: PanicGuard<'_, DB, Q>,
|
||||||
) {
|
) {
|
||||||
|
// No panic occurred, do not run the panic-guard destructor:
|
||||||
|
std::mem::forget(panic_guard);
|
||||||
|
|
||||||
// Overwrite the value, releasing the lock afterwards:
|
// Overwrite the value, releasing the lock afterwards:
|
||||||
let waiting = {
|
let waiting = {
|
||||||
let mut write = self.map.write();
|
let mut write = self.map.write();
|
||||||
|
@ -501,6 +515,37 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct PanicGuard<'db, DB, Q>
|
||||||
|
where
|
||||||
|
DB: Database + 'db,
|
||||||
|
Q: QueryFunction<DB>,
|
||||||
|
{
|
||||||
|
map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>,
|
||||||
|
key: &'db Q::Key,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'db, DB, Q> PanicGuard<'db, DB, Q>
|
||||||
|
where
|
||||||
|
DB: Database + 'db,
|
||||||
|
Q: QueryFunction<DB>,
|
||||||
|
{
|
||||||
|
fn new(map: &'db RwLock<FxHashMap<Q::Key, QueryState<DB, Q>>>, key: &'db Q::Key) -> Self {
|
||||||
|
Self { map, key }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'db, DB, Q> Drop for PanicGuard<'db, DB, Q>
|
||||||
|
where
|
||||||
|
DB: Database + 'db,
|
||||||
|
Q: QueryFunction<DB>,
|
||||||
|
{
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let map = self.map.upgradable_read();
|
||||||
|
let mut map = RwLockUpgradableReadGuard::upgrade(map);
|
||||||
|
let _ = map.remove(self.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<DB, Q, MP> QueryStorageOps<DB, Q> for DerivedStorage<DB, Q, MP>
|
impl<DB, Q, MP> QueryStorageOps<DB, Q> for DerivedStorage<DB, Q, MP>
|
||||||
where
|
where
|
||||||
Q: QueryFunction<DB>,
|
Q: QueryFunction<DB>,
|
||||||
|
@ -622,8 +667,7 @@ where
|
||||||
key,
|
key,
|
||||||
old_input
|
old_input
|
||||||
)
|
)
|
||||||
})
|
}).next()
|
||||||
.next()
|
|
||||||
.is_some();
|
.is_some();
|
||||||
|
|
||||||
// Either way, we have to update our entry.
|
// Either way, we have to update our entry.
|
||||||
|
@ -740,8 +784,7 @@ where
|
||||||
Q::default(),
|
Q::default(),
|
||||||
old_input
|
old_input
|
||||||
)
|
)
|
||||||
})
|
}).next();
|
||||||
.next();
|
|
||||||
|
|
||||||
changed_input.is_none()
|
changed_input.is_none()
|
||||||
}
|
}
|
||||||
|
|
54
tests/panic_safely.rs
Normal file
54
tests/panic_safely.rs
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
use salsa::Database;
|
||||||
|
use std::panic::{self, AssertUnwindSafe};
|
||||||
|
|
||||||
|
salsa::query_group! {
|
||||||
|
trait PanicSafelyDatabase: salsa::Database {
|
||||||
|
fn one() -> usize {
|
||||||
|
type One;
|
||||||
|
storage input;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn panic_safely() -> () {
|
||||||
|
type PanicSafely;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn panic_safely(db: &impl PanicSafelyDatabase) -> () {
|
||||||
|
assert_eq!(db.one(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct DatabaseStruct {
|
||||||
|
runtime: salsa::Runtime<DatabaseStruct>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl salsa::Database for DatabaseStruct {
|
||||||
|
fn salsa_runtime(&self) -> &salsa::Runtime<DatabaseStruct> {
|
||||||
|
&self.runtime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
salsa::database_storage! {
|
||||||
|
struct DatabaseStorage for DatabaseStruct {
|
||||||
|
impl PanicSafelyDatabase {
|
||||||
|
fn one() for One;
|
||||||
|
fn panic_safely() for PanicSafely;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_panic_safely() {
|
||||||
|
let db = DatabaseStruct::default();
|
||||||
|
|
||||||
|
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
|
||||||
|
// default to 0 and we should catch the panic.
|
||||||
|
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Set `db.one` to 1 and assert ok
|
||||||
|
db.query(One).set((), 1);
|
||||||
|
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
|
||||||
|
assert!(result.is_ok())
|
||||||
|
}
|
Loading…
Reference in a new issue