From 1c1e0e3fc9a463820955cf823b2d36b3162d746a Mon Sep 17 00:00:00 2001 From: Fuyang Liu Date: Fri, 5 Feb 2021 19:56:24 +0100 Subject: [PATCH] tokio-stream: add wrapper for broadcast and watch (#3384) --- examples/Cargo.toml | 2 +- tokio-stream/Cargo.toml | 3 +- tokio-stream/src/wrappers.rs | 7 +++ tokio-stream/src/wrappers/broadcast.rs | 62 ++++++++++++++++++++++++++ tokio-stream/src/wrappers/watch.rs | 60 +++++++++++++++++++++++++ 5 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 tokio-stream/src/wrappers/broadcast.rs create mode 100644 tokio-stream/src/wrappers/watch.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 1a86c81ef72..802930d820a 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -8,7 +8,7 @@ edition = "2018" # [dependencies] instead. [dev-dependencies] tokio = { version = "1.0.0", features = ["full", "tracing"] } -tokio-util = { version = "0.6.1", features = ["full"] } +tokio-util = { version = "0.6.3", features = ["full"] } tokio-stream = { version = "0.1" } async-stream = "0.3" diff --git a/tokio-stream/Cargo.toml b/tokio-stream/Cargo.toml index d662c387fec..0bc03ac0331 100644 --- a/tokio-stream/Cargo.toml +++ b/tokio-stream/Cargo.toml @@ -30,11 +30,12 @@ fs = ["tokio/fs"] futures-core = { version = "0.3.0" } pin-project-lite = "0.2.0" tokio = { version = "1.0", features = ["sync"] } +tokio-util = { version = "0.6.3" } [dev-dependencies] tokio = { version = "1.0", features = ["full", "test-util"] } -tokio-test = { path = "../tokio-test" } async-stream = "0.3" +tokio-test = { path = "../tokio-test" } futures = { version = "0.3", default-features = false } proptest = "0.10.0" diff --git a/tokio-stream/src/wrappers.rs b/tokio-stream/src/wrappers.rs index c0ffb234a09..405f35a5b61 100644 --- a/tokio-stream/src/wrappers.rs +++ b/tokio-stream/src/wrappers.rs @@ -6,6 +6,13 @@ pub use mpsc_bounded::ReceiverStream; mod mpsc_unbounded; pub use mpsc_unbounded::UnboundedReceiverStream; +mod broadcast; +pub use broadcast::BroadcastStream; +pub use broadcast::BroadcastStreamRecvError; + +mod watch; +pub use watch::WatchStream; + cfg_time! { mod interval; pub use interval::IntervalStream; diff --git a/tokio-stream/src/wrappers/broadcast.rs b/tokio-stream/src/wrappers/broadcast.rs new file mode 100644 index 00000000000..f3ff002355c --- /dev/null +++ b/tokio-stream/src/wrappers/broadcast.rs @@ -0,0 +1,62 @@ +use std::pin::Pin; +use tokio::sync::broadcast::error::RecvError; +use tokio::sync::broadcast::Receiver; + +use futures_core::Stream; +use tokio_util::sync::ReusableBoxFuture; + +use std::fmt; +use std::task::{Context, Poll}; + +/// A wrapper around [`tokio::sync::broadcast::Receiver`] that implements [`Stream`]. +/// +/// [`tokio::sync::broadcast::Receiver`]: struct@tokio::sync::broadcast::Receiver +/// [`Stream`]: trait@crate::Stream +pub struct BroadcastStream { + inner: ReusableBoxFuture<(Result, Receiver)>, +} + +/// An error returned from the inner stream of a [`BroadcastStream`]. +#[derive(Debug, PartialEq)] +pub enum BroadcastStreamRecvError { + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), +} + +async fn make_future(mut rx: Receiver) -> (Result, Receiver) { + let result = rx.recv().await; + (result, rx) +} + +impl BroadcastStream { + /// Create a new `BroadcastStream`. + pub fn new(rx: Receiver) -> Self { + Self { + inner: ReusableBoxFuture::new(make_future(rx)), + } + } +} + +impl Stream for BroadcastStream { + type Item = Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (result, rx) = ready!(self.inner.poll(cx)); + self.inner.set(make_future(rx)); + match result { + Ok(item) => Poll::Ready(Some(Ok(item))), + Err(RecvError::Closed) => Poll::Ready(None), + Err(RecvError::Lagged(n)) => { + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) + } + } + } +} + +impl fmt::Debug for BroadcastStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BroadcastStream").finish() + } +} diff --git a/tokio-stream/src/wrappers/watch.rs b/tokio-stream/src/wrappers/watch.rs new file mode 100644 index 00000000000..e58de918d29 --- /dev/null +++ b/tokio-stream/src/wrappers/watch.rs @@ -0,0 +1,60 @@ +use std::pin::Pin; +use tokio::sync::watch::Receiver; + +use futures_core::Stream; +use tokio_util::sync::ReusableBoxFuture; + +use std::fmt; +use std::task::{Context, Poll}; +use tokio::sync::watch::error::RecvError; + +/// A wrapper around [`tokio::sync::watch::Receiver`] that implements [`Stream`]. +/// +/// [`tokio::sync::watch::Receiver`]: struct@tokio::sync::watch::Receiver +/// [`Stream`]: trait@crate::Stream +pub struct WatchStream { + inner: ReusableBoxFuture<(Result<(), RecvError>, Receiver)>, +} + +async fn make_future( + mut rx: Receiver, +) -> (Result<(), RecvError>, Receiver) { + let result = rx.changed().await; + (result, rx) +} + +impl WatchStream { + /// Create a new `WatchStream`. + pub fn new(rx: Receiver) -> Self { + Self { + inner: ReusableBoxFuture::new(make_future(rx)), + } + } +} + +impl Stream for WatchStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (result, rx) = ready!(self.inner.poll(cx)); + match result { + Ok(_) => { + let received = (*rx.borrow()).clone(); + self.inner.set(make_future(rx)); + Poll::Ready(Some(received)) + } + Err(_) => { + self.inner.set(make_future(rx)); + Poll::Ready(None) + } + } + } +} + +impl Unpin for WatchStream {} + +impl fmt::Debug for WatchStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WatchStream").finish() + } +}