1 use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride};
2 use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
3 use crate::{Code, Status};
4 use bytes::{BufMut, Bytes, BytesMut};
5 use http::HeaderMap;
6 use http_body::Body;
7 use pin_project::pin_project;
8 use std::{
9     pin::Pin,
10     task::{ready, Context, Poll},
11 };
12 use tokio_stream::{Stream, StreamExt};
13 
14 pub(super) const BUFFER_SIZE: usize = 8 * 1024;
15 const YIELD_THRESHOLD: usize = 32 * 1024;
16 
encode_server<T, U>( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>> where T: Encoder<Error = Status>, U: Stream<Item = Result<T::Item, Status>>,17 pub(crate) fn encode_server<T, U>(
18     encoder: T,
19     source: U,
20     compression_encoding: Option<CompressionEncoding>,
21     compression_override: SingleMessageCompressionOverride,
22     max_message_size: Option<usize>,
23 ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
24 where
25     T: Encoder<Error = Status>,
26     U: Stream<Item = Result<T::Item, Status>>,
27 {
28     let stream = EncodedBytes::new(
29         encoder,
30         source.fuse(),
31         compression_encoding,
32         compression_override,
33         max_message_size,
34     );
35 
36     EncodeBody::new_server(stream)
37 }
38 
encode_client<T, U>( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>> where T: Encoder<Error = Status>, U: Stream<Item = T::Item>,39 pub(crate) fn encode_client<T, U>(
40     encoder: T,
41     source: U,
42     compression_encoding: Option<CompressionEncoding>,
43     max_message_size: Option<usize>,
44 ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
45 where
46     T: Encoder<Error = Status>,
47     U: Stream<Item = T::Item>,
48 {
49     let stream = EncodedBytes::new(
50         encoder,
51         source.fuse().map(Ok),
52         compression_encoding,
53         SingleMessageCompressionOverride::default(),
54         max_message_size,
55     );
56     EncodeBody::new_client(stream)
57 }
58 
59 /// Combinator for efficient encoding of messages into reasonably sized buffers.
60 /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
61 /// splitting off and yielding a buffer when either:
62 ///  * The delegate stream polls as not ready, or
63 ///  * The encoded buffer surpasses YIELD_THRESHOLD.
64 #[pin_project(project = EncodedBytesProj)]
65 #[derive(Debug)]
66 pub(crate) struct EncodedBytes<T, U>
67 where
68     T: Encoder<Error = Status>,
69     U: Stream<Item = Result<T::Item, Status>>,
70 {
71     #[pin]
72     source: U,
73     encoder: T,
74     compression_encoding: Option<CompressionEncoding>,
75     max_message_size: Option<usize>,
76     buf: BytesMut,
77     uncompression_buf: BytesMut,
78 }
79 
80 impl<T, U> EncodedBytes<T, U>
81 where
82     T: Encoder<Error = Status>,
83     U: Stream<Item = Result<T::Item, Status>>,
84 {
85     // `source` should be fused stream.
new( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self86     fn new(
87         encoder: T,
88         source: U,
89         compression_encoding: Option<CompressionEncoding>,
90         compression_override: SingleMessageCompressionOverride,
91         max_message_size: Option<usize>,
92     ) -> Self {
93         let buf = BytesMut::with_capacity(BUFFER_SIZE);
94 
95         let compression_encoding =
96             if compression_override == SingleMessageCompressionOverride::Disable {
97                 None
98             } else {
99                 compression_encoding
100             };
101 
102         let uncompression_buf = if compression_encoding.is_some() {
103             BytesMut::with_capacity(BUFFER_SIZE)
104         } else {
105             BytesMut::new()
106         };
107 
108         Self {
109             source,
110             encoder,
111             compression_encoding,
112             max_message_size,
113             buf,
114             uncompression_buf,
115         }
116     }
117 }
118 
119 impl<T, U> Stream for EncodedBytes<T, U>
120 where
121     T: Encoder<Error = Status>,
122     U: Stream<Item = Result<T::Item, Status>>,
123 {
124     type Item = Result<Bytes, Status>;
125 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>126     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127         let EncodedBytesProj {
128             mut source,
129             encoder,
130             compression_encoding,
131             max_message_size,
132             buf,
133             uncompression_buf,
134         } = self.project();
135 
136         loop {
137             match source.as_mut().poll_next(cx) {
138                 Poll::Pending if buf.is_empty() => {
139                     return Poll::Pending;
140                 }
141                 Poll::Ready(None) if buf.is_empty() => {
142                     return Poll::Ready(None);
143                 }
144                 Poll::Pending | Poll::Ready(None) => {
145                     return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
146                 }
147                 Poll::Ready(Some(Ok(item))) => {
148                     if let Err(status) = encode_item(
149                         encoder,
150                         buf,
151                         uncompression_buf,
152                         *compression_encoding,
153                         *max_message_size,
154                         item,
155                     ) {
156                         return Poll::Ready(Some(Err(status)));
157                     }
158 
159                     if buf.len() >= YIELD_THRESHOLD {
160                         return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
161                     }
162                 }
163                 Poll::Ready(Some(Err(status))) => {
164                     return Poll::Ready(Some(Err(status)));
165                 }
166             }
167         }
168     }
169 }
170 
encode_item<T>( encoder: &mut T, buf: &mut BytesMut, uncompression_buf: &mut BytesMut, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, item: T::Item, ) -> Result<(), Status> where T: Encoder<Error = Status>,171 fn encode_item<T>(
172     encoder: &mut T,
173     buf: &mut BytesMut,
174     uncompression_buf: &mut BytesMut,
175     compression_encoding: Option<CompressionEncoding>,
176     max_message_size: Option<usize>,
177     item: T::Item,
178 ) -> Result<(), Status>
179 where
180     T: Encoder<Error = Status>,
181 {
182     let offset = buf.len();
183 
184     buf.reserve(HEADER_SIZE);
185     unsafe {
186         buf.advance_mut(HEADER_SIZE);
187     }
188 
189     if let Some(encoding) = compression_encoding {
190         uncompression_buf.clear();
191 
192         encoder
193             .encode(item, &mut EncodeBuf::new(uncompression_buf))
194             .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
195 
196         let uncompressed_len = uncompression_buf.len();
197 
198         compress(encoding, uncompression_buf, buf, uncompressed_len)
199             .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
200     } else {
201         encoder
202             .encode(item, &mut EncodeBuf::new(buf))
203             .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
204     }
205 
206     // now that we know length, we can write the header
207     finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
208 }
209 
finish_encoding( compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buf: &mut [u8], ) -> Result<(), Status>210 fn finish_encoding(
211     compression_encoding: Option<CompressionEncoding>,
212     max_message_size: Option<usize>,
213     buf: &mut [u8],
214 ) -> Result<(), Status> {
215     let len = buf.len() - HEADER_SIZE;
216     let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
217     if len > limit {
218         return Err(Status::new(
219             Code::OutOfRange,
220             format!(
221                 "Error, message length too large: found {} bytes, the limit is: {} bytes",
222                 len, limit
223             ),
224         ));
225     }
226 
227     if len > std::u32::MAX as usize {
228         return Err(Status::resource_exhausted(format!(
229             "Cannot return body with more than 4GB of data but got {len} bytes"
230         )));
231     }
232     {
233         let mut buf = &mut buf[..HEADER_SIZE];
234         buf.put_u8(compression_encoding.is_some() as u8);
235         buf.put_u32(len as u32);
236     }
237 
238     Ok(())
239 }
240 
241 #[derive(Debug)]
242 enum Role {
243     Client,
244     Server,
245 }
246 
247 #[pin_project]
248 #[derive(Debug)]
249 pub(crate) struct EncodeBody<S> {
250     #[pin]
251     inner: S,
252     state: EncodeState,
253 }
254 
255 #[derive(Debug)]
256 struct EncodeState {
257     error: Option<Status>,
258     role: Role,
259     is_end_stream: bool,
260 }
261 
262 impl<S> EncodeBody<S>
263 where
264     S: Stream<Item = Result<Bytes, Status>>,
265 {
new_client(inner: S) -> Self266     pub(crate) fn new_client(inner: S) -> Self {
267         Self {
268             inner,
269             state: EncodeState {
270                 error: None,
271                 role: Role::Client,
272                 is_end_stream: false,
273             },
274         }
275     }
276 
new_server(inner: S) -> Self277     pub(crate) fn new_server(inner: S) -> Self {
278         Self {
279             inner,
280             state: EncodeState {
281                 error: None,
282                 role: Role::Server,
283                 is_end_stream: false,
284             },
285         }
286     }
287 }
288 
289 impl EncodeState {
trailers(&mut self) -> Result<Option<HeaderMap>, Status>290     fn trailers(&mut self) -> Result<Option<HeaderMap>, Status> {
291         match self.role {
292             Role::Client => Ok(None),
293             Role::Server => {
294                 if self.is_end_stream {
295                     return Ok(None);
296                 }
297 
298                 let status = if let Some(status) = self.error.take() {
299                     self.is_end_stream = true;
300                     status
301                 } else {
302                     Status::new(Code::Ok, "")
303                 };
304 
305                 Ok(Some(status.to_header_map()?))
306             }
307         }
308     }
309 }
310 
311 impl<S> Body for EncodeBody<S>
312 where
313     S: Stream<Item = Result<Bytes, Status>>,
314 {
315     type Data = Bytes;
316     type Error = Status;
317 
is_end_stream(&self) -> bool318     fn is_end_stream(&self) -> bool {
319         self.state.is_end_stream
320     }
321 
poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>322     fn poll_data(
323         self: Pin<&mut Self>,
324         cx: &mut Context<'_>,
325     ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
326         let self_proj = self.project();
327         match ready!(self_proj.inner.poll_next(cx)) {
328             Some(Ok(d)) => Some(Ok(d)).into(),
329             Some(Err(status)) => match self_proj.state.role {
330                 Role::Client => Some(Err(status)).into(),
331                 Role::Server => {
332                     self_proj.state.error = Some(status);
333                     None.into()
334                 }
335             },
336             None => None.into(),
337         }
338     }
339 
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Status>>340     fn poll_trailers(
341         self: Pin<&mut Self>,
342         _cx: &mut Context<'_>,
343     ) -> Poll<Result<Option<HeaderMap>, Status>> {
344         Poll::Ready(self.project().state.trailers())
345     }
346 }
347