1 /// SWAR: SIMD Within A Register 2 /// SIMD validator backend that validates register-sized chunks of data at a time. 3 use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes}; 4 5 // Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8 6 const BLOCK_SIZE: usize = core::mem::size_of::<usize>(); 7 type ByteBlock = [u8; BLOCK_SIZE]; 8 9 #[inline] match_uri_vectored(bytes: &mut Bytes)10 pub fn match_uri_vectored(bytes: &mut Bytes) { 11 loop { 12 if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) { 13 let n = match_uri_char_8_swar(bytes8); 14 // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes 15 // in `bytes`, so calling `advance(n)` is safe. 16 unsafe { 17 bytes.advance(n); 18 } 19 if n == BLOCK_SIZE { 20 continue; 21 } 22 } 23 if let Some(b) = bytes.peek() { 24 if is_uri_token(b) { 25 // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte 26 // in bytes, so calling advance is safe. 27 unsafe { 28 bytes.advance(1); 29 } 30 continue; 31 } 32 } 33 break; 34 } 35 } 36 37 #[inline] match_header_value_vectored(bytes: &mut Bytes)38 pub fn match_header_value_vectored(bytes: &mut Bytes) { 39 loop { 40 if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) { 41 let n = match_header_value_char_8_swar(bytes8); 42 // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes 43 // in `bytes`, so calling `advance(n)` is safe. 44 unsafe { 45 bytes.advance(n); 46 } 47 if n == BLOCK_SIZE { 48 continue; 49 } 50 } 51 if let Some(b) = bytes.peek() { 52 if is_header_value_token(b) { 53 // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte 54 // in bytes, so calling advance is safe. 55 unsafe { 56 bytes.advance(1); 57 } 58 continue; 59 } 60 } 61 break; 62 } 63 } 64 65 #[inline] match_header_name_vectored(bytes: &mut Bytes)66 pub fn match_header_name_vectored(bytes: &mut Bytes) { 67 while let Some(block) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) { 68 let n = match_block(is_header_name_token, block); 69 // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes 70 // in `bytes`, so calling `advance(n)` is safe. 71 unsafe { 72 bytes.advance(n); 73 } 74 if n != BLOCK_SIZE { 75 return; 76 } 77 } 78 // SAFETY: match_tail processes at most the remaining data in `bytes`. advances `bytes` to the 79 // end, but no further. 80 unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) }; 81 } 82 83 // Matches "tail", i.e: when we have <BLOCK_SIZE bytes in the buffer, should be uncommon 84 #[cold] 85 #[inline] match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize86 fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize { 87 for (i, &b) in bytes.iter().enumerate() { 88 if !f(b) { 89 return i; 90 } 91 } 92 bytes.len() 93 } 94 95 // Naive fallback block matcher 96 #[inline(always)] match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize97 fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize { 98 for (i, &b) in block.iter().enumerate() { 99 if !f(b) { 100 return i; 101 } 102 } 103 BLOCK_SIZE 104 } 105 106 // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) 107 // creates a u64 whose bytes are each equal to b uniform_block(b: u8) -> usize108 const fn uniform_block(b: u8) -> usize { 109 (b as u64 * 0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize 110 } 111 112 // A byte-wise range-check on an enire word/block, 113 // ensuring all bytes in the word satisfy 114 // `33 <= x <= 126 && x != '>' && x != '<'` 115 // IMPORTANT: it false negatives if the block contains '?' 116 #[inline] match_uri_char_8_swar(block: ByteBlock) -> usize117 fn match_uri_char_8_swar(block: ByteBlock) -> usize { 118 // 33 <= x <= 126 119 const M: u8 = 0x21; 120 const N: u8 = 0x7E; 121 const BM: usize = uniform_block(M); 122 const BN: usize = uniform_block(127 - N); 123 const M128: usize = uniform_block(128); 124 125 let x = usize::from_ne_bytes(block); // Really just a transmute 126 let lt = x.wrapping_sub(BM) & !x; // <= m 127 let gt = x.wrapping_add(BN) | x; // >= n 128 129 // XOR checks to catch '<' & '>' for correctness 130 // 131 // XOR can be thought of as a "distance function" 132 // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0` 133 // (each u8 "xor key" providing a unique total ordering of u8) 134 // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`) 135 // xor(x, '>') <= 2 => {'>', '?', '<'} 136 // xor(x, '<') <= 2 => {'<', '=', '>'} 137 // 138 // We assume P('=') > P('?'), 139 // given well/commonly-formatted URLs with querystrings contain 140 // a single '?' but possibly many '=' 141 // 142 // Thus it's preferable/near-optimal to "xor distance" on '>', 143 // since we'll slowpath at most one block per URL 144 // 145 // Some rust code to sanity check this yourself: 146 // ```rs 147 // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> { 148 // (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect() 149 // } 150 // (xordist(b'<', 2), xordist(b'>', 2)) 151 // ``` 152 const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap 153 const BGT: usize = uniform_block(b'>'); 154 155 let xgt = x ^ BGT; 156 let ltgtq = xgt.wrapping_sub(B3) & !xgt; 157 158 offsetnz((ltgtq | lt | gt) & M128) 159 } 160 161 // A byte-wise range-check on an entire word/block, 162 // ensuring all bytes in the word satisfy `32 <= x <= 126` 163 // IMPORTANT: false negatives if obs-text is present (0x80..=0xFF) 164 #[inline] match_header_value_char_8_swar(block: ByteBlock) -> usize165 fn match_header_value_char_8_swar(block: ByteBlock) -> usize { 166 // 32 <= x <= 126 167 const M: u8 = 0x20; 168 const N: u8 = 0x7E; 169 const BM: usize = uniform_block(M); 170 const BN: usize = uniform_block(127 - N); 171 const M128: usize = uniform_block(128); 172 173 let x = usize::from_ne_bytes(block); // Really just a transmute 174 let lt = x.wrapping_sub(BM) & !x; // <= m 175 let gt = x.wrapping_add(BN) | x; // >= n 176 offsetnz((lt | gt) & M128) 177 } 178 179 /// Check block to find offset of first non-zero byte 180 // NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit 181 #[inline] offsetnz(block: usize) -> usize182 fn offsetnz(block: usize) -> usize { 183 // fast path optimistic case (common for long valid sequences) 184 if block == 0 { 185 return BLOCK_SIZE; 186 } 187 188 // perf: rust will unroll this loop 189 for (i, b) in block.to_ne_bytes().iter().copied().enumerate() { 190 if b != 0 { 191 return i; 192 } 193 } 194 unreachable!() 195 } 196 197 #[test] test_is_header_value_block()198 fn test_is_header_value_block() { 199 let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE; 200 201 // 0..32 => false 202 for b in 0..32_u8 { 203 assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b); 204 } 205 // 32..127 => true 206 for b in 32..127_u8 { 207 assert!(is_header_value_block([b; BLOCK_SIZE]), "b={}", b); 208 } 209 // 127..=255 => false 210 for b in 127..=255_u8 { 211 assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b); 212 } 213 214 215 #[cfg(target_pointer_width = "64")] 216 { 217 // A few sanity checks on non-uniform bytes for safe-measure 218 assert!(!is_header_value_block(*b"foo.com\n")); 219 assert!(!is_header_value_block(*b"o.com\r\nU")); 220 } 221 } 222 223 #[test] test_is_uri_block()224 fn test_is_uri_block() { 225 let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE; 226 227 // 0..33 => false 228 for b in 0..33_u8 { 229 assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b); 230 } 231 // 33..127 => true if b not in { '<', '?', '>' } 232 let falsy = |b| b"<?>".contains(&b); 233 for b in 33..127_u8 { 234 assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b); 235 } 236 // 127..=255 => false 237 for b in 127..=255_u8 { 238 assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b); 239 } 240 } 241 242 #[test] test_offsetnz()243 fn test_offsetnz() { 244 let seq = [0_u8; BLOCK_SIZE]; 245 for i in 0..BLOCK_SIZE { 246 let mut seq = seq; 247 seq[i] = 1; 248 let x = usize::from_ne_bytes(seq); 249 assert_eq!(offsetnz(x), i); 250 } 251 } 252