Merge pull request #165 from matklad/requires

allow private requirements in query groups
This commit is contained in:
Niko Matsakis 2019-05-30 11:49:41 +02:00 committed by GitHub
commit e9c787e2b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 26 deletions

View file

@ -1,6 +1,6 @@
//! This crate provides salsa's macros and attributes.
#![recursion_limit = "128"]
#![recursion_limit = "256"]
extern crate proc_macro;
extern crate proc_macro2;

View file

@ -1,17 +1,34 @@
use std::convert::TryFrom;
use crate::parenthesized::Parenthesized;
use heck::CamelCase;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{parse_macro_input, parse_quote, FnArg, Ident, ItemTrait, ReturnType, TraitItem, Type};
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, Attribute, FnArg, Ident, ItemTrait, Path,
ReturnType, Token, TraitBound, TraitBoundModifier, TraitItem, Type, TypeParamBound,
};
/// Implementation for `[salsa::query_group]` decorator.
pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
let group_struct: Ident = parse_macro_input!(args as Ident);
let group_struct = parse_macro_input!(args as Ident);
let input: ItemTrait = parse_macro_input!(input as ItemTrait);
// println!("args: {:#?}", args);
// println!("input: {:#?}", input);
let (trait_attrs, salsa_attrs) = filter_attrs(input.attrs);
let mut requires: Punctuated<Path, Token![+]> = Punctuated::new();
for SalsaAttr { name, tts } in salsa_attrs {
match name.as_str() {
"requires" => {
requires.push(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
_ => panic!("unknown salsa attribute `{}`", name),
}
}
let trait_vis = input.vis;
let trait_name = input.ident;
let _generics = input.generics.clone();
@ -30,19 +47,8 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
let mut num_storages = 0;
// Extract attributes.
let mut attrs = vec![];
for attr in method.attrs {
// Leave non-salsa attributes untouched. These are
// attributes that don't start with `salsa::` or don't have
// exactly two segments in their path.
if is_not_salsa_attr_path(&attr.path) {
attrs.push(attr);
continue;
}
// Keep the salsa attributes around.
let name = attr.path.segments[1].ident.to_string();
let tts = attr.tts.into();
let (attrs, salsa_attrs) = filter_attrs(method.attrs);
for SalsaAttr { name, tts } in salsa_attrs {
match name.as_str() {
"memoized" => {
storage = QueryStorage::Memoized;
@ -297,10 +303,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
// Emit the trait itself.
let mut output = {
let attrs = &input.attrs;
let bounds = &input.supertraits;
quote! {
#(#attrs)*
#(#trait_attrs)*
#trait_vis trait #trait_name : #bounds {
#query_fn_declarations
}
@ -314,7 +319,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<DB__> salsa::plumbing::QueryGroup<DB__> for #group_struct
where
DB__: #trait_name,
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
DB__: salsa::Database,
{
@ -325,7 +330,15 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
// Emit an impl of the trait
output.extend({
let bounds = &input.supertraits;
let mut bounds = input.supertraits.clone();
for path in requires.clone() {
bounds.push(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: TraitBoundModifier::None,
lifetimes: None,
path,
}));
}
quote! {
impl<T> #trait_name for T
where
@ -365,7 +378,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<#db> salsa::Query<#db> for #qt
where
DB: #trait_name,
DB: #trait_name + #requires,
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
DB: salsa::Database,
{
@ -401,7 +414,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
output.extend(quote_spanned! {span=>
impl<DB> salsa::plumbing::QueryFunction<DB> for #qt
where
DB: #trait_name,
DB: #trait_name + #requires,
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
DB: salsa::Database,
{
@ -430,7 +443,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
revision: salsa::plumbing::Revision,
) -> bool
where
DB__: #trait_name,
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
{
match self {
@ -456,7 +469,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
output.extend(quote! {
#trait_vis struct #group_storage<DB__>
where
DB__: #trait_name,
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
DB__: salsa::Database,
{
@ -465,7 +478,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<DB__> Default for #group_storage<DB__>
where
DB__: #trait_name,
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
DB__: salsa::Database,
{
@ -479,7 +492,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
impl<DB__> #group_storage<DB__>
where
DB__: #trait_name,
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
{
#trait_vis fn for_each_query(
@ -501,6 +514,24 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
output.into()
}
struct SalsaAttr {
name: String,
tts: TokenStream,
}
impl TryFrom<syn::Attribute> for SalsaAttr {
type Error = syn::Attribute;
fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
if is_not_salsa_attr_path(&attr.path) {
return Err(attr);
}
let name = attr.path.segments[1].ident.to_string();
let tts = attr.tts.into();
Ok(SalsaAttr { name, tts })
}
}
fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
path.segments
.first()
@ -509,6 +540,22 @@ fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
|| path.segments.len() != 2
}
fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
let mut other = vec![];
let mut salsa = vec![];
// Leave non-salsa attributes untouched. These are
// attributes that don't start with `salsa::` or don't have
// exactly two segments in their path.
// Keep the salsa attributes around.
for attr in attrs {
match SalsaAttr::try_from(attr) {
Ok(it) => salsa.push(it),
Err(it) => other.push(it),
}
}
(other, salsa)
}
#[derive(Debug)]
struct Query {
fn_name: Ident,

67
tests/requires.rs Normal file
View file

@ -0,0 +1,67 @@
//! Test `salsa::requires` attribute for private query dependencies
//! https://github.com/salsa-rs/salsa-rfcs/pull/3
mod queries {
#[salsa::query_group(InputGroupStorage)]
pub trait InputGroup {
#[salsa::input]
fn input(&self, x: u32) -> u32;
}
#[salsa::query_group(PrivGroupAStorage)]
pub trait PrivGroupA: InputGroup {
fn private_a(&self, x: u32) -> u32;
}
fn private_a(db: &impl PrivGroupA, x: u32) -> u32{
db.input(x)
}
#[salsa::query_group(PrivGroupBStorage)]
pub trait PrivGroupB: InputGroup {
fn private_b(&self, x: u32) -> u32;
}
fn private_b(db: &impl PrivGroupB, x: u32) -> u32{
db.input(x)
}
#[salsa::query_group(PubGroupStorage)]
#[salsa::requires(PrivGroupA)]
#[salsa::requires(PrivGroupB)]
pub trait PubGroup: InputGroup {
fn public(&self, x: u32) -> u32;
}
fn public(db: &(impl PubGroup + PrivGroupA + PrivGroupB), x: u32) -> u32 {
db.private_a(x) + db.private_b(x)
}
}
#[salsa::database(
queries::InputGroupStorage,
queries::PrivGroupAStorage,
queries::PrivGroupBStorage,
queries::PubGroupStorage,
)]
#[derive(Default)]
struct Database {
runtime: salsa::Runtime<Database>,
}
impl salsa::Database for Database {
fn salsa_runtime(&self) -> &salsa::Runtime<Database> {
&self.runtime
}
}
#[test]
fn require_clauses_work() {
use queries::{InputGroup, PubGroup};
let mut db = Database::default();
db.set_input(1, 10);
assert_eq!(db.public(1), 20);
}