adopt the Salsa 3.0 Update` trait

Right now, this doesn't change much except the
behavior in the event that `Eq` is not properly
implemented. In the future, it will enable
the use of references and slices and things.
This commit is contained in:
Niko Matsakis 2024-04-10 06:34:36 -04:00
parent 225a81ae8f
commit 4533cd9e4b
7 changed files with 423 additions and 28 deletions

View file

@ -100,24 +100,33 @@ impl TrackedStruct {
// Create the function body that will update the revisions for each field.
// If a field is a "backdate field" (the default), then we first check if
// the new value is `==` to the old value. If so, we leave the revision unchanged.
let old_value = syn::Ident::new("old_value_", Span::call_site());
let new_value = syn::Ident::new("new_value_", Span::call_site());
let old_fields = syn::Ident::new("old_fields_", Span::call_site());
let new_fields = syn::Ident::new("new_fields_", Span::call_site());
let revisions = syn::Ident::new("revisions_", Span::call_site());
let current_revision = syn::Ident::new("current_revision_", Span::call_site());
let update_revisions: TokenStream = self
let update_fields: TokenStream = self
.all_fields()
.zip(0..)
.map(|(field, i)| {
let field_ty = field.ty();
let field_index = Literal::u32_unsuffixed(i);
if field.is_backdate_field() {
quote_spanned! { field.span() =>
if #old_value.#field_index != #new_value.#field_index {
if salsa::update::helper::Dispatch::<#field_ty>::maybe_update(
std::ptr::addr_of_mut!((*#old_fields).#field_index),
#new_fields.#field_index,
) {
#revisions[#field_index] = #current_revision;
}
}
} else {
quote_spanned! { field.span() =>
#revisions[#field_index] = #current_revision;
salsa::update::always_update(
&mut #revisions[#field_index],
#current_revision,
unsafe { &mut (*#old_fields).#field_index },
#new_fields.#field_index,
);
}
}
})
@ -142,13 +151,14 @@ impl TrackedStruct {
[current_revision; #arity]
}
fn update_revisions(
unsafe fn update_fields(
#current_revision: salsa::Revision,
#old_value: &Self::Fields,
#new_value: &Self::Fields,
#revisions: &mut Self::Revisions,
#old_fields: *mut Self::Fields,
#new_fields: Self::Fields,
) {
#update_revisions
use salsa::update::helper::Fallback as _;
#update_fields
}
}
}

View file

@ -23,6 +23,7 @@ pub mod salsa_struct;
pub mod setter;
pub mod storage;
pub mod tracked_struct;
pub mod update;
pub use self::cancelled::Cancelled;
pub use self::cycle::Cycle;

View file

@ -47,16 +47,32 @@ pub trait Configuration {
/// Create a new value revision array where each element is set to `current_revision`.
fn new_revisions(current_revision: Revision) -> Self::Revisions;
/// Update an existing value revision array `revisions`,
/// given the tuple of the old values (`old_value`)
/// and the tuple of the values (`new_value`).
/// If a value has changed, then its element is
/// updated to `current_revision`.
fn update_revisions(
/// Update the field data and, if the value has changed,
/// the appropriate entry in the `revisions` array.
///
/// # Safety requirements and conditions
///
/// Requires the same conditions as the `maybe_update`
/// method on [the `Update` trait](`crate::update::Update`).
///
/// In short, requires that `old_fields` be a pointer into
/// storage from a previous revision.
/// It must meet its validity invariant.
/// Owned content must meet safety invariant.
/// `*mut` here is not strictly needed;
/// it is used to signal that the content
/// is not guaranteed to recursively meet
/// its safety invariant and
/// hence this must be dereferenced with caution.
///
/// Ensures that `old_fields` is fully updated and valid
/// after it returns and that `revisions` has been updated
/// for any field that changed.
unsafe fn update_fields(
current_revision: Revision,
old_value: &Self::Fields,
new_value: &Self::Fields,
revisions: &mut Self::Revisions,
old_fields: *mut Self::Fields,
new_fields: Self::Fields,
);
}
// ANCHOR_END: Configuration
@ -210,19 +226,25 @@ where
} else {
let mut data = self.entity_data.get_mut(&id).unwrap();
let data = &mut *data;
// SAFETY: We assert that the pointer to `data.revisions`
// is a pointer into the database referencing a value
// from a previous revision. As such, it continues to meet
// its validity invariant and any owned content also continues
// to meet its safety invariant.
unsafe {
C::update_fields(
current_revision,
&mut data.revisions,
std::ptr::addr_of_mut!(data.fields),
fields,
);
}
if current_deps.durability < data.durability {
data.revisions = C::new_revisions(current_revision);
} else {
C::update_revisions(current_revision, &data.fields, &fields, &mut data.revisions);
}
data.created_at = current_revision;
data.durability = current_deps.durability;
// Subtle but important: we *always* update the values of the fields,
// even if they are `==` to the old values. This is because the `==`
// operation might not mean tha tthe fields are bitwise equal, and we
// want to take the new value.
data.fields = fields;
}
id

View file

@ -0,0 +1,242 @@
use std::path::PathBuf;
use crate::Revision;
/// This is used by the macro generated code.
/// If possible, uses `Update` trait, but else requires `'static`.
///
/// To use:
///
/// ```rust,ignore
/// use crate::update::helper::Fallback;
/// update::helper::Dispatch::<$ty>::maybe_update(pointer, new_value);
/// ```
///
/// It is important that you specify the `$ty` explicitly.
///
/// This uses the ["method dispatch hack"](https://github.com/nvzqz/impls#how-it-works)
/// to use the `Update` trait if it is available and else fallback to `'static`.
pub mod helper {
use std::marker::PhantomData;
use super::{update_fallback, Update};
pub struct Dispatch<D>(PhantomData<D>);
impl<D> Dispatch<D> {
pub fn new() -> Self {
Dispatch(PhantomData)
}
}
impl<D> Dispatch<D>
where
D: Update,
{
pub unsafe fn maybe_update(old_pointer: *mut D, new_value: D) -> bool {
unsafe { D::maybe_update(old_pointer, new_value) }
}
}
pub unsafe trait Fallback<T> {
/// Same safety conditions as `Update::maybe_update`
unsafe fn maybe_update(old_pointer: *mut T, new_value: T) -> bool;
}
unsafe impl<T: 'static + PartialEq> Fallback<T> for Dispatch<T> {
unsafe fn maybe_update(old_pointer: *mut T, new_value: T) -> bool {
unsafe { update_fallback(old_pointer, new_value) }
}
}
}
/// "Fallback" for maybe-update that is suitable for fully owned T
/// that implement `Eq`. In this version, we update only if the new value
/// is not `Eq` to the old one. Note that given `Eq` impls that are not just
/// structurally comparing fields, this may cause us not to update even if
/// the value has changed (presumably because this change is not semantically
/// significant).
///
/// # Safety
///
/// See `Update::maybe_update`
pub unsafe fn update_fallback<T>(old_pointer: *mut T, new_value: T) -> bool
where
T: 'static + PartialEq,
{
// Because everything is owned, this ref is simply a valid `&mut`
let old_ref: &mut T = unsafe { &mut *old_pointer };
if *old_ref != new_value {
*old_ref = new_value;
true
} else {
// Subtle but important: Eq impls can be buggy or define equality
// in surprising ways. If it says that the value has not changed,
// we do not modify the existing value, and thus do not have to
// update the revision, as downstream code will not see the new value.
false
}
}
/// Helper for generated code. Updates `*old_pointer` with `new_value`
/// and updates `*old_revision` with `new_revision.` Used for fields
/// tagged with `#[no_eq]`
pub fn always_update<T>(
old_revision: &mut Revision,
new_revision: Revision,
old_pointer: &mut T,
new_value: T,
) where
T: 'static,
{
*old_revision = new_revision;
*old_pointer = new_value;
}
/// The `unsafe` on the trait is to assert that `maybe_update` ensures
/// the properties it is intended to ensure.
pub unsafe trait Update {
/// # Returns
///
/// True if the value should be considered to have changed in the new revision.
///
/// # Unsafe contract
///
/// ## Requires
///
/// Informally, requires that `old_value` points to a value in the
/// database that is potentially from a previous revision and `new_value`
/// points to a value produced in this revision.
///
/// More formally, requires that
///
/// * all parameters meet the [validity and safety invariants][i] for their type
/// * `old_value` further points to allocated memory that meets the [validity invariant][i] for `Self`
/// * all data *owned* by `old_value` further meets its safety invariant
/// * not that borrowed data in `old_value` only meets its validity invariant
/// and hence cannot be dereferenced; essentially, a `&T` may point to memory
/// in the database which has been modified or even freed in the newer revision.
///
/// [i]: https://www.ralfj.de/blog/2018/08/22/two-kinds-of-invariants.html
///
/// ## Ensures
///
/// That `old_value` is updated with
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool;
}
unsafe impl<T> Update for &T {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let old_value: *const T = unsafe { *old_pointer };
if old_value != (new_value as *const T) {
unsafe {
*old_pointer = new_value;
}
true
} else {
false
}
}
}
unsafe impl<T> Update for Vec<T>
where
T: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
let old_vec: &mut Vec<T> = unsafe { &mut *old_pointer };
if old_vec.len() != new_vec.len() {
old_vec.clear();
old_vec.extend(new_vec);
return true;
}
let mut changed = false;
for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) {
changed |= T::maybe_update(old_element, new_element);
}
changed
}
}
unsafe impl<T, const N: usize> Update for [T; N]
where
T: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
let old_pointer: *mut T = std::ptr::addr_of_mut!((*old_pointer)[0]);
let mut changed = false;
for (new_element, i) in new_vec.into_iter().zip(0..) {
changed |= T::maybe_update(old_pointer.add(i), new_element);
}
changed
}
}
macro_rules! fallback_impl {
($($t:ty,)*) => {
$(
unsafe impl Update for $t {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
update_fallback(old_pointer, new_value)
}
}
)*
}
}
fallback_impl! {
String,
i64,
u64,
i32,
u32,
i16,
u16,
i8,
u8,
bool,
f32,
f64,
usize,
isize,
PathBuf,
}
macro_rules! tuple_impl {
($($t:ident),*; $($u:ident),*) => {
unsafe impl<$($t),*> Update for ($($t,)*)
where
$($t: Update,)*
{
#[allow(non_snake_case)]
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let ($($t,)*) = new_value;
let ($($u,)*) = unsafe { &mut *old_pointer };
let mut changed = false;
$(
unsafe { changed |= Update::maybe_update($u, $t); }
)*
changed
}
}
}
}
// Create implementations for tuples up to arity 12
tuple_impl!(A; a);
tuple_impl!(A, B; a, b);
tuple_impl!(A, B, C; a, b, c);
tuple_impl!(A, B, C, D; a, b, c, d);
tuple_impl!(A, B, C, D, E; a, b, c, d, e);
tuple_impl!(A, B, C, D, E, F; a, b, c, d, e, f);
tuple_impl!(A, B, C, D, E, F, G; a, b, c, d, e, f, g);
tuple_impl!(A, B, C, D, E, F, G, H; a, b, c, d, e, f, g, h);
tuple_impl!(A, B, C, D, E, F, G, H, I; a, b, c, d, e, f, g, h, i);
tuple_impl!(A, B, C, D, E, F, G, H, I, J; a, b, c, d, e, f, g, h, i, j);
tuple_impl!(A, B, C, D, E, F, G, H, I, J, K; a, b, c, d, e, f, g, h, i, j, k);
tuple_impl!(A, B, C, D, E, F, G, H, I, J, K, L; a, b, c, d, e, f, g, h, i, j, k, l);

