support customizing the DebugWithDb impl

This commit is contained in:
Niko Matsakis 2024-04-03 06:23:43 -04:00
parent 389aa66bcf
commit fd15c3a600
2 changed files with 86 additions and 7 deletions

View file

@ -32,9 +32,15 @@ use syn::spanned::Spanned;
pub(crate) struct SalsaStruct<A: AllowedOptions> {
args: Options<A>,
struct_item: syn::ItemStruct,
customizations: Vec<Customization>,
fields: Vec<SalsaField>,
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub enum Customization {
DebugWithDb,
}
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
impl<A: AllowedOptions> SalsaStruct<A> {
@ -51,18 +57,42 @@ impl<A: AllowedOptions> SalsaStruct<A> {
struct_item: syn::ItemStruct,
) -> syn::Result<Self> {
let args: Options<A> = syn::parse(args)?;
let fields = Self::extract_options(&struct_item)?;
let customizations = Self::extract_customizations(&struct_item)?;
let fields = Self::extract_fields(&struct_item)?;
Ok(Self {
args,
struct_item,
customizations,
fields,
})
}
fn extract_customizations(struct_item: &syn::ItemStruct) -> syn::Result<Vec<Customization>> {
Ok(struct_item
.attrs
.iter()
.map(|attr| {
if attr.path.is_ident("customize") {
let args: syn::Ident = attr.parse_args()?;
if args.to_string() == "DebugWithDb" {
Ok(vec![Customization::DebugWithDb])
} else {
Err(syn::Error::new_spanned(args, "unrecognized customization"))
}
} else {
Ok(vec![])
}
})
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect())
}
/// Extract out the fields and their options:
/// If this is a struct, it must use named fields, so we can define field accessors.
/// If it is an enum, then this is not necessary.
pub(crate) fn extract_options(struct_item: &syn::ItemStruct) -> syn::Result<Vec<SalsaField>> {
fn extract_fields(struct_item: &syn::ItemStruct) -> syn::Result<Vec<SalsaField>> {
match &struct_item.fields {
syn::Fields::Named(n) => Ok(n
.named
@ -146,12 +176,14 @@ impl<A: AllowedOptions> SalsaStruct<A> {
let ident = self.id_ident();
let visibility = &self.struct_item.vis;
// Extract the attributes the user gave, but screen out derive, since we are adding our own.
// Extract the attributes the user gave, but screen out derive, since we are adding our own,
// and the customize attribute that we use for our own purposes.
let attrs: Vec<_> = self
.struct_item
.attrs
.iter()
.filter(|attr| !attr.path.is_ident("derive"))
.filter(|attr| !attr.path.is_ident("customize"))
.collect();
parse_quote! {
@ -214,7 +246,11 @@ impl<A: AllowedOptions> SalsaStruct<A> {
}
/// Generate `impl salsa::DebugWithDb for Foo`
pub(crate) fn as_debug_with_db_impl(&self) -> syn::ItemImpl {
pub(crate) fn as_debug_with_db_impl(&self) -> Option<syn::ItemImpl> {
if self.customizations.contains(&Customization::DebugWithDb) {
return None;
}
let ident = self.id_ident();
let db_type = self.db_dyn_ty();
@ -242,7 +278,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
.collect::<TokenStream>();
// `use ::salsa::debug::helper::Fallback` is needed for the fallback to `Debug` impl
parse_quote_spanned! {ident.span()=>
Some(parse_quote_spanned! {ident.span()=>
impl ::salsa::DebugWithDb<#db_type> for #ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>, _db: &#db_type) -> ::std::fmt::Result {
#[allow(unused_imports)]
@ -253,7 +289,7 @@ impl<A: AllowedOptions> SalsaStruct<A> {
debug_struct.finish()
}
}
}
})
}
/// Disallow `#[id]` attributes on the fields of this struct.

View file

@ -4,7 +4,13 @@ use expect_test::expect;
use salsa::DebugWithDb;
#[salsa::jar(db = Db)]
struct Jar(MyInput, ComplexStruct, leak_debug_string);
struct Jar(
MyInput,
ComplexStruct,
leak_debug_string,
DerivedCustom,
leak_derived_custom,
);
trait Db: salsa::DbWithJar<Jar> {}
@ -79,3 +85,40 @@ fn untracked_dependencies() {
let s = leak_debug_string(&db, input);
assert!(s.contains(", field: 22 }"));
}
#[salsa::tracked]
#[customize(DebugWithDb)]
struct DerivedCustom {
my_input: MyInput,
value: u32,
}
impl DebugWithDb<dyn Db + '_> for DerivedCustom {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Db) -> std::fmt::Result {
write!(
f,
"{:?} / {:?}",
self.my_input(db).debug(db),
self.value(db)
)
}
}
#[salsa::tracked]
fn leak_derived_custom(db: &dyn Db, input: MyInput, value: u32) -> String {
let c = DerivedCustom::new(db, input, value);
format!("{:?}", c.debug(db))
}
#[test]
fn custom_debug_impl() {
let db = Database::default();
let input = MyInput::new(&db, 22);
let s = leak_derived_custom(&db, input, 23);
expect![[r#"
"MyInput { [salsa id]: 0, field: 22 } / 23"
"#]]
.assert_debug_eq(&s);
}