Skip to content

Commit

Permalink
Add JoinHandle::into_join_future().
Browse files Browse the repository at this point in the history
This allows spawned threads to be incorporated into `Future`-based
concurrency control without needing to add separate result-reporting
channels and an additional layer of `catch_unwind()` to the thread
functions. I believe this will be useful to async/blocking interop
and for various applications which want to manage parallel tasks in
a lightweight way.

There is a small additional cost which is paid even if the mechanism is
unused: the algorithm built into the shutdown of a spawned thread must
obtain and invoke a `Waker`, and the `Packet` internal struct is larger
by one `Mutex<Waker>`. In the future, this `Mutex` should be replaced by
something equivalent to `futures::task::AtomicWaker`, which will be more
efficient and eliminate deadlock and blocking hazards, but `std` doesn't
contain one of those yet.

This is not an `impl IntoFuture for JoinHandle` so that it can avoid
being insta-stable; particularly because during the design discussion,
concerns were raised that a proper implementation should obey structured
concurrency via an `AsyncDrop` that forces waiting for the thread.
I personally think that would be a mistake, and structured spawning
should be its own thing, but this choice of API permits either option in
the future by keeping everything unstable, where a trait implementation
would not.
  • Loading branch information
kpreid committed Oct 30, 2024
1 parent 1e4f10b commit 54f6a78
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 17 deletions.
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@
#![feature(lazy_get)]
#![feature(maybe_uninit_slice)]
#![feature(maybe_uninit_write_slice)]
#![feature(noop_waker)]
#![feature(panic_can_unwind)]
#![feature(panic_internals)]
#![feature(pin_coerce_unsized_trait)]
Expand Down
186 changes: 171 additions & 15 deletions library/std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,18 @@ use core::mem::MaybeUninit;

use crate::any::Any;
use crate::cell::UnsafeCell;
use crate::future::Future;
use crate::marker::PhantomData;
use crate::mem::{self, ManuallyDrop, forget};
use crate::num::NonZero;
use crate::pin::Pin;
use crate::sync::Arc;
use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::sync::{Arc, Mutex, PoisonError};
use crate::sys::sync::Parker;
use crate::sys::thread as imp;
use crate::sys_common::{AsInner, IntoInner};
use crate::time::{Duration, Instant};
use crate::{env, fmt, io, panic, panicking, str};
use crate::{env, fmt, io, panic, panicking, str, task};

#[stable(feature = "scoped_threads", since = "1.63.0")]
mod scoped;
Expand Down Expand Up @@ -490,6 +491,7 @@ impl Builder {
let my_packet: Arc<Packet<'scope, T>> = Arc::new(Packet {
scope: scope_data,
result: UnsafeCell::new(None),
waker: Mutex::new(task::Waker::noop().clone()),
_marker: PhantomData,
});
let their_packet = my_packet.clone();
Expand Down Expand Up @@ -540,15 +542,35 @@ impl Builder {
let try_result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
crate::sys::backtrace::__rust_begin_short_backtrace(f)
}));

// Store the `Result` of the thread that the `JoinHandle` can retrieve.
//
// SAFETY: `their_packet` as been built just above and moved by the
// closure (it is an Arc<...>) and `my_packet` will be stored in the
// same `JoinInner` as this closure meaning the mutation will be
// safe (not modify it and affect a value far away).
unsafe { *their_packet.result.get() = Some(try_result) };
// Here `their_packet` gets dropped, and if this is the last `Arc` for that packet that
// will call `decrement_num_running_threads` and therefore signal that this thread is
// done.

// Fetch the `Waker` from the packet; this is needed to support `.into_join_future()`.
// If unused, this just returns `Waker::noop()` which will do nothing.
let waker: task::Waker = {
let placeholder = task::Waker::noop().clone();
let mut guard = their_packet.waker.lock().unwrap_or_else(PoisonError::into_inner);
mem::replace(&mut *guard, placeholder)
};

