1 use std::fmt;
2 use std::io::IoSlice;
3 
4 use bytes::buf::{Chain, Take};
5 use bytes::Buf;
6 use tracing::trace;
7 
8 use super::io::WriteBuf;
9 
10 type StaticBuf = &'static [u8];
11 
12 /// Encoders to handle different Transfer-Encodings.
13 #[derive(Debug, Clone, PartialEq)]
14 pub(crate) struct Encoder {
15     kind: Kind,
16     is_last: bool,
17 }
18 
19 #[derive(Debug)]
20 pub(crate) struct EncodedBuf<B> {
21     kind: BufKind<B>,
22 }
23 
24 #[derive(Debug)]
25 pub(crate) struct NotEof(u64);
26 
27 #[derive(Debug, PartialEq, Clone)]
28 enum Kind {
29     /// An Encoder for when Transfer-Encoding includes `chunked`.
30     Chunked,
31     /// An Encoder for when Content-Length is set.
32     ///
33     /// Enforces that the body is not longer than the Content-Length header.
34     Length(u64),
35     /// An Encoder for when neither Content-Length nor Chunked encoding is set.
36     ///
37     /// This is mostly only used with HTTP/1.0 with a length. This kind requires
38     /// the connection to be closed when the body is finished.
39     #[cfg(feature = "server")]
40     CloseDelimited,
41 }
42 
43 #[derive(Debug)]
44 enum BufKind<B> {
45     Exact(B),
46     Limited(Take<B>),
47     Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
48     ChunkedEnd(StaticBuf),
49 }
50 
51 impl Encoder {
new(kind: Kind) -> Encoder52     fn new(kind: Kind) -> Encoder {
53         Encoder {
54             kind,
55             is_last: false,
56         }
57     }
chunked() -> Encoder58     pub(crate) fn chunked() -> Encoder {
59         Encoder::new(Kind::Chunked)
60     }
61 
length(len: u64) -> Encoder62     pub(crate) fn length(len: u64) -> Encoder {
63         Encoder::new(Kind::Length(len))
64     }
65 
66     #[cfg(feature = "server")]
close_delimited() -> Encoder67     pub(crate) fn close_delimited() -> Encoder {
68         Encoder::new(Kind::CloseDelimited)
69     }
70 
is_eof(&self) -> bool71     pub(crate) fn is_eof(&self) -> bool {
72         matches!(self.kind, Kind::Length(0))
73     }
74 
75     #[cfg(feature = "server")]
set_last(mut self, is_last: bool) -> Self76     pub(crate) fn set_last(mut self, is_last: bool) -> Self {
77         self.is_last = is_last;
78         self
79     }
80 
is_last(&self) -> bool81     pub(crate) fn is_last(&self) -> bool {
82         self.is_last
83     }
84 
is_close_delimited(&self) -> bool85     pub(crate) fn is_close_delimited(&self) -> bool {
86         match self.kind {
87             #[cfg(feature = "server")]
88             Kind::CloseDelimited => true,
89             _ => false,
90         }
91     }
92 
end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof>93     pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
94         match self.kind {
95             Kind::Length(0) => Ok(None),
96             Kind::Chunked => Ok(Some(EncodedBuf {
97                 kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
98             })),
99             #[cfg(feature = "server")]
100             Kind::CloseDelimited => Ok(None),
101             Kind::Length(n) => Err(NotEof(n)),
102         }
103     }
104 
encode<B>(&mut self, msg: B) -> EncodedBuf<B> where B: Buf,105     pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
106     where
107         B: Buf,
108     {
109         let len = msg.remaining();
110         debug_assert!(len > 0, "encode() called with empty buf");
111 
112         let kind = match self.kind {
113             Kind::Chunked => {
114                 trace!("encoding chunked {}B", len);
115                 let buf = ChunkSize::new(len)
116                     .chain(msg)
117                     .chain(b"\r\n" as &'static [u8]);
118                 BufKind::Chunked(buf)
119             }
120             Kind::Length(ref mut remaining) => {
121                 trace!("sized write, len = {}", len);
122                 if len as u64 > *remaining {
123                     let limit = *remaining as usize;
124                     *remaining = 0;
125                     BufKind::Limited(msg.take(limit))
126                 } else {
127                     *remaining -= len as u64;
128                     BufKind::Exact(msg)
129                 }
130             }
131             #[cfg(feature = "server")]
132             Kind::CloseDelimited => {
133                 trace!("close delimited write {}B", len);
134                 BufKind::Exact(msg)
135             }
136         };
137         EncodedBuf { kind }
138     }
139 
encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool where B: Buf,140     pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
141     where
142         B: Buf,
143     {
144         let len = msg.remaining();
145         debug_assert!(len > 0, "encode() called with empty buf");
146 
147         match self.kind {
148             Kind::Chunked => {
149                 trace!("encoding chunked {}B", len);
150                 let buf = ChunkSize::new(len)
151                     .chain(msg)
152                     .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
153                 dst.buffer(buf);
154                 !self.is_last
155             }
156             Kind::Length(remaining) => {
157                 use std::cmp::Ordering;
158 
159                 trace!("sized write, len = {}", len);
160                 match (len as u64).cmp(&remaining) {
161                     Ordering::Equal => {
162                         dst.buffer(msg);
163                         !self.is_last
164                     }
165                     Ordering::Greater => {
166                         dst.buffer(msg.take(remaining as usize));
167                         !self.is_last
168                     }
169                     Ordering::Less => {
170                         dst.buffer(msg);
171                         false
172                     }
173                 }
174             }
175             #[cfg(feature = "server")]
176             Kind::CloseDelimited => {
177                 trace!("close delimited write {}B", len);
178                 dst.buffer(msg);
179                 false
180             }
181         }
182     }
183 
184     /// Encodes the full body, without verifying the remaining length matches.
185     ///
186     /// This is used in conjunction with HttpBody::__hyper_full_data(), which
187     /// means we can trust that the buf has the correct size (the buf itself
188     /// was checked to make the headers).
danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) where B: Buf,189     pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>)
190     where
191         B: Buf,
192     {
193         debug_assert!(msg.remaining() > 0, "encode() called with empty buf");
194         debug_assert!(
195             match self.kind {
196                 Kind::Length(len) => len == msg.remaining() as u64,
197                 _ => true,
198             },
199             "danger_full_buf length mismatches"
200         );
201 
202         match self.kind {
203             Kind::Chunked => {
204                 let len = msg.remaining();
205                 trace!("encoding chunked {}B", len);
206                 let buf = ChunkSize::new(len)
207                     .chain(msg)
208                     .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
209                 dst.buffer(buf);
210             }
211             _ => {
212                 dst.buffer(msg);
213             }
214         }
215     }
216 }
217 
218 impl<B> Buf for EncodedBuf<B>
219 where
220     B: Buf,
221 {
222     #[inline]
remaining(&self) -> usize223     fn remaining(&self) -> usize {
224         match self.kind {
225             BufKind::Exact(ref b) => b.remaining(),
226             BufKind::Limited(ref b) => b.remaining(),
227             BufKind::Chunked(ref b) => b.remaining(),
228             BufKind::ChunkedEnd(ref b) => b.remaining(),
229         }
230     }
231 
232     #[inline]
chunk(&self) -> &[u8]233     fn chunk(&self) -> &[u8] {
234         match self.kind {
235             BufKind::Exact(ref b) => b.chunk(),
236             BufKind::Limited(ref b) => b.chunk(),
237             BufKind::Chunked(ref b) => b.chunk(),
238             BufKind::ChunkedEnd(ref b) => b.chunk(),
239         }
240     }
241 
242     #[inline]
advance(&mut self, cnt: usize)243     fn advance(&mut self, cnt: usize) {
244         match self.kind {
245             BufKind::Exact(ref mut b) => b.advance(cnt),
246             BufKind::Limited(ref mut b) => b.advance(cnt),
247             BufKind::Chunked(ref mut b) => b.advance(cnt),
248             BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
249         }
250     }
251 
252     #[inline]
chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize253     fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
254         match self.kind {
255             BufKind::Exact(ref b) => b.chunks_vectored(dst),
256             BufKind::Limited(ref b) => b.chunks_vectored(dst),
257             BufKind::Chunked(ref b) => b.chunks_vectored(dst),
258             BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
259         }
260     }
261 }
262 
263 #[cfg(target_pointer_width = "32")]
264 const USIZE_BYTES: usize = 4;
265 
266 #[cfg(target_pointer_width = "64")]
267 const USIZE_BYTES: usize = 8;
268 
269 // each byte will become 2 hex
270 const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
271 
272 #[derive(Clone, Copy)]
273 struct ChunkSize {
274     bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
275     pos: u8,
276     len: u8,
277 }
278 
279 impl ChunkSize {
new(len: usize) -> ChunkSize280     fn new(len: usize) -> ChunkSize {
281         use std::fmt::Write;
282         let mut size = ChunkSize {
283             bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
284             pos: 0,
285             len: 0,
286         };
287         write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
288         size
289     }
290 }
291 
292 impl Buf for ChunkSize {
293     #[inline]
remaining(&self) -> usize294     fn remaining(&self) -> usize {
295         (self.len - self.pos).into()
296     }
297 
298     #[inline]
chunk(&self) -> &[u8]299     fn chunk(&self) -> &[u8] {
300         &self.bytes[self.pos.into()..self.len.into()]
301     }
302 
303     #[inline]
advance(&mut self, cnt: usize)304     fn advance(&mut self, cnt: usize) {
305         assert!(cnt <= self.remaining());
306         self.pos += cnt as u8; // just asserted cnt fits in u8
307     }
308 }
309 
310 impl fmt::Debug for ChunkSize {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result311     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312         f.debug_struct("ChunkSize")
313             .field("bytes", &&self.bytes[..self.len.into()])
314             .field("pos", &self.pos)
315             .finish()
316     }
317 }
318 
319 impl fmt::Write for ChunkSize {
write_str(&mut self, num: &str) -> fmt::Result320     fn write_str(&mut self, num: &str) -> fmt::Result {
321         use std::io::Write;
322         (&mut self.bytes[self.len.into()..])
323             .write_all(num.as_bytes())
324             .expect("&mut [u8].write() cannot error");
325         self.len += num.len() as u8; // safe because bytes is never bigger than 256
326         Ok(())
327     }
328 }
329 
330 impl<B: Buf> From<B> for EncodedBuf<B> {
from(buf: B) -> Self331     fn from(buf: B) -> Self {
332         EncodedBuf {
333             kind: BufKind::Exact(buf),
334         }
335     }
336 }
337 
338 impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
from(buf: Take<B>) -> Self339     fn from(buf: Take<B>) -> Self {
340         EncodedBuf {
341             kind: BufKind::Limited(buf),
342         }
343     }
344 }
345 
346 impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self347     fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
348         EncodedBuf {
349             kind: BufKind::Chunked(buf),
350         }
351     }
352 }
353 
354 impl fmt::Display for NotEof {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result355     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356         write!(f, "early end, expected {} more bytes", self.0)
357     }
358 }
359 
360 impl std::error::Error for NotEof {}
361 
362 #[cfg(test)]
363 mod tests {
364     use bytes::BufMut;
365 
366     use super::super::io::Cursor;
367     use super::Encoder;
368 
369     #[test]
chunked()370     fn chunked() {
371         let mut encoder = Encoder::chunked();
372         let mut dst = Vec::new();
373 
374         let msg1 = b"foo bar".as_ref();
375         let buf1 = encoder.encode(msg1);
376         dst.put(buf1);
377         assert_eq!(dst, b"7\r\nfoo bar\r\n");
378 
379         let msg2 = b"baz quux herp".as_ref();
380         let buf2 = encoder.encode(msg2);
381         dst.put(buf2);
382 
383         assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
384 
385         let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
386         dst.put(end);
387 
388         assert_eq!(
389             dst,
390             b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
391         );
392     }
393 
394     #[test]
length()395     fn length() {
396         let max_len = 8;
397         let mut encoder = Encoder::length(max_len as u64);
398         let mut dst = Vec::new();
399 
400         let msg1 = b"foo bar".as_ref();
401         let buf1 = encoder.encode(msg1);
402         dst.put(buf1);
403 
404         assert_eq!(dst, b"foo bar");
405         assert!(!encoder.is_eof());
406         encoder.end::<()>().unwrap_err();
407 
408         let msg2 = b"baz".as_ref();
409         let buf2 = encoder.encode(msg2);
410         dst.put(buf2);
411 
412         assert_eq!(dst.len(), max_len);
413         assert_eq!(dst, b"foo barb");
414         assert!(encoder.is_eof());
415         assert!(encoder.end::<()>().unwrap().is_none());
416     }
417 
418     #[test]
eof()419     fn eof() {
420         let mut encoder = Encoder::close_delimited();
421         let mut dst = Vec::new();
422 
423         let msg1 = b"foo bar".as_ref();
424         let buf1 = encoder.encode(msg1);
425         dst.put(buf1);
426 
427         assert_eq!(dst, b"foo bar");
428         assert!(!encoder.is_eof());
429         encoder.end::<()>().unwrap();
430 
431         let msg2 = b"baz".as_ref();
432         let buf2 = encoder.encode(msg2);
433         dst.put(buf2);
434 
435         assert_eq!(dst, b"foo barbaz");
436         assert!(!encoder.is_eof());
437         encoder.end::<()>().unwrap();
438     }
439 }
440