From e7d704dd8bcb133f8af164a6c786d3577f3fb52a Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Thu, 20 Jun 2019 20:46:51 -0700 Subject: [PATCH] convert `DatabaseSlot` to unsafe trait The unsafe impl now asserts that the `DatabaseSlot` implementor type is indeed `Send+Sync` if `DB::DatabaseData` is `Send+Sync`. Since our query keys/values are a part of database-data, this means that `Slot` must be `Send+Sync` if the key/value are `Send+Sync`. We test this with a function that will cause compliation to fail if we accidentally introduce an `Rc` etc. --- components/salsa-macros/src/query_group.rs | 4 +- src/dependency.rs | 13 ++- src/derived.rs | 2 +- src/derived/slot.rs | 41 ++++++- src/doctest.rs | 121 +++++++++++++++++++++ src/input.rs | 39 ++++++- src/interned.rs | 30 ++++- src/lib.rs | 9 +- 8 files changed, 250 insertions(+), 9 deletions(-) create mode 100644 src/doctest.rs diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 6349bf9..ec36306 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -367,7 +367,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream #[derive(Default, Debug)] #trait_vis struct #qt; - impl<#db> salsa::Query<#db> for #qt + // Unsafe proof obligation: that our key/value are a part + // of the `GroupData`. + unsafe impl<#db> salsa::Query<#db> for #qt where DB: #trait_name + #requires, DB: salsa::plumbing::HasQueryGroup<#group_struct>, diff --git a/src/dependency.rs b/src/dependency.rs index af61b62..d4cf6be 100644 --- a/src/dependency.rs +++ b/src/dependency.rs @@ -4,8 +4,11 @@ use std::fmt::Debug; use std::hash::Hasher; use std::sync::Arc; -/// Each kind of query exports a "slot". -pub(crate) trait DatabaseSlot: Debug { +/// Unsafe proof obligations: +/// +/// - If `DB::DatabaseData: Send + Sync`, then `Self: Send + Sync` +/// - If `DB: 'static` and `DB::DatabaseData: 'static`, then `Self: 'static` +pub(crate) unsafe trait DatabaseSlot: Debug { /// Returns true if the value of this query may have changed since /// the given revision. fn maybe_changed_since(&self, db: &DB, revision: Revision) -> bool; @@ -13,6 +16,7 @@ pub(crate) trait DatabaseSlot: Debug { pub(crate) struct Dependency { slot: Arc + Send + Sync>, + phantom: std::marker::PhantomData>, } impl Dependency { @@ -22,7 +26,10 @@ impl Dependency { // Hiding these bounds behind a trait object is a total hack // but I just want to see how well this works. -nikomatsakis let slot: Arc + Send + Sync> = unsafe { std::mem::transmute(slot) }; - Self { slot } + Self { + slot, + phantom: std::marker::PhantomData, + } } fn raw_slot(&self) -> *const dyn DatabaseSlot { diff --git a/src/derived.rs b/src/derived.rs index 2a8c2e0..2bf47eb 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -49,7 +49,7 @@ where { } -pub trait MemoizationPolicy +pub trait MemoizationPolicy: Send + Sync + 'static where Q: QueryFunction, DB: Database, diff --git a/src/derived/slot.rs b/src/derived/slot.rs index 239cd05..b055044 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -755,7 +755,11 @@ where } } -impl DatabaseSlot for Slot +// The unsafe obligation here is for us to assert that `Slot` is `Send + Sync + 'static`, assuming `Q::Key` and `Q::Value` +// are. We assert this with the `check_send_sync` and `check_static` +// functions below. +unsafe impl DatabaseSlot for Slot where Q: QueryFunction, DB: Database + HasQueryGroup, @@ -933,3 +937,38 @@ where maybe_changed } } + +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + Q: QueryFunction, + DB: Database + HasQueryGroup, + MP: MemoizationPolicy, + DB::DatabaseData: Send + Sync, + Q::Key: Send + Sync, + Q::Value: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + Q: QueryFunction, + DB: Database + HasQueryGroup, + MP: MemoizationPolicy, + DB: 'static, + DB::DatabaseData: 'static, + Q::Key: 'static, + Q::Value: 'static, +{ + fn is_static() {} + is_static::>(); +} diff --git a/src/doctest.rs b/src/doctest.rs new file mode 100644 index 0000000..ff8de9a --- /dev/null +++ b/src/doctest.rs @@ -0,0 +1,121 @@ +#![allow(dead_code)] + +/// Test that a database with a key/value that is not `Send` will, +/// indeed, not be `Send`. +/// +/// ```compile_fail,E0277 +/// use std::rc::Rc; +/// +/// #[salsa::query_group(NoSendSyncStorage)] +/// trait NoSendSyncDatabase: salsa::Database { +/// fn no_send_sync_value(&self, key: bool) -> Rc; +/// fn no_send_sync_key(&self, key: Rc) -> bool; +/// } +/// +/// fn no_send_sync_value(_db: &impl NoSendSyncDatabase, key: bool) -> Rc { +/// Rc::new(key) +/// } +/// +/// fn no_send_sync_key(_db: &impl NoSendSyncDatabase, key: Rc) -> bool { +/// *key +/// } +/// +/// #[salsa::database(NoSendSyncStorage)] +/// #[derive(Default)] +/// struct DatabaseImpl { +/// runtime: salsa::Runtime, +/// } +/// +/// impl salsa::Database for DatabaseImpl { +/// fn salsa_runtime(&self) -> &salsa::Runtime { +/// &self.runtime +/// } +/// } +/// +/// fn is_send(_: T) { } +/// +/// fn assert_send() { +/// is_send(DatabaseImpl::default()); +/// } +/// ``` +fn test_key_not_send_db_not_send() {} + +/// Test that a database with a key/value that is not `Sync` will not +/// be `Send`. +/// +/// ```compile_fail,E0277 +/// use std::rc::Rc; +/// +/// #[salsa::query_group(NoSendSyncStorage)] +/// trait NoSendSyncDatabase: salsa::Database { +/// fn no_send_sync_value(&self, key: bool) -> Cell; +/// fn no_send_sync_key(&self, key: Cell) -> bool; +/// } +/// +/// fn no_send_sync_value(_db: &impl NoSendSyncDatabase, key: bool) -> Cell { +/// Cell::new(key) +/// } +/// +/// fn no_send_sync_key(_db: &impl NoSendSyncDatabase, key: Cell) -> bool { +/// *key +/// } +/// +/// #[salsa::database(NoSendSyncStorage)] +/// #[derive(Default)] +/// struct DatabaseImpl { +/// runtime: salsa::Runtime, +/// } +/// +/// impl salsa::Database for DatabaseImpl { +/// fn salsa_runtime(&self) -> &salsa::Runtime { +/// &self.runtime +/// } +/// } +/// +/// fn is_send(_: T) { } +/// +/// fn assert_send() { +/// is_send(DatabaseImpl::default()); +/// } +/// ``` +fn test_key_not_sync_db_not_send() {} + +/// Test that a database with a key/value that is not `Sync` will +/// not be `Sync`. +/// +/// ```compile_fail,E0277 +/// use std::rc::Rc; +/// +/// #[salsa::query_group(NoSendSyncStorage)] +/// trait NoSendSyncDatabase: salsa::Database { +/// fn no_send_sync_value(&self, key: bool) -> Cell; +/// fn no_send_sync_key(&self, key: Cell) -> bool; +/// } +/// +/// fn no_send_sync_value(_db: &impl NoSendSyncDatabase, key: bool) -> Cell { +/// Cell::new(key) +/// } +/// +/// fn no_send_sync_key(_db: &impl NoSendSyncDatabase, key: Cell) -> bool { +/// *key +/// } +/// +/// #[salsa::database(NoSendSyncStorage)] +/// #[derive(Default)] +/// struct DatabaseImpl { +/// runtime: salsa::Runtime, +/// } +/// +/// impl salsa::Database for DatabaseImpl { +/// fn salsa_runtime(&self) -> &salsa::Runtime { +/// &self.runtime +/// } +/// } +/// +/// fn is_sync(_: T) { } +/// +/// fn assert_send() { +/// is_sync(DatabaseImpl::default()); +/// } +/// ``` +fn test_key_not_sync_db_not_sync() {} diff --git a/src/input.rs b/src/input.rs index 172e7ad..00c1db0 100644 --- a/src/input.rs +++ b/src/input.rs @@ -201,7 +201,11 @@ where } } -impl DatabaseSlot for Slot +// Unsafe proof obligation: `Slot` is Send + Sync if the query +// key/value is Send + Sync (also, that we introduce no +// references). These are tested by the `check_send_sync` and +// `check_static` helpers below. +unsafe impl DatabaseSlot for Slot where Q: Query, DB: Database, @@ -220,6 +224,39 @@ where } } +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + Q: Query, + DB: Database, + DB::DatabaseData: Send + Sync, + Q::Key: Send + Sync, + Q::Value: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + Q: Query, + DB: Database, + DB: 'static, + DB::DatabaseData: 'static, + Q::Key: 'static, + Q::Value: 'static, +{ + fn is_static() {} + is_static::>(); +} + impl std::fmt::Debug for Slot where Q: Query, diff --git a/src/interned.rs b/src/interned.rs index 8415a6a..d558046 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -526,7 +526,11 @@ impl Slot { } } -impl DatabaseSlot for Slot +// Unsafe proof obligation: `Slot` is Send + Sync if the query +// key/value is Send + Sync (also, that we introduce no +// references). These are tested by the `check_send_sync` and +// `check_static` helpers below. +unsafe impl DatabaseSlot for Slot where DB: Database, K: Debug, @@ -542,3 +546,27 @@ where } } } + +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + K: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + K: 'static, +{ + fn is_static() {} + is_static::>(); +} diff --git a/src/lib.rs b/src/lib.rs index 6aa220d..7ea49b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod dependency; mod derived; +mod doctest; mod input; mod intern_id; mod interned; @@ -404,7 +405,13 @@ where /// Trait implements by all of the "special types" associated with /// each of your queries. -pub trait Query: Debug + Default + Sized + 'static { +/// +/// Unsafe trait obligation: Asserts that the Key/Value associated +/// types for this trait are a part of the `Group::GroupData` type. +/// In particular, `Group::GroupData: Send + Sync` must imply that +/// `Key: Send + Sync` and `Value: Send + Sync`. This is relied upon +/// by the dependency tracking logic. +pub unsafe trait Query: Debug + Default + Sized + 'static { /// Type that you you give as a parameter -- for queries with zero /// or more than one input, this will be a tuple. type Key: Clone + Debug + Hash + Eq;