// Here `their_packet` gets dropped, and if this is the last `Arc` for that packet
// (which happens if the `JoinHandle` has been dropped) that will call
// `decrement_num_running_threads` and therefore signal to the scope (if there is one)
// that this thread is done.
drop(their_packet);

// Now that we have become visibly “finished” by dropping the packet
// (`JoinInner::is_finished` will return true), we can use the `Waker` to signal
// any waiting `JoinFuture`. If instead we are being waited for by
// `JoinHandle::join()`, the actual platform thread termination will be the wakeup.
waker.wake();

// Here, the lifetime `'scope` can end. `main` keeps running for a bit
// after that before returning itself.
};
Expand Down Expand Up @@ -1192,8 +1214,6 @@ impl ThreadId {
}
}
} else {
use crate::sync::{Mutex, PoisonError};

static COUNTER: Mutex<u64> = Mutex::new(0);

let mut counter = COUNTER.lock().unwrap_or_else(PoisonError::into_inner);
Expand Down Expand Up @@ -1635,16 +1655,30 @@ impl fmt::Debug for Thread {
#[stable(feature = "rust1", since = "1.0.0")]
pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>;

// This packet is used to communicate the return value between the spawned
// thread and the rest of the program. It is shared through an `Arc` and
// there's no need for a mutex here because synchronization happens with `join()`
// (the caller will never read this packet until the thread has exited).
//
// An Arc to the packet is stored into a `JoinInner` which in turns is placed
// in `JoinHandle`.
/// This packet is used to communicate the return value between the spawned
/// thread and the rest of the program. It is shared through an [`Arc`].
///
/// An Arc to the packet is stored into a [`JoinInner`] which in turn is placed
/// in [`JoinHandle`] or [`ScopedJoinHandle`].
struct Packet<'scope, T> {
/// Communication with the enclosing thread scope if there is one.
scope: Option<Arc<scoped::ScopeData>>,

/// Holds the return value.
///
/// Synchronization happens via reference counting: as long as the `Arc<Packet>`
/// has two or more references, this field is never read, and will only be written
/// once as the thread terminates. After that happens, either the packet is dropped,
/// or [`JoinInner::join()`] will `take()` the result value from here.
result: UnsafeCell<Option<Result<T>>>,

/// If a [`JoinFuture`] for this thread exists and has been polled,
/// this is the waker from that poll. If it does not exist or has not
/// been polled yet, this is [`task::Waker::noop()`].
// FIXME: This should be an `AtomicWaker` instead of a `Mutex`,
// to be cheaper and impossible to deadlock.
waker: Mutex<task::Waker>,

_marker: PhantomData<Option<&'scope scoped::ScopeData>>,
}

Expand Down Expand Up @@ -1698,6 +1732,10 @@ impl<'scope, T> JoinInner<'scope, T> {
self.native.join();
Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap()
}

fn is_finished(&self) -> bool {
Arc::strong_count(&self.packet) == 1
}
}

/// An owned permission to join on a thread (block on its termination).
Expand Down Expand Up @@ -1844,6 +1882,45 @@ impl<T> JoinHandle<T> {
self.0.join()
}

/// Returns a [`Future`] that resolves when the thread has finished.
///
/// Its [output](Future::Output) value is identical to that of [`JoinHandle::join()`];
/// this is the `async` equivalent of that blocking function.
///
/// If the returned future is dropped (cancelled), the thread will become *detached*;
/// there will be no way to observe or wait for the thread’s termination.
/// This is identical to the behavior of `JoinHandle` itself.
///
/// # Example
///
// FIXME: ideally we would actually run this example, with the help of a trivial async executor
/// ```no_run
/// #![feature(thread_join_future)]
/// use std::thread;
///
/// async fn do_some_heavy_tasks_in_parallel() -> thread::Result<()> {
/// let future_1 = thread::spawn(|| {
/// // ... do something ...
/// }).into_join_future();
/// let future_2 = thread::spawn(|| {
/// // ... do something else ...
/// }).into_join_future();
///
/// // Both threads have been started; now await the completion of both.
/// future_1.await?;
/// future_2.await?;
/// Ok(())
/// }
/// ```
#[unstable(feature = "thread_join_future", issue = "none")]
pub fn into_join_future(self) -> JoinFuture<'static, T> {
// The method is not named `into_future()` to avoid overlapping with the stable
// `IntoFuture::into_future()`. We're not implementing `IntoFuture` in order to
// keep this unstable and preserve the *option* of compatibly making this obey structured
// concurrency via an async-Drop that waits for the thread to end.
JoinFuture::new(self.0)
}

