Fix accumulator only accumulating direct children

This commit is contained in:
Phoebe Szmucer 2024-07-20 19:55:20 +01:00
parent 1c69d3ba7b
commit 148e1cfecf
2 changed files with 92 additions and 11 deletions

View file

@ -1,4 +1,6 @@
use crate::{accumulator, storage::DatabaseGen, Id};
use std::collections::HashSet;
use crate::{accumulator, storage::DatabaseGen, DatabaseKeyIndex, Id};
use super::{Configuration, IngredientImpl};
@ -21,17 +23,42 @@ where
// First ensure the result is up to date
self.fetch(db, key);
let database_key_index = self.database_key_index(key);
accumulator.produced_by(runtime, database_key_index, &mut output);
if let Some(origin) = self.origin(key) {
for input in origin.inputs() {
if let Ok(input) = input.try_into() {
accumulator.produced_by(runtime, input, &mut output);
}
}
}
// Recursively accumulate outputs from children
self.database_key_index(key).traverse_children::<C, _>(
db,
&mut |query| accumulator.produced_by(runtime, query, &mut output),
&mut HashSet::new(),
);
output
}
}
impl DatabaseKeyIndex {
pub fn traverse_children<C, F>(
&self,
db: &C::DbView,
handler: &mut F,
visited: &mut HashSet<DatabaseKeyIndex>,
) where
C: Configuration,
F: (FnMut(DatabaseKeyIndex)),
{
handler(*self);
visited.insert(*self);
let origin = db
.lookup_ingredient(self.ingredient_index)
.origin(self.key_index);
if let Some(origin) = origin {
for input in origin.inputs() {
if let Ok(input) = TryInto::<DatabaseKeyIndex>::try_into(input) {
if !visited.contains(&input) {
input.traverse_children::<C, F>(db, handler, visited);
}
}
}
}
}
}

54
tests/accumulate-chain.rs Normal file
View file

@ -0,0 +1,54 @@
mod common;
use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;
#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);
#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db);
}
#[salsa::tracked]
fn push_a_logs(db: &dyn Database) {
Log("log a".to_string()).accumulate(db);
push_b_logs(db);
}
#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
// No logs
push_c_logs(db);
}
#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
// No logs
push_d_logs(db);
}
#[salsa::tracked]
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
}
#[test]
fn accumulate_chain() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we don't see logs from `a` appearing twice in the input.
expect![[r#"
[
Log(
"log a",
),
Log(
"log d",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}