diff --git a/components/salsa-2022-macros/src/tracked_struct.rs b/components/salsa-2022-macros/src/tracked_struct.rs index 85afbbe0..5ec21af9 100644 --- a/components/salsa-2022-macros/src/tracked_struct.rs +++ b/components/salsa-2022-macros/src/tracked_struct.rs @@ -100,24 +100,33 @@ impl TrackedStruct { // Create the function body that will update the revisions for each field. // If a field is a "backdate field" (the default), then we first check if // the new value is `==` to the old value. If so, we leave the revision unchanged. - let old_value = syn::Ident::new("old_value_", Span::call_site()); - let new_value = syn::Ident::new("new_value_", Span::call_site()); + let old_fields = syn::Ident::new("old_fields_", Span::call_site()); + let new_fields = syn::Ident::new("new_fields_", Span::call_site()); let revisions = syn::Ident::new("revisions_", Span::call_site()); let current_revision = syn::Ident::new("current_revision_", Span::call_site()); - let update_revisions: TokenStream = self + let update_fields: TokenStream = self .all_fields() .zip(0..) .map(|(field, i)| { + let field_ty = field.ty(); let field_index = Literal::u32_unsuffixed(i); if field.is_backdate_field() { quote_spanned! { field.span() => - if #old_value.#field_index != #new_value.#field_index { + if salsa::update::helper::Dispatch::<#field_ty>::maybe_update( + std::ptr::addr_of_mut!((*#old_fields).#field_index), + #new_fields.#field_index, + ) { #revisions[#field_index] = #current_revision; } } } else { quote_spanned! { field.span() => - #revisions[#field_index] = #current_revision; + salsa::update::always_update( + &mut #revisions[#field_index], + #current_revision, + unsafe { &mut (*#old_fields).#field_index }, + #new_fields.#field_index, + ); } } }) @@ -142,13 +151,14 @@ impl TrackedStruct { [current_revision; #arity] } - fn update_revisions( + unsafe fn update_fields( #current_revision: salsa::Revision, - #old_value: &Self::Fields, - #new_value: &Self::Fields, #revisions: &mut Self::Revisions, + #old_fields: *mut Self::Fields, + #new_fields: Self::Fields, ) { - #update_revisions + use salsa::update::helper::Fallback as _; + #update_fields } } } diff --git a/components/salsa-2022/src/lib.rs b/components/salsa-2022/src/lib.rs index 2a96a9c8..b21b47d2 100644 --- a/components/salsa-2022/src/lib.rs +++ b/components/salsa-2022/src/lib.rs @@ -23,6 +23,7 @@ pub mod salsa_struct; pub mod setter; pub mod storage; pub mod tracked_struct; +pub mod update; pub use self::cancelled::Cancelled; pub use self::cycle::Cycle; diff --git a/components/salsa-2022/src/tracked_struct.rs b/components/salsa-2022/src/tracked_struct.rs index 189e5c5c..012dd11d 100644 --- a/components/salsa-2022/src/tracked_struct.rs +++ b/components/salsa-2022/src/tracked_struct.rs @@ -47,16 +47,32 @@ pub trait Configuration { /// Create a new value revision array where each element is set to `current_revision`. fn new_revisions(current_revision: Revision) -> Self::Revisions; - /// Update an existing value revision array `revisions`, - /// given the tuple of the old values (`old_value`) - /// and the tuple of the values (`new_value`). - /// If a value has changed, then its element is - /// updated to `current_revision`. - fn update_revisions( + /// Update the field data and, if the value has changed, + /// the appropriate entry in the `revisions` array. + /// + /// # Safety requirements and conditions + /// + /// Requires the same conditions as the `maybe_update` + /// method on [the `Update` trait](`crate::update::Update`). + /// + /// In short, requires that `old_fields` be a pointer into + /// storage from a previous revision. + /// It must meet its validity invariant. + /// Owned content must meet safety invariant. + /// `*mut` here is not strictly needed; + /// it is used to signal that the content + /// is not guaranteed to recursively meet + /// its safety invariant and + /// hence this must be dereferenced with caution. + /// + /// Ensures that `old_fields` is fully updated and valid + /// after it returns and that `revisions` has been updated + /// for any field that changed. + unsafe fn update_fields( current_revision: Revision, - old_value: &Self::Fields, - new_value: &Self::Fields, revisions: &mut Self::Revisions, + old_fields: *mut Self::Fields, + new_fields: Self::Fields, ); } // ANCHOR_END: Configuration @@ -210,19 +226,25 @@ where } else { let mut data = self.entity_data.get_mut(&id).unwrap(); let data = &mut *data; + + // SAFETY: We assert that the pointer to `data.revisions` + // is a pointer into the database referencing a value + // from a previous revision. As such, it continues to meet + // its validity invariant and any owned content also continues + // to meet its safety invariant. + unsafe { + C::update_fields( + current_revision, + &mut data.revisions, + std::ptr::addr_of_mut!(data.fields), + fields, + ); + } if current_deps.durability < data.durability { data.revisions = C::new_revisions(current_revision); - } else { - C::update_revisions(current_revision, &data.fields, &fields, &mut data.revisions); } data.created_at = current_revision; data.durability = current_deps.durability; - - // Subtle but important: we *always* update the values of the fields, - // even if they are `==` to the old values. This is because the `==` - // operation might not mean tha tthe fields are bitwise equal, and we - // want to take the new value. - data.fields = fields; } id diff --git a/components/salsa-2022/src/update.rs b/components/salsa-2022/src/update.rs new file mode 100644 index 00000000..ad5b6d3e --- /dev/null +++ b/components/salsa-2022/src/update.rs @@ -0,0 +1,242 @@ +use std::path::PathBuf; + +use crate::Revision; + +/// This is used by the macro generated code. +/// If possible, uses `Update` trait, but else requires `'static`. +/// +/// To use: +/// +/// ```rust,ignore +/// use crate::update::helper::Fallback; +/// update::helper::Dispatch::<$ty>::maybe_update(pointer, new_value); +/// ``` +/// +/// It is important that you specify the `$ty` explicitly. +/// +/// This uses the ["method dispatch hack"](https://github.com/nvzqz/impls#how-it-works) +/// to use the `Update` trait if it is available and else fallback to `'static`. +pub mod helper { + use std::marker::PhantomData; + + use super::{update_fallback, Update}; + + pub struct Dispatch(PhantomData); + + impl Dispatch { + pub fn new() -> Self { + Dispatch(PhantomData) + } + } + + impl Dispatch + where + D: Update, + { + pub unsafe fn maybe_update(old_pointer: *mut D, new_value: D) -> bool { + unsafe { D::maybe_update(old_pointer, new_value) } + } + } + + pub unsafe trait Fallback { + /// Same safety conditions as `Update::maybe_update` + unsafe fn maybe_update(old_pointer: *mut T, new_value: T) -> bool; + } + + unsafe impl Fallback for Dispatch { + unsafe fn maybe_update(old_pointer: *mut T, new_value: T) -> bool { + unsafe { update_fallback(old_pointer, new_value) } + } + } +} + +/// "Fallback" for maybe-update that is suitable for fully owned T +/// that implement `Eq`. In this version, we update only if the new value +/// is not `Eq` to the old one. Note that given `Eq` impls that are not just +/// structurally comparing fields, this may cause us not to update even if +/// the value has changed (presumably because this change is not semantically +/// significant). +/// +/// # Safety +/// +/// See `Update::maybe_update` +pub unsafe fn update_fallback(old_pointer: *mut T, new_value: T) -> bool +where + T: 'static + PartialEq, +{ + // Because everything is owned, this ref is simply a valid `&mut` + let old_ref: &mut T = unsafe { &mut *old_pointer }; + + if *old_ref != new_value { + *old_ref = new_value; + true + } else { + // Subtle but important: Eq impls can be buggy or define equality + // in surprising ways. If it says that the value has not changed, + // we do not modify the existing value, and thus do not have to + // update the revision, as downstream code will not see the new value. + false + } +} + +/// Helper for generated code. Updates `*old_pointer` with `new_value` +/// and updates `*old_revision` with `new_revision.` Used for fields +/// tagged with `#[no_eq]` +pub fn always_update( + old_revision: &mut Revision, + new_revision: Revision, + old_pointer: &mut T, + new_value: T, +) where + T: 'static, +{ + *old_revision = new_revision; + *old_pointer = new_value; +} + +/// The `unsafe` on the trait is to assert that `maybe_update` ensures +/// the properties it is intended to ensure. +pub unsafe trait Update { + /// # Returns + /// + /// True if the value should be considered to have changed in the new revision. + /// + /// # Unsafe contract + /// + /// ## Requires + /// + /// Informally, requires that `old_value` points to a value in the + /// database that is potentially from a previous revision and `new_value` + /// points to a value produced in this revision. + /// + /// More formally, requires that + /// + /// * all parameters meet the [validity and safety invariants][i] for their type + /// * `old_value` further points to allocated memory that meets the [validity invariant][i] for `Self` + /// * all data *owned* by `old_value` further meets its safety invariant + /// * not that borrowed data in `old_value` only meets its validity invariant + /// and hence cannot be dereferenced; essentially, a `&T` may point to memory + /// in the database which has been modified or even freed in the newer revision. + /// + /// [i]: https://www.ralfj.de/blog/2018/08/22/two-kinds-of-invariants.html + /// + /// ## Ensures + /// + /// That `old_value` is updated with + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool; +} + +unsafe impl Update for &T { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + let old_value: *const T = unsafe { *old_pointer }; + if old_value != (new_value as *const T) { + unsafe { + *old_pointer = new_value; + } + true + } else { + false + } + } +} + +unsafe impl Update for Vec +where + T: Update, +{ + unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool { + let old_vec: &mut Vec = unsafe { &mut *old_pointer }; + + if old_vec.len() != new_vec.len() { + old_vec.clear(); + old_vec.extend(new_vec); + return true; + } + + let mut changed = false; + for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) { + changed |= T::maybe_update(old_element, new_element); + } + + changed + } +} + +unsafe impl Update for [T; N] +where + T: Update, +{ + unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool { + let old_pointer: *mut T = std::ptr::addr_of_mut!((*old_pointer)[0]); + let mut changed = false; + for (new_element, i) in new_vec.into_iter().zip(0..) { + changed |= T::maybe_update(old_pointer.add(i), new_element); + } + changed + } +} + +macro_rules! fallback_impl { + ($($t:ty,)*) => { + $( + unsafe impl Update for $t { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + update_fallback(old_pointer, new_value) + } + } + )* + } +} + +fallback_impl! { + String, + i64, + u64, + i32, + u32, + i16, + u16, + i8, + u8, + bool, + f32, + f64, + usize, + isize, + PathBuf, +} + +macro_rules! tuple_impl { + ($($t:ident),*; $($u:ident),*) => { + unsafe impl<$($t),*> Update for ($($t,)*) + where + $($t: Update,)* + { + #[allow(non_snake_case)] + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + let ($($t,)*) = new_value; + let ($($u,)*) = unsafe { &mut *old_pointer }; + + let mut changed = false; + $( + unsafe { changed |= Update::maybe_update($u, $t); } + )* + changed + } + } + } +} + +// Create implementations for tuples up to arity 12 +tuple_impl!(A; a); +tuple_impl!(A, B; a, b); +tuple_impl!(A, B, C; a, b, c); +tuple_impl!(A, B, C, D; a, b, c, d); +tuple_impl!(A, B, C, D, E; a, b, c, d, e); +tuple_impl!(A, B, C, D, E, F; a, b, c, d, e, f); +tuple_impl!(A, B, C, D, E, F, G; a, b, c, d, e, f, g); +tuple_impl!(A, B, C, D, E, F, G, H; a, b, c, d, e, f, g, h); +tuple_impl!(A, B, C, D, E, F, G, H, I; a, b, c, d, e, f, g, h, i); +tuple_impl!(A, B, C, D, E, F, G, H, I, J; a, b, c, d, e, f, g, h, i, j); +tuple_impl!(A, B, C, D, E, F, G, H, I, J, K; a, b, c, d, e, f, g, h, i, j, k); +tuple_impl!(A, B, C, D, E, F, G, H, I, J, K, L; a, b, c, d, e, f, g, h, i, j, k, l); diff --git a/salsa-2022-tests/tests/tracked-struct-field-bad-eq.rs b/salsa-2022-tests/tests/tracked-struct-id-field-bad-eq.rs similarity index 87% rename from salsa-2022-tests/tests/tracked-struct-field-bad-eq.rs rename to salsa-2022-tests/tests/tracked-struct-id-field-bad-eq.rs index bd487a8f..fd9c38a2 100644 --- a/salsa-2022-tests/tests/tracked-struct-field-bad-eq.rs +++ b/salsa-2022-tests/tests/tracked-struct-id-field-bad-eq.rs @@ -1,6 +1,4 @@ -//! Test a field whose `PartialEq` impl is always true. -//! This can our "last changed" data to be wrong -//! but we *should* always reflect the final values. +//! Test an id field whose `PartialEq` impl is always true. use test_log::test; diff --git a/salsa-2022-tests/tests/tracked-struct-value-field-bad-eq.rs b/salsa-2022-tests/tests/tracked-struct-value-field-bad-eq.rs new file mode 100644 index 00000000..94651b0f --- /dev/null +++ b/salsa-2022-tests/tests/tracked-struct-value-field-bad-eq.rs @@ -0,0 +1,122 @@ +//! Test a field whose `PartialEq` impl is always true. +//! This can result in us getting different results than +//! if we were to execute from scratch. + +use expect_test::expect; +use salsa::DebugWithDb; +use salsa_2022_tests::{HasLogger, Logger}; +use test_log::test; + +#[salsa::jar(db = Db)] +struct Jar( + MyInput, + MyTracked, + the_fn, + make_tracked_struct, + read_tracked_struct, +); + +trait Db: salsa::DbWithJar {} + +#[salsa::input] +struct MyInput { + field: bool, +} + +#[allow(clippy::derived_hash_with_manual_eq)] +#[derive(Eq, Hash, Debug, Clone)] +struct BadEq { + field: bool, +} + +impl PartialEq for BadEq { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl From for BadEq { + fn from(value: bool) -> Self { + Self { field: value } + } +} + +#[salsa::tracked] +struct MyTracked { + field: BadEq, +} + +#[salsa::tracked] +fn the_fn(db: &dyn Db, input: MyInput) -> bool { + let tracked = make_tracked_struct(db, input); + read_tracked_struct(db, tracked) +} + +#[salsa::tracked] +fn make_tracked_struct(db: &dyn Db, input: MyInput) -> MyTracked { + MyTracked::new(db, BadEq::from(input.field(db))) +} + +#[salsa::tracked] +fn read_tracked_struct(db: &dyn Db, tracked: MyTracked) -> bool { + tracked.field(db).field +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, + logger: Logger, +} + +impl salsa::Database for Database { + fn salsa_event(&self, event: salsa::Event) { + match event.kind { + salsa::EventKind::WillExecute { .. } + | salsa::EventKind::DidValidateMemoizedValue { .. } => { + self.push_log(format!("salsa_event({:?})", event.kind.debug(self))); + } + _ => {} + } + } +} + +impl Db for Database {} + +impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } +} + +#[test] +fn execute() { + let mut db = Database::default(); + + let input = MyInput::new(&db, true); + let result = the_fn(&db, input); + assert!(result); + + db.assert_logs(expect![[r#" + [ + "salsa_event(WillExecute { database_key: the_fn(0) })", + "salsa_event(WillExecute { database_key: make_tracked_struct(0) })", + "salsa_event(WillExecute { database_key: read_tracked_struct(0) })", + ]"#]]); + + // Update the input to `false` and re-execute. + input.set_field(&mut db).to(false); + let result = the_fn(&db, input); + + // If the `Eq` impl were working properly, we would + // now return `false`. But because the `Eq` is considered + // equal we re-use memoized results and so we get true. + assert!(result); + + db.assert_logs(expect![[r#" + [ + "salsa_event(WillExecute { database_key: make_tracked_struct(0) })", + "salsa_event(DidValidateMemoizedValue { database_key: read_tracked_struct(0) })", + "salsa_event(DidValidateMemoizedValue { database_key: the_fn(0) })", + ]"#]]); +} diff --git a/salsa-2022-tests/tests/tracked-struct-field-not-eq.rs b/salsa-2022-tests/tests/tracked-struct-value-field-not-eq.rs similarity index 100% rename from salsa-2022-tests/tests/tracked-struct-field-not-eq.rs rename to salsa-2022-tests/tests/tracked-struct-value-field-not-eq.rs