This commit is contained in:
Niko Matsakis 2024-07-11 07:30:26 -04:00
parent 15d5f213c5
commit 2cfb75837b
10 changed files with 338 additions and 106 deletions

View file

@ -11,6 +11,7 @@ description = "Procedural macros for the salsa crate"
proc-macro = true
[dependencies]
heck = "0.5.0"
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2.0.64", features = ["full", "visit-mut"] }

View file

@ -1,5 +1,7 @@
use proc_macro2::{Literal, TokenStream};
use syn::{parse::Nothing, spanned::Spanned, Token};
use proc_macro2::TokenStream;
use syn::{parse::Nothing, ItemStruct};
use crate::hygiene::Hygiene;
// Source:
//
@ -12,37 +14,35 @@ pub(crate) fn db(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let args = syn::parse_macro_input!(args as Args);
let input = syn::parse_macro_input!(input as syn::ItemStruct);
match args.try_db(&input) {
Ok(v) => quote! { #input #v }.into(),
Err(e) => {
let error = e.to_compile_error();
quote! { #input #error }.into()
}
let _nothing = syn::parse_macro_input!(args as Nothing);
let db_macro = DbMacro {
hygiene: Hygiene::from(&input),
input: syn::parse_macro_input!(input as syn::ItemStruct),
};
match db_macro.try_db() {
Ok(v) => v.into(),
Err(e) => e.to_compile_error().into(),
}
}
pub struct Args {}
impl syn::parse::Parse for Args {
fn parse(_input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Args {})
}
struct DbMacro {
hygiene: Hygiene,
input: ItemStruct,
}
impl Args {
fn try_db(self, input: &syn::ItemStruct) -> syn::Result<TokenStream> {
let storage = self.find_storage_field(input)?;
impl DbMacro {
fn try_db(self) -> syn::Result<TokenStream> {
let has_storage_impl = self.has_storage_impl()?;
let input = self.input;
Ok(quote! {
#has_storage_impl
#input
})
}
fn find_storage_field(&self, input: &syn::ItemStruct) -> syn::Result<syn::Ident> {
fn find_storage_field(&self) -> syn::Result<syn::Ident> {
let storage = "storage";
for field in input.fields.iter() {
for field in self.input.fields.iter() {
if let Some(i) = &field.ident {
if i == storage {
return Ok(i.clone());
@ -56,8 +56,34 @@ impl Args {
}
return Err(syn::Error::new_spanned(
&input.ident,
&self.input.ident,
"database struct must be a braced struct (`{}`) with a field named `storage`",
));
}
#[allow(non_snake_case)]
fn has_storage_impl(&self) -> syn::Result<TokenStream> {
let storage = self.find_storage_field()?;
let db = &self.input.ident;
let SalsaHasStorage = self.hygiene.ident("SalsaHasStorage");
let SalsaStorage = self.hygiene.ident("SalsaStorage");
Ok(quote! {
const _: () = {
use salsa::storage::HasStorage as #SalsaHasStorage;
use salsa::storage::Storage as #SalsaStorage;
unsafe impl #SalsaHasStorage for #db {
fn storage(&self) -> &#SalsaStorage<Self> {
&self.#storage
}
fn storage_mut(&mut self) -> &mut #SalsaStorage<Self> {
&mut self.#storage
}
}
};
})
}
}

View file

@ -1,66 +1,153 @@
use proc_macro2::{Literal, TokenStream};
use syn::{spanned::Spanned, Token};
use std::fmt::Display;
use heck::ToSnakeCase;
use proc_macro2::{Span, TokenStream};
use crate::hygiene::Hygiene;
// Source:
//
// ```
// #[salsa::db_view]
// pub trait Db: salsa::DatabaseView<dyn Db> + ... {
// pub trait $Db: ... {
// ...
// }
// ```
//
// becomes
//
// ```
// pub trait $Db: __SalsaViewAs$Db__ {
// ...
// }
//
// pub trait __SalsaViewAs$Db__ {
// fn __salsa_add_view_for_$db__(&self);
// }
//
// impl<T: Db> __SalsaViewAs$Db__ for T {
// ...
// }
// ```
pub(crate) fn db_view(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let args: TokenStream = args.into();
let input = syn::parse_macro_input!(input as syn::ItemTrait);
match try_db_view(args, input) {
syn::parse_macro_input!(args as syn::parse::Nothing);
let db_view_macro = DbViewMacro::new(
Hygiene::from(&input),
syn::parse_macro_input!(input as syn::ItemTrait),
);
match db_view_macro.expand() {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn try_db_view(args: TokenStream, input: syn::ItemTrait) -> syn::Result<TokenStream> {
if let Some(token) = args.into_iter().next() {
return Err(syn::Error::new_spanned(token, "unexpected token"));
}
// FIXME: check for `salsa::DataviewView<dyn Db>` supertrait?
let view_impl = view_impl(&input);
Ok(quote! {
#input
#view_impl
})
#[allow(non_snake_case)]
pub(crate) struct DbViewMacro {
hygiene: Hygiene,
input: syn::ItemTrait,
DbViewTrait: syn::Ident,
db_view_method: syn::Ident,
}
#[allow(non_snake_case)]
fn view_impl(input: &syn::ItemTrait) -> syn::Item {
let DB = syn::Ident::new("_DB", proc_macro2::Span::call_site());
let Database = syn::Ident::new("_Database", proc_macro2::Span::call_site());
let DatabaseView = syn::Ident::new("_DatabaseView", proc_macro2::Span::call_site());
let upcasts = syn::Ident::new("_upcasts", proc_macro2::Span::call_site());
let UserTrait = &input.ident;
impl DbViewMacro {
// This is a case where our hygiene mechanism is inadequate.
//
// We cannot know whether `DbViewTrait` is defined elsewhere
// in the module.
//
// Therefore we give it a dorky name.
parse_quote! {
const _: () = {
use salsa::DatabaseView as #DatabaseView;
use salsa::Database as #Database;
pub(crate) fn db_view_trait_name(input: &impl Display) -> syn::Ident {
syn::Ident::new(&format!("__SalsaAddView{}__", input), Span::call_site())
}
impl<#DB: #Database> #DatabaseView<dyn #UserTrait> for #DB {
fn add_view_to_db(&self) {
let #upcasts = self.upcasts_for_self();
#upcasts.add::<dyn #UserTrait>(|t| t, |t| t);
}
pub(crate) fn db_view_method_name(input: &impl Display) -> syn::Ident {
syn::Ident::new(
&format!("__salsa_add_view_{}__", input.to_string().to_snake_case()),
Span::call_site(),
)
}
fn new(hygiene: Hygiene, input: syn::ItemTrait) -> Self {
Self {
DbViewTrait: Self::db_view_trait_name(&input.ident),
db_view_method: Self::db_view_method_name(&input.ident),
hygiene,
input,
}
}
fn expand(mut self) -> syn::Result<TokenStream> {
self.add_supertrait();
let view_impl = self.view_impl();
let view_trait = self.view_trait();
let input = self.input;
Ok(quote! {
#input
#view_trait
#view_impl
})
}
fn add_supertrait(&mut self) {
let Self { DbViewTrait, .. } = self;
self.input.supertraits.push(parse_quote! { #DbViewTrait })
}
fn view_trait(&self) -> syn::ItemTrait {
let Self {
DbViewTrait,
db_view_method,
..
} = self;
let vis = &self.input.vis;
parse_quote! {
/// Internal salsa method generated by the `salsa::db_view` macro
/// that registers this database view trait with the salsa database.
///
/// Nothing to see here.
#[doc(hidden)]
#vis trait #DbViewTrait {
fn #db_view_method(&self);
}
};
}
}
pub struct Args {}
impl syn::parse::Parse for Args {
fn parse(_input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self {})
}
}
fn view_impl(&self) -> syn::Item {
let Self {
DbViewTrait,
db_view_method,
..
} = self;
let DB = self.hygiene.ident("DB");
let Database = self.hygiene.ident("Database");
let views = self.hygiene.ident("views");
let UserTrait = &self.input.ident;
parse_quote! {
const _: () = {
use salsa::Database as #Database;
#[doc(hidden)]
impl<#DB: #Database> #DbViewTrait for #DB {
/// Internal salsa method generated by the `salsa::db_view` macro
/// that registers this database view trait with the salsa database.
///
/// Nothing to see here.
fn #db_view_method(&self) {
let #views = self.views_of_self();
#views.add::<dyn #UserTrait>(|t| t, |t| t);
}
}
};
}
}
}

View file

@ -0,0 +1,41 @@
use std::collections::HashSet;
pub struct Hygiene {
user_tokens: HashSet<String>,
}
impl From<&proc_macro::TokenStream> for Hygiene {
fn from(input: &proc_macro::TokenStream) -> Self {
let mut user_tokens = HashSet::new();
push_idents(input.clone(), &mut user_tokens);
Self { user_tokens }
}
}
fn push_idents(input: proc_macro::TokenStream, user_tokens: &mut HashSet<String>) {
input.into_iter().for_each(|token| match token {
proc_macro::TokenTree::Group(g) => {
push_idents(g.stream(), user_tokens);
}
proc_macro::TokenTree::Ident(ident) => {
user_tokens.insert(ident.to_string());
}
proc_macro::TokenTree::Punct(_) => (),
proc_macro::TokenTree::Literal(_) => (),
})
}
impl Hygiene {
/// Generates an identifier similar to `text` but
/// distinct from any identifiers that appear in the user's
/// code.
pub(crate) fn ident(&self, text: &str) -> syn::Ident {
let mut buffer = String::from(text);
while self.user_tokens.contains(&buffer) {
buffer.push('_');
}
syn::Ident::new(&buffer, proc_macro2::Span::call_site())
}
}

View file

@ -7,6 +7,8 @@ extern crate proc_macro2;
#[macro_use]
extern crate quote;
mod hygiene;
use proc_macro::TokenStream;
macro_rules! parse_quote {

View file

@ -47,7 +47,7 @@ pub trait DatabaseView<Dyn: ?Sized + Any>: Database {
impl<Db: Database> DatabaseView<dyn Database> for Db {
fn add_view_to_db(&self) {
let upcasts = self.upcasts_for_self();
let upcasts = self.views_of_self();
upcasts.add::<dyn Database>(|t| t, |t| t);
}
}

View file

@ -3,7 +3,6 @@ mod alloc;
pub mod cancelled;
pub mod cycle;
pub mod database;
mod downcast;
pub mod durability;
pub mod event;
pub mod function;
@ -24,6 +23,7 @@ pub mod setter;
pub mod storage;
pub mod tracked_struct;
pub mod update;
mod views;
pub use self::cancelled::Cancelled;
pub use self::cycle::Cycle;
@ -41,6 +41,7 @@ pub use self::runtime::Runtime;
pub use self::storage::Storage;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::db_view;
pub use salsa_macros::input;
pub use salsa_macros::interned;
pub use salsa_macros::jar;

View file

@ -6,41 +6,52 @@ use parking_lot::{Condvar, Mutex};
use rustc_hash::FxHashMap;
use crate::cycle::CycleRecoveryStrategy;
use crate::downcast::{DynDowncasts, DynDowncastsFor};
use crate::ingredient::{Ingredient, Jar};
use crate::nonce::{Nonce, NonceGenerator};
use crate::runtime::Runtime;
use crate::views::{Views, ViewsOf};
use crate::Database;
use super::ParallelDatabase;
/// Salsa database methods that are generated by the `#[salsa::database]` procedural macro.
/// Salsa database methods whose implementation is generated by
/// the `#[salsa::database]` procedural macro.
///
/// # Safety
///
/// This trait is meant to be implemented by our procedural macro.
/// We need to document any non-obvious conditions that it satisfies.
pub unsafe trait DatabaseGen: Any + Send + Sync {
/// Upcast to a `dyn Database`.
///
/// Only required because upcasts not yet stabilized (*grr*).
///
/// # Safety
///
/// Returns the same data pointer as `self`.
fn as_salsa_database(&self) -> &dyn Database;
/// Returns a reference to the underlying "dyn-upcasts"
fn upcasts(&self) -> &DynDowncasts;
/// Upcast to a `dyn DatabaseGen`.
///
/// Only required because upcasts not yet stabilized (*grr*).
///
/// # Ensures
/// # Safety
///
/// Returns the same data pointer as `self`.
fn upcast_to_dyn_database_gen(&self) -> &dyn DatabaseGen;
fn as_salsa_database_gen(&self) -> &dyn DatabaseGen;
/// Returns a reference to the underlying.
fn views(&self) -> &Views;
/// Returns the upcasts database, tied to the type of `Self`; cannot be used from `dyn DatabaseGen` objects.
fn upcasts_for_self(&self) -> &DynDowncastsFor<Self>
fn views_of_self(&self) -> &ViewsOf<Self>
where
Self: Sized + Database;
/// Returns the nonce for the underyling storage.
///
/// # Safety
///
/// This nonce is guaranteed to be unique for the database and never to be reused.
fn nonce(&self) -> Nonce<StorageNonce>;
@ -61,6 +72,65 @@ pub unsafe trait DatabaseGen: Any + Send + Sync {
fn runtime_mut(&mut self) -> &mut Runtime;
}
/// This is the *actual* trait that the macro generates.
/// It simply gives access to the internal storage.
/// Note that it is NOT a supertrait of `Database`
/// because it is not `dyn`-safe.
///
/// # Safety
///
/// The `storage` field must be an owned field of
/// the implementing struct.
pub unsafe trait HasStorage: Database + Sized + Any + Send + Sync {
fn storage(&self) -> &Storage<Self>;
fn storage_mut(&self) -> &mut Storage<Self>;
}
unsafe impl<T: HasStorage> DatabaseGen for T {
fn as_salsa_database(&self) -> &dyn Database {
self
}
fn as_salsa_database_gen(&self) -> &dyn DatabaseGen {
self
}
fn views(&self) -> &Views {
&self.storage().shared.upcasts
}
fn views_of_self(&self) -> &ViewsOf<Self>
where
Self: Sized + Database,
{
&self.storage().shared.upcasts
}
fn nonce(&self) -> Nonce<StorageNonce> {
self.storage().shared.nonce
}
fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> {
self.storage().lookup_jar_by_type(jar)
}
fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex {
self.storage().add_or_lookup_jar_by_type(jar)
}
fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient {
self.storage().lookup_ingredient(index)
}
fn runtime(&self) -> &Runtime {
&self.storage().runtime
}
fn runtime_mut(&mut self) -> &mut Runtime {
&mut self.storage_mut().runtime
}
}
impl dyn Database {
/// Upcasts `self` to the given view.
///
@ -68,7 +138,7 @@ impl dyn Database {
///
/// If the view has not been added to the database (see [`DatabaseView`][])
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
self.upcasts().try_cast(self).unwrap()
self.views().try_view_as(self).unwrap()
}
/// Upcasts `self` to the given view.
@ -78,8 +148,8 @@ impl dyn Database {
/// If the view has not been added to the database (see [`DatabaseView`][])
pub fn as_view_mut<DbView: ?Sized + Database>(&mut self) -> &mut DbView {
// Avoid a borrow check error by cloning. This is the "uncommon" path so it seems fine.
let upcasts = self.upcasts().clone();
upcasts.try_cast_mut(self).unwrap()
let upcasts = self.views().clone();
upcasts.try_view_as_mut(self).unwrap()
}
}
@ -138,7 +208,7 @@ pub struct Storage<Db: Database> {
/// This is where the actual data for tracked functions, structs, inputs, etc lives,
/// along with some coordination variables between treads.
struct Shared<Db: Database> {
upcasts: DynDowncastsFor<Db>,
upcasts: ViewsOf<Db>,
nonce: Nonce<StorageNonce>,
@ -198,7 +268,7 @@ impl<Db: Database> Storage<Db> {
/// Adds the ingredients in `jar` to the database if not already present.
/// If a jar of this type is already present, returns the index.
fn add_or_lookup_adapted_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex {
fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex {
let jar_type_id = jar.type_id();
let mut jar_map = self.shared.jar_map.lock();
*jar_map
@ -225,8 +295,8 @@ impl<Db: Database> Storage<Db> {
}
/// Return the index of the 1st ingredient from the given jar.
pub fn lookup_jar_by_type(&self, jar_type_id: TypeId) -> Option<IngredientIndex> {
self.shared.jar_map.lock().get(&jar_type_id).copied()
pub fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> {
self.shared.jar_map.lock().get(&jar.type_id()).copied()
}
pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient {

View file

@ -9,18 +9,18 @@ use orx_concurrent_vec::ConcurrentVec;
use crate::Database;
pub struct DynDowncastsFor<Db: Database> {
upcasts: DynDowncasts,
pub struct ViewsOf<Db: Database> {
upcasts: Views,
phantom: PhantomData<Db>,
}
#[derive(Clone)]
pub struct DynDowncasts {
pub struct Views {
source_type_id: TypeId,
vec: Arc<ConcurrentVec<Caster>>,
view_casters: Arc<ConcurrentVec<ViewCaster>>,
}
struct Caster {
struct ViewCaster {
target_type_id: TypeId,
type_name: &'static str,
func: fn(&Dummy) -> &Dummy,
@ -30,16 +30,16 @@ struct Caster {
#[allow(dead_code)]
enum Dummy {}
impl<Db: Database> Default for DynDowncastsFor<Db> {
impl<Db: Database> Default for ViewsOf<Db> {
fn default() -> Self {
Self {
upcasts: DynDowncasts::new::<Db>(),
upcasts: Views::new::<Db>(),
phantom: Default::default(),
}
}
}
impl<Db: Database> DynDowncastsFor<Db> {
impl<Db: Database> ViewsOf<Db> {
/// Add a new upcast from `Db` to `T`, given the upcasting function `func`.
pub fn add<DbView: ?Sized + Any>(
&self,
@ -50,20 +50,20 @@ impl<Db: Database> DynDowncastsFor<Db> {
}
}
impl<Db: Database> Deref for DynDowncastsFor<Db> {
type Target = DynDowncasts;
impl<Db: Database> Deref for ViewsOf<Db> {
type Target = Views;
fn deref(&self) -> &Self::Target {
&self.upcasts
}
}
impl DynDowncasts {
impl Views {
fn new<Db: Database>() -> Self {
let source_type_id = TypeId::of::<Db>();
Self {
source_type_id,
vec: Default::default(),
view_casters: Default::default(),
}
}
@ -77,11 +77,15 @@ impl DynDowncasts {
let target_type_id = TypeId::of::<DbView>();
if self.vec.iter().any(|u| u.target_type_id == target_type_id) {
if self
.view_casters
.iter()
.any(|u| u.target_type_id == target_type_id)
{
return;
}
self.vec.push(Caster {
self.view_casters.push(ViewCaster {
target_type_id,
type_name: std::any::type_name::<DbView>(),
func: unsafe { std::mem::transmute(func) },
@ -94,7 +98,7 @@ impl DynDowncasts {
/// # Panics
///
/// If the underlying type of `db` is not the same as the database type this upcasts was created for.
pub fn try_cast<'db, DbView: ?Sized + Any>(
pub fn try_view_as<'db, DbView: ?Sized + Any>(
&self,
db: &'db dyn Database,
) -> Option<&'db DbView> {
@ -102,7 +106,7 @@ impl DynDowncasts {
assert_eq!(self.source_type_id, db_type_id, "database type mismatch");
let view_type_id = TypeId::of::<DbView>();
for caster in self.vec.iter() {
for caster in self.view_casters.iter() {
if caster.target_type_id == view_type_id {
// SAFETY: We have some function that takes a thin reference to the underlying
// database type `X` and returns a (potentially wide) reference to `View`.
@ -123,7 +127,7 @@ impl DynDowncasts {
/// # Panics
///
/// If the underlying type of `db` is not the same as the database type this upcasts was created for.
pub fn try_cast_mut<'db, View: ?Sized + Any>(
pub fn try_view_as_mut<'db, View: ?Sized + Any>(
&self,
db: &'db mut dyn Database,
) -> Option<&'db mut View> {
@ -131,7 +135,7 @@ impl DynDowncasts {
assert_eq!(self.source_type_id, db_type_id, "database type mismatch");
let view_type_id = TypeId::of::<View>();
for caster in self.vec.iter() {
for caster in self.view_casters.iter() {
if caster.target_type_id == view_type_id {
// SAFETY: We have some function that takes a thin reference to the underlying
// database type `X` and returns a (potentially wide) reference to `View`.
@ -149,15 +153,15 @@ impl DynDowncasts {
}
}
impl std::fmt::Debug for DynDowncasts {
impl std::fmt::Debug for Views {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynDowncasts")
.field("vec", &self.vec)
.field("vec", &self.view_casters)
.finish()
}
}
impl std::fmt::Debug for Caster {
impl std::fmt::Debug for ViewCaster {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("DynDowncast").field(&self.type_name).finish()
}
@ -179,7 +183,7 @@ fn data_ptr_mut<T: ?Sized>(t: &mut T) -> &mut () {
unsafe { &mut *u }
}
impl<Db: Database> Clone for DynDowncastsFor<Db> {
impl<Db: Database> Clone for ViewsOf<Db> {
fn clone(&self) -> Self {
Self {
upcasts: self.upcasts.clone(),

View file

@ -8,7 +8,7 @@ use expect_test::expect;
use test_log::test;
#[salsa::db_view]
trait Db: DatabaseView<dyn Db> + HasLogger {}
trait Db: HasLogger {}
#[salsa::input]
struct MyInput {