From e3f5eb6ee832d729f74c0178235e5389d3104413 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sun, 3 Feb 2019 10:47:18 -0500 Subject: [PATCH] implement `#[salsa::interned]` query storage --- components/salsa-macros/src/query_group.rs | 34 +- src/interned.rs | 448 +++++++++++++++++++++ src/lib.rs | 13 + src/plumbing.rs | 11 + tests/interned.rs | 88 ++++ 5 files changed, 593 insertions(+), 1 deletion(-) create mode 100644 src/interned.rs create mode 100644 tests/interned.rs diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 8d8e0baf..3cbe5873 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -60,6 +60,10 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream storage = QueryStorage::Input; num_storages += 1; } + "interned" => { + storage = QueryStorage::Interned; + num_storages += 1; + } "invoke" => { invoke = Some(parse_macro_input!(tts as Parenthesized).0); } @@ -195,6 +199,23 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream }); } + // For interned queries, we need `lookup_foo` + if let QueryStorage::Interned = query.storage { + let lookup_fn_name = Ident::new(&format!("lookup_{}", fn_name), fn_name.span()); + + query_fn_declarations.extend(quote! { + /// Lookup the value(s) interned with a specific key. + fn #lookup_fn_name(&mut self, value: #value) -> (#(#keys),*); + }); + + query_fn_definitions.extend(quote! { + fn #lookup_fn_name(&mut self, value: #value) -> (#(#keys),*) { + >::get_query_table(self) + .lookup(value) + } + }); + } + // A variant for the group descriptor below query_descriptor_variants.extend(quote! { #fn_name((#(#keys),*)), @@ -276,6 +297,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream QueryStorage::Volatile => "VolatileStorage", QueryStorage::Dependencies => "DependencyStorage", QueryStorage::Input => "InputStorage", + QueryStorage::Interned => "InternedStorage", }, Span::call_site(), ); @@ -310,7 +332,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream }); // Implement the QueryFunction trait for all queries except inputs. - if query.storage != QueryStorage::Input { + if query.storage.needs_query_function() { let span = query.fn_name.span(); let key_names: &Vec<_> = &(0..query.keys.len()) .map(|i| Ident::new(&format!("key{}", i), Span::call_site())) @@ -446,4 +468,14 @@ enum QueryStorage { Volatile, Dependencies, Input, + Interned, +} + +impl QueryStorage { + fn needs_query_function(self) -> bool { + match self { + QueryStorage::Input | QueryStorage::Interned => false, + QueryStorage::Memoized | QueryStorage::Volatile | QueryStorage::Dependencies => true, + } + } } diff --git a/src/interned.rs b/src/interned.rs new file mode 100644 index 00000000..e4223d17 --- /dev/null +++ b/src/interned.rs @@ -0,0 +1,448 @@ +use crate::debug::TableEntry; +use crate::plumbing::CycleDetected; +use crate::plumbing::InternedQueryStorageOps; +use crate::plumbing::QueryStorageMassOps; +use crate::plumbing::QueryStorageOps; +use crate::runtime::ChangedAt; +use crate::runtime::Revision; +use crate::runtime::StampedValue; +use crate::Query; +use crate::{Database, DiscardIf, SweepStrategy}; +use parking_lot::RwLock; +use rustc_hash::FxHashMap; +use std::collections::hash_map::Entry; +use std::convert::From; +use std::hash::Hash; + +/// Handles storage where the value is 'derived' by executing a +/// function (in contrast to "inputs"). +pub struct InternedStorage +where + Q: Query, + Q::Value: From, + Q::Value: Into, + DB: Database, +{ + tables: RwLock>, +} + +struct InternTables { + /// Map from the key to the corresponding intern-index. + map: FxHashMap, + + /// For each valid intern-index, stores the interned value. When + /// an interned value is GC'd, the entry is set to + /// `InternValue::Free` with the next free item. + values: Vec>, + + /// Index of the first free intern-index, if any. + first_free: Option, +} + +/// Newtype indicating an index into the intern table. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +struct InternIndex { + index: u32, +} + +impl InternIndex { + fn index(self) -> usize { + self.index as usize + } +} + +impl From for InternIndex { + fn from(v: usize) -> Self { + assert!(v < (std::u32::MAX as usize)); + InternIndex { index: v as u32 } + } +} + +enum InternValue { + /// The value has not been gc'd. + Present { + value: K, + + /// When was this intern'd? + /// + /// (This informs the "changed-at" result) + interned_at: Revision, + + /// When was it accessed? + /// + /// (This informs the garbage collector) + accessed_at: Revision, + }, + + /// Free-list -- the index is the next + Free { next: Option }, +} + +impl std::panic::RefUnwindSafe for InternedStorage +where + Q: Query, + DB: Database, + Q::Key: std::panic::RefUnwindSafe, + Q::Value: From, + Q::Value: Into, + Q::Value: std::panic::RefUnwindSafe, +{ +} + +impl Default for InternedStorage +where + Q: Query, + Q::Key: Eq + Hash, + Q::Value: From, + Q::Value: Into, + DB: Database, +{ + fn default() -> Self { + InternedStorage { + tables: RwLock::new(InternTables::default()), + } + } +} + +impl Default for InternTables +where + K: Eq + Hash, +{ + fn default() -> Self { + Self { + map: Default::default(), + values: Default::default(), + first_free: Default::default(), + } + } +} + +impl InternedStorage +where + Q: Query, + Q::Key: Eq + Hash + Clone, + Q::Value: From, + Q::Value: Into, + DB: Database, +{ + fn intern_index(&self, db: &DB, key: &Q::Key) -> StampedValue { + if let Some(i) = self.intern_check(db, key) { + return i; + } + + let owned_key1 = key.to_owned(); + let owned_key2 = owned_key1.clone(); + let revision_now = db.salsa_runtime().current_revision(); + + let mut tables = self.tables.write(); + let tables = &mut *tables; + let entry = match tables.map.entry(owned_key1) { + Entry::Vacant(entry) => entry, + Entry::Occupied(entry) => { + // Somebody inserted this key while we were waiting + // for the write lock. + let index = *entry.get(); + match &tables.values[index.index()] { + InternValue::Present { + value, + interned_at, + accessed_at, + } => { + debug_assert_eq!(owned_key2, *value); + debug_assert_eq!(*accessed_at, revision_now); + return StampedValue { + value: index, + changed_at: ChangedAt { + is_constant: false, + revision: *interned_at, + }, + }; + } + + InternValue::Free { .. } => { + panic!("key {:?} should be present but is not", key,); + } + } + } + }; + + let index = match tables.first_free { + None => { + let index = InternIndex::from(tables.values.len()); + tables.values.push(InternValue::Present { + value: owned_key2, + interned_at: revision_now, + accessed_at: revision_now, + }); + index + } + + Some(i) => { + let next_free = match &tables.values[i.index()] { + InternValue::Free { next } => *next, + InternValue::Present { value, .. } => { + panic!( + "index {:?} was supposed to be free but contains {:?}", + i, value + ); + } + }; + + tables.values[i.index()] = InternValue::Present { + value: owned_key2, + interned_at: revision_now, + accessed_at: revision_now, + }; + tables.first_free = next_free; + i + } + }; + + entry.insert(index); + + StampedValue { + value: index, + changed_at: ChangedAt { + is_constant: false, + revision: revision_now, + }, + } + } + + fn intern_check(&self, db: &DB, key: &Q::Key) -> Option> { + let revision_now = db.salsa_runtime().current_revision(); + + // First, + { + let tables = self.tables.read(); + let &index = tables.map.get(key)?; + match &tables.values[index.index()] { + InternValue::Present { + interned_at, + accessed_at, + .. + } => { + if *accessed_at == revision_now { + return Some(StampedValue { + value: index, + changed_at: ChangedAt { + is_constant: false, + revision: *interned_at, + }, + }); + } + } + + InternValue::Free { .. } => { + panic!( + "key {:?} maps to index {:?} is free but should not be", + key, index + ); + } + } + } + + // Next, + let mut tables = self.tables.write(); + let &index = tables.map.get(key)?; + match &mut tables.values[index.index()] { + InternValue::Present { + interned_at, + accessed_at, + .. + } => { + *accessed_at = revision_now; + + return Some(StampedValue { + value: index, + changed_at: ChangedAt { + is_constant: false, + revision: *interned_at, + }, + }); + } + + InternValue::Free { .. } => { + panic!( + "key {:?} maps to index {:?} is free but should not be", + key, index + ); + } + } + } + + /// Given an index, lookup and clone its value, updating the + /// `accessed_at` time if necessary. + fn lookup_value(&self, db: &DB, index: u32) -> StampedValue { + let index = index as usize; + let revision_now = db.salsa_runtime().current_revision(); + + { + let tables = self.tables.read(); + match &tables.values[index] { + InternValue::Present { + accessed_at, + interned_at, + value, + } => { + if *accessed_at == revision_now { + return StampedValue { + value: value.clone(), + changed_at: ChangedAt { + is_constant: false, + revision: *interned_at, + }, + }; + } + } + + InternValue::Free { .. } => panic!("lookup of index {:?} found a free slot", index), + } + } + + let mut tables = self.tables.write(); + match &mut tables.values[index] { + InternValue::Present { + accessed_at, + interned_at, + value, + } => { + *accessed_at = revision_now; + + return StampedValue { + value: value.clone(), + changed_at: ChangedAt { + is_constant: false, + revision: *interned_at, + }, + }; + } + + InternValue::Free { .. } => panic!("lookup of index {:?} found a free slot", index), + } + } +} + +impl QueryStorageOps for InternedStorage +where + Q: Query, + Q::Key: ToOwned, + ::Owned: Eq + Hash + Clone, + Q::Value: From, + Q::Value: Into, + DB: Database, +{ + fn try_fetch( + &self, + db: &DB, + key: &Q::Key, + database_key: &DB::DatabaseKey, + ) -> Result { + let StampedValue { value, changed_at } = self.intern_index(db, key); + + db.salsa_runtime() + .report_query_read(database_key, changed_at); + + Ok(::from(value.index)) + } + + fn maybe_changed_since( + &self, + db: &DB, + revision: Revision, + key: &Q::Key, + _database_key: &DB::DatabaseKey, + ) -> bool { + match self.intern_check(db, key) { + Some(StampedValue { + value: _, + changed_at, + }) => changed_at.changed_since(revision), + None => true, + } + } + + fn is_constant(&self, _db: &DB, _key: &Q::Key) -> bool { + false + } + + fn entries(&self, _db: &DB) -> C + where + C: std::iter::FromIterator>, + { + let tables = self.tables.read(); + tables + .map + .iter() + .map(|(key, index)| TableEntry::new(key.clone(), Some(::from(index.index)))) + .collect() + } +} + +impl InternedQueryStorageOps for InternedStorage +where + Q: Query, + Q::Key: ToOwned, + ::Owned: Eq + Hash + Clone, + Q::Value: From, + Q::Value: Into, + DB: Database, +{ + fn lookup(&self, db: &DB, value: Q::Value) -> Q::Key { + let index: u32 = value.into(); + let StampedValue { + value, + changed_at: _, + } = self.lookup_value(db, index); + + // XXX -- this setup is wrong, we can't report the read. We + // should create a *second* query that is linked to this query + // somehow. Or, at least, we need a distinct entry in the + // group key so that we can implement the "maybe changed" and + // all that stuff. + + value + } +} + +impl QueryStorageMassOps for InternedStorage +where + Q: Query, + Q::Key: ToOwned, + Q::Value: From, + Q::Value: Into, + DB: Database, +{ + fn sweep(&self, db: &DB, strategy: SweepStrategy) { + let mut tables = self.tables.write(); + let revision_now = db.salsa_runtime().current_revision(); + let InternTables { + map, + values, + first_free, + } = &mut *tables; + map.retain(|key, intern_index| { + let discard = match strategy.discard_if { + DiscardIf::Never => false, + DiscardIf::Outdated => match values[intern_index.index()] { + InternValue::Present { accessed_at, .. } => accessed_at < revision_now, + + InternValue::Free { .. } => { + panic!( + "key {:?} maps to index {:?} which is free", + key, intern_index + ); + } + }, + DiscardIf::Always => false, + }; + + if discard { + values[intern_index.index()] = InternValue::Free { next: *first_free }; + *first_free = Some(*intern_index); + } + + !discard + }); + } +} diff --git a/src/lib.rs b/src/lib.rs index dab0200f..7bf2204e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod derived; mod input; +mod interned; mod runtime; pub mod debug; @@ -20,6 +21,7 @@ pub mod plumbing; use crate::plumbing::CycleDetected; use crate::plumbing::InputQueryStorageOps; +use crate::plumbing::InternedQueryStorageOps; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use derive_new::new; @@ -462,6 +464,17 @@ where }) } + /// For `#[salsa::interned]` queries only, does a reverse lookup + /// from the interned value (which must be some newtype'd integer) + /// to get the key that was interned. + pub fn lookup(&self, value: Q::Value) -> Q::Key + where + Q::Storage: plumbing::InternedQueryStorageOps, + Q::Value: Into, + { + self.storage.lookup(self.db, value) + } + /// Remove all values for this query that have not been used in /// the most recent revision. pub fn sweep(&self, strategy: SweepStrategy) diff --git a/src/plumbing.rs b/src/plumbing.rs index 01a909a9..d668cba9 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -13,6 +13,7 @@ pub use crate::derived::DependencyStorage; pub use crate::derived::MemoizedStorage; pub use crate::derived::VolatileStorage; pub use crate::input::InputStorage; +pub use crate::interned::InternedStorage; pub use crate::runtime::Revision; pub struct CycleDetected; @@ -183,6 +184,16 @@ where C: std::iter::FromIterator>; } +/// An optional trait that is implemented for "interned" storage. +pub trait InternedQueryStorageOps: Default +where + DB: Database, + Q: Query, + Q::Value: Into, +{ + fn lookup(&self, db: &DB, value: Q::Value) -> Q::Key; +} + /// An optional trait that is implemented for "user mutable" storage: /// that is, storage whose value is not derived from other storage but /// is set independently. diff --git a/tests/interned.rs b/tests/interned.rs new file mode 100644 index 00000000..ad5e175b --- /dev/null +++ b/tests/interned.rs @@ -0,0 +1,88 @@ +//! Test that you can implement a query using a `dyn Trait` setup. + +#[salsa::database(InternStorage)] +#[derive(Default)] +struct Database { + runtime: salsa::Runtime, +} + +impl salsa::Database for Database { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +#[salsa::query_group(InternStorage)] +trait Intern { + #[salsa::interned] + fn intern1(&self, x: String) -> u32; + + #[salsa::interned] + fn intern2(&self, x: String, y: String) -> u32; + + #[salsa::interned] + fn intern_key(&self, x: String) -> InternKey; +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct InternKey(u32); + +impl From for InternKey { + fn from(v: u32) -> Self { + InternKey(v) + } +} + +impl Into for InternKey { + fn into(self) -> u32 { + self.0 + } +} + +#[test] +fn test_intern1() { + let mut db = Database::default(); + let foo0 = db.intern1(format!("foo")); + let bar0 = db.intern1(format!("bar")); + let foo1 = db.intern1(format!("foo")); + let bar1 = db.intern1(format!("bar")); + + assert_eq!(foo0, foo1); + assert_eq!(bar0, bar1); + assert_ne!(foo0, bar0); + + assert_eq!(format!("foo"), db.lookup_intern1(foo0)); + assert_eq!(format!("bar"), db.lookup_intern1(bar0)); +} + +#[test] +fn test_intern2() { + let mut db = Database::default(); + let foo0 = db.intern2(format!("x"), format!("foo")); + let bar0 = db.intern2(format!("x"), format!("bar")); + let foo1 = db.intern2(format!("x"), format!("foo")); + let bar1 = db.intern2(format!("x"), format!("bar")); + + assert_eq!(foo0, foo1); + assert_eq!(bar0, bar1); + assert_ne!(foo0, bar0); + + assert_eq!((format!("x"), format!("foo")), db.lookup_intern2(foo0)); + assert_eq!((format!("x"), format!("bar")), db.lookup_intern2(bar0)); +} + +#[test] +fn test_intern_key() { + let mut db = Database::default(); + let foo0 = db.intern_key(format!("foo")); + let bar0 = db.intern_key(format!("bar")); + let foo1 = db.intern_key(format!("foo")); + let bar1 = db.intern_key(format!("bar")); + + assert_eq!(foo0, foo1); + assert_eq!(bar0, bar1); + assert_ne!(foo0, bar0); + + assert_eq!(format!("foo"), db.lookup_intern_key(foo0)); + assert_eq!(format!("bar"), db.lookup_intern_key(bar0)); +}