1 //! Routing between [`Service`]s and handlers.
2
3 use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
4 #[cfg(feature = "tokio")]
5 use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6 use crate::{
7 body::{Body, HttpBody},
8 boxed::BoxedIntoRoute,
9 handler::Handler,
10 util::try_downcast,
11 };
12 use axum_core::response::{IntoResponse, Response};
13 use http::Request;
14 use std::{
15 convert::Infallible,
16 fmt,
17 task::{Context, Poll},
18 };
19 use sync_wrapper::SyncWrapper;
20 use tower_layer::Layer;
21 use tower_service::Service;
22
23 pub mod future;
24 pub mod method_routing;
25
26 mod into_make_service;
27 mod method_filter;
28 mod not_found;
29 pub(crate) mod path_router;
30 mod route;
31 mod strip_prefix;
32 pub(crate) mod url_params;
33
34 #[cfg(test)]
35 mod tests;
36
37 pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
38
39 pub use self::method_routing::{
40 any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service,
41 options, options_service, patch, patch_service, post, post_service, put, put_service, trace,
42 trace_service, MethodRouter,
43 };
44
45 macro_rules! panic_on_err {
46 ($expr:expr) => {
47 match $expr {
48 Ok(x) => x,
49 Err(err) => panic!("{err}"),
50 }
51 };
52 }
53
54 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
55 pub(crate) struct RouteId(u32);
56
57 /// The router type for composing handlers and services.
58 #[must_use]
59 pub struct Router<S = (), B = Body> {
60 path_router: PathRouter<S, B, false>,
61 fallback_router: PathRouter<S, B, true>,
62 default_fallback: bool,
63 catch_all_fallback: Fallback<S, B>,
64 }
65
66 impl<S, B> Clone for Router<S, B> {
clone(&self) -> Self67 fn clone(&self) -> Self {
68 Self {
69 path_router: self.path_router.clone(),
70 fallback_router: self.fallback_router.clone(),
71 default_fallback: self.default_fallback,
72 catch_all_fallback: self.catch_all_fallback.clone(),
73 }
74 }
75 }
76
77 impl<S, B> Default for Router<S, B>
78 where
79 B: HttpBody + Send + 'static,
80 S: Clone + Send + Sync + 'static,
81 {
default() -> Self82 fn default() -> Self {
83 Self::new()
84 }
85 }
86
87 impl<S, B> fmt::Debug for Router<S, B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("Router")
90 .field("path_router", &self.path_router)
91 .field("fallback_router", &self.fallback_router)
92 .field("default_fallback", &self.default_fallback)
93 .field("catch_all_fallback", &self.catch_all_fallback)
94 .finish()
95 }
96 }
97
98 pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
99 pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
100 pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
101 pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback";
102
103 impl<S, B> Router<S, B>
104 where
105 B: HttpBody + Send + 'static,
106 S: Clone + Send + Sync + 'static,
107 {
108 /// Create a new `Router`.
109 ///
110 /// Unless you add additional routes this will respond with `404 Not Found` to
111 /// all requests.
new() -> Self112 pub fn new() -> Self {
113 Self {
114 path_router: Default::default(),
115 fallback_router: PathRouter::new_fallback(),
116 default_fallback: true,
117 catch_all_fallback: Fallback::Default(Route::new(NotFound)),
118 }
119 }
120
121 #[doc = include_str!("../docs/routing/route.md")]
122 #[track_caller]
route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self123 pub fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
124 panic_on_err!(self.path_router.route(path, method_router));
125 self
126 }
127
128 #[doc = include_str!("../docs/routing/route_service.md")]
route_service<T>(mut self, path: &str, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,129 pub fn route_service<T>(mut self, path: &str, service: T) -> Self
130 where
131 T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
132 T::Response: IntoResponse,
133 T::Future: Send + 'static,
134 {
135 let service = match try_downcast::<Router<S, B>, _>(service) {
136 Ok(_) => {
137 panic!(
138 "Invalid route: `Router::route_service` cannot be used with `Router`s. \
139 Use `Router::nest` instead"
140 );
141 }
142 Err(service) => service,
143 };
144
145 panic_on_err!(self.path_router.route_service(path, service));
146 self
147 }
148
149 #[doc = include_str!("../docs/routing/nest.md")]
150 #[track_caller]
nest(mut self, path: &str, router: Router<S, B>) -> Self151 pub fn nest(mut self, path: &str, router: Router<S, B>) -> Self {
152 let Router {
153 path_router,
154 fallback_router,
155 default_fallback,
156 // we don't need to inherit the catch-all fallback. It is only used for CONNECT
157 // requests with an empty path. If we were to inherit the catch-all fallback
158 // it would end up matching `/{path}/*` which doesn't match empty paths.
159 catch_all_fallback: _,
160 } = router;
161
162 panic_on_err!(self.path_router.nest(path, path_router));
163
164 if !default_fallback {
165 panic_on_err!(self.fallback_router.nest(path, fallback_router));
166 }
167
168 self
169 }
170
171 /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
172 #[track_caller]
nest_service<T>(mut self, path: &str, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,173 pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
174 where
175 T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
176 T::Response: IntoResponse,
177 T::Future: Send + 'static,
178 {
179 panic_on_err!(self.path_router.nest_service(path, service));
180 self
181 }
182
183 #[doc = include_str!("../docs/routing/merge.md")]
184 #[track_caller]
merge<R>(mut self, other: R) -> Self where R: Into<Router<S, B>>,185 pub fn merge<R>(mut self, other: R) -> Self
186 where
187 R: Into<Router<S, B>>,
188 {
189 const PANIC_MSG: &str =
190 "Failed to merge fallbacks. This is a bug in axum. Please file an issue";
191
192 let Router {
193 path_router,
194 fallback_router: mut other_fallback,
195 default_fallback,
196 catch_all_fallback,
197 } = other.into();
198
199 panic_on_err!(self.path_router.merge(path_router));
200
201 match (self.default_fallback, default_fallback) {
202 // both have the default fallback
203 // use the one from other
204 (true, true) => {
205 self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
206 }
207 // self has default fallback, other has a custom fallback
208 (true, false) => {
209 self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
210 self.default_fallback = false;
211 }
212 // self has a custom fallback, other has a default
213 (false, true) => {
214 let fallback_router = std::mem::take(&mut self.fallback_router);
215 other_fallback.merge(fallback_router).expect(PANIC_MSG);
216 self.fallback_router = other_fallback;
217 }
218 // both have a custom fallback, not allowed
219 (false, false) => {
220 panic!("Cannot merge two `Router`s that both have a fallback")
221 }
222 };
223
224 self.catch_all_fallback = self
225 .catch_all_fallback
226 .merge(catch_all_fallback)
227 .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
228
229 self
230 }
231
232 #[doc = include_str!("../docs/routing/layer.md")]
layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: HttpBody + 'static,233 pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
234 where
235 L: Layer<Route<B>> + Clone + Send + 'static,
236 L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
237 <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
238 <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
239 <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
240 NewReqBody: HttpBody + 'static,
241 {
242 Router {
243 path_router: self.path_router.layer(layer.clone()),
244 fallback_router: self.fallback_router.layer(layer.clone()),
245 default_fallback: self.default_fallback,
246 catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
247 }
248 }
249
250 #[doc = include_str!("../docs/routing/route_layer.md")]
251 #[track_caller]
route_layer<L>(self, layer: L) -> Self where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<B>> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static,252 pub fn route_layer<L>(self, layer: L) -> Self
253 where
254 L: Layer<Route<B>> + Clone + Send + 'static,
255 L::Service: Service<Request<B>> + Clone + Send + 'static,
256 <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
257 <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
258 <L::Service as Service<Request<B>>>::Future: Send + 'static,
259 {
260 Router {
261 path_router: self.path_router.route_layer(layer),
262 fallback_router: self.fallback_router,
263 default_fallback: self.default_fallback,
264 catch_all_fallback: self.catch_all_fallback,
265 }
266 }
267
268 #[track_caller]
269 #[doc = include_str!("../docs/routing/fallback.md")]
fallback<H, T>(mut self, handler: H) -> Self where H: Handler<T, S, B>, T: 'static,270 pub fn fallback<H, T>(mut self, handler: H) -> Self
271 where
272 H: Handler<T, S, B>,
273 T: 'static,
274 {
275 self.catch_all_fallback =
276 Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
277 self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
278 }
279
280 /// Add a fallback [`Service`] to the router.
281 ///
282 /// See [`Router::fallback`] for more details.
fallback_service<T>(mut self, service: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,283 pub fn fallback_service<T>(mut self, service: T) -> Self
284 where
285 T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
286 T::Response: IntoResponse,
287 T::Future: Send + 'static,
288 {
289 let route = Route::new(service);
290 self.catch_all_fallback = Fallback::Service(route.clone());
291 self.fallback_endpoint(Endpoint::Route(route))
292 }
293
fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self294 fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self {
295 self.fallback_router.set_fallback(endpoint);
296 self.default_fallback = false;
297 self
298 }
299
300 #[doc = include_str!("../docs/routing/with_state.md")]
with_state<S2>(self, state: S) -> Router<S2, B>301 pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
302 Router {
303 path_router: self.path_router.with_state(state.clone()),
304 fallback_router: self.fallback_router.with_state(state.clone()),
305 default_fallback: self.default_fallback,
306 catch_all_fallback: self.catch_all_fallback.with_state(state),
307 }
308 }
309
call_with_state( &mut self, mut req: Request<B>, state: S, ) -> RouteFuture<B, Infallible>310 pub(crate) fn call_with_state(
311 &mut self,
312 mut req: Request<B>,
313 state: S,
314 ) -> RouteFuture<B, Infallible> {
315 // required for opaque routers to still inherit the fallback
316 // TODO(david): remove this feature in 0.7
317 if !self.default_fallback {
318 req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
319 self.fallback_router.clone(),
320 )));
321 }
322
323 match self.path_router.call_with_state(req, state) {
324 Ok(future) => future,
325 Err((mut req, state)) => {
326 let super_fallback = req
327 .extensions_mut()
328 .remove::<SuperFallback<S, B>>()
329 .map(|SuperFallback(path_router)| path_router.into_inner());
330
331 if let Some(mut super_fallback) = super_fallback {
332 match super_fallback.call_with_state(req, state) {
333 Ok(future) => return future,
334 Err((req, state)) => {
335 return self.catch_all_fallback.call_with_state(req, state);
336 }
337 }
338 }
339
340 match self.fallback_router.call_with_state(req, state) {
341 Ok(future) => future,
342 Err((req, state)) => self.catch_all_fallback.call_with_state(req, state),
343 }
344 }
345 }
346 }
347 }
348
349 impl<B> Router<(), B>
350 where
351 B: HttpBody + Send + 'static,
352 {
353 /// Convert this router into a [`MakeService`], that is a [`Service`] whose
354 /// response is another service.
355 ///
356 /// This is useful when running your application with hyper's
357 /// [`Server`](hyper::server::Server):
358 ///
359 /// ```
360 /// use axum::{
361 /// routing::get,
362 /// Router,
363 /// };
364 ///
365 /// let app = Router::new().route("/", get(|| async { "Hi!" }));
366 ///
367 /// # async {
368 /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
369 /// .serve(app.into_make_service())
370 /// .await
371 /// .expect("server failed");
372 /// # };
373 /// ```
374 ///
375 /// [`MakeService`]: tower::make::MakeService
into_make_service(self) -> IntoMakeService<Self>376 pub fn into_make_service(self) -> IntoMakeService<Self> {
377 // call `Router::with_state` such that everything is turned into `Route` eagerly
378 // rather than doing that per request
379 IntoMakeService::new(self.with_state(()))
380 }
381
382 #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")]
383 #[cfg(feature = "tokio")]
into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>384 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
385 // call `Router::with_state` such that everything is turned into `Route` eagerly
386 // rather than doing that per request
387 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
388 }
389 }
390
391 impl<B> Service<Request<B>> for Router<(), B>
392 where
393 B: HttpBody + Send + 'static,
394 {
395 type Response = Response;
396 type Error = Infallible;
397 type Future = RouteFuture<B, Infallible>;
398
399 #[inline]
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>>400 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
401 Poll::Ready(Ok(()))
402 }
403
404 #[inline]
call(&mut self, req: Request<B>) -> Self::Future405 fn call(&mut self, req: Request<B>) -> Self::Future {
406 self.call_with_state(req, ())
407 }
408 }
409
410 enum Fallback<S, B, E = Infallible> {
411 Default(Route<B, E>),
412 Service(Route<B, E>),
413 BoxedHandler(BoxedIntoRoute<S, B, E>),
414 }
415
416 impl<S, B, E> Fallback<S, B, E>
417 where
418 S: Clone,
419 {
merge(self, other: Self) -> Option<Self>420 fn merge(self, other: Self) -> Option<Self> {
421 match (self, other) {
422 (Self::Default(_), pick @ Self::Default(_)) => Some(pick),
423 (Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick),
424 _ => None,
425 }
426 }
427
map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2> where S: 'static, B: 'static, E: 'static, F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static,428 fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
429 where
430 S: 'static,
431 B: 'static,
432 E: 'static,
433 F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
434 B2: HttpBody + 'static,
435 E2: 'static,
436 {
437 match self {
438 Self::Default(route) => Fallback::Default(f(route)),
439 Self::Service(route) => Fallback::Service(f(route)),
440 Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
441 }
442 }
443
with_state<S2>(self, state: S) -> Fallback<S2, B, E>444 fn with_state<S2>(self, state: S) -> Fallback<S2, B, E> {
445 match self {
446 Fallback::Default(route) => Fallback::Default(route),
447 Fallback::Service(route) => Fallback::Service(route),
448 Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
449 }
450 }
451
call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E>452 fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
453 match self {
454 Fallback::Default(route) | Fallback::Service(route) => {
455 RouteFuture::from_future(route.oneshot_inner(req))
456 }
457 Fallback::BoxedHandler(handler) => {
458 let mut route = handler.clone().into_route(state);
459 RouteFuture::from_future(route.oneshot_inner(req))
460 }
461 }
462 }
463 }
464
465 impl<S, B, E> Clone for Fallback<S, B, E> {
clone(&self) -> Self466 fn clone(&self) -> Self {
467 match self {
468 Self::Default(inner) => Self::Default(inner.clone()),
469 Self::Service(inner) => Self::Service(inner.clone()),
470 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
471 }
472 }
473 }
474
475 impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result476 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477 match self {
478 Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
479 Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
480 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
481 }
482 }
483 }
484
485 #[allow(clippy::large_enum_variant)]
486 enum Endpoint<S, B> {
487 MethodRouter(MethodRouter<S, B>),
488 Route(Route<B>),
489 }
490
491 impl<S, B> Endpoint<S, B>
492 where
493 B: HttpBody + Send + 'static,
494 S: Clone + Send + Sync + 'static,
495 {
layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: HttpBody + 'static,496 fn layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody>
497 where
498 L: Layer<Route<B>> + Clone + Send + 'static,
499 L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
500 <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
501 <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
502 <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
503 NewReqBody: HttpBody + 'static,
504 {
505 match self {
506 Endpoint::MethodRouter(method_router) => {
507 Endpoint::MethodRouter(method_router.layer(layer))
508 }
509 Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
510 }
511 }
512 }
513
514 impl<S, B> Clone for Endpoint<S, B> {
clone(&self) -> Self515 fn clone(&self) -> Self {
516 match self {
517 Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
518 Self::Route(inner) => Self::Route(inner.clone()),
519 }
520 }
521 }
522
523 impl<S, B> fmt::Debug for Endpoint<S, B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result524 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
525 match self {
526 Self::MethodRouter(method_router) => {
527 f.debug_tuple("MethodRouter").field(method_router).finish()
528 }
529 Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
530 }
531 }
532 }
533
534 struct SuperFallback<S, B>(SyncWrapper<PathRouter<S, B, true>>);
535
536 #[test]
537 #[allow(warnings)]
traits()538 fn traits() {
539 use crate::test_helpers::*;
540 assert_send::<Router<(), ()>>();
541 }
542