1 use proc_macro2::TokenStream;
2 use quote::{format_ident, quote};
3 use syn::{
4     parenthesized,
5     parse::{Parse, ParseStream},
6     spanned::Spanned,
7 };
8 
9 pub enum Forward {
10     Unnamed(usize),
11     Named(syn::Ident),
12 }
13 
14 impl Parse for Forward {
parse(input: ParseStream) -> syn::Result<Self>15     fn parse(input: ParseStream) -> syn::Result<Self> {
16         let forward = input.parse::<syn::Ident>()?;
17         if forward != "forward" {
18             return Err(syn::Error::new(forward.span(), "msg"));
19         }
20         let content;
21         parenthesized!(content in input);
22         let looky = content.lookahead1();
23         if looky.peek(syn::LitInt) {
24             let int: syn::LitInt = content.parse()?;
25             let index = int.base10_parse()?;
26             return Ok(Forward::Unnamed(index));
27         }
28         Ok(Forward::Named(content.parse()?))
29     }
30 }
31 
32 #[derive(Copy, Clone)]
33 pub enum WhichFn {
34     Code,
35     Help,
36     Url,
37     Severity,
38     Labels,
39     SourceCode,
40     Related,
41     DiagnosticSource,
42 }
43 
44 impl WhichFn {
method_call(&self) -> TokenStream45     pub fn method_call(&self) -> TokenStream {
46         match self {
47             Self::Code => quote! { code() },
48             Self::Help => quote! { help() },
49             Self::Url => quote! { url() },
50             Self::Severity => quote! { severity() },
51             Self::Labels => quote! { labels() },
52             Self::SourceCode => quote! { source_code() },
53             Self::Related => quote! { related() },
54             Self::DiagnosticSource => quote! { diagnostic_source() },
55         }
56     }
57 
signature(&self) -> TokenStream58     pub fn signature(&self) -> TokenStream {
59         match self {
60             Self::Code => quote! {
61                 fn code(& self) -> std::option::Option<std::boxed::Box<dyn std::fmt::Display + '_>>
62             },
63             Self::Help => quote! {
64                 fn help(& self) -> std::option::Option<std::boxed::Box<dyn std::fmt::Display + '_>>
65             },
66             Self::Url => quote! {
67                 fn url(& self) -> std::option::Option<std::boxed::Box<dyn std::fmt::Display + '_>>
68             },
69             Self::Severity => quote! {
70                 fn severity(&self) -> std::option::Option<miette::Severity>
71             },
72             Self::Related => quote! {
73                 fn related(&self) -> std::option::Option<std::boxed::Box<dyn std::iter::Iterator<Item = &dyn miette::Diagnostic> + '_>>
74             },
75             Self::Labels => quote! {
76                 fn labels(&self) -> std::option::Option<std::boxed::Box<dyn std::iter::Iterator<Item = miette::LabeledSpan> + '_>>
77             },
78             Self::SourceCode => quote! {
79                 fn source_code(&self) -> std::option::Option<&dyn miette::SourceCode>
80             },
81             Self::DiagnosticSource => quote! {
82                 fn diagnostic_source(&self) -> std::option::Option<&dyn miette::Diagnostic>
83             },
84         }
85     }
86 
catchall_arm(&self) -> TokenStream87     pub fn catchall_arm(&self) -> TokenStream {
88         quote! { _ => std::option::Option::None }
89     }
90 }
91 
92 impl Forward {
for_transparent_field(fields: &syn::Fields) -> syn::Result<Self>93     pub fn for_transparent_field(fields: &syn::Fields) -> syn::Result<Self> {
94         let make_err = || {
95             syn::Error::new(
96                 fields.span(),
97                 "you can only use #[diagnostic(transparent)] with exactly one field",
98             )
99         };
100         match fields {
101             syn::Fields::Named(named) => {
102                 let mut iter = named.named.iter();
103                 let field = iter.next().ok_or_else(make_err)?;
104                 if iter.next().is_some() {
105                     return Err(make_err());
106                 }
107                 let field_name = field
108                     .ident
109                     .clone()
110                     .unwrap_or_else(|| format_ident!("unnamed"));
111                 Ok(Self::Named(field_name))
112             }
113             syn::Fields::Unnamed(unnamed) => {
114                 if unnamed.unnamed.iter().len() != 1 {
115                     return Err(make_err());
116                 }
117                 Ok(Self::Unnamed(0))
118             }
119             _ => Err(syn::Error::new(
120                 fields.span(),
121                 "you cannot use #[diagnostic(transparent)] with a unit struct or a unit variant",
122             )),
123         }
124     }
125 
gen_struct_method(&self, which_fn: WhichFn) -> TokenStream126     pub fn gen_struct_method(&self, which_fn: WhichFn) -> TokenStream {
127         let signature = which_fn.signature();
128         let method_call = which_fn.method_call();
129 
130         let field_name = match self {
131             Forward::Named(field_name) => quote!(#field_name),
132             Forward::Unnamed(index) => {
133                 let index = syn::Index::from(*index);
134                 quote!(#index)
135             }
136         };
137 
138         quote! {
139             #[inline]
140             #signature {
141                 self.#field_name.#method_call
142             }
143         }
144     }
145 
gen_enum_match_arm(&self, variant: &syn::Ident, which_fn: WhichFn) -> TokenStream146     pub fn gen_enum_match_arm(&self, variant: &syn::Ident, which_fn: WhichFn) -> TokenStream {
147         let method_call = which_fn.method_call();
148         match self {
149             Forward::Named(field_name) => quote! {
150                 Self::#variant { #field_name, .. } => #field_name.#method_call,
151             },
152             Forward::Unnamed(index) => {
153                 let underscores: Vec<_> = core::iter::repeat(quote! { _, }).take(*index).collect();
154                 let unnamed = format_ident!("unnamed");
155                 quote! {
156                     Self::#variant ( #(#underscores)* #unnamed, .. ) => #unnamed.#method_call,
157                 }
158             }
159         }
160     }
161 }
162