Skip to content

Commit

Permalink
rt: switch enter to an RAII guard (#2954)
Browse files Browse the repository at this point in the history
  • Loading branch information
carllerche authored Oct 13, 2020
1 parent a249421 commit 00b6127
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 98 deletions.
3 changes: 2 additions & 1 deletion tokio-util/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ impl<F: Future> Future for TokioContext<'_, F> {
let handle = me.handle;
let fut = me.inner;

handle.enter(|| fut.poll(cx))
let _enter = handle.enter();
fut.poll(cx)
}
}

Expand Down
8 changes: 3 additions & 5 deletions tokio/src/runtime/blocking/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,9 @@ impl Spawner {
builder
.spawn(move || {
// Only the reference should be moved into the closure
let rt = &rt;
rt.enter(move || {
rt.blocking_spawner.inner.run(worker_id);
drop(shutdown_tx);
})
let _enter = crate::runtime::context::enter(rt.clone());
rt.blocking_spawner.inner.run(worker_id);
drop(shutdown_tx);
})
.unwrap()
}
Expand Down
3 changes: 2 additions & 1 deletion tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ cfg_rt_multi_thread! {
};

// Spawn the thread pool workers
handle.enter(|| launch.launch());
let _enter = crate::runtime::context::enter(handle.clone());
launch.launch();

