Skip to content

Commit

Permalink
Refactor TestClient usage (#3121)
Browse files Browse the repository at this point in the history
  • Loading branch information
Turbo87 authored Dec 27, 2024
1 parent 3497e5d commit 28d8d9b
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 65 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion axum-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ tower-http = { version = "0.6.0", optional = true, features = ["limit"] }
tracing = { version = "0.1.37", default-features = false, optional = true }

[dev-dependencies]
axum = { path = "../axum" }
axum = { path = "../axum", features = ["__private"] }
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
hyper = "1.0.0"
tokio = { version = "1.25.0", features = ["macros"] }
Expand Down
25 changes: 25 additions & 0 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,28 @@ where
Ok(req.into_body())
}
}

#[cfg(test)]
mod tests {
use axum::{extract::Extension, routing::get, test_helpers::*, Router};
use http::{Method, StatusCode};

#[crate::test]
async fn extract_request_parts() {
#[derive(Clone)]
struct Ext;

async fn handler(parts: http::request::Parts) {
assert_eq!(parts.method, Method::GET);
assert_eq!(parts.uri, "/");
assert_eq!(parts.version, http::Version::HTTP_11);
assert_eq!(parts.headers["x-foo"], "123");
parts.extensions.get::<Ext>().unwrap();
}

let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));

let res = client.get("/").header("x-foo", "123").await;
assert_eq!(res.status(), StatusCode::OK);
}
}
3 changes: 3 additions & 0 deletions axum-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ pub mod response;
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt};

#[cfg(test)]
use axum_macros::__private_axum_test as test;
2 changes: 1 addition & 1 deletion axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ tracing = { version = "0.1.37", default-features = false, optional = true }
typed-json = { version = "0.1.1", optional = true }

[dev-dependencies]
axum = { path = "../axum", features = ["macros"] }
axum = { path = "../axum", features = ["macros", "__private"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = "1.0.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
Expand Down
14 changes: 1 addition & 13 deletions axum-extra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,4 @@ pub mod __private {
use axum_macros::__private_axum_test as test;

#[cfg(test)]
#[allow(unused_imports)]
pub(crate) mod test_helpers {
use axum::{extract::Request, response::Response, serve};

mod test_client {
#![allow(dead_code)]
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../axum/src/test_helpers/test_client.rs"
));
}
pub(crate) use self::test_client::*;
}
pub(crate) use axum::test_helpers;
8 changes: 8 additions & 0 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ __private_docs = [
"tower/full", "dep:tower-http",
]

# This feature is used to enable private test helper usage
# in `axum-core` and `axum-extra`.
__private = ["tokio", "http1", "dep:reqwest"]

[dependencies]
axum-core = { path = "../axum-core", version = "0.5.0-rc.1" }
bytes = "1.0"
Expand Down Expand Up @@ -72,6 +76,7 @@ form_urlencoded = { version = "1.1.0", optional = true }
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true }
multer = { version = "3.0.0", optional = true }
reqwest = { version = "0.12", optional = true, default-features = false, features = ["json", "stream", "multipart"] }
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
serde_path_to_error = { version = "0.1.8", optional = true }
serde_urlencoded = { version = "0.7", optional = true }
Expand Down Expand Up @@ -214,6 +219,9 @@ allowed = [
"http_body",
"serde",
"tokio",

# for the `__private` feature
"reqwest",
]

[[bench]]
Expand Down
1 change: 0 additions & 1 deletion axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub(crate) mod nested_path;
mod original_uri;
mod raw_form;
mod raw_query;
mod request_parts;
mod state;

#[doc(inline)]
Expand Down
27 changes: 0 additions & 27 deletions axum/src/extract/request_parts.rs

This file was deleted.

5 changes: 3 additions & 2 deletions axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,9 @@ pub mod routing;
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub mod serve;

#[cfg(test)]
mod test_helpers;
#[cfg(any(test, feature = "__private"))]
#[allow(missing_docs, missing_debug_implementations, clippy::print_stdout)]
pub mod test_helpers;

