diff --git a/tower/src/util/boxed_clone/mod.rs b/tower/src/util/boxed_clone/mod.rs new file mode 100644 index 000000000..1573e48a7 --- /dev/null +++ b/tower/src/util/boxed_clone/mod.rs @@ -0,0 +1,7 @@ +mod sync; +mod unsync; + +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::{ + sync::BoxCloneService, unsync::UnsyncBoxCloneService, +}; diff --git a/tower/src/util/boxed_clone.rs b/tower/src/util/boxed_clone/sync.rs similarity index 99% rename from tower/src/util/boxed_clone.rs rename to tower/src/util/boxed_clone/sync.rs index 1209fd2ef..b7844c889 100644 --- a/tower/src/util/boxed_clone.rs +++ b/tower/src/util/boxed_clone/sync.rs @@ -1,4 +1,4 @@ -use super::ServiceExt; +use crate::ServiceExt; use futures_util::future::BoxFuture; use std::{ fmt, diff --git a/tower/src/util/boxed_clone/unsync.rs b/tower/src/util/boxed_clone/unsync.rs new file mode 100644 index 000000000..60aae0eac --- /dev/null +++ b/tower/src/util/boxed_clone/unsync.rs @@ -0,0 +1,90 @@ +use crate::ServiceExt; +use futures_util::future::LocalBoxFuture; +use std::{ + fmt, + task::{Context, Poll}, +}; +use tower_layer::{layer_fn, LayerFn}; +use tower_service::Service; + +/// A boxed [`CloneService`] trait object. +/// +/// This type alias represents a boxed future that is *not* [`Send`] and must +/// remain on the current thread. +pub struct UnsyncBoxCloneService( + Box< + dyn UnsyncCloneService>> + >, +); + +impl UnsyncBoxCloneService { + /// Create a new `BoxCloneService`. + pub fn new(inner: S) -> Self + where + S: Service + Clone + 'static, + S::Future: 'static, + { + let inner = inner.map_future(|f| Box::pin(f) as _); + UnsyncBoxCloneService(Box::new(inner)) + } + + /// Returns a [`Layer`] for wrapping a [`Service`] in a [`BoxCloneService`] + /// middleware. + /// + /// [`Layer`]: crate::Layer + pub fn layer() -> LayerFn Self> + where + S: Service + Clone + 'static, + S::Future: 'static, + { + layer_fn(Self::new) + } +} + +impl Service for UnsyncBoxCloneService { + type Response = U; + type Error = E; + type Future = LocalBoxFuture<'static, Result>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + #[inline] + fn call(&mut self, request: T) -> Self::Future { + self.0.call(request) + } +} + +impl Clone for UnsyncBoxCloneService { + fn clone(&self) -> Self { + Self(self.0.clone_box()) + } +} + +trait UnsyncCloneService: Service { + fn clone_box( + &self, + ) -> Box< + dyn UnsyncCloneService, + >; +} + +impl UnsyncCloneService for T +where + T: Service + Clone + 'static, +{ + fn clone_box( + &self, + ) -> Box> + { + Box::new(self.clone()) + } +} + +impl fmt::Debug for UnsyncBoxCloneService { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("BoxCloneService").finish() + } +} diff --git a/tower/src/util/mod.rs b/tower/src/util/mod.rs index c617a9e05..54e7da96c 100644 --- a/tower/src/util/mod.rs +++ b/tower/src/util/mod.rs @@ -24,7 +24,7 @@ pub mod rng; pub use self::{ and_then::{AndThen, AndThenLayer}, boxed::{BoxCloneServiceLayer, BoxLayer, BoxService, UnsyncBoxService}, - boxed_clone::BoxCloneService, + boxed_clone::{BoxCloneService, UnsyncBoxCloneService}, either::Either, future_service::{future_service, FutureService}, map_err::{MapErr, MapErrLayer}, diff --git a/tower/tests/util/service_fn.rs b/tower/tests/util/service_fn.rs index ac6bf06f3..10b6f54e2 100644 --- a/tower/tests/util/service_fn.rs +++ b/tower/tests/util/service_fn.rs @@ -1,5 +1,5 @@ use futures_util::future::ready; -use tower::util::service_fn; +use tower::util::{service_fn, UnsyncBoxCloneService}; use tower_service::Service; #[tokio::test(flavor = "current_thread")] @@ -10,3 +10,15 @@ async fn simple() { let answer = add_one.call(1).await.unwrap(); assert_eq!(answer, 2); } + +#[tokio::test(flavor = "current_thread")] +async fn boxed_clone() { + let _t = super::support::trace_init(); + let x = std::rc::Rc::new(1); + let mut add_one = service_fn(|req| ready(Ok::<_, ()>(req + 1))); + let mut cloned = UnsyncBoxCloneService::new(add_one); + let answer = cloned.call(1).await.unwrap(); + assert_eq!(answer, 2); + let answer = add_one.call(1).await.unwrap(); + assert_eq!(answer, 2); +}