mirror of
https://github.com/salsa-rs/salsa.git
synced 2025-02-02 09:46:06 +00:00
allow private requirements in query groups
This commit is contained in:
parent
8aa01bcccb
commit
940eed92a6
3 changed files with 122 additions and 11 deletions
|
@ -1,6 +1,6 @@
|
|||
//! This crate provides salsa's macros and attributes.
|
||||
|
||||
#![recursion_limit = "128"]
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
extern crate proc_macro;
|
||||
extern crate proc_macro2;
|
||||
|
|
|
@ -3,11 +3,19 @@ use heck::CamelCase;
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro2::Span;
|
||||
use quote::ToTokens;
|
||||
use syn::{parse_macro_input, parse_quote, FnArg, Ident, ItemTrait, ReturnType, TraitItem, Type};
|
||||
use syn::parse::{Parse, ParseStream};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::{
|
||||
parse_macro_input, parse_quote, FnArg, Ident, ItemTrait, Lit, MetaNameValue, Path, TypeParamBound,
|
||||
ReturnType, Token, TraitItem, Type, TraitBound, TraitBoundModifier
|
||||
};
|
||||
|
||||
/// Implementation for `[salsa::query_group]` decorator.
|
||||
pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let group_struct: Ident = parse_macro_input!(args as Ident);
|
||||
let GroupDef {
|
||||
group_struct,
|
||||
requires,
|
||||
} = parse_macro_input!(args as GroupDef);
|
||||
let input: ItemTrait = parse_macro_input!(input as ItemTrait);
|
||||
// println!("args: {:#?}", args);
|
||||
// println!("input: {:#?}", input);
|
||||
|
@ -314,7 +322,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
|
||||
impl<DB__> salsa::plumbing::QueryGroup<DB__> for #group_struct
|
||||
where
|
||||
DB__: #trait_name,
|
||||
DB__: #trait_name + #requires,
|
||||
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
DB__: salsa::Database,
|
||||
{
|
||||
|
@ -325,7 +333,15 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
|
||||
// Emit an impl of the trait
|
||||
output.extend({
|
||||
let bounds = &input.supertraits;
|
||||
let mut bounds = input.supertraits.clone();
|
||||
for path in requires.clone() {
|
||||
bounds.push(TypeParamBound::Trait(TraitBound {
|
||||
paren_token: None,
|
||||
modifier: TraitBoundModifier::None,
|
||||
lifetimes: None,
|
||||
path,
|
||||
}));
|
||||
}
|
||||
quote! {
|
||||
impl<T> #trait_name for T
|
||||
where
|
||||
|
@ -365,7 +381,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
|
||||
impl<#db> salsa::Query<#db> for #qt
|
||||
where
|
||||
DB: #trait_name,
|
||||
DB: #trait_name + #requires,
|
||||
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
DB: salsa::Database,
|
||||
{
|
||||
|
@ -401,7 +417,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
output.extend(quote_spanned! {span=>
|
||||
impl<DB> salsa::plumbing::QueryFunction<DB> for #qt
|
||||
where
|
||||
DB: #trait_name,
|
||||
DB: #trait_name + #requires,
|
||||
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
DB: salsa::Database,
|
||||
{
|
||||
|
@ -430,7 +446,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
revision: salsa::plumbing::Revision,
|
||||
) -> bool
|
||||
where
|
||||
DB__: #trait_name,
|
||||
DB__: #trait_name + #requires,
|
||||
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
{
|
||||
match self {
|
||||
|
@ -456,7 +472,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
output.extend(quote! {
|
||||
#trait_vis struct #group_storage<DB__>
|
||||
where
|
||||
DB__: #trait_name,
|
||||
DB__: #trait_name + #requires,
|
||||
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
DB__: salsa::Database,
|
||||
{
|
||||
|
@ -465,7 +481,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
|
||||
impl<DB__> Default for #group_storage<DB__>
|
||||
where
|
||||
DB__: #trait_name,
|
||||
DB__: #trait_name + #requires,
|
||||
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
DB__: salsa::Database,
|
||||
{
|
||||
|
@ -479,7 +495,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
|
|||
|
||||
impl<DB__> #group_storage<DB__>
|
||||
where
|
||||
DB__: #trait_name,
|
||||
DB__: #trait_name + #requires,
|
||||
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
|
||||
{
|
||||
#trait_vis fn for_each_query(
|
||||
|
@ -509,6 +525,37 @@ fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
|
|||
|| path.segments.len() != 2
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GroupDef {
|
||||
group_struct: Ident,
|
||||
requires: Punctuated<Path, Token![+]>,
|
||||
}
|
||||
|
||||
impl Parse for GroupDef {
|
||||
fn parse(input: ParseStream) -> syn::Result<GroupDef> {
|
||||
let res = GroupDef {
|
||||
group_struct: input.parse()?,
|
||||
requires: {
|
||||
if input.lookahead1().peek(Token![,]) {
|
||||
input.parse::<Token![,]>()?;
|
||||
let name_value: MetaNameValue = input.parse()?;
|
||||
if name_value.ident != "requires" {
|
||||
return Err(syn::Error::new_spanned(name_value, "invalid attribute"));
|
||||
}
|
||||
let str_lit = match name_value.lit {
|
||||
Lit::Str(it) => it,
|
||||
_ => return Err(syn::Error::new_spanned(name_value, "invalid attribute")),
|
||||
};
|
||||
str_lit.parse_with(Punctuated::<Path, Token![+]>::parse_separated_nonempty)?
|
||||
} else {
|
||||
Punctuated::new()
|
||||
}
|
||||
},
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Query {
|
||||
fn_name: Ident,
|
||||
|
|
64
tests/requires.rs
Normal file
64
tests/requires.rs
Normal file
|
@ -0,0 +1,64 @@
|
|||
//! Test that transparent (uncached) queries work
|
||||
|
||||
|
||||
mod queries {
|
||||
#[salsa::query_group(InputGroupStorage)]
|
||||
pub trait InputGroup {
|
||||
#[salsa::input]
|
||||
fn input(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
#[salsa::query_group(PrivGroupAStorage)]
|
||||
pub trait PrivGroupA: InputGroup {
|
||||
fn private_a(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
fn private_a(db: &impl PrivGroupA, x: u32) -> u32{
|
||||
db.input(x)
|
||||
}
|
||||
|
||||
#[salsa::query_group(PrivGroupBStorage)]
|
||||
pub trait PrivGroupB: InputGroup {
|
||||
fn private_b(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
fn private_b(db: &impl PrivGroupB, x: u32) -> u32{
|
||||
db.input(x)
|
||||
}
|
||||
|
||||
#[salsa::query_group(PubGroupStorage, requires = "PrivGroupA + PrivGroupB")]
|
||||
pub trait PubGroup: InputGroup {
|
||||
fn public(&self, x: u32) -> u32;
|
||||
}
|
||||
|
||||
|
||||
fn public(db: &(impl PubGroup + PrivGroupA + PrivGroupB), x: u32) -> u32 {
|
||||
db.private_a(x) + db.private_b(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[salsa::database(
|
||||
queries::InputGroupStorage,
|
||||
queries::PrivGroupAStorage,
|
||||
queries::PrivGroupBStorage,
|
||||
queries::PubGroupStorage,
|
||||
)]
|
||||
#[derive(Default)]
|
||||
struct Database {
|
||||
runtime: salsa::Runtime<Database>,
|
||||
}
|
||||
|
||||
impl salsa::Database for Database {
|
||||
fn salsa_runtime(&self) -> &salsa::Runtime<Database> {
|
||||
&self.runtime
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn require_clauses_work() {
|
||||
use queries::{InputGroup, PubGroup};
|
||||
let mut db = Database::default();
|
||||
|
||||
db.set_input(1, 10);
|
||||
assert_eq!(db.public(1), 20);
|
||||
}
|
Loading…
Reference in a new issue