1 //! gRPC interceptors which are a kind of middleware.
2 //!
3 //! See [`Interceptor`] for more details.
4 
5 use crate::{
6     body::{boxed, BoxBody},
7     request::SanitizeHeaders,
8     Status,
9 };
10 use bytes::Bytes;
11 use pin_project::pin_project;
12 use std::{
13     fmt,
14     future::Future,
15     pin::Pin,
16     task::{Context, Poll},
17 };
18 use tower_layer::Layer;
19 use tower_service::Service;
20 
21 /// A gRPC interceptor.
22 ///
23 /// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
24 /// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
25 /// request. Two, cancel a request with a `Status`.
26 ///
27 /// Any function that satisfies the bound `FnMut(Request<()>) -> Result<Request<()>, Status>` can be
28 /// used as an `Interceptor`.
29 ///
30 /// An interceptor can be used on both the server and client side through the `tonic-build` crate's
31 /// generated structs.
32 ///
33 /// See the [interceptor example][example] for more details.
34 ///
35 /// If you need more powerful middleware, [tower] is the recommended approach. You can find
36 /// examples of how to use tower with tonic [here][tower-example].
37 ///
38 /// Additionally, interceptors is not the recommended way to add logging to your service. For that
39 /// a [tower] middleware is more appropriate since it can also act on the response. For example
40 /// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
41 /// middleware supports gRPC out of the box.
42 ///
43 /// [tower]: https://crates.io/crates/tower
44 /// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
45 /// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
46 pub trait Interceptor {
47     /// Intercept a request before it is sent, optionally cancelling it.
call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>48     fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
49 }
50 
51 impl<F> Interceptor for F
52 where
53     F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
54 {
call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>55     fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
56         self(request)
57     }
58 }
59 
60 /// Create a new interceptor layer.
61 ///
62 /// See [`Interceptor`] for more details.
interceptor<F>(f: F) -> InterceptorLayer<F> where F: Interceptor,63 pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
64 where
65     F: Interceptor,
66 {
67     InterceptorLayer { f }
68 }
69 
70 /// A gRPC interceptor that can be used as a [`Layer`],
71 /// created by calling [`interceptor`].
72 ///
73 /// See [`Interceptor`] for more details.
74 #[derive(Debug, Clone, Copy)]
75 pub struct InterceptorLayer<F> {
76     f: F,
77 }
78 
79 impl<S, F> Layer<S> for InterceptorLayer<F>
80 where
81     F: Interceptor + Clone,
82 {
83     type Service = InterceptedService<S, F>;
84 
layer(&self, service: S) -> Self::Service85     fn layer(&self, service: S) -> Self::Service {
86         InterceptedService::new(service, self.f.clone())
87     }
88 }
89 
90 /// A service wrapped in an interceptor middleware.
91 ///
92 /// See [`Interceptor`] for more details.
93 #[derive(Clone, Copy)]
94 pub struct InterceptedService<S, F> {
95     inner: S,
96     f: F,
97 }
98 
99 impl<S, F> InterceptedService<S, F> {
100     /// Create a new `InterceptedService` that wraps `S` and intercepts each request with the
101     /// function `F`.
new(service: S, f: F) -> Self where F: Interceptor,102     pub fn new(service: S, f: F) -> Self
103     where
104         F: Interceptor,
105     {
106         Self { inner: service, f }
107     }
108 }
109 
110 impl<S, F> fmt::Debug for InterceptedService<S, F>
111 where
112     S: fmt::Debug,
113 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result114     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115         f.debug_struct("InterceptedService")
116             .field("inner", &self.inner)
117             .field("f", &format_args!("{}", std::any::type_name::<F>()))
118             .finish()
119     }
120 }
121 
122 impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
123 where
124     ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
125     F: Interceptor,
126     S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
127     S::Error: Into<crate::Error>,
128     ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
129     ResBody::Error: Into<crate::Error>,
130 {
131     type Response = http::Response<BoxBody>;
132     type Error = S::Error;
133     type Future = ResponseFuture<S::Future>;
134 
135     #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>136     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137         self.inner.poll_ready(cx)
138     }
139 
call(&mut self, req: http::Request<ReqBody>) -> Self::Future140     fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
141         // It is bad practice to modify the body (i.e. Message) of the request via an interceptor.
142         // To avoid exposing the body of the request to the interceptor function, we first remove it
143         // here, allow the interceptor to modify the metadata and extensions, and then recreate the
144         // HTTP request with the body. Tonic requests do not preserve the URI, HTTP version, and
145         // HTTP method of the HTTP request, so we extract them here and then add them back in below.
146         let uri = req.uri().clone();
147         let method = req.method().clone();
148         let version = req.version();
149         let req = crate::Request::from_http(req);
150         let (metadata, extensions, msg) = req.into_parts();
151 
152         match self
153             .f
154             .call(crate::Request::from_parts(metadata, extensions, ()))
155         {
156             Ok(req) => {
157                 let (metadata, extensions, _) = req.into_parts();
158                 let req = crate::Request::from_parts(metadata, extensions, msg);
159                 let req = req.into_http(uri, method, version, SanitizeHeaders::No);
160                 ResponseFuture::future(self.inner.call(req))
161             }
162             Err(status) => ResponseFuture::status(status),
163         }
164     }
165 }
166 
167 // required to use `InterceptedService` with `Router`
168 impl<S, F> crate::server::NamedService for InterceptedService<S, F>
169 where
170     S: crate::server::NamedService,
171 {
172     const NAME: &'static str = S::NAME;
173 }
174 
175 /// Response future for [`InterceptedService`].
176 #[pin_project]
177 #[derive(Debug)]
178 pub struct ResponseFuture<F> {
179     #[pin]
180     kind: Kind<F>,
181 }
182 
183 impl<F> ResponseFuture<F> {
future(future: F) -> Self184     fn future(future: F) -> Self {
185         Self {
186             kind: Kind::Future(future),
187         }
188     }
189 
status(status: Status) -> Self190     fn status(status: Status) -> Self {
191         Self {
192             kind: Kind::Status(Some(status)),
193         }
194     }
195 }
196 
197 #[pin_project(project = KindProj)]
198 #[derive(Debug)]
199 enum Kind<F> {
200     Future(#[pin] F),
201     Status(Option<Status>),
202 }
203 
204 impl<F, E, B> Future for ResponseFuture<F>
205 where
206     F: Future<Output = Result<http::Response<B>, E>>,
207     E: Into<crate::Error>,
208     B: Default + http_body::Body<Data = Bytes> + Send + 'static,
209     B::Error: Into<crate::Error>,
210 {
211     type Output = Result<http::Response<BoxBody>, E>;
212 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>213     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214         match self.project().kind.project() {
215             KindProj::Future(future) => future
216                 .poll(cx)
217                 .map(|result| result.map(|res| res.map(boxed))),
218             KindProj::Status(status) => {
219                 let response = status
220                     .take()
221                     .unwrap()
222                     .to_http()
223                     .map(|_| B::default())
224                     .map(boxed);
225                 Poll::Ready(Ok(response))
226             }
227         }
228     }
229 }
230 
231 #[cfg(test)]
232 mod tests {
233     #[allow(unused_imports)]
234     use super::*;
235     use http::header::HeaderMap;
236     use std::{
237         pin::Pin,
238         task::{Context, Poll},
239     };
240     use tower::ServiceExt;
241 
242     #[derive(Debug, Default)]
243     struct TestBody;
244 
245     impl http_body::Body for TestBody {
246         type Data = Bytes;
247         type Error = Status;
248 
poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>249         fn poll_data(
250             self: Pin<&mut Self>,
251             _cx: &mut Context<'_>,
252         ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
253             Poll::Ready(None)
254         }
255 
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>256         fn poll_trailers(
257             self: Pin<&mut Self>,
258             _cx: &mut Context<'_>,
259         ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
260             Poll::Ready(Ok(None))
261         }
262     }
263 
264     #[tokio::test]
doesnt_remove_headers_from_requests()265     async fn doesnt_remove_headers_from_requests() {
266         let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
267             assert_eq!(
268                 request
269                     .headers()
270                     .get("user-agent")
271                     .expect("missing in leaf service"),
272                 "test-tonic"
273             );
274 
275             Ok::<_, Status>(http::Response::new(TestBody))
276         });
277 
278         let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
279             assert_eq!(
280                 request
281                     .metadata()
282                     .get("user-agent")
283                     .expect("missing in interceptor"),
284                 "test-tonic"
285             );
286 
287             Ok(request)
288         });
289 
290         let request = http::Request::builder()
291             .header("user-agent", "test-tonic")
292             .body(TestBody)
293             .unwrap();
294 
295         svc.oneshot(request).await.unwrap();
296     }
297 
298     #[tokio::test]
handles_intercepted_status_as_response()299     async fn handles_intercepted_status_as_response() {
300         let message = "Blocked by the interceptor";
301         let expected = Status::permission_denied(message).to_http();
302 
303         let svc = tower::service_fn(|_: http::Request<TestBody>| async {
304             Ok::<_, Status>(http::Response::new(TestBody))
305         });
306 
307         let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
308             Err(Status::permission_denied(message))
309         });
310 
311         let request = http::Request::builder().body(TestBody).unwrap();
312         let response = svc.oneshot(request).await.unwrap();
313 
314         assert_eq!(expected.status(), response.status());
315         assert_eq!(expected.version(), response.version());
316         assert_eq!(expected.headers(), response.headers());
317     }
318 
319     #[tokio::test]
doesnt_change_http_method()320     async fn doesnt_change_http_method() {
321         let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
322             assert_eq!(request.method(), http::Method::OPTIONS);
323 
324             Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
325         });
326 
327         let svc = InterceptedService::new(svc, Ok);
328 
329         let request = http::Request::builder()
330             .method(http::Method::OPTIONS)
331             .body(hyper::Body::empty())
332             .unwrap();
333 
334         svc.oneshot(request).await.unwrap();
335     }
336 }
337