diff --git a/components/salsa-macros/src/database_storage.rs b/components/salsa-macros/src/database_storage.rs index 523b4dfd..d6ba38b4 100644 --- a/components/salsa-macros/src/database_storage.rs +++ b/components/salsa-macros/src/database_storage.rs @@ -1,5 +1,6 @@ use crate::parenthesized::Parenthesized; use proc_macro::TokenStream; +use proc_macro2::Span; use syn::parse::{Parse, ParseStream, Peek}; use syn::{Attribute, Ident, Path, Token, Visibility}; @@ -21,33 +22,34 @@ use syn::{Attribute, Ident, Path, Token, Visibility}; /// impl Database { pub(crate) fn database_storage(input: TokenStream) -> TokenStream { let DatabaseStorage { - storage_struct_name, database_name, query_groups, attributes, visibility, } = syn::parse_macro_input!(input as DatabaseStorage); + let mut output = proc_macro2::TokenStream::new(); let each_query = || { query_groups .iter() - .flat_map(|query_group| &query_group.queries) + .enumerate() + .flat_map(|(index, query_group)| query_group.queries.iter().map(move |q| (index, q))) }; - // For each query `fn foo() for FooType` create - // - // ``` - // foo: >::Storage, - // ``` - let mut fields = proc_macro2::TokenStream::new(); - for Query { - query_name, - query_type, - } in each_query() - { - fields.extend(quote! { - #query_name: <#query_type as ::salsa::Query<#database_name>>::Storage, - }); + // For each query group `foo::MyGroup` create a link to its + // `foo::MyGroupGroupStorage` + let mut storage_tuple_elements = proc_macro2::TokenStream::new(); + for query_group in &query_groups { + // rewrite the last identifier (`MyGroup`, above) to + // (e.g.) `MyGroupGroupStorage`. + let mut group_storage = query_group.query_group.clone(); + let last_ident = &group_storage.segments.last().unwrap().value().ident; + let storage_ident = Ident::new( + &format!("{}GroupStorage", last_ident.to_string()), + Span::call_site(), + ); + group_storage.segments.last_mut().unwrap().value_mut().ident = storage_ident; + storage_tuple_elements.extend(quote! { #group_storage, }); } let mut attrs = proc_macro2::TokenStream::new(); @@ -55,15 +57,6 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { attrs.extend(quote! { #attr }); } - // Create the storage struct defintion - let mut output = quote! { - #[derive(Default)] - #attrs - #visibility struct #storage_struct_name { - #fields - } - }; - // create query descriptor wrapper struct output.extend(quote! { #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -79,10 +72,13 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { // foo(>::Key), // ``` let mut variants = proc_macro2::TokenStream::new(); - for Query { - query_name, - query_type, - } in each_query() + for ( + _, + Query { + query_name, + query_type, + }, + ) in each_query() { variants.extend(quote!( #query_name(<#query_type as ::salsa::Query<#database_name>>::Key), @@ -99,15 +95,15 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { output.extend(quote! { impl ::salsa::plumbing::DatabaseStorageTypes for #database_name { type QueryDescriptor = __SalsaQueryDescriptor; - type DatabaseStorage = #storage_struct_name; + type DatabaseStorage = (#storage_tuple_elements); } }); // let mut for_each_ops = proc_macro2::TokenStream::new(); - for Query { query_name, .. } in each_query() { + for (group_index, Query { query_name, .. }) in each_query() { for_each_ops.extend(quote! { - op(&::salsa::Database::salsa_runtime(self).storage().#query_name); + op(&::salsa::Database::salsa_runtime(self).storage().#group_index.#query_name); }); } output.extend(quote! { @@ -122,15 +118,18 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { }); let mut for_each_query_desc = proc_macro2::TokenStream::new(); - for Query { - query_name, - query_type, - } in each_query() + for ( + group_index, + Query { + query_name, + query_type, + }, + ) in each_query() { for_each_query_desc.extend(quote! { __SalsaQueryDescriptorKind::#query_name(key) => { let runtime = ::salsa::Database::salsa_runtime(db); - let storage = &runtime.storage().#query_name; + let storage = &runtime.storage().#group_index.#query_name; <_ as ::salsa::plumbing::QueryStorageOps<#database_name, #query_type>>::maybe_changed_since( storage, db, @@ -157,10 +156,13 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { }); let mut for_each_query_table = proc_macro2::TokenStream::new(); - for Query { - query_name, - query_type, - } in each_query() + for ( + group_index, + Query { + query_name, + query_type, + }, + ) in each_query() { for_each_query_table.extend(quote! { impl ::salsa::plumbing::GetQueryTable<#query_type> for #database_name { @@ -171,6 +173,7 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { db, &::salsa::Database::salsa_runtime(db) .storage() + .#group_index .#query_name, ) } @@ -183,6 +186,7 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { db, &::salsa::Database::salsa_runtime(db) .storage() + .#group_index .#query_name, ) } @@ -205,7 +209,6 @@ pub(crate) fn database_storage(input: TokenStream) -> TokenStream { } struct DatabaseStorage { - storage_struct_name: Ident, database_name: Path, query_groups: Vec, attributes: Vec, @@ -213,7 +216,7 @@ struct DatabaseStorage { } struct QueryGroup { - _query_group: Path, + query_group: Path, queries: Vec, } @@ -227,7 +230,7 @@ impl Parse for DatabaseStorage { let attributes = input.call(Attribute::parse_outer)?; let visibility = input.parse()?; let _struct_token: Token![struct ] = input.parse()?; - let storage_struct_name: Ident = input.parse()?; + let _storage_struct_name: Ident = input.parse()?; let _for_token: Token![for ] = input.parse()?; let database_name: Path = input.parse()?; let content; @@ -236,7 +239,6 @@ impl Parse for DatabaseStorage { Ok(DatabaseStorage { attributes, visibility, - storage_struct_name, database_name, query_groups, }) @@ -257,7 +259,7 @@ impl Parse for QueryGroup { syn::braced!(content in input); let queries: Vec = parse_while(Token![fn ], &content)?; Ok(QueryGroup { - _query_group: query_group, + query_group, queries, }) } diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 62b4b237..33dd3707 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -122,6 +122,8 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream let mut query_fn_declarations = proc_macro2::TokenStream::new(); let mut query_fn_definitions = proc_macro2::TokenStream::new(); + let mut query_descriptor_variants = proc_macro2::TokenStream::new(); + let mut storage_fields = proc_macro2::TokenStream::new(); for query in &queries { let key_names: &Vec<_> = &(0..query.keys.len()) .map(|i| Ident::new(&format!("key{}", i), Span::call_site())) @@ -139,9 +141,21 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream query_fn_definitions.extend(quote! { fn #fn_name(&self, #(#key_names: #keys),*) -> #value { - >::get_query_table(self).get((#(#key_names),*)) + >::get_query_table(self).get((#(#key_names),*)) } }); + + // A variant for the group descriptor below + query_descriptor_variants.extend(quote! { + #qt(<#qt as ::salsa::Query<__DB>>::Key), + }); + + // A field for the storage struct + // + // FIXME(#120): the pub should not be necessary once we complete the transition + storage_fields.extend(quote! { + pub #fn_name: <#qt as ::salsa::Query<__DB>>::Storage, + }); } // Emit the trait itself. @@ -230,6 +244,29 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream } } + // Emit query group descriptor + //let group_descriptor = Ident::new( + // &format!("{}GroupDescriptor", trait_name.to_string()), + // Span::call_site(), + //); + //output.extend(quote! { + // #trait_vis enum #group_descriptor<__DB: #trait_name> { + // #query_descriptor_variants + // } + //}); + + // Emit query group storage struct + let group_storage = Ident::new( + &format!("{}GroupStorage", trait_name.to_string()), + Span::call_site(), + ); + output.extend(quote! { + #[derive(Default)] + #trait_vis struct #group_storage<__DB: #trait_name> { + #storage_fields + } + }); + output.into() } diff --git a/tests/cycles.rs b/tests/cycles.rs index 947a80ea..b63694a2 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -1,5 +1,5 @@ #[derive(Default)] -pub struct DatabaseImpl { +struct DatabaseImpl { runtime: salsa::Runtime, } @@ -10,7 +10,7 @@ impl salsa::Database for DatabaseImpl { } salsa::database_storage! { - pub struct DatabaseImplStorage for DatabaseImpl { + struct DatabaseImplStorage for DatabaseImpl { impl Database { fn memoized_a() for MemoizedAQuery; fn memoized_b() for MemoizedBQuery; diff --git a/tests/gc/db.rs b/tests/gc/db.rs index 83d4d959..d5528066 100644 --- a/tests/gc/db.rs +++ b/tests/gc/db.rs @@ -2,7 +2,7 @@ use crate::group; use crate::log::{HasLog, Log}; #[derive(Default)] -pub struct DatabaseImpl { +pub(crate) struct DatabaseImpl { runtime: salsa::Runtime, log: Log, } @@ -14,7 +14,7 @@ impl salsa::Database for DatabaseImpl { } salsa::database_storage! { - pub struct DatabaseImplStorage for DatabaseImpl { + pub(crate) struct DatabaseImplStorage for DatabaseImpl { impl group::GcDatabase { fn min() for group::MinQuery; fn max() for group::MaxQuery; diff --git a/tests/incremental/implementation.rs b/tests/incremental/implementation.rs index d9dad5d6..9c2bbb80 100644 --- a/tests/incremental/implementation.rs +++ b/tests/incremental/implementation.rs @@ -42,7 +42,7 @@ salsa::database_storage! { pub(crate) struct TestContextImplStorage for TestContextImpl { impl constants::ConstantsDatabase { fn constants_input() for constants::ConstantsInputQuery; - fn constants_derived() for constants::ConstantsAddQuery; + fn constants_add() for constants::ConstantsAddQuery; } impl memoized_dep_inputs::MemoizedDepInputsContext { diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 99549034..8685948a 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -185,7 +185,7 @@ fn snapshot_me(db: &impl ParDatabase) { } #[derive(Default)] -pub struct ParDatabaseImpl { +pub(crate) struct ParDatabaseImpl { runtime: salsa::Runtime, knobs: KnobsStruct, } @@ -235,7 +235,7 @@ impl Knobs for ParDatabaseImpl { } salsa::database_storage! { - pub struct DatabaseImplStorage for ParDatabaseImpl { + pub(crate) struct DatabaseImplStorage for ParDatabaseImpl { impl ParDatabase { fn input() for InputQuery; fn sum() for SumQuery; diff --git a/tests/storage_varieties/implementation.rs b/tests/storage_varieties/implementation.rs index a6e56ab8..cf74bc36 100644 --- a/tests/storage_varieties/implementation.rs +++ b/tests/storage_varieties/implementation.rs @@ -2,13 +2,13 @@ use crate::queries; use std::cell::Cell; #[derive(Default)] -pub struct DatabaseImpl { +pub(crate) struct DatabaseImpl { runtime: salsa::Runtime, counter: Cell, } salsa::database_storage! { - pub struct DatabaseImplStorage for DatabaseImpl { + pub(crate) struct DatabaseImplStorage for DatabaseImpl { impl queries::Database { fn memoized() for queries::MemoizedQuery; fn volatile() for queries::VolatileQuery;