Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pollable mpsc::Sender #3490

Merged
merged 7 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tokio-util/src/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ pub use cancellation_token::{CancellationToken, WaitForCancellationFuture};

mod intrusive_double_linked_list;

mod mpsc;
pub use mpsc::PollSender;

mod poll_semaphore;
pub use poll_semaphore::PollSemaphore;

Expand Down
224 changes: 224 additions & 0 deletions tokio-util/src/sync/mpsc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
use futures_core::ready;
use futures_sink::Sink;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::mpsc::{error::SendError, Sender};

use super::ReusableBoxFuture;

// This implementation was chosen over something based on permits because to get a
// `tokio::sync::mpsc::Permit` out of the `inner` future, you must transmute the
// lifetime on the permit to `'static`.

/// A wrapper around [`mpsc::Sender`] that can be polled.
///
/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
#[derive(Debug)]
pub struct PollSender<T> {
/// is none if closed
sender: Option<Arc<Sender<T>>>,
is_sending: bool,
inner: ReusableBoxFuture<Result<(), SendError<T>>>,
}

// By reusing the same async fn for both Some and None, we make sure every
// future passed to ReusableBoxFuture has the same underlying type, and hence
// the same size and alignment.
async fn make_future<T>(data: Option<(Arc<Sender<T>>, T)>) -> Result<(), SendError<T>> {
match data {
Some((sender, value)) => sender.send(value).await,
None => unreachable!(
"This future should not be pollable, as is_sending should be set to false."
),
}
}

impl<T: Send + 'static> PollSender<T> {
/// Create a new `PollSender`.
pub fn new(sender: Sender<T>) -> Self {
Self {
sender: Some(Arc::new(sender)),
is_sending: false,
inner: ReusableBoxFuture::new(make_future(None)),
}
}

/// Start sending a new item.
///
/// This method panics if a send is currently in progress. To ensure that no
/// send is in progress, call `poll_send_done` first until it returns
/// `Poll::Ready`.
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
///
/// If this method returns an error, that indicates that the channel is
/// closed. Note that this method is not guaranteed to return an error if
/// the channel is closed, but in that case the error would be reported by
/// the first call to `poll_send_done`.
pub fn start_send(&mut self, value: T) -> Result<(), SendError<T>> {
if self.is_sending {
panic!("start_send called while not ready.");
}
match self.sender.clone() {
Some(sender) => {
self.inner.set(make_future(Some((sender, value))));
self.is_sending = true;
Ok(())
}
None => Err(SendError(value)),
}
}

/// If a send is in progress, poll for its completion. If no send is in progress,
/// this method returns `Poll::Ready(Ok(()))`.
///
/// This method can return the following values:
///
/// - `Poll::Ready(Ok(()))` if the in-progress send has been completed, or there is
/// no send in progress (even if the channel is closed).
/// - `Poll::Ready(Err(err))` if the in-progress send failed because the channel has
/// been closed.
/// - `Poll::Pending` if a send is in progress, but it could not complete now.
///
/// When this method returns `Poll::Pending`, the current task is scheduled
/// to receive a wakeup when the message is sent, or when the entire channel
/// is closed (but not if just this sender is closed by
/// `close_this_sender`). Note that on multiple calls to `poll_send_done`,
/// only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// If this method returns `Poll::Ready`, then `start_send` is guaranteed to
/// not panic.
pub fn poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
if !self.is_sending {
return Poll::Ready(Ok(()));
}

let result = self.inner.poll(cx);
if result.is_ready() {
self.is_sending = false;
}
if let Poll::Ready(Err(_)) = &result {
self.sender = None;
}
result
}

/// Check whether the channel is ready to send more messages now.
///
/// If this method returns `true`, then `start_send` is guaranteed to not
/// panic.
///
/// If the channel is closed, this method returns `true`.
pub fn is_ready(&self) -> bool {
!self.is_sending
}

/// Check whether the channel has been closed.
pub fn is_closed(&self) -> bool {
match &self.sender {
Some(sender) => sender.is_closed(),
None => true,
}
}

/// Clone the underlying `Sender`.
///
/// If this method returns `None`, then the channel is closed. (But it is
/// not guaranteed to return `None` if the channel is closed.)
pub fn clone_inner(&self) -> Option<Sender<T>> {
match &self.sender {
Some(sender) => Some((&**sender).clone()),
None => None,
}
}
Comment on lines +124 to +133
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we have this rather than just exposing accessors for the inner sender?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an accessor too.


/// Access the underlying `Sender`.
///
/// If this method returns `None`, then the channel is closed. (But it is
/// not guaranteed to return `None` if the channel is closed.)
pub fn inner_ref(&self) -> Option<&Sender<T>> {
self.sender.as_deref()
}

