1 use futures_io::AsyncWrite;
2 use futures_sink::Sink;
3 use std::{
4     io::{self, IoSlice},
5     pin::Pin,
6     task::{Context, Poll},
7 };
8 
9 /// Async wrapper that tracks whether it has been closed.
10 ///
11 /// See the `track_closed` methods on:
12 /// * [`SinkTestExt`](crate::sink::SinkTestExt::track_closed)
13 /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::track_closed)
14 #[pin_project::pin_project]
15 #[derive(Debug)]
16 pub struct TrackClosed<T> {
17     #[pin]
18     inner: T,
19     closed: bool,
20 }
21 
22 impl<T> TrackClosed<T> {
new(inner: T) -> Self23     pub(crate) fn new(inner: T) -> Self {
24         Self { inner, closed: false }
25     }
26 
27     /// Check whether this object has been closed.
is_closed(&self) -> bool28     pub fn is_closed(&self) -> bool {
29         self.closed
30     }
31 
32     /// Acquires a reference to the underlying object that this adaptor is
33     /// wrapping.
get_ref(&self) -> &T34     pub fn get_ref(&self) -> &T {
35         &self.inner
36     }
37 
38     /// Acquires a mutable reference to the underlying object that this
39     /// adaptor is wrapping.
get_mut(&mut self) -> &mut T40     pub fn get_mut(&mut self) -> &mut T {
41         &mut self.inner
42     }
43 
44     /// Acquires a pinned mutable reference to the underlying object that
45     /// this adaptor is wrapping.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T>46     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
47         self.project().inner
48     }
49 
50     /// Consumes this adaptor returning the underlying object.
into_inner(self) -> T51     pub fn into_inner(self) -> T {
52         self.inner
53     }
54 }
55 
56 impl<T: AsyncWrite> AsyncWrite for TrackClosed<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>57     fn poll_write(
58         self: Pin<&mut Self>,
59         cx: &mut Context<'_>,
60         buf: &[u8],
61     ) -> Poll<io::Result<usize>> {
62         if self.is_closed() {
63             return Poll::Ready(Err(io::Error::new(
64                 io::ErrorKind::Other,
65                 "Attempted to write after stream was closed",
66             )));
67         }
68         self.project().inner.poll_write(cx, buf)
69     }
70 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>71     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72         if self.is_closed() {
73             return Poll::Ready(Err(io::Error::new(
74                 io::ErrorKind::Other,
75                 "Attempted to flush after stream was closed",
76             )));
77         }
78         assert!(!self.is_closed());
79         self.project().inner.poll_flush(cx)
80     }
81 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>82     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83         if self.is_closed() {
84             return Poll::Ready(Err(io::Error::new(
85                 io::ErrorKind::Other,
86                 "Attempted to close after stream was closed",
87             )));
88         }
89         let this = self.project();
90         match this.inner.poll_close(cx) {
91             Poll::Ready(Ok(())) => {
92                 *this.closed = true;
93                 Poll::Ready(Ok(()))
94             }
95             other => other,
96         }
97     }
98 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>99     fn poll_write_vectored(
100         self: Pin<&mut Self>,
101         cx: &mut Context<'_>,
102         bufs: &[IoSlice<'_>],
103     ) -> Poll<io::Result<usize>> {
104         if self.is_closed() {
105             return Poll::Ready(Err(io::Error::new(
106                 io::ErrorKind::Other,
107                 "Attempted to write after stream was closed",
108             )));
109         }
110         self.project().inner.poll_write_vectored(cx, bufs)
111     }
112 }
113 
114 impl<Item, T: Sink<Item>> Sink<Item> for TrackClosed<T> {
115     type Error = T::Error;
116 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>117     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118         assert!(!self.is_closed());
119         self.project().inner.poll_ready(cx)
120     }
121 
start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error>122     fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
123         assert!(!self.is_closed());
124         self.project().inner.start_send(item)
125     }
126 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>127     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128         assert!(!self.is_closed());
129         self.project().inner.poll_flush(cx)
130     }
131 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>132     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133         assert!(!self.is_closed());
134         let this = self.project();
135         match this.inner.poll_close(cx) {
136             Poll::Ready(Ok(())) => {
137                 *this.closed = true;
138                 Poll::Ready(Ok(()))
139             }
140             other => other,
141         }
142     }
143 }
144