1 //! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
2 //!
3 //! See [`Multipart`] for more details.
4
5 use super::{BodyStream, FromRequest};
6 use crate::body::{Bytes, HttpBody};
7 use crate::BoxError;
8 use async_trait::async_trait;
9 use axum_core::__composite_rejection as composite_rejection;
10 use axum_core::__define_rejection as define_rejection;
11 use axum_core::response::{IntoResponse, Response};
12 use axum_core::RequestExt;
13 use futures_util::stream::Stream;
14 use http::header::{HeaderMap, CONTENT_TYPE};
15 use http::{Request, StatusCode};
16 use std::error::Error;
17 use std::{
18 fmt,
19 pin::Pin,
20 task::{Context, Poll},
21 };
22
23 /// Extractor that parses `multipart/form-data` requests (commonly used with file uploads).
24 ///
25 /// ⚠️ Since extracting multipart form data from the request requires consuming the body, the
26 /// `Multipart` extractor must be *last* if there are multiple extractors in a handler.
27 /// See ["the order of extractors"][order-of-extractors]
28 ///
29 /// [order-of-extractors]: crate::extract#the-order-of-extractors
30 ///
31 /// # Example
32 ///
33 /// ```rust,no_run
34 /// use axum::{
35 /// extract::Multipart,
36 /// routing::post,
37 /// Router,
38 /// };
39 /// use futures_util::stream::StreamExt;
40 ///
41 /// async fn upload(mut multipart: Multipart) {
42 /// while let Some(mut field) = multipart.next_field().await.unwrap() {
43 /// let name = field.name().unwrap().to_string();
44 /// let data = field.bytes().await.unwrap();
45 ///
46 /// println!("Length of `{}` is {} bytes", name, data.len());
47 /// }
48 /// }
49 ///
50 /// let app = Router::new().route("/upload", post(upload));
51 /// # async {
52 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
53 /// # };
54 /// ```
55 #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
56 #[derive(Debug)]
57 pub struct Multipart {
58 inner: multer::Multipart<'static>,
59 }
60
61 #[async_trait]
62 impl<S, B> FromRequest<S, B> for Multipart
63 where
64 B: HttpBody + Send + 'static,
65 B::Data: Into<Bytes>,
66 B::Error: Into<BoxError>,
67 S: Send + Sync,
68 {
69 type Rejection = MultipartRejection;
70
from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection>71 async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
72 let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
73 let stream_result = match req.with_limited_body() {
74 Ok(limited) => BodyStream::from_request(limited, state).await,
75 Err(unlimited) => BodyStream::from_request(unlimited, state).await,
76 };
77 let stream = stream_result.unwrap_or_else(|err| match err {});
78 let multipart = multer::Multipart::new(stream, boundary);
79 Ok(Self { inner: multipart })
80 }
81 }
82
83 impl Multipart {
84 /// Yields the next [`Field`] if available.
next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError>85 pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
86 let field = self
87 .inner
88 .next_field()
89 .await
90 .map_err(MultipartError::from_multer)?;
91
92 if let Some(field) = field {
93 Ok(Some(Field {
94 inner: field,
95 _multipart: self,
96 }))
97 } else {
98 Ok(None)
99 }
100 }
101 }
102
103 /// A single field in a multipart stream.
104 #[derive(Debug)]
105 pub struct Field<'a> {
106 inner: multer::Field<'static>,
107 // multer requires there to only be one live `multer::Field` at any point. This enforces that
108 // statically, which multer does not do, it returns an error instead.
109 _multipart: &'a mut Multipart,
110 }
111
112 impl<'a> Stream for Field<'a> {
113 type Item = Result<Bytes, MultipartError>;
114
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>115 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116 Pin::new(&mut self.inner)
117 .poll_next(cx)
118 .map_err(MultipartError::from_multer)
119 }
120 }
121
122 impl<'a> Field<'a> {
123 /// The field name found in the
124 /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
125 /// header.
name(&self) -> Option<&str>126 pub fn name(&self) -> Option<&str> {
127 self.inner.name()
128 }
129
130 /// The file name found in the
131 /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
132 /// header.
file_name(&self) -> Option<&str>133 pub fn file_name(&self) -> Option<&str> {
134 self.inner.file_name()
135 }
136
137 /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field.
content_type(&self) -> Option<&str>138 pub fn content_type(&self) -> Option<&str> {
139 self.inner.content_type().map(|m| m.as_ref())
140 }
141
142 /// Get a map of headers as [`HeaderMap`].
headers(&self) -> &HeaderMap143 pub fn headers(&self) -> &HeaderMap {
144 self.inner.headers()
145 }
146
147 /// Get the full data of the field as [`Bytes`].
bytes(self) -> Result<Bytes, MultipartError>148 pub async fn bytes(self) -> Result<Bytes, MultipartError> {
149 self.inner
150 .bytes()
151 .await
152 .map_err(MultipartError::from_multer)
153 }
154
155 /// Get the full field data as text.
text(self) -> Result<String, MultipartError>156 pub async fn text(self) -> Result<String, MultipartError> {
157 self.inner.text().await.map_err(MultipartError::from_multer)
158 }
159
160 /// Stream a chunk of the field data.
161 ///
162 /// When the field data has been exhausted, this will return [`None`].
163 ///
164 /// Note this does the same thing as `Field`'s [`Stream`] implementation.
165 ///
166 /// # Example
167 ///
168 /// ```
169 /// use axum::{
170 /// extract::Multipart,
171 /// routing::post,
172 /// response::IntoResponse,
173 /// http::StatusCode,
174 /// Router,
175 /// };
176 ///
177 /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> {
178 /// while let Some(mut field) = multipart
179 /// .next_field()
180 /// .await
181 /// .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
182 /// {
183 /// while let Some(chunk) = field
184 /// .chunk()
185 /// .await
186 /// .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
187 /// {
188 /// println!("received {} bytes", chunk.len());
189 /// }
190 /// }
191 ///
192 /// Ok(())
193 /// }
194 ///
195 /// let app = Router::new().route("/upload", post(upload));
196 /// # let _: Router = app;
197 /// ```
chunk(&mut self) -> Result<Option<Bytes>, MultipartError>198 pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
199 self.inner
200 .chunk()
201 .await
202 .map_err(MultipartError::from_multer)
203 }
204 }
205
206 /// Errors associated with parsing `multipart/form-data` requests.
207 #[derive(Debug)]
208 pub struct MultipartError {
209 source: multer::Error,
210 }
211
212 impl MultipartError {
from_multer(multer: multer::Error) -> Self213 fn from_multer(multer: multer::Error) -> Self {
214 Self { source: multer }
215 }
216
217 /// Get the response body text used for this rejection.
body_text(&self) -> String218 pub fn body_text(&self) -> String {
219 self.source.to_string()
220 }
221
222 /// Get the status code used for this rejection.
status(&self) -> http::StatusCode223 pub fn status(&self) -> http::StatusCode {
224 status_code_from_multer_error(&self.source)
225 }
226 }
227
status_code_from_multer_error(err: &multer::Error) -> StatusCode228 fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
229 match err {
230 multer::Error::UnknownField { .. }
231 | multer::Error::IncompleteFieldData { .. }
232 | multer::Error::IncompleteHeaders
233 | multer::Error::ReadHeaderFailed(..)
234 | multer::Error::DecodeHeaderName { .. }
235 | multer::Error::DecodeContentType(..)
236 | multer::Error::NoBoundary
237 | multer::Error::DecodeHeaderValue { .. }
238 | multer::Error::NoMultipart
239 | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
240 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
241 StatusCode::PAYLOAD_TOO_LARGE
242 }
243 multer::Error::StreamReadFailed(err) => {
244 if let Some(err) = err.downcast_ref::<multer::Error>() {
245 return status_code_from_multer_error(err);
246 }
247
248 if err
249 .downcast_ref::<crate::Error>()
250 .and_then(|err| err.source())
251 .and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
252 .is_some()
253 {
254 return StatusCode::PAYLOAD_TOO_LARGE;
255 }
256
257 StatusCode::INTERNAL_SERVER_ERROR
258 }
259 _ => StatusCode::INTERNAL_SERVER_ERROR,
260 }
261 }
262
263 impl fmt::Display for MultipartError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 write!(f, "Error parsing `multipart/form-data` request")
266 }
267 }
268
269 impl std::error::Error for MultipartError {
source(&self) -> Option<&(dyn std::error::Error + 'static)>270 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
271 Some(&self.source)
272 }
273 }
274
275 impl IntoResponse for MultipartError {
into_response(self) -> Response276 fn into_response(self) -> Response {
277 axum_core::__log_rejection!(
278 rejection_type = Self,
279 body_text = self.body_text(),
280 status = self.status(),
281 );
282 (self.status(), self.body_text()).into_response()
283 }
284 }
285
parse_boundary(headers: &HeaderMap) -> Option<String>286 fn parse_boundary(headers: &HeaderMap) -> Option<String> {
287 let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
288 multer::parse_boundary(content_type).ok()
289 }
290
291 composite_rejection! {
292 /// Rejection used for [`Multipart`].
293 ///
294 /// Contains one variant for each way the [`Multipart`] extractor can fail.
295 pub enum MultipartRejection {
296 InvalidBoundary,
297 }
298 }
299
300 define_rejection! {
301 #[status = BAD_REQUEST]
302 #[body = "Invalid `boundary` for `multipart/form-data` request"]
303 /// Rejection type used if the `boundary` in a `multipart/form-data` is
304 /// missing or invalid.
305 pub struct InvalidBoundary;
306 }
307
308 #[cfg(test)]
309 mod tests {
310 use axum_core::extract::DefaultBodyLimit;
311
312 use super::*;
313 use crate::{body::Body, response::IntoResponse, routing::post, test_helpers::*, Router};
314
315 #[crate::test]
content_type_with_encoding()316 async fn content_type_with_encoding() {
317 const BYTES: &[u8] = "<!doctype html><title></title>".as_bytes();
318 const FILE_NAME: &str = "index.html";
319 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
320
321 async fn handle(mut multipart: Multipart) -> impl IntoResponse {
322 let field = multipart.next_field().await.unwrap().unwrap();
323
324 assert_eq!(field.file_name().unwrap(), FILE_NAME);
325 assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
326 assert_eq!(field.bytes().await.unwrap(), BYTES);
327
328 assert!(multipart.next_field().await.unwrap().is_none());
329 }
330
331 let app = Router::new().route("/", post(handle));
332
333 let client = TestClient::new(app);
334
335 let form = reqwest::multipart::Form::new().part(
336 "file",
337 reqwest::multipart::Part::bytes(BYTES)
338 .file_name(FILE_NAME)
339 .mime_str(CONTENT_TYPE)
340 .unwrap(),
341 );
342
343 client.post("/").multipart(form).send().await;
344 }
345
346 // No need for this to be a #[test], we just want to make sure it compiles
_multipart_from_request_limited()347 fn _multipart_from_request_limited() {
348 async fn handler(_: Multipart) {}
349 let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
350 }
351
352 #[crate::test]
body_too_large()353 async fn body_too_large() {
354 const BYTES: &[u8] = "<!doctype html><title></title>".as_bytes();
355
356 async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
357 while let Some(field) = multipart.next_field().await? {
358 field.bytes().await?;
359 }
360 Ok(())
361 }
362
363 let app = Router::new()
364 .route("/", post(handle))
365 .layer(DefaultBodyLimit::max(BYTES.len() - 1));
366
367 let client = TestClient::new(app);
368
369 let form =
370 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
371
372 let res = client.post("/").multipart(form).send().await;
373 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
374 }
375 }
376