Merge pull request #615 from MichaReiser/micha/faster-accumulators
Some checks failed
Book / Book (push) Has been cancelled
Test / Test (push) Has been cancelled
Test / Miri (push) Has been cancelled
Test / Benchmarks (push) Has been cancelled
Book / Deploy (push) Has been cancelled

Faster accumulators
This commit is contained in:
Micha Reiser 2024-12-03 14:59:51 +00:00 committed by GitHub
commit e68679b3a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 146 additions and 6 deletions

View file

@ -6,6 +6,7 @@ edition = "2021"
license = "Apache-2.0 OR MIT"
repository = "https://github.com/salsa-rs/salsa"
description = "A generic framework for on-demand, incrementalized computation (experimental)"
rust-version = "1.76"
[dependencies]
arc-swap = "1"
@ -44,5 +45,9 @@ harness = false
name = "incremental"
harness = false
[[bench]]
name = "accumulator"
harness = false
[workspace]
members = ["components/salsa-macro-rules", "components/salsa-macros"]

64
benches/accumulator.rs Normal file
View file

@ -0,0 +1,64 @@
use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion};
use salsa::Accumulator;
#[salsa::input]
struct Input {
expressions: usize,
}
#[allow(dead_code)]
#[salsa::accumulator]
struct Diagnostic(String);
#[salsa::interned]
struct Expression<'db> {
number: usize,
}
#[salsa::tracked]
fn root<'db>(db: &'db dyn salsa::Database, input: Input) -> Vec<usize> {
(0..input.expressions(db))
.map(|i| infer_expression(db, Expression::new(db, i)))
.collect()
}
#[salsa::tracked]
fn infer_expression<'db>(db: &'db dyn salsa::Database, expression: Expression<'db>) -> usize {
let number = expression.number(db);
if number % 10 == 0 {
Diagnostic(format!("Number is {number}")).accumulate(db);
}
if number != 0 && number % 2 == 0 {
let sub_expression = Expression::new(db, number / 2);
let _ = infer_expression(db, sub_expression);
}
number
}
fn accumulator(criterion: &mut Criterion) {
criterion.bench_function("accumulator", |b| {
b.iter_batched_ref(
|| {
let db = salsa::DatabaseImpl::new();
let input = Input::new(&db, 10_000);
// Pre-warm
let _ = root(&db, input);
(db, input)
},
|(db, input)| {
// Measure the cost of collecting accumulators ignoring the cost of running the
// query itself.
let diagnostics = root::accumulated::<Diagnostic>(db, *input);
assert_eq!(diagnostics.len(), 1000);
},
BatchSize::SmallInput,
);
});
}
criterion_group!(benches, accumulator);
criterion_main!(benches);

View file

@ -7,6 +7,10 @@ use super::{accumulated::Accumulated, Accumulator, AnyAccumulated};
#[derive(Default, Debug)]
pub struct AccumulatedMap {
map: FxHashMap<IngredientIndex, Box<dyn AnyAccumulated>>,
/// [`InputAccumulatedValues::Empty`] if any input read during the query's execution
/// has any direct or indirect accumulated values.
inputs: InputAccumulatedValues,
}
impl AccumulatedMap {
@ -17,6 +21,21 @@ impl AccumulatedMap {
.accumulate(value);
}
/// Adds the accumulated state of an input to this accumulated map.
pub(crate) fn add_input(&mut self, input: InputAccumulatedValues) {
if input.is_any() {
self.inputs = InputAccumulatedValues::Any;
}
}
/// Returns whether an input of the associated query has any accumulated values.
///
/// Note: Use [`InputAccumulatedValues::from_map`] to check if the associated query itself
/// or any of its inputs has accumulated values.
pub(crate) fn inputs(&self) -> InputAccumulatedValues {
self.inputs
}
pub fn extend_with_accumulated<A: Accumulator>(
&self,
index: IngredientIndex,
@ -41,6 +60,39 @@ impl Clone for AccumulatedMap {
.iter()
.map(|(&key, value)| (key, value.cloned()))
.collect(),
inputs: self.inputs,
}
}
}
/// Tracks whether any input read during a query's execution has any accumulated values.
///
/// Knowning whether any input has accumulated values makes aggregating the accumulated values
/// cheaper because we can skip over entire subtrees.
#[derive(Copy, Clone, Debug, Default)]
pub(crate) enum InputAccumulatedValues {
/// The query nor any of its inputs have any accumulated values.
#[default]
Empty,
/// The query or any of its inputs have at least one accumulated value.
Any,
}
impl InputAccumulatedValues {
pub(crate) fn from_map(accumulated: &AccumulatedMap) -> Self {
if accumulated.map.is_empty() {
accumulated.inputs
} else {
Self::Any
}
}
pub(crate) const fn is_any(self) -> bool {
matches!(self, Self::Any)
}
pub(crate) const fn is_empty(self) -> bool {
matches!(self, Self::Empty)
}
}