/// Checks if the associated thread has finished running its main function.
///
/// `is_finished` supports implementing a non-blocking join operation, by checking
Expand All @@ -1856,7 +1933,7 @@ impl<T> JoinHandle<T> {
/// to return quickly, without blocking for any significant amount of time.
#[stable(feature = "thread_is_running", since = "1.61.0")]
pub fn is_finished(&self) -> bool {
Arc::strong_count(&self.0.packet) == 1
self.0.is_finished()
}
}

Expand All @@ -1882,9 +1959,88 @@ impl<T> fmt::Debug for JoinHandle<T> {
fn _assert_sync_and_send() {
fn _assert_both<T: Send + Sync>() {}
_assert_both::<JoinHandle<()>>();
_assert_both::<JoinFuture<'static, ()>>();
_assert_both::<Thread>();
}

/// A [`Future`] that resolves when a thread has finished.
///
/// Its [output](Future::Output) value is identical to that of [`JoinHandle::join()`];
/// this is the `async` equivalent of that blocking function.
/// Obtain it by calling [`JoinHandle::into_join_future()`] or
/// [`ScopedJoinHandle::into_join_future()`].
///
/// If a `JoinFuture` is dropped (cancelled), and the thread does not belong to a [scope],
/// the associated thread will become *detached*;
/// there will be no way to observe or wait for the thread’s termination.
#[unstable(feature = "thread_join_future", issue = "none")]
pub struct JoinFuture<'scope, T>(Option<JoinInner<'scope, T>>);

impl<'scope, T> JoinFuture<'scope, T> {
fn new(inner: JoinInner<'scope, T>) -> Self {
Self(Some(inner))
}

/// Implements the “getting a result” part of joining/polling, without blocking or changing
/// the `Waker`. Part of the implementation of `poll()`.
///
/// If this returns `Some`, then `self.0` is now `None` and the future will panic
/// if polled again.
fn take_result(&mut self) -> Option<Result<T>> {
self.0.take_if(|i| i.is_finished()).map(JoinInner::join)
}
}

#[unstable(feature = "thread_join_future", issue = "none")]
impl<T> Future for JoinFuture<'_, T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
if let Some(result) = self.take_result() {
return task::Poll::Ready(result);
}

// Update the `Waker` the thread should wake when it completes.
{
let Some(inner) = &mut self.0 else {
panic!("polled after complete");
};

let new_waker = cx.waker();

// Lock the mutex, and ignore the poison state because there are no meaningful ways
// the existing contents can be corrupted; they will be overwritten completely and the
// overwrite is atomic-in-the-database-sense.
let mut current_waker_guard =
inner.packet.waker.lock().unwrap_or_else(PoisonError::into_inner);

// Overwrite the waker. Note that we are executing the new waker’s clone and the old
// waker’s destructor; these could panic (which will merely poison the lock) or hang,
// which will hold the lock, but the most that can do is prevent the thread from
// exiting because it's trying to acquire `packet.waker`, which it won't do while
// holding any *other* locks (...unless the thread’s data includes a lock guard that
// the waker also wants).
if !new_waker.will_wake(&*current_waker_guard) {
*current_waker_guard = new_waker.clone();
}
}

// Check for completion again in case the thread finished while we were busy
// setting the waker, to prevent a lost wakeup in that case.
if let Some(result) = self.take_result() {
task::Poll::Ready(result)
} else {
task::Poll::Pending
}
}
}

