1 #![warn(rust_2018_idioms)] 2 3 use tokio::io::AsyncWrite; 4 use tokio_test::{assert_ready, task}; 5 use tokio_util::codec::{Encoder, FramedWrite}; 6 7 use bytes::{BufMut, BytesMut}; 8 use futures_sink::Sink; 9 use std::collections::VecDeque; 10 use std::io::{self, Write}; 11 use std::pin::Pin; 12 use std::task::Poll::{Pending, Ready}; 13 use std::task::{Context, Poll}; 14 15 macro_rules! mock { 16 ($($x:expr,)*) => {{ 17 let mut v = VecDeque::new(); 18 v.extend(vec![$($x),*]); 19 Mock { calls: v } 20 }}; 21 } 22 23 macro_rules! pin { 24 ($id:ident) => { 25 Pin::new(&mut $id) 26 }; 27 } 28 29 struct U32Encoder; 30 31 impl Encoder<u32> for U32Encoder { 32 type Error = io::Error; 33 encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()>34 fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { 35 // Reserve space 36 dst.reserve(4); 37 dst.put_u32(item); 38 Ok(()) 39 } 40 } 41 42 struct U64Encoder; 43 44 impl Encoder<u64> for U64Encoder { 45 type Error = io::Error; 46 encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()>47 fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { 48 // Reserve space 49 dst.reserve(8); 50 dst.put_u64(item); 51 Ok(()) 52 } 53 } 54 55 #[test] write_multi_frame_in_packet()56 fn write_multi_frame_in_packet() { 57 let mut task = task::spawn(()); 58 let mock = mock! { 59 Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), 60 }; 61 let mut framed = FramedWrite::new(mock, U32Encoder); 62 63 task.enter(|cx, _| { 64 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 65 assert!(pin!(framed).start_send(0).is_ok()); 66 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 67 assert!(pin!(framed).start_send(1).is_ok()); 68 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 69 assert!(pin!(framed).start_send(2).is_ok()); 70 71 // Nothing written yet 72 assert_eq!(1, framed.get_ref().calls.len()); 73 74 // Flush the writes 75 assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); 76 77 assert_eq!(0, framed.get_ref().calls.len()); 78 }); 79 } 80 81 #[test] write_multi_frame_after_codec_changed()82 fn write_multi_frame_after_codec_changed() { 83 let mut task = task::spawn(()); 84 let mock = mock! { 85 Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), 86 }; 87 let mut framed = FramedWrite::new(mock, U32Encoder); 88 89 task.enter(|cx, _| { 90 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 91 assert!(pin!(framed).start_send(0x04).is_ok()); 92 93 let mut framed = framed.map_encoder(|_| U64Encoder); 94 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 95 assert!(pin!(framed).start_send(0x08).is_ok()); 96 97 // Nothing written yet 98 assert_eq!(1, framed.get_ref().calls.len()); 99 100 // Flush the writes 101 assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); 102 103 assert_eq!(0, framed.get_ref().calls.len()); 104 }); 105 } 106 107 #[test] write_hits_backpressure()108 fn write_hits_backpressure() { 109 const ITER: usize = 2 * 1024; 110 111 let mut mock = mock! { 112 // Block the `ITER*2`th write 113 Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), 114 Ok(b"".to_vec()), 115 }; 116 117 for i in 0..=ITER * 2 { 118 let mut b = BytesMut::with_capacity(4); 119 b.put_u32(i as u32); 120 121 // Append to the end 122 match mock.calls.back_mut().unwrap() { 123 Ok(ref mut data) => { 124 // Write in 2kb chunks 125 if data.len() < ITER { 126 data.extend_from_slice(&b[..]); 127 continue; 128 } // else fall through and create a new buffer 129 } 130 _ => unreachable!(), 131 } 132 133 // Push a new chunk 134 mock.calls.push_back(Ok(b[..].to_vec())); 135 } 136 // 1 'wouldblock', 8 * 2KB buffers, 1 b-byte buffer 137 assert_eq!(mock.calls.len(), 10); 138 139 let mut task = task::spawn(()); 140 let mut framed = FramedWrite::new(mock, U32Encoder); 141 framed.set_backpressure_boundary(ITER * 8); 142 task.enter(|cx, _| { 143 // Send 16KB. This fills up FramedWrite buffer 144 for i in 0..ITER * 2 { 145 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 146 assert!(pin!(framed).start_send(i as u32).is_ok()); 147 } 148 149 // Now we poll_ready which forces a flush. The mock pops the front message 150 // and decides to block. 151 assert!(pin!(framed).poll_ready(cx).is_pending()); 152 153 // We poll again, forcing another flush, which this time succeeds 154 // The whole 16KB buffer is flushed 155 assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); 156 157 // Send more data. This matches the final message expected by the mock 158 assert!(pin!(framed).start_send((ITER * 2) as u32).is_ok()); 159 160 // Flush the rest of the buffer 161 assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); 162 163 // Ensure the mock is empty 164 assert_eq!(0, framed.get_ref().calls.len()); 165 }) 166 } 167 168 // // ===== Mock ====== 169 170 struct Mock { 171 calls: VecDeque<io::Result<Vec<u8>>>, 172 } 173 174 impl Write for Mock { write(&mut self, src: &[u8]) -> io::Result<usize>175 fn write(&mut self, src: &[u8]) -> io::Result<usize> { 176 match self.calls.pop_front() { 177 Some(Ok(data)) => { 178 assert!(src.len() >= data.len()); 179 assert_eq!(&data[..], &src[..data.len()]); 180 Ok(data.len()) 181 } 182 Some(Err(e)) => Err(e), 183 None => panic!("unexpected write; {src:?}"), 184 } 185 } 186 flush(&mut self) -> io::Result<()>187 fn flush(&mut self) -> io::Result<()> { 188 Ok(()) 189 } 190 } 191 192 impl AsyncWrite for Mock { poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>193 fn poll_write( 194 self: Pin<&mut Self>, 195 _cx: &mut Context<'_>, 196 buf: &[u8], 197 ) -> Poll<Result<usize, io::Error>> { 198 match Pin::get_mut(self).write(buf) { 199 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, 200 other => Ready(other), 201 } 202 } poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>203 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 204 match Pin::get_mut(self).flush() { 205 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, 206 other => Ready(other), 207 } 208 } poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>209 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 210 unimplemented!() 211 } 212 } 213