From c8693498c1e76f8fc54ae4044e6fa948495d4a02 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 19 Sep 2022 22:41:54 +0200 Subject: [PATCH] Add `DefaultBodyLimit::max` to change the body size limit (#1397) --- axum-core/CHANGELOG.md | 4 +- axum-core/src/extract/default_body_limit.rs | 61 ++++++++++++++++++--- axum-core/src/extract/request_parts.rs | 28 ++++++---- axum/CHANGELOG.md | 2 + axum/src/routing/tests/mod.rs | 25 +++++++++ 5 files changed, 99 insertions(+), 21 deletions(-) diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index dc34140b15..be59e33fc0 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) + +[#1397]: https://github.com/tokio-rs/axum/pull/1397 # 0.2.8 (10. September, 2022) diff --git a/axum-core/src/extract/default_body_limit.rs b/axum-core/src/extract/default_body_limit.rs index 7f12bc9c8e..2830dee7a2 100644 --- a/axum-core/src/extract/default_body_limit.rs +++ b/axum-core/src/extract/default_body_limit.rs @@ -16,8 +16,15 @@ use tower_layer::Layer; /// [`Json`]: https://docs.rs/axum/0.5/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.5/axum/struct.Form.html #[derive(Debug, Clone)] -#[non_exhaustive] -pub struct DefaultBodyLimit; +pub struct DefaultBodyLimit { + kind: DefaultBodyLimitKind, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) enum DefaultBodyLimitKind { + Disable, + Limit(usize), +} impl DefaultBodyLimit { /// Disable the default request body limit. @@ -53,7 +60,42 @@ impl DefaultBodyLimit { /// [`Json`]: https://docs.rs/axum/0.5/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.5/axum/struct.Form.html pub fn disable() -> Self { - Self + Self { + kind: DefaultBodyLimitKind::Disable, + } + } + + /// Set the default request body limit. + /// + /// By default the limit of request body sizes that [`Bytes::from_request`] (and other + /// extractors built on top of it such as `String`, [`Json`], and [`Form`]) is 2MB. This method + /// can be used to change that limit. + /// + /// # Example + /// + /// ``` + /// use axum::{ + /// Router, + /// routing::get, + /// body::{Bytes, Body}, + /// extract::DefaultBodyLimit, + /// }; + /// use tower_http::limit::RequestBodyLimitLayer; + /// use http_body::Limited; + /// + /// let app: Router<_, Limited> = Router::new() + /// .route("/", get(|body: Bytes| async {})) + /// // Replace the default of 2MB with 1024 bytes. + /// .layer(DefaultBodyLimit::max(1024)); + /// ``` + /// + /// [`Bytes::from_request`]: bytes::Bytes + /// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html + /// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html + pub fn max(limit: usize) -> Self { + Self { + kind: DefaultBodyLimitKind::Limit(limit), + } } } @@ -61,15 +103,15 @@ impl Layer for DefaultBodyLimit { type Service = DefaultBodyLimitService; fn layer(&self, inner: S) -> Self::Service { - DefaultBodyLimitService { inner } + DefaultBodyLimitService { + inner, + kind: self.kind, + } } } -#[derive(Copy, Clone, Debug)] -pub(crate) struct DefaultBodyLimitDisabled; - mod private { - use super::DefaultBodyLimitDisabled; + use super::DefaultBodyLimitKind; use http::Request; use std::task::Context; use tower_service::Service; @@ -77,6 +119,7 @@ mod private { #[derive(Debug, Clone, Copy)] pub struct DefaultBodyLimitService { pub(super) inner: S, + pub(super) kind: DefaultBodyLimitKind, } impl Service> for DefaultBodyLimitService @@ -94,7 +137,7 @@ mod private { #[inline] fn call(&mut self, mut req: Request) -> Self::Future { - req.extensions_mut().insert(DefaultBodyLimitDisabled); + req.extensions_mut().insert(self.kind); self.inner.call(req) } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 66753c1be4..5d7442c840 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,6 +1,4 @@ -use super::{ - default_body_limit::DefaultBodyLimitDisabled, rejection::*, FromRequest, RequestParts, -}; +use super::{default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, RequestParts}; use crate::BoxError; use async_trait::async_trait; use bytes::Bytes; @@ -100,15 +98,23 @@ where let body = take_body(req)?; - let bytes = if req.extensions().get::().is_some() { - crate::body::to_bytes(body) + let limit_kind = req.extensions().get::().copied(); + let bytes = match limit_kind { + Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(body) .await - .map_err(FailedToBufferBody::from_err)? - } else { - let body = http_body::Limited::new(body, DEFAULT_LIMIT); - crate::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? + .map_err(FailedToBufferBody::from_err)?, + Some(DefaultBodyLimitKind::Limit(limit)) => { + let body = http_body::Limited::new(body, limit); + crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + } + None => { + let body = http_body::Limited::new(body, DEFAULT_LIMIT); + crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + } }; Ok(bytes) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 484dbae740..108303a5f7 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Annotate panicking functions with `#[track_caller]` so the error message points to where the user added the invalid router, rather than somewhere internally in axum ([#1248]) +- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) [#1248]: https://github.com/tokio-rs/axum/pull/1248 +[#1397]: https://github.com/tokio-rs/axum/pull/1397 # 0.5.16 (10. September, 2022) diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 311126f4fe..e138010fd1 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -767,6 +767,31 @@ async fn limited_body_with_content_length() { assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } +#[tokio::test] +async fn changing_the_default_limit() { + let new_limit = 2; + + let app = Router::new() + .route("/", post(|_: Bytes| async {})) + .layer(DefaultBodyLimit::max(new_limit)); + + let client = TestClient::new(app); + + let res = client + .post("/") + .body(Body::from("a".repeat(new_limit))) + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post("/") + .body(Body::from("a".repeat(new_limit + 1))) + .send() + .await; + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); +} + #[tokio::test] async fn limited_body_with_streaming_body() { const LIMIT: usize = 3;