1 use super::compression::{decompress, CompressionEncoding};
2 use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE};
3 use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
4 use bytes::{Buf, BufMut, BytesMut};
5 use http::StatusCode;
6 use http_body::Body;
7 use std::{
8     fmt, future,
9     pin::Pin,
10     task::ready,
11     task::{Context, Poll},
12 };
13 use tokio_stream::Stream;
14 use tracing::{debug, trace};
15 
16 const BUFFER_SIZE: usize = 8 * 1024;
17 
18 /// Streaming requests and responses.
19 ///
20 /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface
21 /// to fetch the message stream and trailing metadata
22 pub struct Streaming<T> {
23     decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>,
24     inner: StreamingInner,
25 }
26 
27 struct StreamingInner {
28     body: BoxBody,
29     state: State,
30     direction: Direction,
31     buf: BytesMut,
32     trailers: Option<MetadataMap>,
33     decompress_buf: BytesMut,
34     encoding: Option<CompressionEncoding>,
35     max_message_size: Option<usize>,
36 }
37 
38 impl<T> Unpin for Streaming<T> {}
39 
40 #[derive(Debug, Clone, Copy)]
41 enum State {
42     ReadHeader,
43     ReadBody {
44         compression: Option<CompressionEncoding>,
45         len: usize,
46     },
47     Error,
48 }
49 
50 #[derive(Debug, PartialEq, Eq)]
51 enum Direction {
52     Request,
53     Response(StatusCode),
54     EmptyResponse,
55 }
56 
57 impl<T> Streaming<T> {
new_response<B, D>( decoder: D, body: B, status_code: StatusCode, encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,58     pub(crate) fn new_response<B, D>(
59         decoder: D,
60         body: B,
61         status_code: StatusCode,
62         encoding: Option<CompressionEncoding>,
63         max_message_size: Option<usize>,
64     ) -> Self
65     where
66         B: Body + Send + 'static,
67         B::Error: Into<crate::Error>,
68         D: Decoder<Item = T, Error = Status> + Send + 'static,
69     {
70         Self::new(
71             decoder,
72             body,
73             Direction::Response(status_code),
74             encoding,
75             max_message_size,
76         )
77     }
78 
new_empty<B, D>(decoder: D, body: B) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,79     pub(crate) fn new_empty<B, D>(decoder: D, body: B) -> Self
80     where
81         B: Body + Send + 'static,
82         B::Error: Into<crate::Error>,
83         D: Decoder<Item = T, Error = Status> + Send + 'static,
84     {
85         Self::new(decoder, body, Direction::EmptyResponse, None, None)
86     }
87 
88     #[doc(hidden)]
new_request<B, D>( decoder: D, body: B, encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,89     pub fn new_request<B, D>(
90         decoder: D,
91         body: B,
92         encoding: Option<CompressionEncoding>,
93         max_message_size: Option<usize>,
94     ) -> Self
95     where
96         B: Body + Send + 'static,
97         B::Error: Into<crate::Error>,
98         D: Decoder<Item = T, Error = Status> + Send + 'static,
99     {
100         Self::new(
101             decoder,
102             body,
103             Direction::Request,
104             encoding,
105             max_message_size,
106         )
107     }
108 
new<B, D>( decoder: D, body: B, direction: Direction, encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,109     fn new<B, D>(
110         decoder: D,
111         body: B,
112         direction: Direction,
113         encoding: Option<CompressionEncoding>,
114         max_message_size: Option<usize>,
115     ) -> Self
116     where
117         B: Body + Send + 'static,
118         B::Error: Into<crate::Error>,
119         D: Decoder<Item = T, Error = Status> + Send + 'static,
120     {
121         Self {
122             decoder: Box::new(decoder),
123             inner: StreamingInner {
124                 body: body
125                     .map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
126                     .map_err(|err| Status::map_error(err.into()))
127                     .boxed_unsync(),
128                 state: State::ReadHeader,
129                 direction,
130                 buf: BytesMut::with_capacity(BUFFER_SIZE),
131                 trailers: None,
132                 decompress_buf: BytesMut::new(),
133                 encoding,
134                 max_message_size,
135             },
136         }
137     }
138 }
139 
140 impl StreamingInner {
decode_chunk(&mut self) -> Result<Option<DecodeBuf<'_>>, Status>141     fn decode_chunk(&mut self) -> Result<Option<DecodeBuf<'_>>, Status> {
142         if let State::ReadHeader = self.state {
143             if self.buf.remaining() < HEADER_SIZE {
144                 return Ok(None);
145             }
146 
147             let compression_encoding = match self.buf.get_u8() {
148                 0 => None,
149                 1 => {
150                     {
151                         if self.encoding.is_some() {
152                             self.encoding
153                         } else {
154                             // https://grpc.github.io/grpc/core/md_doc_compression.html
155                             // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding
156                             // entry different from identity in its metadata MUST fail with INTERNAL status,
157                             // its associated description indicating the invalid Compressed-Flag condition.
158                             return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified"));
159                         }
160                     }
161                 }
162                 f => {
163                     trace!("unexpected compression flag");
164                     let message = if let Direction::Response(status) = self.direction {
165                         format!(
166                             "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}",
167                             f, status
168                         )
169                     } else {
170                         format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f)
171                     };
172                     return Err(Status::new(Code::Internal, message));
173                 }
174             };
175 
176             let len = self.buf.get_u32() as usize;
177             let limit = self
178                 .max_message_size
179                 .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
180             if len > limit {
181                 return Err(Status::new(
182                     Code::OutOfRange,
183                     format!(
184                         "Error, message length too large: found {} bytes, the limit is: {} bytes",
185                         len, limit
186                     ),
187                 ));
188             }
189 
190             self.buf.reserve(len);
191 
192             self.state = State::ReadBody {
193                 compression: compression_encoding,
194                 len,
195             }
196         }
197 
198         if let State::ReadBody { len, compression } = self.state {
199             // if we haven't read enough of the message then return and keep
200             // reading
201             if self.buf.remaining() < len || self.buf.len() < len {
202                 return Ok(None);
203             }
204 
205             let decode_buf = if let Some(encoding) = compression {
206                 self.decompress_buf.clear();
207 
208                 if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len)
209                 {
210                     let message = if let Direction::Response(status) = self.direction {
211                         format!(
212                             "Error decompressing: {}, while receiving response with status: {}",
213                             err, status
214                         )
215                     } else {
216                         format!("Error decompressing: {}, while sending request", err)
217                     };
218                     return Err(Status::new(Code::Internal, message));
219                 }
220                 let decompressed_len = self.decompress_buf.len();
221                 DecodeBuf::new(&mut self.decompress_buf, decompressed_len)
222             } else {
223                 DecodeBuf::new(&mut self.buf, len)
224             };
225 
226             return Ok(Some(decode_buf));
227         }
228 
229         Ok(None)
230     }
231 
232     // Returns Some(()) if data was found or None if the loop in `poll_next` should break
poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>>233     fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> {
234         let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
235             Some(Ok(d)) => Some(d),
236             Some(Err(status)) => {
237                 if self.direction == Direction::Request && status.code() == Code::Cancelled {
238                     return Poll::Ready(Ok(None));
239                 }
240 
241                 let _ = std::mem::replace(&mut self.state, State::Error);
242                 debug!("decoder inner stream error: {:?}", status);
243                 return Poll::Ready(Err(status));
244             }
245             None => None,
246         };
247 
248         Poll::Ready(if let Some(data) = chunk {
249             self.buf.put(data);
250             Ok(Some(()))
251         } else {
252             // FIXME: improve buf usage.
253             if self.buf.has_remaining() {
254                 trace!("unexpected EOF decoding stream, state: {:?}", self.state);
255                 Err(Status::new(
256                     Code::Internal,
257                     "Unexpected EOF decoding stream.".to_string(),
258                 ))
259             } else {
260                 Ok(None)
261             }
262         })
263     }
264 
poll_response(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Status>>265     fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Status>> {
266         if let Direction::Response(status) = self.direction {
267             match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
268                 Ok(trailer) => {
269                     if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) {
270                         if let Some(e) = e {
271                             return Poll::Ready(Err(e));
272                         } else {
273                             return Poll::Ready(Ok(()));
274                         }
275                     } else {
276                         self.trailers = trailer.map(MetadataMap::from_headers);
277                     }
278                 }
279                 Err(status) => {
280                     debug!("decoder inner trailers error: {:?}", status);
281                     return Poll::Ready(Err(status));
282                 }
283             }
284         }
285         Poll::Ready(Ok(()))
286     }
287 }
288 
289 impl<T> Streaming<T> {
290     /// Fetch the next message from this stream.
291     ///
292     /// # Return value
293     ///
294     /// - `Result::Err(val)` means a gRPC error was sent by the sender instead
295     /// of a valid response message. Refer to [`Status::code`] and
296     /// [`Status::message`] to examine possible error causes.
297     ///
298     /// - `Result::Ok(None)` means the stream was closed by the sender and no
299     /// more messages will be delivered. Further attempts to call
300     /// [`Streaming::message`] will result in the same return value.
301     ///
302     /// - `Result::Ok(Some(val))` means the sender streamed a valid response
303     /// message `val`.
304     ///
305     /// ```rust
306     /// # use tonic::{Streaming, Status, codec::Decoder};
307     /// # use std::fmt::Debug;
308     /// # async fn next_message_ex<T, D>(mut request: Streaming<T>) -> Result<(), Status>
309     /// # where T: Debug,
310     /// # D: Decoder<Item = T, Error = Status> + Send  + 'static,
311     /// # {
312     /// if let Some(next_message) = request.message().await? {
313     ///     println!("{:?}", next_message);
314     /// }
315     /// # Ok(())
316     /// # }
317     /// ```
message(&mut self) -> Result<Option<T>, Status>318     pub async fn message(&mut self) -> Result<Option<T>, Status> {
319         match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
320             Some(Ok(m)) => Ok(Some(m)),
321             Some(Err(e)) => Err(e),
322             None => Ok(None),
323         }
324     }
325 
326     /// Fetch the trailing metadata.
327     ///
328     /// This will drain the stream of all its messages to receive the trailing
329     /// metadata. If [`Streaming::message`] returns `None` then this function
330     /// will not need to poll for trailers since the body was totally consumed.
331     ///
332     /// ```rust
333     /// # use tonic::{Streaming, Status};
334     /// # async fn trailers_ex<T>(mut request: Streaming<T>) -> Result<(), Status> {
335     /// if let Some(metadata) = request.trailers().await? {
336     ///     println!("{:?}", metadata);
337     /// }
338     /// # Ok(())
339     /// # }
340     /// ```
trailers(&mut self) -> Result<Option<MetadataMap>, Status>341     pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
342         // Shortcut to see if we already pulled the trailers in the stream step
343         // we need to do that so that the stream can error on trailing grpc-status
344         if let Some(trailers) = self.inner.trailers.take() {
345             return Ok(Some(trailers));
346         }
347 
348         // To fetch the trailers we must clear the body and drop it.
349         while self.message().await?.is_some() {}
350 
351         // Since we call poll_trailers internally on poll_next we need to
352         // check if it got cached again.
353         if let Some(trailers) = self.inner.trailers.take() {
354             return Ok(Some(trailers));
355         }
356 
357         // Trailers were not caught during poll_next and thus lets poll for
358         // them manually.
359         let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx))
360             .await
361             .map_err(|e| Status::from_error(Box::new(e)));
362 
363         map.map(|x| x.map(MetadataMap::from_headers))
364     }
365 
decode_chunk(&mut self) -> Result<Option<T>, Status>366     fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
367         match self.inner.decode_chunk()? {
368             Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? {
369                 Some(msg) => {
370                     self.inner.state = State::ReadHeader;
371                     Ok(Some(msg))
372                 }
373                 None => Ok(None),
374             },
375             None => Ok(None),
376         }
377     }
378 }
379 
380 impl<T> Stream for Streaming<T> {
381     type Item = Result<T, Status>;
382 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>383     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
384         loop {
385             if let State::Error = &self.inner.state {
386                 return Poll::Ready(None);
387             }
388 
389             // FIXME: implement the ability to poll trailers when we _know_ that
390             // the consumer of this stream will only poll for the first message.
391             // This means we skip the poll_trailers step.
392             if let Some(item) = self.decode_chunk()? {
393                 return Poll::Ready(Some(Ok(item)));
394             }
395 
396             match ready!(self.inner.poll_data(cx))? {
397                 Some(()) => (),
398                 None => break,
399             }
400         }
401 
402         Poll::Ready(match ready!(self.inner.poll_response(cx)) {
403             Ok(()) => None,
404             Err(err) => Some(Err(err)),
405         })
406     }
407 }
408 
409 impl<T> fmt::Debug for Streaming<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result410     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411         f.debug_struct("Streaming").finish()
412     }
413 }
414 
415 #[cfg(test)]
416 static_assertions::assert_impl_all!(Streaming<()>: Send);
417