Ok(Runtime {
kind: Kind::ThreadPool(scheduler),
Expand Down
32 changes: 14 additions & 18 deletions tokio/src/runtime/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,20 @@ cfg_rt! {
/// Set this [`Handle`] as the current active [`Handle`].
///
/// [`Handle`]: Handle
pub(crate) fn enter<F, R>(new: Handle, f: F) -> R
where
F: FnOnce() -> R,
{
struct DropGuard(Option<Handle>);

impl Drop for DropGuard {
fn drop(&mut self) {
CONTEXT.with(|ctx| {
*ctx.borrow_mut() = self.0.take();
});
}
}

let _guard = CONTEXT.with(|ctx| {
pub(crate) fn enter(new: Handle) -> EnterGuard {
CONTEXT.with(|ctx| {
let old = ctx.borrow_mut().replace(new);
DropGuard(old)
});
EnterGuard(old)
})
}

#[derive(Debug)]
pub(crate) struct EnterGuard(Option<Handle>);

f()
impl Drop for EnterGuard {
fn drop(&mut self) {
CONTEXT.with(|ctx| {
*ctx.borrow_mut() = self.0.take();
});
}
}
20 changes: 10 additions & 10 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::runtime::{blocking, context, driver, Spawner};
use crate::runtime::{blocking, driver, Spawner};

/// Handle to the runtime.
///
Expand Down Expand Up @@ -27,13 +27,13 @@ pub(crate) struct Handle {
}

impl Handle {
/// Enter the runtime context. This allows you to construct types that must
/// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
/// It will also allow you to call methods such as [`tokio::spawn`].
pub(crate) fn enter<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
context::enter(self.clone(), f)
}
// /// Enter the runtime context. This allows you to construct types that must
// /// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
// /// It will also allow you to call methods such as [`tokio::spawn`].
// pub(crate) fn enter<F, R>(&self, f: F) -> R
// where
// F: FnOnce() -> R,
// {
// context::enter(self.clone(), f)
// }
}
60 changes: 38 additions & 22 deletions tokio/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ cfg_rt! {
blocking_pool: BlockingPool,
}

/// Runtime context guard.
///
/// Returned by [`Runtime::enter`], the context guard exits the runtime
/// context on drop.
#[derive(Debug)]
pub struct EnterGuard<'a> {
rt: &'a Runtime,
guard: context::EnterGuard,
}

/// The runtime executor is either a thread-pool or a current-thread executor.
#[derive(Debug)]
enum Kind {
Expand Down Expand Up @@ -356,25 +366,26 @@ cfg_rt! {
}
}

/// Run a future to completion on the Tokio runtime. This is the runtime's
/// entry point.
/// Run a future to completion on the Tokio runtime. This is the
/// runtime's entry point.
///
/// This runs the given future on the runtime, blocking until it is
/// complete, and yielding its resolved result. Any tasks or timers which
/// the future spawns internally will be executed on the runtime.
/// complete, and yielding its resolved result. Any tasks or timers
/// which the future spawns internally will be executed on the runtime.
///
/// When this runtime is configured with `core_threads = 0`, only the first call
/// to `block_on` will run the IO and timer drivers. Calls to other methods _before_ the first
/// `block_on` completes will just hook into the driver running on the thread
/// that first called `block_on`. This means that the driver may be passed
/// from thread to thread by the user between calls to `block_on`.
/// When this runtime is configured with `core_threads = 0`, only the
/// first call to `block_on` will run the IO and timer drivers. Calls to
/// other methods _before_ the first `block_on` completes will just hook
/// into the driver running on the thread that first called `block_on`.
/// This means that the driver may be passed from thread to thread by
/// the user between calls to `block_on`.
///
/// This method may not be called from an asynchronous context.
///
/// # Panics
///
/// This function panics if the provided future panics, or if called within an
/// asynchronous execution context.
/// This function panics if the provided future panics, or if called
/// within an asynchronous execution context.
///
/// # Examples
///
Expand All @@ -392,17 +403,21 @@ cfg_rt! {
///
/// [handle]: fn@Handle::block_on
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
self.handle.enter(|| match &self.kind {
let _enter = self.enter();

match &self.kind {
#[cfg(feature = "rt")]
Kind::CurrentThread(exec) => exec.block_on(future),
#[cfg(feature = "rt-multi-thread")]
Kind::ThreadPool(exec) => exec.block_on(future),
})
}
}

/// Enter the runtime context. This allows you to construct types that must
/// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
/// It will also allow you to call methods such as [`tokio::spawn`].
/// Enter the runtime context.
///
/// This allows you to construct types that must have an executor
/// available on creation such as [`Sleep`] or [`TcpStream`]. It will
/// also allow you to call methods such as [`tokio::spawn`].
///
/// [`Sleep`]: struct@crate::time::Sleep
/// [`TcpStream`]: struct@crate::net::TcpStream
Expand All @@ -426,14 +441,15 @@ cfg_rt! {
/// let s = "Hello World!".to_string();
///
/// // By entering the context, we tie `tokio::spawn` to this executor.
/// rt.enter(|| function_that_spawns(s));
/// let _guard = rt.enter();
/// function_that_spawns(s);
/// }
/// ```
pub fn enter<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
self.handle.enter(f)
pub fn enter(&self) -> EnterGuard<'_> {
EnterGuard {
rt: self,
guard: context::enter(self.handle.clone()),
}
}

/// Shutdown the runtime, waiting for at most `duration` for all spawned
Expand Down
5 changes: 3 additions & 2 deletions tokio/src/runtime/tests/loom_blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ fn blocking_shutdown() {
let v = Arc::new(());

let rt = mk_runtime(1);
rt.enter(|| {
{
let _enter = rt.enter();
for _ in 0..2 {
let v = v.clone();
crate::task::spawn_blocking(move || {
assert!(1 < Arc::strong_count(&v));
});
}
});
}

drop(rt);
assert_eq!(1, Arc::strong_count(&v));
Expand Down
21 changes: 10 additions & 11 deletions tokio/src/signal/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,20 @@ mod tests {
#[test]
fn ctrl_c() {
let rt = rt();
let _enter = rt.enter();

rt.enter(|| {
let mut ctrl_c = task::spawn(crate::signal::ctrl_c());
let mut ctrl_c = task::spawn(crate::signal::ctrl_c());

assert_pending!(ctrl_c.poll());
assert_pending!(ctrl_c.poll());

// Windows doesn't have a good programmatic way of sending events
// like sending signals on Unix, so we'll stub out the actual OS
// integration and test that our handling works.
unsafe {
super::handler(CTRL_C_EVENT);
}
// Windows doesn't have a good programmatic way of sending events
// like sending signals on Unix, so we'll stub out the actual OS
// integration and test that our handling works.
unsafe {
super::handler(CTRL_C_EVENT);
}

assert_ready_ok!(ctrl_c.poll());
});
assert_ready_ok!(ctrl_c.poll());
}

#[test]
Expand Down
9 changes: 4 additions & 5 deletions tokio/tests/io_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ fn test_drop_on_notify() {
}));

{
rt.enter(|| {
let waker = waker_ref(&task);
let mut cx = Context::from_waker(&waker);
assert_pending!(task.future.lock().unwrap().as_mut().poll(&mut cx));
});
let _enter = rt.enter();
let waker = waker_ref(&task);
let mut cx = Context::from_waker(&waker);
assert_pending!(task.future.lock().unwrap().as_mut().poll(&mut cx));
}

// Get the address
Expand Down
10 changes: 6 additions & 4 deletions tokio/tests/io_driver_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task};
fn tcp_doesnt_block() {
let rt = rt();

let listener = rt.enter(|| {
let listener = {
let _enter = rt.enter();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
};

drop(rt);

Expand All @@ -27,10 +28,11 @@ fn tcp_doesnt_block() {
fn drop_wakes() {
let rt = rt();

let listener = rt.enter(|| {
let listener = {
let _enter = rt.enter();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
};

let mut task = task::spawn(async move {
assert_err!(listener.accept().await);
Expand Down
22 changes: 3 additions & 19 deletions tokio/tests/rt_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,23 +554,6 @@ rt_test! {
});
}

#[test]
fn spawn_blocking_after_shutdown() {
let rt = rt();
let handle = rt.clone();

// Shutdown
drop(rt);

handle.enter(|| {
let res = task::spawn_blocking(|| unreachable!());

// Avoid using a tokio runtime
let out = futures::executor::block_on(res);
assert!(out.is_err());
});
}

#[test]
fn always_active_parker() {
// This test it to show that we will always have
Expand Down Expand Up @@ -713,9 +696,10 @@ rt_test! {
#[test]
fn enter_and_spawn() {
let rt = rt();
let handle = rt.enter(|| {
let handle = {
let _enter = rt.enter();
tokio::spawn(async {})
});
};

assert_ok!(rt.block_on(handle));
}
Expand Down

0 comments on commit 00b6127

Please sign in to comment.