From 2213729c4e44ce16d455a6d50c96cfb4c64da855 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Wed, 17 Jul 2024 08:42:06 -0400 Subject: [PATCH] wip --- .../src/setup_interned_struct.rs | 11 ++++++++ .../src/setup_tracked_struct.rs | 25 +++++++++++++------ components/salsa-macros/src/update.rs | 4 +-- src/lib.rs | 2 ++ tests/tracked_struct_db1_lt.rs | 18 ++----------- tests/tracked_with_intern.rs | 10 ++------ tests/tracked_with_struct_db.rs | 17 ++++--------- 7 files changed, 42 insertions(+), 45 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index e69e8055..deae9d76 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -111,6 +111,17 @@ macro_rules! setup_interned_struct { } } + unsafe impl $zalsa::Update for $Struct<'_> { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + if unsafe { *old_pointer } != new_value { + unsafe { *old_pointer = new_value }; + true + } else { + false + } + } + } + impl<$db_lt> $Struct<$db_lt> { pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self where diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 8e8619a7..937fca52 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -11,7 +11,7 @@ macro_rules! setup_tracked_struct { // Name of the struct Struct: $Struct:ident, - // Name of the `'db` lifetime that the user gave + // Name of the `$db_lt` lifetime that the user gave db_lt: $db_lt:lifetime, // Name user gave for `new` @@ -88,7 +88,7 @@ macro_rules! setup_tracked_struct { type Struct<$db_lt> = $Struct<$db_lt>; - unsafe fn struct_from_raw<'db>(ptr: $NonNull<$zalsa_struct::Value>) -> Self::Struct<'db> { + unsafe fn struct_from_raw<$db_lt>(ptr: $NonNull<$zalsa_struct::Value>) -> Self::Struct<$db_lt> { $Struct(ptr, std::marker::PhantomData) } @@ -104,11 +104,11 @@ macro_rules! setup_tracked_struct { $zalsa::Array::new([current_revision; $N]) } - unsafe fn update_fields<'db>( + unsafe fn update_fields<$db_lt>( current_revision: $Revision, revisions: &mut Self::Revisions, - old_fields: *mut Self::Fields<'db>, - new_fields: Self::Fields<'db>, + old_fields: *mut Self::Fields<$db_lt>, + new_fields: Self::Fields<$db_lt>, ) { use $zalsa::UpdateFallback as _; unsafe { @@ -137,8 +137,8 @@ macro_rules! setup_tracked_struct { } } - impl<'db> $zalsa::LookupId<'db> for $Struct<$db_lt> { - fn lookup_id(id: salsa::Id, db: &'db dyn $zalsa::Database) -> Self { + impl<$db_lt> $zalsa::LookupId<$db_lt> for $Struct<$db_lt> { + fn lookup_id(id: salsa::Id, db: &$db_lt dyn $zalsa::Database) -> Self { $Configuration::ingredient(db).lookup_struct(db.runtime(), id) } } @@ -173,6 +173,17 @@ macro_rules! setup_tracked_struct { } } + unsafe impl $zalsa::Update for $Struct<'_> { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + if unsafe { *old_pointer } != new_value { + unsafe { *old_pointer = new_value }; + true + } else { + false + } + } + } + impl<$db_lt> $Struct<$db_lt> { pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self where diff --git a/components/salsa-macros/src/update.rs b/components/salsa-macros/src/update.rs index 4bb91035..19112a9e 100644 --- a/components/salsa-macros/src/update.rs +++ b/components/salsa-macros/src/update.rs @@ -76,10 +76,10 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result let ident = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let tokens = quote! { - unsafe impl #impl_generics salsa::update::Update for #ident #ty_generics #where_clause { + unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause { unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool { use ::salsa::plumbing::UpdateFallback as _; - let old_pointer = unsafe { &mut *#old_pointer }; + let #old_pointer = unsafe { &mut *#old_pointer }; match #old_pointer { #fields } diff --git a/src/lib.rs b/src/lib.rs index 98fd6515..614223bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,7 @@ pub use self::key::DatabaseKeyIndex; pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::storage::Storage; +pub use self::update::Update; pub use salsa_macros::accumulator; pub use salsa_macros::db; pub use salsa_macros::input; @@ -88,6 +89,7 @@ pub mod plumbing { pub use crate::update::always_update; pub use crate::update::helper::Dispatch as UpdateDispatch; pub use crate::update::helper::Fallback as UpdateFallback; + pub use crate::update::Update; pub use salsa_macro_rules::macro_if; pub use salsa_macro_rules::maybe_backdate; diff --git a/tests/tracked_struct_db1_lt.rs b/tests/tracked_struct_db1_lt.rs index 97f3b3c3..6931ea7e 100644 --- a/tests/tracked_struct_db1_lt.rs +++ b/tests/tracked_struct_db1_lt.rs @@ -2,15 +2,9 @@ //! compile successfully. mod common; -use common::{HasLogger, Logger}; use test_log::test; -#[salsa::jar(db = Db)] -struct Jar(MyInput, MyTracked1<'_>, MyTracked2<'_>); - -trait Db: salsa::DbWithJar + HasLogger {} - #[salsa::input] struct MyInput { field: u32, @@ -26,22 +20,14 @@ struct MyTracked2<'db2> { field: u32, } -#[salsa::db(Jar)] +#[salsa::db] #[derive(Default)] struct Database { storage: salsa::Storage, - logger: Logger, } +#[salsa::db] impl salsa::Database for Database {} -impl Db for Database {} - -impl HasLogger for Database { - fn logger(&self) -> &Logger { - &self.logger - } -} - #[test] fn create_db() {} diff --git a/tests/tracked_with_intern.rs b/tests/tracked_with_intern.rs index d344ee3e..508a93c1 100644 --- a/tests/tracked_with_intern.rs +++ b/tests/tracked_with_intern.rs @@ -3,21 +3,15 @@ use test_log::test; -#[salsa::jar(db = Db)] -struct Jar(MyInput, MyTracked<'_>, MyInterned<'_>); - -trait Db: salsa::DbWithJar {} - -#[salsa::db(Jar)] +#[salsa::db] #[derive(Default)] struct Database { storage: salsa::Storage, } +#[salsa::db] impl salsa::Database for Database {} -impl Db for Database {} - #[salsa::input] struct MyInput { field: String, diff --git a/tests/tracked_with_struct_db.rs b/tests/tracked_with_struct_db.rs index 737aef25..5a044443 100644 --- a/tests/tracked_with_struct_db.rs +++ b/tests/tracked_with_struct_db.rs @@ -1,24 +1,17 @@ //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. -use salsa::DebugWithDb; use test_log::test; -#[salsa::jar(db = Db)] -struct Jar(MyInput, MyTracked<'_>, create_tracked_list); - -trait Db: salsa::DbWithJar {} - -#[salsa::db(Jar)] +#[salsa::db] #[derive(Default)] struct Database { storage: salsa::Storage, } +#[salsa::db] impl salsa::Database for Database {} -impl Db for Database {} - #[salsa::input] struct MyInput { field: String, @@ -30,14 +23,14 @@ struct MyTracked<'db> { next: MyList<'db>, } -#[derive(PartialEq, Eq, Clone, Debug, salsa::Update, salsa::DebugWithDb)] +#[derive(PartialEq, Eq, Clone, Debug, salsa::Update)] enum MyList<'db> { None, Next(MyTracked<'db>), } #[salsa::tracked] -fn create_tracked_list<'db>(db: &'db dyn Db, input: MyInput) -> MyTracked<'db> { +fn create_tracked_list<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { let t0 = MyTracked::new(db, input, MyList::None); let t1 = MyTracked::new(db, input, MyList::Next(t0)); t1 @@ -68,6 +61,6 @@ fn execute() { ), } "#]] - .assert_debug_eq(&t0.debug(&db)); + .assert_debug_eq(&t0); assert_eq!(t0, t1); }