View file

@ -3,7 +3,7 @@ use rustc_hash::FxHashMap;
use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions};
use crate::tracked_struct::IdentityHash;
use crate::{
accumulator::accumulated_map::AccumulatedMap,
accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues},
durability::Durability,
hash::FxIndexSet,
key::{DatabaseKeyIndex, DependencyIndex},
@ -76,10 +76,12 @@ impl ActiveQuery {
input: DependencyIndex,
durability: Durability,
revision: Revision,
accumulated: InputAccumulatedValues,
) {
self.input_outputs.insert((EdgeKind::Input, input));
self.durability = self.durability.min(durability);
self.changed_at = self.changed_at.max(revision);
self.accumulated.add_input(accumulated);
}
pub(super) fn add_untracked_read(&mut self, changed_at: Revision) {

View file

@ -56,6 +56,11 @@ where
// Extend `output` with any values accumulated by `k`.
if let Some(accumulated_map) = k.accumulated(db) {
accumulated_map.extend_with_accumulated(accumulator.index(), &mut output);
// Skip over the inputs because we know that the entire sub-graph has no accumulated values
if accumulated_map.inputs().is_empty() {
continue;
}
}
// Find the inputs of `k` and push them onto the stack.

View file

@ -1,6 +1,6 @@
use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id};
use super::{memo::Memo, Configuration, IngredientImpl};
use crate::accumulator::accumulated_map::InputAccumulatedValues;
use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id};
impl<C> IngredientImpl<C>
where
@ -21,7 +21,12 @@ where
self.evict_value_from_memo_for(zalsa, evicted);
}
zalsa_local.report_tracked_read(self.database_key_index(id).into(), durability, changed_at);
zalsa_local.report_tracked_read(
self.database_key_index(id).into(),
durability,
changed_at,
InputAccumulatedValues::from_map(&memo.revisions.accumulated),
);
value
}

View file

@ -8,6 +8,7 @@ use input_field::FieldIngredientImpl;
use parking_lot::Mutex;
use crate::{
accumulator::accumulated_map::InputAccumulatedValues,
cycle::CycleRecoveryStrategy,
id::{AsId, FromId},
ingredient::{fmt_index, Ingredient},
@ -188,6 +189,7 @@ impl<C: Configuration> IngredientImpl<C> {
},
stamp.durability,
stamp.changed_at,
InputAccumulatedValues::Empty,
);
&value.fields
}

View file

@ -1,3 +1,4 @@
use crate::accumulator::accumulated_map::InputAccumulatedValues;
use crate::durability::Durability;
use crate::id::AsId;
use crate::ingredient::fmt_index;
@ -133,6 +134,7 @@ where
DependencyIndex::for_table(self.ingredient_index),
Durability::MAX,
self.reset_at,
InputAccumulatedValues::Empty,
);
// Optimisation to only get read lock on the map if the data has already

View file

@ -4,6 +4,7 @@ use crossbeam::{atomic::AtomicCell, queue::SegQueue};
use tracked_field::FieldIngredientImpl;
use crate::{
accumulator::accumulated_map::InputAccumulatedValues,
cycle::CycleRecoveryStrategy,
ingredient::{fmt_index, Ingredient, Jar, JarAux},
key::{DatabaseKeyIndex, DependencyIndex},
@ -561,6 +562,7 @@ where
},
data.durability,
field_changed_at,
InputAccumulatedValues::Empty,
);
unsafe { self.to_self_ref(&data.fields) }

View file

@ -1,7 +1,7 @@
use rustc_hash::FxHashMap;
use tracing::debug;
use crate::accumulator::accumulated_map::AccumulatedMap;
use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues};
use crate::active_query::ActiveQuery;
use crate::durability::Durability;
use crate::key::DatabaseKeyIndex;
@ -170,6 +170,7 @@ impl ZalsaLocal {
input: DependencyIndex,
durability: Durability,
changed_at: Revision,
accumulated: InputAccumulatedValues,
) {
debug!(
"report_tracked_read(input={:?}, durability={:?}, changed_at={:?})",
@ -177,7 +178,7 @@ impl ZalsaLocal {
);
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read(input, durability, changed_at);
top_query.add_read(input, durability, changed_at, accumulated);
// We are a cycle participant:
//