1  #![warn(rust_2018_idioms)]
2  
3  use tokio::io::AsyncWrite;
4  use tokio_test::{assert_ready, task};
5  use tokio_util::codec::{Encoder, FramedWrite};
6  
7  use bytes::{BufMut, BytesMut};
8  use futures_sink::Sink;
9  use std::collections::VecDeque;
10  use std::io::{self, Write};
11  use std::pin::Pin;
12  use std::task::Poll::{Pending, Ready};
13  use std::task::{Context, Poll};
14  
15  macro_rules! mock {
16      ($($x:expr,)*) => {{
17          let mut v = VecDeque::new();
18          v.extend(vec![$($x),*]);
19          Mock { calls: v }
20      }};
21  }
22  
23  macro_rules! pin {
24      ($id:ident) => {
25          Pin::new(&mut $id)
26      };
27  }
28  
29  struct U32Encoder;
30  
31  impl Encoder<u32> for U32Encoder {
32      type Error = io::Error;
33  
encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()>34      fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> {
35          // Reserve space
36          dst.reserve(4);
37          dst.put_u32(item);
38          Ok(())
39      }
40  }
41  
42  struct U64Encoder;
43  
44  impl Encoder<u64> for U64Encoder {
45      type Error = io::Error;
46  
encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()>47      fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
48          // Reserve space
49          dst.reserve(8);
50          dst.put_u64(item);
51          Ok(())
52      }
53  }
54  
55  #[test]
write_multi_frame_in_packet()56  fn write_multi_frame_in_packet() {
57      let mut task = task::spawn(());
58      let mock = mock! {
59          Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()),
60      };
61      let mut framed = FramedWrite::new(mock, U32Encoder);
62  
63      task.enter(|cx, _| {
64          assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
65          assert!(pin!(framed).start_send(0).is_ok());
66          assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
67          assert!(pin!(framed).start_send(1).is_ok());
68          assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
69          assert!(pin!(framed).start_send(2).is_ok());
70  
71          // Nothing written yet
72          assert_eq!(1, framed.get_ref().calls.len());
73  
74          // Flush the writes
75          assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
76  
77          assert_eq!(0, framed.get_ref().calls.len());
78      });
79  }
80  
81  #[test]
write_multi_frame_after_codec_changed()82  fn write_multi_frame_after_codec_changed() {
83      let mut task = task::spawn(());
84      let mock = mock! {
85          Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
86      };
87      let mut framed = FramedWrite::new(mock, U32Encoder);
88  
89      task.enter(|cx, _| {
90          assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
91          assert!(pin!(framed).start_send(0x04).is_ok());
92  
93          let mut framed = framed.map_encoder(|_| U64Encoder);
94          assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
95          assert!(pin!(framed).start_send(0x08).is_ok());
96  
97          // Nothing written yet
98          assert_eq!(1, framed.get_ref().calls.len());
99  
100          // Flush the writes
101          assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
102  
103          assert_eq!(0, framed.get_ref().calls.len());
104      });
105  }
106  
107  #[test]
write_hits_backpressure()108  fn write_hits_backpressure() {
109      const ITER: usize = 2 * 1024;
110  
111      let mut mock = mock! {
112          // Block the `ITER*2`th write
113          Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")),
114          Ok(b"".to_vec()),
115      };
116  
117      for i in 0..=ITER * 2 {
118          let mut b = BytesMut::with_capacity(4);
119          b.put_u32(i as u32);
120  
121          // Append to the end
122          match mock.calls.back_mut().unwrap() {
123              Ok(ref mut data) => {
124                  // Write in 2kb chunks
125                  if data.len() < ITER {
126                      data.extend_from_slice(&b[..]);
127                      continue;
128                  } // else fall through and create a new buffer
129              }
130              _ => unreachable!(),
131          }
132  
133          // Push a new chunk
134          mock.calls.push_back(Ok(b[..].to_vec()));
135      }
136      // 1 'wouldblock', 8 * 2KB buffers, 1 b-byte buffer
137      assert_eq!(mock.calls.len(), 10);
138  
139      let mut task = task::spawn(());
140      let mut framed = FramedWrite::new(mock, U32Encoder);
141      framed.set_backpressure_boundary(ITER * 8);
142      task.enter(|cx, _| {
143          // Send 16KB. This fills up FramedWrite buffer
144          for i in 0..ITER * 2 {
145              assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
146              assert!(pin!(framed).start_send(i as u32).is_ok());
147          }
148  
149          // Now we poll_ready which forces a flush. The mock pops the front message
150          // and decides to block.
151          assert!(pin!(framed).poll_ready(cx).is_pending());
152  
153          // We poll again, forcing another flush, which this time succeeds
154          // The whole 16KB buffer is flushed
155          assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
156  
157          // Send more data. This matches the final message expected by the mock
158          assert!(pin!(framed).start_send((ITER * 2) as u32).is_ok());
159  
160          // Flush the rest of the buffer
161          assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
162  
163          // Ensure the mock is empty
164          assert_eq!(0, framed.get_ref().calls.len());
165      })
166  }
167  
168  // // ===== Mock ======
169  
170  struct Mock {
171      calls: VecDeque<io::Result<Vec<u8>>>,
172  }
173  
174  impl Write for Mock {
write(&mut self, src: &[u8]) -> io::Result<usize>175      fn write(&mut self, src: &[u8]) -> io::Result<usize> {
176          match self.calls.pop_front() {
177              Some(Ok(data)) => {
178                  assert!(src.len() >= data.len());
179                  assert_eq!(&data[..], &src[..data.len()]);
180                  Ok(data.len())
181              }
182              Some(Err(e)) => Err(e),
183              None => panic!("unexpected write; {src:?}"),
184          }
185      }
186  
flush(&mut self) -> io::Result<()>187      fn flush(&mut self) -> io::Result<()> {
188          Ok(())
189      }
190  }
191  
192  impl AsyncWrite for Mock {
poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>193      fn poll_write(
194          self: Pin<&mut Self>,
195          _cx: &mut Context<'_>,
196          buf: &[u8],
197      ) -> Poll<Result<usize, io::Error>> {
198          match Pin::get_mut(self).write(buf) {
199              Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
200              other => Ready(other),
201          }
202      }
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>203      fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
204          match Pin::get_mut(self).flush() {
205              Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
206              other => Ready(other),
207          }
208      }
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>209      fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
210          unimplemented!()
211      }
212  }
213