extract register_with_in_progress_thread helper

This commit is contained in:
Niko Matsakis 2018-10-15 11:48:21 -04:00
parent 29831a7430
commit cf72c98946

View file

@ -17,7 +17,7 @@ use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::mpsc::{self, Sender};
use std::sync::mpsc::{self, Receiver, Sender};
/// Memoized queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
@ -358,26 +358,20 @@ where
match map.get(key) {
Some(QueryState::InProgress { id, waiting }) => {
let other_id = *id;
if other_id == runtime.id() {
return ProbeState::UpToDate(Err(CycleDetected));
} else {
if !runtime.try_block_on(descriptor, other_id) {
return ProbeState::UpToDate(Err(CycleDetected));
return match self
.register_with_in_progress_thread(runtime, descriptor, other_id, waiting)
{
Ok(rx) => {
// Release our lock on `self.map`, so other thread
// can complete.
std::mem::drop(map);
let value = rx.recv().unwrap();
ProbeState::UpToDate(Ok(value))
}
let (tx, rx) = mpsc::channel();
// The reader of this will have to acquire map
// lock, we don't need any particular ordering.
waiting.lock().push(tx);
// Release our lock on `self.map`, so other thread
// can complete.
std::mem::drop(map);
let value = rx.recv().unwrap();
return ProbeState::UpToDate(Ok(value));
}
Err(CycleDetected) => ProbeState::UpToDate(Err(CycleDetected)),
};
}
Some(QueryState::Memoized(memo)) => {
@ -413,6 +407,37 @@ where
ProbeState::StaleOrAbsent(map)
}
/// Helper:
///
/// When we encounter an `InProgress` indicator, we need to either
/// report a cycle or else register ourselves to be notified when
/// that work completes. This helper does that; it returns a port
/// where you can wait for the final value that wound up being
/// computed (but first drop the lock on the map).
fn register_with_in_progress_thread(
&self,
runtime: &Runtime<DB>,
descriptor: &DB::QueryDescriptor,
other_id: RuntimeId,
waiting: &Mutex<SmallVec<[Sender<StampedValue<Q::Value>>; 2]>>,
) -> Result<Receiver<StampedValue<Q::Value>>, CycleDetected> {
if other_id == runtime.id() {
return Err(CycleDetected);
} else {
if !runtime.try_block_on(descriptor, other_id) {
return Err(CycleDetected);
}
let (tx, rx) = mpsc::channel();
// The reader of this will have to acquire map
// lock, we don't need any particular ordering.
waiting.lock().push(tx);
Ok(rx)
}
}
/// Overwrites the `InProgress` placeholder for `key` that we
/// inserted; if others were blocked, waiting for us to finish,
/// the notify them.