#[doc(no_inline)]
pub use http;
Expand Down
6 changes: 5 additions & 1 deletion axum/src/test_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
use crate::{extract::Request, response::Response, serve};

mod test_client;
pub(crate) use self::test_client::*;
pub use self::test_client::*;

#[cfg(test)]
pub(crate) mod tracing_helpers;

#[cfg(test)]
pub(crate) mod counting_cloneable_state;

#[cfg(test)]
pub(crate) fn assert_send<T: Send>() {}
#[cfg(test)]
pub(crate) fn assert_sync<T: Sync>() {}

#[allow(dead_code)]
Expand Down
38 changes: 19 additions & 19 deletions axum/src/test_helpers/test_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ where
addr
}

pub(crate) struct TestClient {
pub struct TestClient {
client: reqwest::Client,
addr: SocketAddr,
}

impl TestClient {
pub(crate) fn new<S>(svc: S) -> Self
pub fn new<S>(svc: S) -> Self
where
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
Expand All @@ -50,63 +50,63 @@ impl TestClient {
TestClient { client, addr }
}

pub(crate) fn get(&self, url: &str) -> RequestBuilder {
pub fn get(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.get(format!("http://{}{url}", self.addr)),
}
}

pub(crate) fn head(&self, url: &str) -> RequestBuilder {
pub fn head(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.head(format!("http://{}{url}", self.addr)),
}
}

pub(crate) fn post(&self, url: &str) -> RequestBuilder {
pub fn post(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.post(format!("http://{}{url}", self.addr)),
}
}

#[allow(dead_code)]
pub(crate) fn put(&self, url: &str) -> RequestBuilder {
pub fn put(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.put(format!("http://{}{url}", self.addr)),
}
}

#[allow(dead_code)]
pub(crate) fn patch(&self, url: &str) -> RequestBuilder {
pub fn patch(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.patch(format!("http://{}{url}", self.addr)),
}
}

#[allow(dead_code)]
pub(crate) fn server_port(&self) -> u16 {
pub fn server_port(&self) -> u16 {
self.addr.port()
}
}

pub(crate) struct RequestBuilder {
pub struct RequestBuilder {
builder: reqwest::RequestBuilder,
}

impl RequestBuilder {
pub(crate) fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
self.builder = self.builder.body(body);
self
}

pub(crate) fn json<T>(mut self, json: &T) -> Self
pub fn json<T>(mut self, json: &T) -> Self
where
T: serde::Serialize,
{
self.builder = self.builder.json(json);
self
}

pub(crate) fn header<K, V>(mut self, key: K, value: V) -> Self
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
Expand All @@ -118,7 +118,7 @@ impl RequestBuilder {
}

#[allow(dead_code)]
pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
self.builder = self.builder.multipart(form);
self
}
Expand All @@ -138,7 +138,7 @@ impl IntoFuture for RequestBuilder {
}

#[derive(Debug)]
pub(crate) struct TestResponse {
pub struct TestResponse {
response: reqwest::Response,
}

Expand All @@ -152,27 +152,27 @@ impl Deref for TestResponse {

impl TestResponse {
#[allow(dead_code)]
pub(crate) async fn bytes(self) -> Bytes {
pub async fn bytes(self) -> Bytes {
self.response.bytes().await.unwrap()
}

pub(crate) async fn text(self) -> String {
pub async fn text(self) -> String {
self.response.text().await.unwrap()
}

#[allow(dead_code)]
pub(crate) async fn json<T>(self) -> T
pub async fn json<T>(self) -> T
where
T: serde::de::DeserializeOwned,
{
self.response.json().await.unwrap()
}

pub(crate) async fn chunk(&mut self) -> Option<Bytes> {
pub async fn chunk(&mut self) -> Option<Bytes> {
self.response.chunk().await.unwrap()
}

pub(crate) async fn chunk_text(&mut self) -> Option<String> {
pub async fn chunk_text(&mut self) -> Option<String> {
let chunk = self.chunk().await?;
Some(String::from_utf8(chunk.to_vec()).unwrap())
}
Expand Down

0 comments on commit 28d8d9b

Please sign in to comment.