1 //! Contains utilities for stdout and stderr.
2 use crate::io::AsyncWrite;
3 use std::pin::Pin;
4 use std::task::{Context, Poll};
5 /// # Windows
6 /// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest,
7 /// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`.
8 /// That's why, wrapped writer will always receive well-formed utf-8 bytes.
9 /// # Other platforms
10 /// Passes data to `inner` as is.
11 #[derive(Debug)]
12 pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
13     inner: W,
14 }
15 
16 impl<W> SplitByUtf8BoundaryIfWindows<W> {
new(inner: W) -> Self17     pub(crate) fn new(inner: W) -> Self {
18         Self { inner }
19     }
20 }
21 
22 // this constant is defined by Unicode standard.
23 const MAX_BYTES_PER_CHAR: usize = 4;
24 
25 // Subject for tweaking here
26 const MAGIC_CONST: usize = 8;
27 
28 impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
29 where
30     W: AsyncWrite + Unpin,
31 {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, mut buf: &[u8], ) -> Poll<Result<usize, std::io::Error>>32     fn poll_write(
33         mut self: Pin<&mut Self>,
34         cx: &mut Context<'_>,
35         mut buf: &[u8],
36     ) -> Poll<Result<usize, std::io::Error>> {
37         // just a closure to avoid repetitive code
38         let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
39 
40         // 1. Only windows stdio can suffer from non-utf8.
41         // We also check for `test` so that we can write some tests
42         // for further code. Since `AsyncWrite` can always shrink
43         // buffer at its discretion, excessive (i.e. in tests) shrinking
44         // does not break correctness.
45         // 2. If buffer is small, it will not be shrunk.
46         // That's why, it's "textness" will not change, so we don't have
47         // to fixup it.
48         if cfg!(not(any(target_os = "windows", test)))
49             || buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE
50         {
51             return call_inner(buf);
52         }
53 
54         buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE];
55 
56         // Now there are two possibilities.
57         // If caller gave is binary buffer, we **should not** shrink it
58         // anymore, because excessive shrinking hits performance.
59         // If caller gave as binary buffer, we  **must** additionally
60         // shrink it to strip incomplete char at the end of buffer.
61         // that's why check we will perform now is allowed to have
62         // false-positive.
63 
64         // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes.
65         // if they are (possibly incomplete) utf8, then we can be quite sure
66         // that input buffer was utf8.
67 
68         let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
69             Ok(_) => true,
70             Err(err) => {
71                 let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
72                 incomplete_bytes < MAX_BYTES_PER_CHAR
73             }
74         };
75 
76         if have_to_fix_up {
77             // We must pop several bytes at the end which form incomplete
78             // character. To achieve it, we exploit UTF8 encoding:
79             // for any code point, all bytes except first start with 0b10 prefix.
80             // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details
81             let trailing_incomplete_char_size = buf
82                 .iter()
83                 .rev()
84                 .take(MAX_BYTES_PER_CHAR)
85                 .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
86                 .unwrap_or(0)
87                 + 1;
88             buf = &buf[..buf.len() - trailing_incomplete_char_size];
89         }
90 
91         call_inner(buf)
92     }
93 
poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<(), std::io::Error>>94     fn poll_flush(
95         mut self: Pin<&mut Self>,
96         cx: &mut Context<'_>,
97     ) -> Poll<Result<(), std::io::Error>> {
98         Pin::new(&mut self.inner).poll_flush(cx)
99     }
100 
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<(), std::io::Error>>101     fn poll_shutdown(
102         mut self: Pin<&mut Self>,
103         cx: &mut Context<'_>,
104     ) -> Poll<Result<(), std::io::Error>> {
105         Pin::new(&mut self.inner).poll_shutdown(cx)
106     }
107 }
108 
109 #[cfg(test)]
110 #[cfg(not(loom))]
111 mod tests {
112     use crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
113     use crate::io::AsyncWriteExt;
114     use std::io;
115     use std::pin::Pin;
116     use std::task::Context;
117     use std::task::Poll;
118 
119     struct TextMockWriter;
120 
121     impl crate::io::AsyncWrite for TextMockWriter {
poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>122         fn poll_write(
123             self: Pin<&mut Self>,
124             _cx: &mut Context<'_>,
125             buf: &[u8],
126         ) -> Poll<Result<usize, io::Error>> {
127             assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
128             assert!(std::str::from_utf8(buf).is_ok());
129             Poll::Ready(Ok(buf.len()))
130         }
131 
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>132         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
133             Poll::Ready(Ok(()))
134         }
135 
poll_shutdown( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<(), io::Error>>136         fn poll_shutdown(
137             self: Pin<&mut Self>,
138             _cx: &mut Context<'_>,
139         ) -> Poll<Result<(), io::Error>> {
140             Poll::Ready(Ok(()))
141         }
142     }
143 
144     struct LoggingMockWriter {
145         write_history: Vec<usize>,
146     }
147 
148     impl LoggingMockWriter {
new() -> Self149         fn new() -> Self {
150             LoggingMockWriter {
151                 write_history: Vec::new(),
152             }
153         }
154     }
155 
156     impl crate::io::AsyncWrite for LoggingMockWriter {
poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>157         fn poll_write(
158             mut self: Pin<&mut Self>,
159             _cx: &mut Context<'_>,
160             buf: &[u8],
161         ) -> Poll<Result<usize, io::Error>> {
162             assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
163             self.write_history.push(buf.len());
164             Poll::Ready(Ok(buf.len()))
165         }
166 
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>167         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
168             Poll::Ready(Ok(()))
169         }
170 
poll_shutdown( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<(), io::Error>>171         fn poll_shutdown(
172             self: Pin<&mut Self>,
173             _cx: &mut Context<'_>,
174         ) -> Poll<Result<(), io::Error>> {
175             Poll::Ready(Ok(()))
176         }
177     }
178 
179     #[test]
180     #[cfg_attr(miri, ignore)]
test_splitter()181     fn test_splitter() {
182         let data = str::repeat("█", DEFAULT_MAX_BUF_SIZE);
183         let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
184         let fut = async move {
185             wr.write_all(data.as_bytes()).await.unwrap();
186         };
187         crate::runtime::Builder::new_current_thread()
188             .build()
189             .unwrap()
190             .block_on(fut);
191     }
192 
193     #[test]
194     #[cfg_attr(miri, ignore)]
test_pseudo_text()195     fn test_pseudo_text() {
196         // In this test we write a piece of binary data, whose beginning is
197         // text though. We then validate that even in this corner case buffer
198         // was not shrunk too much.
199         let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
200         let mut data: Vec<u8> = str::repeat("a", checked_count).into();
201         data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1));
202         let mut writer = LoggingMockWriter::new();
203         let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
204         crate::runtime::Builder::new_current_thread()
205             .build()
206             .unwrap()
207             .block_on(async {
208                 splitter.write_all(&data).await.unwrap();
209             });
210         // Check that at most two writes were performed
211         assert!(writer.write_history.len() <= 2);
212         // Check that all has been written
213         assert_eq!(
214             writer.write_history.iter().copied().sum::<usize>(),
215             data.len()
216         );
217         // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk
218         // from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and
219         // up to one "utf8 code point".
220         assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
221     }
222 }
223