Skip to content

Commit

Permalink
Make waiter.queued atomic
Browse files Browse the repository at this point in the history
  • Loading branch information
vnetserg committed Jan 18, 2024
1 parent 4f98a68 commit 21f8a44
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
use std::task::{Context, Poll, Waker};
use std::usize;

Expand Down Expand Up @@ -354,7 +355,7 @@ struct Slot<T> {
/// An entry in the wait queue.
struct Waiter {
/// True if queued.
queued: bool,
queued: AtomicBool,

/// Task waiting on the broadcast channel.
waker: Option<Waker>,
Expand All @@ -369,7 +370,7 @@ struct Waiter {
impl Waiter {
fn new() -> Self {
Self {
queued: false,
queued: AtomicBool::new(false),
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
Expand Down Expand Up @@ -901,8 +902,7 @@ impl<T> Shared<T> {
// Safety: `tail` lock is still held.
let waiter = unsafe { waiter.as_mut() };

assert!(waiter.queued);
waiter.queued = false;
assert!(waiter.queued.swap(false, Release));

if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
Expand Down Expand Up @@ -1104,8 +1104,7 @@ impl<T> Receiver<T> {
}
}

if !(*ptr).queued {
(*ptr).queued = true;
if !(*ptr).queued.swap(true, Relaxed) {
tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
}
});
Expand Down Expand Up @@ -1357,7 +1356,7 @@ impl<'a, T> Recv<'a, T> {
Recv {
receiver,
waiter: UnsafeCell::new(Waiter {
queued: false,
queued: AtomicBool::new(false),
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
Expand Down Expand Up @@ -1402,22 +1401,30 @@ where

impl<'a, T> Drop for Recv<'a, T> {
fn drop(&mut self) {
// Acquire the tail lock. This is required for safety before accessing
// the waiter node.
let mut tail = self.receiver.shared.tail.lock();

// safety: tail lock is held
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
let queued = self
.waiter
.with(|ptr| unsafe { (*ptr).queued.load(Acquire) });

if queued {
// Remove the node
//
// safety: tail lock is held and the wait node is verified to be in
// the list.
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
// Acquire the tail lock. This is required for safety before accessing
// the waiter node.
let mut tail = self.receiver.shared.tail.lock();

// safety: tail lock is held
let queued = self
.waiter
.with(|ptr| unsafe { (*ptr).queued.load(Relaxed) });

if queued {
// Remove the node
//
// safety: tail lock is held and the wait node is verified to be in
// the list.
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
}
}
}
}
Expand Down

0 comments on commit 21f8a44

Please sign in to comment.