mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 09:46:06 +00:00
Merge pull request #524 from PhoebeSzmucer/ps/accumulate-chain
Fix accumulator only accumulating direct children
This commit is contained in:
commit
c8234e4fbf
5 changed files with 247 additions and 12 deletions
|
@ -1,4 +1,4 @@
|
||||||
use crate::{accumulator, storage::DatabaseGen, Id};
|
use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id};
|
||||||
|
|
||||||
use super::{Configuration, IngredientImpl};
|
use super::{Configuration, IngredientImpl};
|
||||||
|
|
||||||
|
@ -21,14 +21,24 @@ where
|
||||||
// First ensure the result is up to date
|
// First ensure the result is up to date
|
||||||
self.fetch(db, key);
|
self.fetch(db, key);
|
||||||
|
|
||||||
let database_key_index = self.database_key_index(key);
|
let db_key = self.database_key_index(key);
|
||||||
accumulator.produced_by(runtime, database_key_index, &mut output);
|
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
|
||||||
|
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];
|
||||||
|
|
||||||
if let Some(origin) = self.origin(key) {
|
while let Some(k) = stack.pop() {
|
||||||
for input in origin.inputs() {
|
if visited.insert(k) {
|
||||||
if let Ok(input) = input.try_into() {
|
accumulator.produced_by(runtime, k, &mut output);
|
||||||
accumulator.produced_by(runtime, input, &mut output);
|
|
||||||
}
|
let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
|
||||||
|
let inputs = origin.iter().flat_map(|origin| origin.inputs());
|
||||||
|
// Careful: we want to push in execution order, so reverse order to
|
||||||
|
// ensure the first child that was executed will be the first child popped
|
||||||
|
// from the stack.
|
||||||
|
stack.extend(
|
||||||
|
inputs
|
||||||
|
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
|
||||||
|
.rev(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ pub enum QueryOrigin {
|
||||||
|
|
||||||
impl QueryOrigin {
|
impl QueryOrigin {
|
||||||
/// Indices for queries *read* by this query
|
/// Indices for queries *read* by this query
|
||||||
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||||
let opt_edges = match self {
|
let opt_edges = match self {
|
||||||
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
||||||
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
||||||
|
@ -86,7 +86,7 @@ impl QueryOrigin {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Indices for queries *written* by this query (if any)
|
/// Indices for queries *written* by this query (if any)
|
||||||
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||||
let opt_edges = match self {
|
let opt_edges = match self {
|
||||||
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
|
||||||
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
|
||||||
|
@ -127,7 +127,7 @@ impl QueryEdges {
|
||||||
/// Returns the (tracked) inputs that were executed in computing this memoized value.
|
/// Returns the (tracked) inputs that were executed in computing this memoized value.
|
||||||
///
|
///
|
||||||
/// These will always be in execution order.
|
/// These will always be in execution order.
|
||||||
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||||
self.input_outputs
|
self.input_outputs
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input)
|
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input)
|
||||||
|
@ -137,7 +137,7 @@ impl QueryEdges {
|
||||||
/// Returns the (tracked) outputs that were executed in computing this memoized value.
|
/// Returns the (tracked) outputs that were executed in computing this memoized value.
|
||||||
///
|
///
|
||||||
/// These will always be in execution order.
|
/// These will always be in execution order.
|
||||||
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
|
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
|
||||||
self.input_outputs
|
self.input_outputs
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output)
|
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output)
|
||||||
|
|
57
tests/accumulate-chain.rs
Normal file
57
tests/accumulate-chain.rs
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
//! Test that when having nested tracked functions
|
||||||
|
//! we don't drop any values when accumulating.
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use expect_test::expect;
|
||||||
|
use salsa::{Accumulator, Database};
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
|
#[salsa::accumulator]
|
||||||
|
struct Log(#[allow(dead_code)] String);
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_logs(db: &dyn Database) {
|
||||||
|
push_a_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_a_logs(db: &dyn Database) {
|
||||||
|
Log("log a".to_string()).accumulate(db);
|
||||||
|
push_b_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_b_logs(db: &dyn Database) {
|
||||||
|
// No logs
|
||||||
|
push_c_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_c_logs(db: &dyn Database) {
|
||||||
|
// No logs
|
||||||
|
push_d_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_d_logs(db: &dyn Database) {
|
||||||
|
Log("log d".to_string()).accumulate(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accumulate_chain() {
|
||||||
|
salsa::default_database().attach(|db| {
|
||||||
|
let logs = push_logs::accumulated::<Log>(db);
|
||||||
|
// Check that we get all the logs.
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log a",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log d",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
})
|
||||||
|
}
|
64
tests/accumulate-execution-order.rs
Normal file
64
tests/accumulate-execution-order.rs
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
//! Demonstrates that accumulation is done in the order
|
||||||
|
//! in which things were originally executed.
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use expect_test::expect;
|
||||||
|
use salsa::{Accumulator, Database};
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
|
#[salsa::accumulator]
|
||||||
|
struct Log(#[allow(dead_code)] String);
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_logs(db: &dyn Database) {
|
||||||
|
push_a_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_a_logs(db: &dyn Database) {
|
||||||
|
Log("log a".to_string()).accumulate(db);
|
||||||
|
push_b_logs(db);
|
||||||
|
push_c_logs(db);
|
||||||
|
push_d_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_b_logs(db: &dyn Database) {
|
||||||
|
Log("log b".to_string()).accumulate(db);
|
||||||
|
push_d_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_c_logs(db: &dyn Database) {
|
||||||
|
Log("log c".to_string()).accumulate(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_d_logs(db: &dyn Database) {
|
||||||
|
Log("log d".to_string()).accumulate(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accumulate_execution_order() {
|
||||||
|
salsa::default_database().attach(|db| {
|
||||||
|
let logs = push_logs::accumulated::<Log>(db);
|
||||||
|
// Check that we get logs in execution order
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log a",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log b",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log d",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log c",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
})
|
||||||
|
}
|
104
tests/accumulate-no-duplicates.rs
Normal file
104
tests/accumulate-no-duplicates.rs
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
//! Test that we don't get duplicate accumulated values
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use expect_test::expect;
|
||||||
|
use salsa::{Accumulator, Database};
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
|
// A(1) {
|
||||||
|
// B
|
||||||
|
// B
|
||||||
|
// C {
|
||||||
|
// D {
|
||||||
|
// A(2) {
|
||||||
|
// B
|
||||||
|
// }
|
||||||
|
// B
|
||||||
|
// }
|
||||||
|
// E
|
||||||
|
// }
|
||||||
|
// B
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[salsa::accumulator]
|
||||||
|
struct Log(#[allow(dead_code)] String);
|
||||||
|
|
||||||
|
#[salsa::input]
|
||||||
|
struct MyInput {
|
||||||
|
n: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_logs(db: &dyn Database) {
|
||||||
|
push_a_logs(db, MyInput::new(db, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_a_logs(db: &dyn Database, input: MyInput) {
|
||||||
|
Log("log a".to_string()).accumulate(db);
|
||||||
|
if input.n(db) == 1 {
|
||||||
|
push_b_logs(db);
|
||||||
|
push_b_logs(db);
|
||||||
|
push_c_logs(db);
|
||||||
|
push_b_logs(db);
|
||||||
|
} else {
|
||||||
|
push_b_logs(db);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_b_logs(db: &dyn Database) {
|
||||||
|
Log("log b".to_string()).accumulate(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_c_logs(db: &dyn Database) {
|
||||||
|
Log("log c".to_string()).accumulate(db);
|
||||||
|
push_d_logs(db);
|
||||||
|
push_e_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note this isn't tracked
|
||||||
|
fn push_d_logs(db: &dyn Database) {
|
||||||
|
Log("log d".to_string()).accumulate(db);
|
||||||
|
push_a_logs(db, MyInput::new(db, 2));
|
||||||
|
push_b_logs(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[salsa::tracked]
|
||||||
|
fn push_e_logs(db: &dyn Database) {
|
||||||
|
Log("log e".to_string()).accumulate(db);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accumulate_no_duplicates() {
|
||||||
|
salsa::default_database().attach(|db| {
|
||||||
|
let logs = push_logs::accumulated::<Log>(db);
|
||||||
|
// Test that there aren't duplicate B logs.
|
||||||
|
// Note that log A appears twice, because they both come
|
||||||
|
// from different inputs.
|
||||||
|
expect![[r#"
|
||||||
|
[
|
||||||
|
Log(
|
||||||
|
"log a",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log b",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log c",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log d",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log a",
|
||||||
|
),
|
||||||
|
Log(
|
||||||
|
"log e",
|
||||||
|
),
|
||||||
|
]"#]]
|
||||||
|
.assert_eq(&format!("{:#?}", logs));
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in a new issue