diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index c5ab7321..ea457e23 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,4 +1,4 @@ -use crate::{accumulator, storage::DatabaseGen, Id}; +use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id}; use super::{Configuration, IngredientImpl}; @@ -21,14 +21,24 @@ where // First ensure the result is up to date self.fetch(db, key); - let database_key_index = self.database_key_index(key); - accumulator.produced_by(runtime, database_key_index, &mut output); + let db_key = self.database_key_index(key); + let mut visited: FxHashSet = FxHashSet::default(); + let mut stack: Vec = vec![db_key]; - if let Some(origin) = self.origin(key) { - for input in origin.inputs() { - if let Ok(input) = input.try_into() { - accumulator.produced_by(runtime, input, &mut output); - } + while let Some(k) = stack.pop() { + if visited.insert(k) { + accumulator.produced_by(runtime, k, &mut output); + + let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + // Careful: we want to push in execution order, so reverse order to + // ensure the first child that was executed will be the first child popped + // from the stack. + stack.extend( + inputs + .flat_map(|input| TryInto::::try_into(input).into_iter()) + .rev(), + ); } } diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index 2213bf1f..2619d85a 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -77,7 +77,7 @@ pub enum QueryOrigin { impl QueryOrigin { /// Indices for queries *read* by this query - pub(crate) fn inputs(&self) -> impl Iterator + '_ { + pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, @@ -86,7 +86,7 @@ impl QueryOrigin { } /// Indices for queries *written* by this query (if any) - pub(crate) fn outputs(&self) -> impl Iterator + '_ { + pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, @@ -127,7 +127,7 @@ impl QueryEdges { /// Returns the (tracked) inputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn inputs(&self) -> impl Iterator + '_ { + pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { self.input_outputs .iter() .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input) @@ -137,7 +137,7 @@ impl QueryEdges { /// Returns the (tracked) outputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn outputs(&self) -> impl Iterator + '_ { + pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { self.input_outputs .iter() .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output) diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs new file mode 100644 index 00000000..bf19bc29 --- /dev/null +++ b/tests/accumulate-chain.rs @@ -0,0 +1,57 @@ +//! Test that when having nested tracked functions +//! we don't drop any values when accumulating. + +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database) { + Log("log a".to_string()).accumulate(db); + push_b_logs(db); +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + // No logs + push_c_logs(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + // No logs + push_d_logs(db); +} + +#[salsa::tracked] +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); +} + +#[test] +fn accumulate_chain() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Check that we get all the logs. + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log d", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +} diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs new file mode 100644 index 00000000..c1fa0481 --- /dev/null +++ b/tests/accumulate-execution-order.rs @@ -0,0 +1,64 @@ +//! Demonstrates that accumulation is done in the order +//! in which things were originally executed. + +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database) { + Log("log a".to_string()).accumulate(db); + push_b_logs(db); + push_c_logs(db); + push_d_logs(db); +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + Log("log b".to_string()).accumulate(db); + push_d_logs(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + Log("log c".to_string()).accumulate(db); +} + +#[salsa::tracked] +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); +} + +#[test] +fn accumulate_execution_order() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Check that we get logs in execution order + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log b", + ), + Log( + "log d", + ), + Log( + "log c", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +} diff --git a/tests/accumulate-no-duplicates.rs b/tests/accumulate-no-duplicates.rs new file mode 100644 index 00000000..10d47baa --- /dev/null +++ b/tests/accumulate-no-duplicates.rs @@ -0,0 +1,104 @@ +//! Test that we don't get duplicate accumulated values + +mod common; + +use expect_test::expect; +use salsa::{Accumulator, Database}; +use test_log::test; + +// A(1) { +// B +// B +// C { +// D { +// A(2) { +// B +// } +// B +// } +// E +// } +// B +// } + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::input] +struct MyInput { + n: u32, +} + +#[salsa::tracked] +fn push_logs(db: &dyn Database) { + push_a_logs(db, MyInput::new(db, 1)); +} + +#[salsa::tracked] +fn push_a_logs(db: &dyn Database, input: MyInput) { + Log("log a".to_string()).accumulate(db); + if input.n(db) == 1 { + push_b_logs(db); + push_b_logs(db); + push_c_logs(db); + push_b_logs(db); + } else { + push_b_logs(db); + } +} + +#[salsa::tracked] +fn push_b_logs(db: &dyn Database) { + Log("log b".to_string()).accumulate(db); +} + +#[salsa::tracked] +fn push_c_logs(db: &dyn Database) { + Log("log c".to_string()).accumulate(db); + push_d_logs(db); + push_e_logs(db); +} + +// Note this isn't tracked +fn push_d_logs(db: &dyn Database) { + Log("log d".to_string()).accumulate(db); + push_a_logs(db, MyInput::new(db, 2)); + push_b_logs(db); +} + +#[salsa::tracked] +fn push_e_logs(db: &dyn Database) { + Log("log e".to_string()).accumulate(db); +} + +#[test] +fn accumulate_no_duplicates() { + salsa::default_database().attach(|db| { + let logs = push_logs::accumulated::(db); + // Test that there aren't duplicate B logs. + // Note that log A appears twice, because they both come + // from different inputs. + expect![[r#" + [ + Log( + "log a", + ), + Log( + "log b", + ), + Log( + "log c", + ), + Log( + "log d", + ), + Log( + "log a", + ), + Log( + "log e", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + }) +}