xref: /aosp_15_r20/external/gsc-utils/rust/enum_utils/src/lib.rs (revision 4f2df630800bdcf1d4f0decf95d8a1cb87344f5f)
1  // Copyright 2023 The ChromiumOS Authors
2  // Use of this source code is governed by a BSD-style license that can be
3  // found in the LICENSE file.
4  use std::cmp::Ordering;
5  
6  use proc_macro::*;
7  
get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)>8  fn get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)> {
9      let mut enum_name = None;
10      let mut cap_next = false;
11      for tree in stream {
12          match tree {
13              TokenTree::Ident(i) if i.to_string() == "enum" => cap_next = true,
14              TokenTree::Ident(i) if cap_next => {
15                  enum_name = Some(TokenTree::Ident(i));
16                  cap_next = false;
17              }
18              TokenTree::Group(g) => {
19                  if let Some(name) = enum_name {
20                      return Some((name, g.stream()));
21                  } else {
22                      let result = get_enum_and_stream(g.stream());
23                      if result.is_some() {
24                          return result;
25                      }
26                  };
27              }
28              _ => {}
29          }
30      }
31      None
32  }
33  
generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream34  fn generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream {
35      let (enum_name, enum_stream) = get_enum_and_stream(item.clone()).expect("Must use on enum");
36      // Check that enums do not have associated fields, but see still want to all doc comments.
37      for tree in enum_stream.clone() {
38          match tree {
39              TokenTree::Group(group) => match group.stream().into_iter().next() {
40                  Some(TokenTree::Ident(ident)) if ident.to_string() == "doc" => {
41                      // This is a doc comment; this is the only allowed group right now since we
42                      // do not support associated fields on any enum value
43                  }
44                  _ => panic!("Only enums without associated fields are supported"),
45              },
46              _ => {
47                  // If no groups, then this enum should just be a normal "C-Style" enum.
48              }
49          }
50      }
51  
52      let qualified_list = enum_stream
53          .into_iter()
54          .map(|tree| {
55              if let TokenTree::Ident(det_name) = tree {
56                  format!("{}::{},", enum_name, det_name).parse().unwrap()
57              } else {
58                  TokenStream::new()
59              }
60          })
61          .collect::<TokenStream>();
62      let array_stream: TokenStream = format!(
63          r#"{header}
64          pub const {name}: &[{enum_name}] = &[{qualified_list}];"#,
65          header = header,
66          name = name,
67          enum_name = enum_name,
68          qualified_list = qualified_list,
69      )
70      .parse()
71      .unwrap();
72  
73      // Emitted exactly what was written, then add the test array
74      item.extend(array_stream);
75      item
76  }
77  
78  /// Generates an test cfg array with the specified name containing all enum values. This is only
79  /// valid for enums where all the variants do not have fields associated with them.
80  /// Access to array is possible only in test builds
81  ///
82  /// # Example
83  ///
84  /// ```
85  /// # #[macro_use] extern crate enum_utils;
86  /// #[gen_test_enum_array(MyTestArrayName)]
87  /// enum MyEnum {
88  ///     ValueOne,
89  ///     ValueTwo,
90  /// }
91  ///
92  /// #[cfg(test)]
93  /// mod tests {
94  ///     #[test]
95  ///     fn test_two_values() {
96  ///         assert_eq!(MyTestArrayName.len(), 2);
97  ///     }
98  /// }
99  /// ```
100  #[proc_macro_attribute]
gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream101  pub fn gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream {
102      generate_enum_array(name, item, "#[cfg(test)]")
103  }
104  
105  /// Generates an array with the specified name containing all enum values. This is only valid for
106  /// enums where all the variants do not have fields associated with them.
107  ///
108  /// # Example
109  ///
110  /// ```
111  /// # #[macro_use] extern crate enum_utils;
112  /// #[gen_enum_array(MyArrayName)]
113  /// pub enum MyEnum {
114  ///     ValueOne,
115  ///     ValueTwo,
116  /// }
117  ///
118  /// fn check() {
119  ///         assert_eq!(MyArrayName.len(), 2);
120  /// }
121  /// ```
122  #[proc_macro_attribute]
gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream123  pub fn gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream {
124      generate_enum_array(name, item, "")
125  }
126  
127  /// Generates an impl to_string for enum which returns &str.
128  /// Method is implemended using match for every enum variant.
129  /// Valid for enums where all the variants do not have fields associated with them.
130  ///
131  /// # Example
132  ///
133  /// ```
134  /// # #[macro_use] extern crate enum_utils;
135  /// #[gen_to_string]
136  /// enum MyEnum {
137  ///     ValueOne,
138  ///     ValueTwo,
139  /// }
140  ///
141  /// fn main() {
142  ///     let e = MyEnum::ValueOne;
143  ///     println!("{}", e.to_string());
144  /// }
145  ///
146  /// ```
147  #[proc_macro_attribute]
gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream148  pub fn gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream {
149      let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum");
150  
151      let mut match_arms = TokenStream::new();
152      let enums_items = enum_stream.into_iter().filter_map(|tt| match tt {
153          TokenTree::Ident(id) => Some(id.to_string()),
154          _ => None,
155      });
156      for item in enums_items {
157          let arm: TokenStream = format!(
158              r#"{enum_name}::{item} => "{item}","#,
159              enum_name = enum_name,
160              item = item,
161          )
162          .parse()
163          .unwrap();
164          match_arms.extend(arm);
165      }
166  
167      let implementation: TokenStream = format!(
168          r#"impl {enum_name} {{
169          pub fn to_string(&self) -> &'static str {{
170              match *self {{
171                  {match_arms}
172              }}
173          }}
174      }}"#,
175          enum_name = enum_name,
176          match_arms = match_arms
177      )
178      .parse()
179      .unwrap();
180  
181      // Emit input as is and add the to_string implementation
182      input.extend(implementation);
183      input
184  }
185  
parse_to_i128(val: &str, negative: bool) -> i128186  fn parse_to_i128(val: &str, negative: bool) -> i128 {
187      let (first_pos, base) = match val.get(0..2) {
188          Some("0x") => (2, 16),
189          Some("0o") => (2, 8),
190          Some("0b") => (2, 2),
191          _ => (0, 10),
192      };
193  
194      let sign = if negative { -1 } else { 1 };
195  
196      // Remove any helper _ in the string literal, then convert from base
197      sign * i128::from_str_radix(
198          &val[first_pos..]
199              .chars()
200              .filter(|c| *c != '_')
201              .collect::<String>(),
202          base,
203      )
204      .unwrap_or_else(|_| panic!("Invalid number {}", val))
205  }
206  
207  /// Generates a exclusive `END` const and `from_<repr>` function for an enum
208  ///
209  /// # Example
210  ///
211  /// ```
212  /// # #[macro_use] extern crate enum_utils;
213  /// #[enum_as(u8)]
214  /// enum MyEnum {
215  ///     ValueZero,
216  ///     ValueOne,
217  /// }
218  ///
219  /// # fn main() {
220  /// assert!(matches!(MyEnum::from_u8(1), Some(MyEnum::ValueOne)));
221  /// assert_eq!(MyEnum::END, 2);
222  /// # }
223  ///
224  /// ```
225  #[proc_macro_attribute]
enum_as(repr: TokenStream, input: TokenStream) -> TokenStream226  pub fn enum_as(repr: TokenStream, input: TokenStream) -> TokenStream {
227      let repr = repr.to_string();
228      let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum");
229  
230      #[derive(Debug, Copy, Clone)]
231      enum WantState {
232          Determinant,
233          Punc,
234          NegativeVal,
235          Val,
236          CommentBlock,
237      }
238      let mut state = WantState::Determinant;
239      let mut skipped_ranges = vec![];
240      let mut start: Option<i128> = None;
241      let mut end: Option<i128> = None;
242      for tt in enum_stream {
243          use WantState::*;
244          state = match (tt, state) {
245              (TokenTree::Ident(_), Determinant) => Punc,
246              // If we are expecting a Determinant, but get a # instead, it must be a comment
247              (TokenTree::Punct(p), Determinant) if p.as_char() == '#' => CommentBlock,
248              (TokenTree::Punct(p), Val) if p.as_char() == '-' => NegativeVal,
249              (TokenTree::Group(_), CommentBlock) => Determinant,
250              (TokenTree::Punct(p), Punc) => match p.as_char() {
251                  ',' => {
252                      start = Some(start.unwrap_or(0));
253                      end = Some(end.unwrap_or_else(|| start.unwrap()) + 1);
254                      Determinant
255                  }
256                  '=' => Val,
257                  other => panic!("Unexpected punctuation '{}'", other),
258              },
259              (TokenTree::Literal(l), Val | NegativeVal) => {
260                  let val = parse_to_i128(&l.to_string(), matches!(state, NegativeVal));
261                  start = Some(start.unwrap_or(val));
262                  let expected = end.unwrap_or_else(|| start.unwrap());
263                  match val.cmp(&expected) {
264                      Ordering::Greater => {
265                          skipped_ranges.push((expected, val));
266                          end = Some(val);
267                      }
268                      Ordering::Less => panic!("Discriminants must increase in value"),
269                      Ordering::Equal => (),
270                  }
271                  Punc
272              }
273              (tt, want) => {
274                  panic!("Want {:?} but got {:?}", want, tt)
275              }
276          };
277      }
278  
279      let skipped_ranges = if skipped_ranges.is_empty() {
280          "".to_string()
281      } else {
282          format!(
283              r#"for r in &{ranges:?} {{
284                  if (r.0..r.1).contains(&val) {{
285                      return None;
286                  }}
287              }}"#,
288              ranges = skipped_ranges
289          )
290      };
291  
292      // Ensure that there is at least one discriminant
293      let start = start.expect("Enum needs at least one discriminant");
294      let end = end.expect("Enum needs at least one discriminant");
295  
296      // Ensure that END will fit into usize or u64
297      u64::try_from(end).unwrap_or_else(|_| panic!("Value after last discriminant must be unsigned"));
298  
299      let implementation: TokenStream = format!(
300          r#"impl {enum_name} {{
301              pub const END: {end_type} = {end};
302  
303              pub fn from_{repr}(val: {repr}) -> Option<Self> {{
304                  if val < {start} {{
305                      return None;
306                  }}
307                  if val > ((Self::END - 1) as {repr}) {{
308                      return None;
309                  }}
310                  {skipped_ranges}
311                  Some( unsafe {{ core::mem::transmute(val) }})
312              }}
313          }}
314  
315          impl PartialEq for {enum_name} {{
316              fn eq(&self, other: &Self) -> bool {{
317                  *self as {repr} == *other as {repr}
318              }}
319          }}
320  
321          impl Eq for {enum_name} {{ }}
322  
323          #[cfg(not(target_arch = "riscv32"))]
324          impl core::hash::Hash for {enum_name} {{
325              fn hash<H: core::hash::Hasher>(&self, state: &mut H) {{
326                  (*self as {repr}).hash(state);
327              }}
328           }}"#,
329          enum_name = enum_name,
330          start = start,
331          end = end,
332          end_type = if repr == "u64" { "u64" } else { "usize " },
333          repr = repr,
334          skipped_ranges = skipped_ranges,
335      )
336      .parse()
337      .unwrap();
338  
339      // Attribute input with a repr, Clone, Clone, and allow(dead_code), then add the custom
340      // implementation. We allow dead code since the these enums are typically interface enums that
341      // define an API boundary.
342      let mut res: TokenStream = format!(
343          "#[repr({})]\n#[derive(Copy, Clone)]\n#[allow(dead_code)]",
344          repr
345      )
346      .parse()
347      .unwrap();
348      res.extend(input);
349      res.extend(implementation);
350      res
351  }
352  
get_param_list_without_self(params: TokenStream) -> String353  fn get_param_list_without_self(params: TokenStream) -> String {
354      let mut result = String::new();
355  
356      #[derive(Debug, Copy, Clone)]
357      enum WantState {
358          SelfParam,
359          FirstIdentifier,
360          Comma,
361      }
362      let mut state = WantState::SelfParam;
363      for tt in params {
364          use WantState::*;
365          state = match (tt, state) {
366              (TokenTree::Ident(ident), SelfParam) if ident.to_string() == "self" => FirstIdentifier,
367              (TokenTree::Ident(ident), FirstIdentifier) => {
368                  result.push_str(&ident.to_string());
369                  Comma
370              }
371              (TokenTree::Punct(p), Comma) if p.to_string() == "," => {
372                  result.push(',');
373                  FirstIdentifier
374              }
375              (_, other) => {
376                  // Do nothing; keep watching for what we are looking for
377                  other
378              }
379          };
380      }
381  
382      result
383  }
384  
385  /// Replaces the function body with a single statement that pass all parameters thru to same
386  /// function name on the variable specified in the passthru_to macro
387  ///
388  /// # Example
389  ///
390  /// ```
391  /// # #[macro_use] extern crate enum_utils;
392  /// pub struct Inner(usize);
393  ///
394  /// impl Inner {
395  ///     pub fn read_plus(&self, plus: usize) -> usize {
396  ///         self.0 + plus
397  ///     }
398  /// }
399  ///
400  /// pub struct Outer {
401  ///     pub my_inner: Inner,
402  /// }
403  ///
404  /// impl Outer {
405  ///     #[passthru_to(my_inner)]
406  ///     pub fn read_plus(&self, plus: usize) -> usize {}
407  /// }
408  ///
409  /// ```
410  #[proc_macro_attribute]
passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream411  pub fn passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream {
412      #[derive(Debug, Copy, Clone)]
413      enum WantState {
414          FunctionKeyword,
415          FunctionName,
416          Parameters,
417          Body,
418          End,
419      }
420      let mut state = WantState::FunctionKeyword;
421      let mut name = None;
422      let mut params = None;
423      let mut body_num = None;
424      for (i, tt) in input.clone().into_iter().enumerate() {
425          use WantState::*;
426          state = match (tt, state) {
427              (TokenTree::Ident(ident), FunctionKeyword) if ident.to_string() == "fn" => FunctionName,
428              (_, FunctionKeyword) => FunctionKeyword,
429              (TokenTree::Ident(ident), FunctionName) => {
430                  name = Some(ident.to_string());
431                  Parameters
432              }
433              (TokenTree::Group(group), Parameters) => {
434                  params = Some(get_param_list_without_self(group.stream()));
435                  Body
436              }
437              (TokenTree::Group(group), Body) if group.delimiter() == Delimiter::Brace => {
438                  body_num = Some(i);
439                  End
440              }
441              (_, Body) => Body,
442              (tt, want) => {
443                  panic!("Want {:?} but got {:?}", want, tt)
444              }
445          };
446      }
447  
448      let implementation: TokenStream = format!(
449          r#"{{ self.{passthru_var}.{name}({params}) }}"#,
450          passthru_var = passthru_var,
451          name = name.expect("Cannot find function name"),
452          params = params.expect("Could find parameters"),
453      )
454      .parse()
455      .unwrap();
456  
457      // Take everything up to body, then replace body with the passthru implementation
458      input
459          .into_iter()
460          .take(body_num.expect("Cannot find body"))
461          .chain(implementation)
462          .collect()
463  }
464