1  // Copyright 2018 The ChromiumOS Authors
2  // Use of this source code is governed by a BSD-style license that can be
3  // found in the LICENSE file.
4  
5  //! Derives a 9P wire format encoding for a struct by recursively calling
6  //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
7  //! This is only intended to be used from within the `p9` crate.
8  
9  #![recursion_limit = "256"]
10  
11  extern crate proc_macro;
12  extern crate proc_macro2;
13  
14  #[macro_use]
15  extern crate quote;
16  
17  #[macro_use]
18  extern crate syn;
19  
20  use proc_macro2::Span;
21  use proc_macro2::TokenStream;
22  use syn::spanned::Spanned;
23  use syn::Data;
24  use syn::DeriveInput;
25  use syn::Fields;
26  use syn::Ident;
27  
28  /// The function that derives the actual implementation.
29  #[proc_macro_derive(P9WireFormat)]
p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream30  pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31      let input = parse_macro_input!(input as DeriveInput);
32      p9_wire_format_inner(input).into()
33  }
34  
p9_wire_format_inner(input: DeriveInput) -> TokenStream35  fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
36      if !input.generics.params.is_empty() {
37          return quote! {
38              compile_error!("derive(P9WireFormat) does not support generic parameters");
39          };
40      }
41  
42      let container = input.ident;
43  
44      let byte_size_impl = byte_size_sum(&input.data);
45      let encode_impl = encode_wire_format(&input.data);
46      let decode_impl = decode_wire_format(&input.data, &container);
47  
48      let scope = format!("wire_format_{}", container).to_lowercase();
49      let scope = Ident::new(&scope, Span::call_site());
50      quote! {
51          mod #scope {
52              extern crate std;
53              use self::std::io;
54              use self::std::result::Result::Ok;
55  
56              use super::#container;
57  
58              use protocol::WireFormat;
59  
60              impl WireFormat for #container {
61                  fn byte_size(&self) -> u32 {
62                      #byte_size_impl
63                  }
64  
65                  fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
66                      #encode_impl
67                  }
68  
69                  fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
70                      #decode_impl
71                  }
72              }
73          }
74      }
75  }
76  
77  // Generate code that recursively calls byte_size on every field in the struct.
byte_size_sum(data: &Data) -> TokenStream78  fn byte_size_sum(data: &Data) -> TokenStream {
79      if let Data::Struct(ref data) = *data {
80          if let Fields::Named(ref fields) = data.fields {
81              let fields = fields.named.iter().map(|f| {
82                  let field = &f.ident;
83                  let span = field.span();
84                  quote_spanned! {span=>
85                      WireFormat::byte_size(&self.#field)
86                  }
87              });
88  
89              quote! {
90                  0 #(+ #fields)*
91              }
92          } else {
93              unimplemented!();
94          }
95      } else {
96          unimplemented!();
97      }
98  }
99  
100  // Generate code that recursively calls encode on every field in the struct.
encode_wire_format(data: &Data) -> TokenStream101  fn encode_wire_format(data: &Data) -> TokenStream {
102      if let Data::Struct(ref data) = *data {
103          if let Fields::Named(ref fields) = data.fields {
104              let fields = fields.named.iter().map(|f| {
105                  let field = &f.ident;
106                  let span = field.span();
107                  quote_spanned! {span=>
108                      WireFormat::encode(&self.#field, _writer)?;
109                  }
110              });
111  
112              quote! {
113                  #(#fields)*
114  
115                  Ok(())
116              }
117          } else {
118              unimplemented!();
119          }
120      } else {
121          unimplemented!();
122      }
123  }
124  
125  // Generate code that recursively calls decode on every field in the struct.
decode_wire_format(data: &Data, container: &Ident) -> TokenStream126  fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
127      if let Data::Struct(ref data) = *data {
128          if let Fields::Named(ref fields) = data.fields {
129              let values = fields.named.iter().map(|f| {
130                  let field = &f.ident;
131                  let span = field.span();
132                  quote_spanned! {span=>
133                      let #field = WireFormat::decode(_reader)?;
134                  }
135              });
136  
137              let members = fields.named.iter().map(|f| {
138                  let field = &f.ident;
139                  quote! {
140                      #field: #field,
141                  }
142              });
143  
144              quote! {
145                  #(#values)*
146  
147                  Ok(#container {
148                      #(#members)*
149                  })
150              }
151          } else {
152              unimplemented!();
153          }
154      } else {
155          unimplemented!();
156      }
157  }
158  
159  #[cfg(test)]
160  mod tests {
161      use super::*;
162  
163      #[test]
byte_size()164      fn byte_size() {
165          let input: DeriveInput = parse_quote! {
166              struct Item {
167                  ident: u32,
168                  with_underscores: String,
169                  other: u8,
170              }
171          };
172  
173          let expected = quote! {
174              0
175                  + WireFormat::byte_size(&self.ident)
176                  + WireFormat::byte_size(&self.with_underscores)
177                  + WireFormat::byte_size(&self.other)
178          };
179  
180          assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
181      }
182  
183      #[test]
encode()184      fn encode() {
185          let input: DeriveInput = parse_quote! {
186              struct Item {
187                  ident: u32,
188                  with_underscores: String,
189                  other: u8,
190              }
191          };
192  
193          let expected = quote! {
194              WireFormat::encode(&self.ident, _writer)?;
195              WireFormat::encode(&self.with_underscores, _writer)?;
196              WireFormat::encode(&self.other, _writer)?;
197              Ok(())
198          };
199  
200          assert_eq!(
201              encode_wire_format(&input.data).to_string(),
202              expected.to_string(),
203          );
204      }
205  
206      #[test]
decode()207      fn decode() {
208          let input: DeriveInput = parse_quote! {
209              struct Item {
210                  ident: u32,
211                  with_underscores: String,
212                  other: u8,
213              }
214          };
215  
216          let container = Ident::new("Item", Span::call_site());
217          let expected = quote! {
218              let ident = WireFormat::decode(_reader)?;
219              let with_underscores = WireFormat::decode(_reader)?;
220              let other = WireFormat::decode(_reader)?;
221              Ok(Item {
222                  ident: ident,
223                  with_underscores: with_underscores,
224                  other: other,
225              })
226          };
227  
228          assert_eq!(
229              decode_wire_format(&input.data, &container).to_string(),
230              expected.to_string(),
231          );
232      }
233  
234      #[test]
end_to_end()235      fn end_to_end() {
236          let input: DeriveInput = parse_quote! {
237              struct Niijima_先輩 {
238                  a: u8,
239                  b: u16,
240                  c: u32,
241                  d: u64,
242                  e: String,
243                  f: Vec<String>,
244                  g: Nested,
245              }
246          };
247  
248          let expected = quote! {
249              mod wire_format_niijima_先輩 {
250                  extern crate std;
251                  use self::std::io;
252                  use self::std::result::Result::Ok;
253  
254                  use super::Niijima_先輩;
255  
256                  use protocol::WireFormat;
257  
258                  impl WireFormat for Niijima_先輩 {
259                      fn byte_size(&self) -> u32 {
260                          0
261                          + WireFormat::byte_size(&self.a)
262                          + WireFormat::byte_size(&self.b)
263                          + WireFormat::byte_size(&self.c)
264                          + WireFormat::byte_size(&self.d)
265                          + WireFormat::byte_size(&self.e)
266                          + WireFormat::byte_size(&self.f)
267                          + WireFormat::byte_size(&self.g)
268                      }
269  
270                      fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
271                          WireFormat::encode(&self.a, _writer)?;
272                          WireFormat::encode(&self.b, _writer)?;
273                          WireFormat::encode(&self.c, _writer)?;
274                          WireFormat::encode(&self.d, _writer)?;
275                          WireFormat::encode(&self.e, _writer)?;
276                          WireFormat::encode(&self.f, _writer)?;
277                          WireFormat::encode(&self.g, _writer)?;
278                          Ok(())
279                      }
280                      fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
281                          let a = WireFormat::decode(_reader)?;
282                          let b = WireFormat::decode(_reader)?;
283                          let c = WireFormat::decode(_reader)?;
284                          let d = WireFormat::decode(_reader)?;
285                          let e = WireFormat::decode(_reader)?;
286                          let f = WireFormat::decode(_reader)?;
287                          let g = WireFormat::decode(_reader)?;
288                          Ok(Niijima_先輩 {
289                              a: a,
290                              b: b,
291                              c: c,
292                              d: d,
293                              e: e,
294                              f: f,
295                              g: g,
296                          })
297                      }
298                  }
299              }
300          };
301  
302          assert_eq!(
303              p9_wire_format_inner(input).to_string(),
304              expected.to_string(),
305          );
306      }
307  }
308