1 use super::compression::{decompress, CompressionEncoding}; 2 use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; 3 use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; 4 use bytes::{Buf, BufMut, BytesMut}; 5 use http::StatusCode; 6 use http_body::Body; 7 use std::{ 8 fmt, future, 9 pin::Pin, 10 task::ready, 11 task::{Context, Poll}, 12 }; 13 use tokio_stream::Stream; 14 use tracing::{debug, trace}; 15 16 const BUFFER_SIZE: usize = 8 * 1024; 17 18 /// Streaming requests and responses. 19 /// 20 /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface 21 /// to fetch the message stream and trailing metadata 22 pub struct Streaming<T> { 23 decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>, 24 inner: StreamingInner, 25 } 26 27 struct StreamingInner { 28 body: BoxBody, 29 state: State, 30 direction: Direction, 31 buf: BytesMut, 32 trailers: Option<MetadataMap>, 33 decompress_buf: BytesMut, 34 encoding: Option<CompressionEncoding>, 35 max_message_size: Option<usize>, 36 } 37 38 impl<T> Unpin for Streaming<T> {} 39 40 #[derive(Debug, Clone, Copy)] 41 enum State { 42 ReadHeader, 43 ReadBody { 44 compression: Option<CompressionEncoding>, 45 len: usize, 46 }, 47 Error, 48 } 49 50 #[derive(Debug, PartialEq, Eq)] 51 enum Direction { 52 Request, 53 Response(StatusCode), 54 EmptyResponse, 55 } 56 57 impl<T> Streaming<T> { new_response<B, D>( decoder: D, body: B, status_code: StatusCode, encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,58 pub(crate) fn new_response<B, D>( 59 decoder: D, 60 body: B, 61 status_code: StatusCode, 62 encoding: Option<CompressionEncoding>, 63 max_message_size: Option<usize>, 64 ) -> Self 65 where 66 B: Body + Send + 'static, 67 B::Error: Into<crate::Error>, 68 D: Decoder<Item = T, Error = Status> + Send + 'static, 69 { 70 Self::new( 71 decoder, 72 body, 73 Direction::Response(status_code), 74 encoding, 75 max_message_size, 76 ) 77 } 78 new_empty<B, D>(decoder: D, body: B) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,79 pub(crate) fn new_empty<B, D>(decoder: D, body: B) -> Self 80 where 81 B: Body + Send + 'static, 82 B::Error: Into<crate::Error>, 83 D: Decoder<Item = T, Error = Status> + Send + 'static, 84 { 85 Self::new(decoder, body, Direction::EmptyResponse, None, None) 86 } 87 88 #[doc(hidden)] new_request<B, D>( decoder: D, body: B, encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,89 pub fn new_request<B, D>( 90 decoder: D, 91 body: B, 92 encoding: Option<CompressionEncoding>, 93 max_message_size: Option<usize>, 94 ) -> Self 95 where 96 B: Body + Send + 'static, 97 B::Error: Into<crate::Error>, 98 D: Decoder<Item = T, Error = Status> + Send + 'static, 99 { 100 Self::new( 101 decoder, 102 body, 103 Direction::Request, 104 encoding, 105 max_message_size, 106 ) 107 } 108 new<B, D>( decoder: D, body: B, direction: Direction, encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self where B: Body + Send + 'static, B::Error: Into<crate::Error>, D: Decoder<Item = T, Error = Status> + Send + 'static,109 fn new<B, D>( 110 decoder: D, 111 body: B, 112 direction: Direction, 113 encoding: Option<CompressionEncoding>, 114 max_message_size: Option<usize>, 115 ) -> Self 116 where 117 B: Body + Send + 'static, 118 B::Error: Into<crate::Error>, 119 D: Decoder<Item = T, Error = Status> + Send + 'static, 120 { 121 Self { 122 decoder: Box::new(decoder), 123 inner: StreamingInner { 124 body: body 125 .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) 126 .map_err(|err| Status::map_error(err.into())) 127 .boxed_unsync(), 128 state: State::ReadHeader, 129 direction, 130 buf: BytesMut::with_capacity(BUFFER_SIZE), 131 trailers: None, 132 decompress_buf: BytesMut::new(), 133 encoding, 134 max_message_size, 135 }, 136 } 137 } 138 } 139 140 impl StreamingInner { decode_chunk(&mut self) -> Result<Option<DecodeBuf<'_>>, Status>141 fn decode_chunk(&mut self) -> Result<Option<DecodeBuf<'_>>, Status> { 142 if let State::ReadHeader = self.state { 143 if self.buf.remaining() < HEADER_SIZE { 144 return Ok(None); 145 } 146 147 let compression_encoding = match self.buf.get_u8() { 148 0 => None, 149 1 => { 150 { 151 if self.encoding.is_some() { 152 self.encoding 153 } else { 154 // https://grpc.github.io/grpc/core/md_doc_compression.html 155 // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding 156 // entry different from identity in its metadata MUST fail with INTERNAL status, 157 // its associated description indicating the invalid Compressed-Flag condition. 158 return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified")); 159 } 160 } 161 } 162 f => { 163 trace!("unexpected compression flag"); 164 let message = if let Direction::Response(status) = self.direction { 165 format!( 166 "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", 167 f, status 168 ) 169 } else { 170 format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f) 171 }; 172 return Err(Status::new(Code::Internal, message)); 173 } 174 }; 175 176 let len = self.buf.get_u32() as usize; 177 let limit = self 178 .max_message_size 179 .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE); 180 if len > limit { 181 return Err(Status::new( 182 Code::OutOfRange, 183 format!( 184 "Error, message length too large: found {} bytes, the limit is: {} bytes", 185 len, limit 186 ), 187 )); 188 } 189 190 self.buf.reserve(len); 191 192 self.state = State::ReadBody { 193 compression: compression_encoding, 194 len, 195 } 196 } 197 198 if let State::ReadBody { len, compression } = self.state { 199 // if we haven't read enough of the message then return and keep 200 // reading 201 if self.buf.remaining() < len || self.buf.len() < len { 202 return Ok(None); 203 } 204 205 let decode_buf = if let Some(encoding) = compression { 206 self.decompress_buf.clear(); 207 208 if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) 209 { 210 let message = if let Direction::Response(status) = self.direction { 211 format!( 212 "Error decompressing: {}, while receiving response with status: {}", 213 err, status 214 ) 215 } else { 216 format!("Error decompressing: {}, while sending request", err) 217 }; 218 return Err(Status::new(Code::Internal, message)); 219 } 220 let decompressed_len = self.decompress_buf.len(); 221 DecodeBuf::new(&mut self.decompress_buf, decompressed_len) 222 } else { 223 DecodeBuf::new(&mut self.buf, len) 224 }; 225 226 return Ok(Some(decode_buf)); 227 } 228 229 Ok(None) 230 } 231 232 // Returns Some(()) if data was found or None if the loop in `poll_next` should break poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>>233 fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> { 234 let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { 235 Some(Ok(d)) => Some(d), 236 Some(Err(status)) => { 237 if self.direction == Direction::Request && status.code() == Code::Cancelled { 238 return Poll::Ready(Ok(None)); 239 } 240 241 let _ = std::mem::replace(&mut self.state, State::Error); 242 debug!("decoder inner stream error: {:?}", status); 243 return Poll::Ready(Err(status)); 244 } 245 None => None, 246 }; 247 248 Poll::Ready(if let Some(data) = chunk { 249 self.buf.put(data); 250 Ok(Some(())) 251 } else { 252 // FIXME: improve buf usage. 253 if self.buf.has_remaining() { 254 trace!("unexpected EOF decoding stream, state: {:?}", self.state); 255 Err(Status::new( 256 Code::Internal, 257 "Unexpected EOF decoding stream.".to_string(), 258 )) 259 } else { 260 Ok(None) 261 } 262 }) 263 } 264 poll_response(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Status>>265 fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Status>> { 266 if let Direction::Response(status) = self.direction { 267 match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { 268 Ok(trailer) => { 269 if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { 270 if let Some(e) = e { 271 return Poll::Ready(Err(e)); 272 } else { 273 return Poll::Ready(Ok(())); 274 } 275 } else { 276 self.trailers = trailer.map(MetadataMap::from_headers); 277 } 278 } 279 Err(status) => { 280 debug!("decoder inner trailers error: {:?}", status); 281 return Poll::Ready(Err(status)); 282 } 283 } 284 } 285 Poll::Ready(Ok(())) 286 } 287 } 288 289 impl<T> Streaming<T> { 290 /// Fetch the next message from this stream. 291 /// 292 /// # Return value 293 /// 294 /// - `Result::Err(val)` means a gRPC error was sent by the sender instead 295 /// of a valid response message. Refer to [`Status::code`] and 296 /// [`Status::message`] to examine possible error causes. 297 /// 298 /// - `Result::Ok(None)` means the stream was closed by the sender and no 299 /// more messages will be delivered. Further attempts to call 300 /// [`Streaming::message`] will result in the same return value. 301 /// 302 /// - `Result::Ok(Some(val))` means the sender streamed a valid response 303 /// message `val`. 304 /// 305 /// ```rust 306 /// # use tonic::{Streaming, Status, codec::Decoder}; 307 /// # use std::fmt::Debug; 308 /// # async fn next_message_ex<T, D>(mut request: Streaming<T>) -> Result<(), Status> 309 /// # where T: Debug, 310 /// # D: Decoder<Item = T, Error = Status> + Send + 'static, 311 /// # { 312 /// if let Some(next_message) = request.message().await? { 313 /// println!("{:?}", next_message); 314 /// } 315 /// # Ok(()) 316 /// # } 317 /// ``` message(&mut self) -> Result<Option<T>, Status>318 pub async fn message(&mut self) -> Result<Option<T>, Status> { 319 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await { 320 Some(Ok(m)) => Ok(Some(m)), 321 Some(Err(e)) => Err(e), 322 None => Ok(None), 323 } 324 } 325 326 /// Fetch the trailing metadata. 327 /// 328 /// This will drain the stream of all its messages to receive the trailing 329 /// metadata. If [`Streaming::message`] returns `None` then this function 330 /// will not need to poll for trailers since the body was totally consumed. 331 /// 332 /// ```rust 333 /// # use tonic::{Streaming, Status}; 334 /// # async fn trailers_ex<T>(mut request: Streaming<T>) -> Result<(), Status> { 335 /// if let Some(metadata) = request.trailers().await? { 336 /// println!("{:?}", metadata); 337 /// } 338 /// # Ok(()) 339 /// # } 340 /// ``` trailers(&mut self) -> Result<Option<MetadataMap>, Status>341 pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> { 342 // Shortcut to see if we already pulled the trailers in the stream step 343 // we need to do that so that the stream can error on trailing grpc-status 344 if let Some(trailers) = self.inner.trailers.take() { 345 return Ok(Some(trailers)); 346 } 347 348 // To fetch the trailers we must clear the body and drop it. 349 while self.message().await?.is_some() {} 350 351 // Since we call poll_trailers internally on poll_next we need to 352 // check if it got cached again. 353 if let Some(trailers) = self.inner.trailers.take() { 354 return Ok(Some(trailers)); 355 } 356 357 // Trailers were not caught during poll_next and thus lets poll for 358 // them manually. 359 let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) 360 .await 361 .map_err(|e| Status::from_error(Box::new(e))); 362 363 map.map(|x| x.map(MetadataMap::from_headers)) 364 } 365 decode_chunk(&mut self) -> Result<Option<T>, Status>366 fn decode_chunk(&mut self) -> Result<Option<T>, Status> { 367 match self.inner.decode_chunk()? { 368 Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { 369 Some(msg) => { 370 self.inner.state = State::ReadHeader; 371 Ok(Some(msg)) 372 } 373 None => Ok(None), 374 }, 375 None => Ok(None), 376 } 377 } 378 } 379 380 impl<T> Stream for Streaming<T> { 381 type Item = Result<T, Status>; 382 poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>383 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 384 loop { 385 if let State::Error = &self.inner.state { 386 return Poll::Ready(None); 387 } 388 389 // FIXME: implement the ability to poll trailers when we _know_ that 390 // the consumer of this stream will only poll for the first message. 391 // This means we skip the poll_trailers step. 392 if let Some(item) = self.decode_chunk()? { 393 return Poll::Ready(Some(Ok(item))); 394 } 395 396 match ready!(self.inner.poll_data(cx))? { 397 Some(()) => (), 398 None => break, 399 } 400 } 401 402 Poll::Ready(match ready!(self.inner.poll_response(cx)) { 403 Ok(()) => None, 404 Err(err) => Some(Err(err)), 405 }) 406 } 407 } 408 409 impl<T> fmt::Debug for Streaming<T> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result410 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 411 f.debug_struct("Streaming").finish() 412 } 413 } 414 415 #[cfg(test)] 416 static_assertions::assert_impl_all!(Streaming<()>: Send); 417