Skip to content

Commit

Permalink
util: fuse PollSemaphore (#3578)
Browse files Browse the repository at this point in the history
  • Loading branch information
Darksonn authored Mar 9, 2021
1 parent 05eeea5 commit db1d904
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tokio-util/src/sync/poll_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ impl PollSemaphore {
/// the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
match ready!(self.permit_fut.poll(cx)) {
Ok(permit) => {
let next_fut = Arc::clone(&self.semaphore).acquire_owned();
self.permit_fut.set(next_fut);
Poll::Ready(Some(permit))
}
let result = ready!(self.permit_fut.poll(cx));

let next_fut = Arc::clone(&self.semaphore).acquire_owned();
self.permit_fut.set(next_fut);

match result {
Ok(permit) => Poll::Ready(Some(permit)),
Err(_closed) => Poll::Ready(None),
}
}
Expand Down
36 changes: 36 additions & 0 deletions tokio-util/tests/poll_semaphore.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::future::Future;
use std::sync::Arc;
use std::task::Poll;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;

type SemRet = Option<OwnedSemaphorePermit>;

fn semaphore_poll<'a>(
sem: &'a mut PollSemaphore,
) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + 'a> {
let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx));
tokio_test::task::spawn(fut)
}

#[tokio::test]
async fn it_works() {
let sem = Arc::new(Semaphore::new(1));
let mut poll_sem = PollSemaphore::new(sem.clone());

let permit = sem.acquire().await.unwrap();
let mut poll = semaphore_poll(&mut poll_sem);
assert!(poll.poll().is_pending());
drop(permit);

assert!(matches!(poll.poll(), Poll::Ready(Some(_))));
drop(poll);

sem.close();

assert!(semaphore_poll(&mut poll_sem).await.is_none());

// Check that it is fused.
assert!(semaphore_poll(&mut poll_sem).await.is_none());
assert!(semaphore_poll(&mut poll_sem).await.is_none());
}

0 comments on commit db1d904

Please sign in to comment.