From 2a7d47adda559764c7b4ba8dc64b44ead0385def Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 6 Jan 2021 12:18:20 -0800 Subject: [PATCH] rewrite `tower::filter` (#508) ## Motivation It was pointed out that there is currently some overlap between the `try_with` `Service` combinator and `tower::filter` middleware (see https://github.com/tower-rs/tower/pull/499#discussion_r549522471 ). `try_with` synchronously maps from a `Request` -> `Result`, while `tower::filter` _asynchronously_ maps from a `&Request` to a `Result<(), Error>`. The key differences are: - `try_with` takes a request by value, and allows the predicate to return a *different* request value - `try_with` also permits changing the _type_ of the request - `try_with` is synchronous, while `tower::filter` is asynchronous - `tower::filter` has a `Predicate` trait, which can be implemented by more than just functions. For example, a struct with a `HashSet` could implement `Predicate` by failing requests that match the values in the hashset. It definitely seems like there's demand for both synchronous and asynchronous request filtering. However, the APIs we have currently differ pretty significantly. It would be nice to make them more consistent with each other. As an aside, `tower::filter` [does not seem all that widely used][1]. Meanwhile, `linkerd2-proxy` defines its own `RequestFilter` middleware, using a [predicate trait][2] that's essentially in between `tower::filter` and `ServiceExt::try_with`: - it's synchronous, like `try_with` - it allows modifying the type of the request, like `try_with` - it uses a trait for predicates, rather than a `Fn`, like `tower::filter` - it uses a similar naming scheme to `tower::filter` ("filtering" rather than "with"/"map"). [1]: https://github.com/search?l=&p=1&q=%22tower%3A%3Afilter%22+extension%3Ars&ref=advsearch&type=Code [2]: https://github.com/linkerd/linkerd2-proxy/blob/24bee8cbc5413b4587a14bea1e2714ce1f1f919a/linkerd/stack/src/request_filter.rs#L8-L12 ## Solution This branch rewrites `tower::filter` to make the following changes: * Predicates are synchronous by default. A separate `AsyncFilter` type and an `AsyncPredicate` trait are available for predicates returning futures. * Predicates may now return a new `Request` type, allowing `Filter` and `AsyncFilter` to subsume `try_map_request`. * Predicates may now return any error type, and errors are now converted into `BoxError`s. Closes #502 Signed-off-by: Eliza Weisman --- tower/Cargo.toml | 2 +- tower/src/builder/mod.rs | 57 +++++---- tower/src/filter/error.rs | 46 ------- tower/src/filter/future.rs | 83 ++++++------ tower/src/filter/layer.rs | 56 ++++++++- tower/src/filter/mod.rs | 118 +++++++++++++++--- tower/src/filter/predicate.rs | 56 +++++++-- tower/src/hedge/mod.rs | 19 +-- tower/src/lib.rs | 4 +- tower/src/macros.rs | 24 ++++ tower/src/util/mod.rs | 113 +++++++++++++++-- tower/src/util/try_map_request.rs | 70 ----------- .../tests/filter/{main.rs => async_filter.rs} | 6 +- 13 files changed, 423 insertions(+), 231 deletions(-) delete mode 100644 tower/src/filter/error.rs create mode 100644 tower/src/macros.rs delete mode 100644 tower/src/util/try_map_request.rs rename tower/tests/filter/{main.rs => async_filter.rs} (91%) diff --git a/tower/Cargo.toml b/tower/Cargo.toml index accf38851..d85d93a62 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -29,7 +29,7 @@ log = ["tracing/log"] balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio-stream"] buffer = ["tokio/sync", "tokio/rt", "tokio-stream"] discover = [] -filter = [] +filter = ["futures-util"] hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time"] limit = ["tokio/time", "tokio/sync"] load = ["tokio/time"] diff --git a/tower/src/builder/mod.rs b/tower/src/builder/mod.rs index 3efd091e7..03ea919aa 100644 --- a/tower/src/builder/mod.rs +++ b/tower/src/builder/mod.rs @@ -239,6 +239,42 @@ impl ServiceBuilder { self.layer(crate::timeout::TimeoutLayer::new(timeout)) } + /// Conditionally reject requests based on `predicate`. + /// + /// `predicate` must implement the [`Predicate`] trait. + /// + /// This wraps the inner service with an instance of the [`Filter`] + /// middleware. + /// + /// [`Filter`]: crate::filter + /// [`Predicate`]: crate::filter::Predicate + #[cfg(feature = "filter")] + #[cfg_attr(docsrs, doc(cfg(feature = "filter")))] + pub fn filter

( + self, + predicate: P, + ) -> ServiceBuilder, L>> { + self.layer(crate::filter::FilterLayer::new(predicate)) + } + + /// Conditionally reject requests based on an asynchronous `predicate`. + /// + /// `predicate` must implement the [`AsyncPredicate`] trait. + /// + /// This wraps the inner service with an instance of the [`AsyncFilter`] + /// middleware. + /// + /// [`AsyncFilter`]: crate::filter::AsyncFilter + /// [`AsyncPredicate`]: crate::filter::AsyncPredicate + #[cfg(feature = "filter")] + #[cfg_attr(docsrs, doc(cfg(feature = "filter")))] + pub fn filter_async

( + self, + predicate: P, + ) -> ServiceBuilder, L>> { + self.layer(crate::filter::AsyncFilterLayer::new(predicate)) + } + /// Map one request type to another. /// /// This wraps the inner service with an instance of the [`MapRequest`] @@ -311,27 +347,6 @@ impl ServiceBuilder { self.layer(crate::util::MapRequestLayer::new(f)) } - /// Fallibly one request type to another, or to an error. - /// - /// This wraps the inner service with an instance of the [`TryMapRequest`] - /// middleware. - /// - /// See the documentation for the [`try_map_request` combinator] for details. - /// - /// [`TryMapRequest`]: crate::util::MapResponse - /// [`try_map_request` combinator]: crate::util::ServiceExt::try_map_request - #[cfg(feature = "util")] - #[cfg_attr(docsrs, doc(cfg(feature = "util")))] - pub fn try_map_request( - self, - f: F, - ) -> ServiceBuilder, L>> - where - F: FnMut(R1) -> Result + Clone, - { - self.layer(crate::util::TryMapRequestLayer::new(f)) - } - /// Map one response type to another. /// /// This wraps the inner service with an instance of the [`MapResponse`] diff --git a/tower/src/filter/error.rs b/tower/src/filter/error.rs deleted file mode 100644 index 548100d31..000000000 --- a/tower/src/filter/error.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! Error types - -use std::{error, fmt}; - -/// Error produced by `Filter` -#[derive(Debug)] -pub struct Error { - source: Option, -} - -impl Error { - /// Create a new `Error` representing a rejected request. - pub fn rejected() -> Error { - Error { source: None } - } - - /// Create a new `Error` representing an inner service error. - pub fn inner(source: E) -> Error - where - E: Into, - { - Error { - source: Some(source.into()), - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - if self.source.is_some() { - write!(fmt, "inner service errored") - } else { - write!(fmt, "rejected") - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - if let Some(ref err) = self.source { - Some(&**err) - } else { - None - } - } -} diff --git a/tower/src/filter/future.rs b/tower/src/filter/future.rs index d4b3aa8a1..7577887a1 100644 --- a/tower/src/filter/future.rs +++ b/tower/src/filter/future.rs @@ -1,6 +1,7 @@ //! Future types -use super::error::Error; +use super::AsyncPredicate; +use crate::BoxError; use futures_core::ready; use pin_project::pin_project; use std::{ @@ -10,79 +11,77 @@ use std::{ }; use tower_service::Service; -/// Filtered response future +/// Filtered response future from [`AsyncFilter`] services. +/// +/// [`AsyncFilter`]: crate::filter::AsyncFilter #[pin_project] #[derive(Debug)] -pub struct ResponseFuture +pub struct AsyncResponseFuture where - S: Service, + P: AsyncPredicate, + S: Service, { #[pin] - /// Response future state - state: State, - - #[pin] - /// Predicate future - check: T, + state: State, /// Inner service service: S, } +opaque_future! { + /// Filtered response future from [`Filter`] services. + /// + /// [`Filter`]: crate::filter::Filter + pub type ResponseFuture = + futures_util::future::Either< + futures_util::future::Ready>, + futures_util::future::ErrInto + >; +} + #[pin_project(project = StateProj)] #[derive(Debug)] -enum State { - Check(Option), - WaitResponse(#[pin] U), +enum State { + /// Waiting for the predicate future + Check(#[pin] F), + /// Waiting for the response future + WaitResponse(#[pin] G), } -impl ResponseFuture +impl AsyncResponseFuture where - F: Future>, - S: Service, - S::Error: Into, + P: AsyncPredicate, + S: Service, + S::Error: Into, { - pub(crate) fn new(request: Request, check: F, service: S) -> Self { - ResponseFuture { - state: State::Check(Some(request)), - check, + pub(crate) fn new(check: P::Future, service: S) -> Self { + Self { + state: State::Check(check), service, } } } -impl Future for ResponseFuture +impl Future for AsyncResponseFuture where - F: Future>, - S: Service, + P: AsyncPredicate, + S: Service, S::Error: Into, { - type Output = Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); loop { match this.state.as_mut().project() { - StateProj::Check(request) => { - let request = request - .take() - .expect("we either give it back or leave State::Check once we take"); - - // Poll predicate - match this.check.as_mut().poll(cx)? { - Poll::Ready(_) => { - let response = this.service.call(request); - this.state.set(State::WaitResponse(response)); - } - Poll::Pending => { - this.state.set(State::Check(Some(request))); - return Poll::Pending; - } - } + StateProj::Check(mut check) => { + let request = ready!(check.as_mut().poll(cx))?; + let response = this.service.call(request); + this.state.set(State::WaitResponse(response)); } StateProj::WaitResponse(response) => { - return Poll::Ready(ready!(response.poll(cx)).map_err(Error::inner)); + return response.poll(cx).map_err(Into::into); } } } diff --git a/tower/src/filter/layer.rs b/tower/src/filter/layer.rs index 92f368e20..d86408954 100644 --- a/tower/src/filter/layer.rs +++ b/tower/src/filter/layer.rs @@ -1,16 +1,42 @@ -use super::Filter; +use super::{AsyncFilter, Filter}; use tower_layer::Layer; -/// Conditionally dispatch requests to the inner service based on a predicate. +/// Conditionally dispatch requests to the inner service based on a synchronous +/// [predicate]. +/// +/// This [`Layer`] produces instances of the [`Filter`] service. +/// +/// [predicate]: crate::filter::Predicate +/// [`Layer`]: crate::Layer +/// [`Filter`]: crate::filter::Filter #[derive(Debug)] pub struct FilterLayer { predicate: U, } +/// Conditionally dispatch requests to the inner service based on an asynchronous +/// [predicate]. +/// +/// This [`Layer`] produces instances of the [`AsyncFilter`] service. +/// +/// [predicate]: crate::filter::AsyncPredicate +/// [`Layer`]: crate::Layer +/// [`Filter`]: crate::filter::AsyncFilter +#[derive(Debug)] +pub struct AsyncFilterLayer { + predicate: U, +} + +// === impl FilterLayer === + impl FilterLayer { - #[allow(missing_docs)] + /// Returns a new layer that produces [`Filter`] services with the given + /// [`Predicate`]. + /// + /// [`Predicate`]: crate::filter::Predicate + /// [`Filter`]: crate::filter::Filter pub fn new(predicate: U) -> Self { - FilterLayer { predicate } + Self { predicate } } } @@ -22,3 +48,25 @@ impl Layer for FilterLayer { Filter::new(service, predicate) } } + +// === impl AsyncFilterLayer === + +impl AsyncFilterLayer { + /// Returns a new layer that produces [`AsyncFilter`] services with the given + /// [`AsyncPredicate`]. + /// + /// [`AsyncPredicate`]: crate::filter::AsyncPredicate + /// [`Filter`]: crate::filter::Filter + pub fn new(predicate: U) -> Self { + Self { predicate } + } +} + +impl Layer for AsyncFilterLayer { + type Service = AsyncFilter; + + fn layer(&self, service: S) -> Self::Service { + let predicate = self.predicate.clone(); + AsyncFilter::new(service, predicate) + } +} diff --git a/tower/src/filter/mod.rs b/tower/src/filter/mod.rs index 6613592fd..81f3788cf 100644 --- a/tower/src/filter/mod.rs +++ b/tower/src/filter/mod.rs @@ -1,55 +1,145 @@ //! Conditionally dispatch requests to the inner service based on the result of //! a predicate. - -pub mod error; +//! +//! A predicate takes some request type and returns a `Result`. +//! If the predicate returns `Ok`, the inner service is called with the request +//! returned by the predicate — which may be the original request or a +//! modified one. If the predicate returns `Err`, the request is rejected and +//! the inner service is not called. +//! +//! Predicates may either be synchronous (simple functions from a `Request` to +//! a `Result`) or asynchronous (functions returning `Future`s). Separate +//! traits, [`Predicate`] and [`AsyncPredicate`], represent these two types of +//! predicate. Note that when it is not necessary to await some other +//! asynchronous operation in the predicate, the synchronous predicate should be +//! preferred, as it introduces less overhead. +//! +//! The predicate traits are implemented for closures and function pointers. +//! However, users may also implement them for other types, such as when the +//! predicate requires some state carried between requests. For example, +//! `Predicate` could be implemented for a type that rejects a fixed set of +//! requests by checking if they are contained by a a [`HashSet`] or other +//! collection. +//! +//! [`HashSet`]: std::sync::HashSet pub mod future; mod layer; mod predicate; -pub use self::{layer::FilterLayer, predicate::Predicate}; +pub use self::{ + layer::{AsyncFilterLayer, FilterLayer}, + predicate::{AsyncPredicate, Predicate}, +}; -use self::{error::Error, future::ResponseFuture}; +use self::future::{AsyncResponseFuture, ResponseFuture}; +use crate::BoxError; use futures_core::ready; +use futures_util::{future::Either, TryFutureExt}; use std::task::{Context, Poll}; use tower_service::Service; -/// Conditionally dispatch requests to the inner service based on a predicate. +/// Conditionally dispatch requests to the inner service based on a [predicate]. +/// +/// [predicate]: Predicate #[derive(Clone, Debug)] pub struct Filter { inner: T, predicate: U, } +/// Conditionally dispatch requests to the inner service based on an +/// [asynchronous predicate]. +/// +/// [asynchronous predicate]: AsyncPredicate +#[derive(Clone, Debug)] +pub struct AsyncFilter { + inner: T, + predicate: U, +} + +// ==== impl Filter ==== + impl Filter { - #[allow(missing_docs)] + /// Returns a new `Filter` service wrapping `inner`. pub fn new(inner: T, predicate: U) -> Self { - Filter { inner, predicate } + Self { inner, predicate } + } + + /// Returns a new [`Layer`] that wraps services with a `Filter` service + /// with the given [`Predicate`]. + /// + /// [`Layer`]: crate::Layer + pub fn layer(predicate: U) -> FilterLayer { + FilterLayer::new(predicate) } } impl Service for Filter where - T: Service + Clone, - T::Error: Into, U: Predicate, + T: Service, + T::Error: Into, +{ + type Response = T::Response; + type Error = BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, request: Request) -> Self::Future { + ResponseFuture(match self.predicate.check(request) { + Ok(request) => Either::Right(self.inner.call(request).err_into()), + Err(e) => Either::Left(futures_util::future::ready(Err(e.into()))), + }) + } +} + +// ==== impl AsyncFilter ==== + +impl AsyncFilter { + /// Returns a new `AsyncFilter` service wrapping `inner`. + pub fn new(inner: T, predicate: U) -> Self { + Self { inner, predicate } + } + + /// Returns a new [`Layer`] that wraps services with an `AsyncFilter` + /// service with the given [`AsyncPredicate`]. + /// + /// [`Layer`]: crate::Layer + pub fn layer(predicate: U) -> FilterLayer { + FilterLayer::new(predicate) + } +} + +impl Service for AsyncFilter +where + U: AsyncPredicate, + T: Service + Clone, + T::Error: Into, { type Response = T::Response; - type Error = Error; - type Future = ResponseFuture; + type Error = BoxError; + type Future = AsyncResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Poll::Ready(ready!(self.inner.poll_ready(cx)).map_err(error::Error::inner)) + self.inner.poll_ready(cx).map_err(Into::into) } fn call(&mut self, request: Request) -> Self::Future { use std::mem; let inner = self.inner.clone(); + // In case the inner service has state that's driven to readiness and + // not tracked by clones (such as `Buffer`), pass the version we have + // already called `poll_ready` on into the future, and leave its clone + // behind. let inner = mem::replace(&mut self.inner, inner); // Check the request - let check = self.predicate.check(&request); + let check = self.predicate.check(request); - ResponseFuture::new(request, check, inner) + AsyncResponseFuture::new(check, inner) } } diff --git a/tower/src/filter/predicate.rs b/tower/src/filter/predicate.rs index 52b3936aa..e1f5c220f 100644 --- a/tower/src/filter/predicate.rs +++ b/tower/src/filter/predicate.rs @@ -1,25 +1,59 @@ -use super::error::Error; +use crate::BoxError; use std::future::Future; -/// Checks a request -pub trait Predicate { +/// Checks a request asynchronously. +pub trait AsyncPredicate { /// The future returned by `check`. - type Future: Future>; + type Future: Future>; + + /// The type of requests returned by `check`. + /// + /// This request is forwarded to the inner service if the predicate + /// succeeds. + type Request; + + /// Check whether the given request should be forwarded. + /// + /// If the future resolves with `Ok`, the request is forwarded to the inner service. + fn check(&mut self, request: Request) -> Self::Future; +} +/// Checks a request synchronously. +pub trait Predicate { + /// The type of requests returned by `check`. + /// + /// This request is forwarded to the inner service if the predicate + /// succeeds. + type Request; /// Check whether the given request should be forwarded. /// /// If the future resolves with `Ok`, the request is forwarded to the inner service. - fn check(&mut self, request: &Request) -> Self::Future; + fn check(&mut self, request: Request) -> Result; +} + +impl AsyncPredicate for F +where + F: FnMut(T) -> U, + U: Future>, + E: Into, +{ + type Future = futures_util::future::ErrInto; + type Request = R; + + fn check(&mut self, request: T) -> Self::Future { + use futures_util::TryFutureExt; + self(request).err_into() + } } -impl Predicate for F +impl Predicate for F where - F: Fn(&T) -> U, - U: Future>, + F: FnMut(T) -> Result, + E: Into, { - type Future = U; + type Request = R; - fn check(&mut self, request: &T) -> Self::Future { - self(request) + fn check(&mut self, request: T) -> Result { + self(request).map_err(Into::into) } } diff --git a/tower/src/hedge/mod.rs b/tower/src/hedge/mod.rs index 35522404c..e4d134753 100644 --- a/tower/src/hedge/mod.rs +++ b/tower/src/hedge/mod.rs @@ -3,7 +3,7 @@ #![warn(missing_debug_implementations, missing_docs, unreachable_pub)] -use crate::filter::Filter; +use crate::filter::AsyncFilter; use futures_util::future; use pin_project::pin_project; use std::sync::{Arc, Mutex}; @@ -28,7 +28,7 @@ type Histo = Arc>; type Service = select::Select< SelectPolicy

, Latency, - Delay, PolicyPredicate

>>, + Delay, PolicyPredicate

>>, >; /// A middleware that pre-emptively retries requests which have been outstanding /// for longer than a given latency percentile. If either of the original @@ -138,7 +138,7 @@ impl Hedge { let recorded_b = Latency::new(histo.clone(), service); // Check policy to see if the hedge request should be issued. - let filtered = Filter::new(recorded_b, PolicyPredicate(policy.clone())); + let filtered = AsyncFilter::new(recorded_b, PolicyPredicate(policy.clone())); // Delay the second request by a percentile of the recorded request latency // histogram. @@ -213,18 +213,19 @@ impl latency::Record for Histo { } } -impl crate::filter::Predicate for PolicyPredicate

+impl crate::filter::AsyncPredicate for PolicyPredicate

where P: Policy, { type Future = future::Either< - future::Ready>, - future::Pending>, + future::Ready>, + future::Pending>, >; + type Request = Request; - fn check(&mut self, request: &Request) -> Self::Future { - if self.0.can_retry(request) { - future::Either::Left(future::ready(Ok(()))) + fn check(&mut self, request: Request) -> Self::Future { + if self.0.can_retry(&request) { + future::Either::Left(future::ready(Ok(request))) } else { // If the hedge retry should not be issued, we simply want to wait // for the result of the original request. Therefore we don't want diff --git a/tower/src/lib.rs b/tower/src/lib.rs index 4ef3554cd..9ed2ce661 100644 --- a/tower/src/lib.rs +++ b/tower/src/lib.rs @@ -88,6 +88,8 @@ //! [`tower-test`]: https://crates.io/crates/tower-test //! [`retry`]: crate::retry //! [`timeout`]: crate::timeout +#[macro_use] +pub(crate) mod macros; #[cfg(feature = "balance")] #[cfg_attr(docsrs, doc(cfg(feature = "balance")))] pub mod balance; @@ -112,6 +114,7 @@ pub mod load; #[cfg(feature = "load-shed")] #[cfg_attr(docsrs, doc(cfg(feature = "load-shed")))] pub mod load_shed; + #[cfg(feature = "make")] #[cfg_attr(docsrs, doc(cfg(feature = "make")))] pub mod make; @@ -155,7 +158,6 @@ pub use crate::make::MakeService; pub use tower_layer::Layer; #[doc(inline)] pub use tower_service::Service; - #[cfg(any(feature = "buffer", feature = "limit"))] mod semaphore; diff --git a/tower/src/macros.rs b/tower/src/macros.rs new file mode 100644 index 000000000..ea13204da --- /dev/null +++ b/tower/src/macros.rs @@ -0,0 +1,24 @@ +macro_rules! opaque_future { + ($(#[$m:meta])* pub type $name:ident<$($param:ident),+> = $actual:ty;) => { + #[pin_project::pin_project] + $(#[$m])* + pub struct $name<$($param),+>(#[pin] pub(crate) $actual); + + impl<$($param),+> std::fmt::Debug for $name<$($param),+> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(stringify!($name)).field(&format_args!("...")).finish() + } + } + + impl<$($param),+> std::future::Future for $name<$($param),+> + where + $actual: std::future::Future, + { + type Output = <$actual as std::future::Future>::Output; + #[inline] + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + self.project().0.poll(cx) + } + } + } +} diff --git a/tower/src/util/mod.rs b/tower/src/util/mod.rs index 7537779cd..24dde2377 100644 --- a/tower/src/util/mod.rs +++ b/tower/src/util/mod.rs @@ -15,7 +15,6 @@ mod optional; mod ready; mod service_fn; mod then; -mod try_map_request; pub use self::{ boxed::{BoxService, UnsyncBoxService}, @@ -30,7 +29,6 @@ pub use self::{ ready::{ReadyAnd, ReadyOneshot}, service_fn::{service_fn, ServiceFn}, then::{Then, ThenLayer}, - try_map_request::{TryMapRequest, TryMapRequestLayer}, }; pub use self::call_all::{CallAll, CallAllUnordered}; @@ -517,10 +515,11 @@ pub trait ServiceExt: tower_service::Service { MapRequest::new(self, f) } - /// Composes a fallible function *in front of* the service. + /// Composes this service with a [`Filter`] that conditionally accepts or + /// rejects requests based on a [predicate]. /// /// This adapter produces a new service that passes each value through the - /// given function `f` before sending it to `self`. + /// given function `predicate` before sending it to `self`. /// /// # Example /// ``` @@ -535,10 +534,14 @@ pub trait ServiceExt: tower_service::Service { /// # } /// # } /// # - /// # enum DbError { + /// # #[derive(Debug)] enum DbError { /// # Parse(std::num::ParseIntError) /// # } /// # + /// # impl std::fmt::Display for DbError { + /// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { std::fmt::Debug::fmt(self, f) } + /// # } + /// # impl std::error::Error for DbError {} /// # impl Service for DatabaseService { /// # type Response = String; /// # type Error = DbError; @@ -560,7 +563,7 @@ pub trait ServiceExt: tower_service::Service { /// /// // Fallibly map the request to a new request /// let mut new_service = service - /// .try_map_request(|id_str: &str| id_str.parse().map_err(DbError::Parse)); + /// .filter(|id_str: &str| id_str.parse().map_err(DbError::Parse)); /// /// // Call the new service /// let id = "13"; @@ -573,12 +576,104 @@ pub trait ServiceExt: tower_service::Service { /// # }; /// # } /// ``` - fn try_map_request(self, f: F) -> TryMapRequest + /// + /// [`Filter`]: crate::filter::Filter + /// [predicate]: crate::filter::Predicate + #[cfg(feature = "filter")] + #[cfg_attr(docsrs, doc(cfg(feature = "filter")))] + fn filter(self, filter: F) -> crate::filter::Filter + where + Self: Sized, + F: crate::filter::Predicate, + { + crate::filter::Filter::new(self, filter) + } + + /// Composes this service with an [`AsyncFilter`] that conditionally accepts or + /// rejects requests based on an [async predicate]. + /// + /// This adapter produces a new service that passes each value through the + /// given function `predicate` before sending it to `self`. + /// + /// # Example + /// ``` + /// # use std::convert::TryFrom; + /// # use std::task::{Poll, Context}; + /// # use tower::{Service, ServiceExt}; + /// # + /// # #[derive(Clone)] struct DatabaseService; + /// # impl DatabaseService { + /// # fn new(address: &str) -> Self { + /// # DatabaseService + /// # } + /// # } + /// # #[derive(Debug)] + /// # enum DbError { + /// # Rejected + /// # } + /// # impl std::fmt::Display for DbError { + /// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { std::fmt::Debug::fmt(self, f) } + /// # } + /// # impl std::error::Error for DbError {} + /// # + /// # impl Service for DatabaseService { + /// # type Response = String; + /// # type Error = DbError; + /// # type Future = futures_util::future::Ready>; + /// # + /// # fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + /// # Poll::Ready(Ok(())) + /// # } + /// # + /// # fn call(&mut self, request: u32) -> Self::Future { + /// # futures_util::future::ready(Ok(String::new())) + /// # } + /// # } + /// # + /// # fn main() { + /// # async { + /// // A service taking a u32 as a request and returning Result<_, DbError> + /// let service = DatabaseService::new("127.0.0.1:8080"); + /// + /// /// Returns `true` if we should query the database for an ID. + /// async fn should_query(id: u32) -> bool { + /// // ... + /// # true + /// } + /// + /// // Filter requests based on `should_query`. + /// let mut new_service = service + /// .filter_async(|id: u32| async move { + /// if should_query(id).await { + /// return Ok(id); + /// } + /// + /// Err(DbError::Rejected) + /// }); + /// + /// // Call the new service + /// let id = 13; + /// # let id: u32 = id; + /// let response = new_service + /// .ready_and() + /// .await? + /// .call(id) + /// .await; + /// # response + /// # }; + /// # } + /// ``` + /// + /// [`AsyncFilter`]: crate::filter::AsyncFilter + /// [asynchronous predicate]: crate::filter::AsyncPredicate + #[cfg(feature = "filter")] + #[cfg_attr(docsrs, doc(cfg(feature = "filter")))] + fn filter_async(self, filter: F) -> crate::filter::AsyncFilter where Self: Sized, - F: FnMut(NewRequest) -> Result, + F: crate::filter::AsyncPredicate, { - TryMapRequest::new(self, f) + crate::filter::AsyncFilter::new(self, filter) } /// Composes an asynchronous function *after* this service. diff --git a/tower/src/util/try_map_request.rs b/tower/src/util/try_map_request.rs deleted file mode 100644 index abcc51887..000000000 --- a/tower/src/util/try_map_request.rs +++ /dev/null @@ -1,70 +0,0 @@ -use futures_util::future::{ready, Either, Ready}; -use std::task::{Context, Poll}; -use tower_layer::Layer; -use tower_service::Service; - -/// Service returned by the [`try_map_request`] combinator. -/// -/// [`try_map_request`]: crate::util::ServiceExt::try_map_request -#[derive(Clone, Debug)] -pub struct TryMapRequest { - inner: S, - f: F, -} - -/// A [`Layer`] that produces a [`TryMapRequest`] service. -/// -/// [`Layer`]: tower_layer::Layer -#[derive(Debug)] -pub struct TryMapRequestLayer { - f: F, -} - -impl TryMapRequest { - /// Creates a new [`TryMapRequest`] service. - pub fn new(inner: S, f: F) -> Self { - TryMapRequest { inner, f } - } -} - -impl Service for TryMapRequest -where - S: Service, - F: FnMut(R1) -> Result, -{ - type Response = S::Response; - type Error = S::Error; - type Future = Either>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, request: R1) -> Self::Future { - match (self.f)(request) { - Ok(ok) => Either::Left(self.inner.call(ok)), - Err(err) => Either::Right(ready(Err(err))), - } - } -} - -impl TryMapRequestLayer { - /// Creates a new [`TryMapRequestLayer`]. - pub fn new(f: F) -> Self { - TryMapRequestLayer { f } - } -} - -impl Layer for TryMapRequestLayer -where - F: Clone, -{ - type Service = TryMapRequest; - - fn layer(&self, inner: S) -> Self::Service { - TryMapRequest { - f: self.f.clone(), - inner, - } - } -} diff --git a/tower/tests/filter/main.rs b/tower/tests/filter/async_filter.rs similarity index 91% rename from tower/tests/filter/main.rs rename to tower/tests/filter/async_filter.rs index 2bcd7b0e5..4a0a5a917 100644 --- a/tower/tests/filter/main.rs +++ b/tower/tests/filter/async_filter.rs @@ -3,7 +3,7 @@ mod support; use futures_util::{future::poll_fn, pin_mut}; use std::future::Future; -use tower::filter::{error::Error, Filter}; +use tower::filter::{error::Error, AsyncFilter}; use tower_service::Service; use tower_test::{assert_request_eq, mock}; @@ -52,12 +52,12 @@ async fn rejected_sync() { type Mock = mock::Mock; type Handle = mock::Handle; -fn new_service(f: F) -> (Filter, Handle) +fn new_service(f: F) -> (AsyncFilter, Handle) where F: Fn(&String) -> U, U: Future>, { let (service, handle) = mock::pair(); - let service = Filter::new(service, f); + let service = AsyncFilter::new(service, f); (service, handle) }