mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 09:46:06 +00:00
Fix accumulator only accumulating direct children
This commit is contained in:
parent
1c69d3ba7b
commit
148e1cfecf
2 changed files with 92 additions and 11 deletions
|
@ -1,4 +1,6 @@
|
||||||
use crate::{accumulator, storage::DatabaseGen, Id};
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use crate::{accumulator, storage::DatabaseGen, DatabaseKeyIndex, Id};
|
||||||
|
|
||||||
use super::{Configuration, IngredientImpl};
|
use super::{Configuration, IngredientImpl};
|
||||||
|
|
||||||
|
@ -21,17 +23,42 @@ where
|
||||||
// First ensure the result is up to date
|
// First ensure the result is up to date
|
||||||
self.fetch(db, key);
|
self.fetch(db, key);
|
||||||
|
|
||||||
let database_key_index = self.database_key_index(key);
|
// Recursively accumulate outputs from children
|
||||||
accumulator.produced_by(runtime, database_key_index, &mut output);
|
self.database_key_index(key).traverse_children::<C, _>(
|
||||||
|
db,
|
||||||
if let Some(origin) = self.origin(key) {
|
&mut |query| accumulator.produced_by(runtime, query, &mut output),
|
||||||
for input in origin.inputs() {
|
&mut HashSet::new(),
|
||||||
if let Ok(input) = input.try_into() {
|
);
|
||||||
accumulator.produced_by(runtime, input, &mut output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl DatabaseKeyIndex {
|
||||||
|
pub fn traverse_children<C, F>(
|
||||||
|
&self,
|
||||||
|
db: &C::DbView,
|
||||||
|
handler: &mut F,
|
||||||
|
visited: &mut HashSet<DatabaseKeyIndex>,
|
||||||
|
) where
|
||||||
|
C: Configuration,
|
||||||
|
F: (FnMut(DatabaseKeyIndex)),
|
||||||
|
{
|
||||||
|
handler(*self);
|
||||||
|
visited.insert(*self);
|
||||||
|
|
||||||
|
let origin = db
|
||||||
|
.lookup_ingredient(self.ingredient_index)
|
||||||
|
.origin(self.key_index);
|
||||||
|
|
||||||
|
if let Some(origin) = origin {
|
||||||
|
for input in origin.inputs() {
|
||||||
|
if let Ok(input) = TryInto::<DatabaseKeyIndex>::try_into(input) {
|
||||||
|
if !visited.contains(&input) {
|
||||||
|
input.traverse_children::<C, F>(db, handler, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
54
tests/accumulate-chain.rs
Normal file
54
tests/accumulate-chain.rs
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use expect_test::expect;
|
||||||
|
use salsa::{Accumulator, Database};
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
|
#[salsa::accumulator]
|
||||||
|
struct Log(#[allow(dead_code)] String);
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_logs(db: &dyn Database) {
|
||||||
|
push_a_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_a_logs(db: &dyn Database) {
|
||||||
|
Log("log a".to_string()).accumulate(db);
|
||||||
|
push_b_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_b_logs(db: &dyn Database) {
|
||||||
|
// No logs
|
||||||
|
push_c_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_c_logs(db: &dyn Database) {
|
||||||
|
// No logs
|
||||||
|
push_d_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_d_logs(db: &dyn Database) {
|
||||||
|
Log("log d".to_string()).accumulate(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accumulate_chain() {
|
||||||
|
salsa::default_database().attach(|db| {
|
||||||
|
let logs = push_logs::accumulated::<Log>(db);
|
||||||
|
// Check that we don't see logs from `a` appearing twice in the input.
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log a",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log d",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in a new issue