1 //! Routing between [`Service`]s and handlers.
2 
3 use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
4 #[cfg(feature = "tokio")]
5 use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6 use crate::{
7     body::{Body, HttpBody},
8     boxed::BoxedIntoRoute,
9     handler::Handler,
10     util::try_downcast,
11 };
12 use axum_core::response::{IntoResponse, Response};
13 use http::Request;
14 use std::{
15     convert::Infallible,
16     fmt,
17     task::{Context, Poll},
18 };
19 use sync_wrapper::SyncWrapper;
20 use tower_layer::Layer;
21 use tower_service::Service;
22 
23 pub mod future;
24 pub mod method_routing;
25 
26 mod into_make_service;
27 mod method_filter;
28 mod not_found;
29 pub(crate) mod path_router;
30 mod route;
31 mod strip_prefix;
32 pub(crate) mod url_params;
33 
34 #[cfg(test)]
35 mod tests;
36 
37 pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
38 
39 pub use self::method_routing::{
40     any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service,
41     options, options_service, patch, patch_service, post, post_service, put, put_service, trace,
42     trace_service, MethodRouter,
43 };
44 
45 macro_rules! panic_on_err {
46     ($expr:expr) => {
47         match $expr {
48             Ok(x) => x,
49             Err(err) => panic!("{err}"),
50         }
51     };
52 }
53 
54 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
55 pub(crate) struct RouteId(u32);
56 
57 /// The router type for composing handlers and services.
58 #[must_use]
59 pub struct Router<S = (), B = Body> {
60     path_router: PathRouter<S, B, false>,
61     fallback_router: PathRouter<S, B, true>,
62     default_fallback: bool,
63     catch_all_fallback: Fallback<S, B>,
64 }
65 
66 impl<S, B> Clone for Router<S, B> {
clone(&self) -> Self67     fn clone(&self) -> Self {
68         Self {
69             path_router: self.path_router.clone(),
70             fallback_router: self.fallback_router.clone(),
71             default_fallback: self.default_fallback,
72             catch_all_fallback: self.catch_all_fallback.clone(),
73         }
74     }
75 }
76 
77 impl<S, B> Default for Router<S, B>
78 where
79     B: HttpBody + Send + 'static,
80     S: Clone + Send + Sync + 'static,
81 {
default() -> Self82     fn default() -> Self {
83         Self::new()
84     }
85 }
86 
87 impl<S, B> fmt::Debug for Router<S, B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89         f.debug_struct("Router")
90             .field("path_router", &self.path_router)
91             .field("fallback_router", &self.fallback_router)
92             .field("default_fallback", &self.default_fallback)
93             .field("catch_all_fallback", &self.catch_all_fallback)
94             .finish()
95     }
96 }
97 
98 pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
99 pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
100 pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
101 pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback";
102 
103 impl<S, B> Router<S, B>
104 where
105     B: HttpBody + Send + 'static,
106     S: Clone + Send + Sync + 'static,
107 {
108     /// Create a new `Router`.
109     ///
110     /// Unless you add additional routes this will respond with `404 Not Found` to
111     /// all requests.
new() -> Self112     pub fn new() -> Self {
113         Self {
114             path_router: Default::default(),
115             fallback_router: PathRouter::new_fallback(),
116             default_fallback: true,
117             catch_all_fallback: Fallback::Default(Route::new(NotFound)),
118         }
119     }
120 
121     #[doc = include_str!("../docs/routing/route.md")]
122     #[track_caller]
route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self123     pub fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
124         panic_on_err!(self.path_router.route(path, method_router));
125         self
126     }
127 
128     #[doc = include_str!("../docs/routing/route_service.md")]
route_service<T>(mut self, path: &str, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,129     pub fn route_service<T>(mut self, path: &str, service: T) -> Self
130     where
131         T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
132         T::Response: IntoResponse,
133         T::Future: Send + 'static,
134     {
135         let service = match try_downcast::<Router<S, B>, _>(service) {
136             Ok(_) => {
137                 panic!(
138                     "Invalid route: `Router::route_service` cannot be used with `Router`s. \
139                      Use `Router::nest` instead"
140                 );
141             }
142             Err(service) => service,
143         };
144 
145         panic_on_err!(self.path_router.route_service(path, service));
146         self
147     }
148 
149     #[doc = include_str!("../docs/routing/nest.md")]
150     #[track_caller]
nest(mut self, path: &str, router: Router<S, B>) -> Self151     pub fn nest(mut self, path: &str, router: Router<S, B>) -> Self {
152         let Router {
153             path_router,
154             fallback_router,
155             default_fallback,
156             // we don't need to inherit the catch-all fallback. It is only used for CONNECT
157             // requests with an empty path. If we were to inherit the catch-all fallback
158             // it would end up matching `/{path}/*` which doesn't match empty paths.
159             catch_all_fallback: _,
160         } = router;
161 
162         panic_on_err!(self.path_router.nest(path, path_router));
163 
164         if !default_fallback {
165             panic_on_err!(self.fallback_router.nest(path, fallback_router));
166         }
167 
168         self
169     }
170 
171     /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
172     #[track_caller]
nest_service<T>(mut self, path: &str, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,173     pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
174     where
175         T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
176         T::Response: IntoResponse,
177         T::Future: Send + 'static,
178     {
179         panic_on_err!(self.path_router.nest_service(path, service));
180         self
181     }
182 
183     #[doc = include_str!("../docs/routing/merge.md")]
184     #[track_caller]
merge<R>(mut self, other: R) -> Self where R: Into<Router<S, B>>,185     pub fn merge<R>(mut self, other: R) -> Self
186     where
187         R: Into<Router<S, B>>,
188     {
189         const PANIC_MSG: &str =
190             "Failed to merge fallbacks. This is a bug in axum. Please file an issue";
191 
192         let Router {
193             path_router,
194             fallback_router: mut other_fallback,
195             default_fallback,
196             catch_all_fallback,
197         } = other.into();
198 
199         panic_on_err!(self.path_router.merge(path_router));
200 
201         match (self.default_fallback, default_fallback) {
202             // both have the default fallback
203             // use the one from other
204             (true, true) => {
205                 self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
206             }
207             // self has default fallback, other has a custom fallback
208             (true, false) => {
209                 self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
210                 self.default_fallback = false;
211             }
212             // self has a custom fallback, other has a default
213             (false, true) => {
214                 let fallback_router = std::mem::take(&mut self.fallback_router);
215                 other_fallback.merge(fallback_router).expect(PANIC_MSG);
216                 self.fallback_router = other_fallback;
217             }
218             // both have a custom fallback, not allowed
219             (false, false) => {
220                 panic!("Cannot merge two `Router`s that both have a fallback")
221             }
222         };
223 
224         self.catch_all_fallback = self
225             .catch_all_fallback
226             .merge(catch_all_fallback)
227             .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
228 
229         self
230     }
231 
232     #[doc = include_str!("../docs/routing/layer.md")]
layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: HttpBody + 'static,233     pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
234     where
235         L: Layer<Route<B>> + Clone + Send + 'static,
236         L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
237         <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
238         <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
239         <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
240         NewReqBody: HttpBody + 'static,
241     {
242         Router {
243             path_router: self.path_router.layer(layer.clone()),
244             fallback_router: self.fallback_router.layer(layer.clone()),
245             default_fallback: self.default_fallback,
246             catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
247         }
248     }
249 
250     #[doc = include_str!("../docs/routing/route_layer.md")]
251     #[track_caller]
route_layer<L>(self, layer: L) -> Self where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<B>> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static,252     pub fn route_layer<L>(self, layer: L) -> Self
253     where
254         L: Layer<Route<B>> + Clone + Send + 'static,
255         L::Service: Service<Request<B>> + Clone + Send + 'static,
256         <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
257         <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
258         <L::Service as Service<Request<B>>>::Future: Send + 'static,
259     {
260         Router {
261             path_router: self.path_router.route_layer(layer),
262             fallback_router: self.fallback_router,
263             default_fallback: self.default_fallback,
264             catch_all_fallback: self.catch_all_fallback,
265         }
266     }
267 
268     #[track_caller]
269     #[doc = include_str!("../docs/routing/fallback.md")]
fallback<H, T>(mut self, handler: H) -> Self where H: Handler<T, S, B>, T: 'static,270     pub fn fallback<H, T>(mut self, handler: H) -> Self
271     where
272         H: Handler<T, S, B>,
273         T: 'static,
274     {
275         self.catch_all_fallback =
276             Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
277         self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
278     }
279 
280     /// Add a fallback [`Service`] to the router.
281     ///
282     /// See [`Router::fallback`] for more details.
fallback_service<T>(mut self, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,283     pub fn fallback_service<T>(mut self, service: T) -> Self
284     where
285         T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
286         T::Response: IntoResponse,
287         T::Future: Send + 'static,
288     {
289         let route = Route::new(service);
290         self.catch_all_fallback = Fallback::Service(route.clone());
291         self.fallback_endpoint(Endpoint::Route(route))
292     }
293 
fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self294     fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self {
295         self.fallback_router.set_fallback(endpoint);
296         self.default_fallback = false;
297         self
298     }
299 
300     #[doc = include_str!("../docs/routing/with_state.md")]
with_state<S2>(self, state: S) -> Router<S2, B>301     pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
302         Router {
303             path_router: self.path_router.with_state(state.clone()),
304             fallback_router: self.fallback_router.with_state(state.clone()),
305             default_fallback: self.default_fallback,
306             catch_all_fallback: self.catch_all_fallback.with_state(state),
307         }
308     }
309 
call_with_state( &mut self, mut req: Request<B>, state: S, ) -> RouteFuture<B, Infallible>310     pub(crate) fn call_with_state(
311         &mut self,
312         mut req: Request<B>,
313         state: S,
314     ) -> RouteFuture<B, Infallible> {
315         // required for opaque routers to still inherit the fallback
316         // TODO(david): remove this feature in 0.7
317         if !self.default_fallback {
318             req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
319                 self.fallback_router.clone(),
320             )));
321         }
322 
323         match self.path_router.call_with_state(req, state) {
324             Ok(future) => future,
325             Err((mut req, state)) => {
326                 let super_fallback = req
327                     .extensions_mut()
328                     .remove::<SuperFallback<S, B>>()
329                     .map(|SuperFallback(path_router)| path_router.into_inner());
330 
331                 if let Some(mut super_fallback) = super_fallback {
332                     match super_fallback.call_with_state(req, state) {
333                         Ok(future) => return future,
334                         Err((req, state)) => {
335                             return self.catch_all_fallback.call_with_state(req, state);
336                         }
337                     }
338                 }
339 
340                 match self.fallback_router.call_with_state(req, state) {
341                     Ok(future) => future,
342                     Err((req, state)) => self.catch_all_fallback.call_with_state(req, state),
343                 }
344             }
345         }
346     }
347 }
348 
349 impl<B> Router<(), B>
350 where
351     B: HttpBody + Send + 'static,
352 {
353     /// Convert this router into a [`MakeService`], that is a [`Service`] whose
354     /// response is another service.
355     ///
356     /// This is useful when running your application with hyper's
357     /// [`Server`](hyper::server::Server):
358     ///
359     /// ```
360     /// use axum::{
361     ///     routing::get,
362     ///     Router,
363     /// };
364     ///
365     /// let app = Router::new().route("/", get(|| async { "Hi!" }));
366     ///
367     /// # async {
368     /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
369     ///     .serve(app.into_make_service())
370     ///     .await
371     ///     .expect("server failed");
372     /// # };
373     /// ```
374     ///
375     /// [`MakeService`]: tower::make::MakeService
into_make_service(self) -> IntoMakeService<Self>376     pub fn into_make_service(self) -> IntoMakeService<Self> {
377         // call `Router::with_state` such that everything is turned into `Route` eagerly
378         // rather than doing that per request
379         IntoMakeService::new(self.with_state(()))
380     }
381 
382     #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")]
383     #[cfg(feature = "tokio")]
into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>384     pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
385         // call `Router::with_state` such that everything is turned into `Route` eagerly
386         // rather than doing that per request
387         IntoMakeServiceWithConnectInfo::new(self.with_state(()))
388     }
389 }
390 
391 impl<B> Service<Request<B>> for Router<(), B>
392 where
393     B: HttpBody + Send + 'static,
394 {
395     type Response = Response;
396     type Error = Infallible;
397     type Future = RouteFuture<B, Infallible>;
398 
399     #[inline]
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>>400     fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
401         Poll::Ready(Ok(()))
402     }
403 
404     #[inline]
call(&mut self, req: Request<B>) -> Self::Future405     fn call(&mut self, req: Request<B>) -> Self::Future {
406         self.call_with_state(req, ())
407     }
408 }
409 
410 enum Fallback<S, B, E = Infallible> {
411     Default(Route<B, E>),
412     Service(Route<B, E>),
413     BoxedHandler(BoxedIntoRoute<S, B, E>),
414 }
415 
416 impl<S, B, E> Fallback<S, B, E>
417 where
418     S: Clone,
419 {
merge(self, other: Self) -> Option<Self>420     fn merge(self, other: Self) -> Option<Self> {
421         match (self, other) {
422             (Self::Default(_), pick @ Self::Default(_)) => Some(pick),
423             (Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick),
424             _ => None,
425         }
426     }
427 
map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2> where S: 'static, B: 'static, E: 'static, F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static,428     fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
429     where
430         S: 'static,
431         B: 'static,
432         E: 'static,
433         F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
434         B2: HttpBody + 'static,
435         E2: 'static,
436     {
437         match self {
438             Self::Default(route) => Fallback::Default(f(route)),
439             Self::Service(route) => Fallback::Service(f(route)),
440             Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
441         }
442     }
443 
with_state<S2>(self, state: S) -> Fallback<S2, B, E>444     fn with_state<S2>(self, state: S) -> Fallback<S2, B, E> {
445         match self {
446             Fallback::Default(route) => Fallback::Default(route),
447             Fallback::Service(route) => Fallback::Service(route),
448             Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
449         }
450     }
451 
call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E>452     fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
453         match self {
454             Fallback::Default(route) | Fallback::Service(route) => {
455                 RouteFuture::from_future(route.oneshot_inner(req))
456             }
457             Fallback::BoxedHandler(handler) => {
458                 let mut route = handler.clone().into_route(state);
459                 RouteFuture::from_future(route.oneshot_inner(req))
460             }
461         }
462     }
463 }
464 
465 impl<S, B, E> Clone for Fallback<S, B, E> {
clone(&self) -> Self466     fn clone(&self) -> Self {
467         match self {
468             Self::Default(inner) => Self::Default(inner.clone()),
469             Self::Service(inner) => Self::Service(inner.clone()),
470             Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
471         }
472     }
473 }
474 
475 impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result476     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477         match self {
478             Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
479             Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
480             Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
481         }
482     }
483 }
484 
485 #[allow(clippy::large_enum_variant)]
486 enum Endpoint<S, B> {
487     MethodRouter(MethodRouter<S, B>),
488     Route(Route<B>),
489 }
490 
491 impl<S, B> Endpoint<S, B>
492 where
493     B: HttpBody + Send + 'static,
494     S: Clone + Send + Sync + 'static,
495 {
layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: HttpBody + 'static,496     fn layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody>
497     where
498         L: Layer<Route<B>> + Clone + Send + 'static,
499         L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
500         <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
501         <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
502         <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
503         NewReqBody: HttpBody + 'static,
504     {
505         match self {
506             Endpoint::MethodRouter(method_router) => {
507                 Endpoint::MethodRouter(method_router.layer(layer))
508             }
509             Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
510         }
511     }
512 }
513 
514 impl<S, B> Clone for Endpoint<S, B> {
clone(&self) -> Self515     fn clone(&self) -> Self {
516         match self {
517             Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
518             Self::Route(inner) => Self::Route(inner.clone()),
519         }
520     }
521 }
522 
523 impl<S, B> fmt::Debug for Endpoint<S, B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result524     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
525         match self {
526             Self::MethodRouter(method_router) => {
527                 f.debug_tuple("MethodRouter").field(method_router).finish()
528             }
529             Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
530         }
531     }
532 }
533 
534 struct SuperFallback<S, B>(SyncWrapper<PathRouter<S, B, true>>);
535 
536 #[test]
537 #[allow(warnings)]
traits()538 fn traits() {
539     use crate::test_helpers::*;
540     assert_send::<Router<(), ()>>();
541 }
542