diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index c5ab7321..eff5734e 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,6 @@ -use crate::{accumulator, storage::DatabaseGen, Id}; +use std::collections::HashSet; + +use crate::{accumulator, storage::DatabaseGen, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -21,17 +23,42 @@ where // First ensure the result is up to date self.fetch(db, key); - let database_key_index = self.database_key_index(key); - accumulator.produced_by(runtime, database_key_index, &mut output); - - if let Some(origin) = self.origin(key) { - for input in origin.inputs() { - if let Ok(input) = input.try_into() { - accumulator.produced_by(runtime, input, &mut output); - } - } - } + // Recursively accumulate outputs from children + self.database_key_index(key).traverse_children::( + db, + &mut |query| accumulator.produced_by(runtime, query, &mut output), + &mut HashSet::new(), + ); output } } + +impl DatabaseKeyIndex { + pub fn traverse_children( + &self, + db: &C::DbView, + handler: &mut F, + visited: &mut HashSet, + ) where + C: Configuration, + F: (FnMut(DatabaseKeyIndex)), + { + handler(*self); + visited.insert(*self); + + let origin = db + .lookup_ingredient(self.ingredient_index) + .origin(self.key_index); + + if let Some(origin) = origin { + for input in origin.inputs() { + if let Ok(input) = TryInto::::try_into(input) { + if !visited.contains(&input) { + input.traverse_children::(db, handler, visited); + } + } + } + } + } +} diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs new file mode 100644 index 00000000..aa3a8ea8 --- /dev/null +++ b/tests/accumulate-chain.rs @@ -0,0 +1,54 @@ +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database) { + Log("log a".to_string()).accumulate(db); + push_b_logs(db); +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + // No logs + push_c_logs(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + // No logs + push_d_logs(db); +} + +#[salsa::tracked] +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); +} + +#[test] +fn accumulate_chain() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Check that we don't see logs from `a` appearing twice in the input. + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log d", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +}