1 use crate::bound::{has_bound, InferredBound, Supertraits};
2 use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3 use crate::parse::Item;
4 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5 use crate::verbatim::VerbatimFn;
6 use proc_macro2::{Span, TokenStream};
7 use quote::{format_ident, quote, quote_spanned, ToTokens};
8 use std::collections::BTreeSet as Set;
9 use std::mem;
10 use syn::punctuated::Punctuated;
11 use syn::visit_mut::{self, VisitMut};
12 use syn::{
13     parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14     Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15     ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
16 };
17 
18 impl ToTokens for Item {
to_tokens(&self, tokens: &mut TokenStream)19     fn to_tokens(&self, tokens: &mut TokenStream) {
20         match self {
21             Item::Trait(item) => item.to_tokens(tokens),
22             Item::Impl(item) => item.to_tokens(tokens),
23         }
24     }
25 }
26 
27 #[derive(Clone, Copy)]
28 enum Context<'a> {
29     Trait {
30         generics: &'a Generics,
31         supertraits: &'a Supertraits,
32     },
33     Impl {
34         impl_generics: &'a Generics,
35         associated_type_impl_traits: &'a Set<Ident>,
36     },
37 }
38 
39 impl Context<'_> {
lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam>40     fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41         let generics = match self {
42             Context::Trait { generics, .. } => generics,
43             Context::Impl { impl_generics, .. } => impl_generics,
44         };
45         generics.params.iter().filter_map(move |param| {
46             if let GenericParam::Lifetime(param) = param {
47                 if used.contains(&param.lifetime) {
48                     return Some(param);
49                 }
50             }
51             None
52         })
53     }
54 }
55 
expand(input: &mut Item, is_local: bool)56 pub fn expand(input: &mut Item, is_local: bool) {
57     match input {
58         Item::Trait(input) => {
59             let context = Context::Trait {
60                 generics: &input.generics,
61                 supertraits: &input.supertraits,
62             };
63             for inner in &mut input.items {
64                 if let TraitItem::Fn(method) = inner {
65                     let sig = &mut method.sig;
66                     if sig.asyncness.is_some() {
67                         let block = &mut method.default;
68                         let mut has_self = has_self_in_sig(sig);
69                         method.attrs.push(parse_quote!(#[must_use]));
70                         if let Some(block) = block {
71                             has_self |= has_self_in_block(block);
72                             transform_block(context, sig, block);
73                             method.attrs.push(lint_suppress_with_body());
74                         } else {
75                             method.attrs.push(lint_suppress_without_body());
76                         }
77                         let has_default = method.default.is_some();
78                         transform_sig(context, sig, has_self, has_default, is_local);
79                     }
80                 }
81             }
82         }
83         Item::Impl(input) => {
84             let mut associated_type_impl_traits = Set::new();
85             for inner in &input.items {
86                 if let ImplItem::Type(assoc) = inner {
87                     if let Type::ImplTrait(_) = assoc.ty {
88                         associated_type_impl_traits.insert(assoc.ident.clone());
89                     }
90                 }
91             }
92 
93             let context = Context::Impl {
94                 impl_generics: &input.generics,
95                 associated_type_impl_traits: &associated_type_impl_traits,
96             };
97             for inner in &mut input.items {
98                 match inner {
99                     ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100                         let sig = &mut method.sig;
101                         let block = &mut method.block;
102                         let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103                         transform_block(context, sig, block);
104                         transform_sig(context, sig, has_self, false, is_local);
105                         method.attrs.push(lint_suppress_with_body());
106                     }
107                     ImplItem::Verbatim(tokens) => {
108                         let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109                             Ok(method) if method.sig.asyncness.is_some() => method,
110                             _ => continue,
111                         };
112                         let sig = &mut method.sig;
113                         let has_self = has_self_in_sig(sig);
114                         transform_sig(context, sig, has_self, false, is_local);
115                         method.attrs.push(lint_suppress_with_body());
116                         *tokens = quote!(#method);
117                     }
118                     _ => {}
119                 }
120             }
121         }
122     }
123 }
124 
lint_suppress_with_body() -> Attribute125 fn lint_suppress_with_body() -> Attribute {
126     parse_quote! {
127         #[allow(
128             elided_named_lifetimes,
129             clippy::async_yields_async,
130             clippy::diverging_sub_expression,
131             clippy::let_unit_value,
132             clippy::needless_arbitrary_self_type,
133             clippy::no_effect_underscore_binding,
134             clippy::shadow_same,
135             clippy::type_complexity,
136             clippy::type_repetition_in_bounds,
137             clippy::used_underscore_binding
138         )]
139     }
140 }
141 
lint_suppress_without_body() -> Attribute142 fn lint_suppress_without_body() -> Attribute {
143     parse_quote! {
144         #[allow(
145             elided_named_lifetimes,
146             clippy::type_complexity,
147             clippy::type_repetition_in_bounds
148         )]
149     }
150 }
151 
152 // Input:
153 //     async fn f<T>(&self, x: &T) -> Ret;
154 //
155 // Output:
156 //     fn f<'life0, 'life1, 'async_trait, T>(
157 //         &'life0 self,
158 //         x: &'life1 T,
159 //     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
160 //     where
161 //         'life0: 'async_trait,
162 //         'life1: 'async_trait,
163 //         T: 'async_trait,
164 //         Self: Sync + 'async_trait;
transform_sig( context: Context, sig: &mut Signature, has_self: bool, has_default: bool, is_local: bool, )165 fn transform_sig(
166     context: Context,
167     sig: &mut Signature,
168     has_self: bool,
169     has_default: bool,
170     is_local: bool,
171 ) {
172     sig.fn_token.span = sig.asyncness.take().unwrap().span;
173 
174     let (ret_arrow, ret) = match &sig.output {
175         ReturnType::Default => (quote!(->), quote!(())),
176         ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)),
177     };
178 
179     let mut lifetimes = CollectLifetimes::new();
180     for arg in &mut sig.inputs {
181         match arg {
182             FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183             FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184         }
185     }
186 
187     for param in &mut sig.generics.params {
188         match param {
189             GenericParam::Type(param) => {
190                 let param_name = &param.ident;
191                 let span = match param.colon_token.take() {
192                     Some(colon_token) => colon_token.span,
193                     None => param_name.span(),
194                 };
195                 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
196                 where_clause_or_default(&mut sig.generics.where_clause)
197                     .predicates
198                     .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
199             }
200             GenericParam::Lifetime(param) => {
201                 let param_name = &param.lifetime;
202                 let span = match param.colon_token.take() {
203                     Some(colon_token) => colon_token.span,
204                     None => param_name.span(),
205                 };
206                 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
207                 where_clause_or_default(&mut sig.generics.where_clause)
208                     .predicates
209                     .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
210             }
211             GenericParam::Const(_) => {}
212         }
213     }
214 
215     for param in context.lifetimes(&lifetimes.explicit) {
216         let param = &param.lifetime;
217         let span = param.span();
218         where_clause_or_default(&mut sig.generics.where_clause)
219             .predicates
220             .push(parse_quote_spanned!(span=> #param: 'async_trait));
221     }
222 
223     if sig.generics.lt_token.is_none() {
224         sig.generics.lt_token = Some(Token![<](sig.ident.span()));
225     }
226     if sig.generics.gt_token.is_none() {
227         sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
228     }
229 
230     for elided in lifetimes.elided {
231         sig.generics.params.push(parse_quote!(#elided));
232         where_clause_or_default(&mut sig.generics.where_clause)
233             .predicates
234             .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
235     }
236 
237     sig.generics.params.push(parse_quote!('async_trait));
238 
239     if has_self {
240         let bounds: &[InferredBound] = if is_local {
241             &[]
242         } else if let Some(receiver) = sig.receiver() {
243             match receiver.ty.as_ref() {
244                 // self: &Self
245                 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
246                 // self: Arc<Self>
247                 Type::Path(ty)
248                     if {
249                         let segment = ty.path.segments.last().unwrap();
250                         segment.ident == "Arc"
251                             && match &segment.arguments {
252                                 PathArguments::AngleBracketed(arguments) => {
253                                     arguments.args.len() == 1
254                                         && match &arguments.args[0] {
255                                             GenericArgument::Type(Type::Path(arg)) => {
256                                                 arg.path.is_ident("Self")
257                                             }
258                                             _ => false,
259                                         }
260                                 }
261                                 _ => false,
262                             }
263                     } =>
264                 {
265                     &[InferredBound::Sync, InferredBound::Send]
266                 }
267                 _ => &[InferredBound::Send],
268             }
269         } else {
270             &[InferredBound::Send]
271         };
272 
273         let bounds = bounds.iter().filter(|bound| match context {
274             Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
275             Context::Impl { .. } => false,
276         });
277 
278         where_clause_or_default(&mut sig.generics.where_clause)
279             .predicates
280             .push(parse_quote! {
281                 Self: #(#bounds +)* 'async_trait
282             });
283     }
284 
285     for (i, arg) in sig.inputs.iter_mut().enumerate() {
286         match arg {
287             FnArg::Receiver(receiver) => {
288                 if receiver.reference.is_none() {
289                     receiver.mutability = None;
290                 }
291             }
292             FnArg::Typed(arg) => {
293                 if match *arg.ty {
294                     Type::Reference(_) => false,
295                     _ => true,
296                 } {
297                     if let Pat::Ident(pat) = &mut *arg.pat {
298                         pat.by_ref = None;
299                         pat.mutability = None;
300                     } else {
301                         let positional = positional_arg(i, &arg.pat);
302                         let m = mut_pat(&mut arg.pat);
303                         arg.pat = parse_quote!(#m #positional);
304                     }
305                 }
306                 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
307             }
308         }
309     }
310 
311     let bounds = if is_local {
312         quote!('async_trait)
313     } else {
314         quote!(::core::marker::Send + 'async_trait)
315     };
316     sig.output = parse_quote! {
317         #ret_arrow ::core::pin::Pin<Box<
318             dyn ::core::future::Future<Output = #ret> + #bounds
319         >>
320     };
321 }
322 
323 // Input:
324 //     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
325 //         self + x + a + b
326 //     }
327 //
328 // Output:
329 //     Box::pin(async move {
330 //         let ___ret: Ret = {
331 //             let __self = self;
332 //             let x = x;
333 //             let (a, b) = __arg1;
334 //
335 //             __self + x + a + b
336 //         };
337 //
338 //         ___ret
339 //     })
transform_block(context: Context, sig: &mut Signature, block: &mut Block)340 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
341     let mut replace_self = false;
342     let decls = sig
343         .inputs
344         .iter()
345         .enumerate()
346         .map(|(i, arg)| match arg {
347             FnArg::Receiver(Receiver {
348                 self_token,
349                 mutability,
350                 ..
351             }) => {
352                 replace_self = true;
353                 let ident = Ident::new("__self", self_token.span);
354                 quote!(let #mutability #ident = #self_token;)
355             }
356             FnArg::Typed(arg) => {
357                 // If there is a #[cfg(...)] attribute that selectively enables
358                 // the parameter, forward it to the variable.
359                 //
360                 // This is currently not applied to the `self` parameter.
361                 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
362 
363                 if let Type::Reference(_) = *arg.ty {
364                     quote!()
365                 } else if let Pat::Ident(PatIdent {
366                     ident, mutability, ..
367                 }) = &*arg.pat
368                 {
369                     quote! {
370                         #(#attrs)*
371                         let #mutability #ident = #ident;
372                     }
373                 } else {
374                     let pat = &arg.pat;
375                     let ident = positional_arg(i, pat);
376                     if let Pat::Wild(_) = **pat {
377                         quote! {
378                             #(#attrs)*
379                             let #ident = #ident;
380                         }
381                     } else {
382                         quote! {
383                             #(#attrs)*
384                             let #pat = {
385                                 let #ident = #ident;
386                                 #ident
387                             };
388                         }
389                     }
390                 }
391             }
392         })
393         .collect::<Vec<_>>();
394 
395     if replace_self {
396         ReplaceSelf.visit_block_mut(block);
397     }
398 
399     let stmts = &block.stmts;
400     let let_ret = match &mut sig.output {
401         ReturnType::Default => quote_spanned! {block.brace_token.span=>
402             #(#decls)*
403             let () = { #(#stmts)* };
404         },
405         ReturnType::Type(_, ret) => {
406             if contains_associated_type_impl_trait(context, ret) {
407                 if decls.is_empty() {
408                     quote!(#(#stmts)*)
409                 } else {
410                     quote!(#(#decls)* { #(#stmts)* })
411                 }
412             } else {
413                 quote! {
414                     if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
415                         #[allow(unreachable_code)]
416                         return __ret;
417                     }
418                     #(#decls)*
419                     let __ret: #ret = { #(#stmts)* };
420                     #[allow(unreachable_code)]
421                     __ret
422                 }
423             }
424         }
425     };
426     let box_pin = quote_spanned!(block.brace_token.span=>
427         Box::pin(async move { #let_ret })
428     );
429     block.stmts = parse_quote!(#box_pin);
430 }
431 
positional_arg(i: usize, pat: &Pat) -> Ident432 fn positional_arg(i: usize, pat: &Pat) -> Ident {
433     let span = syn::spanned::Spanned::span(pat).resolved_at(Span::mixed_site());
434     format_ident!("__arg{}", i, span = span)
435 }
436 
contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool437 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
438     struct AssociatedTypeImplTraits<'a> {
439         set: &'a Set<Ident>,
440         contains: bool,
441     }
442 
443     impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
444         fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
445             if ty.qself.is_none()
446                 && ty.path.segments.len() == 2
447                 && ty.path.segments[0].ident == "Self"
448                 && self.set.contains(&ty.path.segments[1].ident)
449             {
450                 self.contains = true;
451             }
452             visit_mut::visit_type_path_mut(self, ty);
453         }
454     }
455 
456     match context {
457         Context::Trait { .. } => false,
458         Context::Impl {
459             associated_type_impl_traits,
460             ..
461         } => {
462             let mut visit = AssociatedTypeImplTraits {
463                 set: associated_type_impl_traits,
464                 contains: false,
465             };
466             visit.visit_type_mut(ret);
467             visit.contains
468         }
469     }
470 }
471 
where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause472 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
473     clause.get_or_insert_with(|| WhereClause {
474         where_token: Default::default(),
475         predicates: Punctuated::new(),
476     })
477 }
478