1 //! gRPC interceptors which are a kind of middleware.
2 //!
3 //! See [`Interceptor`] for more details.
4
5 use crate::{
6 body::{boxed, BoxBody},
7 request::SanitizeHeaders,
8 Status,
9 };
10 use bytes::Bytes;
11 use pin_project::pin_project;
12 use std::{
13 fmt,
14 future::Future,
15 pin::Pin,
16 task::{Context, Poll},
17 };
18 use tower_layer::Layer;
19 use tower_service::Service;
20
21 /// A gRPC interceptor.
22 ///
23 /// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
24 /// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
25 /// request. Two, cancel a request with a `Status`.
26 ///
27 /// Any function that satisfies the bound `FnMut(Request<()>) -> Result<Request<()>, Status>` can be
28 /// used as an `Interceptor`.
29 ///
30 /// An interceptor can be used on both the server and client side through the `tonic-build` crate's
31 /// generated structs.
32 ///
33 /// See the [interceptor example][example] for more details.
34 ///
35 /// If you need more powerful middleware, [tower] is the recommended approach. You can find
36 /// examples of how to use tower with tonic [here][tower-example].
37 ///
38 /// Additionally, interceptors is not the recommended way to add logging to your service. For that
39 /// a [tower] middleware is more appropriate since it can also act on the response. For example
40 /// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
41 /// middleware supports gRPC out of the box.
42 ///
43 /// [tower]: https://crates.io/crates/tower
44 /// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
45 /// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
46 pub trait Interceptor {
47 /// Intercept a request before it is sent, optionally cancelling it.
call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>48 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
49 }
50
51 impl<F> Interceptor for F
52 where
53 F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
54 {
call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>55 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
56 self(request)
57 }
58 }
59
60 /// Create a new interceptor layer.
61 ///
62 /// See [`Interceptor`] for more details.
interceptor<F>(f: F) -> InterceptorLayer<F> where F: Interceptor,63 pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
64 where
65 F: Interceptor,
66 {
67 InterceptorLayer { f }
68 }
69
70 /// A gRPC interceptor that can be used as a [`Layer`],
71 /// created by calling [`interceptor`].
72 ///
73 /// See [`Interceptor`] for more details.
74 #[derive(Debug, Clone, Copy)]
75 pub struct InterceptorLayer<F> {
76 f: F,
77 }
78
79 impl<S, F> Layer<S> for InterceptorLayer<F>
80 where
81 F: Interceptor + Clone,
82 {
83 type Service = InterceptedService<S, F>;
84
layer(&self, service: S) -> Self::Service85 fn layer(&self, service: S) -> Self::Service {
86 InterceptedService::new(service, self.f.clone())
87 }
88 }
89
90 /// A service wrapped in an interceptor middleware.
91 ///
92 /// See [`Interceptor`] for more details.
93 #[derive(Clone, Copy)]
94 pub struct InterceptedService<S, F> {
95 inner: S,
96 f: F,
97 }
98
99 impl<S, F> InterceptedService<S, F> {
100 /// Create a new `InterceptedService` that wraps `S` and intercepts each request with the
101 /// function `F`.
new(service: S, f: F) -> Self where F: Interceptor,102 pub fn new(service: S, f: F) -> Self
103 where
104 F: Interceptor,
105 {
106 Self { inner: service, f }
107 }
108 }
109
110 impl<S, F> fmt::Debug for InterceptedService<S, F>
111 where
112 S: fmt::Debug,
113 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("InterceptedService")
116 .field("inner", &self.inner)
117 .field("f", &format_args!("{}", std::any::type_name::<F>()))
118 .finish()
119 }
120 }
121
122 impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
123 where
124 ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
125 F: Interceptor,
126 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
127 S::Error: Into<crate::Error>,
128 ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
129 ResBody::Error: Into<crate::Error>,
130 {
131 type Response = http::Response<BoxBody>;
132 type Error = S::Error;
133 type Future = ResponseFuture<S::Future>;
134
135 #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>136 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137 self.inner.poll_ready(cx)
138 }
139
call(&mut self, req: http::Request<ReqBody>) -> Self::Future140 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
141 // It is bad practice to modify the body (i.e. Message) of the request via an interceptor.
142 // To avoid exposing the body of the request to the interceptor function, we first remove it
143 // here, allow the interceptor to modify the metadata and extensions, and then recreate the
144 // HTTP request with the body. Tonic requests do not preserve the URI, HTTP version, and
145 // HTTP method of the HTTP request, so we extract them here and then add them back in below.
146 let uri = req.uri().clone();
147 let method = req.method().clone();
148 let version = req.version();
149 let req = crate::Request::from_http(req);
150 let (metadata, extensions, msg) = req.into_parts();
151
152 match self
153 .f
154 .call(crate::Request::from_parts(metadata, extensions, ()))
155 {
156 Ok(req) => {
157 let (metadata, extensions, _) = req.into_parts();
158 let req = crate::Request::from_parts(metadata, extensions, msg);
159 let req = req.into_http(uri, method, version, SanitizeHeaders::No);
160 ResponseFuture::future(self.inner.call(req))
161 }
162 Err(status) => ResponseFuture::status(status),
163 }
164 }
165 }
166
167 // required to use `InterceptedService` with `Router`
168 impl<S, F> crate::server::NamedService for InterceptedService<S, F>
169 where
170 S: crate::server::NamedService,
171 {
172 const NAME: &'static str = S::NAME;
173 }
174
175 /// Response future for [`InterceptedService`].
176 #[pin_project]
177 #[derive(Debug)]
178 pub struct ResponseFuture<F> {
179 #[pin]
180 kind: Kind<F>,
181 }
182
183 impl<F> ResponseFuture<F> {
future(future: F) -> Self184 fn future(future: F) -> Self {
185 Self {
186 kind: Kind::Future(future),
187 }
188 }
189
status(status: Status) -> Self190 fn status(status: Status) -> Self {
191 Self {
192 kind: Kind::Status(Some(status)),
193 }
194 }
195 }
196
197 #[pin_project(project = KindProj)]
198 #[derive(Debug)]
199 enum Kind<F> {
200 Future(#[pin] F),
201 Status(Option<Status>),
202 }
203
204 impl<F, E, B> Future for ResponseFuture<F>
205 where
206 F: Future<Output = Result<http::Response<B>, E>>,
207 E: Into<crate::Error>,
208 B: Default + http_body::Body<Data = Bytes> + Send + 'static,
209 B::Error: Into<crate::Error>,
210 {
211 type Output = Result<http::Response<BoxBody>, E>;
212
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>213 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214 match self.project().kind.project() {
215 KindProj::Future(future) => future
216 .poll(cx)
217 .map(|result| result.map(|res| res.map(boxed))),
218 KindProj::Status(status) => {
219 let response = status
220 .take()
221 .unwrap()
222 .to_http()
223 .map(|_| B::default())
224 .map(boxed);
225 Poll::Ready(Ok(response))
226 }
227 }
228 }
229 }
230
231 #[cfg(test)]
232 mod tests {
233 #[allow(unused_imports)]
234 use super::*;
235 use http::header::HeaderMap;
236 use std::{
237 pin::Pin,
238 task::{Context, Poll},
239 };
240 use tower::ServiceExt;
241
242 #[derive(Debug, Default)]
243 struct TestBody;
244
245 impl http_body::Body for TestBody {
246 type Data = Bytes;
247 type Error = Status;
248
poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>249 fn poll_data(
250 self: Pin<&mut Self>,
251 _cx: &mut Context<'_>,
252 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
253 Poll::Ready(None)
254 }
255
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>256 fn poll_trailers(
257 self: Pin<&mut Self>,
258 _cx: &mut Context<'_>,
259 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
260 Poll::Ready(Ok(None))
261 }
262 }
263
264 #[tokio::test]
doesnt_remove_headers_from_requests()265 async fn doesnt_remove_headers_from_requests() {
266 let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
267 assert_eq!(
268 request
269 .headers()
270 .get("user-agent")
271 .expect("missing in leaf service"),
272 "test-tonic"
273 );
274
275 Ok::<_, Status>(http::Response::new(TestBody))
276 });
277
278 let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
279 assert_eq!(
280 request
281 .metadata()
282 .get("user-agent")
283 .expect("missing in interceptor"),
284 "test-tonic"
285 );
286
287 Ok(request)
288 });
289
290 let request = http::Request::builder()
291 .header("user-agent", "test-tonic")
292 .body(TestBody)
293 .unwrap();
294
295 svc.oneshot(request).await.unwrap();
296 }
297
298 #[tokio::test]
handles_intercepted_status_as_response()299 async fn handles_intercepted_status_as_response() {
300 let message = "Blocked by the interceptor";
301 let expected = Status::permission_denied(message).to_http();
302
303 let svc = tower::service_fn(|_: http::Request<TestBody>| async {
304 Ok::<_, Status>(http::Response::new(TestBody))
305 });
306
307 let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
308 Err(Status::permission_denied(message))
309 });
310
311 let request = http::Request::builder().body(TestBody).unwrap();
312 let response = svc.oneshot(request).await.unwrap();
313
314 assert_eq!(expected.status(), response.status());
315 assert_eq!(expected.version(), response.version());
316 assert_eq!(expected.headers(), response.headers());
317 }
318
319 #[tokio::test]
doesnt_change_http_method()320 async fn doesnt_change_http_method() {
321 let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
322 assert_eq!(request.method(), http::Method::OPTIONS);
323
324 Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
325 });
326
327 let svc = InterceptedService::new(svc, Ok);
328
329 let request = http::Request::builder()
330 .method(http::Method::OPTIONS)
331 .body(hyper::Body::empty())
332 .unwrap();
333
334 svc.oneshot(request).await.unwrap();
335 }
336 }
337