// This operation is supported because it is required by the Sink trait.
/// Close this sender. No more messages can be sent from this sender.
///
/// Note that this only closes the channel from the view-point of this
/// sender. The channel remains open until all senders have gone away, or
/// until the [`Receiver`] closes the channel.
///
/// If there is a send in progress when this method is called, that send is
/// unaffected by this operation, and `poll_send_done` can still be called
/// to complete that send.
///
/// [`Receiver`]: tokio::sync::mpsc::Receiver
pub fn close_this_sender(&mut self) {
self.sender = None;
}

/// Abort the current in-progress send, if any.
///
/// Returns `true` if a send was aborted.
pub fn abort_send(&mut self) -> bool {
if self.is_sending {
self.inner.set(make_future(None));
self.is_sending = false;
true
} else {
false
}
}
}

impl<T> Clone for PollSender<T> {
/// Clones this `PollSender`. The resulting clone will not have any
/// in-progress send operations, even if the current `PollSender` does.
fn clone(&self) -> PollSender<T> {
Self {
sender: self.sender.clone(),
is_sending: false,
inner: ReusableBoxFuture::new(async { unreachable!() }),
}
}
}

impl<T: Send + 'static> Sink<T> for PollSender<T> {
type Error = SendError<T>;

/// This is equivalent to calling [`poll_send_done`].
///
/// [`poll_send_done`]: PollSender::poll_send_done
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).poll_send_done(cx)
}

/// This is equivalent to calling [`poll_send_done`].
///
/// [`poll_send_done`]: PollSender::poll_send_done
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).poll_send_done(cx)
}

/// This is equivalent to calling [`start_send`].
///
/// [`start_send`]: PollSender::start_send
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::into_inner(self).start_send(item)
}

/// This method will first flush the `PollSender`, and then close it by
/// calling [`close_this_sender`].
///
/// If a send fails while flushing because the [`Receiver`] has gone away,
/// then this function returns an error. The channel is still successfully
/// closed in this situation.
///
/// [`close_this_sender`]: PollSender::close_this_sender
/// [`Receiver`]: tokio::sync::mpsc::Receiver
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;

Pin::into_inner(self).close_this_sender();
Poll::Ready(Ok(()))
}
}
95 changes: 95 additions & 0 deletions tokio-util/tests/mpsc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use futures::future::poll_fn;
use tokio::sync::mpsc::channel;
use tokio_test::task::spawn;
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok};
use tokio_util::sync::PollSender;

#[tokio::test]
async fn test_simple() {
let (send, mut recv) = channel(3);
let mut send = PollSender::new(send);

for i in 1..=3i32 {
send.start_send(i).unwrap();
assert_ready_ok!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll());
}

send.start_send(4).unwrap();
let mut fourth_send = spawn(poll_fn(|cx| send.poll_send_done(cx)));
assert_pending!(fourth_send.poll());
assert_eq!(recv.recv().await.unwrap(), 1);
assert!(fourth_send.is_woken());
assert_ready_ok!(fourth_send.poll());

drop(recv);

// Here, start_send is not guaranteed to fail, but if it doesn't the first
// call to poll_send_done should.
if send.start_send(5).is_ok() {
assert_ready_err!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll());
}
}

#[tokio::test]
async fn test_abort() {
let (send, mut recv) = channel(3);
let mut send = PollSender::new(send);
let send2 = send.clone_inner().unwrap();

for i in 1..=3i32 {
send.start_send(i).unwrap();
assert_ready_ok!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll());
}

send.start_send(4).unwrap();
{
let mut fourth_send = spawn(poll_fn(|cx| send.poll_send_done(cx)));
assert_pending!(fourth_send.poll());
assert_eq!(recv.recv().await.unwrap(), 1);
assert!(fourth_send.is_woken());
}

let mut send2_send = spawn(send2.send(5));
assert_pending!(send2_send.poll());
send.abort_send();
assert!(send2_send.is_woken());
assert_ready_ok!(send2_send.poll());

assert_eq!(recv.recv().await.unwrap(), 2);
assert_eq!(recv.recv().await.unwrap(), 3);
assert_eq!(recv.recv().await.unwrap(), 5);
}

#[tokio::test]
async fn close_sender_last() {
let (send, mut recv) = channel::<i32>(3);
let mut send = PollSender::new(send);

let mut recv_task = spawn(recv.recv());
assert_pending!(recv_task.poll());

send.close_this_sender();

assert!(recv_task.is_woken());
assert!(assert_ready!(recv_task.poll()).is_none());
}

#[tokio::test]
async fn close_sender_not_last() {
let (send, mut recv) = channel::<i32>(3);
let send2 = send.clone();
let mut send = PollSender::new(send);

let mut recv_task = spawn(recv.recv());
assert_pending!(recv_task.poll());

send.close_this_sender();

assert!(!recv_task.is_woken());
assert_pending!(recv_task.poll());

drop(send2);

assert!(recv_task.is_woken());
assert!(assert_ready!(recv_task.poll()).is_none());
}