diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md index 57450f21..506182cf 100644 --- a/examples/hello_world/README.md +++ b/examples/hello_world/README.md @@ -44,7 +44,7 @@ need. The `HelloWorldDatabase` trait has one supertrait: `salsa::Database`. If we were defining more query groups in our application, and we wanted to invoke some of those queries from within -this query group, we might list those query groupes here. You can also +this query group, we might list those query groups here. You can also list any other traits you want, so long as your final database type implements them (this lets you add custom state and so forth to your database). diff --git a/src/derived.rs b/src/derived.rs index 5cca054b..95ff907d 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -53,6 +53,8 @@ where { fn should_memoize_value(key: &Q::Key) -> bool; + fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool; + fn should_track_inputs(key: &Q::Key) -> bool; } @@ -60,12 +62,17 @@ pub enum AlwaysMemoizeValue {} impl MemoizationPolicy for AlwaysMemoizeValue where Q: QueryFunction, + Q::Value: Eq, DB: Database, { fn should_memoize_value(_key: &Q::Key) -> bool { true } + fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool { + old_value == new_value + } + fn should_track_inputs(_key: &Q::Key) -> bool { true } @@ -81,6 +88,10 @@ where false } + fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool { + panic!("cannot reach since we never memoize") + } + fn should_track_inputs(_key: &Q::Key) -> bool { true } @@ -101,6 +112,10 @@ where true } + fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool { + false + } + fn should_track_inputs(_key: &Q::Key) -> bool { false } @@ -243,7 +258,7 @@ where match map.insert(key.clone(), QueryState::in_progress(runtime.id())) { Some(QueryState::Memoized(old_memo)) => Some(old_memo), Some(QueryState::InProgress { .. }) => unreachable!(), - None => None + None => None, } } }; @@ -262,13 +277,7 @@ where let changed_at = memo.changed_at; let new_value = StampedValue { value, changed_at }; - self.overwrite_placeholder( - runtime, - descriptor, - key, - old_memo.unwrap(), - &new_value, - ); + self.overwrite_placeholder(runtime, descriptor, key, old_memo.unwrap(), &new_value); return Ok(new_value); } } @@ -299,9 +308,11 @@ where // "backdate" its `changed_at` revision to be the same as the // old value. if let Some(old_memo) = &old_memo { - if old_memo.value.as_ref() == Some(&stamped_value.value) { - assert!(old_memo.changed_at <= stamped_value.changed_at); - stamped_value.changed_at = old_memo.changed_at; + if let Some(old_value) = &old_memo.value { + if MP::memoized_value_eq(&old_value, &stamped_value.value) { + assert!(old_memo.changed_at <= stamped_value.changed_at); + stamped_value.changed_at = old_memo.changed_at; + } } } diff --git a/src/input.rs b/src/input.rs index 81fff76e..437eca25 100644 --- a/src/input.rs +++ b/src/input.rs @@ -42,6 +42,7 @@ struct IsConstant(bool); impl InputStorage where Q: Query, + Q::Value: Eq, DB: Database, Q::Value: Default, { @@ -137,6 +138,7 @@ where impl QueryStorageOps for InputStorage where Q: Query, + Q::Value: Eq, DB: Database, Q::Value: Default, { @@ -197,6 +199,7 @@ where impl InputQueryStorageOps for InputStorage where Q: Query, + Q::Value: Eq, DB: Database, Q::Value: Default, { diff --git a/src/lib.rs b/src/lib.rs index 1fab6bf0..88433ecf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,7 +56,7 @@ pub trait ParallelDatabase: Database + Send { pub trait Query: Debug + Default + Sized + 'static { type Key: Clone + Debug + Hash + Eq; - type Value: Clone + Debug + Hash + Eq; + type Value: Clone + Debug; type Storage: plumbing::QueryStorageOps + Send + Sync; }