From 81f0248713148a794f330fc5d6b6e578e6cff811 Mon Sep 17 00:00:00 2001 From: Wonwoo Choi Date: Sat, 1 Feb 2020 21:12:17 +0900 Subject: [PATCH 1/4] Per-route middleware --- src/endpoint.rs | 54 +++++++++++++++++++++++++- src/router.rs | 10 ++--- src/server/route.rs | 68 +++++++++++++++++++++++++++++--- tests/route_middleware.rs | 81 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 11 deletions(-) create mode 100644 tests/route_middleware.rs diff --git a/src/endpoint.rs b/src/endpoint.rs index 292c4b424..1fb4bc97c 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,7 +1,10 @@ +use std::sync::Arc; + use async_std::future::Future; +use crate::middleware::Next; use crate::utils::BoxFuture; -use crate::{response::IntoResponse, Request, Response}; +use crate::{response::IntoResponse, Middleware, Request, Response}; /// An HTTP request handler. /// @@ -63,3 +66,52 @@ where Box::pin(async move { fut.await.into_response() }) } } + +pub struct MiddlewareEndpoint { + endpoint: E, + middleware: Vec>>, +} + +impl Clone for MiddlewareEndpoint { + fn clone(&self) -> Self { + Self { + endpoint: self.endpoint.clone(), + middleware: self.middleware.clone(), + } + } +} + +impl std::fmt::Debug for MiddlewareEndpoint { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + fmt, + "MiddlewareEndpoint (length: {})", + self.middleware.len(), + ) + } +} + +impl MiddlewareEndpoint +where + E: Endpoint, +{ + pub fn wrap_with_middleware(ep: E, middleware: &[Arc>]) -> Self { + Self { + endpoint: ep, + middleware: middleware.to_vec(), + } + } +} + +impl Endpoint for MiddlewareEndpoint +where + E: Endpoint, +{ + fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, Response> { + let next = Next { + endpoint: &self.endpoint, + next_middleware: &self.middleware, + }; + next.run(req) + } +} diff --git a/src/router.rs b/src/router.rs index 6b358a07b..5d10cacba 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,7 +1,7 @@ use route_recognizer::{Match, Params, Router as MethodRouter}; use std::collections::HashMap; -use crate::endpoint::{DynEndpoint, Endpoint}; +use crate::endpoint::DynEndpoint; use crate::utils::BoxFuture; use crate::{Request, Response}; @@ -29,15 +29,15 @@ impl Router { } } - pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: impl Endpoint) { + pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: Box>) { self.method_map .entry(method) .or_insert_with(MethodRouter::new) - .add(path, Box::new(ep)) + .add(path, ep) } - pub(crate) fn add_all(&mut self, path: &str, ep: impl Endpoint) { - self.all_method_router.add(path, Box::new(ep)) + pub(crate) fn add_all(&mut self, path: &str, ep: Box>) { + self.all_method_router.add(path, ep) } pub(crate) fn route(&self, path: &str, method: http::Method) -> Selection<'_, State> { diff --git a/src/server/route.rs b/src/server/route.rs index 3d6649918..e4c250aff 100644 --- a/src/server/route.rs +++ b/src/server/route.rs @@ -1,5 +1,8 @@ +use std::sync::Arc; + +use crate::endpoint::MiddlewareEndpoint; use crate::utils::BoxFuture; -use crate::{router::Router, Endpoint, Response}; +use crate::{router::Router, Endpoint, Middleware, Response}; /// A handle to a route. /// @@ -13,6 +16,7 @@ use crate::{router::Router, Endpoint, Response}; pub struct Route<'a, State> { router: &'a mut Router, path: String, + middleware: Vec>>, /// Indicates whether the path of current route is treated as a prefix. Set by /// [`strip_prefix`]. /// @@ -25,11 +29,14 @@ impl<'a, State: 'static> Route<'a, State> { Route { router, path, + middleware: Vec::new(), prefix: false, } } /// Extend the route with the given `path`. + /// + /// The returned route won't have any middleware applied. pub fn at<'b>(&'b mut self, path: &str) -> Route<'b, State> { let mut p = self.path.clone(); @@ -44,6 +51,7 @@ impl<'a, State: 'static> Route<'a, State> { Route { router: &mut self.router, path: p, + middleware: Vec::new(), prefix: false, } } @@ -60,6 +68,18 @@ impl<'a, State: 'static> Route<'a, State> { self } + /// Apply the given middleware to the current route. + pub fn middleware(&mut self, middleware: impl Middleware) -> &mut Self { + self.middleware.push(Arc::new(middleware)); + self + } + + /// Reset the middleware chain for the current route, if any. + pub fn reset_middleware(&mut self) -> &mut Self { + self.middleware.clear(); + self + } + /// Nest a [`Server`] at the current path. /// /// [`Server`]: struct.Server.html @@ -78,10 +98,29 @@ impl<'a, State: 'static> Route<'a, State> { pub fn method(&mut self, method: http::Method, ep: impl Endpoint) -> &mut Self { if self.prefix { let ep = StripPrefixEndpoint::new(ep); - self.router.add(&self.path, method.clone(), ep.clone()); + let (ep1, ep2): (Box>, Box>) = + if self.middleware.is_empty() { + let ep = Box::new(ep); + (ep.clone(), ep) + } else { + let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware( + ep, + &self.middleware, + )); + (ep.clone(), ep) + }; + self.router.add(&self.path, method.clone(), ep1); let wildcard = self.at("*--tide-path-rest"); - wildcard.router.add(&wildcard.path, method, ep); + wildcard.router.add(&wildcard.path, method, ep2); } else { + let ep: Box> = if self.middleware.is_empty() { + Box::new(ep) + } else { + Box::new(MiddlewareEndpoint::wrap_with_middleware( + ep, + &self.middleware, + )) + }; self.router.add(&self.path, method, ep); } self @@ -93,10 +132,29 @@ impl<'a, State: 'static> Route<'a, State> { pub fn all(&mut self, ep: impl Endpoint) -> &mut Self { if self.prefix { let ep = StripPrefixEndpoint::new(ep); - self.router.add_all(&self.path, ep.clone()); + let (ep1, ep2): (Box>, Box>) = + if self.middleware.is_empty() { + let ep = Box::new(ep); + (ep.clone(), ep) + } else { + let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware( + ep, + &self.middleware, + )); + (ep.clone(), ep) + }; + self.router.add_all(&self.path, ep1); let wildcard = self.at("*--tide-path-rest"); - wildcard.router.add_all(&wildcard.path, ep); + wildcard.router.add_all(&wildcard.path, ep2); } else { + let ep: Box> = if self.middleware.is_empty() { + Box::new(ep) + } else { + Box::new(MiddlewareEndpoint::wrap_with_middleware( + ep, + &self.middleware, + )) + }; self.router.add_all(&self.path, ep); } self diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs new file mode 100644 index 000000000..2f49d71e1 --- /dev/null +++ b/tests/route_middleware.rs @@ -0,0 +1,81 @@ +use async_std::io::prelude::*; +use futures::executor::block_on; +use futures::future::BoxFuture; +use http_service::Body; +use http_service_mock::make_server; +use tide::Middleware; + +struct TestMiddleware(&'static str); + +impl Middleware for TestMiddleware { + fn handle<'a>( + &'a self, + req: tide::Request, + next: tide::Next<'a, State>, + ) -> BoxFuture<'a, tide::Response> { + Box::pin(async move { + let res = next.run(req).await; + res.set_header("X-Tide-Test", self.0) + }) + } +} + +async fn echo_path(req: tide::Request) -> String { + req.uri().path().to_string() +} + +#[test] +fn route_middleware() { + let mut app = tide::new(); + let mut foo_route = app.at("/foo"); + foo_route.middleware(TestMiddleware("foo")) + .get(echo_path); + foo_route.at("/bar") + .middleware(TestMiddleware("bar")) + .get(echo_path); + foo_route.post(echo_path) + .reset_middleware() + .put(echo_path); + let mut server = make_server(app.into_http_service()).unwrap(); + + let mut buf = Vec::new(); + let req = http::Request::get("/foo").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!( + res.headers().get("X-Tide-Test"), + Some(&"foo".parse().unwrap()) + ); + assert_eq!(res.status(), 200); + block_on(res.into_body().read_to_end(&mut buf)).unwrap(); + assert_eq!(&*buf, &*b"/foo"); + + buf.clear(); + let req = http::Request::post("/foo").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!( + res.headers().get("X-Tide-Test"), + Some(&"foo".parse().unwrap()) + ); + assert_eq!(res.status(), 200); + block_on(res.into_body().read_to_end(&mut buf)).unwrap(); + assert_eq!(&*buf, &*b"/foo"); + + buf.clear(); + let req = http::Request::put("/foo").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Tide-Test"), None); + assert_eq!(res.status(), 200); + block_on(res.into_body().read_to_end(&mut buf)).unwrap(); + assert_eq!(&*buf, &*b"/foo"); + + buf.clear(); + let req = http::Request::get("/foo/bar").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!( + res.headers().get("X-Tide-Test"), + Some(&"bar".parse().unwrap()) + ); + assert_eq!(res.status(), 200); + block_on(res.into_body().read_to_end(&mut buf)).unwrap(); + assert_eq!(&*buf, &*b"/foo/bar"); +} From 7fc25778f242396993dadbd17e516247dcc21fbd Mon Sep 17 00:00:00 2001 From: Wonwoo Choi Date: Sat, 1 Feb 2020 22:26:11 +0900 Subject: [PATCH 2/4] cargo fmt --all --- tests/route_middleware.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs index 2f49d71e1..9425d3739 100644 --- a/tests/route_middleware.rs +++ b/tests/route_middleware.rs @@ -28,14 +28,12 @@ async fn echo_path(req: tide::Request) -> String { fn route_middleware() { let mut app = tide::new(); let mut foo_route = app.at("/foo"); - foo_route.middleware(TestMiddleware("foo")) - .get(echo_path); - foo_route.at("/bar") + foo_route.middleware(TestMiddleware("foo")).get(echo_path); + foo_route + .at("/bar") .middleware(TestMiddleware("bar")) .get(echo_path); - foo_route.post(echo_path) - .reset_middleware() - .put(echo_path); + foo_route.post(echo_path).reset_middleware().put(echo_path); let mut server = make_server(app.into_http_service()).unwrap(); let mut buf = Vec::new(); From a4ee7febd9b1fd2e6939495521cb46e539d4e1d2 Mon Sep 17 00:00:00 2001 From: Wonwoo Choi Date: Sun, 2 Feb 2020 16:29:13 +0900 Subject: [PATCH 3/4] Preserve per-route middleware on nested subroutes --- src/server/route.rs | 4 +- tests/route_middleware.rs | 77 +++++++++++++++++++++------------------ 2 files changed, 43 insertions(+), 38 deletions(-) diff --git a/src/server/route.rs b/src/server/route.rs index e4c250aff..237a49965 100644 --- a/src/server/route.rs +++ b/src/server/route.rs @@ -35,8 +35,6 @@ impl<'a, State: 'static> Route<'a, State> { } /// Extend the route with the given `path`. - /// - /// The returned route won't have any middleware applied. pub fn at<'b>(&'b mut self, path: &str) -> Route<'b, State> { let mut p = self.path.clone(); @@ -51,7 +49,7 @@ impl<'a, State: 'static> Route<'a, State> { Route { router: &mut self.router, path: p, - middleware: Vec::new(), + middleware: self.middleware.clone(), prefix: false, } } diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs index 9425d3739..bc0ce3d6f 100644 --- a/tests/route_middleware.rs +++ b/tests/route_middleware.rs @@ -1,11 +1,15 @@ -use async_std::io::prelude::*; -use futures::executor::block_on; use futures::future::BoxFuture; use http_service::Body; use http_service_mock::make_server; use tide::Middleware; -struct TestMiddleware(&'static str); +struct TestMiddleware(&'static str, &'static str); + +impl TestMiddleware { + fn with_header_name(name: &'static str, value: &'static str) -> Self { + Self(name, value) + } +} impl Middleware for TestMiddleware { fn handle<'a>( @@ -15,7 +19,7 @@ impl Middleware for TestMiddleware { ) -> BoxFuture<'a, tide::Response> { Box::pin(async move { let res = next.run(req).await; - res.set_header("X-Tide-Test", self.0) + res.set_header(self.0, self.1) }) } } @@ -28,52 +32,55 @@ async fn echo_path(req: tide::Request) -> String { fn route_middleware() { let mut app = tide::new(); let mut foo_route = app.at("/foo"); - foo_route.middleware(TestMiddleware("foo")).get(echo_path); + foo_route // /foo + .middleware(TestMiddleware::with_header_name("X-Foo", "foo")) + .get(echo_path); foo_route - .at("/bar") - .middleware(TestMiddleware("bar")) + .at("/bar") // nested, /foo/bar + .middleware(TestMiddleware::with_header_name("X-Bar", "bar")) .get(echo_path); - foo_route.post(echo_path).reset_middleware().put(echo_path); + foo_route // /foo + .post(echo_path) + .reset_middleware() + .put(echo_path); let mut server = make_server(app.into_http_service()).unwrap(); - let mut buf = Vec::new(); let req = http::Request::get("/foo").body(Body::empty()).unwrap(); let res = server.simulate(req).unwrap(); - assert_eq!( - res.headers().get("X-Tide-Test"), - Some(&"foo".parse().unwrap()) - ); - assert_eq!(res.status(), 200); - block_on(res.into_body().read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"/foo"); + assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap())); - buf.clear(); let req = http::Request::post("/foo").body(Body::empty()).unwrap(); let res = server.simulate(req).unwrap(); - assert_eq!( - res.headers().get("X-Tide-Test"), - Some(&"foo".parse().unwrap()) - ); - assert_eq!(res.status(), 200); - block_on(res.into_body().read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"/foo"); + assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap())); - buf.clear(); let req = http::Request::put("/foo").body(Body::empty()).unwrap(); let res = server.simulate(req).unwrap(); - assert_eq!(res.headers().get("X-Tide-Test"), None); - assert_eq!(res.status(), 200); - block_on(res.into_body().read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"/foo"); + assert_eq!(res.headers().get("X-Foo"), None); - buf.clear(); let req = http::Request::get("/foo/bar").body(Body::empty()).unwrap(); let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap())); + assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap())); +} + +#[test] +fn subroute_not_nested() { + let mut app = tide::new(); + app.at("/parent") // /parent + .middleware(TestMiddleware::with_header_name("X-Parent", "Parent")) + .get(echo_path); + app.at("/parent/child") // /parent/child, not nested + .middleware(TestMiddleware::with_header_name("X-Child", "child")) + .get(echo_path); + let mut server = make_server(app.into_http_service()).unwrap(); + + let req = http::Request::get("/parent/child") + .body(Body::empty()) + .unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Parent"), None); assert_eq!( - res.headers().get("X-Tide-Test"), - Some(&"bar".parse().unwrap()) + res.headers().get("X-Child"), + Some(&"child".parse().unwrap()) ); - assert_eq!(res.status(), 200); - block_on(res.into_body().read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"/foo/bar"); } From e4f2f2d93ef01ed14f0ddb52d1b9eafc3372eeb7 Mon Sep 17 00:00:00 2001 From: Wonwoo Choi Date: Sun, 2 Feb 2020 16:43:16 +0900 Subject: [PATCH 4/4] Add more tests --- tests/route_middleware.rs | 64 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs index bc0ce3d6f..6eca04f31 100644 --- a/tests/route_middleware.rs +++ b/tests/route_middleware.rs @@ -63,6 +63,70 @@ fn route_middleware() { assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap())); } +#[test] +fn app_and_route_middleware() { + let mut app = tide::new(); + app.middleware(TestMiddleware::with_header_name("X-Root", "root")); + app.at("/foo") + .middleware(TestMiddleware::with_header_name("X-Foo", "foo")) + .get(echo_path); + app.at("/bar") + .middleware(TestMiddleware::with_header_name("X-Bar", "bar")) + .get(echo_path); + let mut server = make_server(app.into_http_service()).unwrap(); + + let req = http::Request::get("/foo").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap())); + assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap())); + assert_eq!(res.headers().get("X-Bar"), None); + + let req = http::Request::get("/bar").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap())); + assert_eq!(res.headers().get("X-Foo"), None); + assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap())); +} + +#[test] +fn nested_app_with_route_middleware() { + let mut inner = tide::new(); + inner.middleware(TestMiddleware::with_header_name("X-Inner", "inner")); + inner + .at("/baz") + .middleware(TestMiddleware::with_header_name("X-Baz", "baz")) + .get(echo_path); + + let mut app = tide::new(); + app.middleware(TestMiddleware::with_header_name("X-Root", "root")); + app.at("/foo") + .middleware(TestMiddleware::with_header_name("X-Foo", "foo")) + .get(echo_path); + app.at("/bar") + .middleware(TestMiddleware::with_header_name("X-Bar", "bar")) + .nest(inner); + let mut server = make_server(app.into_http_service()).unwrap(); + + let req = http::Request::get("/foo").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap())); + assert_eq!(res.headers().get("X-Inner"), None); + assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap())); + assert_eq!(res.headers().get("X-Bar"), None); + assert_eq!(res.headers().get("X-Baz"), None); + + let req = http::Request::get("/bar/baz").body(Body::empty()).unwrap(); + let res = server.simulate(req).unwrap(); + assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap())); + assert_eq!( + res.headers().get("X-Inner"), + Some(&"inner".parse().unwrap()) + ); + assert_eq!(res.headers().get("X-Foo"), None); + assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap())); + assert_eq!(res.headers().get("X-Baz"), Some(&"baz".parse().unwrap())); +} + #[test] fn subroute_not_nested() { let mut app = tide::new();