diff --git a/Cargo.toml b/Cargo.toml index ec195ace..c391f02c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,8 @@ repository = "https://github.com/salsa-rs/salsa" description = "A generic framework for on-demand, incrementalized computation (experimental)" [dependencies] +crossbeam-utils = { version = "0.8", default-features = false } +dashmap = "4.0.2" indexmap = "1.0.1" lock_api = "0.4" log = "0.4.5" diff --git a/src/derived.rs b/src/derived.rs index 87ba5976..5228451e 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -1,6 +1,6 @@ use crate::debug::TableEntry; use crate::durability::Durability; -use crate::hash::FxIndexMap; +use crate::hash::FxDashMap; use crate::lru::Lru; use crate::plumbing::DerivedQueryStorageOps; use crate::plumbing::LruQueryStorageOps; @@ -10,9 +10,8 @@ use crate::plumbing::QueryStorageOps; use crate::runtime::StampedValue; use crate::Runtime; use crate::{Database, DatabaseKeyIndex, QueryDb, Revision}; -use parking_lot::RwLock; +use crossbeam_utils::atomic::AtomicCell; use std::borrow::Borrow; -use std::convert::TryFrom; use std::hash::Hash; use std::marker::PhantomData; use std::sync::Arc; @@ -39,10 +38,23 @@ where { group_index: u16, lru_list: Lru>, - slot_map: RwLock>>>, + indices: AtomicCell, + index_map: FxDashMap, + slot_map: FxDashMap>, policy: PhantomData, } +struct KeySlot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + key: Q::Key, + slot: Arc>, +} + +type DerivedKeyIndex = u32; + impl std::panic::RefUnwindSafe for DerivedStorage where Q: QueryFunction, @@ -95,22 +107,52 @@ where Q: QueryFunction, MP: MemoizationPolicy, { - fn slot(&self, key: &Q::Key) -> Arc> { - if let Some(v) = self.slot_map.read().get(key) { - return v.clone(); + fn slot_for_key(&self, key: &Q::Key) -> Arc> { + // Common case: get an existing key + if let Some(v) = self.index_map.get(key) { + let index = *v; + + // release the read-write lock early, for no particular reason + // apart from it bothers me + drop(v); + + return self.slot_for_key_index(index); } - let mut write = self.slot_map.write(); - let entry = write.entry(key.clone()); - let key_index = u32::try_from(entry.index()).unwrap(); - let database_key_index = DatabaseKeyIndex { - group_index: self.group_index, - query_index: Q::QUERY_INDEX, - key_index, - }; - entry - .or_insert_with(|| Arc::new(Slot::new(key.clone(), database_key_index))) - .clone() + // Less common case: (potentially) create a new slot + match self.index_map.entry(key.clone()) { + dashmap::mapref::entry::Entry::Occupied(entry) => self.slot_for_key_index(*entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + let key_index = self.indices.fetch_add(1); + let database_key_index = DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index, + }; + let slot = Arc::new(Slot::new(key.clone(), database_key_index)); + // Subtle: store the new slot *before* the new index, so that + // other threads only see the new index once the slot is also available. + self.slot_map.insert( + key_index, + KeySlot { + key: key.clone(), + slot: slot.clone(), + }, + ); + entry.insert(key_index); + slot + } + } + } + + fn slot_for_key_index(&self, index: DerivedKeyIndex) -> Arc> { + return self.slot_map.get(&index).unwrap().slot.clone(); + } + + fn slot_for_db_index(&self, index: DatabaseKeyIndex) -> Arc> { + assert_eq!(index.group_index, self.group_index); + assert_eq!(index.query_index, Q::QUERY_INDEX); + self.slot_for_key_index(index.key_index) } } @@ -124,9 +166,11 @@ where fn new(group_index: u16) -> Self { DerivedStorage { group_index, - slot_map: RwLock::new(FxIndexMap::default()), + index_map: Default::default(), + slot_map: Default::default(), lru_list: Default::default(), policy: PhantomData, + indices: Default::default(), } } @@ -138,9 +182,8 @@ where ) -> std::fmt::Result { assert_eq!(index.group_index, self.group_index); assert_eq!(index.query_index, Q::QUERY_INDEX); - let slot_map = self.slot_map.read(); - let key = slot_map.get_index(index.key_index as usize).unwrap().0; - write!(fmt, "{}({:?})", Q::QUERY_NAME, key) + let key_slot = self.slot_map.get(&index.key_index).unwrap(); + write!(fmt, "{}({:?})", Q::QUERY_NAME, key_slot.key) } fn maybe_changed_after( @@ -149,23 +192,15 @@ where input: DatabaseKeyIndex, revision: Revision, ) -> bool { - assert_eq!(input.group_index, self.group_index); - assert_eq!(input.query_index, Q::QUERY_INDEX); debug_assert!(revision < db.salsa_runtime().current_revision()); - let slot = self - .slot_map - .read() - .get_index(input.key_index as usize) - .unwrap() - .1 - .clone(); + let slot = self.slot_for_db_index(input); slot.maybe_changed_after(db, revision) } fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { db.unwind_if_cancelled(); - let slot = self.slot(key); + let slot = self.slot_for_key(key); let StampedValue { value, durability, @@ -187,17 +222,16 @@ where } fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability { - self.slot(key).durability(db) + self.slot_for_key(key).durability(db) } fn entries(&self, _db: &>::DynDb) -> C where C: std::iter::FromIterator>, { - let slot_map = self.slot_map.read(); - slot_map - .values() - .filter_map(|slot| slot.as_table_entry()) + self.slot_map + .iter() + .filter_map(|r| r.value().slot.as_table_entry()) .collect() } } @@ -209,7 +243,9 @@ where { fn purge(&self) { self.lru_list.purge(); - *self.slot_map.write() = Default::default(); + self.indices.store(0); + self.index_map.clear(); + self.slot_map.clear(); } } @@ -234,14 +270,12 @@ where Q::Key: Borrow, { runtime.with_incremented_revision(|new_revision| { - let map_read = self.slot_map.read(); - - if let Some(slot) = map_read.get(key) { + if let Some(key_index) = self.index_map.get(key) { + let slot = self.slot_for_key_index(*key_index); if let Some(durability) = slot.invalidate(new_revision) { return Some(durability); } } - None }) } diff --git a/src/hash.rs b/src/hash.rs index 3b2d7df3..4c7d2da7 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -1,3 +1,4 @@ pub(crate) type FxHasher = std::hash::BuildHasherDefault; pub(crate) type FxIndexSet = indexmap::IndexSet; pub(crate) type FxIndexMap = indexmap::IndexMap; +pub(crate) type FxDashMap = dashmap::DashMap;