fix race condition around dropping arc handle

Sigh, I always make this mistake.
This commit is contained in:
Niko Matsakis 2024-07-25 09:31:06 +00:00
parent 8a3cc6e404
commit 246dcab977

View file

@ -9,7 +9,10 @@ use crate::{storage::HasStorage, Event, EventKind};
/// When you attempt to modify the database, you call `get_mut`, which will set the cancellation flag, /// When you attempt to modify the database, you call `get_mut`, which will set the cancellation flag,
/// causing other handles to get panics. Once all other handles are dropped, you can proceed. /// causing other handles to get panics. Once all other handles are dropped, you can proceed.
pub struct Handle<Db: HasStorage> { pub struct Handle<Db: HasStorage> {
db: Arc<Db>, /// Reference to the database. This is always `Some` except during destruction.
db: Option<Arc<Db>>,
/// Coordination data.
coordinate: Arc<Coordinate>, coordinate: Arc<Coordinate>,
} }
@ -23,7 +26,7 @@ struct Coordinate {
impl<Db: HasStorage> Handle<Db> { impl<Db: HasStorage> Handle<Db> {
pub fn new(db: Db) -> Self { pub fn new(db: Db) -> Self {
Self { Self {
db: Arc::new(db), db: Some(Arc::new(db)),
coordinate: Arc::new(Coordinate { coordinate: Arc::new(Coordinate {
clones: Mutex::new(1), clones: Mutex::new(1),
cvar: Default::default(), cvar: Default::default(),
@ -31,12 +34,29 @@ impl<Db: HasStorage> Handle<Db> {
} }
} }
fn db(&self) -> &Arc<Db> {
self.db.as_ref().unwrap()
}
fn db_mut(&mut self) -> &mut Arc<Db> {
self.db.as_mut().unwrap()
}
/// Returns a mutable reference to the inner database. /// Returns a mutable reference to the inner database.
/// If other handles are active, this method sets the cancellation flag /// If other handles are active, this method sets the cancellation flag
/// and blocks until they are dropped. /// and blocks until they are dropped.
pub fn get_mut(&mut self) -> &mut Db { pub fn get_mut(&mut self) -> &mut Db {
self.cancel_others(); self.cancel_others();
Arc::get_mut(&mut self.db).expect("no other handles")
// Once cancellation above completes, the other handles are being dropped.
// However, because the signal is sent before the destructor completes, it's
// possible that they have not *yet* dropped.
//
// Therefore, we may have to do a (short) bit of
// spinning before we observe the thread-count reducing to 0.
//
// An alternative would be to
Arc::get_mut(self.db_mut()).expect("other threads remain active despite cancellation")
} }
// ANCHOR: cancel_other_workers // ANCHOR: cancel_other_workers
@ -46,10 +66,10 @@ impl<Db: HasStorage> Handle<Db> {
/// This could deadlock if there is a single worker with two handles to the /// This could deadlock if there is a single worker with two handles to the
/// same database! /// same database!
fn cancel_others(&mut self) { fn cancel_others(&mut self) {
let storage = self.db.storage(); let storage = self.db().storage();
storage.runtime().set_cancellation_flag(); storage.runtime().set_cancellation_flag();
self.db.salsa_event(Event { self.db().salsa_event(Event {
thread_id: std::thread::current().id(), thread_id: std::thread::current().id(),
kind: EventKind::DidSetCancellationFlag, kind: EventKind::DidSetCancellationFlag,
@ -65,6 +85,10 @@ impl<Db: HasStorage> Handle<Db> {
impl<Db: HasStorage> Drop for Handle<Db> { impl<Db: HasStorage> Drop for Handle<Db> {
fn drop(&mut self) { fn drop(&mut self) {
// Drop the database handle *first*
self.db.take();
// *Now* decrement the number of clones and notify once we have completed
*self.coordinate.clones.lock() -= 1; *self.coordinate.clones.lock() -= 1;
self.coordinate.cvar.notify_all(); self.coordinate.cvar.notify_all();
} }
@ -74,7 +98,7 @@ impl<Db: HasStorage> std::ops::Deref for Handle<Db> {
type Target = Db; type Target = Db;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.db self.db()
} }
} }
@ -83,7 +107,7 @@ impl<Db: HasStorage> Clone for Handle<Db> {
*self.coordinate.clones.lock() += 1; *self.coordinate.clones.lock() += 1;
Self { Self {
db: Arc::clone(&self.db), db: Some(Arc::clone(self.db())),
coordinate: Arc::clone(&self.coordinate), coordinate: Arc::clone(&self.coordinate),
} }
} }