1 use crate::codec::UserError; 2 use crate::codec::UserError::*; 3 use crate::frame::{self, Frame, FrameSize}; 4 use crate::hpack; 5 6 use bytes::{Buf, BufMut, BytesMut}; 7 use std::pin::Pin; 8 use std::task::{Context, Poll}; 9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 10 use tokio_util::io::poll_write_buf; 11 12 use std::io::{self, Cursor}; 13 14 // A macro to get around a method needing to borrow &mut self 15 macro_rules! limited_write_buf { 16 ($self:expr) => {{ 17 let limit = $self.max_frame_size() + frame::HEADER_LEN; 18 $self.buf.get_mut().limit(limit) 19 }}; 20 } 21 22 #[derive(Debug)] 23 pub struct FramedWrite<T, B> { 24 /// Upstream `AsyncWrite` 25 inner: T, 26 27 encoder: Encoder<B>, 28 } 29 30 #[derive(Debug)] 31 struct Encoder<B> { 32 /// HPACK encoder 33 hpack: hpack::Encoder, 34 35 /// Write buffer 36 /// 37 /// TODO: Should this be a ring buffer? 38 buf: Cursor<BytesMut>, 39 40 /// Next frame to encode 41 next: Option<Next<B>>, 42 43 /// Last data frame 44 last_data_frame: Option<frame::Data<B>>, 45 46 /// Max frame size, this is specified by the peer 47 max_frame_size: FrameSize, 48 49 /// Chain payloads bigger than this. 50 chain_threshold: usize, 51 52 /// Min buffer required to attempt to write a frame 53 min_buffer_capacity: usize, 54 } 55 56 #[derive(Debug)] 57 enum Next<B> { 58 Data(frame::Data<B>), 59 Continuation(frame::Continuation), 60 } 61 62 /// Initialize the connection with this amount of write buffer. 63 /// 64 /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS 65 /// frame that big. 66 const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; 67 68 /// Chain payloads bigger than this when vectored I/O is enabled. The remote 69 /// will never advertise a max frame size less than this (well, the spec says 70 /// the max frame size can't be less than 16kb, so not even close). 71 const CHAIN_THRESHOLD: usize = 256; 72 73 /// Chain payloads bigger than this when vectored I/O is **not** enabled. 74 /// A larger value in this scenario will reduce the number of small and 75 /// fragmented data being sent, and hereby improve the throughput. 76 const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; 77 78 // TODO: Make generic 79 impl<T, B> FramedWrite<T, B> 80 where 81 T: AsyncWrite + Unpin, 82 B: Buf, 83 { new(inner: T) -> FramedWrite<T, B>84 pub fn new(inner: T) -> FramedWrite<T, B> { 85 let chain_threshold = if inner.is_write_vectored() { 86 CHAIN_THRESHOLD 87 } else { 88 CHAIN_THRESHOLD_WITHOUT_VECTORED_IO 89 }; 90 FramedWrite { 91 inner, 92 encoder: Encoder { 93 hpack: hpack::Encoder::default(), 94 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), 95 next: None, 96 last_data_frame: None, 97 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, 98 chain_threshold, 99 min_buffer_capacity: chain_threshold + frame::HEADER_LEN, 100 }, 101 } 102 } 103 104 /// Returns `Ready` when `send` is able to accept a frame 105 /// 106 /// Calling this function may result in the current contents of the buffer 107 /// to be flushed to `T`. poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>108 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { 109 if !self.encoder.has_capacity() { 110 // Try flushing 111 ready!(self.flush(cx))?; 112 113 if !self.encoder.has_capacity() { 114 return Poll::Pending; 115 } 116 } 117 118 Poll::Ready(Ok(())) 119 } 120 121 /// Buffer a frame. 122 /// 123 /// `poll_ready` must be called first to ensure that a frame may be 124 /// accepted. buffer(&mut self, item: Frame<B>) -> Result<(), UserError>125 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { 126 self.encoder.buffer(item) 127 } 128 129 /// Flush buffered data to the wire flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>130 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { 131 let span = tracing::trace_span!("FramedWrite::flush"); 132 let _e = span.enter(); 133 134 loop { 135 while !self.encoder.is_empty() { 136 match self.encoder.next { 137 Some(Next::Data(ref mut frame)) => { 138 tracing::trace!(queued_data_frame = true); 139 let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); 140 ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? 141 } 142 _ => { 143 tracing::trace!(queued_data_frame = false); 144 ready!(poll_write_buf( 145 Pin::new(&mut self.inner), 146 cx, 147 &mut self.encoder.buf 148 ))? 149 } 150 }; 151 } 152 153 match self.encoder.unset_frame() { 154 ControlFlow::Continue => (), 155 ControlFlow::Break => break, 156 } 157 } 158 159 tracing::trace!("flushing buffer"); 160 // Flush the upstream 161 ready!(Pin::new(&mut self.inner).poll_flush(cx))?; 162 163 Poll::Ready(Ok(())) 164 } 165 166 /// Close the codec shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>167 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { 168 ready!(self.flush(cx))?; 169 Pin::new(&mut self.inner).poll_shutdown(cx) 170 } 171 } 172 173 #[must_use] 174 enum ControlFlow { 175 Continue, 176 Break, 177 } 178 179 impl<B> Encoder<B> 180 where 181 B: Buf, 182 { unset_frame(&mut self) -> ControlFlow183 fn unset_frame(&mut self) -> ControlFlow { 184 // Clear internal buffer 185 self.buf.set_position(0); 186 self.buf.get_mut().clear(); 187 188 // The data frame has been written, so unset it 189 match self.next.take() { 190 Some(Next::Data(frame)) => { 191 self.last_data_frame = Some(frame); 192 debug_assert!(self.is_empty()); 193 ControlFlow::Break 194 } 195 Some(Next::Continuation(frame)) => { 196 // Buffer the continuation frame, then try to write again 197 let mut buf = limited_write_buf!(self); 198 if let Some(continuation) = frame.encode(&mut buf) { 199 self.next = Some(Next::Continuation(continuation)); 200 } 201 ControlFlow::Continue 202 } 203 None => ControlFlow::Break, 204 } 205 } 206 buffer(&mut self, item: Frame<B>) -> Result<(), UserError>207 fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { 208 // Ensure that we have enough capacity to accept the write. 209 assert!(self.has_capacity()); 210 let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); 211 let _e = span.enter(); 212 213 tracing::debug!(frame = ?item, "send"); 214 215 match item { 216 Frame::Data(mut v) => { 217 // Ensure that the payload is not greater than the max frame. 218 let len = v.payload().remaining(); 219 220 if len > self.max_frame_size() { 221 return Err(PayloadTooBig); 222 } 223 224 if len >= self.chain_threshold { 225 let head = v.head(); 226 227 // Encode the frame head to the buffer 228 head.encode(len, self.buf.get_mut()); 229 230 if self.buf.get_ref().remaining() < self.chain_threshold { 231 let extra_bytes = self.chain_threshold - self.buf.remaining(); 232 self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); 233 } 234 235 // Save the data frame 236 self.next = Some(Next::Data(v)); 237 } else { 238 v.encode_chunk(self.buf.get_mut()); 239 240 // The chunk has been fully encoded, so there is no need to 241 // keep it around 242 assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); 243 244 // Save off the last frame... 245 self.last_data_frame = Some(v); 246 } 247 } 248 Frame::Headers(v) => { 249 let mut buf = limited_write_buf!(self); 250 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { 251 self.next = Some(Next::Continuation(continuation)); 252 } 253 } 254 Frame::PushPromise(v) => { 255 let mut buf = limited_write_buf!(self); 256 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { 257 self.next = Some(Next::Continuation(continuation)); 258 } 259 } 260 Frame::Settings(v) => { 261 v.encode(self.buf.get_mut()); 262 tracing::trace!(rem = self.buf.remaining(), "encoded settings"); 263 } 264 Frame::GoAway(v) => { 265 v.encode(self.buf.get_mut()); 266 tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); 267 } 268 Frame::Ping(v) => { 269 v.encode(self.buf.get_mut()); 270 tracing::trace!(rem = self.buf.remaining(), "encoded ping"); 271 } 272 Frame::WindowUpdate(v) => { 273 v.encode(self.buf.get_mut()); 274 tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); 275 } 276 277 Frame::Priority(_) => { 278 /* 279 v.encode(self.buf.get_mut()); 280 tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); 281 */ 282 unimplemented!(); 283 } 284 Frame::Reset(v) => { 285 v.encode(self.buf.get_mut()); 286 tracing::trace!(rem = self.buf.remaining(), "encoded reset"); 287 } 288 } 289 290 Ok(()) 291 } 292 has_capacity(&self) -> bool293 fn has_capacity(&self) -> bool { 294 self.next.is_none() 295 && (self.buf.get_ref().capacity() - self.buf.get_ref().len() 296 >= self.min_buffer_capacity) 297 } 298 is_empty(&self) -> bool299 fn is_empty(&self) -> bool { 300 match self.next { 301 Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), 302 _ => !self.buf.has_remaining(), 303 } 304 } 305 } 306 307 impl<B> Encoder<B> { max_frame_size(&self) -> usize308 fn max_frame_size(&self) -> usize { 309 self.max_frame_size as usize 310 } 311 } 312 313 impl<T, B> FramedWrite<T, B> { 314 /// Returns the max frame size that can be sent max_frame_size(&self) -> usize315 pub fn max_frame_size(&self) -> usize { 316 self.encoder.max_frame_size() 317 } 318 319 /// Set the peer's max frame size. set_max_frame_size(&mut self, val: usize)320 pub fn set_max_frame_size(&mut self, val: usize) { 321 assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); 322 self.encoder.max_frame_size = val as FrameSize; 323 } 324 325 /// Set the peer's header table size. set_header_table_size(&mut self, val: usize)326 pub fn set_header_table_size(&mut self, val: usize) { 327 self.encoder.hpack.update_max_size(val); 328 } 329 330 /// Retrieve the last data frame that has been sent take_last_data_frame(&mut self) -> Option<frame::Data<B>>331 pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { 332 self.encoder.last_data_frame.take() 333 } 334 get_mut(&mut self) -> &mut T335 pub fn get_mut(&mut self) -> &mut T { 336 &mut self.inner 337 } 338 } 339 340 impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<io::Result<()>>341 fn poll_read( 342 mut self: Pin<&mut Self>, 343 cx: &mut Context<'_>, 344 buf: &mut ReadBuf, 345 ) -> Poll<io::Result<()>> { 346 Pin::new(&mut self.inner).poll_read(cx, buf) 347 } 348 } 349 350 // We never project the Pin to `B`. 351 impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} 352 353 #[cfg(feature = "unstable")] 354 mod unstable { 355 use super::*; 356 357 impl<T, B> FramedWrite<T, B> { get_ref(&self) -> &T358 pub fn get_ref(&self) -> &T { 359 &self.inner 360 } 361 } 362 } 363