1 use super::{Extension, FromRequest, FromRequestParts};
2 use crate::{
3 body::{Body, Bytes, HttpBody},
4 BoxError, Error,
5 };
6 use async_trait::async_trait;
7 use futures_util::stream::Stream;
8 use http::{request::Parts, Request, Uri};
9 use std::{
10 convert::Infallible,
11 fmt,
12 pin::Pin,
13 task::{Context, Poll},
14 };
15 use sync_wrapper::SyncWrapper;
16
17 /// Extractor that gets the original request URI regardless of nesting.
18 ///
19 /// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
20 /// have the prefix stripped if used in a nested service.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use axum::{
26 /// routing::get,
27 /// Router,
28 /// extract::OriginalUri,
29 /// http::Uri
30 /// };
31 ///
32 /// let api_routes = Router::new()
33 /// .route(
34 /// "/users",
35 /// get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
36 /// // `uri` is `/users`
37 /// // `original_uri` is `/api/users`
38 /// }),
39 /// );
40 ///
41 /// let app = Router::new().nest("/api", api_routes);
42 /// # async {
43 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
44 /// # };
45 /// ```
46 ///
47 /// # Extracting via request extensions
48 ///
49 /// `OriginalUri` can also be accessed from middleware via request extensions.
50 /// This is useful for example with [`Trace`](tower_http::trace::Trace) to
51 /// create a span that contains the full path, if your service might be nested:
52 ///
53 /// ```
54 /// use axum::{
55 /// Router,
56 /// extract::OriginalUri,
57 /// http::Request,
58 /// routing::get,
59 /// };
60 /// use tower_http::trace::TraceLayer;
61 ///
62 /// let api_routes = Router::new()
63 /// .route("/users/:id", get(|| async { /* ... */ }))
64 /// .layer(
65 /// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
66 /// let path = if let Some(path) = req.extensions().get::<OriginalUri>() {
67 /// // This will include `/api`
68 /// path.0.path().to_owned()
69 /// } else {
70 /// // The `OriginalUri` extension will always be present if using
71 /// // `Router` unless another extractor or middleware has removed it
72 /// req.uri().path().to_owned()
73 /// };
74 /// tracing::info_span!("http-request", %path)
75 /// }),
76 /// );
77 ///
78 /// let app = Router::new().nest("/api", api_routes);
79 /// # async {
80 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
81 /// # };
82 /// ```
83 #[cfg(feature = "original-uri")]
84 #[derive(Debug, Clone)]
85 pub struct OriginalUri(pub Uri);
86
87 #[cfg(feature = "original-uri")]
88 #[async_trait]
89 impl<S> FromRequestParts<S> for OriginalUri
90 where
91 S: Send + Sync,
92 {
93 type Rejection = Infallible;
94
from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>95 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
96 let uri = Extension::<Self>::from_request_parts(parts, state)
97 .await
98 .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
99 .0;
100 Ok(uri)
101 }
102 }
103
104 #[cfg(feature = "original-uri")]
105 axum_core::__impl_deref!(OriginalUri: Uri);
106
107 /// Extractor that extracts the request body as a [`Stream`].
108 ///
109 /// Since extracting the request body requires consuming it, the `BodyStream` extractor must be
110 /// *last* if there are multiple extractors in a handler.
111 /// See ["the order of extractors"][order-of-extractors]
112 ///
113 /// [order-of-extractors]: crate::extract#the-order-of-extractors
114 ///
115 /// # Example
116 ///
117 /// ```rust,no_run
118 /// use axum::{
119 /// extract::BodyStream,
120 /// routing::get,
121 /// Router,
122 /// };
123 /// use futures_util::StreamExt;
124 ///
125 /// async fn handler(mut stream: BodyStream) {
126 /// while let Some(chunk) = stream.next().await {
127 /// // ...
128 /// }
129 /// }
130 ///
131 /// let app = Router::new().route("/users", get(handler));
132 /// # async {
133 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
134 /// # };
135 /// ```
136 ///
137 /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
138 /// [`body::Body`]: crate::body::Body
139 pub struct BodyStream(
140 SyncWrapper<Pin<Box<dyn HttpBody<Data = Bytes, Error = Error> + Send + 'static>>>,
141 );
142
143 impl Stream for BodyStream {
144 type Item = Result<Bytes, Error>;
145
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>146 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
147 Pin::new(self.0.get_mut()).poll_data(cx)
148 }
149 }
150
151 #[async_trait]
152 impl<S, B> FromRequest<S, B> for BodyStream
153 where
154 B: HttpBody + Send + 'static,
155 B::Data: Into<Bytes>,
156 B::Error: Into<BoxError>,
157 S: Send + Sync,
158 {
159 type Rejection = Infallible;
160
from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection>161 async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
162 let body = req
163 .into_body()
164 .map_data(Into::into)
165 .map_err(|err| Error::new(err.into()));
166 let stream = BodyStream(SyncWrapper::new(Box::pin(body)));
167 Ok(stream)
168 }
169 }
170
171 impl fmt::Debug for BodyStream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_tuple("BodyStream").finish()
174 }
175 }
176
177 #[test]
body_stream_traits()178 fn body_stream_traits() {
179 crate::test_helpers::assert_send::<BodyStream>();
180 crate::test_helpers::assert_sync::<BodyStream>();
181 }
182
183 /// Extractor that extracts the raw request body.
184 ///
185 /// Since extracting the raw request body requires consuming it, the `RawBody` extractor must be
186 /// *last* if there are multiple extractors in a handler. See ["the order of extractors"][order-of-extractors]
187 ///
188 /// [order-of-extractors]: crate::extract#the-order-of-extractors
189 ///
190 /// # Example
191 ///
192 /// ```rust,no_run
193 /// use axum::{
194 /// extract::RawBody,
195 /// routing::get,
196 /// Router,
197 /// };
198 /// use futures_util::StreamExt;
199 ///
200 /// async fn handler(RawBody(body): RawBody) {
201 /// // ...
202 /// }
203 ///
204 /// let app = Router::new().route("/users", get(handler));
205 /// # async {
206 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
207 /// # };
208 /// ```
209 ///
210 /// [`body::Body`]: crate::body::Body
211 #[derive(Debug, Default, Clone)]
212 pub struct RawBody<B = Body>(pub B);
213
214 #[async_trait]
215 impl<S, B> FromRequest<S, B> for RawBody<B>
216 where
217 B: Send,
218 S: Send + Sync,
219 {
220 type Rejection = Infallible;
221
from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection>222 async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
223 Ok(Self(req.into_body()))
224 }
225 }
226
227 axum_core::__impl_deref!(RawBody);
228
229 #[cfg(test)]
230 mod tests {
231 use crate::{extract::Extension, routing::get, test_helpers::*, Router};
232 use http::{Method, StatusCode};
233
234 #[crate::test]
extract_request_parts()235 async fn extract_request_parts() {
236 #[derive(Clone)]
237 struct Ext;
238
239 async fn handler(parts: http::request::Parts) {
240 assert_eq!(parts.method, Method::GET);
241 assert_eq!(parts.uri, "/");
242 assert_eq!(parts.version, http::Version::HTTP_11);
243 assert_eq!(parts.headers["x-foo"], "123");
244 parts.extensions.get::<Ext>().unwrap();
245 }
246
247 let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
248
249 let res = client.get("/").header("x-foo", "123").send().await;
250 assert_eq!(res.status(), StatusCode::OK);
251 }
252 }
253