Skip to content

Commit

Permalink
rt: panic if EnterGuard dropped incorrect order (#5772)
Browse files Browse the repository at this point in the history
Calling `Handle::enter()` returns a `EnterGuard` value, which resets the
thread-local context on drop. The drop implementation assumes that
guards from nested `enter()` calls are dropped in reverse order.
However, there is no static enforcement of this requirement.

This patch checks that the guards are dropped in reverse order and
panics otherwise. A future PR will deprecate `Handle::enter()` in favor
of a method that takes a closure, ensuring the guard is dropped
appropriately.
  • Loading branch information
carllerche authored Jun 7, 2023
1 parent 038c4d9 commit cbb3c15
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 21 deletions.
5 changes: 2 additions & 3 deletions tokio/src/runtime/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ cfg_rt! {

use crate::runtime::{scheduler, task::Id};

use std::cell::RefCell;
use std::task::Waker;

cfg_taskdump! {
Expand All @@ -41,7 +40,7 @@ struct Context {

/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell<Option<scheduler::Handle>>,
current: current::HandleCell,

/// Handle to the scheduler's internal "context"
#[cfg(feature = "rt")]
Expand Down Expand Up @@ -84,7 +83,7 @@ tokio_thread_local! {
/// Tracks the current runtime handle to use when spawning,
/// accessing drivers, etc...
#[cfg(feature = "rt")]
handle: RefCell::new(None),
current: current::HandleCell::new(),

/// Tracks the current scheduler internal context
#[cfg(feature = "rt")]
Expand Down
70 changes: 55 additions & 15 deletions tokio/src/runtime/context/current.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,42 @@ use super::{Context, CONTEXT};
use crate::runtime::{scheduler, TryCurrentError};
use crate::util::markers::SyncNotSend;

use std::cell::{Cell, RefCell};
use std::marker::PhantomData;

#[derive(Debug)]
#[must_use]
pub(crate) struct SetCurrentGuard {
old_handle: Option<scheduler::Handle>,
// The previous handle
prev: Option<scheduler::Handle>,

// The depth for this guard
depth: usize,

// Don't let the type move across threads.
_p: PhantomData<SyncNotSend>,
}

pub(super) struct HandleCell {
/// Current handle
handle: RefCell<Option<scheduler::Handle>>,

/// Tracks the number of nested calls to `try_set_current`.
depth: Cell<usize>,
}

/// Sets this [`Handle`] as the current active [`Handle`].
///
/// [`Handle`]: crate::runtime::scheduler::Handle
pub(crate) fn try_set_current(handle: &scheduler::Handle) -> Option<SetCurrentGuard> {
CONTEXT
.try_with(|ctx| {
let old_handle = ctx.handle.borrow_mut().replace(handle.clone());

SetCurrentGuard {
old_handle,
_p: PhantomData,
}
})
.ok()
CONTEXT.try_with(|ctx| ctx.set_current(handle)).ok()
}

pub(crate) fn with_current<F, R>(f: F) -> Result<R, TryCurrentError>
where
F: FnOnce(&scheduler::Handle) -> R,
{
match CONTEXT.try_with(|ctx| ctx.handle.borrow().as_ref().map(f)) {
match CONTEXT.try_with(|ctx| ctx.current.handle.borrow().as_ref().map(f)) {
Ok(Some(ret)) => Ok(ret),
Ok(None) => Err(TryCurrentError::new_no_context()),
Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()),
Expand All @@ -41,19 +47,53 @@ where

impl Context {
pub(super) fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard {
let old_handle = self.handle.borrow_mut().replace(handle.clone());
let old_handle = self.current.handle.borrow_mut().replace(handle.clone());
let depth = self.current.depth.get();

if depth == usize::MAX {
panic!("reached max `enter` depth");
}

let depth = depth + 1;
self.current.depth.set(depth);

SetCurrentGuard {
old_handle,
prev: old_handle,
depth,
_p: PhantomData,
}
}
}

impl HandleCell {
pub(super) const fn new() -> HandleCell {
HandleCell {
handle: RefCell::new(None),
depth: Cell::new(0),
}
}
}

impl Drop for SetCurrentGuard {
fn drop(&mut self) {
CONTEXT.with(|ctx| {
*ctx.handle.borrow_mut() = self.old_handle.take();
let depth = ctx.current.depth.get();

if depth != self.depth {
if !std::thread::panicking() {
panic!(
"`EnterGuard` values dropped out of order. Guards returned by \
`tokio::runtime::Handle::enter()` must be dropped in the reverse \
order as they were acquired."
);
} else {
// Just return... this will leave handles in a wonky state though...
return;
}
}

*ctx.current.handle.borrow_mut() = self.prev.take();
ctx.current.depth.set(depth - 1);
});
}
}
41 changes: 38 additions & 3 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,44 @@ pub struct EnterGuard<'a> {

impl Handle {
/// Enters the runtime context. This allows you to construct types that must
/// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
/// It will also allow you to call methods such as [`tokio::spawn`] and [`Handle::current`]
/// without panicking.
/// have an executor available on creation such as [`Sleep`] or
/// [`TcpStream`]. It will also allow you to call methods such as
/// [`tokio::spawn`] and [`Handle::current`] without panicking.
///
/// # Panics
///
/// When calling `Handle::enter` multiple times, the returned guards
/// **must** be dropped in the reverse order that they were acquired.
/// Failure to do so will result in a panic and possible memory leaks.
///
/// # Examples
///
/// ```
/// use tokio::runtime::Runtime;
///
/// let rt = Runtime::new().unwrap();
///
/// let _guard = rt.enter();
/// tokio::spawn(async {
/// println!("Hello world!");
/// });
/// ```
///
/// Do **not** do the following, this shows a scenario that will result in a
/// panic and possible memory leak.
///
/// ```should_panic
/// use tokio::runtime::Runtime;
///
/// let rt1 = Runtime::new().unwrap();
/// let rt2 = Runtime::new().unwrap();
///
/// let enter1 = rt1.enter();
/// let enter2 = rt2.enter();
///
/// drop(enter1);
/// drop(enter2);
/// ```
///
/// [`Sleep`]: struct@crate::time::Sleep
/// [`TcpStream`]: struct@crate::net::TcpStream
Expand Down
67 changes: 67 additions & 0 deletions tokio/tests/rt_handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::runtime::Runtime;

#[test]
fn basic_enter() {
let rt1 = rt();
let rt2 = rt();

let enter1 = rt1.enter();
let enter2 = rt2.enter();

drop(enter2);
drop(enter1);
}

#[test]
#[should_panic]
fn interleave_enter_different_rt() {
let rt1 = rt();
let rt2 = rt();

let enter1 = rt1.enter();
let enter2 = rt2.enter();

drop(enter1);
drop(enter2);
}

#[test]
#[should_panic]
fn interleave_enter_same_rt() {
let rt1 = rt();

let _enter1 = rt1.enter();
let enter2 = rt1.enter();
let enter3 = rt1.enter();

drop(enter2);
drop(enter3);
}

#[test]
#[cfg(not(tokio_wasi))]
fn interleave_then_enter() {
let _ = std::panic::catch_unwind(|| {
let rt1 = rt();
let rt2 = rt();

let enter1 = rt1.enter();
let enter2 = rt2.enter();

drop(enter1);
drop(enter2);
});

// Can still enter
let rt3 = rt();
let _enter = rt3.enter();
}

fn rt() -> Runtime {
tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
}

0 comments on commit cbb3c15

Please sign in to comment.