Skip to content

Commit

Permalink
Remove wakers for cancelled tasks
Browse files Browse the repository at this point in the history
When an async function does:

```
async f(&mut self) {
    tokio::select! {
        v = self.subscriber().next => { /* do something with v */ }
        _ = std::future::ready() => {},
    }
}
```

then the future returned by `self.subscriber().next` is cancelled, but
the observed object stilled referenced the waker, preventing the future
(consequently, the function's closure) from being dropped even though
it won't be scheduled again.

This change is twofold:

1. `ObservableState` is now handed `Weak` references, so it does not
   keep futures alive, and a strong reference is kept by whichever
   object is held by the future awaiting it (`Subscriber` or `Next`)
2. `ObservableState` garbage-collects weak references from time to time,
   so its own vector of wakers does not grow unbounded
  • Loading branch information
progval committed Nov 11, 2023
1 parent 7ce1b78 commit 1fd6d5f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 18 deletions.
72 changes: 64 additions & 8 deletions eyeball/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::{
hash::{Hash, Hasher},
mem,
sync::{
atomic::{AtomicU64, Ordering},
RwLock,
atomic::{AtomicUsize, AtomicU64, Ordering},
RwLock, Weak,
},
task::Waker,
};
Expand All @@ -27,12 +27,29 @@ pub struct ObservableState<T> {
/// locked for reading. This way, it is guaranteed that between a subscriber
/// reading the value and adding a waker because the value hasn't changed
/// yet, no updates to the value could have happened.
wakers: RwLock<Vec<Waker>>,
///
/// It contains weak references to wakers, so it does not keep references to
/// [`Subscriber`](crate::Subscriber) or [`Next`](crate::subscriber::Next)
/// that would otherwise be dropped and won't be awaited again (eg. as part
/// of a future being cancelled).
wakers: RwLock<Vec<Weak<Waker>>>,

/// Whenever wakers.len() reaches this size, iterate through it and remove
/// dangling weak references.
/// This is updated in order to only cleanup every time the list of wakers
/// doubled in size since the previous cleanup, allowing a O(1) amortized
/// time complexity.
next_wakers_cleanup_at_len: AtomicUsize,
}

impl<T> ObservableState<T> {
pub(crate) fn new(value: T) -> Self {
Self { value, version: AtomicU64::new(1), wakers: Default::default() }
Self {
value,
version: AtomicU64::new(1),
wakers: Default::default(),
next_wakers_cleanup_at_len: AtomicUsize::new(64), // Arbitrary constant
}
}

/// Get a reference to the inner value.
Expand All @@ -45,8 +62,34 @@ impl<T> ObservableState<T> {
self.version.load(Ordering::Acquire)
}

pub(crate) fn add_waker(&self, waker: Waker) {
self.wakers.write().unwrap().push(waker);
pub(crate) fn add_waker(&self, waker: Weak<Waker>) {
// TODO: clean up dangling Weak references in the vector if there are too many
let mut wakers = self.wakers.write().unwrap();
wakers.push(waker);
if wakers.len() >= self.next_wakers_cleanup_at_len.load(Ordering::Relaxed) {
// Remove dangling Weak references from the vector to free any
// cancelled future that awaited on a `Subscriber` of this
// observable.
let mut new_wakers = Vec::with_capacity(wakers.len());
for waker in wakers.iter() {
if waker.strong_count() > 0 {
new_wakers.push(waker.clone());
}
}
if new_wakers.len() == wakers.len() {
#[cfg(feature = "tracing")]
tracing::debug!("No dangling wakers among set of {}", wakers.len());
} else {
#[cfg(feature = "tracing")]
tracing::debug!(
"Removed {} dangling wakers from a set of {}",
wakers.len() - new_wakers.len(),
wakers.len()
);
std::mem::swap(&mut *wakers, &mut new_wakers);
}
self.next_wakers_cleanup_at_len.store(wakers.len() * 2, Ordering::Relaxed);
}
}

pub(crate) fn set(&mut self, value: T) -> T {
Expand Down Expand Up @@ -111,7 +154,7 @@ fn hash<T: Hash>(value: &T) -> u64 {

fn wake<I>(wakers: I)
where
I: IntoIterator<Item = Waker>,
I: IntoIterator<Item = Weak<Waker>>,
I::IntoIter: ExactSizeIterator,
{
let iter = wakers.into_iter();
Expand All @@ -124,7 +167,20 @@ where
tracing::debug!("No wakers");
}
}
let mut num_alive_wakers = 0;
for waker in iter {
waker.wake();
if let Some(waker) = waker.upgrade() {
num_alive_wakers += 1;
waker.wake_by_ref();
}
}

#[cfg(feature = "tracing")]
{
tracing::debug!("Woke up {num_alive_wakers} waiting subscribers");
}
#[cfg(not(feature = "tracing"))]
{
let _ = num_alive_wakers; // For Clippy
}
}
42 changes: 32 additions & 10 deletions eyeball/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use std::{
fmt,
future::{poll_fn, Future},
pin::Pin,
task::{Context, Poll},
sync::{Arc, Weak},
task::{Context, Poll, Waker},
};

use futures_core::Stream;
Expand All @@ -22,11 +23,14 @@ pub(crate) mod async_lock;
pub struct Subscriber<T, L: Lock = SyncLock> {
state: L::SubscriberState<T>,
observed_version: u64,
/// Prevent wakers from being dropped from `ObservableState` until this
/// `Subscriber` is dropped
wakers: Vec<Arc<Waker>>,
}

impl<T> Subscriber<T> {
pub(crate) fn new(state: readlock::SharedReadLock<ObservableState<T>>, version: u64) -> Self {
Self { state, observed_version: version }
Self { state, observed_version: version, wakers: Vec::new() }
}

/// Wait for an update and get a clone of the updated value.
Expand Down Expand Up @@ -87,7 +91,12 @@ impl<T> Subscriber<T> {
#[must_use]
pub async fn next_ref(&mut self) -> Option<ObservableReadGuard<'_, T>> {
// Unclear how to implement this as a named future.
poll_fn(|cx| self.poll_next_ref(cx).map(|opt| opt.map(|_| {}))).await?;
let mut waker = None;
poll_fn(|cx| {
waker = Some(Arc::new(cx.waker().clone()));
self.poll_next_ref(Arc::downgrade(waker.as_ref().unwrap())).map(|opt| opt.map(|_| {}))
})
.await?;
Some(self.next_ref_now())
}

Expand Down Expand Up @@ -120,7 +129,7 @@ impl<T> Subscriber<T> {
ObservableReadGuard::new(self.state.lock())
}

fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll<Option<ObservableReadGuard<'_, T>>> {
fn poll_next_ref(&mut self, waker: Weak<Waker>) -> Poll<Option<ObservableReadGuard<'_, T>>> {
let state = self.state.lock();
let version = state.version();
if version == 0 {
Expand All @@ -129,7 +138,7 @@ impl<T> Subscriber<T> {
self.observed_version = version;
Poll::Ready(Some(ObservableReadGuard::new(state)))
} else {
state.add_waker(cx.waker().clone());
state.add_waker(waker);
Poll::Pending
}
}
Expand Down Expand Up @@ -160,7 +169,7 @@ impl<T, L: Lock> Subscriber<T, L> {
where
L::SubscriberState<T>: Clone,
{
Self { state: self.state.clone(), observed_version: 0 }
Self { state: self.state.clone(), observed_version: 0, wakers: Vec::new() }
}
}

Expand All @@ -178,7 +187,11 @@ where
L::SubscriberState<T>: Clone,
{
fn clone(&self) -> Self {
Self { state: self.state.clone(), observed_version: self.observed_version }
Self {
state: self.state.clone(),
observed_version: self.observed_version,
wakers: Vec::new(),
}
}
}

Expand All @@ -198,7 +211,10 @@ impl<T: Clone> Stream for Subscriber<T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_next_ref(cx).map(opt_guard_to_owned)
let waker = Arc::new(cx.waker().clone());
let poll = self.poll_next_ref(Arc::downgrade(&waker)).map(opt_guard_to_owned);
self.wakers.push(waker);
poll
}
}

Expand All @@ -207,19 +223,25 @@ impl<T: Clone> Stream for Subscriber<T> {
#[allow(missing_debug_implementations)]
pub struct Next<'a, T, L: Lock = SyncLock> {
subscriber: &'a mut Subscriber<T, L>,
/// Prevent wakers from being dropped from `ObservableState` until this
/// `Next` is dropped
wakers: Vec<Arc<Waker>>,
}

impl<'a, T> Next<'a, T> {
fn new(subscriber: &'a mut Subscriber<T>) -> Self {
Self { subscriber }
Self { subscriber, wakers: Vec::new() }
}
}

impl<T: Clone> Future for Next<'_, T> {
type Output = Option<T>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.subscriber.poll_next_ref(cx).map(opt_guard_to_owned)
let waker = Arc::new(cx.waker().clone());
let poll = self.subscriber.poll_next_ref(Arc::downgrade(&waker)).map(opt_guard_to_owned);
self.wakers.push(waker);
poll
}
}

Expand Down

0 comments on commit 1fd6d5f

Please sign in to comment.