mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-01-22 21:05:11 +00:00
Merge pull request #524 from PhoebeSzmucer/ps/accumulate-chain
Fix accumulator only accumulating direct children
This commit is contained in:
commit
c8234e4fbf
5 changed files with 247 additions and 12 deletions
|
@ -1,4 +1,4 @@
|
|||
use crate::{accumulator, storage::DatabaseGen, Id};
|
||||
use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id};
|
||||
|
||||
use super::{Configuration, IngredientImpl};
|
||||
|
||||
|
@ -21,14 +21,24 @@ where
|
|||
// First ensure the result is up to date
|
||||
self.fetch(db, key);
|
||||
|
||||
let database_key_index = self.database_key_index(key);
|
||||
accumulator.produced_by(runtime, database_key_index, &mut output);
|
||||
let db_key = self.database_key_index(key);
|
||||
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];
|
||||
|
||||
if let Some(origin) = self.origin(key) {
|
||||
for input in origin.inputs() {
|
||||
if let Ok(input) = input.try_into() {
|
||||
accumulator.produced_by(runtime, input, &mut output);
|
||||
}
|
||||
while let Some(k) = stack.pop() {
|
||||
if visited.insert(k) {
|
||||
accumulator.produced_by(runtime, k, &mut output);
|
||||
|
||||
let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
|
||||
let inputs = origin.iter().flat_map(|origin| origin.inputs());
|
||||
// Careful: we want to push in execution order, so reverse order to
|
||||
// ensure the first child that was executed will be the first child popped
|
||||
// from the stack.
|
||||
stack.extend(
|
||||
inputs
|
||||
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
|
||||
.rev(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ pub enum QueryOrigin {
|
|||
|
||||
impl QueryOrigin {
|
||||
/// Indices for queries *read* by this query
|
||||
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
||||
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||
let opt_edges = match self {
|
||||
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
||||
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
||||
|
@ -86,7 +86,7 @@ impl QueryOrigin {
|
|||
}
|
||||
|
||||
/// Indices for queries *written* by this query (if any)
|
||||
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
||||
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||
let opt_edges = match self {
|
||||
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
||||
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
||||
|
@ -127,7 +127,7 @@ impl QueryEdges {
|
|||
/// Returns the (tracked) inputs that were executed in computing this memoized value.
|
||||
///
|
||||
/// These will always be in execution order.
|
||||
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
||||
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||
self.input_outputs
|
||||
.iter()
|
||||
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input)
|
||||
|
@ -137,7 +137,7 @@ impl QueryEdges {
|
|||
/// Returns the (tracked) outputs that were executed in computing this memoized value.
|
||||
///
|
||||
/// These will always be in execution order.
|
||||
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
||||
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||
self.input_outputs
|
||||
.iter()
|
||||
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output)
|
||||
|
|
57
tests/accumulate-chain.rs
Normal file
57
tests/accumulate-chain.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
//! Test that when having nested tracked functions
|
||||
//! we don't drop any values when accumulating.
|
||||
|
||||
mod common;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::{Accumulator, Database};
|
||||
use test_log::test;
|
||||
|
||||
#[salsa::accumulator]
|
||||
struct Log(#[allow(dead_code)] String);
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_logs(db: &dyn Database) {
|
||||
push_a_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_a_logs(db: &dyn Database) {
|
||||
Log("log a".to_string()).accumulate(db);
|
||||
push_b_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_b_logs(db: &dyn Database) {
|
||||
// No logs
|
||||
push_c_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_c_logs(db: &dyn Database) {
|
||||
// No logs
|
||||
push_d_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_d_logs(db: &dyn Database) {
|
||||
Log("log d".to_string()).accumulate(db);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulate_chain() {
|
||||
salsa::default_database().attach(|db| {
|
||||
let logs = push_logs::accumulated::<Log>(db);
|
||||
// Check that we get all the logs.
|
||||
expect![[r#"
|
||||
[
|
||||
Log(
|
||||
"log a",
|
||||
),
|
||||
Log(
|
||||
"log d",
|
||||
),
|
||||
]"#]]
|
||||
.assert_eq(&format!("{:#?}", logs));
|
||||
})
|
||||
}
|
64
tests/accumulate-execution-order.rs
Normal file
64
tests/accumulate-execution-order.rs
Normal file
|
@ -0,0 +1,64 @@
|
|||
//! Demonstrates that accumulation is done in the order
|
||||
//! in which things were originally executed.
|
||||
|
||||
mod common;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::{Accumulator, Database};
|
||||
use test_log::test;
|
||||
|
||||
#[salsa::accumulator]
|
||||
struct Log(#[allow(dead_code)] String);
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_logs(db: &dyn Database) {
|
||||
push_a_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_a_logs(db: &dyn Database) {
|
||||
Log("log a".to_string()).accumulate(db);
|
||||
push_b_logs(db);
|
||||
push_c_logs(db);
|
||||
push_d_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_b_logs(db: &dyn Database) {
|
||||
Log("log b".to_string()).accumulate(db);
|
||||
push_d_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_c_logs(db: &dyn Database) {
|
||||
Log("log c".to_string()).accumulate(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_d_logs(db: &dyn Database) {
|
||||
Log("log d".to_string()).accumulate(db);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulate_execution_order() {
|
||||
salsa::default_database().attach(|db| {
|
||||
let logs = push_logs::accumulated::<Log>(db);
|
||||
// Check that we get logs in execution order
|
||||
expect![[r#"
|
||||
[
|
||||
Log(
|
||||
"log a",
|
||||
),
|
||||
Log(
|
||||
"log b",
|
||||
),
|
||||
Log(
|
||||
"log d",
|
||||
),
|
||||
Log(
|
||||
"log c",
|
||||
),
|
||||
]"#]]
|
||||
.assert_eq(&format!("{:#?}", logs));
|
||||
})
|
||||
}
|
104
tests/accumulate-no-duplicates.rs
Normal file
104
tests/accumulate-no-duplicates.rs
Normal file
|
@ -0,0 +1,104 @@
|
|||
//! Test that we don't get duplicate accumulated values
|
||||
|
||||
mod common;
|
||||
|
||||
use expect_test::expect;
|
||||
use salsa::{Accumulator, Database};
|
||||
use test_log::test;
|
||||
|
||||
// A(1) {
|
||||
// B
|
||||
// B
|
||||
// C {
|
||||
// D {
|
||||
// A(2) {
|
||||
// B
|
||||
// }
|
||||
// B
|
||||
// }
|
||||
// E
|
||||
// }
|
||||
// B
|
||||
// }
|
||||
|
||||
#[salsa::accumulator]
|
||||
struct Log(#[allow(dead_code)] String);
|
||||
|
||||
#[salsa::input]
|
||||
struct MyInput {
|
||||
n: u32,
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_logs(db: &dyn Database) {
|
||||
push_a_logs(db, MyInput::new(db, 1));
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_a_logs(db: &dyn Database, input: MyInput) {
|
||||
Log("log a".to_string()).accumulate(db);
|
||||
if input.n(db) == 1 {
|
||||
push_b_logs(db);
|
||||
push_b_logs(db);
|
||||
push_c_logs(db);
|
||||
push_b_logs(db);
|
||||
} else {
|
||||
push_b_logs(db);
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_b_logs(db: &dyn Database) {
|
||||
Log("log b".to_string()).accumulate(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_c_logs(db: &dyn Database) {
|
||||
Log("log c".to_string()).accumulate(db);
|
||||
push_d_logs(db);
|
||||
push_e_logs(db);
|
||||
}
|
||||
|
||||
// Note this isn't tracked
|
||||
fn push_d_logs(db: &dyn Database) {
|
||||
Log("log d".to_string()).accumulate(db);
|
||||
push_a_logs(db, MyInput::new(db, 2));
|
||||
push_b_logs(db);
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
fn push_e_logs(db: &dyn Database) {
|
||||
Log("log e".to_string()).accumulate(db);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulate_no_duplicates() {
|
||||
salsa::default_database().attach(|db| {
|
||||
let logs = push_logs::accumulated::<Log>(db);
|
||||
// Test that there aren't duplicate B logs.
|
||||
// Note that log A appears twice, because they both come
|
||||
// from different inputs.
|
||||
expect![[r#"
|
||||
[
|
||||
Log(
|
||||
"log a",
|
||||
),
|
||||
Log(
|
||||
"log b",
|
||||
),
|
||||
Log(
|
||||
"log c",
|
||||
),
|
||||
Log(
|
||||
"log d",
|
||||
),
|
||||
Log(
|
||||
"log a",
|
||||
),
|
||||
Log(
|
||||
"log e",
|
||||
),
|
||||
]"#]]
|
||||
.assert_eq(&format!("{:#?}", logs));
|
||||
})
|
||||
}
|
Loading…
Reference in a new issue