1 // Copyright 2023 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 use std::cmp::Ordering; 5 6 use proc_macro::*; 7 get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)>8 fn get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)> { 9 let mut enum_name = None; 10 let mut cap_next = false; 11 for tree in stream { 12 match tree { 13 TokenTree::Ident(i) if i.to_string() == "enum" => cap_next = true, 14 TokenTree::Ident(i) if cap_next => { 15 enum_name = Some(TokenTree::Ident(i)); 16 cap_next = false; 17 } 18 TokenTree::Group(g) => { 19 if let Some(name) = enum_name { 20 return Some((name, g.stream())); 21 } else { 22 let result = get_enum_and_stream(g.stream()); 23 if result.is_some() { 24 return result; 25 } 26 }; 27 } 28 _ => {} 29 } 30 } 31 None 32 } 33 generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream34 fn generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream { 35 let (enum_name, enum_stream) = get_enum_and_stream(item.clone()).expect("Must use on enum"); 36 // Check that enums do not have associated fields, but see still want to all doc comments. 37 for tree in enum_stream.clone() { 38 match tree { 39 TokenTree::Group(group) => match group.stream().into_iter().next() { 40 Some(TokenTree::Ident(ident)) if ident.to_string() == "doc" => { 41 // This is a doc comment; this is the only allowed group right now since we 42 // do not support associated fields on any enum value 43 } 44 _ => panic!("Only enums without associated fields are supported"), 45 }, 46 _ => { 47 // If no groups, then this enum should just be a normal "C-Style" enum. 48 } 49 } 50 } 51 52 let qualified_list = enum_stream 53 .into_iter() 54 .map(|tree| { 55 if let TokenTree::Ident(det_name) = tree { 56 format!("{}::{},", enum_name, det_name).parse().unwrap() 57 } else { 58 TokenStream::new() 59 } 60 }) 61 .collect::<TokenStream>(); 62 let array_stream: TokenStream = format!( 63 r#"{header} 64 pub const {name}: &[{enum_name}] = &[{qualified_list}];"#, 65 header = header, 66 name = name, 67 enum_name = enum_name, 68 qualified_list = qualified_list, 69 ) 70 .parse() 71 .unwrap(); 72 73 // Emitted exactly what was written, then add the test array 74 item.extend(array_stream); 75 item 76 } 77 78 /// Generates an test cfg array with the specified name containing all enum values. This is only 79 /// valid for enums where all the variants do not have fields associated with them. 80 /// Access to array is possible only in test builds 81 /// 82 /// # Example 83 /// 84 /// ``` 85 /// # #[macro_use] extern crate enum_utils; 86 /// #[gen_test_enum_array(MyTestArrayName)] 87 /// enum MyEnum { 88 /// ValueOne, 89 /// ValueTwo, 90 /// } 91 /// 92 /// #[cfg(test)] 93 /// mod tests { 94 /// #[test] 95 /// fn test_two_values() { 96 /// assert_eq!(MyTestArrayName.len(), 2); 97 /// } 98 /// } 99 /// ``` 100 #[proc_macro_attribute] gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream101 pub fn gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream { 102 generate_enum_array(name, item, "#[cfg(test)]") 103 } 104 105 /// Generates an array with the specified name containing all enum values. This is only valid for 106 /// enums where all the variants do not have fields associated with them. 107 /// 108 /// # Example 109 /// 110 /// ``` 111 /// # #[macro_use] extern crate enum_utils; 112 /// #[gen_enum_array(MyArrayName)] 113 /// pub enum MyEnum { 114 /// ValueOne, 115 /// ValueTwo, 116 /// } 117 /// 118 /// fn check() { 119 /// assert_eq!(MyArrayName.len(), 2); 120 /// } 121 /// ``` 122 #[proc_macro_attribute] gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream123 pub fn gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream { 124 generate_enum_array(name, item, "") 125 } 126 127 /// Generates an impl to_string for enum which returns &str. 128 /// Method is implemended using match for every enum variant. 129 /// Valid for enums where all the variants do not have fields associated with them. 130 /// 131 /// # Example 132 /// 133 /// ``` 134 /// # #[macro_use] extern crate enum_utils; 135 /// #[gen_to_string] 136 /// enum MyEnum { 137 /// ValueOne, 138 /// ValueTwo, 139 /// } 140 /// 141 /// fn main() { 142 /// let e = MyEnum::ValueOne; 143 /// println!("{}", e.to_string()); 144 /// } 145 /// 146 /// ``` 147 #[proc_macro_attribute] gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream148 pub fn gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream { 149 let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum"); 150 151 let mut match_arms = TokenStream::new(); 152 let enums_items = enum_stream.into_iter().filter_map(|tt| match tt { 153 TokenTree::Ident(id) => Some(id.to_string()), 154 _ => None, 155 }); 156 for item in enums_items { 157 let arm: TokenStream = format!( 158 r#"{enum_name}::{item} => "{item}","#, 159 enum_name = enum_name, 160 item = item, 161 ) 162 .parse() 163 .unwrap(); 164 match_arms.extend(arm); 165 } 166 167 let implementation: TokenStream = format!( 168 r#"impl {enum_name} {{ 169 pub fn to_string(&self) -> &'static str {{ 170 match *self {{ 171 {match_arms} 172 }} 173 }} 174 }}"#, 175 enum_name = enum_name, 176 match_arms = match_arms 177 ) 178 .parse() 179 .unwrap(); 180 181 // Emit input as is and add the to_string implementation 182 input.extend(implementation); 183 input 184 } 185 parse_to_i128(val: &str, negative: bool) -> i128186 fn parse_to_i128(val: &str, negative: bool) -> i128 { 187 let (first_pos, base) = match val.get(0..2) { 188 Some("0x") => (2, 16), 189 Some("0o") => (2, 8), 190 Some("0b") => (2, 2), 191 _ => (0, 10), 192 }; 193 194 let sign = if negative { -1 } else { 1 }; 195 196 // Remove any helper _ in the string literal, then convert from base 197 sign * i128::from_str_radix( 198 &val[first_pos..] 199 .chars() 200 .filter(|c| *c != '_') 201 .collect::<String>(), 202 base, 203 ) 204 .unwrap_or_else(|_| panic!("Invalid number {}", val)) 205 } 206 207 /// Generates a exclusive `END` const and `from_<repr>` function for an enum 208 /// 209 /// # Example 210 /// 211 /// ``` 212 /// # #[macro_use] extern crate enum_utils; 213 /// #[enum_as(u8)] 214 /// enum MyEnum { 215 /// ValueZero, 216 /// ValueOne, 217 /// } 218 /// 219 /// # fn main() { 220 /// assert!(matches!(MyEnum::from_u8(1), Some(MyEnum::ValueOne))); 221 /// assert_eq!(MyEnum::END, 2); 222 /// # } 223 /// 224 /// ``` 225 #[proc_macro_attribute] enum_as(repr: TokenStream, input: TokenStream) -> TokenStream226 pub fn enum_as(repr: TokenStream, input: TokenStream) -> TokenStream { 227 let repr = repr.to_string(); 228 let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum"); 229 230 #[derive(Debug, Copy, Clone)] 231 enum WantState { 232 Determinant, 233 Punc, 234 NegativeVal, 235 Val, 236 CommentBlock, 237 } 238 let mut state = WantState::Determinant; 239 let mut skipped_ranges = vec![]; 240 let mut start: Option<i128> = None; 241 let mut end: Option<i128> = None; 242 for tt in enum_stream { 243 use WantState::*; 244 state = match (tt, state) { 245 (TokenTree::Ident(_), Determinant) => Punc, 246 // If we are expecting a Determinant, but get a # instead, it must be a comment 247 (TokenTree::Punct(p), Determinant) if p.as_char() == '#' => CommentBlock, 248 (TokenTree::Punct(p), Val) if p.as_char() == '-' => NegativeVal, 249 (TokenTree::Group(_), CommentBlock) => Determinant, 250 (TokenTree::Punct(p), Punc) => match p.as_char() { 251 ',' => { 252 start = Some(start.unwrap_or(0)); 253 end = Some(end.unwrap_or_else(|| start.unwrap()) + 1); 254 Determinant 255 } 256 '=' => Val, 257 other => panic!("Unexpected punctuation '{}'", other), 258 }, 259 (TokenTree::Literal(l), Val | NegativeVal) => { 260 let val = parse_to_i128(&l.to_string(), matches!(state, NegativeVal)); 261 start = Some(start.unwrap_or(val)); 262 let expected = end.unwrap_or_else(|| start.unwrap()); 263 match val.cmp(&expected) { 264 Ordering::Greater => { 265 skipped_ranges.push((expected, val)); 266 end = Some(val); 267 } 268 Ordering::Less => panic!("Discriminants must increase in value"), 269 Ordering::Equal => (), 270 } 271 Punc 272 } 273 (tt, want) => { 274 panic!("Want {:?} but got {:?}", want, tt) 275 } 276 }; 277 } 278 279 let skipped_ranges = if skipped_ranges.is_empty() { 280 "".to_string() 281 } else { 282 format!( 283 r#"for r in &{ranges:?} {{ 284 if (r.0..r.1).contains(&val) {{ 285 return None; 286 }} 287 }}"#, 288 ranges = skipped_ranges 289 ) 290 }; 291 292 // Ensure that there is at least one discriminant 293 let start = start.expect("Enum needs at least one discriminant"); 294 let end = end.expect("Enum needs at least one discriminant"); 295 296 // Ensure that END will fit into usize or u64 297 u64::try_from(end).unwrap_or_else(|_| panic!("Value after last discriminant must be unsigned")); 298 299 let implementation: TokenStream = format!( 300 r#"impl {enum_name} {{ 301 pub const END: {end_type} = {end}; 302 303 pub fn from_{repr}(val: {repr}) -> Option<Self> {{ 304 if val < {start} {{ 305 return None; 306 }} 307 if val > ((Self::END - 1) as {repr}) {{ 308 return None; 309 }} 310 {skipped_ranges} 311 Some( unsafe {{ core::mem::transmute(val) }}) 312 }} 313 }} 314 315 impl PartialEq for {enum_name} {{ 316 fn eq(&self, other: &Self) -> bool {{ 317 *self as {repr} == *other as {repr} 318 }} 319 }} 320 321 impl Eq for {enum_name} {{ }} 322 323 #[cfg(not(target_arch = "riscv32"))] 324 impl core::hash::Hash for {enum_name} {{ 325 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {{ 326 (*self as {repr}).hash(state); 327 }} 328 }}"#, 329 enum_name = enum_name, 330 start = start, 331 end = end, 332 end_type = if repr == "u64" { "u64" } else { "usize " }, 333 repr = repr, 334 skipped_ranges = skipped_ranges, 335 ) 336 .parse() 337 .unwrap(); 338 339 // Attribute input with a repr, Clone, Clone, and allow(dead_code), then add the custom 340 // implementation. We allow dead code since the these enums are typically interface enums that 341 // define an API boundary. 342 let mut res: TokenStream = format!( 343 "#[repr({})]\n#[derive(Copy, Clone)]\n#[allow(dead_code)]", 344 repr 345 ) 346 .parse() 347 .unwrap(); 348 res.extend(input); 349 res.extend(implementation); 350 res 351 } 352 get_param_list_without_self(params: TokenStream) -> String353 fn get_param_list_without_self(params: TokenStream) -> String { 354 let mut result = String::new(); 355 356 #[derive(Debug, Copy, Clone)] 357 enum WantState { 358 SelfParam, 359 FirstIdentifier, 360 Comma, 361 } 362 let mut state = WantState::SelfParam; 363 for tt in params { 364 use WantState::*; 365 state = match (tt, state) { 366 (TokenTree::Ident(ident), SelfParam) if ident.to_string() == "self" => FirstIdentifier, 367 (TokenTree::Ident(ident), FirstIdentifier) => { 368 result.push_str(&ident.to_string()); 369 Comma 370 } 371 (TokenTree::Punct(p), Comma) if p.to_string() == "," => { 372 result.push(','); 373 FirstIdentifier 374 } 375 (_, other) => { 376 // Do nothing; keep watching for what we are looking for 377 other 378 } 379 }; 380 } 381 382 result 383 } 384 385 /// Replaces the function body with a single statement that pass all parameters thru to same 386 /// function name on the variable specified in the passthru_to macro 387 /// 388 /// # Example 389 /// 390 /// ``` 391 /// # #[macro_use] extern crate enum_utils; 392 /// pub struct Inner(usize); 393 /// 394 /// impl Inner { 395 /// pub fn read_plus(&self, plus: usize) -> usize { 396 /// self.0 + plus 397 /// } 398 /// } 399 /// 400 /// pub struct Outer { 401 /// pub my_inner: Inner, 402 /// } 403 /// 404 /// impl Outer { 405 /// #[passthru_to(my_inner)] 406 /// pub fn read_plus(&self, plus: usize) -> usize {} 407 /// } 408 /// 409 /// ``` 410 #[proc_macro_attribute] passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream411 pub fn passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream { 412 #[derive(Debug, Copy, Clone)] 413 enum WantState { 414 FunctionKeyword, 415 FunctionName, 416 Parameters, 417 Body, 418 End, 419 } 420 let mut state = WantState::FunctionKeyword; 421 let mut name = None; 422 let mut params = None; 423 let mut body_num = None; 424 for (i, tt) in input.clone().into_iter().enumerate() { 425 use WantState::*; 426 state = match (tt, state) { 427 (TokenTree::Ident(ident), FunctionKeyword) if ident.to_string() == "fn" => FunctionName, 428 (_, FunctionKeyword) => FunctionKeyword, 429 (TokenTree::Ident(ident), FunctionName) => { 430 name = Some(ident.to_string()); 431 Parameters 432 } 433 (TokenTree::Group(group), Parameters) => { 434 params = Some(get_param_list_without_self(group.stream())); 435 Body 436 } 437 (TokenTree::Group(group), Body) if group.delimiter() == Delimiter::Brace => { 438 body_num = Some(i); 439 End 440 } 441 (_, Body) => Body, 442 (tt, want) => { 443 panic!("Want {:?} but got {:?}", want, tt) 444 } 445 }; 446 } 447 448 let implementation: TokenStream = format!( 449 r#"{{ self.{passthru_var}.{name}({params}) }}"#, 450 passthru_var = passthru_var, 451 name = name.expect("Cannot find function name"), 452 params = params.expect("Could find parameters"), 453 ) 454 .parse() 455 .unwrap(); 456 457 // Take everything up to body, then replace body with the passthru implementation 458 input 459 .into_iter() 460 .take(body_num.expect("Cannot find body")) 461 .chain(implementation) 462 .collect() 463 } 464