Skip to content

Commit

Permalink
Response: add optional error storage
Browse files Browse the repository at this point in the history
This allows errors to be propagated through the Tide middleware stack
while still keeping an existing Response intact. Effectively allowing
headers, body, etc to be passed along with an Error; while also opening
up options for nicer error handling via `res.downcast_error<>` checking.

Also contains APIs for turning an Error into a Response via Into, and
taking the error from / setting an error on an existing Response.

This is a breaking change in that `next.run().await` within middleware
no longer returns a tide::Result but rather always returns a
tide::Response.

Thanks to Jacob Rothstein for helping me get this to compile!

PR-URL: http-rs#570
  • Loading branch information
Fishrock123 committed Jul 7, 2020
1 parent e994b2e commit 1baff28
Show file tree
Hide file tree
Showing 17 changed files with 185 additions and 98 deletions.
8 changes: 3 additions & 5 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn user_loader<'a>(
if let Some(user) = request.state().find_user().await {
tide::log::trace!("user loaded", {user: user.name});
request.set_ext(user);
next.run(request).await
Ok(next.run(request).await)
// this middleware only needs to run before the endpoint, so
// it just passes through the result of Next
} else {
Expand Down Expand Up @@ -72,7 +72,7 @@ impl<State: Send + Sync + 'static> Middleware<State> for RequestCounterMiddlewar
tide::log::trace!("request counter", { count: count });
req.set_ext(RequestCount(count));

let mut res = next.run(req).await?;
let mut res = next.run(req).await;

res.insert_header("request-number", count.to_string());
Ok(res)
Expand Down Expand Up @@ -100,9 +100,7 @@ async fn main() -> Result<()> {
tide::log::start();
let mut app = tide::with_state(UserDatabase::default());

app.middleware(After(|result: Result| async move {
let response = result.unwrap_or_else(|e| Response::new(e.status()));

app.middleware(After(|response: Response| async move {
let response = match response.status() {
StatusCode::NotFound => Response::builder(404)
.content_type(mime::HTML)
Expand Down
2 changes: 1 addition & 1 deletion src/cookies/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<State: Send + Sync + 'static> Middleware<State> for CookiesMiddleware {
content
};

let mut res = next.run(ctx).await?;
let mut res = next.run(ctx).await;

// Don't do anything if there are no cookies.
if res.cookie_events.is_empty() {
Expand Down
13 changes: 8 additions & 5 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ use crate::{Middleware, Request, Response};
/// ```
///
/// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics.
pub trait Endpoint<State>: Send + Sync + 'static {
pub trait Endpoint<State: Send + Sync + 'static>: Send + Sync + 'static {
/// Invoke the endpoint within the given context
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result>;
}

pub(crate) type DynEndpoint<State> = dyn Endpoint<State>;

impl<State, F: Send + Sync + 'static, Fut, Res> Endpoint<State> for F
impl<State, F, Fut, Res> Endpoint<State> for F
where
F: Fn(Request<State>) -> Fut,
State: Send + Sync + 'static,
F: Send + Sync + 'static + Fn(Request<State>) -> Fut,
Fut: Future<Output = Result<Res>> + Send + 'static,
Res: Into<Response>,
{
Expand Down Expand Up @@ -92,6 +93,7 @@ impl<E, State> std::fmt::Debug for MiddlewareEndpoint<E, State> {

impl<E, State> MiddlewareEndpoint<E, State>
where
State: Send + Sync + 'static,
E: Endpoint<State>,
{
pub fn wrap_with_middleware(ep: E, middleware: &[Arc<dyn Middleware<State>>]) -> Self {
Expand All @@ -102,15 +104,16 @@ where
}
}

impl<E, State: 'static> Endpoint<State> for MiddlewareEndpoint<E, State>
impl<E, State> Endpoint<State> for MiddlewareEndpoint<E, State>
where
State: Send + Sync + 'static,
E: Endpoint<State>,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
let next = Next {
endpoint: &self.endpoint,
next_middleware: &self.middleware,
};
next.run(req)
Box::pin(async move { Ok(next.run(req).await) })
}
}
5 changes: 4 additions & 1 deletion src/fs/serve_dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ impl ServeDir {
}
}

impl<State> Endpoint<State> for ServeDir {
impl<State> Endpoint<State> for ServeDir
where
State: Send + Sync + 'static,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, Result> {
let path = req.url().path();
let path = path.trim_start_matches(&self.prefix);
Expand Down
60 changes: 24 additions & 36 deletions src/log/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,31 @@ impl LogMiddleware {
path: path,
});
let start = std::time::Instant::now();
match next.run(ctx).await {
Ok(res) => {
let status = res.status();
if status.is_server_error() {
log::error!("--> Response sent", {
method: method,
path: path,
status: status as u16,
duration: format!("{:?}", start.elapsed()),
});
} else if status.is_client_error() {
log::warn!("--> Response sent", {
method: method,
path: path,
status: status as u16,
duration: format!("{:?}", start.elapsed()),
});
} else {
log::info!("--> Response sent", {
method: method,
path: path,
status: status as u16,
duration: format!("{:?}", start.elapsed()),
});
}
Ok(res)
}
Err(err) => {
log::error!("{}", err.to_string(), {
method: method,
path: path,
status: err.status() as u16,
duration: format!("{:?}", start.elapsed()),
});
Err(err)
}
let response = next.run(ctx).await;
let status = response.status();
if status.is_server_error() {
log::error!("--> Response sent", {
method: method,
path: path,
status: status as u16,
duration: format!("{:?}", start.elapsed()),
});
} else if status.is_client_error() {
log::warn!("--> Response sent", {
method: method,
path: path,
status: status as u16,
duration: format!("{:?}", start.elapsed()),
});
} else {
log::info!("--> Response sent", {
method: method,
path: path,
status: status as u16,
duration: format!("{:?}", start.elapsed()),
});
}
Ok(response)
}
}

Expand Down
22 changes: 12 additions & 10 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use std::sync::Arc;

use crate::endpoint::DynEndpoint;
use crate::utils::BoxFuture;
use crate::Request;
use crate::{Request, Response};

/// Middleware that wraps around the remaining middleware chain.
pub trait Middleware<State>: 'static + Send + Sync {
pub trait Middleware<State>: Send + Sync + 'static {
/// Asynchronously handle the request, and return a response.
fn handle<'a>(
&'a self,
Expand Down Expand Up @@ -44,15 +44,17 @@ pub struct Next<'a, State> {
pub(crate) next_middleware: &'a [Arc<dyn Middleware<State>>],
}

impl<'a, State: 'static> Next<'a, State> {
impl<'a, State: Send + Sync + 'static> Next<'a, State> {
/// Asynchronously execute the remaining middleware chain.
#[must_use]
pub fn run(mut self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
if let Some((current, next)) = self.next_middleware.split_first() {
self.next_middleware = next;
current.handle(req, self)
} else {
self.endpoint.call(req)
}
pub fn run(mut self, req: Request<State>) -> BoxFuture<'a, Response> {
Box::pin(async move {
if let Some((current, next)) = self.next_middleware.split_first() {
self.next_middleware = next;
current.handle(req, self).await.into()
} else {
self.endpoint.call(req).await.into()
}
})
}
}
1 change: 1 addition & 0 deletions src/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ impl<T: AsRef<str>> Redirect<T> {

impl<State, T> Endpoint<State> for Redirect<T>
where
State: Send + Sync + 'static,
T: AsRef<str> + Send + Sync + 'static,
{
fn call<'a>(&'a self, _req: Request<State>) -> BoxFuture<'a, crate::Result<Response>> {
Expand Down
76 changes: 73 additions & 3 deletions src/response.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::convert::TryInto;
use std::fmt::Debug;
use std::fmt::{Debug, Display};
use std::ops::Index;

use crate::http::cookies::Cookie;
use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues};
use crate::http::Mime;
use crate::http::{self, Body, StatusCode};
use crate::http::{self, Body, Error, Mime, StatusCode};
use crate::ResponseBuilder;

#[derive(Debug)]
Expand All @@ -18,6 +17,7 @@ pub(crate) enum CookieEvent {
#[derive(Debug)]
pub struct Response {
pub(crate) res: http::Response,
pub(crate) error: Option<Error>,
// tracking here
pub(crate) cookie_events: Vec<CookieEvent>,
}
Expand All @@ -33,6 +33,7 @@ impl Response {
let res = http::Response::new(status);
Self {
res,
error: None,
cookie_events: vec![],
}
}
Expand Down Expand Up @@ -257,6 +258,54 @@ impl Response {
self.cookie_events.push(CookieEvent::Removed(cookie));
}

/// Returns an optional reference to an error if the response contains one.
pub fn error(&self) -> Option<&Error> {
self.error.as_ref()
}

/// Returns a reference to the original error associated with this response if there is one and
/// if it can be downcast to the specified type.
///
/// # Example
///
/// ```
/// # use std::io::ErrorKind;
/// # use async_std::task::block_on;
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
/// #
/// use tide::Response;
///
/// let error = std::io::Error::new(ErrorKind::Other, "oh no!");
/// let error = tide::http::Error::from(error);
///
/// let mut res = Response::new(400);
/// res.set_error(error);
///
/// if let Some(err) = res.downcast_error::<std::io::Error>() {
/// // Do something with the `std::io::Error`.
/// }
/// # Ok(())
/// # })}
pub fn downcast_error<E>(&self) -> Option<&E>
where
E: Display + Debug + Send + Sync + 'static,
{
self.error.as_ref()?.downcast_ref()
}

/// Takes the error from the response if one exists, replacing it with `None`.
pub fn take_error(&mut self) -> Option<Error> {
self.error.take()
}

/// Sets the response's error, overwriting any existing error.
///
/// This is particularly useful for middleware which would like to notify further
/// middleware that an error has occured without overwriting the existing response.
pub fn set_error(&mut self, error: impl Into<Error>) {
self.error = Some(error.into());
}

/// Get a response scoped extension value.
#[must_use]
pub fn ext<T: Send + Sync + 'static>(&self) -> Option<&T> {
Expand All @@ -277,6 +326,7 @@ impl Response {
let res: http_types::Response = value.into();
Self {
res,
error: None,
cookie_events: vec![],
}
}
Expand Down Expand Up @@ -328,10 +378,30 @@ impl From<serde_json::Value> for Response {
}
}

impl From<Error> for Response {
fn from(err: Error) -> Self {
Self {
res: http::Response::new(err.status()),
error: Some(err),
cookie_events: vec![],
}
}
}

impl From<crate::Result> for Response {
fn from(result: crate::Result) -> Self {
match result {
Ok(res) => res,
Err(err) => err.into(),
}
}
}

impl From<http::Response> for Response {
fn from(res: http::Response) -> Self {
Self {
res,
error: None,
cookie_events: vec![],
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct Route<'a, State> {
prefix: bool,
}

impl<'a, State: 'static> Route<'a, State> {
impl<'a, State: Send + Sync + 'static> Route<'a, State> {
pub(crate) fn new(router: &'a mut Router<State>, path: String) -> Route<'a, State> {
Route {
router,
Expand Down Expand Up @@ -274,7 +274,11 @@ impl<E> Clone for StripPrefixEndpoint<E> {
}
}

impl<State, E: Endpoint<State>> Endpoint<State> for StripPrefixEndpoint<E> {
impl<State, E> Endpoint<State> for StripPrefixEndpoint<E>
where
State: Send + Sync + 'static,
E: Endpoint<State>,
{
fn call<'a>(&'a self, req: crate::Request<State>) -> BoxFuture<'a, crate::Result> {
let crate::Request {
state,
Expand Down
10 changes: 7 additions & 3 deletions src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct Selection<'a, State> {
pub(crate) params: Params,
}

impl<State: 'static> Router<State> {
impl<State: Send + Sync + 'static> Router<State> {
pub fn new() -> Self {
Router {
method_map: HashMap::default(),
Expand Down Expand Up @@ -82,10 +82,14 @@ impl<State: 'static> Router<State> {
}
}

fn not_found_endpoint<State>(_req: Request<State>) -> BoxFuture<'static, crate::Result> {
fn not_found_endpoint<State: Send + Sync + 'static>(
_req: Request<State>,
) -> BoxFuture<'static, crate::Result> {
Box::pin(async { Ok(Response::new(StatusCode::NotFound)) })
}

fn method_not_allowed<State>(_req: Request<State>) -> BoxFuture<'static, crate::Result> {
fn method_not_allowed<State: Send + Sync + 'static>(
_req: Request<State>,
) -> BoxFuture<'static, crate::Result> {
Box::pin(async { Ok(Response::new(StatusCode::MethodNotAllowed)) })
}
Loading

0 comments on commit 1baff28

Please sign in to comment.