remove shared field and inline fields

This commit is contained in:
Niko Matsakis 2024-07-21 07:03:43 -04:00
parent 2f4f80fe23
commit 4e015483fe
3 changed files with 73 additions and 129 deletions

View file

@ -1,12 +1,16 @@
use std::{ use std::{
panic::panic_any, panic::panic_any,
sync::{atomic::Ordering, Arc}, sync::{atomic::AtomicUsize, Arc},
}; };
use crossbeam::atomic::AtomicCell;
use parking_lot::Mutex;
use crate::{ use crate::{
cycle::CycleRecoveryStrategy, cycle::CycleRecoveryStrategy,
durability::Durability, durability::Durability,
key::{DatabaseKeyIndex, DependencyIndex}, key::{DatabaseKeyIndex, DependencyIndex},
revision::AtomicRevision,
runtime::active_query::ActiveQuery, runtime::active_query::ActiveQuery,
storage::IngredientIndex, storage::IngredientIndex,
Cancelled, Cycle, Database, Event, EventKind, Revision, Cancelled, Cycle, Database, Event, EventKind, Revision,
@ -22,7 +26,6 @@ use super::tracked_struct::Disambiguator;
mod active_query; mod active_query;
mod dependency_graph; mod dependency_graph;
pub mod local_state; pub mod local_state;
mod shared_state;
pub struct Runtime { pub struct Runtime {
/// Our unique runtime id. /// Our unique runtime id.
@ -31,8 +34,31 @@ pub struct Runtime {
/// Local state that is specific to this runtime (thread). /// Local state that is specific to this runtime (thread).
local_state: local_state::LocalState, local_state: local_state::LocalState,
/// Shared state that is accessible via all runtimes. /// Stores the next id to use for a snapshotted runtime (starts at 1).
shared_state: Arc<shared_state::SharedState>, next_id: AtomicUsize,
/// Vector we can clone
empty_dependencies: Arc<[(EdgeKind, DependencyIndex)]>,
/// Set to true when the current revision has been canceled.
/// This is done when we an input is being changed. The flag
/// is set back to false once the input has been changed.
revision_canceled: AtomicCell<bool>,
/// Stores the "last change" revision for values of each duration.
/// This vector is always of length at least 1 (for Durability 0)
/// but its total length depends on the number of durations. The
/// element at index 0 is special as it represents the "current
/// revision". In general, we have the invariant that revisions
/// in here are *declining* -- that is, `revisions[i] >=
/// revisions[i + 1]`, for all `i`. This is because when you
/// modify a value with durability D, that implies that values
/// with durability less than D may have changed too.
revisions: Vec<AtomicRevision>,
/// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate.
dependency_graph: Mutex<DependencyGraph>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -80,8 +106,14 @@ impl Default for Runtime {
fn default() -> Self { fn default() -> Self {
Runtime { Runtime {
id: RuntimeId { counter: 0 }, id: RuntimeId { counter: 0 },
shared_state: Default::default(),
local_state: Default::default(), local_state: Default::default(),
revisions: (0..Durability::LEN)
.map(|_| AtomicRevision::start())
.collect(),
next_id: AtomicUsize::new(1),
empty_dependencies: None.into_iter().collect(),
revision_canceled: Default::default(),
dependency_graph: Default::default(),
} }
} }
} }
@ -89,8 +121,11 @@ impl Default for Runtime {
impl std::fmt::Debug for Runtime { impl std::fmt::Debug for Runtime {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.debug_struct("Runtime") fmt.debug_struct("Runtime")
.field("id", &self.id()) .field("id", &self.id)
.field("shared_state", &self.shared_state) .field("revisions", &self.revisions)
.field("next_id", &self.next_id)
.field("revision_canceled", &self.revision_canceled)
.field("dependency_graph", &self.dependency_graph)
.finish() .finish()
} }
} }
@ -101,7 +136,7 @@ impl Runtime {
} }
pub(crate) fn current_revision(&self) -> Revision { pub(crate) fn current_revision(&self) -> Revision {
self.shared_state.revisions[0].load() self.revisions[0].load()
} }
/// Returns the index of the active query along with its *current* durability/changed-at /// Returns the index of the active query along with its *current* durability/changed-at
@ -111,7 +146,7 @@ impl Runtime {
} }
pub(crate) fn empty_dependencies(&self) -> Arc<[(EdgeKind, DependencyIndex)]> { pub(crate) fn empty_dependencies(&self) -> Arc<[(EdgeKind, DependencyIndex)]> {
self.shared_state.empty_dependencies.clone() self.empty_dependencies.clone()
} }
/// Executes `op` but ignores its effect on /// Executes `op` but ignores its effect on
@ -130,22 +165,6 @@ impl Runtime {
self.local_state.debug_probe(op) self.local_state.debug_probe(op)
} }
pub fn snapshot(&self) -> Self {
if self.local_state.query_in_progress() {
panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
}
let id = RuntimeId {
counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst),
};
Runtime {
id,
shared_state: self.shared_state.clone(),
local_state: Default::default(),
}
}
pub(crate) fn report_tracked_read( pub(crate) fn report_tracked_read(
&self, &self,
key_index: DependencyIndex, key_index: DependencyIndex,
@ -170,7 +189,7 @@ impl Runtime {
/// less than or equal to `durability` to the current revision. /// less than or equal to `durability` to the current revision.
pub(crate) fn report_tracked_write(&mut self, durability: Durability) { pub(crate) fn report_tracked_write(&mut self, durability: Durability) {
let new_revision = self.current_revision(); let new_revision = self.current_revision();
for rev in &self.shared_state.revisions[1..=durability.index()] { for rev in &self.revisions[1..=durability.index()] {
rev.store(new_revision); rev.store(new_revision);
} }
} }
@ -219,7 +238,7 @@ impl Runtime {
/// dependencies. /// dependencies.
#[inline] #[inline]
pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision { pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
self.shared_state.revisions[d.index()].load() self.revisions[d.index()].load()
} }
/// Starts unwinding the stack if the current revision is cancelled. /// Starts unwinding the stack if the current revision is cancelled.
@ -239,7 +258,7 @@ impl Runtime {
runtime_id: self.id(), runtime_id: self.id(),
kind: EventKind::WillCheckCancellation, kind: EventKind::WillCheckCancellation,
}); });
if self.shared_state.revision_canceled.load() { if self.revision_canceled.load() {
db.salsa_event(Event { db.salsa_event(Event {
runtime_id: self.id(), runtime_id: self.id(),
kind: EventKind::WillCheckCancellation, kind: EventKind::WillCheckCancellation,
@ -255,7 +274,7 @@ impl Runtime {
} }
pub(crate) fn set_cancellation_flag(&self) { pub(crate) fn set_cancellation_flag(&self) {
self.shared_state.revision_canceled.store(true); self.revision_canceled.store(true);
} }
/// Increments the "current revision" counter and clears /// Increments the "current revision" counter and clears
@ -265,8 +284,8 @@ impl Runtime {
pub(crate) fn new_revision(&mut self) -> Revision { pub(crate) fn new_revision(&mut self) -> Revision {
let r_old = self.current_revision(); let r_old = self.current_revision();
let r_new = r_old.next(); let r_new = r_old.next();
self.shared_state.revisions[0].store(r_new); self.revisions[0].store(r_new);
self.shared_state.revision_canceled.store(false); self.revision_canceled.store(false);
r_new r_new
} }
@ -304,7 +323,7 @@ impl Runtime {
other_id: RuntimeId, other_id: RuntimeId,
query_mutex_guard: QueryMutexGuard, query_mutex_guard: QueryMutexGuard,
) { ) {
let mut dg = self.shared_state.dependency_graph.lock(); let mut dg = self.dependency_graph.lock();
if dg.depends_on(other_id, self.id()) { if dg.depends_on(other_id, self.id()) {
self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id); self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id);
@ -464,8 +483,7 @@ impl Runtime {
database_key: DatabaseKeyIndex, database_key: DatabaseKeyIndex,
wait_result: WaitResult, wait_result: WaitResult,
) { ) {
self.shared_state self.dependency_graph
.dependency_graph
.lock() .lock()
.unblock_runtimes_blocked_on(database_key, wait_result); .unblock_runtimes_blocked_on(database_key, wait_result);
} }

View file

@ -1,56 +0,0 @@
use std::sync::{atomic::AtomicUsize, Arc};
use crossbeam::atomic::AtomicCell;
use parking_lot::Mutex;
use crate::{durability::Durability, key::DependencyIndex, revision::AtomicRevision};
use super::{dependency_graph::DependencyGraph, local_state::EdgeKind};
/// State that will be common to all threads (when we support multiple threads)
#[derive(Debug)]
pub(super) struct SharedState {
/// Stores the next id to use for a snapshotted runtime (starts at 1).
pub(super) next_id: AtomicUsize,
/// Vector we can clone
pub(super) empty_dependencies: Arc<[(EdgeKind, DependencyIndex)]>,
/// Set to true when the current revision has been canceled.
/// This is done when we an input is being changed. The flag
/// is set back to false once the input has been changed.
pub(super) revision_canceled: AtomicCell<bool>,
/// Stores the "last change" revision for values of each duration.
/// This vector is always of length at least 1 (for Durability 0)
/// but its total length depends on the number of durations. The
/// element at index 0 is special as it represents the "current
/// revision". In general, we have the invariant that revisions
/// in here are *declining* -- that is, `revisions[i] >=
/// revisions[i + 1]`, for all `i`. This is because when you
/// modify a value with durability D, that implies that values
/// with durability less than D may have changed too.
pub(super) revisions: Vec<AtomicRevision>,
/// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate.
pub(super) dependency_graph: Mutex<DependencyGraph>,
}
impl Default for SharedState {
fn default() -> Self {
Self::with_durabilities(Durability::LEN)
}
}
impl SharedState {
fn with_durabilities(durabilities: usize) -> Self {
SharedState {
next_id: AtomicUsize::new(1),
empty_dependencies: None.into_iter().collect(),
revision_canceled: Default::default(),
revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
dependency_graph: Default::default(),
}
}
}

View file

@ -111,11 +111,11 @@ unsafe impl<T: HasStorage> DatabaseGen for T {
} }
fn views(&self) -> &Views { fn views(&self) -> &Views {
&self.storage().shared.upcasts &self.storage().upcasts
} }
fn nonce(&self) -> Nonce<StorageNonce> { fn nonce(&self) -> Nonce<StorageNonce> {
self.storage().shared.nonce self.storage().nonce
} }
fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> { fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> {
@ -206,20 +206,6 @@ impl IngredientIndex {
/// The "storage" struct stores all the data for the jars. /// The "storage" struct stores all the data for the jars.
/// It is shared between the main database and any active snapshots. /// It is shared between the main database and any active snapshots.
pub struct Storage<Db: Database> { pub struct Storage<Db: Database> {
/// Data shared across all databases. This contains the ingredients needed by each jar.
/// See the ["jars and ingredients" chapter](https://salsa-rs.github.io/salsa/plumbing/jars_and_ingredients.html)
/// for more detailed description.
shared: Shared<Db>,
/// The runtime for this particular salsa database handle.
/// Each handle gets its own runtime, but the runtimes have shared state between them.
runtime: Runtime,
}
/// Data shared between all threads.
/// This is where the actual data for tracked functions, structs, inputs, etc lives,
/// along with some coordination variables between treads.
struct Shared<Db: Database> {
upcasts: ViewsOf<Db>, upcasts: ViewsOf<Db>,
nonce: Nonce<StorageNonce>, nonce: Nonce<StorageNonce>,
@ -239,19 +225,21 @@ struct Shared<Db: Database> {
/// Indices of ingredients that require reset when a new revision starts. /// Indices of ingredients that require reset when a new revision starts.
ingredients_requiring_reset: ConcurrentVec<IngredientIndex>, ingredients_requiring_reset: ConcurrentVec<IngredientIndex>,
/// The runtime for this particular salsa database handle.
/// Each handle gets its own runtime, but the runtimes have shared state between them.
runtime: Runtime,
} }
// ANCHOR: default // ANCHOR: default
impl<Db: Database> Default for Storage<Db> { impl<Db: Database> Default for Storage<Db> {
fn default() -> Self { fn default() -> Self {
Self { Self {
shared: Shared { upcasts: Default::default(),
upcasts: Default::default(), nonce: NONCE.nonce(),
nonce: NONCE.nonce(), jar_map: Default::default(),
jar_map: Default::default(), ingredients_vec: Default::default(),
ingredients_vec: Default::default(), ingredients_requiring_reset: Default::default(),
ingredients_requiring_reset: Default::default(),
},
runtime: Runtime::default(), runtime: Runtime::default(),
} }
} }
@ -265,35 +253,34 @@ impl<Db: Database> Storage<Db> {
func: fn(&Db) -> &T, func: fn(&Db) -> &T,
func_mut: fn(&mut Db) -> &mut T, func_mut: fn(&mut Db) -> &mut T,
) { ) {
self.shared.upcasts.add::<T>(func, func_mut) self.upcasts.add::<T>(func, func_mut)
} }
/// Adds the ingredients in `jar` to the database if not already present. /// Adds the ingredients in `jar` to the database if not already present.
/// If a jar of this type is already present, returns the index. /// If a jar of this type is already present, returns the index.
fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex {
let jar_type_id = jar.type_id(); let jar_type_id = jar.type_id();
let mut jar_map = self.shared.jar_map.lock(); let mut jar_map = self.jar_map.lock();
*jar_map *jar_map
.entry(jar_type_id) .entry(jar_type_id)
.or_insert_with(|| { .or_insert_with(|| {
let index = IngredientIndex::from(self.shared.ingredients_vec.len()); let index = IngredientIndex::from(self.ingredients_vec.len());
let ingredients = jar.create_ingredients(index); let ingredients = jar.create_ingredients(index);
for ingredient in ingredients { for ingredient in ingredients {
let expected_index = ingredient.ingredient_index(); let expected_index = ingredient.ingredient_index();
if ingredient.requires_reset_for_new_revision() { if ingredient.requires_reset_for_new_revision() {
self.shared.ingredients_requiring_reset.push(expected_index); self.ingredients_requiring_reset.push(expected_index);
} }
let actual_index = self let actual_index = self
.shared
.ingredients_vec .ingredients_vec
.push(ingredient); .push(ingredient);
assert_eq!( assert_eq!(
expected_index.as_usize(), expected_index.as_usize(),
actual_index, actual_index,
"ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`",
self.shared.ingredients_vec.get(actual_index).unwrap(), self.ingredients_vec.get(actual_index).unwrap(),
expected_index, expected_index,
actual_index, actual_index,
); );
@ -305,11 +292,11 @@ impl<Db: Database> Storage<Db> {
/// Return the index of the 1st ingredient from the given jar. /// Return the index of the 1st ingredient from the given jar.
pub fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> { pub fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> {
self.shared.jar_map.lock().get(&jar.type_id()).copied() self.jar_map.lock().get(&jar.type_id()).copied()
} }
pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient {
&**self.shared.ingredients_vec.get(index.as_usize()).unwrap() &**self.ingredients_vec.get(index.as_usize()).unwrap()
} }
fn lookup_ingredient_mut( fn lookup_ingredient_mut(
@ -318,20 +305,15 @@ impl<Db: Database> Storage<Db> {
) -> (&mut dyn Ingredient, &mut Runtime) { ) -> (&mut dyn Ingredient, &mut Runtime) {
self.runtime.new_revision(); self.runtime.new_revision();
for index in self.shared.ingredients_requiring_reset.iter() { for index in self.ingredients_requiring_reset.iter() {
self.shared self.ingredients_vec
.ingredients_vec
.get_mut(index.as_usize()) .get_mut(index.as_usize())
.unwrap() .unwrap()
.reset_for_new_revision(); .reset_for_new_revision();
} }
( (
&mut **self &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap(),
.shared
.ingredients_vec
.get_mut(index.as_usize())
.unwrap(),
&mut self.runtime, &mut self.runtime,
) )
} }