1 use futures_core::future::{FusedFuture, Future};
2 use futures_core::stream::{FusedStream, Stream};
3 use futures_io::{
4     self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom,
5 };
6 use futures_sink::Sink;
7 use pin_project::pin_project;
8 use std::{
9     pin::Pin,
10     task::{Context, Poll},
11 };
12 
13 /// Wrapper that interleaves [`Poll::Pending`] in calls to poll.
14 ///
15 /// See the `interleave_pending` methods on:
16 /// * [`FutureTestExt`](crate::future::FutureTestExt::interleave_pending)
17 /// * [`StreamTestExt`](crate::stream::StreamTestExt::interleave_pending)
18 /// * [`SinkTestExt`](crate::sink::SinkTestExt::interleave_pending_sink)
19 /// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::interleave_pending)
20 /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::interleave_pending_write)
21 #[pin_project]
22 #[derive(Debug)]
23 pub struct InterleavePending<T> {
24     #[pin]
25     inner: T,
26     pended: bool,
27 }
28 
29 impl<T> InterleavePending<T> {
new(inner: T) -> Self30     pub(crate) fn new(inner: T) -> Self {
31         Self { inner, pended: false }
32     }
33 
34     /// Acquires a reference to the underlying I/O object that this adaptor is
35     /// wrapping.
get_ref(&self) -> &T36     pub fn get_ref(&self) -> &T {
37         &self.inner
38     }
39 
40     /// Acquires a mutable reference to the underlying I/O object that this
41     /// adaptor is wrapping.
get_mut(&mut self) -> &mut T42     pub fn get_mut(&mut self) -> &mut T {
43         &mut self.inner
44     }
45 
46     /// Acquires a pinned mutable reference to the underlying I/O object that
47     /// this adaptor is wrapping.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T>48     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
49         self.project().inner
50     }
51 
52     /// Consumes this adaptor returning the underlying I/O object.
into_inner(self) -> T53     pub fn into_inner(self) -> T {
54         self.inner
55     }
56 
poll_with<'a, U>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, f: impl FnOnce(Pin<&'a mut T>, &mut Context<'_>) -> Poll<U>, ) -> Poll<U>57     fn poll_with<'a, U>(
58         self: Pin<&'a mut Self>,
59         cx: &mut Context<'_>,
60         f: impl FnOnce(Pin<&'a mut T>, &mut Context<'_>) -> Poll<U>,
61     ) -> Poll<U> {
62         let this = self.project();
63         if *this.pended {
64             let next = f(this.inner, cx);
65             if next.is_ready() {
66                 *this.pended = false;
67             }
68             next
69         } else {
70             cx.waker().wake_by_ref();
71             *this.pended = true;
72             Poll::Pending
73         }
74     }
75 }
76 
77 impl<Fut: Future> Future for InterleavePending<Fut> {
78     type Output = Fut::Output;
79 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>80     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81         self.poll_with(cx, Fut::poll)
82     }
83 }
84 
85 impl<Fut: FusedFuture> FusedFuture for InterleavePending<Fut> {
is_terminated(&self) -> bool86     fn is_terminated(&self) -> bool {
87         self.inner.is_terminated()
88     }
89 }
90 
91 impl<St: Stream> Stream for InterleavePending<St> {
92     type Item = St::Item;
93 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>94     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95         self.poll_with(cx, St::poll_next)
96     }
97 
size_hint(&self) -> (usize, Option<usize>)98     fn size_hint(&self) -> (usize, Option<usize>) {
99         self.inner.size_hint()
100     }
101 }
102 
103 impl<St: FusedStream> FusedStream for InterleavePending<St> {
is_terminated(&self) -> bool104     fn is_terminated(&self) -> bool {
105         self.inner.is_terminated()
106     }
107 }
108 
109 impl<Si: Sink<Item>, Item> Sink<Item> for InterleavePending<Si> {
110     type Error = Si::Error;
111 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>112     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113         self.poll_with(cx, Si::poll_ready)
114     }
115 
start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error>116     fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
117         self.project().inner.start_send(item)
118     }
119 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>120     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121         self.poll_with(cx, Si::poll_flush)
122     }
123 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>124     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125         self.poll_with(cx, Si::poll_close)
126     }
127 }
128 
129 impl<R: AsyncRead> AsyncRead for InterleavePending<R> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>130     fn poll_read(
131         self: Pin<&mut Self>,
132         cx: &mut Context<'_>,
133         buf: &mut [u8],
134     ) -> Poll<io::Result<usize>> {
135         self.poll_with(cx, |r, cx| r.poll_read(cx, buf))
136     }
137 
poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll<io::Result<usize>>138     fn poll_read_vectored(
139         self: Pin<&mut Self>,
140         cx: &mut Context<'_>,
141         bufs: &mut [IoSliceMut<'_>],
142     ) -> Poll<io::Result<usize>> {
143         self.poll_with(cx, |r, cx| r.poll_read_vectored(cx, bufs))
144     }
145 }
146 
147 impl<W: AsyncWrite> AsyncWrite for InterleavePending<W> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>148     fn poll_write(
149         self: Pin<&mut Self>,
150         cx: &mut Context<'_>,
151         buf: &[u8],
152     ) -> Poll<io::Result<usize>> {
153         self.poll_with(cx, |w, cx| w.poll_write(cx, buf))
154     }
155 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>156     fn poll_write_vectored(
157         self: Pin<&mut Self>,
158         cx: &mut Context<'_>,
159         bufs: &[IoSlice<'_>],
160     ) -> Poll<io::Result<usize>> {
161         self.poll_with(cx, |w, cx| w.poll_write_vectored(cx, bufs))
162     }
163 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>164     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165         self.poll_with(cx, W::poll_flush)
166     }
167 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>168     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169         self.poll_with(cx, W::poll_close)
170     }
171 }
172 
173 impl<S: AsyncSeek> AsyncSeek for InterleavePending<S> {
poll_seek( self: Pin<&mut Self>, cx: &mut Context<'_>, pos: SeekFrom, ) -> Poll<io::Result<u64>>174     fn poll_seek(
175         self: Pin<&mut Self>,
176         cx: &mut Context<'_>,
177         pos: SeekFrom,
178     ) -> Poll<io::Result<u64>> {
179         self.poll_with(cx, |s, cx| s.poll_seek(cx, pos))
180     }
181 }
182 
183 impl<R: AsyncBufRead> AsyncBufRead for InterleavePending<R> {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>184     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
185         self.poll_with(cx, R::poll_fill_buf)
186     }
187 
consume(self: Pin<&mut Self>, amount: usize)188     fn consume(self: Pin<&mut Self>, amount: usize) {
189         self.project().inner.consume(amount)
190     }
191 }
192