diff --git a/Cargo.toml b/Cargo.toml index c7b4947..044733a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,4 @@ readme = "README.md" [dependencies] derive-new = "0.5.5" rustc-hash = "1.0" +parking_lot = "0.6.4" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 7871c7d..ee98fae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,7 @@ pub trait BaseQueryContext: Sized { pub trait Query: Debug + Default + Sized + 'static { type Key: Clone + Debug + Hash + Eq + Send; type Value: Clone + Debug + Hash + Eq + Send; - type Storage: QueryStorageOps + Send; + type Storage: QueryStorageOps + Send + Sync; fn execute(query: &QC, key: Self::Key) -> Self::Value; } diff --git a/src/storage.rs b/src/storage.rs index 0b85f30..ce1f2e3 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -3,6 +3,7 @@ use crate::CycleDetected; use crate::Query; use crate::QueryStorageOps; use crate::QueryTable; +use parking_lot::{RwLock, RwLockUpgradableReadGuard}; use rustc_hash::FxHashMap; use std::any::Any; use std::cell::RefCell; @@ -21,7 +22,7 @@ where Q: Query, QC: BaseQueryContext, { - map: RefCell>>, + map: RwLock>>, } /// Defines the "current state" of query's memoized results. @@ -42,7 +43,7 @@ where { fn default() -> Self { MemoizedStorage { - map: RefCell::new(FxHashMap::default()), + map: RwLock::new(FxHashMap::default()), } } } @@ -59,18 +60,16 @@ where descriptor: impl FnOnce() -> QC::QueryDescriptor, ) -> Result { { - let mut map = self.map.borrow_mut(); - match map.entry(key.clone()) { - Entry::Occupied(entry) => { - return match entry.get() { - QueryState::InProgress => Err(CycleDetected), - QueryState::Memoized(value) => Ok(value.clone()), - }; - } - Entry::Vacant(entry) => { - entry.insert(QueryState::InProgress); - } + let map_read = self.map.upgradable_read(); + if let Some(value) = map_read.get(key) { + return match value { + QueryState::InProgress => Err(CycleDetected), + QueryState::Memoized(value) => Ok(value.clone()), + }; } + + let mut map_write = RwLockUpgradableReadGuard::upgrade(map_read); + map_write.insert(key.clone(), QueryState::InProgress); } // If we get here, the query is in progress, and we are the @@ -79,8 +78,8 @@ where let value = query.execute_query_implementation::(descriptor, key); { - let mut map = self.map.borrow_mut(); - let old_value = map.insert(key.clone(), QueryState::Memoized(value.clone())); + let mut map_write = self.map.write(); + let old_value = map_write.insert(key.clone(), QueryState::Memoized(value.clone())); assert!( match old_value { Some(QueryState::InProgress) => true,