Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-route middleware #399

Merged
merged 4 commits into from
Feb 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::sync::Arc;

use async_std::future::Future;

use crate::middleware::Next;
use crate::utils::BoxFuture;
use crate::{response::IntoResponse, Request, Response};
use crate::{response::IntoResponse, Middleware, Request, Response};

/// An HTTP request handler.
///
Expand Down Expand Up @@ -63,3 +66,52 @@ where
Box::pin(async move { fut.await.into_response() })
}
}

pub struct MiddlewareEndpoint<E, State> {
endpoint: E,
middleware: Vec<Arc<dyn Middleware<State>>>,
}

impl<E: Clone, State> Clone for MiddlewareEndpoint<E, State> {
fn clone(&self) -> Self {
Self {
endpoint: self.endpoint.clone(),
middleware: self.middleware.clone(),
}
}
}

impl<E, State> std::fmt::Debug for MiddlewareEndpoint<E, State> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
fmt,
"MiddlewareEndpoint (length: {})",
self.middleware.len(),
)
}
}

impl<E, State> MiddlewareEndpoint<E, State>
where
E: Endpoint<State>,
{
pub fn wrap_with_middleware(ep: E, middleware: &[Arc<dyn Middleware<State>>]) -> Self {
Self {
endpoint: ep,
middleware: middleware.to_vec(),
}
}
}

impl<E, State: 'static> Endpoint<State> for MiddlewareEndpoint<E, State>
where
E: Endpoint<State>,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, Response> {
let next = Next {
endpoint: &self.endpoint,
next_middleware: &self.middleware,
};
next.run(req)
}
}
10 changes: 5 additions & 5 deletions src/router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use route_recognizer::{Match, Params, Router as MethodRouter};
use std::collections::HashMap;

use crate::endpoint::{DynEndpoint, Endpoint};
use crate::endpoint::DynEndpoint;
use crate::utils::BoxFuture;
use crate::{Request, Response};

Expand Down Expand Up @@ -29,15 +29,15 @@ impl<State: 'static> Router<State> {
}
}

pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: impl Endpoint<State>) {
pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: Box<DynEndpoint<State>>) {
self.method_map
.entry(method)
.or_insert_with(MethodRouter::new)
.add(path, Box::new(ep))
.add(path, ep)
}

pub(crate) fn add_all(&mut self, path: &str, ep: impl Endpoint<State>) {
self.all_method_router.add(path, Box::new(ep))
pub(crate) fn add_all(&mut self, path: &str, ep: Box<DynEndpoint<State>>) {
self.all_method_router.add(path, ep)
}

pub(crate) fn route(&self, path: &str, method: http::Method) -> Selection<'_, State> {
Expand Down
66 changes: 61 additions & 5 deletions src/server/route.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::sync::Arc;

use crate::endpoint::MiddlewareEndpoint;
use crate::utils::BoxFuture;
use crate::{router::Router, Endpoint, Response};
use crate::{router::Router, Endpoint, Middleware, Response};

/// A handle to a route.
///
Expand All @@ -13,6 +16,7 @@ use crate::{router::Router, Endpoint, Response};
pub struct Route<'a, State> {
router: &'a mut Router<State>,
path: String,
middleware: Vec<Arc<dyn Middleware<State>>>,
/// Indicates whether the path of current route is treated as a prefix. Set by
/// [`strip_prefix`].
///
Expand All @@ -25,6 +29,7 @@ impl<'a, State: 'static> Route<'a, State> {
Route {
router,
path,
middleware: Vec::new(),
prefix: false,
}
}
Expand All @@ -44,6 +49,7 @@ impl<'a, State: 'static> Route<'a, State> {
Route {
router: &mut self.router,
path: p,
middleware: self.middleware.clone(),
prefix: false,
}
}
Expand All @@ -60,6 +66,18 @@ impl<'a, State: 'static> Route<'a, State> {
self
}

/// Apply the given middleware to the current route.
pub fn middleware(&mut self, middleware: impl Middleware<State>) -> &mut Self {
self.middleware.push(Arc::new(middleware));
self
}

/// Reset the middleware chain for the current route, if any.
pub fn reset_middleware(&mut self) -> &mut Self {
self.middleware.clear();
self
}

