1 use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride};
2 use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
3 use crate::{Code, Status};
4 use bytes::{BufMut, Bytes, BytesMut};
5 use http::HeaderMap;
6 use http_body::Body;
7 use pin_project::pin_project;
8 use std::{
9 pin::Pin,
10 task::{ready, Context, Poll},
11 };
12 use tokio_stream::{Stream, StreamExt};
13
14 pub(super) const BUFFER_SIZE: usize = 8 * 1024;
15 const YIELD_THRESHOLD: usize = 32 * 1024;
16
encode_server<T, U>( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>> where T: Encoder<Error = Status>, U: Stream<Item = Result<T::Item, Status>>,17 pub(crate) fn encode_server<T, U>(
18 encoder: T,
19 source: U,
20 compression_encoding: Option<CompressionEncoding>,
21 compression_override: SingleMessageCompressionOverride,
22 max_message_size: Option<usize>,
23 ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
24 where
25 T: Encoder<Error = Status>,
26 U: Stream<Item = Result<T::Item, Status>>,
27 {
28 let stream = EncodedBytes::new(
29 encoder,
30 source.fuse(),
31 compression_encoding,
32 compression_override,
33 max_message_size,
34 );
35
36 EncodeBody::new_server(stream)
37 }
38
encode_client<T, U>( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>> where T: Encoder<Error = Status>, U: Stream<Item = T::Item>,39 pub(crate) fn encode_client<T, U>(
40 encoder: T,
41 source: U,
42 compression_encoding: Option<CompressionEncoding>,
43 max_message_size: Option<usize>,
44 ) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
45 where
46 T: Encoder<Error = Status>,
47 U: Stream<Item = T::Item>,
48 {
49 let stream = EncodedBytes::new(
50 encoder,
51 source.fuse().map(Ok),
52 compression_encoding,
53 SingleMessageCompressionOverride::default(),
54 max_message_size,
55 );
56 EncodeBody::new_client(stream)
57 }
58
59 /// Combinator for efficient encoding of messages into reasonably sized buffers.
60 /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
61 /// splitting off and yielding a buffer when either:
62 /// * The delegate stream polls as not ready, or
63 /// * The encoded buffer surpasses YIELD_THRESHOLD.
64 #[pin_project(project = EncodedBytesProj)]
65 #[derive(Debug)]
66 pub(crate) struct EncodedBytes<T, U>
67 where
68 T: Encoder<Error = Status>,
69 U: Stream<Item = Result<T::Item, Status>>,
70 {
71 #[pin]
72 source: U,
73 encoder: T,
74 compression_encoding: Option<CompressionEncoding>,
75 max_message_size: Option<usize>,
76 buf: BytesMut,
77 uncompression_buf: BytesMut,
78 }
79
80 impl<T, U> EncodedBytes<T, U>
81 where
82 T: Encoder<Error = Status>,
83 U: Stream<Item = Result<T::Item, Status>>,
84 {
85 // `source` should be fused stream.
new( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self86 fn new(
87 encoder: T,
88 source: U,
89 compression_encoding: Option<CompressionEncoding>,
90 compression_override: SingleMessageCompressionOverride,
91 max_message_size: Option<usize>,
92 ) -> Self {
93 let buf = BytesMut::with_capacity(BUFFER_SIZE);
94
95 let compression_encoding =
96 if compression_override == SingleMessageCompressionOverride::Disable {
97 None
98 } else {
99 compression_encoding
100 };
101
102 let uncompression_buf = if compression_encoding.is_some() {
103 BytesMut::with_capacity(BUFFER_SIZE)
104 } else {
105 BytesMut::new()
106 };
107
108 Self {
109 source,
110 encoder,
111 compression_encoding,
112 max_message_size,
113 buf,
114 uncompression_buf,
115 }
116 }
117 }
118
119 impl<T, U> Stream for EncodedBytes<T, U>
120 where
121 T: Encoder<Error = Status>,
122 U: Stream<Item = Result<T::Item, Status>>,
123 {
124 type Item = Result<Bytes, Status>;
125
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>126 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 let EncodedBytesProj {
128 mut source,
129 encoder,
130 compression_encoding,
131 max_message_size,
132 buf,
133 uncompression_buf,
134 } = self.project();
135
136 loop {
137 match source.as_mut().poll_next(cx) {
138 Poll::Pending if buf.is_empty() => {
139 return Poll::Pending;
140 }
141 Poll::Ready(None) if buf.is_empty() => {
142 return Poll::Ready(None);
143 }
144 Poll::Pending | Poll::Ready(None) => {
145 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
146 }
147 Poll::Ready(Some(Ok(item))) => {
148 if let Err(status) = encode_item(
149 encoder,
150 buf,
151 uncompression_buf,
152 *compression_encoding,
153 *max_message_size,
154 item,
155 ) {
156 return Poll::Ready(Some(Err(status)));
157 }
158
159 if buf.len() >= YIELD_THRESHOLD {
160 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
161 }
162 }
163 Poll::Ready(Some(Err(status))) => {
164 return Poll::Ready(Some(Err(status)));
165 }
166 }
167 }
168 }
169 }
170
encode_item<T>( encoder: &mut T, buf: &mut BytesMut, uncompression_buf: &mut BytesMut, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, item: T::Item, ) -> Result<(), Status> where T: Encoder<Error = Status>,171 fn encode_item<T>(
172 encoder: &mut T,
173 buf: &mut BytesMut,
174 uncompression_buf: &mut BytesMut,
175 compression_encoding: Option<CompressionEncoding>,
176 max_message_size: Option<usize>,
177 item: T::Item,
178 ) -> Result<(), Status>
179 where
180 T: Encoder<Error = Status>,
181 {
182 let offset = buf.len();
183
184 buf.reserve(HEADER_SIZE);
185 unsafe {
186 buf.advance_mut(HEADER_SIZE);
187 }
188
189 if let Some(encoding) = compression_encoding {
190 uncompression_buf.clear();
191
192 encoder
193 .encode(item, &mut EncodeBuf::new(uncompression_buf))
194 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
195
196 let uncompressed_len = uncompression_buf.len();
197
198 compress(encoding, uncompression_buf, buf, uncompressed_len)
199 .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
200 } else {
201 encoder
202 .encode(item, &mut EncodeBuf::new(buf))
203 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
204 }
205
206 // now that we know length, we can write the header
207 finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
208 }
209
finish_encoding( compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buf: &mut [u8], ) -> Result<(), Status>210 fn finish_encoding(
211 compression_encoding: Option<CompressionEncoding>,
212 max_message_size: Option<usize>,
213 buf: &mut [u8],
214 ) -> Result<(), Status> {
215 let len = buf.len() - HEADER_SIZE;
216 let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
217 if len > limit {
218 return Err(Status::new(
219 Code::OutOfRange,
220 format!(
221 "Error, message length too large: found {} bytes, the limit is: {} bytes",
222 len, limit
223 ),
224 ));
225 }
226
227 if len > std::u32::MAX as usize {
228 return Err(Status::resource_exhausted(format!(
229 "Cannot return body with more than 4GB of data but got {len} bytes"
230 )));
231 }
232 {
233 let mut buf = &mut buf[..HEADER_SIZE];
234 buf.put_u8(compression_encoding.is_some() as u8);
235 buf.put_u32(len as u32);
236 }
237
238 Ok(())
239 }
240
241 #[derive(Debug)]
242 enum Role {
243 Client,
244 Server,
245 }
246
247 #[pin_project]
248 #[derive(Debug)]
249 pub(crate) struct EncodeBody<S> {
250 #[pin]
251 inner: S,
252 state: EncodeState,
253 }
254
255 #[derive(Debug)]
256 struct EncodeState {
257 error: Option<Status>,
258 role: Role,
259 is_end_stream: bool,
260 }
261
262 impl<S> EncodeBody<S>
263 where
264 S: Stream<Item = Result<Bytes, Status>>,
265 {
new_client(inner: S) -> Self266 pub(crate) fn new_client(inner: S) -> Self {
267 Self {
268 inner,
269 state: EncodeState {
270 error: None,
271 role: Role::Client,
272 is_end_stream: false,
273 },
274 }
275 }
276
new_server(inner: S) -> Self277 pub(crate) fn new_server(inner: S) -> Self {
278 Self {
279 inner,
280 state: EncodeState {
281 error: None,
282 role: Role::Server,
283 is_end_stream: false,
284 },
285 }
286 }
287 }
288
289 impl EncodeState {
trailers(&mut self) -> Result<Option<HeaderMap>, Status>290 fn trailers(&mut self) -> Result<Option<HeaderMap>, Status> {
291 match self.role {
292 Role::Client => Ok(None),
293 Role::Server => {
294 if self.is_end_stream {
295 return Ok(None);
296 }
297
298 let status = if let Some(status) = self.error.take() {
299 self.is_end_stream = true;
300 status
301 } else {
302 Status::new(Code::Ok, "")
303 };
304
305 Ok(Some(status.to_header_map()?))
306 }
307 }
308 }
309 }
310
311 impl<S> Body for EncodeBody<S>
312 where
313 S: Stream<Item = Result<Bytes, Status>>,
314 {
315 type Data = Bytes;
316 type Error = Status;
317
is_end_stream(&self) -> bool318 fn is_end_stream(&self) -> bool {
319 self.state.is_end_stream
320 }
321
poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>322 fn poll_data(
323 self: Pin<&mut Self>,
324 cx: &mut Context<'_>,
325 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
326 let self_proj = self.project();
327 match ready!(self_proj.inner.poll_next(cx)) {
328 Some(Ok(d)) => Some(Ok(d)).into(),
329 Some(Err(status)) => match self_proj.state.role {
330 Role::Client => Some(Err(status)).into(),
331 Role::Server => {
332 self_proj.state.error = Some(status);
333 None.into()
334 }
335 },
336 None => None.into(),
337 }
338 }
339
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Status>>340 fn poll_trailers(
341 self: Pin<&mut Self>,
342 _cx: &mut Context<'_>,
343 ) -> Poll<Result<Option<HeaderMap>, Status>> {
344 Poll::Ready(self.project().state.trailers())
345 }
346 }
347