1  /// SWAR: SIMD Within A Register
2  /// SIMD validator backend that validates register-sized chunks of data at a time.
3  use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes};
4  
5  // Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8
6  const BLOCK_SIZE: usize = core::mem::size_of::<usize>();
7  type ByteBlock = [u8; BLOCK_SIZE];
8  
9  #[inline]
match_uri_vectored(bytes: &mut Bytes)10  pub fn match_uri_vectored(bytes: &mut Bytes) {
11      loop {
12          if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
13              let n = match_uri_char_8_swar(bytes8);
14              // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
15              // in `bytes`, so calling `advance(n)` is safe.
16              unsafe {
17                  bytes.advance(n);
18              }
19              if n == BLOCK_SIZE {
20                  continue;
21              }
22          }
23          if let Some(b) = bytes.peek() {
24              if is_uri_token(b) {
25                  // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
26                  // in bytes, so calling advance is safe.
27                  unsafe {
28                      bytes.advance(1);
29                  }
30                  continue;
31              }
32          }
33          break;
34      }
35  }
36  
37  #[inline]
match_header_value_vectored(bytes: &mut Bytes)38  pub fn match_header_value_vectored(bytes: &mut Bytes) {
39      loop {
40          if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
41              let n = match_header_value_char_8_swar(bytes8);
42              // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
43              // in `bytes`, so calling `advance(n)` is safe.
44              unsafe {
45                  bytes.advance(n);
46              }
47              if n == BLOCK_SIZE {
48                  continue;
49              }
50          }
51          if let Some(b) = bytes.peek() {
52              if is_header_value_token(b) {
53                  // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
54                  // in bytes, so calling advance is safe.
55                  unsafe {
56                      bytes.advance(1);
57                  }
58                  continue;
59              }
60          }
61          break;
62      }
63  }
64  
65  #[inline]
match_header_name_vectored(bytes: &mut Bytes)66  pub fn match_header_name_vectored(bytes: &mut Bytes) {
67      while let Some(block) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
68          let n = match_block(is_header_name_token, block);
69          // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
70          // in `bytes`, so calling `advance(n)` is safe.
71          unsafe {
72              bytes.advance(n);
73          }
74          if n != BLOCK_SIZE {
75              return;
76          }
77      }
78      // SAFETY: match_tail processes at most the remaining data in `bytes`. advances `bytes` to the
79      // end, but no further.
80      unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) };
81  }
82  
83  // Matches "tail", i.e: when we have <BLOCK_SIZE bytes in the buffer, should be uncommon
84  #[cold]
85  #[inline]
match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize86  fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {
87      for (i, &b) in bytes.iter().enumerate() {
88          if !f(b) {
89              return i;
90          }
91      }
92      bytes.len()
93  }
94  
95  // Naive fallback block matcher
96  #[inline(always)]
match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize97  fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize {
98      for (i, &b) in block.iter().enumerate() {
99          if !f(b) {
100              return i;
101          }
102      }
103      BLOCK_SIZE
104  }
105  
106  // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
107  // creates a u64 whose bytes are each equal to b
uniform_block(b: u8) -> usize108  const fn uniform_block(b: u8) -> usize {
109      (b as u64 *  0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize
110  }
111  
112  // A byte-wise range-check on an enire word/block,
113  // ensuring all bytes in the word satisfy
114  // `33 <= x <= 126 && x != '>' && x != '<'`
115  // IMPORTANT: it false negatives if the block contains '?'
116  #[inline]
match_uri_char_8_swar(block: ByteBlock) -> usize117  fn match_uri_char_8_swar(block: ByteBlock) -> usize {
118      // 33 <= x <= 126
119      const M: u8 = 0x21;
120      const N: u8 = 0x7E;
121      const BM: usize = uniform_block(M);
122      const BN: usize = uniform_block(127 - N);
123      const M128: usize = uniform_block(128);
124  
125      let x = usize::from_ne_bytes(block); // Really just a transmute
126      let lt = x.wrapping_sub(BM) & !x; // <= m
127      let gt = x.wrapping_add(BN) | x; // >= n
128  
129      // XOR checks to catch '<' & '>' for correctness
130      //
131      // XOR can be thought of as a "distance function"
132      // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0`
133      // (each u8 "xor key" providing a unique total ordering of u8)
134      // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`)
135      // xor(x, '>') <= 2 => {'>', '?', '<'}
136      // xor(x, '<') <= 2 => {'<', '=', '>'}
137      //
138      // We assume P('=') > P('?'),
139      // given well/commonly-formatted URLs with querystrings contain
140      // a single '?' but possibly many '='
141      //
142      // Thus it's preferable/near-optimal to "xor distance" on '>',
143      // since we'll slowpath at most one block per URL
144      //
145      // Some rust code to sanity check this yourself:
146      // ```rs
147      // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> {
148      //     (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect()
149      // }
150      // (xordist(b'<', 2), xordist(b'>', 2))
151      // ```
152      const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap
153      const BGT: usize = uniform_block(b'>');
154  
155      let xgt = x ^ BGT;
156      let ltgtq = xgt.wrapping_sub(B3) & !xgt;
157  
158      offsetnz((ltgtq | lt | gt) & M128)
159  }
160  
161  // A byte-wise range-check on an entire word/block,
162  // ensuring all bytes in the word satisfy `32 <= x <= 126`
163  // IMPORTANT: false negatives if obs-text is present (0x80..=0xFF)
164  #[inline]
match_header_value_char_8_swar(block: ByteBlock) -> usize165  fn match_header_value_char_8_swar(block: ByteBlock) -> usize {
166      // 32 <= x <= 126
167      const M: u8 = 0x20;
168      const N: u8 = 0x7E;
169      const BM: usize = uniform_block(M);
170      const BN: usize = uniform_block(127 - N);
171      const M128: usize = uniform_block(128);
172  
173      let x = usize::from_ne_bytes(block); // Really just a transmute
174      let lt = x.wrapping_sub(BM) & !x; // <= m
175      let gt = x.wrapping_add(BN) | x; // >= n
176      offsetnz((lt | gt) & M128)
177  }
178  
179  /// Check block to find offset of first non-zero byte
180  // NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
181  #[inline]
offsetnz(block: usize) -> usize182  fn offsetnz(block: usize) -> usize {
183      // fast path optimistic case (common for long valid sequences)
184      if block == 0 {
185          return BLOCK_SIZE;
186      }
187  
188      // perf: rust will unroll this loop
189      for (i, b) in block.to_ne_bytes().iter().copied().enumerate() {
190          if b != 0 {
191              return i;
192          }
193      }
194      unreachable!()
195  }
196  
197  #[test]
test_is_header_value_block()198  fn test_is_header_value_block() {
199      let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE;
200  
201      // 0..32 => false
202      for b in 0..32_u8 {
203          assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
204      }
205      // 32..127 => true
206      for b in 32..127_u8 {
207          assert!(is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
208      }
209      // 127..=255 => false
210      for b in 127..=255_u8 {
211          assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
212      }
213  
214  
215      #[cfg(target_pointer_width = "64")]
216      {
217          // A few sanity checks on non-uniform bytes for safe-measure
218          assert!(!is_header_value_block(*b"foo.com\n"));
219          assert!(!is_header_value_block(*b"o.com\r\nU"));
220      }
221  }
222  
223  #[test]
test_is_uri_block()224  fn test_is_uri_block() {
225      let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE;
226  
227      // 0..33 => false
228      for b in 0..33_u8 {
229          assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b);
230      }
231      // 33..127 => true if b not in { '<', '?', '>' }
232      let falsy = |b| b"<?>".contains(&b);
233      for b in 33..127_u8 {
234          assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b);
235      }
236      // 127..=255 => false
237      for b in 127..=255_u8 {
238          assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b);
239      }
240  }
241  
242  #[test]
test_offsetnz()243  fn test_offsetnz() {
244      let seq = [0_u8; BLOCK_SIZE];
245      for i in 0..BLOCK_SIZE {
246          let mut seq = seq;
247          seq[i] = 1;
248          let x = usize::from_ne_bytes(seq);
249          assert_eq!(offsetnz(x), i);
250      }
251  }
252