Skip to content

Commit

Permalink
sync: use AtomicBool in broadcast channel future (#6298)
Browse files Browse the repository at this point in the history
  • Loading branch information
vnetserg authored Jan 27, 2024
1 parent b6d0c90 commit 7536132
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 31 deletions.
5 changes: 5 additions & 0 deletions benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ name = "spawn"
path = "spawn.rs"
harness = false

[[bench]]
name = "sync_broadcast"
path = "sync_broadcast.rs"
harness = false

[[bench]]
name = "sync_mpsc"
path = "sync_mpsc.rs"
Expand Down
82 changes: 82 additions & 0 deletions benches/sync_broadcast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use rand::{Rng, RngCore, SeedableRng};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, Notify};

use criterion::measurement::WallTime;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion};

fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(6)
.build()
.unwrap()
}

fn do_work(rng: &mut impl RngCore) -> u32 {
use std::fmt::Write;
let mut message = String::new();
for i in 1..=10 {
let _ = write!(&mut message, " {i}={}", rng.gen::<f64>());
}
message
.as_bytes()
.iter()
.map(|&c| c as u32)
.fold(0, u32::wrapping_add)
}

fn contention_impl<const N_TASKS: usize>(g: &mut BenchmarkGroup<WallTime>) {
let rt = rt();

let (tx, _rx) = broadcast::channel::<usize>(1000);
let wg = Arc::new((AtomicUsize::new(0), Notify::new()));

for n in 0..N_TASKS {
let wg = wg.clone();
let mut rx = tx.subscribe();
let mut rng = rand::rngs::StdRng::seed_from_u64(n as u64);
rt.spawn(async move {
while let Ok(_) = rx.recv().await {
let r = do_work(&mut rng);
let _ = black_box(r);
if wg.0.fetch_sub(1, Ordering::Relaxed) == 1 {
wg.1.notify_one();
}
}
});
}

const N_ITERS: usize = 100;

g.bench_function(N_TASKS.to_string(), |b| {
b.iter(|| {
rt.block_on({
let wg = wg.clone();
let tx = tx.clone();
async move {
for i in 0..N_ITERS {
assert_eq!(wg.0.fetch_add(N_TASKS, Ordering::Relaxed), 0);
tx.send(i).unwrap();
while wg.0.load(Ordering::Relaxed) > 0 {
wg.1.notified().await;
}
}
}
})
})
});
}

fn bench_contention(c: &mut Criterion) {
let mut group = c.benchmark_group("contention");
contention_impl::<10>(&mut group);
contention_impl::<100>(&mut group);
contention_impl::<500>(&mut group);
contention_impl::<1000>(&mut group);
group.finish();
}

criterion_group!(contention, bench_contention);

criterion_main!(contention);
89 changes: 58 additions & 31 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
//! ```
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::{AtomicBool, AtomicUsize};
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
use crate::util::WakeList;
Expand All @@ -127,7 +127,7 @@ use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
use std::task::{Context, Poll, Waker};
use std::usize;

Expand Down Expand Up @@ -354,7 +354,7 @@ struct Slot<T> {
/// An entry in the wait queue.
struct Waiter {
/// True if queued.
queued: bool,
queued: AtomicBool,

/// Task waiting on the broadcast channel.
waker: Option<Waker>,
Expand All @@ -369,7 +369,7 @@ struct Waiter {
impl Waiter {
fn new() -> Self {
Self {
queued: false,
queued: AtomicBool::new(false),
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
Expand Down Expand Up @@ -897,15 +897,22 @@ impl<T> Shared<T> {
'outer: loop {
while wakers.can_push() {
match list.pop_back_locked(&mut tail) {
Some(mut waiter) => {
// Safety: `tail` lock is still held.
let waiter = unsafe { waiter.as_mut() };

assert!(waiter.queued);
waiter.queued = false;

if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
Some(waiter) => {
unsafe {
// Safety: accessing `waker` is safe because
// the tail lock is held.
if let Some(waker) = (*waiter.as_ptr()).waker.take() {
wakers.push(waker);
}

// Safety: `queued` is atomic.
let queued = &(*waiter.as_ptr()).queued;
// `Relaxed` suffices because the tail lock is held.
assert!(queued.load(Relaxed));
// `Release` is needed to synchronize with `Recv::drop`.
// It is critical to set this variable **after** waker
// is extracted, otherwise we may data race with `Recv::drop`.
queued.store(false, Release);
}
}
None => {
Expand Down Expand Up @@ -1104,8 +1111,13 @@ impl<T> Receiver<T> {
}
}

if !(*ptr).queued {
(*ptr).queued = true;
// If the waiter is not already queued, enqueue it.
// `Relaxed` order suffices: we have synchronized with
// all writers through the tail lock that we hold.
if !(*ptr).queued.load(Relaxed) {
// `Relaxed` order suffices: all the readers will
// synchronize with this write through the tail lock.
(*ptr).queued.store(true, Relaxed);
tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
}
});
Expand Down Expand Up @@ -1357,7 +1369,7 @@ impl<'a, T> Recv<'a, T> {
Recv {
receiver,
waiter: UnsafeCell::new(Waiter {
queued: false,
queued: AtomicBool::new(false),
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
Expand Down Expand Up @@ -1402,22 +1414,37 @@ where

impl<'a, T> Drop for Recv<'a, T> {
fn drop(&mut self) {
// Acquire the tail lock. This is required for safety before accessing
// the waiter node.
let mut tail = self.receiver.shared.tail.lock();

// safety: tail lock is held
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });

// Safety: `waiter.queued` is atomic.
// Acquire ordering is required to synchronize with
// `Shared::notify_rx` before we drop the object.
let queued = self
.waiter
.with(|ptr| unsafe { (*ptr).queued.load(Acquire) });

// If the waiter is queued, we need to unlink it from the waiters list.
// If not, no further synchronization is required, since the waiter
// is not in the list and, as such, is not shared with any other threads.
if queued {
// Remove the node
//
// safety: tail lock is held and the wait node is verified to be in
// the list.
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
// Acquire the tail lock. This is required for safety before accessing
// the waiter node.
let mut tail = self.receiver.shared.tail.lock();

// Safety: tail lock is held.
// `Relaxed` order suffices because we hold the tail lock.
let queued = self
.waiter
.with_mut(|ptr| unsafe { (*ptr).queued.load(Relaxed) });

if queued {
// Remove the node
//
// safety: tail lock is held and the wait node is verified to be in
// the list.
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
}
}
}
}
Expand Down

0 comments on commit 7536132

Please sign in to comment.