diff --git a/src/log/middleware.rs b/src/log/middleware.rs index 7ce87e723..716df4ebe 100644 --- a/src/log/middleware.rs +++ b/src/log/middleware.rs @@ -16,6 +16,8 @@ pub struct LogMiddleware { _priv: (), } +struct LogMiddlewareHasBeenRun; + impl LogMiddleware { /// Create a new instance of `LogMiddleware`. #[must_use] @@ -26,17 +28,22 @@ impl LogMiddleware { /// Log a request and a response. async fn log<'a, State: Clone + Send + Sync + 'static>( &'a self, - ctx: Request, + mut req: Request, next: Next<'a, State>, ) -> crate::Result { - let path = ctx.url().path().to_owned(); - let method = ctx.method().to_string(); + if req.ext::().is_some() { + return Ok(next.run(req).await); + } + req.set_ext(LogMiddlewareHasBeenRun); + + let path = req.url().path().to_owned(); + let method = req.method().to_string(); log::info!("<-- Request received", { method: method, path: path, }); let start = std::time::Instant::now(); - let response = next.run(ctx).await; + let response = next.run(req).await; let status = response.status(); if status.is_server_error() { if let Some(error) = response.error() { diff --git a/tests/log.rs b/tests/log.rs index 7c4fb6c7b..d6ddb9b74 100644 --- a/tests/log.rs +++ b/tests/log.rs @@ -2,11 +2,16 @@ use async_std::prelude::*; use std::time::Duration; mod test_utils; +use test_utils::ServerTestingExt; #[async_std::test] -async fn start_server_log() { +async fn log_tests() { let mut logger = logtest::start(); + test_server_listen(&mut logger).await; + test_only_log_once(&mut logger).await; +} +async fn test_server_listen(logger: &mut logtest::Logger) { let port = test_utils::find_port().await; let app = tide::new(); let res = app @@ -23,3 +28,31 @@ async fn start_server_log() { format!("Server listening on http://[::1]:{}", port) ); } + +async fn test_only_log_once(logger: &mut logtest::Logger) { + let mut app = tide::new(); + app.at("/").nest({ + let mut app = tide::new(); + app.at("/").get(|_| async { Ok("nested") }); + app + }); + app.get("/").await; + + let entries: Vec<_> = logger.collect(); + + assert_eq!( + 1, + entries + .iter() + .filter(|entry| entry.args() == "<-- Request received") + .count() + ); + + assert_eq!( + 1, + entries + .iter() + .filter(|entry| entry.args() == "--> Response sent") + .count() + ); +}