#[unstable(feature = "thread_join_future", issue = "none")]
impl<T> fmt::Debug for JoinFuture<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinHandle").finish_non_exhaustive()
}
}

/// Returns an estimate of the default amount of parallelism a program should use.
///
/// Parallelism is a resource. A given machine provides a certain capacity for
Expand Down
22 changes: 21 additions & 1 deletion library/std/src/thread/scoped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,26 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
self.0.join()
}

/// Returns a [`Future`] that resolves when the thread has finished.
///
/// Its [output] value is identical to that of [`ScopedJoinHandle::join()`];
/// this is the `async` equivalent of that blocking function.
///
/// Note that while this function allows waiting for a scoped thread from `async`
/// functions, the original [`scope()`] is still a blocking function which should
/// not be used in `async` functions.
///
/// [`Future`]: crate::future::Future
/// [output]: crate::future::Future::Output
#[unstable(feature = "thread_join_future", issue = "none")]
pub fn into_join_future(self) -> super::JoinFuture<'scope, T> {
// There is no `ScopedJoinFuture` because the only difference between `JoinHandle`
// and `ScopedJoinHandle` is that `JoinHandle` has no lifetime parameter, because
// it was introduced before scoped threads. `JoinFuture` is new enough that we don’t
// need to make two versions of it.
super::JoinFuture::new(self.0)
}

/// Checks if the associated thread has finished running its main function.
///
/// `is_finished` supports implementing a non-blocking join operation, by checking
Expand All @@ -325,7 +345,7 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
/// to return quickly, without blocking for any significant amount of time.
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn is_finished(&self) -> bool {
Arc::strong_count(&self.0.packet) == 1
self.0.is_finished()
}
}

Expand Down
48 changes: 47 additions & 1 deletion library/std/src/thread/tests.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::Builder;
use crate::any::Any;
use crate::assert_matches::assert_matches;
use crate::future::Future as _;
use crate::panic::panic_any;
use crate::sync::atomic::{AtomicBool, Ordering};
use crate::sync::mpsc::{Sender, channel};
use crate::sync::{Arc, Barrier};
use crate::thread::{self, Scope, ThreadId};
use crate::time::{Duration, Instant};
use crate::{mem, result};
use crate::{mem, result, task};

// !!! These tests are dangerous. If something is buggy, they will hang, !!!
// !!! instead of exiting cleanly. This might wedge the buildbots. !!!
Expand Down Expand Up @@ -410,3 +412,47 @@ fn test_minimal_thread_stack() {
assert_eq!(before, 0);
assert_eq!(COUNT.load(Ordering::Relaxed), 1);
}

fn join_future_test(scoped: bool) {
/// Simple `Waker` implementation.
/// If `std` ever gains a `block_on()`, we can consider replacing this with that.
struct MyWaker(Sender<()>);
impl task::Wake for MyWaker {
fn wake(self: Arc<Self>) {
_ = self.0.send(());
}
}

// Communication setup.
let (thread_delay_tx, thread_delay_rx) = channel();
let (waker_tx, waker_rx) = channel();
let waker = task::Waker::from(Arc::new(MyWaker(waker_tx)));
let ctx = &mut task::Context::from_waker(&waker);

thread::scope(|s| {
// Create the thread and the future under test
let thread_body = move || {
thread_delay_rx.recv().unwrap();
"hello"
};
let mut future = crate::pin::pin!(if scoped {
s.spawn(thread_body).into_join_future()
} else {
thread::spawn(thread_body).into_join_future()
});

// Actual test
assert_matches!(future.as_mut().poll(ctx), task::Poll::Pending);
thread_delay_tx.send(()).unwrap(); // Unblock the thread
waker_rx.recv().unwrap(); // Wait for waking (as an executor would)
assert_matches!(future.as_mut().poll(ctx), task::Poll::Ready(Ok("hello")));
});
}
#[test]
fn join_future_unscoped() {
join_future_test(false)
}
#[test]
fn join_future_scoped() {
join_future_test(true)
}

0 comments on commit 54f6a78

Please sign in to comment.