1  use bytes::Buf;
2  use futures_core::stream::Stream;
3  use futures_sink::Sink;
4  use std::io;
5  use std::pin::Pin;
6  use std::task::{Context, Poll};
7  use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
8  
9  /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
10  ///
11  /// This type performs the inverse operation of [`ReaderStream`].
12  ///
13  /// This type also implements the [`AsyncBufRead`] trait, so you can use it
14  /// to read a `Stream` of byte chunks line-by-line. See the examples below.
15  ///
16  /// # Example
17  ///
18  /// ```
19  /// use bytes::Bytes;
20  /// use tokio::io::{AsyncReadExt, Result};
21  /// use tokio_util::io::StreamReader;
22  /// # #[tokio::main(flavor = "current_thread")]
23  /// # async fn main() -> std::io::Result<()> {
24  ///
25  /// // Create a stream from an iterator.
26  /// let stream = tokio_stream::iter(vec![
27  ///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
28  ///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
29  ///     Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
30  /// ]);
31  ///
32  /// // Convert it to an AsyncRead.
33  /// let mut read = StreamReader::new(stream);
34  ///
35  /// // Read five bytes from the stream.
36  /// let mut buf = [0; 5];
37  /// read.read_exact(&mut buf).await?;
38  /// assert_eq!(buf, [0, 1, 2, 3, 4]);
39  ///
40  /// // Read the rest of the current chunk.
41  /// assert_eq!(read.read(&mut buf).await?, 3);
42  /// assert_eq!(&buf[..3], [5, 6, 7]);
43  ///
44  /// // Read the next chunk.
45  /// assert_eq!(read.read(&mut buf).await?, 4);
46  /// assert_eq!(&buf[..4], [8, 9, 10, 11]);
47  ///
48  /// // We have now reached the end.
49  /// assert_eq!(read.read(&mut buf).await?, 0);
50  ///
51  /// # Ok(())
52  /// # }
53  /// ```
54  ///
55  /// If the stream produces errors which are not [`std::io::Error`],
56  /// the errors can be converted using [`StreamExt`] to map each
57  /// element.
58  ///
59  /// ```
60  /// use bytes::Bytes;
61  /// use tokio::io::AsyncReadExt;
62  /// use tokio_util::io::StreamReader;
63  /// use tokio_stream::StreamExt;
64  /// # #[tokio::main(flavor = "current_thread")]
65  /// # async fn main() -> std::io::Result<()> {
66  ///
67  /// // Create a stream from an iterator, including an error.
68  /// let stream = tokio_stream::iter(vec![
69  ///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
70  ///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
71  ///     Result::Err("Something bad happened!")
72  /// ]);
73  ///
74  /// // Use StreamExt to map the stream and error to a std::io::Error
75  /// let stream = stream.map(|result| result.map_err(|err| {
76  ///     std::io::Error::new(std::io::ErrorKind::Other, err)
77  /// }));
78  ///
79  /// // Convert it to an AsyncRead.
80  /// let mut read = StreamReader::new(stream);
81  ///
82  /// // Read five bytes from the stream.
83  /// let mut buf = [0; 5];
84  /// read.read_exact(&mut buf).await?;
85  /// assert_eq!(buf, [0, 1, 2, 3, 4]);
86  ///
87  /// // Read the rest of the current chunk.
88  /// assert_eq!(read.read(&mut buf).await?, 3);
89  /// assert_eq!(&buf[..3], [5, 6, 7]);
90  ///
91  /// // Reading the next chunk will produce an error
92  /// let error = read.read(&mut buf).await.unwrap_err();
93  /// assert_eq!(error.kind(), std::io::ErrorKind::Other);
94  /// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!");
95  ///
96  /// // We have now reached the end.
97  /// assert_eq!(read.read(&mut buf).await?, 0);
98  ///
99  /// # Ok(())
100  /// # }
101  /// ```
102  ///
103  /// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks
104  /// line-by-line. Note that you will usually also need to convert the error
105  /// type when doing this. See the second example for an explanation of how
106  /// to do this.
107  ///
108  /// ```
109  /// use tokio::io::{Result, AsyncBufReadExt};
110  /// use tokio_util::io::StreamReader;
111  /// # #[tokio::main(flavor = "current_thread")]
112  /// # async fn main() -> std::io::Result<()> {
113  ///
114  /// // Create a stream of byte chunks.
115  /// let stream = tokio_stream::iter(vec![
116  ///     Result::Ok(b"The first line.\n".as_slice()),
117  ///     Result::Ok(b"The second line.".as_slice()),
118  ///     Result::Ok(b"\nThe third".as_slice()),
119  ///     Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()),
120  /// ]);
121  ///
122  /// // Convert it to an AsyncRead.
123  /// let mut read = StreamReader::new(stream);
124  ///
125  /// // Loop through the lines from the `StreamReader`.
126  /// let mut line = String::new();
127  /// let mut lines = Vec::new();
128  /// loop {
129  ///     line.clear();
130  ///     let len = read.read_line(&mut line).await?;
131  ///     if len == 0 { break; }
132  ///     lines.push(line.clone());
133  /// }
134  ///
135  /// // Verify that we got the lines we expected.
136  /// assert_eq!(
137  ///     lines,
138  ///     vec![
139  ///         "The first line.\n",
140  ///         "The second line.\n",
141  ///         "The third line.\n",
142  ///         "The fourth line.\n",
143  ///         "The fifth line.\n",
144  ///     ]
145  /// );
146  /// # Ok(())
147  /// # }
148  /// ```
149  ///
150  /// [`AsyncRead`]: tokio::io::AsyncRead
151  /// [`AsyncBufRead`]: tokio::io::AsyncBufRead
152  /// [`Stream`]: futures_core::Stream
153  /// [`ReaderStream`]: crate::io::ReaderStream
154  /// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html
155  #[derive(Debug)]
156  pub struct StreamReader<S, B> {
157      // This field is pinned.
158      inner: S,
159      // This field is not pinned.
160      chunk: Option<B>,
161  }
162  
163  impl<S, B, E> StreamReader<S, B>
164  where
165      S: Stream<Item = Result<B, E>>,
166      B: Buf,
167      E: Into<std::io::Error>,
168  {
169      /// Convert a stream of byte chunks into an [`AsyncRead`].
170      ///
171      /// The item should be a [`Result`] with the ok variant being something that
172      /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
173      /// should be convertible into an [io error].
174      ///
175      /// [`Result`]: std::result::Result
176      /// [`Buf`]: bytes::Buf
177      /// [io error]: std::io::Error
new(stream: S) -> Self178      pub fn new(stream: S) -> Self {
179          Self {
180              inner: stream,
181              chunk: None,
182          }
183      }
184  
185      /// Do we have a chunk and is it non-empty?
has_chunk(&self) -> bool186      fn has_chunk(&self) -> bool {
187          if let Some(ref chunk) = self.chunk {
188              chunk.remaining() > 0
189          } else {
190              false
191          }
192      }
193  
194      /// Consumes this `StreamReader`, returning a Tuple consisting
195      /// of the underlying stream and an Option of the internal buffer,
196      /// which is Some in case the buffer contains elements.
into_inner_with_chunk(self) -> (S, Option<B>)197      pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
198          if self.has_chunk() {
199              (self.inner, self.chunk)
200          } else {
201              (self.inner, None)
202          }
203      }
204  }
205  
206  impl<S, B> StreamReader<S, B> {
207      /// Gets a reference to the underlying stream.
208      ///
209      /// It is inadvisable to directly read from the underlying stream.
get_ref(&self) -> &S210      pub fn get_ref(&self) -> &S {
211          &self.inner
212      }
213  
214      /// Gets a mutable reference to the underlying stream.
215      ///
216      /// It is inadvisable to directly read from the underlying stream.
get_mut(&mut self) -> &mut S217      pub fn get_mut(&mut self) -> &mut S {
218          &mut self.inner
219      }
220  
221      /// Gets a pinned mutable reference to the underlying stream.
222      ///
223      /// It is inadvisable to directly read from the underlying stream.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>224      pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
225          self.project().inner
226      }
227  
228      /// Consumes this `BufWriter`, returning the underlying stream.
229      ///
230      /// Note that any leftover data in the internal buffer is lost.
231      /// If you additionally want access to the internal buffer use
232      /// [`into_inner_with_chunk`].
233      ///
234      /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
into_inner(self) -> S235      pub fn into_inner(self) -> S {
236          self.inner
237      }
238  }
239  
240  impl<S, B, E> AsyncRead for StreamReader<S, B>
241  where
242      S: Stream<Item = Result<B, E>>,
243      B: Buf,
244      E: Into<std::io::Error>,
245  {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>246      fn poll_read(
247          mut self: Pin<&mut Self>,
248          cx: &mut Context<'_>,
249          buf: &mut ReadBuf<'_>,
250      ) -> Poll<io::Result<()>> {
251          if buf.remaining() == 0 {
252              return Poll::Ready(Ok(()));
253          }
254  
255          let inner_buf = match self.as_mut().poll_fill_buf(cx) {
256              Poll::Ready(Ok(buf)) => buf,
257              Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
258              Poll::Pending => return Poll::Pending,
259          };
260          let len = std::cmp::min(inner_buf.len(), buf.remaining());
261          buf.put_slice(&inner_buf[..len]);
262  
263          self.consume(len);
264          Poll::Ready(Ok(()))
265      }
266  }
267  
268  impl<S, B, E> AsyncBufRead for StreamReader<S, B>
269  where
270      S: Stream<Item = Result<B, E>>,
271      B: Buf,
272      E: Into<std::io::Error>,
273  {
poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>274      fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
275          loop {
276              if self.as_mut().has_chunk() {
277                  // This unwrap is very sad, but it can't be avoided.
278                  let buf = self.project().chunk.as_ref().unwrap().chunk();
279                  return Poll::Ready(Ok(buf));
280              } else {
281                  match self.as_mut().project().inner.poll_next(cx) {
282                      Poll::Ready(Some(Ok(chunk))) => {
283                          // Go around the loop in case the chunk is empty.
284                          *self.as_mut().project().chunk = Some(chunk);
285                      }
286                      Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
287                      Poll::Ready(None) => return Poll::Ready(Ok(&[])),
288                      Poll::Pending => return Poll::Pending,
289                  }
290              }
291          }
292      }
consume(self: Pin<&mut Self>, amt: usize)293      fn consume(self: Pin<&mut Self>, amt: usize) {
294          if amt > 0 {
295              self.project()
296                  .chunk
297                  .as_mut()
298                  .expect("No chunk present")
299                  .advance(amt);
300          }
301      }
302  }
303  
304  // The code below is a manual expansion of the code that pin-project-lite would
305  // generate. This is done because pin-project-lite fails by hitting the recursion
306  // limit on this struct. (Every line of documentation is handled recursively by
307  // the macro.)
308  
309  impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
310  
311  struct StreamReaderProject<'a, S, B> {
312      inner: Pin<&'a mut S>,
313      chunk: &'a mut Option<B>,
314  }
315  
316  impl<S, B> StreamReader<S, B> {
317      #[inline]
project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B>318      fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
319          // SAFETY: We define that only `inner` should be pinned when `Self` is
320          // and have an appropriate `impl Unpin` for this.
321          let me = unsafe { Pin::into_inner_unchecked(self) };
322          StreamReaderProject {
323              inner: unsafe { Pin::new_unchecked(&mut me.inner) },
324              chunk: &mut me.chunk,
325          }
326      }
327  }
328  
329  impl<S: Sink<T, Error = E>, B, E, T> Sink<T> for StreamReader<S, B> {
330      type Error = E;
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>331      fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
332          self.project().inner.poll_ready(cx)
333      }
334  
start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error>335      fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
336          self.project().inner.start_send(item)
337      }
338  
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>339      fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
340          self.project().inner.poll_flush(cx)
341      }
342  
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>343      fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344          self.project().inner.poll_close(cx)
345      }
346  }
347