1 use super::{Extension, FromRequest, FromRequestParts};
2 use crate::{
3     body::{Body, Bytes, HttpBody},
4     BoxError, Error,
5 };
6 use async_trait::async_trait;
7 use futures_util::stream::Stream;
8 use http::{request::Parts, Request, Uri};
9 use std::{
10     convert::Infallible,
11     fmt,
12     pin::Pin,
13     task::{Context, Poll},
14 };
15 use sync_wrapper::SyncWrapper;
16 
17 /// Extractor that gets the original request URI regardless of nesting.
18 ///
19 /// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
20 /// have the prefix stripped if used in a nested service.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use axum::{
26 ///     routing::get,
27 ///     Router,
28 ///     extract::OriginalUri,
29 ///     http::Uri
30 /// };
31 ///
32 /// let api_routes = Router::new()
33 ///     .route(
34 ///         "/users",
35 ///         get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
36 ///             // `uri` is `/users`
37 ///             // `original_uri` is `/api/users`
38 ///         }),
39 ///     );
40 ///
41 /// let app = Router::new().nest("/api", api_routes);
42 /// # async {
43 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
44 /// # };
45 /// ```
46 ///
47 /// # Extracting via request extensions
48 ///
49 /// `OriginalUri` can also be accessed from middleware via request extensions.
50 /// This is useful for example with [`Trace`](tower_http::trace::Trace) to
51 /// create a span that contains the full path, if your service might be nested:
52 ///
53 /// ```
54 /// use axum::{
55 ///     Router,
56 ///     extract::OriginalUri,
57 ///     http::Request,
58 ///     routing::get,
59 /// };
60 /// use tower_http::trace::TraceLayer;
61 ///
62 /// let api_routes = Router::new()
63 ///     .route("/users/:id", get(|| async { /* ... */ }))
64 ///     .layer(
65 ///         TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
66 ///             let path = if let Some(path) = req.extensions().get::<OriginalUri>() {
67 ///                 // This will include `/api`
68 ///                 path.0.path().to_owned()
69 ///             } else {
70 ///                 // The `OriginalUri` extension will always be present if using
71 ///                 // `Router` unless another extractor or middleware has removed it
72 ///                 req.uri().path().to_owned()
73 ///             };
74 ///             tracing::info_span!("http-request", %path)
75 ///         }),
76 ///     );
77 ///
78 /// let app = Router::new().nest("/api", api_routes);
79 /// # async {
80 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
81 /// # };
82 /// ```
83 #[cfg(feature = "original-uri")]
84 #[derive(Debug, Clone)]
85 pub struct OriginalUri(pub Uri);
86 
87 #[cfg(feature = "original-uri")]
88 #[async_trait]
89 impl<S> FromRequestParts<S> for OriginalUri
90 where
91     S: Send + Sync,
92 {
93     type Rejection = Infallible;
94 
from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>95     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
96         let uri = Extension::<Self>::from_request_parts(parts, state)
97             .await
98             .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
99             .0;
100         Ok(uri)
101     }
102 }
103 
104 #[cfg(feature = "original-uri")]
105 axum_core::__impl_deref!(OriginalUri: Uri);
106 
107 /// Extractor that extracts the request body as a [`Stream`].
108 ///
109 /// Since extracting the request body requires consuming it, the `BodyStream` extractor must be
110 /// *last* if there are multiple extractors in a handler.
111 /// See ["the order of extractors"][order-of-extractors]
112 ///
113 /// [order-of-extractors]: crate::extract#the-order-of-extractors
114 ///
115 /// # Example
116 ///
117 /// ```rust,no_run
118 /// use axum::{
119 ///     extract::BodyStream,
120 ///     routing::get,
121 ///     Router,
122 /// };
123 /// use futures_util::StreamExt;
124 ///
125 /// async fn handler(mut stream: BodyStream) {
126 ///     while let Some(chunk) = stream.next().await {
127 ///         // ...
128 ///     }
129 /// }
130 ///
131 /// let app = Router::new().route("/users", get(handler));
132 /// # async {
133 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
134 /// # };
135 /// ```
136 ///
137 /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
138 /// [`body::Body`]: crate::body::Body
139 pub struct BodyStream(
140     SyncWrapper<Pin<Box<dyn HttpBody<Data = Bytes, Error = Error> + Send + 'static>>>,
141 );
142 
143 impl Stream for BodyStream {
144     type Item = Result<Bytes, Error>;
145 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>146     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
147         Pin::new(self.0.get_mut()).poll_data(cx)
148     }
149 }
150 
151 #[async_trait]
152 impl<S, B> FromRequest<S, B> for BodyStream
153 where
154     B: HttpBody + Send + 'static,
155     B::Data: Into<Bytes>,
156     B::Error: Into<BoxError>,
157     S: Send + Sync,
158 {
159     type Rejection = Infallible;
160 
from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection>161     async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
162         let body = req
163             .into_body()
164             .map_data(Into::into)
165             .map_err(|err| Error::new(err.into()));
166         let stream = BodyStream(SyncWrapper::new(Box::pin(body)));
167         Ok(stream)
168     }
169 }
170 
171 impl fmt::Debug for BodyStream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result172     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173         f.debug_tuple("BodyStream").finish()
174     }
175 }
176 
177 #[test]
body_stream_traits()178 fn body_stream_traits() {
179     crate::test_helpers::assert_send::<BodyStream>();
180     crate::test_helpers::assert_sync::<BodyStream>();
181 }
182 
183 /// Extractor that extracts the raw request body.
184 ///
185 /// Since extracting the raw request body requires consuming it, the `RawBody` extractor must be
186 /// *last* if there are multiple extractors in a handler. See ["the order of extractors"][order-of-extractors]
187 ///
188 /// [order-of-extractors]: crate::extract#the-order-of-extractors
189 ///
190 /// # Example
191 ///
192 /// ```rust,no_run
193 /// use axum::{
194 ///     extract::RawBody,
195 ///     routing::get,
196 ///     Router,
197 /// };
198 /// use futures_util::StreamExt;
199 ///
200 /// async fn handler(RawBody(body): RawBody) {
201 ///     // ...
202 /// }
203 ///
204 /// let app = Router::new().route("/users", get(handler));
205 /// # async {
206 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
207 /// # };
208 /// ```
209 ///
210 /// [`body::Body`]: crate::body::Body
211 #[derive(Debug, Default, Clone)]
212 pub struct RawBody<B = Body>(pub B);
213 
214 #[async_trait]
215 impl<S, B> FromRequest<S, B> for RawBody<B>
216 where
217     B: Send,
218     S: Send + Sync,
219 {
220     type Rejection = Infallible;
221 
from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection>222     async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
223         Ok(Self(req.into_body()))
224     }
225 }
226 
227 axum_core::__impl_deref!(RawBody);
228 
229 #[cfg(test)]
230 mod tests {
231     use crate::{extract::Extension, routing::get, test_helpers::*, Router};
232     use http::{Method, StatusCode};
233 
234     #[crate::test]
extract_request_parts()235     async fn extract_request_parts() {
236         #[derive(Clone)]
237         struct Ext;
238 
239         async fn handler(parts: http::request::Parts) {
240             assert_eq!(parts.method, Method::GET);
241             assert_eq!(parts.uri, "/");
242             assert_eq!(parts.version, http::Version::HTTP_11);
243             assert_eq!(parts.headers["x-foo"], "123");
244             parts.extensions.get::<Ext>().unwrap();
245         }
246 
247         let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
248 
249         let res = client.get("/").header("x-foo", "123").send().await;
250         assert_eq!(res.status(), StatusCode::OK);
251     }
252 }
253