1 use std::fmt; 2 use std::io::IoSlice; 3 4 use bytes::buf::{Chain, Take}; 5 use bytes::Buf; 6 use tracing::trace; 7 8 use super::io::WriteBuf; 9 10 type StaticBuf = &'static [u8]; 11 12 /// Encoders to handle different Transfer-Encodings. 13 #[derive(Debug, Clone, PartialEq)] 14 pub(crate) struct Encoder { 15 kind: Kind, 16 is_last: bool, 17 } 18 19 #[derive(Debug)] 20 pub(crate) struct EncodedBuf<B> { 21 kind: BufKind<B>, 22 } 23 24 #[derive(Debug)] 25 pub(crate) struct NotEof(u64); 26 27 #[derive(Debug, PartialEq, Clone)] 28 enum Kind { 29 /// An Encoder for when Transfer-Encoding includes `chunked`. 30 Chunked, 31 /// An Encoder for when Content-Length is set. 32 /// 33 /// Enforces that the body is not longer than the Content-Length header. 34 Length(u64), 35 /// An Encoder for when neither Content-Length nor Chunked encoding is set. 36 /// 37 /// This is mostly only used with HTTP/1.0 with a length. This kind requires 38 /// the connection to be closed when the body is finished. 39 #[cfg(feature = "server")] 40 CloseDelimited, 41 } 42 43 #[derive(Debug)] 44 enum BufKind<B> { 45 Exact(B), 46 Limited(Take<B>), 47 Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>), 48 ChunkedEnd(StaticBuf), 49 } 50 51 impl Encoder { new(kind: Kind) -> Encoder52 fn new(kind: Kind) -> Encoder { 53 Encoder { 54 kind, 55 is_last: false, 56 } 57 } chunked() -> Encoder58 pub(crate) fn chunked() -> Encoder { 59 Encoder::new(Kind::Chunked) 60 } 61 length(len: u64) -> Encoder62 pub(crate) fn length(len: u64) -> Encoder { 63 Encoder::new(Kind::Length(len)) 64 } 65 66 #[cfg(feature = "server")] close_delimited() -> Encoder67 pub(crate) fn close_delimited() -> Encoder { 68 Encoder::new(Kind::CloseDelimited) 69 } 70 is_eof(&self) -> bool71 pub(crate) fn is_eof(&self) -> bool { 72 matches!(self.kind, Kind::Length(0)) 73 } 74 75 #[cfg(feature = "server")] set_last(mut self, is_last: bool) -> Self76 pub(crate) fn set_last(mut self, is_last: bool) -> Self { 77 self.is_last = is_last; 78 self 79 } 80 is_last(&self) -> bool81 pub(crate) fn is_last(&self) -> bool { 82 self.is_last 83 } 84 is_close_delimited(&self) -> bool85 pub(crate) fn is_close_delimited(&self) -> bool { 86 match self.kind { 87 #[cfg(feature = "server")] 88 Kind::CloseDelimited => true, 89 _ => false, 90 } 91 } 92 end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof>93 pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> { 94 match self.kind { 95 Kind::Length(0) => Ok(None), 96 Kind::Chunked => Ok(Some(EncodedBuf { 97 kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), 98 })), 99 #[cfg(feature = "server")] 100 Kind::CloseDelimited => Ok(None), 101 Kind::Length(n) => Err(NotEof(n)), 102 } 103 } 104 encode<B>(&mut self, msg: B) -> EncodedBuf<B> where B: Buf,105 pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B> 106 where 107 B: Buf, 108 { 109 let len = msg.remaining(); 110 debug_assert!(len > 0, "encode() called with empty buf"); 111 112 let kind = match self.kind { 113 Kind::Chunked => { 114 trace!("encoding chunked {}B", len); 115 let buf = ChunkSize::new(len) 116 .chain(msg) 117 .chain(b"\r\n" as &'static [u8]); 118 BufKind::Chunked(buf) 119 } 120 Kind::Length(ref mut remaining) => { 121 trace!("sized write, len = {}", len); 122 if len as u64 > *remaining { 123 let limit = *remaining as usize; 124 *remaining = 0; 125 BufKind::Limited(msg.take(limit)) 126 } else { 127 *remaining -= len as u64; 128 BufKind::Exact(msg) 129 } 130 } 131 #[cfg(feature = "server")] 132 Kind::CloseDelimited => { 133 trace!("close delimited write {}B", len); 134 BufKind::Exact(msg) 135 } 136 }; 137 EncodedBuf { kind } 138 } 139 encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool where B: Buf,140 pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool 141 where 142 B: Buf, 143 { 144 let len = msg.remaining(); 145 debug_assert!(len > 0, "encode() called with empty buf"); 146 147 match self.kind { 148 Kind::Chunked => { 149 trace!("encoding chunked {}B", len); 150 let buf = ChunkSize::new(len) 151 .chain(msg) 152 .chain(b"\r\n0\r\n\r\n" as &'static [u8]); 153 dst.buffer(buf); 154 !self.is_last 155 } 156 Kind::Length(remaining) => { 157 use std::cmp::Ordering; 158 159 trace!("sized write, len = {}", len); 160 match (len as u64).cmp(&remaining) { 161 Ordering::Equal => { 162 dst.buffer(msg); 163 !self.is_last 164 } 165 Ordering::Greater => { 166 dst.buffer(msg.take(remaining as usize)); 167 !self.is_last 168 } 169 Ordering::Less => { 170 dst.buffer(msg); 171 false 172 } 173 } 174 } 175 #[cfg(feature = "server")] 176 Kind::CloseDelimited => { 177 trace!("close delimited write {}B", len); 178 dst.buffer(msg); 179 false 180 } 181 } 182 } 183 184 /// Encodes the full body, without verifying the remaining length matches. 185 /// 186 /// This is used in conjunction with HttpBody::__hyper_full_data(), which 187 /// means we can trust that the buf has the correct size (the buf itself 188 /// was checked to make the headers). danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) where B: Buf,189 pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) 190 where 191 B: Buf, 192 { 193 debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); 194 debug_assert!( 195 match self.kind { 196 Kind::Length(len) => len == msg.remaining() as u64, 197 _ => true, 198 }, 199 "danger_full_buf length mismatches" 200 ); 201 202 match self.kind { 203 Kind::Chunked => { 204 let len = msg.remaining(); 205 trace!("encoding chunked {}B", len); 206 let buf = ChunkSize::new(len) 207 .chain(msg) 208 .chain(b"\r\n0\r\n\r\n" as &'static [u8]); 209 dst.buffer(buf); 210 } 211 _ => { 212 dst.buffer(msg); 213 } 214 } 215 } 216 } 217 218 impl<B> Buf for EncodedBuf<B> 219 where 220 B: Buf, 221 { 222 #[inline] remaining(&self) -> usize223 fn remaining(&self) -> usize { 224 match self.kind { 225 BufKind::Exact(ref b) => b.remaining(), 226 BufKind::Limited(ref b) => b.remaining(), 227 BufKind::Chunked(ref b) => b.remaining(), 228 BufKind::ChunkedEnd(ref b) => b.remaining(), 229 } 230 } 231 232 #[inline] chunk(&self) -> &[u8]233 fn chunk(&self) -> &[u8] { 234 match self.kind { 235 BufKind::Exact(ref b) => b.chunk(), 236 BufKind::Limited(ref b) => b.chunk(), 237 BufKind::Chunked(ref b) => b.chunk(), 238 BufKind::ChunkedEnd(ref b) => b.chunk(), 239 } 240 } 241 242 #[inline] advance(&mut self, cnt: usize)243 fn advance(&mut self, cnt: usize) { 244 match self.kind { 245 BufKind::Exact(ref mut b) => b.advance(cnt), 246 BufKind::Limited(ref mut b) => b.advance(cnt), 247 BufKind::Chunked(ref mut b) => b.advance(cnt), 248 BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), 249 } 250 } 251 252 #[inline] chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize253 fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { 254 match self.kind { 255 BufKind::Exact(ref b) => b.chunks_vectored(dst), 256 BufKind::Limited(ref b) => b.chunks_vectored(dst), 257 BufKind::Chunked(ref b) => b.chunks_vectored(dst), 258 BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), 259 } 260 } 261 } 262 263 #[cfg(target_pointer_width = "32")] 264 const USIZE_BYTES: usize = 4; 265 266 #[cfg(target_pointer_width = "64")] 267 const USIZE_BYTES: usize = 8; 268 269 // each byte will become 2 hex 270 const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; 271 272 #[derive(Clone, Copy)] 273 struct ChunkSize { 274 bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], 275 pos: u8, 276 len: u8, 277 } 278 279 impl ChunkSize { new(len: usize) -> ChunkSize280 fn new(len: usize) -> ChunkSize { 281 use std::fmt::Write; 282 let mut size = ChunkSize { 283 bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], 284 pos: 0, 285 len: 0, 286 }; 287 write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize"); 288 size 289 } 290 } 291 292 impl Buf for ChunkSize { 293 #[inline] remaining(&self) -> usize294 fn remaining(&self) -> usize { 295 (self.len - self.pos).into() 296 } 297 298 #[inline] chunk(&self) -> &[u8]299 fn chunk(&self) -> &[u8] { 300 &self.bytes[self.pos.into()..self.len.into()] 301 } 302 303 #[inline] advance(&mut self, cnt: usize)304 fn advance(&mut self, cnt: usize) { 305 assert!(cnt <= self.remaining()); 306 self.pos += cnt as u8; // just asserted cnt fits in u8 307 } 308 } 309 310 impl fmt::Debug for ChunkSize { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 312 f.debug_struct("ChunkSize") 313 .field("bytes", &&self.bytes[..self.len.into()]) 314 .field("pos", &self.pos) 315 .finish() 316 } 317 } 318 319 impl fmt::Write for ChunkSize { write_str(&mut self, num: &str) -> fmt::Result320 fn write_str(&mut self, num: &str) -> fmt::Result { 321 use std::io::Write; 322 (&mut self.bytes[self.len.into()..]) 323 .write_all(num.as_bytes()) 324 .expect("&mut [u8].write() cannot error"); 325 self.len += num.len() as u8; // safe because bytes is never bigger than 256 326 Ok(()) 327 } 328 } 329 330 impl<B: Buf> From<B> for EncodedBuf<B> { from(buf: B) -> Self331 fn from(buf: B) -> Self { 332 EncodedBuf { 333 kind: BufKind::Exact(buf), 334 } 335 } 336 } 337 338 impl<B: Buf> From<Take<B>> for EncodedBuf<B> { from(buf: Take<B>) -> Self339 fn from(buf: Take<B>) -> Self { 340 EncodedBuf { 341 kind: BufKind::Limited(buf), 342 } 343 } 344 } 345 346 impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> { from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self347 fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self { 348 EncodedBuf { 349 kind: BufKind::Chunked(buf), 350 } 351 } 352 } 353 354 impl fmt::Display for NotEof { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result355 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 356 write!(f, "early end, expected {} more bytes", self.0) 357 } 358 } 359 360 impl std::error::Error for NotEof {} 361 362 #[cfg(test)] 363 mod tests { 364 use bytes::BufMut; 365 366 use super::super::io::Cursor; 367 use super::Encoder; 368 369 #[test] chunked()370 fn chunked() { 371 let mut encoder = Encoder::chunked(); 372 let mut dst = Vec::new(); 373 374 let msg1 = b"foo bar".as_ref(); 375 let buf1 = encoder.encode(msg1); 376 dst.put(buf1); 377 assert_eq!(dst, b"7\r\nfoo bar\r\n"); 378 379 let msg2 = b"baz quux herp".as_ref(); 380 let buf2 = encoder.encode(msg2); 381 dst.put(buf2); 382 383 assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n"); 384 385 let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap(); 386 dst.put(end); 387 388 assert_eq!( 389 dst, 390 b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref() 391 ); 392 } 393 394 #[test] length()395 fn length() { 396 let max_len = 8; 397 let mut encoder = Encoder::length(max_len as u64); 398 let mut dst = Vec::new(); 399 400 let msg1 = b"foo bar".as_ref(); 401 let buf1 = encoder.encode(msg1); 402 dst.put(buf1); 403 404 assert_eq!(dst, b"foo bar"); 405 assert!(!encoder.is_eof()); 406 encoder.end::<()>().unwrap_err(); 407 408 let msg2 = b"baz".as_ref(); 409 let buf2 = encoder.encode(msg2); 410 dst.put(buf2); 411 412 assert_eq!(dst.len(), max_len); 413 assert_eq!(dst, b"foo barb"); 414 assert!(encoder.is_eof()); 415 assert!(encoder.end::<()>().unwrap().is_none()); 416 } 417 418 #[test] eof()419 fn eof() { 420 let mut encoder = Encoder::close_delimited(); 421 let mut dst = Vec::new(); 422 423 let msg1 = b"foo bar".as_ref(); 424 let buf1 = encoder.encode(msg1); 425 dst.put(buf1); 426 427 assert_eq!(dst, b"foo bar"); 428 assert!(!encoder.is_eof()); 429 encoder.end::<()>().unwrap(); 430 431 let msg2 = b"baz".as_ref(); 432 let buf2 = encoder.encode(msg2); 433 dst.put(buf2); 434 435 assert_eq!(dst, b"foo barbaz"); 436 assert!(!encoder.is_eof()); 437 encoder.end::<()>().unwrap(); 438 } 439 } 440