1 use crate::codec::compression::{
2 CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3 };
4 use crate::{
5 body::BoxBody,
6 codec::{encode_server, Codec, Streaming},
7 server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
8 Code, Request, Status,
9 };
10 use http_body::Body;
11 use std::fmt;
12 use tokio_stream::{Stream, StreamExt};
13
14 macro_rules! t {
15 ($result:expr) => {
16 match $result {
17 Ok(value) => value,
18 Err(status) => return status.to_http(),
19 }
20 };
21 }
22
23 /// A gRPC Server handler.
24 ///
25 /// This will wrap some inner [`Codec`] and provide utilities to handle
26 /// inbound unary, client side streaming, server side streaming, and
27 /// bi-directional streaming.
28 ///
29 /// Each request handler method accepts some service that implements the
30 /// corresponding service trait and a http request that contains some body that
31 /// implements some [`Body`].
32 pub struct Grpc<T> {
33 codec: T,
34 /// Which compression encodings does the server accept for requests?
35 accept_compression_encodings: EnabledCompressionEncodings,
36 /// Which compression encodings might the server use for responses.
37 send_compression_encodings: EnabledCompressionEncodings,
38 /// Limits the maximum size of a decoded message.
39 max_decoding_message_size: Option<usize>,
40 /// Limits the maximum size of an encoded message.
41 max_encoding_message_size: Option<usize>,
42 }
43
44 impl<T> Grpc<T>
45 where
46 T: Codec,
47 {
48 /// Creates a new gRPC server with the provided [`Codec`].
new(codec: T) -> Self49 pub fn new(codec: T) -> Self {
50 Self {
51 codec,
52 accept_compression_encodings: EnabledCompressionEncodings::default(),
53 send_compression_encodings: EnabledCompressionEncodings::default(),
54 max_decoding_message_size: None,
55 max_encoding_message_size: None,
56 }
57 }
58
59 /// Enable accepting compressed requests.
60 ///
61 /// If a request with an unsupported encoding is received the server will respond with
62 /// [`Code::UnUnimplemented`](crate::Code).
63 ///
64 /// # Example
65 ///
66 /// The most common way of using this is through a server generated by tonic-build:
67 ///
68 /// ```rust
69 /// # enum CompressionEncoding { Gzip }
70 /// # struct Svc;
71 /// # struct ExampleServer<T>(T);
72 /// # impl<T> ExampleServer<T> {
73 /// # fn new(svc: T) -> Self { Self(svc) }
74 /// # fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
75 /// # }
76 /// # #[tonic::async_trait]
77 /// # trait Example {}
78 ///
79 /// #[tonic::async_trait]
80 /// impl Example for Svc {
81 /// // ...
82 /// }
83 ///
84 /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip);
85 /// ```
accept_compressed(mut self, encoding: CompressionEncoding) -> Self86 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
87 self.accept_compression_encodings.enable(encoding);
88 self
89 }
90
91 /// Enable sending compressed responses.
92 ///
93 /// Requires the client to also support receiving compressed responses.
94 ///
95 /// # Example
96 ///
97 /// The most common way of using this is through a server generated by tonic-build:
98 ///
99 /// ```rust
100 /// # enum CompressionEncoding { Gzip }
101 /// # struct Svc;
102 /// # struct ExampleServer<T>(T);
103 /// # impl<T> ExampleServer<T> {
104 /// # fn new(svc: T) -> Self { Self(svc) }
105 /// # fn send_compressed(self, _: CompressionEncoding) -> Self { self }
106 /// # }
107 /// # #[tonic::async_trait]
108 /// # trait Example {}
109 ///
110 /// #[tonic::async_trait]
111 /// impl Example for Svc {
112 /// // ...
113 /// }
114 ///
115 /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip);
116 /// ```
send_compressed(mut self, encoding: CompressionEncoding) -> Self117 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
118 self.send_compression_encodings.enable(encoding);
119 self
120 }
121
122 /// Limits the maximum size of a decoded message.
123 ///
124 /// # Example
125 ///
126 /// The most common way of using this is through a server generated by tonic-build:
127 ///
128 /// ```rust
129 /// # struct Svc;
130 /// # struct ExampleServer<T>(T);
131 /// # impl<T> ExampleServer<T> {
132 /// # fn new(svc: T) -> Self { Self(svc) }
133 /// # fn max_decoding_message_size(self, _: usize) -> Self { self }
134 /// # }
135 /// # #[tonic::async_trait]
136 /// # trait Example {}
137 ///
138 /// #[tonic::async_trait]
139 /// impl Example for Svc {
140 /// // ...
141 /// }
142 ///
143 /// // Set the limit to 2MB, Defaults to 4MB.
144 /// let limit = 2 * 1024 * 1024;
145 /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit);
146 /// ```
max_decoding_message_size(mut self, limit: usize) -> Self147 pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
148 self.max_decoding_message_size = Some(limit);
149 self
150 }
151
152 /// Limits the maximum size of a encoded message.
153 ///
154 /// # Example
155 ///
156 /// The most common way of using this is through a server generated by tonic-build:
157 ///
158 /// ```rust
159 /// # struct Svc;
160 /// # struct ExampleServer<T>(T);
161 /// # impl<T> ExampleServer<T> {
162 /// # fn new(svc: T) -> Self { Self(svc) }
163 /// # fn max_encoding_message_size(self, _: usize) -> Self { self }
164 /// # }
165 /// # #[tonic::async_trait]
166 /// # trait Example {}
167 ///
168 /// #[tonic::async_trait]
169 /// impl Example for Svc {
170 /// // ...
171 /// }
172 ///
173 /// // Set the limit to 2MB, Defaults to 4MB.
174 /// let limit = 2 * 1024 * 1024;
175 /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit);
176 /// ```
max_encoding_message_size(mut self, limit: usize) -> Self177 pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
178 self.max_encoding_message_size = Some(limit);
179 self
180 }
181
182 #[doc(hidden)]
apply_compression_config( self, accept_encodings: EnabledCompressionEncodings, send_encodings: EnabledCompressionEncodings, ) -> Self183 pub fn apply_compression_config(
184 self,
185 accept_encodings: EnabledCompressionEncodings,
186 send_encodings: EnabledCompressionEncodings,
187 ) -> Self {
188 let mut this = self;
189
190 for &encoding in CompressionEncoding::encodings() {
191 if accept_encodings.is_enabled(encoding) {
192 this = this.accept_compressed(encoding);
193 }
194 if send_encodings.is_enabled(encoding) {
195 this = this.send_compressed(encoding);
196 }
197 }
198
199 this
200 }
201
202 #[doc(hidden)]
apply_max_message_size_config( self, max_decoding_message_size: Option<usize>, max_encoding_message_size: Option<usize>, ) -> Self203 pub fn apply_max_message_size_config(
204 self,
205 max_decoding_message_size: Option<usize>,
206 max_encoding_message_size: Option<usize>,
207 ) -> Self {
208 let mut this = self;
209
210 if let Some(limit) = max_decoding_message_size {
211 this = this.max_decoding_message_size(limit);
212 }
213 if let Some(limit) = max_encoding_message_size {
214 this = this.max_encoding_message_size(limit);
215 }
216
217 this
218 }
219
220 /// Handle a single unary gRPC request.
unary<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: UnaryService<T::Decode, Response = T::Encode>, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,221 pub async fn unary<S, B>(
222 &mut self,
223 mut service: S,
224 req: http::Request<B>,
225 ) -> http::Response<BoxBody>
226 where
227 S: UnaryService<T::Decode, Response = T::Encode>,
228 B: Body + Send + 'static,
229 B::Error: Into<crate::Error> + Send,
230 {
231 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
232 req.headers(),
233 self.send_compression_encodings,
234 );
235
236 let request = match self.map_request_unary(req).await {
237 Ok(r) => r,
238 Err(status) => {
239 return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
240 Err(status),
241 accept_encoding,
242 SingleMessageCompressionOverride::default(),
243 self.max_encoding_message_size,
244 );
245 }
246 };
247
248 let response = service
249 .call(request)
250 .await
251 .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
252
253 let compression_override = compression_override_from_response(&response);
254
255 self.map_response(
256 response,
257 accept_encoding,
258 compression_override,
259 self.max_encoding_message_size,
260 )
261 }
262
263 /// Handle a server side streaming request.
server_streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: ServerStreamingService<T::Decode, Response = T::Encode>, S::ResponseStream: Send + 'static, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,264 pub async fn server_streaming<S, B>(
265 &mut self,
266 mut service: S,
267 req: http::Request<B>,
268 ) -> http::Response<BoxBody>
269 where
270 S: ServerStreamingService<T::Decode, Response = T::Encode>,
271 S::ResponseStream: Send + 'static,
272 B: Body + Send + 'static,
273 B::Error: Into<crate::Error> + Send,
274 {
275 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
276 req.headers(),
277 self.send_compression_encodings,
278 );
279
280 let request = match self.map_request_unary(req).await {
281 Ok(r) => r,
282 Err(status) => {
283 return self.map_response::<S::ResponseStream>(
284 Err(status),
285 accept_encoding,
286 SingleMessageCompressionOverride::default(),
287 self.max_encoding_message_size,
288 );
289 }
290 };
291
292 let response = service.call(request).await;
293
294 self.map_response(
295 response,
296 accept_encoding,
297 // disabling compression of individual stream items must be done on
298 // the items themselves
299 SingleMessageCompressionOverride::default(),
300 self.max_encoding_message_size,
301 )
302 }
303
304 /// Handle a client side streaming gRPC request.
client_streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: ClientStreamingService<T::Decode, Response = T::Encode>, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send + 'static,305 pub async fn client_streaming<S, B>(
306 &mut self,
307 mut service: S,
308 req: http::Request<B>,
309 ) -> http::Response<BoxBody>
310 where
311 S: ClientStreamingService<T::Decode, Response = T::Encode>,
312 B: Body + Send + 'static,
313 B::Error: Into<crate::Error> + Send + 'static,
314 {
315 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
316 req.headers(),
317 self.send_compression_encodings,
318 );
319
320 let request = t!(self.map_request_streaming(req));
321
322 let response = service
323 .call(request)
324 .await
325 .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
326
327 let compression_override = compression_override_from_response(&response);
328
329 self.map_response(
330 response,
331 accept_encoding,
332 compression_override,
333 self.max_encoding_message_size,
334 )
335 }
336
337 /// Handle a bi-directional streaming gRPC request.
streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<BoxBody> where S: StreamingService<T::Decode, Response = T::Encode> + Send, S::ResponseStream: Send + 'static, B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,338 pub async fn streaming<S, B>(
339 &mut self,
340 mut service: S,
341 req: http::Request<B>,
342 ) -> http::Response<BoxBody>
343 where
344 S: StreamingService<T::Decode, Response = T::Encode> + Send,
345 S::ResponseStream: Send + 'static,
346 B: Body + Send + 'static,
347 B::Error: Into<crate::Error> + Send,
348 {
349 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
350 req.headers(),
351 self.send_compression_encodings,
352 );
353
354 let request = t!(self.map_request_streaming(req));
355
356 let response = service.call(request).await;
357
358 self.map_response(
359 response,
360 accept_encoding,
361 SingleMessageCompressionOverride::default(),
362 self.max_encoding_message_size,
363 )
364 }
365
map_request_unary<B>( &mut self, request: http::Request<B>, ) -> Result<Request<T::Decode>, Status> where B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,366 async fn map_request_unary<B>(
367 &mut self,
368 request: http::Request<B>,
369 ) -> Result<Request<T::Decode>, Status>
370 where
371 B: Body + Send + 'static,
372 B::Error: Into<crate::Error> + Send,
373 {
374 let request_compression_encoding = self.request_encoding_if_supported(&request)?;
375
376 let (parts, body) = request.into_parts();
377
378 let stream = Streaming::new_request(
379 self.codec.decoder(),
380 body,
381 request_compression_encoding,
382 self.max_decoding_message_size,
383 );
384
385 tokio::pin!(stream);
386
387 let message = stream
388 .try_next()
389 .await?
390 .ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?;
391
392 let mut req = Request::from_http_parts(parts, message);
393
394 if let Some(trailers) = stream.trailers().await? {
395 req.metadata_mut().merge(trailers);
396 }
397
398 Ok(req)
399 }
400
map_request_streaming<B>( &mut self, request: http::Request<B>, ) -> Result<Request<Streaming<T::Decode>>, Status> where B: Body + Send + 'static, B::Error: Into<crate::Error> + Send,401 fn map_request_streaming<B>(
402 &mut self,
403 request: http::Request<B>,
404 ) -> Result<Request<Streaming<T::Decode>>, Status>
405 where
406 B: Body + Send + 'static,
407 B::Error: Into<crate::Error> + Send,
408 {
409 let encoding = self.request_encoding_if_supported(&request)?;
410
411 let request = request.map(|body| {
412 Streaming::new_request(
413 self.codec.decoder(),
414 body,
415 encoding,
416 self.max_decoding_message_size,
417 )
418 });
419
420 Ok(Request::from_http(request))
421 }
422
map_response<B>( &mut self, response: Result<crate::Response<B>, Status>, accept_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> http::Response<BoxBody> where B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,423 fn map_response<B>(
424 &mut self,
425 response: Result<crate::Response<B>, Status>,
426 accept_encoding: Option<CompressionEncoding>,
427 compression_override: SingleMessageCompressionOverride,
428 max_message_size: Option<usize>,
429 ) -> http::Response<BoxBody>
430 where
431 B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
432 {
433 let response = match response {
434 Ok(r) => r,
435 Err(status) => return status.to_http(),
436 };
437
438 let (mut parts, body) = response.into_http().into_parts();
439
440 // Set the content type
441 parts.headers.insert(
442 http::header::CONTENT_TYPE,
443 http::header::HeaderValue::from_static("application/grpc"),
444 );
445
446 #[cfg(any(feature = "gzip", feature = "zstd"))]
447 if let Some(encoding) = accept_encoding {
448 // Set the content encoding
449 parts.headers.insert(
450 crate::codec::compression::ENCODING_HEADER,
451 encoding.into_header_value(),
452 );
453 }
454
455 let body = encode_server(
456 self.codec.encoder(),
457 body,
458 accept_encoding,
459 compression_override,
460 max_message_size,
461 );
462
463 http::Response::from_parts(parts, BoxBody::new(body))
464 }
465
request_encoding_if_supported<B>( &self, request: &http::Request<B>, ) -> Result<Option<CompressionEncoding>, Status>466 fn request_encoding_if_supported<B>(
467 &self,
468 request: &http::Request<B>,
469 ) -> Result<Option<CompressionEncoding>, Status> {
470 CompressionEncoding::from_encoding_header(
471 request.headers(),
472 self.accept_compression_encodings,
473 )
474 }
475 }
476
477 impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result478 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479 let mut f = f.debug_struct("Grpc");
480
481 f.field("codec", &self.codec);
482
483 f.field(
484 "accept_compression_encodings",
485 &self.accept_compression_encodings,
486 );
487
488 f.field(
489 "send_compression_encodings",
490 &self.send_compression_encodings,
491 );
492
493 f.finish()
494 }
495 }
496
compression_override_from_response<B, E>( res: &Result<crate::Response<B>, E>, ) -> SingleMessageCompressionOverride497 fn compression_override_from_response<B, E>(
498 res: &Result<crate::Response<B>, E>,
499 ) -> SingleMessageCompressionOverride {
500 res.as_ref()
501 .ok()
502 .and_then(|response| {
503 response
504 .extensions()
505 .get::<SingleMessageCompressionOverride>()
506 .copied()
507 })
508 .unwrap_or_default()
509 }
510