store the runtime-id in the InProgress indicator

This commit is contained in:
Niko Matsakis 2018-10-12 12:10:25 -04:00
parent 36f72c0b58
commit ca329ddd10
2 changed files with 38 additions and 13 deletions

View file

@ -1,6 +1,8 @@
use crate::runtime::ChangedAt; use crate::runtime::ChangedAt;
use crate::runtime::QueryDescriptorSet; use crate::runtime::QueryDescriptorSet;
use crate::runtime::Revision; use crate::runtime::Revision;
use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
use crate::runtime::StampedValue; use crate::runtime::StampedValue;
use crate::CycleDetected; use crate::CycleDetected;
use crate::Database; use crate::Database;
@ -106,9 +108,10 @@ where
Q: QueryFunction<DB>, Q: QueryFunction<DB>,
DB: Database, DB: Database,
{ {
/// We are currently computing the result of this query; if we see /// The runtime with the given id is currently computing the
/// this value in the table, it indeeds a cycle. /// result of this query; if we see this value in the table, it
InProgress, /// indeeds a cycle.
InProgress(RuntimeId),
/// We have computed the query already, and here is the result. /// We have computed the query already, and here is the result.
Memoized(Memo<DB, Q>), Memoized(Memo<DB, Q>),
@ -180,7 +183,13 @@ where
let map_read = self.map.upgradable_read(); let map_read = self.map.upgradable_read();
if let Some(value) = map_read.get(key) { if let Some(value) = map_read.get(key) {
match value { match value {
QueryState::InProgress => return Err(CycleDetected), QueryState::InProgress(id) => {
if *id == runtime.id() {
return Err(CycleDetected);
} else {
unimplemented!();
}
}
QueryState::Memoized(m) => { QueryState::Memoized(m) => {
debug!( debug!(
"{:?}({:?}): found memoized value verified_at={:?}", "{:?}({:?}): found memoized value verified_at={:?}",
@ -211,7 +220,7 @@ where
} }
let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read); let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read);
map_write.insert(key.clone(), QueryState::InProgress) map_write.insert(key.clone(), QueryState::InProgress(runtime.id()))
}; };
// If we have an old-value, it *may* now be stale, since there // If we have an old-value, it *may* now be stale, since there
@ -228,7 +237,7 @@ where
let changed_at = old_memo.changed_at; let changed_at = old_memo.changed_at;
let mut map_write = self.map.write(); let mut map_write = self.map.write();
self.overwrite_placeholder(&mut map_write, key, old_value.unwrap()); self.overwrite_placeholder(runtime, &mut map_write, key, old_value.unwrap());
return Ok(StampedValue { value, changed_at }); return Ok(StampedValue { value, changed_at });
} }
} }
@ -273,6 +282,7 @@ where
}; };
let mut map_write = self.map.write(); let mut map_write = self.map.write();
self.overwrite_placeholder( self.overwrite_placeholder(
runtime,
&mut map_write, &mut map_write,
key, key,
QueryState::Memoized(Memo { QueryState::Memoized(Memo {
@ -289,6 +299,7 @@ where
fn overwrite_placeholder( fn overwrite_placeholder(
&self, &self,
runtime: &Runtime<DB>,
map_write: &mut FxHashMap<Q::Key, QueryState<DB, Q>>, map_write: &mut FxHashMap<Q::Key, QueryState<DB, Q>>,
key: &Q::Key, key: &Q::Key,
value: QueryState<DB, Q>, value: QueryState<DB, Q>,
@ -296,7 +307,7 @@ where
let old_value = map_write.insert(key.clone(), value); let old_value = map_write.insert(key.clone(), value);
assert!( assert!(
match old_value { match old_value {
Some(QueryState::InProgress) => true, Some(QueryState::InProgress(id)) => id == runtime.id(),
_ => false, _ => false,
}, },
"expected in-progress state", "expected in-progress state",
@ -338,7 +349,8 @@ where
key: &Q::Key, key: &Q::Key,
descriptor: &DB::QueryDescriptor, descriptor: &DB::QueryDescriptor,
) -> bool { ) -> bool {
let revision_now = db.salsa_runtime().current_revision(); let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
debug!( debug!(
"{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})", "{:?}({:?})::maybe_changed_since(revision={:?}, revision_now={:?})",
@ -351,7 +363,7 @@ where
let value = { let value = {
let map_read = self.map.upgradable_read(); let map_read = self.map.upgradable_read();
match map_read.get(key) { match map_read.get(key) {
None | Some(QueryState::InProgress) => return true, None | Some(QueryState::InProgress(_)) => return true,
Some(QueryState::Memoized(memo)) => { Some(QueryState::Memoized(memo)) => {
// If our memo is still up to date, then check if we've // If our memo is still up to date, then check if we've
// changed since the revision. // changed since the revision.
@ -372,7 +384,7 @@ where
// If, however, we don't cache values, then optimistically // If, however, we don't cache values, then optimistically
// try to advance `verified_at` by walking the inputs. // try to advance `verified_at` by walking the inputs.
let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read); let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read);
map_write.insert(key.clone(), QueryState::InProgress) map_write.insert(key.clone(), QueryState::InProgress(runtime.id()))
}; };
let mut memo = match value { let mut memo = match value {
@ -382,7 +394,12 @@ where
if memo.verify_inputs(db) { if memo.verify_inputs(db) {
memo.verified_at = revision_now; memo.verified_at = revision_now;
self.overwrite_placeholder(&mut self.map.write(), key, QueryState::Memoized(memo)); self.overwrite_placeholder(
runtime,
&mut self.map.write(),
key,
QueryState::Memoized(memo),
);
return false; return false;
} }
@ -396,7 +413,7 @@ where
let map_read = self.map.read(); let map_read = self.map.read();
match map_read.get(key) { match map_read.get(key) {
None => false, None => false,
Some(QueryState::InProgress) => panic!("query in progress"), Some(QueryState::InProgress(_)) => panic!("query in progress"),
Some(QueryState::Memoized(memo)) => memo.changed_at.is_constant(), Some(QueryState::Memoized(memo)) => memo.changed_at.is_constant(),
} }
} }

View file

@ -45,6 +45,7 @@ where
Self::default() Self::default()
} }
/// Returns the underlying storage, where the keys/values for all queries are kept.
pub fn storage(&self) -> &DB::DatabaseStorage { pub fn storage(&self) -> &DB::DatabaseStorage {
&self.shared_state.storage &self.shared_state.storage
} }
@ -96,7 +97,13 @@ where
} }
} }
#[inline]
pub(crate) fn id(&self) -> RuntimeId {
self.id
}
/// Read current value of the revision counter. /// Read current value of the revision counter.
#[inline]
pub(crate) fn current_revision(&self) -> Revision { pub(crate) fn current_revision(&self) -> Revision {
Revision { Revision {
generation: self.shared_state.revision.load(Ordering::SeqCst) as u64, generation: self.shared_state.revision.load(Ordering::SeqCst) as u64,
@ -111,6 +118,7 @@ where
/// canceled (which indicates that current query results will be /// canceled (which indicates that current query results will be
/// ignored) your query is free to shortcircuit and return /// ignored) your query is free to shortcircuit and return
/// whatever it likes. /// whatever it likes.
#[inline]
pub fn is_current_revision_canceled(&self) -> bool { pub fn is_current_revision_canceled(&self) -> bool {
let pending_revision_increments = self let pending_revision_increments = self
.shared_state .shared_state
@ -354,7 +362,7 @@ impl<DB: Database> ActiveQuery<DB> {
} }
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RuntimeId { pub(crate) struct RuntimeId {
counter: usize, counter: usize,
} }