1 use crate::codec::UserError;
2 use crate::codec::UserError::*;
3 use crate::frame::{self, Frame, FrameSize};
4 use crate::hpack;
5 
6 use bytes::{Buf, BufMut, BytesMut};
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10 use tokio_util::io::poll_write_buf;
11 
12 use std::io::{self, Cursor};
13 
14 // A macro to get around a method needing to borrow &mut self
15 macro_rules! limited_write_buf {
16     ($self:expr) => {{
17         let limit = $self.max_frame_size() + frame::HEADER_LEN;
18         $self.buf.get_mut().limit(limit)
19     }};
20 }
21 
22 #[derive(Debug)]
23 pub struct FramedWrite<T, B> {
24     /// Upstream `AsyncWrite`
25     inner: T,
26 
27     encoder: Encoder<B>,
28 }
29 
30 #[derive(Debug)]
31 struct Encoder<B> {
32     /// HPACK encoder
33     hpack: hpack::Encoder,
34 
35     /// Write buffer
36     ///
37     /// TODO: Should this be a ring buffer?
38     buf: Cursor<BytesMut>,
39 
40     /// Next frame to encode
41     next: Option<Next<B>>,
42 
43     /// Last data frame
44     last_data_frame: Option<frame::Data<B>>,
45 
46     /// Max frame size, this is specified by the peer
47     max_frame_size: FrameSize,
48 
49     /// Chain payloads bigger than this.
50     chain_threshold: usize,
51 
52     /// Min buffer required to attempt to write a frame
53     min_buffer_capacity: usize,
54 }
55 
56 #[derive(Debug)]
57 enum Next<B> {
58     Data(frame::Data<B>),
59     Continuation(frame::Continuation),
60 }
61 
62 /// Initialize the connection with this amount of write buffer.
63 ///
64 /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
65 /// frame that big.
66 const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
67 
68 /// Chain payloads bigger than this when vectored I/O is enabled. The remote
69 /// will never advertise a max frame size less than this (well, the spec says
70 /// the max frame size can't be less than 16kb, so not even close).
71 const CHAIN_THRESHOLD: usize = 256;
72 
73 /// Chain payloads bigger than this when vectored I/O is **not** enabled.
74 /// A larger value in this scenario will reduce the number of small and
75 /// fragmented data being sent, and hereby improve the throughput.
76 const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
77 
78 // TODO: Make generic
79 impl<T, B> FramedWrite<T, B>
80 where
81     T: AsyncWrite + Unpin,
82     B: Buf,
83 {
new(inner: T) -> FramedWrite<T, B>84     pub fn new(inner: T) -> FramedWrite<T, B> {
85         let chain_threshold = if inner.is_write_vectored() {
86             CHAIN_THRESHOLD
87         } else {
88             CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
89         };
90         FramedWrite {
91             inner,
92             encoder: Encoder {
93                 hpack: hpack::Encoder::default(),
94                 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
95                 next: None,
96                 last_data_frame: None,
97                 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
98                 chain_threshold,
99                 min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
100             },
101         }
102     }
103 
104     /// Returns `Ready` when `send` is able to accept a frame
105     ///
106     /// Calling this function may result in the current contents of the buffer
107     /// to be flushed to `T`.
poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>108     pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
109         if !self.encoder.has_capacity() {
110             // Try flushing
111             ready!(self.flush(cx))?;
112 
113             if !self.encoder.has_capacity() {
114                 return Poll::Pending;
115             }
116         }
117 
118         Poll::Ready(Ok(()))
119     }
120 
121     /// Buffer a frame.
122     ///
123     /// `poll_ready` must be called first to ensure that a frame may be
124     /// accepted.
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>125     pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
126         self.encoder.buffer(item)
127     }
128 
129     /// Flush buffered data to the wire
flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>130     pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
131         let span = tracing::trace_span!("FramedWrite::flush");
132         let _e = span.enter();
133 
134         loop {
135             while !self.encoder.is_empty() {
136                 match self.encoder.next {
137                     Some(Next::Data(ref mut frame)) => {
138                         tracing::trace!(queued_data_frame = true);
139                         let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
140                         ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
141                     }
142                     _ => {
143                         tracing::trace!(queued_data_frame = false);
144                         ready!(poll_write_buf(
145                             Pin::new(&mut self.inner),
146                             cx,
147                             &mut self.encoder.buf
148                         ))?
149                     }
150                 };
151             }
152 
153             match self.encoder.unset_frame() {
154                 ControlFlow::Continue => (),
155                 ControlFlow::Break => break,
156             }
157         }
158 
159         tracing::trace!("flushing buffer");
160         // Flush the upstream
161         ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
162 
163         Poll::Ready(Ok(()))
164     }
165 
166     /// Close the codec
shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>167     pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
168         ready!(self.flush(cx))?;
169         Pin::new(&mut self.inner).poll_shutdown(cx)
170     }
171 }
172 
173 #[must_use]
174 enum ControlFlow {
175     Continue,
176     Break,
177 }
178 
179 impl<B> Encoder<B>
180 where
181     B: Buf,
182 {
unset_frame(&mut self) -> ControlFlow183     fn unset_frame(&mut self) -> ControlFlow {
184         // Clear internal buffer
185         self.buf.set_position(0);
186         self.buf.get_mut().clear();
187 
188         // The data frame has been written, so unset it
189         match self.next.take() {
190             Some(Next::Data(frame)) => {
191                 self.last_data_frame = Some(frame);
192                 debug_assert!(self.is_empty());
193                 ControlFlow::Break
194             }
195             Some(Next::Continuation(frame)) => {
196                 // Buffer the continuation frame, then try to write again
197                 let mut buf = limited_write_buf!(self);
198                 if let Some(continuation) = frame.encode(&mut buf) {
199                     self.next = Some(Next::Continuation(continuation));
200                 }
201                 ControlFlow::Continue
202             }
203             None => ControlFlow::Break,
204         }
205     }
206 
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>207     fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
208         // Ensure that we have enough capacity to accept the write.
209         assert!(self.has_capacity());
210         let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
211         let _e = span.enter();
212 
213         tracing::debug!(frame = ?item, "send");
214 
215         match item {
216             Frame::Data(mut v) => {
217                 // Ensure that the payload is not greater than the max frame.
218                 let len = v.payload().remaining();
219 
220                 if len > self.max_frame_size() {
221                     return Err(PayloadTooBig);
222                 }
223 
224                 if len >= self.chain_threshold {
225                     let head = v.head();
226 
227                     // Encode the frame head to the buffer
228                     head.encode(len, self.buf.get_mut());
229 
230                     if self.buf.get_ref().remaining() < self.chain_threshold {
231                         let extra_bytes = self.chain_threshold - self.buf.remaining();
232                         self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
233                     }
234 
235                     // Save the data frame
236                     self.next = Some(Next::Data(v));
237                 } else {
238                     v.encode_chunk(self.buf.get_mut());
239 
240                     // The chunk has been fully encoded, so there is no need to
241                     // keep it around
242                     assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
243 
244                     // Save off the last frame...
245                     self.last_data_frame = Some(v);
246                 }
247             }
248             Frame::Headers(v) => {
249                 let mut buf = limited_write_buf!(self);
250                 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
251                     self.next = Some(Next::Continuation(continuation));
252                 }
253             }
254             Frame::PushPromise(v) => {
255                 let mut buf = limited_write_buf!(self);
256                 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
257                     self.next = Some(Next::Continuation(continuation));
258                 }
259             }
260             Frame::Settings(v) => {
261                 v.encode(self.buf.get_mut());
262                 tracing::trace!(rem = self.buf.remaining(), "encoded settings");
263             }
264             Frame::GoAway(v) => {
265                 v.encode(self.buf.get_mut());
266                 tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
267             }
268             Frame::Ping(v) => {
269                 v.encode(self.buf.get_mut());
270                 tracing::trace!(rem = self.buf.remaining(), "encoded ping");
271             }
272             Frame::WindowUpdate(v) => {
273                 v.encode(self.buf.get_mut());
274                 tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
275             }
276 
277             Frame::Priority(_) => {
278                 /*
279                 v.encode(self.buf.get_mut());
280                 tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
281                 */
282                 unimplemented!();
283             }
284             Frame::Reset(v) => {
285                 v.encode(self.buf.get_mut());
286                 tracing::trace!(rem = self.buf.remaining(), "encoded reset");
287             }
288         }
289 
290         Ok(())
291     }
292 
has_capacity(&self) -> bool293     fn has_capacity(&self) -> bool {
294         self.next.is_none()
295             && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
296                 >= self.min_buffer_capacity)
297     }
298 
is_empty(&self) -> bool299     fn is_empty(&self) -> bool {
300         match self.next {
301             Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
302             _ => !self.buf.has_remaining(),
303         }
304     }
305 }
306 
307 impl<B> Encoder<B> {
max_frame_size(&self) -> usize308     fn max_frame_size(&self) -> usize {
309         self.max_frame_size as usize
310     }
311 }
312 
313 impl<T, B> FramedWrite<T, B> {
314     /// Returns the max frame size that can be sent
max_frame_size(&self) -> usize315     pub fn max_frame_size(&self) -> usize {
316         self.encoder.max_frame_size()
317     }
318 
319     /// Set the peer's max frame size.
set_max_frame_size(&mut self, val: usize)320     pub fn set_max_frame_size(&mut self, val: usize) {
321         assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
322         self.encoder.max_frame_size = val as FrameSize;
323     }
324 
325     /// Set the peer's header table size.
set_header_table_size(&mut self, val: usize)326     pub fn set_header_table_size(&mut self, val: usize) {
327         self.encoder.hpack.update_max_size(val);
328     }
329 
330     /// Retrieve the last data frame that has been sent
take_last_data_frame(&mut self) -> Option<frame::Data<B>>331     pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
332         self.encoder.last_data_frame.take()
333     }
334 
get_mut(&mut self) -> &mut T335     pub fn get_mut(&mut self) -> &mut T {
336         &mut self.inner
337     }
338 }
339 
340 impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<io::Result<()>>341     fn poll_read(
342         mut self: Pin<&mut Self>,
343         cx: &mut Context<'_>,
344         buf: &mut ReadBuf,
345     ) -> Poll<io::Result<()>> {
346         Pin::new(&mut self.inner).poll_read(cx, buf)
347     }
348 }
349 
350 // We never project the Pin to `B`.
351 impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
352 
353 #[cfg(feature = "unstable")]
354 mod unstable {
355     use super::*;
356 
357     impl<T, B> FramedWrite<T, B> {
get_ref(&self) -> &T358         pub fn get_ref(&self) -> &T {
359             &self.inner
360         }
361     }
362 }
363