342: Make input setters return old value r=nikomatsakis a=MihailMihov

Added an `InputFieldIngredient` which replaces the `FunctionIngredient` for inputs. The end goal is for the inputs to use an ingredient simpler than `FunctionIngredient` and that would allow us to return the old values when updating an input.

Co-authored-by: Mihail Mihov <mmihov.personal@gmail.com>
This commit is contained in:
bors[bot] 2022-08-16 19:26:02 +00:00 committed by GitHub
commit a866e71266
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 9 deletions

View file

@ -35,7 +35,7 @@ impl InputStruct {
let id_struct = self.id_struct();
let inherent_impl = self.input_inherent_impl();
let ingredients_for_impl = self.input_ingredients(&config_structs);
let ingredients_for_impl = self.input_ingredients();
let as_id_impl = self.as_id_impl();
let salsa_struct_in_db_impl = self.salsa_struct_in_db_impl();
@ -74,7 +74,7 @@ impl InputStruct {
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
__ingredients.#field_index.fetch(__db, self)
__ingredients.#field_index.fetch(__runtime, self)
}
}
} else {
@ -83,7 +83,7 @@ impl InputStruct {
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar);
__ingredients.#field_index.fetch(__db, self).clone()
__ingredients.#field_index.fetch(__runtime, self).clone()
}
}
}
@ -93,11 +93,11 @@ impl InputStruct {
let field_setters: Vec<syn::ImplItemMethod> = field_indices.iter().zip(&field_names).zip(&field_tys).map(|((field_index, field_name), field_ty)| {
let set_field_name = syn::Ident::new(&format!("set_{}", field_name), field_name.span());
parse_quote! {
pub fn #set_field_name<'db>(self, __db: &'db mut #db_dyn_ty, __value: #field_ty)
pub fn #set_field_name<'db>(self, __db: &'db mut #db_dyn_ty, __value: #field_ty) -> #field_ty
{
let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar_mut(__db);
let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient_mut(__jar);
__ingredients.#field_index.store(__runtime, self, __value, salsa::Durability::LOW);
__ingredients.#field_index.store(__runtime, self, __value, salsa::Durability::LOW).unwrap()
}
}
})
@ -127,19 +127,19 @@ impl InputStruct {
///
/// The entity's ingredients include both the main entity ingredient along with a
/// function ingredient for each of the value fields.
fn input_ingredients(&self, config_structs: &[syn::ItemStruct]) -> syn::ItemImpl {
fn input_ingredients(&self) -> syn::ItemImpl {
let ident = self.id_ident();
let field_ty = self.all_field_tys();
let jar_ty = self.jar_ty();
let all_field_indices: Vec<Literal> = self.all_field_indices();
let input_index: Literal = self.input_index();
let config_struct_names = config_structs.iter().map(|s| &s.ident);
parse_quote! {
impl salsa::storage::IngredientsFor for #ident {
type Jar = #jar_ty;
type Ingredients = (
#(
salsa::function::FunctionIngredient<#config_struct_names>,
salsa::input_field::InputFieldIngredient<#ident, #field_ty>,
)*
salsa::input::InputIngredient<#ident>,
);
@ -160,7 +160,7 @@ impl InputStruct {
&ingredients.#all_field_indices
},
);
salsa::function::FunctionIngredient::new(index)
salsa::input_field::InputFieldIngredient::new(index)
},
)*
{

View file

@ -0,0 +1,96 @@
use crate::cycle::CycleRecoveryStrategy;
use crate::ingredient::Ingredient;
use crate::key::DependencyIndex;
use crate::runtime::local_state::QueryOrigin;
use crate::runtime::StampedValue;
use crate::{AsId, DatabaseKeyIndex, Durability, Id, IngredientIndex, Revision, Runtime};
use rustc_hash::FxHashMap;
use std::hash::Hash;
/// Ingredient used to represent the fields of a `#[salsa::input]`.
/// These fields can only be mutated by an explicit call to a setter
/// with an `&mut` reference to the database,
/// and therefore cannot be mutated during a tracked function or in parallel.
/// This makes the implementation considerably simpler.
pub struct InputFieldIngredient<K, F> {
index: IngredientIndex,
map: FxHashMap<K, StampedValue<F>>,
}
impl<K, F> InputFieldIngredient<K, F>
where
K: Eq + Hash + AsId,
{
pub fn new(index: IngredientIndex) -> Self {
Self {
index,
map: Default::default(),
}
}
pub fn store(
&mut self,
runtime: &mut Runtime,
key: K,
value: F,
durability: Durability,
) -> Option<F> {
let revision = runtime.current_revision();
let stamped_value = StampedValue {
value,
durability,
changed_at: revision,
};
if let Some(old_value) = self.map.insert(key, stamped_value) {
Some(old_value.value)
} else {
None
}
}
pub fn fetch(&self, runtime: &Runtime, key: K) -> &F {
let StampedValue {
value,
durability,
changed_at,
} = self.map.get(&key).unwrap();
runtime.report_tracked_read(
self.database_key_index(key).into(),
*durability,
*changed_at,
);
value
}
fn database_key_index(&self, key: K) -> DatabaseKeyIndex {
DatabaseKeyIndex {
ingredient_index: self.index,
key_index: key.as_id(),
}
}
}
impl<DB: ?Sized, K, F> Ingredient<DB> for InputFieldIngredient<K, F>
where
K: AsId,
{
fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy {
CycleRecoveryStrategy::Panic
}
fn maybe_changed_after(&self, _db: &DB, input: DependencyIndex, revision: Revision) -> bool {
let key = K::from_id(input.key_index.unwrap());
self.map.get(&key).unwrap().changed_at > revision
}
fn origin(&self, _key_index: Id) -> Option<QueryOrigin> {
None
}
fn mark_validated_output(&self, _db: &DB, _executor: DatabaseKeyIndex, _output_key: Id) {}
fn remove_stale_output(&self, _db: &DB, _executor: DatabaseKeyIndex, _stale_output_key: Id) {}
}

View file

@ -10,6 +10,7 @@ pub mod hash;
pub mod id;
pub mod ingredient;
pub mod input;
pub mod input_field;
pub mod interned;
pub mod jar;
pub mod key;

View file

@ -0,0 +1,57 @@
//! Test that a setting a field on a `#[salsa::input]`
//! overwrites and returns the old value.
use salsa_2022_tests::{HasLogger, Logger};
use expect_test::expect;
use test_log::test;
#[salsa::jar(db = Db)]
struct Jar(MyInput);
trait Db: salsa::DbWithJar<Jar> + HasLogger {}
#[salsa::input(jar = Jar)]
struct MyInput {
field: String,
}
#[salsa::db(Jar)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
logger: Logger,
}
impl salsa::Database for Database {
fn salsa_runtime(&self) -> &salsa::Runtime {
self.storage.runtime()
}
}
impl Db for Database {}
impl HasLogger for Database {
fn logger(&self) -> &Logger {
&self.logger
}
}
#[test]
fn execute() {
let mut db = Database::default();
let input = MyInput::new(&mut db, "Hello".to_string());
// Overwrite field with an empty String
// and store the old value in my_string
let mut my_string = input.set_field(&mut db, String::new());
my_string.push_str(" World!");
// Set the field back to out initial String,
// expecting to get the empty one back
assert_eq!(input.set_field(&mut db, my_string), "");
// Check if the stored String is the one we expected
assert_eq!(input.field(&db), "Hello World!");
}