From 512e9decfb683d22f4a145459142542caa0894c9 Mon Sep 17 00:00:00 2001 From: Noah Kennedy Date: Sat, 12 Oct 2024 10:39:23 -0500 Subject: [PATCH] rt: add LocalRuntime (#6808) This change adds LocalRuntime, a new unstable runtime type which cannot be transferred across thread boundaries and supports spawn_local when called from the thread which owns the runtime. The initial set of docs for this are iffy. Documentation is absent right now at the module level, with the docs for the LocalRuntime struct itself being somewhat duplicative of those for the `Runtime` type. This can be addressed later as stabilization nears. This API has a few interesting implementation details: - because it was considered beneficial to reuse the same Handle as the normal runtime, it is possible to call spawn_local from a runtime context while on a different thread from the one which drives the runtime and owns it. This forces us to check the thread ID before attempting a local spawn. - An empty LocalOptions struct is passed into the build_local method in order to build the runtime. This will eventually have stuff in it like hooks. Relates to #6739. --- tokio/src/runtime/builder.rs | 81 +++- tokio/src/runtime/handle.rs | 29 +- tokio/src/runtime/local_runtime/mod.rs | 7 + tokio/src/runtime/local_runtime/options.rs | 12 + tokio/src/runtime/local_runtime/runtime.rs | 393 ++++++++++++++++++ tokio/src/runtime/mod.rs | 3 + .../runtime/scheduler/current_thread/mod.rs | 35 ++ tokio/src/runtime/scheduler/mod.rs | 43 ++ tokio/src/runtime/task/list.rs | 20 + tokio/src/task/local.rs | 52 ++- tokio/src/util/trace.rs | 1 + tokio/tests/rt_local.rs | 100 +++++ 12 files changed, 760 insertions(+), 16 deletions(-) create mode 100644 tokio/src/runtime/local_runtime/mod.rs create mode 100644 tokio/src/runtime/local_runtime/options.rs create mode 100644 tokio/src/runtime/local_runtime/runtime.rs create mode 100644 tokio/tests/rt_local.rs diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index b5bf35d69b4..4d35120b1f9 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1,13 +1,16 @@ #![cfg_attr(loom, allow(unused_imports))] use crate::runtime::handle::Handle; -#[cfg(tokio_unstable)] -use crate::runtime::TaskMeta; use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +#[cfg(tokio_unstable)] +use crate::runtime::{LocalOptions, LocalRuntime, TaskMeta}; use crate::util::rand::{RngSeed, RngSeedGenerator}; +use crate::runtime::blocking::BlockingPool; +use crate::runtime::scheduler::CurrentThread; use std::fmt; use std::io; +use std::thread::ThreadId; use std::time::Duration; /// Builds Tokio Runtime with custom configuration values. @@ -800,6 +803,37 @@ impl Builder { } } + /// Creates the configured `LocalRuntime`. + /// + /// The returned `LocalRuntime` instance is ready to spawn tasks. + /// + /// # Panics + /// This will panic if `current_thread` is not the selected runtime flavor. + /// All other runtime flavors are unsupported by [`LocalRuntime`]. + /// + /// [`LocalRuntime`]: [crate::runtime::LocalRuntime] + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Builder; + /// + /// let rt = Builder::new_current_thread().build_local(&mut Default::default()).unwrap(); + /// + /// rt.block_on(async { + /// println!("Hello from the Tokio runtime"); + /// }); + /// ``` + #[allow(unused_variables, unreachable_patterns)] + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn build_local(&mut self, options: &LocalOptions) -> io::Result { + match &self.kind { + Kind::CurrentThread => self.build_current_thread_local_runtime(), + _ => panic!("Only current_thread is supported when building a local runtime"), + } + } + fn get_cfg(&self, workers: usize) -> driver::Cfg { driver::Cfg { enable_pause_time: match self.kind { @@ -1191,8 +1225,40 @@ impl Builder { } fn build_current_thread_runtime(&mut self) -> io::Result { - use crate::runtime::scheduler::{self, CurrentThread}; - use crate::runtime::{runtime::Scheduler, Config}; + use crate::runtime::runtime::Scheduler; + + let (scheduler, handle, blocking_pool) = + self.build_current_thread_runtime_components(None)?; + + Ok(Runtime::from_parts( + Scheduler::CurrentThread(scheduler), + handle, + blocking_pool, + )) + } + + #[cfg(tokio_unstable)] + fn build_current_thread_local_runtime(&mut self) -> io::Result { + use crate::runtime::local_runtime::LocalRuntimeScheduler; + + let tid = std::thread::current().id(); + + let (scheduler, handle, blocking_pool) = + self.build_current_thread_runtime_components(Some(tid))?; + + Ok(LocalRuntime::from_parts( + LocalRuntimeScheduler::CurrentThread(scheduler), + handle, + blocking_pool, + )) + } + + fn build_current_thread_runtime_components( + &mut self, + local_tid: Option, + ) -> io::Result<(CurrentThread, Handle, BlockingPool)> { + use crate::runtime::scheduler; + use crate::runtime::Config; let (driver, driver_handle) = driver::Driver::new(self.get_cfg(1))?; @@ -1227,17 +1293,14 @@ impl Builder { seed_generator: seed_generator_1, metrics_poll_count_histogram: self.metrics_poll_count_histogram_builder(), }, + local_tid, ); let handle = Handle { inner: scheduler::Handle::CurrentThread(handle), }; - Ok(Runtime::from_parts( - Scheduler::CurrentThread(scheduler), - handle, - blocking_pool, - )) + Ok((scheduler, handle, blocking_pool)) } fn metrics_poll_count_histogram_builder(&self) -> Option { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 9026e8773a0..752640d75bd 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -250,8 +250,8 @@ impl Handle { /// # Panics /// /// This function panics if the provided future panics, if called within an - /// asynchronous execution context, or if a timer future is executed on a - /// runtime that has been shut down. + /// asynchronous execution context, or if a timer future is executed on a runtime that has been + /// shut down. /// /// # Examples /// @@ -348,6 +348,31 @@ impl Handle { self.inner.spawn(future, id) } + #[track_caller] + #[allow(dead_code)] + pub(crate) unsafe fn spawn_local_named( + &self, + future: F, + _meta: SpawnMeta<'_>, + ) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + let id = crate::runtime::task::Id::next(); + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") + ))] + let future = super::task::trace::Trace::root(future); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + self.inner.spawn_local(future, id) + } + /// Returns the flavor of the current `Runtime`. /// /// # Examples diff --git a/tokio/src/runtime/local_runtime/mod.rs b/tokio/src/runtime/local_runtime/mod.rs new file mode 100644 index 00000000000..1ea7693f292 --- /dev/null +++ b/tokio/src/runtime/local_runtime/mod.rs @@ -0,0 +1,7 @@ +mod runtime; + +mod options; + +pub use options::LocalOptions; +pub use runtime::LocalRuntime; +pub(super) use runtime::LocalRuntimeScheduler; diff --git a/tokio/src/runtime/local_runtime/options.rs b/tokio/src/runtime/local_runtime/options.rs new file mode 100644 index 00000000000..ed25d9ccd44 --- /dev/null +++ b/tokio/src/runtime/local_runtime/options.rs @@ -0,0 +1,12 @@ +use std::marker::PhantomData; + +/// `LocalRuntime`-only config options +/// +/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may +/// be added. +#[derive(Default, Debug)] +#[non_exhaustive] +pub struct LocalOptions { + /// Marker used to make this !Send and !Sync. + _phantom: PhantomData<*mut u8>, +} diff --git a/tokio/src/runtime/local_runtime/runtime.rs b/tokio/src/runtime/local_runtime/runtime.rs new file mode 100644 index 00000000000..0f2b944e4eb --- /dev/null +++ b/tokio/src/runtime/local_runtime/runtime.rs @@ -0,0 +1,393 @@ +#![allow(irrefutable_let_patterns)] + +use crate::runtime::blocking::BlockingPool; +use crate::runtime::scheduler::CurrentThread; +use crate::runtime::{context, Builder, EnterGuard, Handle, BOX_FUTURE_THRESHOLD}; +use crate::task::JoinHandle; + +use crate::util::trace::SpawnMeta; +use std::future::Future; +use std::marker::PhantomData; +use std::mem; +use std::time::Duration; + +/// A local Tokio runtime. +/// +/// This runtime is capable of driving tasks which are not `Send + Sync` without the use of a +/// `LocalSet`, and thus supports `spawn_local` without the need for a `LocalSet` context. +/// +/// This runtime cannot be moved between threads or driven from different threads. +/// +/// This runtime is incompatible with `LocalSet`. You should not attempt to drive a `LocalSet` within a +/// `LocalRuntime`. +/// +/// Currently, this runtime supports one flavor, which is internally identical to `current_thread`, +/// save for the aforementioned differences related to `spawn_local`. +/// +/// For more general information on how to use runtimes, see the [module] docs. +/// +/// [runtime]: crate::runtime::Runtime +/// [module]: crate::runtime +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] +pub struct LocalRuntime { + /// Task scheduler + scheduler: LocalRuntimeScheduler, + + /// Handle to runtime, also contains driver handles + handle: Handle, + + /// Blocking pool handle, used to signal shutdown + blocking_pool: BlockingPool, + + /// Marker used to make this !Send and !Sync. + _phantom: PhantomData<*mut u8>, +} + +/// The runtime scheduler is always a `current_thread` scheduler right now. +#[derive(Debug)] +pub(crate) enum LocalRuntimeScheduler { + /// Execute all tasks on the current-thread. + CurrentThread(CurrentThread), +} + +impl LocalRuntime { + pub(crate) fn from_parts( + scheduler: LocalRuntimeScheduler, + handle: Handle, + blocking_pool: BlockingPool, + ) -> LocalRuntime { + LocalRuntime { + scheduler, + handle, + blocking_pool, + _phantom: Default::default(), + } + } + + /// Creates a new local runtime instance with default configuration values. + /// + /// This results in the scheduler, I/O driver, and time driver being + /// initialized. + /// + /// When a more complex configuration is necessary, the [runtime builder] may be used. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Examples + /// + /// Creating a new `LocalRuntime` with default configuration values. + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// let rt = LocalRuntime::new() + /// .unwrap(); + /// + /// // Use the runtime... + /// ``` + /// + /// [mod]: crate::runtime + /// [runtime builder]: crate::runtime::Builder + pub fn new() -> std::io::Result { + Builder::new_current_thread() + .enable_all() + .build_local(&Default::default()) + } + + /// Returns a handle to the runtime's spawner. + /// + /// The returned handle can be used to spawn tasks that run on this runtime, and can + /// be cloned to allow moving the `Handle` to other threads. + /// + /// As the handle can be sent to other threads, it can only be used to spawn tasks that are `Send`. + /// + /// Calling [`Handle::block_on`] on a handle to a `LocalRuntime` is error-prone. + /// Refer to the documentation of [`Handle::block_on`] for more. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// let rt = LocalRuntime::new() + /// .unwrap(); + /// + /// let handle = rt.handle(); + /// + /// // Use the handle... + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Spawns a task on the runtime. + /// + /// This is analogous to the [`spawn`] method on the standard [`Runtime`], but works even if the task is not thread-safe. + /// + /// [`spawn`]: crate::runtime::Runtime::spawn + /// [`Runtime`]: crate::runtime::Runtime + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Spawn a future onto the runtime + /// rt.spawn_local(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn_local(&self, future: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + let fut_size = std::mem::size_of::(); + let meta = SpawnMeta::new_unnamed(fut_size); + + // safety: spawn_local can only be called from `LocalRuntime`, which this is + unsafe { + if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.handle.spawn_local_named(Box::pin(future), meta) + } else { + self.handle.spawn_local_named(future, meta) + } + } + } + + /// Runs the provided function on a thread from a dedicated blocking thread pool. + /// + /// This function _will_ be run on another thread. + /// + /// See the documentation in the non-local runtime for more information. + /// + /// [Runtime]: crate::runtime::Runtime::spawn_blocking + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Spawn a blocking function onto the runtime + /// rt.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Runs a future to completion on the Tokio runtime. This is the + /// runtime's entry point. + /// + /// See the documentation for [the equivalent method on Runtime] for more information. + /// + /// [Runtime]: crate::runtime::Runtime::block_on + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::LocalRuntime; + /// + /// // Create the runtime + /// let rt = LocalRuntime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + #[track_caller] + pub fn block_on(&self, future: F) -> F::Output { + let fut_size = mem::size_of::(); + let meta = SpawnMeta::new_unnamed(fut_size); + + if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { + self.block_on_inner(Box::pin(future), meta) + } else { + self.block_on_inner(future, meta) + } + } + + #[track_caller] + fn block_on_inner(&self, future: F, _meta: SpawnMeta<'_>) -> F::Output { + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") + ))] + let future = crate::runtime::task::trace::Trace::root(future); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task( + future, + "block_on", + _meta, + crate::runtime::task::Id::next().as_u64(), + ); + + let _enter = self.enter(); + + if let LocalRuntimeScheduler::CurrentThread(exec) = &self.scheduler { + exec.block_on(&self.handle.inner, future) + } else { + unreachable!("LocalRuntime only supports current_thread") + } + } + + /// Enters the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. + /// + /// If this is a handle to a [`LocalRuntime`], and this function is being invoked from the same + /// thread that the runtime was created on, you will also be able to call + /// [`tokio::task::spawn_local`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + /// [`LocalRuntime`]: struct@crate::runtime::LocalRuntime + /// [`tokio::task::spawn_local`]: fn@crate::task::spawn_local + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// use tokio::task::JoinHandle; + /// + /// fn function_that_spawns(msg: String) -> JoinHandle<()> { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }) + /// } + /// + /// fn main() { + /// let rt = LocalRuntime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// let _guard = rt.enter(); + /// let handle = function_that_spawns(s); + /// + /// // Wait for the task before we end the test. + /// rt.block_on(handle).unwrap(); + /// } + /// ``` + pub fn enter(&self) -> EnterGuard<'_> { + self.handle.enter() + } + + /// Shuts down the runtime, waiting for at most `duration` for all spawned + /// work to stop. + /// + /// Note that `spawn_blocking` tasks, and only `spawn_blocking` tasks, can get left behind if + /// the timeout expires. + /// + /// See the [struct level documentation](LocalRuntime#shutdown) for more details. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// use tokio::task; + /// + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let runtime = LocalRuntime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// task::spawn_blocking(move || { + /// thread::sleep(Duration::from_secs(10_000)); + /// }); + /// }); + /// + /// runtime.shutdown_timeout(Duration::from_millis(100)); + /// } + /// ``` + pub fn shutdown_timeout(mut self, duration: Duration) { + // Wakeup and shutdown all the worker threads + self.handle.inner.shutdown(); + self.blocking_pool.shutdown(Some(duration)); + } + + /// Shuts down the runtime, without waiting for any spawned work to stop. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. No other tasks will leak. + /// + /// See the [struct level documentation](LocalRuntime#shutdown) for more details. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::from_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::LocalRuntime; + /// + /// fn main() { + /// let runtime = LocalRuntime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = LocalRuntime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)); + } + + /// Returns a view that lets you get information about how the runtime + /// is performing. + pub fn metrics(&self) -> crate::runtime::RuntimeMetrics { + self.handle.metrics() + } +} + +#[allow(clippy::single_match)] // there are comments in the error branch, so we don't want if-let +impl Drop for LocalRuntime { + fn drop(&mut self) { + if let LocalRuntimeScheduler::CurrentThread(current_thread) = &mut self.scheduler { + // This ensures that tasks spawned on the current-thread + // runtime are dropped inside the runtime's context. + let _guard = context::try_set_current(&self.handle.inner); + current_thread.shutdown(&self.handle.inner); + } else { + unreachable!("LocalRuntime only supports current-thread") + } + } +} + +impl std::panic::UnwindSafe for LocalRuntime {} + +impl std::panic::RefUnwindSafe for LocalRuntime {} diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 3f2467f6dbc..c8efbe2f1cd 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -372,6 +372,9 @@ cfg_rt! { pub use self::builder::UnhandledPanic; pub use crate::util::rand::RngSeed; + + mod local_runtime; + pub use local_runtime::{LocalRuntime, LocalOptions}; } cfg_taskdump! { diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index 9959dff8e46..c66635e7bd6 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -18,6 +18,7 @@ use std::future::{poll_fn, Future}; use std::sync::atomic::Ordering::{AcqRel, Release}; use std::task::Poll::{Pending, Ready}; use std::task::Waker; +use std::thread::ThreadId; use std::time::Duration; use std::{fmt, thread}; @@ -47,6 +48,9 @@ pub(crate) struct Handle { /// User-supplied hooks to invoke for things pub(crate) task_hooks: TaskHooks, + + /// If this is a `LocalRuntime`, flags the owning thread ID. + pub(crate) local_tid: Option, } /// Data required for executing the scheduler. The struct is passed around to @@ -127,6 +131,7 @@ impl CurrentThread { blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, + local_tid: Option, ) -> (CurrentThread, Arc) { let worker_metrics = WorkerMetrics::from_config(&config); worker_metrics.set_thread_id(thread::current().id()); @@ -152,6 +157,7 @@ impl CurrentThread { driver: driver_handle, blocking_spawner, seed_generator, + local_tid, }); let core = AtomicCell::new(Some(Box::new(Core { @@ -458,6 +464,35 @@ impl Handle { handle } + /// Spawn a task which isn't safe to send across thread boundaries onto the runtime. + /// + /// # Safety + /// This should only be used when this is a `LocalRuntime` or in another case where the runtime + /// provably cannot be driven from or moved to different threads from the one on which the task + /// is spawned. + pub(crate) unsafe fn spawn_local( + me: &Arc, + future: F, + id: crate::runtime::task::Id, + ) -> JoinHandle + where + F: crate::future::Future + 'static, + F::Output: 'static, + { + let (handle, notified) = me.shared.owned.bind_local(future, me.clone(), id); + + me.task_hooks.spawn(&TaskMeta { + id, + _phantom: Default::default(), + }); + + if let Some(notified) = notified { + me.schedule(notified); + } + + handle + } + /// Capture a snapshot of this runtime's state. #[cfg(all( tokio_unstable, diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index ada8efbad63..e0a1b20b5bc 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -113,6 +113,31 @@ cfg_rt! { match_flavor!(self, Handle(h) => &h.blocking_spawner) } + pub(crate) fn is_local(&self) -> bool { + match self { + Handle::CurrentThread(h) => h.local_tid.is_some(), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(_) => false, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(_) => false, + } + } + + /// Returns true if this is a local runtime and the runtime is owned by the current thread. + pub(crate) fn can_spawn_local_on_local_runtime(&self) -> bool { + match self { + Handle::CurrentThread(h) => h.local_tid.map(|x| std::thread::current().id() == x).unwrap_or(false), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(_) => false, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(_) => false, + } + } + pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle where F: Future + Send + 'static, @@ -129,6 +154,24 @@ cfg_rt! { } } + /// Spawn a local task + /// + /// # Safety + /// This should only be called in `LocalRuntime` if the runtime has been verified to be owned + /// by the current thread. + #[allow(irrefutable_let_patterns)] + pub(crate) unsafe fn spawn_local(&self, future: F, id: Id) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + if let Handle::CurrentThread(h) = self { + current_thread::Handle::spawn_local(h, future, id) + } else { + panic!("Only current_thread and LocalSet have spawn_local internals implemented") + } + } + pub(crate) fn shutdown(&self) { match *self { Handle::CurrentThread(_) => {}, diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 988d422836d..54bfc01aafb 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -102,6 +102,26 @@ impl OwnedTasks { (join, notified) } + /// Bind a task that isn't safe to transfer across thread boundaries. + /// + /// # Safety + /// Only use this in `LocalRuntime` where the task cannot move + pub(crate) unsafe fn bind_local( + &self, + task: T, + scheduler: S, + id: super::Id, + ) -> (JoinHandle, Option>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler, id); + let notified = unsafe { self.bind_inner(task, notified) }; + (join, notified) + } + /// The part of `bind` that's the same for every type of future. unsafe fn bind_inner(&self, task: Task, notified: Notified) -> Option> where diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index d5341937893..edd02acbac0 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -322,7 +322,7 @@ impl<'a> Drop for LocalDataEnterGuard<'a> { } cfg_rt! { - /// Spawns a `!Send` future on the current [`LocalSet`]. + /// Spawns a `!Send` future on the current [`LocalSet`] or [`LocalRuntime`]. /// /// The spawned future will run on the same thread that called `spawn_local`. /// @@ -362,6 +362,7 @@ cfg_rt! { /// ``` /// /// [`LocalSet`]: struct@crate::task::LocalSet + /// [`LocalRuntime`]: struct@crate::runtime::LocalRuntime /// [`tokio::spawn`]: fn@crate::task::spawn #[track_caller] pub fn spawn_local(future: F) -> JoinHandle @@ -383,10 +384,51 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { - None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), - Some(cx) => cx.spawn(future, meta) - } + use crate::runtime::{context, task}; + + let mut future = Some(future); + + let res = context::with_current(|handle| { + Some(if handle.is_local() { + if !handle.can_spawn_local_on_local_runtime() { + return None; + } + + let future = future.take().unwrap(); + + #[cfg(all( + tokio_unstable, + tokio_taskdump, + feature = "rt", + target_os = "linux", + any( + target_arch = "aarch64", + target_arch = "x86", + target_arch = "x86_64" + ) + ))] + let future = task::trace::Trace::root(future); + let id = task::Id::next(); + let task = crate::util::trace::task(future, "task", meta, id.as_u64()); + + // safety: we have verified that this is a `LocalRuntime` owned by the current thread + unsafe { handle.spawn_local(task, id) } + } else { + match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { + None => panic!("`spawn_local` called from outside of a `task::LocalSet` or LocalRuntime"), + Some(cx) => cx.spawn(future.take().unwrap(), meta) + } + }) + }); + + match res { + Ok(None) => panic!("Local tasks can only be spawned on a LocalRuntime from the thread the runtime was created on"), + Ok(Some(join_handle)) => join_handle, + Err(_) => match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { + None => panic!("`spawn_local` called from outside of a `task::LocalSet` or LocalRuntime"), + Some(cx) => cx.spawn(future.unwrap(), meta) + } + } } } diff --git a/tokio/src/util/trace.rs b/tokio/src/util/trace.rs index 97006df474e..b6eadba2205 100644 --- a/tokio/src/util/trace.rs +++ b/tokio/src/util/trace.rs @@ -1,6 +1,7 @@ cfg_rt! { use std::marker::PhantomData; + #[derive(Copy, Clone)] pub(crate) struct SpawnMeta<'a> { /// The name of the task #[cfg(all(tokio_unstable, feature = "tracing"))] diff --git a/tokio/tests/rt_local.rs b/tokio/tests/rt_local.rs new file mode 100644 index 00000000000..1f14f5444d3 --- /dev/null +++ b/tokio/tests/rt_local.rs @@ -0,0 +1,100 @@ +#![allow(unknown_lints, unexpected_cfgs)] +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full", tokio_unstable))] + +use tokio::runtime::LocalOptions; +use tokio::task::spawn_local; + +#[test] +fn test_spawn_local_in_runtime() { + let rt = rt(); + + let res = rt.block_on(async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + + spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + rx.await.unwrap() + }); + + assert_eq!(res, 5); +} + +#[test] +fn test_spawn_from_handle() { + let rt = rt(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + rt.handle().spawn(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + let res = rt.block_on(async move { rx.await.unwrap() }); + + assert_eq!(res, 5); +} + +#[test] +fn test_spawn_local_on_runtime_object() { + let rt = rt(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + rt.spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + let res = rt.block_on(async move { rx.await.unwrap() }); + + assert_eq!(res, 5); +} + +#[test] +fn test_spawn_local_from_guard() { + let rt = rt(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + let _guard = rt.enter(); + + spawn_local(async { + tokio::task::yield_now().await; + tx.send(5).unwrap(); + }); + + let res = rt.block_on(async move { rx.await.unwrap() }); + + assert_eq!(res, 5); +} + +#[test] +#[should_panic] +fn test_spawn_local_from_guard_other_thread() { + let (tx, rx) = std::sync::mpsc::channel(); + + std::thread::spawn(move || { + let rt = rt(); + let handle = rt.handle().clone(); + + tx.send(handle).unwrap(); + }); + + let handle = rx.recv().unwrap(); + + let _guard = handle.enter(); + + spawn_local(async {}); +} + +fn rt() -> tokio::runtime::LocalRuntime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build_local(&LocalOptions::default()) + .unwrap() +}