1 // Copyright 2018 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 //! Derives a 9P wire format encoding for a struct by recursively calling 6 //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct. 7 //! This is only intended to be used from within the `p9` crate. 8 9 #![recursion_limit = "256"] 10 11 extern crate proc_macro; 12 extern crate proc_macro2; 13 14 #[macro_use] 15 extern crate quote; 16 17 #[macro_use] 18 extern crate syn; 19 20 use proc_macro2::Span; 21 use proc_macro2::TokenStream; 22 use syn::spanned::Spanned; 23 use syn::Data; 24 use syn::DeriveInput; 25 use syn::Fields; 26 use syn::Ident; 27 28 /// The function that derives the actual implementation. 29 #[proc_macro_derive(P9WireFormat)] p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream30 pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream { 31 let input = parse_macro_input!(input as DeriveInput); 32 p9_wire_format_inner(input).into() 33 } 34 p9_wire_format_inner(input: DeriveInput) -> TokenStream35 fn p9_wire_format_inner(input: DeriveInput) -> TokenStream { 36 if !input.generics.params.is_empty() { 37 return quote! { 38 compile_error!("derive(P9WireFormat) does not support generic parameters"); 39 }; 40 } 41 42 let container = input.ident; 43 44 let byte_size_impl = byte_size_sum(&input.data); 45 let encode_impl = encode_wire_format(&input.data); 46 let decode_impl = decode_wire_format(&input.data, &container); 47 48 let scope = format!("wire_format_{}", container).to_lowercase(); 49 let scope = Ident::new(&scope, Span::call_site()); 50 quote! { 51 mod #scope { 52 extern crate std; 53 use self::std::io; 54 use self::std::result::Result::Ok; 55 56 use super::#container; 57 58 use protocol::WireFormat; 59 60 impl WireFormat for #container { 61 fn byte_size(&self) -> u32 { 62 #byte_size_impl 63 } 64 65 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> { 66 #encode_impl 67 } 68 69 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> { 70 #decode_impl 71 } 72 } 73 } 74 } 75 } 76 77 // Generate code that recursively calls byte_size on every field in the struct. byte_size_sum(data: &Data) -> TokenStream78 fn byte_size_sum(data: &Data) -> TokenStream { 79 if let Data::Struct(ref data) = *data { 80 if let Fields::Named(ref fields) = data.fields { 81 let fields = fields.named.iter().map(|f| { 82 let field = &f.ident; 83 let span = field.span(); 84 quote_spanned! {span=> 85 WireFormat::byte_size(&self.#field) 86 } 87 }); 88 89 quote! { 90 0 #(+ #fields)* 91 } 92 } else { 93 unimplemented!(); 94 } 95 } else { 96 unimplemented!(); 97 } 98 } 99 100 // Generate code that recursively calls encode on every field in the struct. encode_wire_format(data: &Data) -> TokenStream101 fn encode_wire_format(data: &Data) -> TokenStream { 102 if let Data::Struct(ref data) = *data { 103 if let Fields::Named(ref fields) = data.fields { 104 let fields = fields.named.iter().map(|f| { 105 let field = &f.ident; 106 let span = field.span(); 107 quote_spanned! {span=> 108 WireFormat::encode(&self.#field, _writer)?; 109 } 110 }); 111 112 quote! { 113 #(#fields)* 114 115 Ok(()) 116 } 117 } else { 118 unimplemented!(); 119 } 120 } else { 121 unimplemented!(); 122 } 123 } 124 125 // Generate code that recursively calls decode on every field in the struct. decode_wire_format(data: &Data, container: &Ident) -> TokenStream126 fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream { 127 if let Data::Struct(ref data) = *data { 128 if let Fields::Named(ref fields) = data.fields { 129 let values = fields.named.iter().map(|f| { 130 let field = &f.ident; 131 let span = field.span(); 132 quote_spanned! {span=> 133 let #field = WireFormat::decode(_reader)?; 134 } 135 }); 136 137 let members = fields.named.iter().map(|f| { 138 let field = &f.ident; 139 quote! { 140 #field: #field, 141 } 142 }); 143 144 quote! { 145 #(#values)* 146 147 Ok(#container { 148 #(#members)* 149 }) 150 } 151 } else { 152 unimplemented!(); 153 } 154 } else { 155 unimplemented!(); 156 } 157 } 158 159 #[cfg(test)] 160 mod tests { 161 use super::*; 162 163 #[test] byte_size()164 fn byte_size() { 165 let input: DeriveInput = parse_quote! { 166 struct Item { 167 ident: u32, 168 with_underscores: String, 169 other: u8, 170 } 171 }; 172 173 let expected = quote! { 174 0 175 + WireFormat::byte_size(&self.ident) 176 + WireFormat::byte_size(&self.with_underscores) 177 + WireFormat::byte_size(&self.other) 178 }; 179 180 assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string()); 181 } 182 183 #[test] encode()184 fn encode() { 185 let input: DeriveInput = parse_quote! { 186 struct Item { 187 ident: u32, 188 with_underscores: String, 189 other: u8, 190 } 191 }; 192 193 let expected = quote! { 194 WireFormat::encode(&self.ident, _writer)?; 195 WireFormat::encode(&self.with_underscores, _writer)?; 196 WireFormat::encode(&self.other, _writer)?; 197 Ok(()) 198 }; 199 200 assert_eq!( 201 encode_wire_format(&input.data).to_string(), 202 expected.to_string(), 203 ); 204 } 205 206 #[test] decode()207 fn decode() { 208 let input: DeriveInput = parse_quote! { 209 struct Item { 210 ident: u32, 211 with_underscores: String, 212 other: u8, 213 } 214 }; 215 216 let container = Ident::new("Item", Span::call_site()); 217 let expected = quote! { 218 let ident = WireFormat::decode(_reader)?; 219 let with_underscores = WireFormat::decode(_reader)?; 220 let other = WireFormat::decode(_reader)?; 221 Ok(Item { 222 ident: ident, 223 with_underscores: with_underscores, 224 other: other, 225 }) 226 }; 227 228 assert_eq!( 229 decode_wire_format(&input.data, &container).to_string(), 230 expected.to_string(), 231 ); 232 } 233 234 #[test] end_to_end()235 fn end_to_end() { 236 let input: DeriveInput = parse_quote! { 237 struct Niijima_先輩 { 238 a: u8, 239 b: u16, 240 c: u32, 241 d: u64, 242 e: String, 243 f: Vec<String>, 244 g: Nested, 245 } 246 }; 247 248 let expected = quote! { 249 mod wire_format_niijima_先輩 { 250 extern crate std; 251 use self::std::io; 252 use self::std::result::Result::Ok; 253 254 use super::Niijima_先輩; 255 256 use protocol::WireFormat; 257 258 impl WireFormat for Niijima_先輩 { 259 fn byte_size(&self) -> u32 { 260 0 261 + WireFormat::byte_size(&self.a) 262 + WireFormat::byte_size(&self.b) 263 + WireFormat::byte_size(&self.c) 264 + WireFormat::byte_size(&self.d) 265 + WireFormat::byte_size(&self.e) 266 + WireFormat::byte_size(&self.f) 267 + WireFormat::byte_size(&self.g) 268 } 269 270 fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> { 271 WireFormat::encode(&self.a, _writer)?; 272 WireFormat::encode(&self.b, _writer)?; 273 WireFormat::encode(&self.c, _writer)?; 274 WireFormat::encode(&self.d, _writer)?; 275 WireFormat::encode(&self.e, _writer)?; 276 WireFormat::encode(&self.f, _writer)?; 277 WireFormat::encode(&self.g, _writer)?; 278 Ok(()) 279 } 280 fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> { 281 let a = WireFormat::decode(_reader)?; 282 let b = WireFormat::decode(_reader)?; 283 let c = WireFormat::decode(_reader)?; 284 let d = WireFormat::decode(_reader)?; 285 let e = WireFormat::decode(_reader)?; 286 let f = WireFormat::decode(_reader)?; 287 let g = WireFormat::decode(_reader)?; 288 Ok(Niijima_先輩 { 289 a: a, 290 b: b, 291 c: c, 292 d: d, 293 e: e, 294 f: f, 295 g: g, 296 }) 297 } 298 } 299 } 300 }; 301 302 assert_eq!( 303 p9_wire_format_inner(input).to_string(), 304 expected.to_string(), 305 ); 306 } 307 } 308