1  //! This module provides functionality to aid managing routing requests between [`Service`]s.
2  //!
3  //! # Example
4  //!
5  //! [`Steer`] can for example be used to create a router, akin to what you might find in web
6  //! frameworks.
7  //!
8  //! Here, `GET /` will be sent to the `root` service, while all other requests go to `not_found`.
9  //!
10  //! ```rust
11  //! # use std::task::{Context, Poll};
12  //! # use tower_service::Service;
13  //! # use futures_util::future::{ready, Ready, poll_fn};
14  //! # use tower::steer::Steer;
15  //! # use tower::service_fn;
16  //! # use tower::util::BoxService;
17  //! # use tower::ServiceExt;
18  //! # use std::convert::Infallible;
19  //! use http::{Request, Response, StatusCode, Method};
20  //!
21  //! # #[tokio::main]
22  //! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
23  //! // Service that responds to `GET /`
24  //! let root = service_fn(|req: Request<String>| async move {
25  //!     # assert_eq!(req.uri().path(), "/");
26  //!     let res = Response::new("Hello, World!".to_string());
27  //!     Ok::<_, Infallible>(res)
28  //! });
29  //! // We have to box the service so its type gets erased and we can put it in a `Vec` with other
30  //! // services
31  //! let root = BoxService::new(root);
32  //!
33  //! // Service that responds with `404 Not Found` to all requests
34  //! let not_found = service_fn(|req: Request<String>| async move {
35  //!     let res = Response::builder()
36  //!         .status(StatusCode::NOT_FOUND)
37  //!         .body(String::new())
38  //!         .expect("response is valid");
39  //!     Ok::<_, Infallible>(res)
40  //! });
41  //! // Box that as well
42  //! let not_found = BoxService::new(not_found);
43  //!
44  //! let mut svc = Steer::new(
45  //!     // All services we route between
46  //!     vec![root, not_found],
47  //!     // How we pick which service to send the request to
48  //!     |req: &Request<String>, _services: &[_]| {
49  //!         if req.method() == Method::GET && req.uri().path() == "/" {
50  //!             0 // Index of `root`
51  //!         } else {
52  //!             1 // Index of `not_found`
53  //!         }
54  //!     },
55  //! );
56  //!
57  //! // This request will get sent to `root`
58  //! let req = Request::get("/").body(String::new()).unwrap();
59  //! let res = svc.ready().await?.call(req).await?;
60  //! assert_eq!(res.into_body(), "Hello, World!");
61  //!
62  //! // This request will get sent to `not_found`
63  //! let req = Request::get("/does/not/exist").body(String::new()).unwrap();
64  //! let res = svc.ready().await?.call(req).await?;
65  //! assert_eq!(res.status(), StatusCode::NOT_FOUND);
66  //! assert_eq!(res.into_body(), "");
67  //! #
68  //! # Ok(())
69  //! # }
70  //! ```
71  use std::task::{Context, Poll};
72  use std::{collections::VecDeque, fmt, marker::PhantomData};
73  use tower_service::Service;
74  
75  /// This is how callers of [`Steer`] tell it which `Service` a `Req` corresponds to.
76  pub trait Picker<S, Req> {
77      /// Return an index into the iterator of `Service` passed to [`Steer::new`].
pick(&mut self, r: &Req, services: &[S]) -> usize78      fn pick(&mut self, r: &Req, services: &[S]) -> usize;
79  }
80  
81  impl<S, F, Req> Picker<S, Req> for F
82  where
83      F: Fn(&Req, &[S]) -> usize,
84  {
pick(&mut self, r: &Req, services: &[S]) -> usize85      fn pick(&mut self, r: &Req, services: &[S]) -> usize {
86          self(r, services)
87      }
88  }
89  
90  /// [`Steer`] manages a list of [`Service`]s which all handle the same type of request.
91  ///
92  /// An example use case is a sharded service.
93  /// It accepts new requests, then:
94  /// 1. Determines, via the provided [`Picker`], which [`Service`] the request coresponds to.
95  /// 2. Waits (in [`Service::poll_ready`]) for *all* services to be ready.
96  /// 3. Calls the correct [`Service`] with the request, and returns a future corresponding to the
97  ///    call.
98  ///
99  /// Note that [`Steer`] must wait for all services to be ready since it can't know ahead of time
100  /// which [`Service`] the next message will arrive for, and is unwilling to buffer items
101  /// indefinitely. This will cause head-of-line blocking unless paired with a [`Service`] that does
102  /// buffer items indefinitely, and thus always returns [`Poll::Ready`]. For example, wrapping each
103  /// component service with a [`Buffer`] with a high enough limit (the maximum number of concurrent
104  /// requests) will prevent head-of-line blocking in [`Steer`].
105  ///
106  /// [`Buffer`]: crate::buffer::Buffer
107  pub struct Steer<S, F, Req> {
108      router: F,
109      services: Vec<S>,
110      not_ready: VecDeque<usize>,
111      _phantom: PhantomData<Req>,
112  }
113  
114  impl<S, F, Req> Steer<S, F, Req> {
115      /// Make a new [`Steer`] with a list of [`Service`]'s and a [`Picker`].
116      ///
117      /// Note: the order of the [`Service`]'s is significant for [`Picker::pick`]'s return value.
new(services: impl IntoIterator<Item = S>, router: F) -> Self118      pub fn new(services: impl IntoIterator<Item = S>, router: F) -> Self {
119          let services: Vec<_> = services.into_iter().collect();
120          let not_ready: VecDeque<_> = services.iter().enumerate().map(|(i, _)| i).collect();
121          Self {
122              router,
123              services,
124              not_ready,
125              _phantom: PhantomData,
126          }
127      }
128  }
129  
130  impl<S, Req, F> Service<Req> for Steer<S, F, Req>
131  where
132      S: Service<Req>,
133      F: Picker<S, Req>,
134  {
135      type Response = S::Response;
136      type Error = S::Error;
137      type Future = S::Future;
138  
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>139      fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140          loop {
141              // must wait for *all* services to be ready.
142              // this will cause head-of-line blocking unless the underlying services are always ready.
143              if self.not_ready.is_empty() {
144                  return Poll::Ready(Ok(()));
145              } else {
146                  if self.services[self.not_ready[0]]
147                      .poll_ready(cx)?
148                      .is_pending()
149                  {
150                      return Poll::Pending;
151                  }
152  
153                  self.not_ready.pop_front();
154              }
155          }
156      }
157  
call(&mut self, req: Req) -> Self::Future158      fn call(&mut self, req: Req) -> Self::Future {
159          assert!(
160              self.not_ready.is_empty(),
161              "Steer must wait for all services to be ready. Did you forget to call poll_ready()?"
162          );
163  
164          let idx = self.router.pick(&req, &self.services[..]);
165          let cl = &mut self.services[idx];
166          self.not_ready.push_back(idx);
167          cl.call(req)
168      }
169  }
170  
171  impl<S, F, Req> Clone for Steer<S, F, Req>
172  where
173      S: Clone,
174      F: Clone,
175  {
clone(&self) -> Self176      fn clone(&self) -> Self {
177          Self {
178              router: self.router.clone(),
179              services: self.services.clone(),
180              not_ready: self.not_ready.clone(),
181              _phantom: PhantomData,
182          }
183      }
184  }
185  
186  impl<S, F, Req> fmt::Debug for Steer<S, F, Req>
187  where
188      S: fmt::Debug,
189      F: fmt::Debug,
190  {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result191      fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192          let Self {
193              router,
194              services,
195              not_ready,
196              _phantom,
197          } = self;
198          f.debug_struct("Steer")
199              .field("router", router)
200              .field("services", services)
201              .field("not_ready", not_ready)
202              .finish()
203      }
204  }
205