This commit is contained in:
Niko Matsakis 2024-07-17 08:42:06 -04:00
parent daba89c278
commit 2213729c4e
7 changed files with 42 additions and 45 deletions

View file

@ -111,6 +111,17 @@ macro_rules! setup_interned_struct {
}
}
unsafe impl $zalsa::Update for $Struct<'_> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
if unsafe { *old_pointer } != new_value {
unsafe { *old_pointer = new_value };
true
} else {
false
}
}
}
impl<$db_lt> $Struct<$db_lt> {
pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self
where

View file

@ -11,7 +11,7 @@ macro_rules! setup_tracked_struct {
// Name of the struct
Struct: $Struct:ident,
// Name of the `'db` lifetime that the user gave
// Name of the `$db_lt` lifetime that the user gave
db_lt: $db_lt:lifetime,
// Name user gave for `new`
@ -88,7 +88,7 @@ macro_rules! setup_tracked_struct {
type Struct<$db_lt> = $Struct<$db_lt>;
unsafe fn struct_from_raw<'db>(ptr: $NonNull<$zalsa_struct::Value<Self>>) -> Self::Struct<'db> {
unsafe fn struct_from_raw<$db_lt>(ptr: $NonNull<$zalsa_struct::Value<Self>>) -> Self::Struct<$db_lt> {
$Struct(ptr, std::marker::PhantomData)
}
@ -104,11 +104,11 @@ macro_rules! setup_tracked_struct {
$zalsa::Array::new([current_revision; $N])
}
unsafe fn update_fields<'db>(
unsafe fn update_fields<$db_lt>(
current_revision: $Revision,
revisions: &mut Self::Revisions,
old_fields: *mut Self::Fields<'db>,
new_fields: Self::Fields<'db>,
old_fields: *mut Self::Fields<$db_lt>,
new_fields: Self::Fields<$db_lt>,
) {
use $zalsa::UpdateFallback as _;
unsafe {
@ -137,8 +137,8 @@ macro_rules! setup_tracked_struct {
}
}
impl<'db> $zalsa::LookupId<'db> for $Struct<$db_lt> {
fn lookup_id(id: salsa::Id, db: &'db dyn $zalsa::Database) -> Self {
impl<$db_lt> $zalsa::LookupId<$db_lt> for $Struct<$db_lt> {
fn lookup_id(id: salsa::Id, db: &$db_lt dyn $zalsa::Database) -> Self {
$Configuration::ingredient(db).lookup_struct(db.runtime(), id)
}
}
@ -173,6 +173,17 @@ macro_rules! setup_tracked_struct {
}
}
unsafe impl $zalsa::Update for $Struct<'_> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
if unsafe { *old_pointer } != new_value {
unsafe { *old_pointer = new_value };
true
} else {
false
}
}
}
impl<$db_lt> $Struct<$db_lt> {
pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: $field_ty),*) -> Self
where

View file

@ -76,10 +76,10 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
let ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let tokens = quote! {
unsafe impl #impl_generics salsa::update::Update for #ident #ty_generics #where_clause {
unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause {
unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool {
use ::salsa::plumbing::UpdateFallback as _;
let old_pointer = unsafe { &mut *#old_pointer };
let #old_pointer = unsafe { &mut *#old_pointer };
match #old_pointer {
#fields
}

View file

@ -38,6 +38,7 @@ pub use self::key::DatabaseKeyIndex;
pub use self::revision::Revision;
pub use self::runtime::Runtime;
pub use self::storage::Storage;
pub use self::update::Update;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;
@ -88,6 +89,7 @@ pub mod plumbing {
pub use crate::update::always_update;
pub use crate::update::helper::Dispatch as UpdateDispatch;
pub use crate::update::helper::Fallback as UpdateFallback;
pub use crate::update::Update;
pub use salsa_macro_rules::macro_if;
pub use salsa_macro_rules::maybe_backdate;

View file

@ -2,15 +2,9 @@
//! compile successfully.
mod common;
use common::{HasLogger, Logger};
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput, MyTracked1<'_>, MyTracked2<'_>);
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
#[salsa::input]
struct MyInput {
field: u32,
@ -26,22 +20,14 @@ struct MyTracked2<'db2> {
field: u32,
}
#[salsa::db(Jar)]
#[salsa::db]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
logger: Logger,
}
#[salsa::db]
impl salsa::Database for Database {}
impl Db for Database {}
impl HasLogger for Database {
fn logger(&self) -> &Logger {
&self.logger
}
}
#[test]
fn create_db() {}

View file

@ -3,21 +3,15 @@
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput, MyTracked<'_>, MyInterned<'_>);
trait Db: salsa::DbWithJar<Jar> {}
#[salsa::db(Jar)]
#[salsa::db]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
#[salsa::db]
impl salsa::Database for Database {}
impl Db for Database {}
#[salsa::input]
struct MyInput {
field: String,

View file

@ -1,24 +1,17 @@
//! Test that a setting a field on a `#[salsa::input]`
//! overwrites and returns the old value.
use salsa::DebugWithDb;
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput, MyTracked<'_>, create_tracked_list);
trait Db: salsa::DbWithJar<Jar> {}
#[salsa::db(Jar)]
#[salsa::db]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
#[salsa::db]
impl salsa::Database for Database {}
impl Db for Database {}
#[salsa::input]
struct MyInput {
field: String,
@ -30,14 +23,14 @@ struct MyTracked<'db> {
next: MyList<'db>,
}
#[derive(PartialEq, Eq, Clone, Debug, salsa::Update, salsa::DebugWithDb)]
#[derive(PartialEq, Eq, Clone, Debug, salsa::Update)]
enum MyList<'db> {
None,
Next(MyTracked<'db>),
}
#[salsa::tracked]
fn create_tracked_list<'db>(db: &'db dyn Db, input: MyInput) -> MyTracked<'db> {
fn create_tracked_list<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> {
let t0 = MyTracked::new(db, input, MyList::None);
let t1 = MyTracked::new(db, input, MyList::Next(t0));
t1
@ -68,6 +61,6 @@ fn execute() {
),
}
"#]]
.assert_debug_eq(&t0.debug(&db));
.assert_debug_eq(&t0);
assert_eq!(t0, t1);
}