From 148e1cfecfbc5a8941535f35c44e20494ba05a99 Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Sat, 20 Jul 2024 19:55:20 +0100 Subject: [PATCH 1/9] Fix accumulator only accumulating direct children --- src/function/accumulated.rs | 49 +++++++++++++++++++++++++-------- tests/accumulate-chain.rs | 54 +++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 11 deletions(-) create mode 100644 tests/accumulate-chain.rs diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index c5ab7321..eff5734e 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,6 @@ -use crate::{accumulator, storage::DatabaseGen, Id}; +use std::collections::HashSet; + +use crate::{accumulator, storage::DatabaseGen, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -21,17 +23,42 @@ where // First ensure the result is up to date self.fetch(db, key); - let database_key_index = self.database_key_index(key); - accumulator.produced_by(runtime, database_key_index, &mut output); - - if let Some(origin) = self.origin(key) { - for input in origin.inputs() { - if let Ok(input) = input.try_into() { - accumulator.produced_by(runtime, input, &mut output); - } - } - } + // Recursively accumulate outputs from children + self.database_key_index(key).traverse_children::( + db, + &mut |query| accumulator.produced_by(runtime, query, &mut output), + &mut HashSet::new(), + ); output } } + +impl DatabaseKeyIndex { + pub fn traverse_children( + &self, + db: &C::DbView, + handler: &mut F, + visited: &mut HashSet, + ) where + C: Configuration, + F: (FnMut(DatabaseKeyIndex)), + { + handler(*self); + visited.insert(*self); + + let origin = db + .lookup_ingredient(self.ingredient_index) + .origin(self.key_index); + + if let Some(origin) = origin { + for input in origin.inputs() { + if let Ok(input) = TryInto::::try_into(input) { + if !visited.contains(&input) { + input.traverse_children::(db, handler, visited); + } + } + } + } + } +} diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs new file mode 100644 index 00000000..aa3a8ea8 --- /dev/null +++ b/tests/accumulate-chain.rs @@ -0,0 +1,54 @@ +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database) { + Log("log a".to_string()).accumulate(db); + push_b_logs(db); +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + // No logs + push_c_logs(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + // No logs + push_d_logs(db); +} + +#[salsa::tracked] +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); +} + +#[test] +fn accumulate_chain() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Check that we don't see logs from `a` appearing twice in the input. + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log d", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +} From beedbd18e50cc18068a83ef0cac403918399b14b Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Sun, 21 Jul 2024 18:18:02 +0100 Subject: [PATCH 2/9] Refactor --- src/function/accumulated.rs | 48 +++++++++++++------------------------ 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index eff5734e..2bfcbf9d 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,6 +1,6 @@ -use std::collections::HashSet; +use std::collections::VecDeque; -use crate::{accumulator, storage::DatabaseGen, DatabaseKeyIndex, Id}; +use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -23,42 +23,26 @@ where // First ensure the result is up to date self.fetch(db, key); - // Recursively accumulate outputs from children - self.database_key_index(key).traverse_children::( - db, - &mut |query| accumulator.produced_by(runtime, query, &mut output), - &mut HashSet::new(), - ); + let db_key = self.database_key_index(key); + let mut visited: FxHashSet = std::iter::once(db_key).collect(); + let mut stack = VecDeque::new(); + stack.push_front(db_key); - output - } -} + while let Some(k) = stack.pop_front() { + accumulator.produced_by(runtime, k, &mut output); -impl DatabaseKeyIndex { - pub fn traverse_children( - &self, - db: &C::DbView, - handler: &mut F, - visited: &mut HashSet, - ) where - C: Configuration, - F: (FnMut(DatabaseKeyIndex)), - { - handler(*self); - visited.insert(*self); + let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); - let origin = db - .lookup_ingredient(self.ingredient_index) - .origin(self.key_index); - - if let Some(origin) = origin { - for input in origin.inputs() { - if let Ok(input) = TryInto::::try_into(input) { - if !visited.contains(&input) { - input.traverse_children::(db, handler, visited); + for input in inputs { + if let Ok(input) = input.try_into() { + if visited.insert(input) { + stack.push_back(input); } } } } + + output } } From a20c6341ec762121c4d2f59ebe8a278ddeb33713 Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Sun, 21 Jul 2024 18:20:26 +0100 Subject: [PATCH 3/9] Use Vec --- src/function/accumulated.rs | 9 +++----- tests/accumulate-dag.rs | 12 +++++------ tests/accumulate.rs | 42 ++++++++++++++++++------------------- 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 2bfcbf9d..51b16fa3 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,5 +1,3 @@ -use std::collections::VecDeque; - use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -25,10 +23,9 @@ where let db_key = self.database_key_index(key); let mut visited: FxHashSet = std::iter::once(db_key).collect(); - let mut stack = VecDeque::new(); - stack.push_front(db_key); + let mut stack = vec![db_key]; - while let Some(k) = stack.pop_front() { + while let Some(k) = stack.pop() { accumulator.produced_by(runtime, k, &mut output); let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); @@ -37,7 +34,7 @@ where for input in inputs { if let Ok(input) = input.try_into() { if visited.insert(input) { - stack.push_back(input); + stack.push(input); } } } diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs index d0c0cfeb..d7decab7 100644 --- a/tests/accumulate-dag.rs +++ b/tests/accumulate-dag.rs @@ -45,12 +45,6 @@ fn accumulate_a_called_twice() { // Check that we don't see logs from `a` appearing twice in the input. expect![[r#" [ - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), Log( "log_b(0 of 3)", ), @@ -60,6 +54,12 @@ fn accumulate_a_called_twice() { Log( "log_b(2 of 3)", ), + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); }) diff --git a/tests/accumulate.rs b/tests/accumulate.rs index 3b6ce192..0387bac3 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -94,12 +94,6 @@ fn accumulate_once() { // (execution order). expect![[r#" [ - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), Log( "log_b(0 of 3)", ), @@ -109,6 +103,12 @@ fn accumulate_once() { Log( "log_b(2 of 3)", ), + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); } @@ -122,12 +122,6 @@ fn change_a_from_2_to_0() { let logs = push_logs::accumulated::(&db, input); expect![[r#" [ - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), Log( "log_b(0 of 3)", ), @@ -137,6 +131,12 @@ fn change_a_from_2_to_0() { Log( "log_b(2 of 3)", ), + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" @@ -177,12 +177,6 @@ fn change_a_from_2_to_1() { let logs = push_logs::accumulated::(&db, input); expect![[r#" [ - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), Log( "log_b(0 of 3)", ), @@ -192,6 +186,12 @@ fn change_a_from_2_to_1() { Log( "log_b(2 of 3)", ), + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" @@ -206,9 +206,6 @@ fn change_a_from_2_to_1() { let logs = push_logs::accumulated::(&db, input); expect![[r#" [ - Log( - "log_a(0 of 1)", - ), Log( "log_b(0 of 3)", ), @@ -218,6 +215,9 @@ fn change_a_from_2_to_1() { Log( "log_b(2 of 3)", ), + Log( + "log_a(0 of 1)", + ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" From 4543063f5a160f0aecac8a7232483fa7404d8da7 Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Mon, 22 Jul 2024 10:19:36 +0100 Subject: [PATCH 4/9] use a queue --- src/function/accumulated.rs | 7 ++++--- tests/accumulate-dag.rs | 12 +++++------ tests/accumulate.rs | 42 ++++++++++++++++++------------------- 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 51b16fa3..5cd3ef92 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,5 @@ use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; +use std::collections::VecDeque; use super::{Configuration, IngredientImpl}; @@ -23,9 +24,9 @@ where let db_key = self.database_key_index(key); let mut visited: FxHashSet = std::iter::once(db_key).collect(); - let mut stack = vec![db_key]; + let mut queue: VecDeque = std::iter::once(db_key).collect(); - while let Some(k) = stack.pop() { + while let Some(k) = queue.pop_front() { accumulator.produced_by(runtime, k, &mut output); let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); @@ -34,7 +35,7 @@ where for input in inputs { if let Ok(input) = input.try_into() { if visited.insert(input) { - stack.push(input); + queue.push_back(input); } } } diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs index d7decab7..d0c0cfeb 100644 --- a/tests/accumulate-dag.rs +++ b/tests/accumulate-dag.rs @@ -45,6 +45,12 @@ fn accumulate_a_called_twice() { // Check that we don't see logs from `a` appearing twice in the input. expect![[r#" [ + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), Log( "log_b(0 of 3)", ), @@ -54,12 +60,6 @@ fn accumulate_a_called_twice() { Log( "log_b(2 of 3)", ), - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), ]"#]] .assert_eq(&format!("{:#?}", logs)); }) diff --git a/tests/accumulate.rs b/tests/accumulate.rs index 0387bac3..3b6ce192 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -94,6 +94,12 @@ fn accumulate_once() { // (execution order). expect![[r#" [ + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), Log( "log_b(0 of 3)", ), @@ -103,12 +109,6 @@ fn accumulate_once() { Log( "log_b(2 of 3)", ), - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), ]"#]] .assert_eq(&format!("{:#?}", logs)); } @@ -122,6 +122,12 @@ fn change_a_from_2_to_0() { let logs = push_logs::accumulated::(&db, input); expect![[r#" [ + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), Log( "log_b(0 of 3)", ), @@ -131,12 +137,6 @@ fn change_a_from_2_to_0() { Log( "log_b(2 of 3)", ), - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" @@ -177,6 +177,12 @@ fn change_a_from_2_to_1() { let logs = push_logs::accumulated::(&db, input); expect![[r#" [ + Log( + "log_a(0 of 2)", + ), + Log( + "log_a(1 of 2)", + ), Log( "log_b(0 of 3)", ), @@ -186,12 +192,6 @@ fn change_a_from_2_to_1() { Log( "log_b(2 of 3)", ), - Log( - "log_a(0 of 2)", - ), - Log( - "log_a(1 of 2)", - ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" @@ -206,6 +206,9 @@ fn change_a_from_2_to_1() { let logs = push_logs::accumulated::(&db, input); expect![[r#" [ + Log( + "log_a(0 of 1)", + ), Log( "log_b(0 of 3)", ), @@ -215,9 +218,6 @@ fn change_a_from_2_to_1() { Log( "log_b(2 of 3)", ), - Log( - "log_a(0 of 1)", - ), ]"#]] .assert_eq(&format!("{:#?}", logs)); db.assert_logs(expect![[r#" From 2d490e245a742a9fb0e87e296c0724b25769e541 Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Mon, 22 Jul 2024 10:21:16 +0100 Subject: [PATCH 5/9] newline --- src/function/accumulated.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 5cd3ef92..5a269c7f 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,6 +1,7 @@ -use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; use std::collections::VecDeque; +use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; + use super::{Configuration, IngredientImpl}; impl IngredientImpl From c9f22f108acbbd7d12b537ac95404590b2081960 Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Mon, 22 Jul 2024 11:28:54 +0100 Subject: [PATCH 6/9] Use a stack and push to it in reverse order --- src/function/accumulated.rs | 15 +++++++-------- src/runtime/local_state.rs | 8 ++++---- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 5a269c7f..7d54a9f8 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,5 +1,3 @@ -use std::collections::VecDeque; - use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -24,19 +22,20 @@ where self.fetch(db, key); let db_key = self.database_key_index(key); - let mut visited: FxHashSet = std::iter::once(db_key).collect(); - let mut queue: VecDeque = std::iter::once(db_key).collect(); + let mut visited: FxHashSet = FxHashSet::default(); + let mut stack: Vec = vec![db_key]; - while let Some(k) = queue.pop_front() { + while let Some(k) = stack.pop() { + visited.insert(k); accumulator.produced_by(runtime, k, &mut output); let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); let inputs = origin.iter().flat_map(|origin| origin.inputs()); - for input in inputs { + for input in inputs.rev() { if let Ok(input) = input.try_into() { - if visited.insert(input) { - queue.push_back(input); + if !visited.contains(&input) { + stack.push(input); } } } diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index 2213bf1f..2619d85a 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -77,7 +77,7 @@ pub enum QueryOrigin { impl QueryOrigin { /// Indices for queries *read* by this query - pub(crate) fn inputs(&self) -> impl Iterator + '_ { + pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, @@ -86,7 +86,7 @@ impl QueryOrigin { } /// Indices for queries *written* by this query (if any) - pub(crate) fn outputs(&self) -> impl Iterator + '_ { + pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, @@ -127,7 +127,7 @@ impl QueryEdges { /// Returns the (tracked) inputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn inputs(&self) -> impl Iterator + '_ { + pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { self.input_outputs .iter() .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input) @@ -137,7 +137,7 @@ impl QueryEdges { /// Returns the (tracked) outputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn outputs(&self) -> impl Iterator + '_ { + pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { self.input_outputs .iter() .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output) From 02008d51a76237b83aa105cbaed4e008bad5817f Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Mon, 22 Jul 2024 11:52:55 +0100 Subject: [PATCH 7/9] Add a test, fix a bug, refactor --- src/function/accumulated.rs | 20 +++++----- tests/accumulate-chain.rs | 2 +- tests/accumulate-execution-order.rs | 61 +++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 12 deletions(-) create mode 100644 tests/accumulate-execution-order.rs diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 7d54a9f8..f0f04500 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -26,18 +26,16 @@ where let mut stack: Vec = vec![db_key]; while let Some(k) = stack.pop() { - visited.insert(k); - accumulator.produced_by(runtime, k, &mut output); + if visited.insert(k) { + accumulator.produced_by(runtime, k, &mut output); - let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); - let inputs = origin.iter().flat_map(|origin| origin.inputs()); - - for input in inputs.rev() { - if let Ok(input) = input.try_into() { - if !visited.contains(&input) { - stack.push(input); - } - } + let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + stack.extend( + inputs + .flat_map(|input| TryInto::::try_into(input).into_iter()) + .rev(), + ); } } diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index aa3a8ea8..7cf3d3b3 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -39,7 +39,7 @@ fn push_d_logs(db: &dyn Database) { fn accumulate_chain() { salsa::default_database().attach(|db| { let logs = push_logs::accumulated::(db); - // Check that we don't see logs from `a` appearing twice in the input. + // Check that we get all the logs. expect![[r#" [ Log( diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs new file mode 100644 index 00000000..ddc2e023 --- /dev/null +++ b/tests/accumulate-execution-order.rs @@ -0,0 +1,61 @@ +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database) { + Log("log a".to_string()).accumulate(db); + push_b_logs(db); + push_c_logs(db); + push_d_logs(db); +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + Log("log b".to_string()).accumulate(db); + push_d_logs(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + Log("log c".to_string()).accumulate(db); +} + +#[salsa::tracked] +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); +} + +#[test] +fn accumulate_chain() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Check that we get logs in execution order + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log b", + ), + Log( + "log d", + ), + Log( + "log c", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +} From 49a147a6279d65762d14cc5675f29408e1160d40 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Mon, 22 Jul 2024 06:59:39 -0400 Subject: [PATCH 8/9] Update src/function/accumulated.rs --- src/function/accumulated.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index f0f04500..ea457e23 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -31,6 +31,9 @@ where let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); let inputs = origin.iter().flat_map(|origin| origin.inputs()); + // Careful: we want to push in execution order, so reverse order to + // ensure the first child that was executed will be the first child popped + // from the stack. stack.extend( inputs .flat_map(|input| TryInto::::try_into(input).into_iter()) From a85ac260d3c1587fc5acd0e128a3691d5f940fe2 Mon Sep 17 00:00:00 2001 From: Phoebe Szmucer Date: Mon, 22 Jul 2024 20:12:54 +0100 Subject: [PATCH 9/9] Add a more complex case --- tests/accumulate-chain.rs | 3 + tests/accumulate-execution-order.rs | 5 +- tests/accumulate-no-duplicates.rs | 104 ++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 tests/accumulate-no-duplicates.rs diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index 7cf3d3b3..bf19bc29 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -1,3 +1,6 @@ +//! Test that when having nested tracked functions +//! we don't drop any values when accumulating. + mod common; use expect_test::expect; diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs index ddc2e023..c1fa0481 100644 --- a/tests/accumulate-execution-order.rs +++ b/tests/accumulate-execution-order.rs @@ -1,3 +1,6 @@ +//! Demonstrates that accumulation is done in the order +//! in which things were originally executed. + mod common; use expect_test::expect; @@ -37,7 +40,7 @@ fn push_d_logs(db: &dyn Database) { } #[test] -fn accumulate_chain() { +fn accumulate_execution_order() { salsa::default_database().attach(|db| { let logs = push_logs::accumulated::(db); // Check that we get logs in execution order diff --git a/tests/accumulate-no-duplicates.rs b/tests/accumulate-no-duplicates.rs new file mode 100644 index 00000000..10d47baa --- /dev/null +++ b/tests/accumulate-no-duplicates.rs @@ -0,0 +1,104 @@ +//! Test that we don't get duplicate accumulated values + +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +// A(1) { +// B +// B +// C { +// D { +// A(2) { +// B +// } +// B +// } +// E +// } +// B +// } + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::input] +struct MyInput { + n: u32, +} + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db, MyInput::new(db, 1)); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database, input: MyInput) { + Log("log a".to_string()).accumulate(db); + if input.n(db) == 1 { + push_b_logs(db); + push_b_logs(db); + push_c_logs(db); + push_b_logs(db); + } else { + push_b_logs(db); + } +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + Log("log b".to_string()).accumulate(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + Log("log c".to_string()).accumulate(db); + push_d_logs(db); + push_e_logs(db); +} + +// Note this isn't tracked +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); + push_a_logs(db, MyInput::new(db, 2)); + push_b_logs(db); +} + +#[salsa::tracked] +fn push_e_logs(db: &dyn Database) { + Log("log e".to_string()).accumulate(db); +} + +#[test] +fn accumulate_no_duplicates() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Test that there aren't duplicate B logs. + // Note that log A appears twice, because they both come + // from different inputs. + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log b", + ), + Log( + "log c", + ), + Log( + "log d", + ), + Log( + "log a", + ), + Log( + "log e", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +}