1 use crate::codec::compression::{
2     CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3 };
4 use crate::{
5     body::BoxBody,
6     codec::{encode_server, Codec, Streaming},
7     server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
8     Code, Request, Status,
9 };
10 use http_body::Body;
11 use std::fmt;
12 use tokio_stream::{Stream, StreamExt};
13 
14 macro_rules! t {
15     ($result:expr) => {
16         match $result {
17             Ok(value) => value,
18             Err(status) => return status.to_http(),
19         }
20     };
21 }
22 
23 /// A gRPC Server handler.
24 ///
25 /// This will wrap some inner [`Codec`] and provide utilities to handle
26 /// inbound unary, client side streaming, server side streaming, and
27 /// bi-directional streaming.
28 ///
29 /// Each request handler method accepts some service that implements the
30 /// corresponding service trait and a http request that contains some body that
31 /// implements some [`Body`].
32 pub struct Grpc<T> {
33     codec: T,
34     /// Which compression encodings does the server accept for requests?
35     accept_compression_encodings: EnabledCompressionEncodings,
36     /// Which compression encodings might the server use for responses.
37     send_compression_encodings: EnabledCompressionEncodings,
38     /// Limits the maximum size of a decoded message.
39     max_decoding_message_size: Option<usize>,
40     /// Limits the maximum size of an encoded message.
41     max_encoding_message_size: Option<usize>,
42 }
43 
44 impl<T> Grpc<T>
45 where
46     T: Codec,
47 {
48     /// Creates a new gRPC server with the provided [`Codec`].
new(codec: T) -> Self49     pub fn new(codec: T) -> Self {
50         Self {
51             codec,
52             accept_compression_encodings: EnabledCompressionEncodings::default(),
53             send_compression_encodings: EnabledCompressionEncodings::default(),
54             max_decoding_message_size: None,
55             max_encoding_message_size: None,
56         }
57     }
58 
59     /// Enable accepting compressed requests.
60     ///
61     /// If a request with an unsupported encoding is received the server will respond with
62     /// [`Code::UnUnimplemented`](crate::Code).
63     ///
64     /// # Example
65     ///
66     /// The most common way of using this is through a server generated by tonic-build:
67     ///
68     /// ```rust
69     /// # enum CompressionEncoding { Gzip }
70     /// # struct Svc;
71     /// # struct ExampleServer<T>(T);
72     /// # impl<T> ExampleServer<T> {
73     /// #     fn new(svc: T) -> Self { Self(svc) }
74     /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
75     /// # }
76     /// # #[tonic::async_trait]
77     /// # trait Example {}
78     ///
79     /// #[tonic::async_trait]
80     /// impl Example for Svc {
81     ///     // ...
82     /// }
83     ///
84     /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip);
85     /// ```
accept_compressed(mut self, encoding: CompressionEncoding) -> Self86     pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
87         self.accept_compression_encodings.enable(encoding);
88         self
89     }
90 
91     /// Enable sending compressed responses.
92     ///
93     /// Requires the client to also support receiving compressed responses.
94     ///
95     /// # Example
96     ///
97     /// The most common way of using this is through a server generated by tonic-build:
98     ///
99     /// ```rust
100     /// # enum CompressionEncoding { Gzip }
101     /// # struct Svc;
102     /// # struct ExampleServer<T>(T);
103     /// # impl<T> ExampleServer<T> {
104     /// #     fn new(svc: T) -> Self { Self(svc) }
105     /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
106     /// # }
107     /// # #[tonic::async_trait]
108     /// # trait Example {}
109     ///
110     /// #[tonic::async_trait]
111     /// impl Example for Svc {
112     ///     // ...
113     /// }
114     ///
115     /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip);
116     /// ```
send_compressed(mut self, encoding: CompressionEncoding) -> Self117     pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
118         self.send_compression_encodings.enable(encoding);
119         self
120     }
121 
122     /// Limits the maximum size of a decoded message.
123     ///
124     /// # Example
125     ///
126     /// The most common way of using this is through a server generated by tonic-build:
127     ///
128     /// ```rust
129     /// # struct Svc;
130     /// # struct ExampleServer<T>(T);
131     /// # impl<T> ExampleServer<T> {
132     /// #     fn new(svc: T) -> Self { Self(svc) }
133     /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
134     /// # }
135     /// # #[tonic::async_trait]
136     /// # trait Example {}
137     ///
138     /// #[tonic::async_trait]
139     /// impl Example for Svc {
140     ///     // ...
141     /// }
142     ///
143     /// // Set the limit to 2MB, Defaults to 4MB.
144     /// let limit = 2 * 1024 * 1024;
145     /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit);
146     /// ```
max_decoding_message_size(mut self, limit: usize) -> Self147     pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
148         self.max_decoding_message_size = Some(limit);
149         self
150     }
151 
152     /// Limits the maximum size of a encoded message.
153     ///
154     /// # Example
155     ///
156     /// The most common way of using this is through a server generated by tonic-build:
157     ///
158     /// ```rust
159     /// # struct Svc;
160     /// # struct ExampleServer<T>(T);
161     /// # impl<T> ExampleServer<T> {
162     /// #     fn new(svc: T) -> Self { Self(svc) }
163     /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
164     /// # }
165     /// # #[tonic::async_trait]
166     /// # trait Example {}
167     ///
168     /// #[tonic::async_trait]
169     /// impl Example for Svc {
170     ///     // ...
171     /// }
172     ///
173     /// // Set the limit to 2MB, Defaults to 4MB.
174     /// let limit = 2 * 1024 * 1024;
175     /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit);
176     /// ```
max_encoding_message_size(mut self, limit: usize) -> Self177     pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
178         self.max_encoding_message_size = Some(limit);
179         self
180     }
181 
182     #[doc(hidden)]
apply_compression_config( self, accept_encodings: EnabledCompressionEncodings, send_encodings: EnabledCompressionEncodings, ) -> Self183     pub fn apply_compression_config(
184         self,
185         accept_encodings: EnabledCompressionEncodings,
186         send_encodings: EnabledCompressionEncodings,
187     ) -> Self {
188         let mut this = self;
189 
190         for &encoding in CompressionEncoding::encodings() {
191             if accept_encodings.is_enabled(encoding) {
192                 this = this.accept_compressed(encoding);
193             }
194             if send_encodings.is_enabled(encoding) {
195                 this = this.send_compressed(encoding);
196             }
197         }
198 
199         this
200     }
201 
202     #[doc(hidden)]
apply_max_message_size_config( self, max_decoding_message_size: Option<usize>, max_encoding_message_size: Option<usize>, ) -> Self203     pub fn apply_max_message_size_config(
204         self,
205         max_decoding_message_size: Option<usize>,
206         max_encoding_message_size: Option<usize>,
207     ) -> Self {
208         let mut this = self;
209 
210         if let Some(limit) = max_decoding_message_size {
211             this = this.max_decoding_message_size(limit);
212         }
213         if let Some(limit) = max_encoding_message_size {
214             this = this.max_encoding_message_size(limit);
215         }
216 
217         this
218     }
219 
220     /// Handle a single unary gRPC request.
unary<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: UnaryService<T::Decode, Response = T::Encode>, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,221     pub async fn unary<S, B>(
222         &mut self,
223         mut service: S,
224         req: http::Request<B>,
225     ) -> http::Response<BoxBody>
226     where
227         S: UnaryService<T::Decode, Response = T::Encode>,
228         B: Body + Send + 'static,
229         B::Error: Into<crate::Error> + Send,
230     {
231         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
232             req.headers(),
233             self.send_compression_encodings,
234         );
235 
236         let request = match self.map_request_unary(req).await {
237             Ok(r) => r,
238             Err(status) => {
239                 return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
240                     Err(status),
241                     accept_encoding,
242                     SingleMessageCompressionOverride::default(),
243                     self.max_encoding_message_size,
244                 );
245             }
246         };
247 
248         let response = service
249             .call(request)
250             .await
251             .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
252 
253         let compression_override = compression_override_from_response(&response);
254 
255         self.map_response(
256             response,
257             accept_encoding,
258             compression_override,
259             self.max_encoding_message_size,
260         )
261     }
262 
263     /// Handle a server side streaming request.
server_streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: ServerStreamingService<T::Decode, Response = T::Encode>, S::ResponseStream: Send + 'static, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,264     pub async fn server_streaming<S, B>(
265         &mut self,
266         mut service: S,
267         req: http::Request<B>,
268     ) -> http::Response<BoxBody>
269     where
270         S: ServerStreamingService<T::Decode, Response = T::Encode>,
271         S::ResponseStream: Send + 'static,
272         B: Body + Send + 'static,
273         B::Error: Into<crate::Error> + Send,
274     {
275         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
276             req.headers(),
277             self.send_compression_encodings,
278         );
279 
280         let request = match self.map_request_unary(req).await {
281             Ok(r) => r,
282             Err(status) => {
283                 return self.map_response::<S::ResponseStream>(
284                     Err(status),
285                     accept_encoding,
286                     SingleMessageCompressionOverride::default(),
287                     self.max_encoding_message_size,
288                 );
289             }
290         };
291 
292         let response = service.call(request).await;
293 
294         self.map_response(
295             response,
296             accept_encoding,
297             // disabling compression of individual stream items must be done on
298             // the items themselves
299             SingleMessageCompressionOverride::default(),
300             self.max_encoding_message_size,
301         )
302     }
303 
304     /// Handle a client side streaming gRPC request.
client_streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: ClientStreamingService<T::Decode, Response = T::Encode>, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send + 'static,305     pub async fn client_streaming<S, B>(
306         &mut self,
307         mut service: S,
308         req: http::Request<B>,
309     ) -> http::Response<BoxBody>
310     where
311         S: ClientStreamingService<T::Decode, Response = T::Encode>,
312         B: Body + Send + 'static,
313         B::Error: Into<crate::Error> + Send + 'static,
314     {
315         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
316             req.headers(),
317             self.send_compression_encodings,
318         );
319 
320         let request = t!(self.map_request_streaming(req));
321 
322         let response = service
323             .call(request)
324             .await
325             .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
326 
327         let compression_override = compression_override_from_response(&response);
328 
329         self.map_response(
330             response,
331             accept_encoding,
332             compression_override,
333             self.max_encoding_message_size,
334         )
335     }
336 
337     /// Handle a bi-directional streaming gRPC request.
streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: StreamingService<T::Decode, Response = T::Encode> + Send, S::ResponseStream: Send + 'static, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,338     pub async fn streaming<S, B>(
339         &mut self,
340         mut service: S,
341         req: http::Request<B>,
342     ) -> http::Response<BoxBody>
343     where
344         S: StreamingService<T::Decode, Response = T::Encode> + Send,
345         S::ResponseStream: Send + 'static,
346         B: Body + Send + 'static,
347         B::Error: Into<crate::Error> + Send,
348     {
349         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
350             req.headers(),
351             self.send_compression_encodings,
352         );
353 
354         let request = t!(self.map_request_streaming(req));
355 
356         let response = service.call(request).await;
357 
358         self.map_response(
359             response,
360             accept_encoding,
361             SingleMessageCompressionOverride::default(),
362             self.max_encoding_message_size,
363         )
364     }
365 
map_request_unary<B>( &mut self, request: http::Request<B>, ) -> Result<Request<T::Decode>, Status> where B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,366     async fn map_request_unary<B>(
367         &mut self,
368         request: http::Request<B>,
369     ) -> Result<Request<T::Decode>, Status>
370     where
371         B: Body + Send + 'static,
372         B::Error: Into<crate::Error> + Send,
373     {
374         let request_compression_encoding = self.request_encoding_if_supported(&request)?;
375 
376         let (parts, body) = request.into_parts();
377 
378         let stream = Streaming::new_request(
379             self.codec.decoder(),
380             body,
381             request_compression_encoding,
382             self.max_decoding_message_size,
383         );
384 
385         tokio::pin!(stream);
386 
387         let message = stream
388             .try_next()
389             .await?
390             .ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?;
391 
392         let mut req = Request::from_http_parts(parts, message);
393 
394         if let Some(trailers) = stream.trailers().await? {
395             req.metadata_mut().merge(trailers);
396         }
397 
398         Ok(req)
399     }
400 
map_request_streaming<B>( &mut self, request: http::Request<B>, ) -> Result<Request<Streaming<T::Decode>>, Status> where B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,401     fn map_request_streaming<B>(
402         &mut self,
403         request: http::Request<B>,
404     ) -> Result<Request<Streaming<T::Decode>>, Status>
405     where
406         B: Body + Send + 'static,
407         B::Error: Into<crate::Error> + Send,
408     {
409         let encoding = self.request_encoding_if_supported(&request)?;
410 
411         let request = request.map(|body| {
412             Streaming::new_request(
413                 self.codec.decoder(),
414                 body,
415                 encoding,
416                 self.max_decoding_message_size,
417             )
418         });
419 
420         Ok(Request::from_http(request))
421     }
422 
map_response<B>( &mut self, response: Result<crate::Response<B>, Status>, accept_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> http::Response<BoxBody> where B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,423     fn map_response<B>(
424         &mut self,
425         response: Result<crate::Response<B>, Status>,
426         accept_encoding: Option<CompressionEncoding>,
427         compression_override: SingleMessageCompressionOverride,
428         max_message_size: Option<usize>,
429     ) -> http::Response<BoxBody>
430     where
431         B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
432     {
433         let response = match response {
434             Ok(r) => r,
435             Err(status) => return status.to_http(),
436         };
437 
438         let (mut parts, body) = response.into_http().into_parts();
439 
440         // Set the content type
441         parts.headers.insert(
442             http::header::CONTENT_TYPE,
443             http::header::HeaderValue::from_static("application/grpc"),
444         );
445 
446         #[cfg(any(feature = "gzip", feature = "zstd"))]
447         if let Some(encoding) = accept_encoding {
448             // Set the content encoding
449             parts.headers.insert(
450                 crate::codec::compression::ENCODING_HEADER,
451                 encoding.into_header_value(),
452             );
453         }
454 
455         let body = encode_server(
456             self.codec.encoder(),
457             body,
458             accept_encoding,
459             compression_override,
460             max_message_size,
461         );
462 
463         http::Response::from_parts(parts, BoxBody::new(body))
464     }
465 
request_encoding_if_supported<B>( &self, request: &http::Request<B>, ) -> Result<Option<CompressionEncoding>, Status>466     fn request_encoding_if_supported<B>(
467         &self,
468         request: &http::Request<B>,
469     ) -> Result<Option<CompressionEncoding>, Status> {
470         CompressionEncoding::from_encoding_header(
471             request.headers(),
472             self.accept_compression_encodings,
473         )
474     }
475 }
476 
477 impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result478     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479         let mut f = f.debug_struct("Grpc");
480 
481         f.field("codec", &self.codec);
482 
483         f.field(
484             "accept_compression_encodings",
485             &self.accept_compression_encodings,
486         );
487 
488         f.field(
489             "send_compression_encodings",
490             &self.send_compression_encodings,
491         );
492 
493         f.finish()
494     }
495 }
496 
compression_override_from_response<B, E>( res: &Result<crate::Response<B>, E>, ) -> SingleMessageCompressionOverride497 fn compression_override_from_response<B, E>(
498     res: &Result<crate::Response<B>, E>,
499 ) -> SingleMessageCompressionOverride {
500     res.as_ref()
501         .ok()
502         .and_then(|response| {
503             response
504                 .extensions()
505                 .get::<SingleMessageCompressionOverride>()
506                 .copied()
507         })
508         .unwrap_or_default()
509 }
510