From cbb3c155dd416f6e6a26be5e3b2ebc02853e4b62 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Wed, 7 Jun 2023 08:47:58 -0700 Subject: [PATCH] rt: panic if `EnterGuard` dropped incorrect order (#5772) Calling `Handle::enter()` returns a `EnterGuard` value, which resets the thread-local context on drop. The drop implementation assumes that guards from nested `enter()` calls are dropped in reverse order. However, there is no static enforcement of this requirement. This patch checks that the guards are dropped in reverse order and panics otherwise. A future PR will deprecate `Handle::enter()` in favor of a method that takes a closure, ensuring the guard is dropped appropriately. --- tokio/src/runtime/context.rs | 5 +- tokio/src/runtime/context/current.rs | 70 ++++++++++++++++++++++------ tokio/src/runtime/handle.rs | 41 ++++++++++++++-- tokio/tests/rt_handle.rs | 67 ++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 21 deletions(-) create mode 100644 tokio/tests/rt_handle.rs diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 008686d3fd4..5943e9aa977 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -21,7 +21,6 @@ cfg_rt! { use crate::runtime::{scheduler, task::Id}; - use std::cell::RefCell; use std::task::Waker; cfg_taskdump! { @@ -41,7 +40,7 @@ struct Context { /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] - handle: RefCell>, + current: current::HandleCell, /// Handle to the scheduler's internal "context" #[cfg(feature = "rt")] @@ -84,7 +83,7 @@ tokio_thread_local! { /// Tracks the current runtime handle to use when spawning, /// accessing drivers, etc... #[cfg(feature = "rt")] - handle: RefCell::new(None), + current: current::HandleCell::new(), /// Tracks the current scheduler internal context #[cfg(feature = "rt")] diff --git a/tokio/src/runtime/context/current.rs b/tokio/src/runtime/context/current.rs index a19a73224a7..c3dc5c89942 100644 --- a/tokio/src/runtime/context/current.rs +++ b/tokio/src/runtime/context/current.rs @@ -3,36 +3,42 @@ use super::{Context, CONTEXT}; use crate::runtime::{scheduler, TryCurrentError}; use crate::util::markers::SyncNotSend; +use std::cell::{Cell, RefCell}; use std::marker::PhantomData; #[derive(Debug)] #[must_use] pub(crate) struct SetCurrentGuard { - old_handle: Option, + // The previous handle + prev: Option, + + // The depth for this guard + depth: usize, + + // Don't let the type move across threads. _p: PhantomData, } +pub(super) struct HandleCell { + /// Current handle + handle: RefCell>, + + /// Tracks the number of nested calls to `try_set_current`. + depth: Cell, +} + /// Sets this [`Handle`] as the current active [`Handle`]. /// /// [`Handle`]: crate::runtime::scheduler::Handle pub(crate) fn try_set_current(handle: &scheduler::Handle) -> Option { - CONTEXT - .try_with(|ctx| { - let old_handle = ctx.handle.borrow_mut().replace(handle.clone()); - - SetCurrentGuard { - old_handle, - _p: PhantomData, - } - }) - .ok() + CONTEXT.try_with(|ctx| ctx.set_current(handle)).ok() } pub(crate) fn with_current(f: F) -> Result where F: FnOnce(&scheduler::Handle) -> R, { - match CONTEXT.try_with(|ctx| ctx.handle.borrow().as_ref().map(f)) { + match CONTEXT.try_with(|ctx| ctx.current.handle.borrow().as_ref().map(f)) { Ok(Some(ret)) => Ok(ret), Ok(None) => Err(TryCurrentError::new_no_context()), Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()), @@ -41,19 +47,53 @@ where impl Context { pub(super) fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard { - let old_handle = self.handle.borrow_mut().replace(handle.clone()); + let old_handle = self.current.handle.borrow_mut().replace(handle.clone()); + let depth = self.current.depth.get(); + + if depth == usize::MAX { + panic!("reached max `enter` depth"); + } + + let depth = depth + 1; + self.current.depth.set(depth); SetCurrentGuard { - old_handle, + prev: old_handle, + depth, _p: PhantomData, } } } +impl HandleCell { + pub(super) const fn new() -> HandleCell { + HandleCell { + handle: RefCell::new(None), + depth: Cell::new(0), + } + } +} + impl Drop for SetCurrentGuard { fn drop(&mut self) { CONTEXT.with(|ctx| { - *ctx.handle.borrow_mut() = self.old_handle.take(); + let depth = ctx.current.depth.get(); + + if depth != self.depth { + if !std::thread::panicking() { + panic!( + "`EnterGuard` values dropped out of order. Guards returned by \ + `tokio::runtime::Handle::enter()` must be dropped in the reverse \ + order as they were acquired." + ); + } else { + // Just return... this will leave handles in a wonky state though... + return; + } + } + + *ctx.current.handle.borrow_mut() = self.prev.take(); + ctx.current.depth.set(depth - 1); }); } } diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 1a292530269..0951d8a3736 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -35,9 +35,44 @@ pub struct EnterGuard<'a> { impl Handle { /// Enters the runtime context. This allows you to construct types that must - /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. - /// It will also allow you to call methods such as [`tokio::spawn`] and [`Handle::current`] - /// without panicking. + /// have an executor available on creation such as [`Sleep`] or + /// [`TcpStream`]. It will also allow you to call methods such as + /// [`tokio::spawn`] and [`Handle::current`] without panicking. + /// + /// # Panics + /// + /// When calling `Handle::enter` multiple times, the returned guards + /// **must** be dropped in the reverse order that they were acquired. + /// Failure to do so will result in a panic and possible memory leaks. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// let rt = Runtime::new().unwrap(); + /// + /// let _guard = rt.enter(); + /// tokio::spawn(async { + /// println!("Hello world!"); + /// }); + /// ``` + /// + /// Do **not** do the following, this shows a scenario that will result in a + /// panic and possible memory leak. + /// + /// ```should_panic + /// use tokio::runtime::Runtime; + /// + /// let rt1 = Runtime::new().unwrap(); + /// let rt2 = Runtime::new().unwrap(); + /// + /// let enter1 = rt1.enter(); + /// let enter2 = rt2.enter(); + /// + /// drop(enter1); + /// drop(enter2); + /// ``` /// /// [`Sleep`]: struct@crate::time::Sleep /// [`TcpStream`]: struct@crate::net::TcpStream diff --git a/tokio/tests/rt_handle.rs b/tokio/tests/rt_handle.rs new file mode 100644 index 00000000000..34c99cdaead --- /dev/null +++ b/tokio/tests/rt_handle.rs @@ -0,0 +1,67 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::runtime::Runtime; + +#[test] +fn basic_enter() { + let rt1 = rt(); + let rt2 = rt(); + + let enter1 = rt1.enter(); + let enter2 = rt2.enter(); + + drop(enter2); + drop(enter1); +} + +#[test] +#[should_panic] +fn interleave_enter_different_rt() { + let rt1 = rt(); + let rt2 = rt(); + + let enter1 = rt1.enter(); + let enter2 = rt2.enter(); + + drop(enter1); + drop(enter2); +} + +#[test] +#[should_panic] +fn interleave_enter_same_rt() { + let rt1 = rt(); + + let _enter1 = rt1.enter(); + let enter2 = rt1.enter(); + let enter3 = rt1.enter(); + + drop(enter2); + drop(enter3); +} + +#[test] +#[cfg(not(tokio_wasi))] +fn interleave_then_enter() { + let _ = std::panic::catch_unwind(|| { + let rt1 = rt(); + let rt2 = rt(); + + let enter1 = rt1.enter(); + let enter2 = rt2.enter(); + + drop(enter1); + drop(enter2); + }); + + // Can still enter + let rt3 = rt(); + let _enter = rt3.enter(); +} + +fn rt() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +}