From 53a3235a607c6a7b3e998e78408e5b3857e49024 Mon Sep 17 00:00:00 2001 From: puuuuh Date: Sat, 12 Oct 2024 18:57:30 +0300 Subject: [PATCH] Add IngredientIndex to KeyStruct --- src/active_query.rs | 12 ++++++++---- src/tracked_struct.rs | 9 +++++++-- src/zalsa_local.rs | 8 ++++++-- tests/hash_collision.rs | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 8 deletions(-) create mode 100644 tests/hash_collision.rs diff --git a/src/active_query.rs b/src/active_query.rs index 6d524b6..88942ba 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -7,7 +7,7 @@ use crate::{ key::{DatabaseKeyIndex, DependencyIndex}, tracked_struct::{Disambiguator, KeyStruct}, zalsa_local::EMPTY_DEPENDENCIES, - Cycle, Revision, + Cycle, IngredientIndex, Revision, }; use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; @@ -45,7 +45,7 @@ pub(crate) struct ActiveQuery { /// This table starts empty as the query begins and is gradually populated. /// Note that if a query executes in 2 different revisions but creates the same /// set of tracked structs, they will get the same disambiguator values. - disambiguator_map: FxHashMap, + disambiguator_map: FxHashMap<(IngredientIndex, u64), Disambiguator>, /// Map from tracked struct keys (which include the hash + disambiguator) to their /// final id. @@ -155,10 +155,14 @@ impl ActiveQuery { self.input_outputs.clone_from(&cycle_query.input_outputs); } - pub(super) fn disambiguate(&mut self, hash: u64) -> Disambiguator { + pub(super) fn disambiguate( + &mut self, + ingredient_index: IngredientIndex, + hash: u64, + ) -> Disambiguator { let disambiguator = self .disambiguator_map - .entry(hash) + .entry((ingredient_index, hash)) .or_insert(Disambiguator(0)); let result = *disambiguator; disambiguator.0 += 1; diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index be48489..27f27c5 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -148,6 +148,9 @@ where /// struct and later moved to the [`Memo`](`crate::function::memo::Memo`). #[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] pub(crate) struct KeyStruct { + /// IngredientIndex of the tracked struct + ingredient_index: IngredientIndex, + /// The hash of the `#[id]` fields of this struct. /// Note that multiple structs may share the same hash. data_hash: u64, @@ -255,11 +258,13 @@ where ) -> C::Struct<'db> { let (zalsa, zalsa_local) = db.zalsas(); - let data_hash = crate::hash::hash(&C::id_fields(&fields)); + let data_hash = crate::hash::hash(&(C::id_fields(&fields))); - let (current_deps, disambiguator) = zalsa_local.disambiguate(data_hash); + let (current_deps, disambiguator) = + zalsa_local.disambiguate(self.ingredient_index, data_hash); let key_struct = KeyStruct { + ingredient_index: self.ingredient_index, disambiguator, data_hash, }; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index d5e5083..33a9097 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -262,7 +262,11 @@ impl ZalsaLocal { /// * the current dependencies (durability, changed_at) of current query /// * the disambiguator index #[track_caller] - pub(crate) fn disambiguate(&self, data_hash: u64) -> (StampedValue<()>, Disambiguator) { + pub(crate) fn disambiguate( + &self, + ingredient_index: IngredientIndex, + data_hash: u64, + ) -> (StampedValue<()>, Disambiguator) { assert!( self.query_in_progress(), "cannot create a tracked struct disambiguator outside of a tracked function" @@ -270,7 +274,7 @@ impl ZalsaLocal { self.with_query_stack(|stack| { let top_query = stack.last_mut().unwrap(); - let disambiguator = top_query.disambiguate(data_hash); + let disambiguator = top_query.disambiguate(ingredient_index, data_hash); ( StampedValue { value: (), diff --git a/tests/hash_collision.rs b/tests/hash_collision.rs new file mode 100644 index 0000000..4efadfa --- /dev/null +++ b/tests/hash_collision.rs @@ -0,0 +1,32 @@ +use std::hash::Hash; + +#[test] +fn hello() { + use salsa::{Database, DatabaseImpl, Setter}; + + #[salsa::input] + struct Bool { + value: bool, + } + + #[salsa::tracked] + struct True<'db> {} + + #[salsa::tracked] + struct False<'db> {} + + #[salsa::tracked] + fn hello(db: &dyn Database, bool: Bool) { + if bool.value(db) { + True::new(db); + } else { + False::new(db); + } + } + + let mut db = DatabaseImpl::new(); + let input = Bool::new(&db, false); + hello(&db, input); + input.set_value(&mut db).to(true); + hello(&db, input); +}