/// Nest a [`Server`] at the current path.
///
/// [`Server`]: struct.Server.html
Expand All @@ -78,10 +96,29 @@ impl<'a, State: 'static> Route<'a, State> {
pub fn method(&mut self, method: http::Method, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);
self.router.add(&self.path, method.clone(), ep.clone());
let (ep1, ep2): (Box<dyn Endpoint<_>>, Box<dyn Endpoint<_>>) =
if self.middleware.is_empty() {
let ep = Box::new(ep);
(ep.clone(), ep)
} else {
let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
));
(ep.clone(), ep)
};
self.router.add(&self.path, method.clone(), ep1);
let wildcard = self.at("*--tide-path-rest");
wildcard.router.add(&wildcard.path, method, ep);
wildcard.router.add(&wildcard.path, method, ep2);
} else {
let ep: Box<dyn Endpoint<_>> = if self.middleware.is_empty() {
Box::new(ep)
} else {
Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
))
};
self.router.add(&self.path, method, ep);
}
self
Expand All @@ -93,10 +130,29 @@ impl<'a, State: 'static> Route<'a, State> {
pub fn all(&mut self, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);
self.router.add_all(&self.path, ep.clone());
let (ep1, ep2): (Box<dyn Endpoint<_>>, Box<dyn Endpoint<_>>) =
if self.middleware.is_empty() {
let ep = Box::new(ep);
(ep.clone(), ep)
} else {
let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
));
(ep.clone(), ep)
};
self.router.add_all(&self.path, ep1);
let wildcard = self.at("*--tide-path-rest");
wildcard.router.add_all(&wildcard.path, ep);
wildcard.router.add_all(&wildcard.path, ep2);
} else {
let ep: Box<dyn Endpoint<_>> = if self.middleware.is_empty() {
Box::new(ep)
} else {
Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
))
};
self.router.add_all(&self.path, ep);
}
self
Expand Down
150 changes: 150 additions & 0 deletions tests/route_middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use futures::future::BoxFuture;
use http_service::Body;
use http_service_mock::make_server;
use tide::Middleware;

struct TestMiddleware(&'static str, &'static str);

impl TestMiddleware {
fn with_header_name(name: &'static str, value: &'static str) -> Self {
Self(name, value)
}
}

impl<State: Send + Sync + 'static> Middleware<State> for TestMiddleware {
fn handle<'a>(
&'a self,
req: tide::Request<State>,
next: tide::Next<'a, State>,
) -> BoxFuture<'a, tide::Response> {
Box::pin(async move {
let res = next.run(req).await;
res.set_header(self.0, self.1)
})
}
}

async fn echo_path<State>(req: tide::Request<State>) -> String {
req.uri().path().to_string()
}

#[test]
fn route_middleware() {
let mut app = tide::new();
let mut foo_route = app.at("/foo");
foo_route // /foo
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
.get(echo_path);
foo_route
.at("/bar") // nested, /foo/bar
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
.get(echo_path);
foo_route // /foo
.post(echo_path)
.reset_middleware()
.put(echo_path);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));

let req = http::Request::post("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));

let req = http::Request::put("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), None);

let req = http::Request::get("/foo/bar").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
}

#[test]
fn app_and_route_middleware() {
let mut app = tide::new();
app.middleware(TestMiddleware::with_header_name("X-Root", "root"));
app.at("/foo")
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
.get(echo_path);
app.at("/bar")
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
.get(echo_path);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
assert_eq!(res.headers().get("X-Bar"), None);

let req = http::Request::get("/bar").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(res.headers().get("X-Foo"), None);
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
}

#[test]
fn nested_app_with_route_middleware() {
let mut inner = tide::new();
inner.middleware(TestMiddleware::with_header_name("X-Inner", "inner"));
inner
.at("/baz")
.middleware(TestMiddleware::with_header_name("X-Baz", "baz"))
.get(echo_path);

let mut app = tide::new();
app.middleware(TestMiddleware::with_header_name("X-Root", "root"));
app.at("/foo")
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
.get(echo_path);
app.at("/bar")
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
.nest(inner);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(res.headers().get("X-Inner"), None);
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
assert_eq!(res.headers().get("X-Bar"), None);
assert_eq!(res.headers().get("X-Baz"), None);

let req = http::Request::get("/bar/baz").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(
res.headers().get("X-Inner"),
Some(&"inner".parse().unwrap())
);
assert_eq!(res.headers().get("X-Foo"), None);
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
assert_eq!(res.headers().get("X-Baz"), Some(&"baz".parse().unwrap()));
}

#[test]
fn subroute_not_nested() {
let mut app = tide::new();
app.at("/parent") // /parent
.middleware(TestMiddleware::with_header_name("X-Parent", "Parent"))
.get(echo_path);
app.at("/parent/child") // /parent/child, not nested
.middleware(TestMiddleware::with_header_name("X-Child", "child"))
.get(echo_path);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/parent/child")
.body(Body::empty())
.unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Parent"), None);
assert_eq!(
res.headers().get("X-Child"),
Some(&"child".parse().unwrap())
);
}