use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; use http_body::Body; use pin_project::pin_project; use std::{ pin::Pin, task::{ready, Context, Poll}, }; use tokio_stream::{Stream, StreamExt}; pub(super) const BUFFER_SIZE: usize = 8 * 1024; const YIELD_THRESHOLD: usize = 32 * 1024; pub(crate) fn encode_server<T, U>( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>> where T: Encoder<Error = Status>, U: Stream<Item = Result<T::Item, Status>>, { let stream = EncodedBytes::new( encoder, source.fuse(), compression_encoding, compression_override, max_message_size, ); EncodeBody::new_server(stream) } pub(crate) fn encode_client<T, U>( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>> where T: Encoder<Error = Status>, U: Stream<Item = T::Item>, { let stream = EncodedBytes::new( encoder, source.fuse().map(Ok), compression_encoding, SingleMessageCompressionOverride::default(), max_message_size, ); EncodeBody::new_client(stream) } /// Combinator for efficient encoding of messages into reasonably sized buffers. /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut, /// splitting off and yielding a buffer when either: /// * The delegate stream polls as not ready, or /// * The encoded buffer surpasses YIELD_THRESHOLD. #[pin_project(project = EncodedBytesProj)] #[derive(Debug)] pub(crate) struct EncodedBytes<T, U> where T: Encoder<Error = Status>, U: Stream<Item = Result<T::Item, Status>>, { #[pin] source: U, encoder: T, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buf: BytesMut, uncompression_buf: BytesMut, } impl<T, U> EncodedBytes<T, U> where T: Encoder<Error = Status>, U: Stream<Item = Result<T::Item, Status>>, { // `source` should be fused stream. fn new( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self { let buf = BytesMut::with_capacity(BUFFER_SIZE); let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable { None } else { compression_encoding }; let uncompression_buf = if compression_encoding.is_some() { BytesMut::with_capacity(BUFFER_SIZE) } else { BytesMut::new() }; Self { source, encoder, compression_encoding, max_message_size, buf, uncompression_buf, } } } impl<T, U> Stream for EncodedBytes<T, U> where T: Encoder<Error = Status>, U: Stream<Item = Result<T::Item, Status>>, { type Item = Result<Bytes, Status>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let EncodedBytesProj { mut source, encoder, compression_encoding, max_message_size, buf, uncompression_buf, } = self.project(); loop { match source.as_mut().poll_next(cx) { Poll::Pending if buf.is_empty() => { return Poll::Pending; } Poll::Ready(None) if buf.is_empty() => { return Poll::Ready(None); } Poll::Pending | Poll::Ready(None) => { return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } Poll::Ready(Some(Ok(item))) => { if let Err(status) = encode_item( encoder, buf, uncompression_buf, *compression_encoding, *max_message_size, item, ) { return Poll::Ready(Some(Err(status))); } if buf.len() >= YIELD_THRESHOLD { return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } } Poll::Ready(Some(Err(status))) => { return Poll::Ready(Some(Err(status))); } } } } } fn encode_item<T>( encoder: &mut T, buf: &mut BytesMut, uncompression_buf: &mut BytesMut, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, item: T::Item, ) -> Result<(), Status> where T: Encoder<Error = Status>, { let offset = buf.len(); buf.reserve(HEADER_SIZE); unsafe { buf.advance_mut(HEADER_SIZE); } if let Some(encoding) = compression_encoding { uncompression_buf.clear(); encoder .encode(item, &mut EncodeBuf::new(uncompression_buf)) .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; let uncompressed_len = uncompression_buf.len(); compress(encoding, uncompression_buf, buf, uncompressed_len) .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; } else { encoder .encode(item, &mut EncodeBuf::new(buf)) .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; } // now that we know length, we can write the header finish_encoding(compression_encoding, max_message_size, &mut buf[offset..]) } fn finish_encoding( compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buf: &mut [u8], ) -> Result<(), Status> { let len = buf.len() - HEADER_SIZE; let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE); if len > limit { return Err(Status::new( Code::OutOfRange, format!( "Error, message length too large: found {} bytes, the limit is: {} bytes", len, limit ), )); } if len > std::u32::MAX as usize { return Err(Status::resource_exhausted(format!( "Cannot return body with more than 4GB of data but got {len} bytes" ))); } { let mut buf = &mut buf[..HEADER_SIZE]; buf.put_u8(compression_encoding.is_some() as u8); buf.put_u32(len as u32); } Ok(()) } #[derive(Debug)] enum Role { Client, Server, } #[pin_project] #[derive(Debug)] pub(crate) struct EncodeBody<S> { #[pin] inner: S, state: EncodeState, } #[derive(Debug)] struct EncodeState { error: Option<Status>, role: Role, is_end_stream: bool, } impl<S> EncodeBody<S> where S: Stream<Item = Result<Bytes, Status>>, { pub(crate) fn new_client(inner: S) -> Self { Self { inner, state: EncodeState { error: None, role: Role::Client, is_end_stream: false, }, } } pub(crate) fn new_server(inner: S) -> Self { Self { inner, state: EncodeState { error: None, role: Role::Server, is_end_stream: false, }, } } } impl EncodeState { fn trailers(&mut self) -> Result<Option<HeaderMap>, Status> { match self.role { Role::Client => Ok(None), Role::Server => { if self.is_end_stream { return Ok(None); } let status = if let Some(status) = self.error.take() { self.is_end_stream = true; status } else { Status::new(Code::Ok, "") }; Ok(Some(status.to_header_map()?)) } } } } impl<S> Body for EncodeBody<S> where S: Stream<Item = Result<Bytes, Status>>, { type Data = Bytes; type Error = Status; fn is_end_stream(&self) -> bool { self.state.is_end_stream } fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>> { let self_proj = self.project(); match ready!(self_proj.inner.poll_next(cx)) { Some(Ok(d)) => Some(Ok(d)).into(), Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { self_proj.state.error = Some(status); None.into() } }, None => None.into(), } } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Status>> { Poll::Ready(self.project().state.trailers()) } }