View file

@ -1,6 +1,4 @@
//! Test a field whose `PartialEq` impl is always true.
//! This can our "last changed" data to be wrong
//! but we *should* always reflect the final values.
//! Test an id field whose `PartialEq` impl is always true.
use test_log::test;

View file

@ -0,0 +1,122 @@
//! Test a field whose `PartialEq` impl is always true.
//! This can result in us getting different results than
//! if we were to execute from scratch.
use expect_test::expect;
use salsa::DebugWithDb;
use salsa_2022_tests::{HasLogger, Logger};
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(
MyInput,
MyTracked,
the_fn,
make_tracked_struct,
read_tracked_struct,
);
trait Db: salsa::DbWithJar<Jar> {}
#[salsa::input]
struct MyInput {
field: bool,
}
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Eq, Hash, Debug, Clone)]
struct BadEq {
field: bool,
}
impl PartialEq for BadEq {
fn eq(&self, _other: &Self) -> bool {
true
}
}
impl From<bool> for BadEq {
fn from(value: bool) -> Self {
Self { field: value }
}
}
#[salsa::tracked]
struct MyTracked {
field: BadEq,
}
#[salsa::tracked]
fn the_fn(db: &dyn Db, input: MyInput) -> bool {
let tracked = make_tracked_struct(db, input);
read_tracked_struct(db, tracked)
}
#[salsa::tracked]
fn make_tracked_struct(db: &dyn Db, input: MyInput) -> MyTracked {
MyTracked::new(db, BadEq::from(input.field(db)))
}
#[salsa::tracked]
fn read_tracked_struct(db: &dyn Db, tracked: MyTracked) -> bool {
tracked.field(db).field
}
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
logger: Logger,
}
impl salsa::Database for Database {
fn salsa_event(&self, event: salsa::Event) {
match event.kind {
salsa::EventKind::WillExecute { .. }
| salsa::EventKind::DidValidateMemoizedValue { .. } => {
self.push_log(format!("salsa_event({:?})", event.kind.debug(self)));
}
_ => {}
}
}
}
impl Db for Database {}
impl HasLogger for Database {
fn logger(&self) -> &Logger {
&self.logger
}
}
#[test]
fn execute() {
let mut db = Database::default();
let input = MyInput::new(&db, true);
let result = the_fn(&db, input);
assert!(result);
db.assert_logs(expect![[r#"
[
"salsa_event(WillExecute { database_key: the_fn(0) })",
"salsa_event(WillExecute { database_key: make_tracked_struct(0) })",
"salsa_event(WillExecute { database_key: read_tracked_struct(0) })",
]"#]]);
// Update the input to `false` and re-execute.
input.set_field(&mut db).to(false);
let result = the_fn(&db, input);
// If the `Eq` impl were working properly, we would
// now return `false`. But because the `Eq` is considered
// equal we re-use memoized results and so we get true.
assert!(result);
db.assert_logs(expect![[r#"
[
"salsa_event(WillExecute { database_key: make_tracked_struct(0) })",
"salsa_event(DidValidateMemoizedValue { database_key: read_tracked_struct(0) })",
"salsa_event(DidValidateMemoizedValue { database_key: the_fn(0) })",
]"#]]);
}