mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-22 21:05:11 +00:00
Fix accumulator only accumulating direct children
This commit is contained in:
parent
1c69d3ba7b
commit
148e1cfecf
2 changed files with 92 additions and 11 deletions
|
@ -1,4 +1,6 @@
|
|||
use crate::{accumulator, storage::DatabaseGen, Id};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{accumulator, storage::DatabaseGen, DatabaseKeyIndex, Id};
|
||||
|
||||
use super::{Configuration, IngredientImpl};
|
||||
|
||||
|
@ -21,17 +23,42 @@ where
|
|||
// First ensure the result is up to date
|
||||
self.fetch(db, key);
|
||||
|
||||
let database_key_index = self.database_key_index(key);
|
||||
accumulator.produced_by(runtime, database_key_index, &mut output);
|
||||
|
||||
if let Some(origin) = self.origin(key) {
|
||||
for input in origin.inputs() {
|
||||
if let Ok(input) = input.try_into() {
|
||||
accumulator.produced_by(runtime, input, &mut output);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Recursively accumulate outputs from children
|
||||
self.database_key_index(key).traverse_children::<C, _>(
|
||||
db,
|
||||
&mut |query| accumulator.produced_by(runtime, query, &mut output),
|
||||
&mut HashSet::new(),
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseKeyIndex {
|
||||
pub fn traverse_children<C, F>(
|
||||
&self,
|
||||
db: &C::DbView,
|
||||
handler: &mut F,
|
||||
visited: &mut HashSet<DatabaseKeyIndex>,
|
||||
) where
|
||||
C: Configuration,
|
||||
F: (FnMut(DatabaseKeyIndex)),
|
||||
{
|
||||
handler(*self);
|
||||
visited.insert(*self);
|
||||
|
||||
let origin = db
|
||||
.lookup_ingredient(self.ingredient_index)
|
||||
.origin(self.key_index);
|
||||
|
||||
if let Some(origin) = origin {
|
||||
for input in origin.inputs() {
|
||||
if let Ok(input) = TryInto::<DatabaseKeyIndex>::try_into(input) {
|
||||
if !visited.contains(&input) {
|
||||
input.traverse_children::<C, F>(db, handler, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
54
tests/accumulate-chain.rs
Normal file
54
tests/accumulate-chain.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
mod common;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::{Accumulator, Database};
|
||||
use test_log::test;
|
||||
|
||||
#[salsa::accumulator]
|
||||
struct Log(#[allow(dead_code)] String);
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_logs(db: &dyn Database) {
|
||||
push_a_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_a_logs(db: &dyn Database) {
|
||||
Log("log a".to_string()).accumulate(db);
|
||||
push_b_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_b_logs(db: &dyn Database) {
|
||||
// No logs
|
||||
push_c_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_c_logs(db: &dyn Database) {
|
||||
// No logs
|
||||
push_d_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_d_logs(db: &dyn Database) {
|
||||
Log("log d".to_string()).accumulate(db);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulate_chain() {
|
||||
salsa::default_database().attach(|db| {
|
||||
let logs = push_logs::accumulated::<Log>(db);
|
||||
// Check that we don't see logs from `a` appearing twice in the input.
|
||||
expect![[r#"
|
||||
[
|
||||
Log(
|
||||
"log a",
|
||||
),
|
||||
Log(
|
||||
"log d",
|
||||
),
|
||||
]"#]]
|
||||
.assert_eq(&format!("{:#?}", logs));
|
||||
})
|
||||
}
|
Loading…
Reference in a new issue