1 //! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
2 //!
3 //! See [`Multipart`] for more details.
4 
5 use super::{BodyStream, FromRequest};
6 use crate::body::{Bytes, HttpBody};
7 use crate::BoxError;
8 use async_trait::async_trait;
9 use axum_core::__composite_rejection as composite_rejection;
10 use axum_core::__define_rejection as define_rejection;
11 use axum_core::response::{IntoResponse, Response};
12 use axum_core::RequestExt;
13 use futures_util::stream::Stream;
14 use http::header::{HeaderMap, CONTENT_TYPE};
15 use http::{Request, StatusCode};
16 use std::error::Error;
17 use std::{
18     fmt,
19     pin::Pin,
20     task::{Context, Poll},
21 };
22 
23 /// Extractor that parses `multipart/form-data` requests (commonly used with file uploads).
24 ///
25 /// ⚠️ Since extracting multipart form data from the request requires consuming the body, the
26 /// `Multipart` extractor must be *last* if there are multiple extractors in a handler.
27 /// See ["the order of extractors"][order-of-extractors]
28 ///
29 /// [order-of-extractors]: crate::extract#the-order-of-extractors
30 ///
31 /// # Example
32 ///
33 /// ```rust,no_run
34 /// use axum::{
35 ///     extract::Multipart,
36 ///     routing::post,
37 ///     Router,
38 /// };
39 /// use futures_util::stream::StreamExt;
40 ///
41 /// async fn upload(mut multipart: Multipart) {
42 ///     while let Some(mut field) = multipart.next_field().await.unwrap() {
43 ///         let name = field.name().unwrap().to_string();
44 ///         let data = field.bytes().await.unwrap();
45 ///
46 ///         println!("Length of `{}` is {} bytes", name, data.len());
47 ///     }
48 /// }
49 ///
50 /// let app = Router::new().route("/upload", post(upload));
51 /// # async {
52 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
53 /// # };
54 /// ```
55 #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
56 #[derive(Debug)]
57 pub struct Multipart {
58     inner: multer::Multipart<'static>,
59 }
60 
61 #[async_trait]
62 impl<S, B> FromRequest<S, B> for Multipart
63 where
64     B: HttpBody + Send + 'static,
65     B::Data: Into<Bytes>,
66     B::Error: Into<BoxError>,
67     S: Send + Sync,
68 {
69     type Rejection = MultipartRejection;
70 
from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection>71     async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
72         let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
73         let stream_result = match req.with_limited_body() {
74             Ok(limited) => BodyStream::from_request(limited, state).await,
75             Err(unlimited) => BodyStream::from_request(unlimited, state).await,
76         };
77         let stream = stream_result.unwrap_or_else(|err| match err {});
78         let multipart = multer::Multipart::new(stream, boundary);
79         Ok(Self { inner: multipart })
80     }
81 }
82 
83 impl Multipart {
84     /// Yields the next [`Field`] if available.
next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError>85     pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
86         let field = self
87             .inner
88             .next_field()
89             .await
90             .map_err(MultipartError::from_multer)?;
91 
92         if let Some(field) = field {
93             Ok(Some(Field {
94                 inner: field,
95                 _multipart: self,
96             }))
97         } else {
98             Ok(None)
99         }
100     }
101 }
102 
103 /// A single field in a multipart stream.
104 #[derive(Debug)]
105 pub struct Field<'a> {
106     inner: multer::Field<'static>,
107     // multer requires there to only be one live `multer::Field` at any point. This enforces that
108     // statically, which multer does not do, it returns an error instead.
109     _multipart: &'a mut Multipart,
110 }
111 
112 impl<'a> Stream for Field<'a> {
113     type Item = Result<Bytes, MultipartError>;
114 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>115     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116         Pin::new(&mut self.inner)
117             .poll_next(cx)
118             .map_err(MultipartError::from_multer)
119     }
120 }
121 
122 impl<'a> Field<'a> {
123     /// The field name found in the
124     /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
125     /// header.
name(&self) -> Option<&str>126     pub fn name(&self) -> Option<&str> {
127         self.inner.name()
128     }
129 
130     /// The file name found in the
131     /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
132     /// header.
file_name(&self) -> Option<&str>133     pub fn file_name(&self) -> Option<&str> {
134         self.inner.file_name()
135     }
136 
137     /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field.
content_type(&self) -> Option<&str>138     pub fn content_type(&self) -> Option<&str> {
139         self.inner.content_type().map(|m| m.as_ref())
140     }
141 
142     /// Get a map of headers as [`HeaderMap`].
headers(&self) -> &HeaderMap143     pub fn headers(&self) -> &HeaderMap {
144         self.inner.headers()
145     }
146 
147     /// Get the full data of the field as [`Bytes`].
bytes(self) -> Result<Bytes, MultipartError>148     pub async fn bytes(self) -> Result<Bytes, MultipartError> {
149         self.inner
150             .bytes()
151             .await
152             .map_err(MultipartError::from_multer)
153     }
154 
155     /// Get the full field data as text.
text(self) -> Result<String, MultipartError>156     pub async fn text(self) -> Result<String, MultipartError> {
157         self.inner.text().await.map_err(MultipartError::from_multer)
158     }
159 
160     /// Stream a chunk of the field data.
161     ///
162     /// When the field data has been exhausted, this will return [`None`].
163     ///
164     /// Note this does the same thing as `Field`'s [`Stream`] implementation.
165     ///
166     /// # Example
167     ///
168     /// ```
169     /// use axum::{
170     ///    extract::Multipart,
171     ///    routing::post,
172     ///    response::IntoResponse,
173     ///    http::StatusCode,
174     ///    Router,
175     /// };
176     ///
177     /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> {
178     ///     while let Some(mut field) = multipart
179     ///         .next_field()
180     ///         .await
181     ///         .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
182     ///     {
183     ///         while let Some(chunk) = field
184     ///             .chunk()
185     ///             .await
186     ///             .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
187     ///         {
188     ///             println!("received {} bytes", chunk.len());
189     ///         }
190     ///     }
191     ///
192     ///     Ok(())
193     /// }
194     ///
195     /// let app = Router::new().route("/upload", post(upload));
196     /// # let _: Router = app;
197     /// ```
chunk(&mut self) -> Result<Option<Bytes>, MultipartError>198     pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
199         self.inner
200             .chunk()
201             .await
202             .map_err(MultipartError::from_multer)
203     }
204 }
205 
206 /// Errors associated with parsing `multipart/form-data` requests.
207 #[derive(Debug)]
208 pub struct MultipartError {
209     source: multer::Error,
210 }
211 
212 impl MultipartError {
from_multer(multer: multer::Error) -> Self213     fn from_multer(multer: multer::Error) -> Self {
214         Self { source: multer }
215     }
216 
217     /// Get the response body text used for this rejection.
body_text(&self) -> String218     pub fn body_text(&self) -> String {
219         self.source.to_string()
220     }
221 
222     /// Get the status code used for this rejection.
status(&self) -> http::StatusCode223     pub fn status(&self) -> http::StatusCode {
224         status_code_from_multer_error(&self.source)
225     }
226 }
227 
status_code_from_multer_error(err: &multer::Error) -> StatusCode228 fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
229     match err {
230         multer::Error::UnknownField { .. }
231         | multer::Error::IncompleteFieldData { .. }
232         | multer::Error::IncompleteHeaders
233         | multer::Error::ReadHeaderFailed(..)
234         | multer::Error::DecodeHeaderName { .. }
235         | multer::Error::DecodeContentType(..)
236         | multer::Error::NoBoundary
237         | multer::Error::DecodeHeaderValue { .. }
238         | multer::Error::NoMultipart
239         | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
240         multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
241             StatusCode::PAYLOAD_TOO_LARGE
242         }
243         multer::Error::StreamReadFailed(err) => {
244             if let Some(err) = err.downcast_ref::<multer::Error>() {
245                 return status_code_from_multer_error(err);
246             }
247 
248             if err
249                 .downcast_ref::<crate::Error>()
250                 .and_then(|err| err.source())
251                 .and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
252                 .is_some()
253             {
254                 return StatusCode::PAYLOAD_TOO_LARGE;
255             }
256 
257             StatusCode::INTERNAL_SERVER_ERROR
258         }
259         _ => StatusCode::INTERNAL_SERVER_ERROR,
260     }
261 }
262 
263 impl fmt::Display for MultipartError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result264     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265         write!(f, "Error parsing `multipart/form-data` request")
266     }
267 }
268 
269 impl std::error::Error for MultipartError {
source(&self) -> Option<&(dyn std::error::Error + 'static)>270     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
271         Some(&self.source)
272     }
273 }
274 
275 impl IntoResponse for MultipartError {
into_response(self) -> Response276     fn into_response(self) -> Response {
277         axum_core::__log_rejection!(
278             rejection_type = Self,
279             body_text = self.body_text(),
280             status = self.status(),
281         );
282         (self.status(), self.body_text()).into_response()
283     }
284 }
285 
parse_boundary(headers: &HeaderMap) -> Option<String>286 fn parse_boundary(headers: &HeaderMap) -> Option<String> {
287     let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
288     multer::parse_boundary(content_type).ok()
289 }
290 
291 composite_rejection! {
292     /// Rejection used for [`Multipart`].
293     ///
294     /// Contains one variant for each way the [`Multipart`] extractor can fail.
295     pub enum MultipartRejection {
296         InvalidBoundary,
297     }
298 }
299 
300 define_rejection! {
301     #[status = BAD_REQUEST]
302     #[body = "Invalid `boundary` for `multipart/form-data` request"]
303     /// Rejection type used if the `boundary` in a `multipart/form-data` is
304     /// missing or invalid.
305     pub struct InvalidBoundary;
306 }
307 
308 #[cfg(test)]
309 mod tests {
310     use axum_core::extract::DefaultBodyLimit;
311 
312     use super::*;
313     use crate::{body::Body, response::IntoResponse, routing::post, test_helpers::*, Router};
314 
315     #[crate::test]
content_type_with_encoding()316     async fn content_type_with_encoding() {
317         const BYTES: &[u8] = "<!doctype html><title>��</title>".as_bytes();
318         const FILE_NAME: &str = "index.html";
319         const CONTENT_TYPE: &str = "text/html; charset=utf-8";
320 
321         async fn handle(mut multipart: Multipart) -> impl IntoResponse {
322             let field = multipart.next_field().await.unwrap().unwrap();
323 
324             assert_eq!(field.file_name().unwrap(), FILE_NAME);
325             assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
326             assert_eq!(field.bytes().await.unwrap(), BYTES);
327 
328             assert!(multipart.next_field().await.unwrap().is_none());
329         }
330 
331         let app = Router::new().route("/", post(handle));
332 
333         let client = TestClient::new(app);
334 
335         let form = reqwest::multipart::Form::new().part(
336             "file",
337             reqwest::multipart::Part::bytes(BYTES)
338                 .file_name(FILE_NAME)
339                 .mime_str(CONTENT_TYPE)
340                 .unwrap(),
341         );
342 
343         client.post("/").multipart(form).send().await;
344     }
345 
346     // No need for this to be a #[test], we just want to make sure it compiles
_multipart_from_request_limited()347     fn _multipart_from_request_limited() {
348         async fn handler(_: Multipart) {}
349         let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
350     }
351 
352     #[crate::test]
body_too_large()353     async fn body_too_large() {
354         const BYTES: &[u8] = "<!doctype html><title>��</title>".as_bytes();
355 
356         async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
357             while let Some(field) = multipart.next_field().await? {
358                 field.bytes().await?;
359             }
360             Ok(())
361         }
362 
363         let app = Router::new()
364             .route("/", post(handle))
365             .layer(DefaultBodyLimit::max(BYTES.len() - 1));
366 
367         let client = TestClient::new(app);
368 
369         let form =
370             reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
371 
372         let res = client.post("/").multipart(form).send().await;
373         assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
374     }
375 }
376