From e6020c0fed28128fce85d69352b540ddc1bbaf69 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Thu, 30 Jun 2022 12:08:41 +0200 Subject: [PATCH] task: various small improvements to `LocalKey` (#4795) Co-authored-by: Kitsu --- tokio/src/task/task_local.rs | 296 +++++++++++++++++++++++++---------- tokio/tests/task_local.rs | 95 ++++++++++- 2 files changed, 301 insertions(+), 90 deletions(-) diff --git a/tokio/src/task/task_local.rs b/tokio/src/task/task_local.rs index 70fa967df5e..e92b05fdc98 100644 --- a/tokio/src/task/task_local.rs +++ b/tokio/src/task/task_local.rs @@ -1,11 +1,10 @@ -use pin_project_lite::pin_project; use std::cell::RefCell; use std::error::Error; use std::future::Future; use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{fmt, thread}; +use std::{fmt, mem, thread}; /// Declares a new task-local key of type [`tokio::task::LocalKey`]. /// @@ -79,7 +78,7 @@ macro_rules! __task_local_inner { /// A key for task-local data. /// -/// This type is generated by the `task_local!` macro. +/// This type is generated by the [`task_local!`] macro. /// /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will /// _not_ lazily initialize the value on first access. Instead, the @@ -107,7 +106,9 @@ macro_rules! __task_local_inner { /// }).await; /// # } /// ``` +/// /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey +/// [`task_local!`]: ../macro.task_local.html #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub struct LocalKey { #[doc(hidden)] @@ -119,6 +120,11 @@ impl LocalKey { /// /// On completion of `scope`, the task-local will be dropped. /// + /// ### Panics + /// + /// If you poll the returned future inside a call to [`with`] or + /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic. + /// /// ### Examples /// /// ``` @@ -132,6 +138,9 @@ impl LocalKey { /// }).await; /// # } /// ``` + /// + /// [`with`]: fn@Self::with + /// [`try_with`]: fn@Self::try_with pub fn scope(&'static self, value: T, f: F) -> TaskLocalFuture where F: Future, @@ -139,7 +148,7 @@ impl LocalKey { TaskLocalFuture { local: self, slot: Some(value), - future: f, + future: Some(f), _pinned: PhantomPinned, } } @@ -148,6 +157,11 @@ impl LocalKey { /// /// On completion of `scope`, the task-local will be dropped. /// + /// ### Panics + /// + /// This method panics if called inside a call to [`with`] or [`try_with`] + /// on the same `LocalKey`. + /// /// ### Examples /// /// ``` @@ -161,34 +175,80 @@ impl LocalKey { /// }); /// # } /// ``` + /// + /// [`with`]: fn@Self::with + /// [`try_with`]: fn@Self::try_with + #[track_caller] pub fn sync_scope(&'static self, value: T, f: F) -> R where F: FnOnce() -> R, { - let scope = TaskLocalFuture { - local: self, - slot: Some(value), - future: (), - _pinned: PhantomPinned, - }; - crate::pin!(scope); - scope.with_task(|_| f()) + let mut value = Some(value); + match self.scope_inner(&mut value, f) { + Ok(res) => res, + Err(err) => err.panic(), + } + } + + fn scope_inner(&'static self, slot: &mut Option, f: F) -> Result + where + F: FnOnce() -> R, + { + struct Guard<'a, T: 'static> { + local: &'static LocalKey, + slot: &'a mut Option, + } + + impl<'a, T: 'static> Drop for Guard<'a, T> { + fn drop(&mut self) { + // This should not panic. + // + // We know that the RefCell was not borrowed before the call to + // `scope_inner`, so the only way for this to panic is if the + // closure has created but not destroyed a RefCell guard. + // However, we never give user-code access to the guards, so + // there's no way for user-code to forget to destroy a guard. + // + // The call to `with` also should not panic, since the + // thread-local wasn't destroyed when we first called + // `scope_inner`, and it shouldn't have gotten destroyed since + // then. + self.local.inner.with(|inner| { + let mut ref_mut = inner.borrow_mut(); + mem::swap(self.slot, &mut *ref_mut); + }); + } + } + + self.inner.try_with(|inner| { + inner + .try_borrow_mut() + .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut)) + })??; + + let guard = Guard { local: self, slot }; + + let res = f(); + + drop(guard); + + Ok(res) } /// Accesses the current task-local and runs the provided closure. /// /// # Panics /// - /// This function will panic if not called within the context - /// of a future containing a task-local with the corresponding key. + /// This function will panic if the task local doesn't have a value set. + #[track_caller] pub fn with(&'static self, f: F) -> R where F: FnOnce(&T) -> R, { - self.try_with(f).expect( - "cannot access a Task Local Storage value \ - without setting it via `LocalKey::set`", - ) + match self.try_with(f) { + Ok(res) => res, + Err(_) => panic!("cannot access a task-local storage value without setting it first"), + } } /// Accesses the current task-local and runs the provided closure. @@ -200,19 +260,31 @@ impl LocalKey { where F: FnOnce(&T) -> R, { - self.inner.with(|v| { - if let Some(val) = v.borrow().as_ref() { - Ok(f(val)) - } else { - Err(AccessError { _private: () }) - } - }) + // If called after the thread-local storing the task-local is destroyed, + // then we are outside of a closure where the task-local is set. + // + // Therefore, it is correct to return an AccessError if `try_with` + // returns an error. + let try_with_res = self.inner.try_with(|v| { + // This call to `borrow` cannot panic because no user-defined code + // runs while a `borrow_mut` call is active. + v.borrow().as_ref().map(f) + }); + + match try_with_res { + Ok(Some(res)) => Ok(res), + Ok(None) | Err(_) => Err(AccessError { _private: () }), + } } } impl LocalKey { /// Returns a copy of the task-local value /// if the task-local value implements `Copy`. + /// + /// # Panics + /// + /// This function will panic if the task local doesn't have a value set. pub fn get(&'static self) -> T { self.with(|v| *v) } @@ -224,76 +296,104 @@ impl fmt::Debug for LocalKey { } } -pin_project! { - /// A future that sets a value `T` of a task local for the future `F` during - /// its execution. - /// - /// The value of the task-local must be `'static` and will be dropped on the - /// completion of the future. - /// - /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). - /// - /// ### Examples - /// - /// ``` - /// # async fn dox() { - /// tokio::task_local! { - /// static NUMBER: u32; - /// } - /// - /// NUMBER.scope(1, async move { - /// println!("task local value: {}", NUMBER.get()); - /// }).await; - /// # } - /// ``` - pub struct TaskLocalFuture - where - T: 'static - { - local: &'static LocalKey, - slot: Option, - #[pin] - future: F, - #[pin] - _pinned: PhantomPinned, - } +/// A future that sets a value `T` of a task local for the future `F` during +/// its execution. +/// +/// The value of the task-local must be `'static` and will be dropped on the +/// completion of the future. +/// +/// Created by the function [`LocalKey::scope`](self::LocalKey::scope). +/// +/// ### Examples +/// +/// ``` +/// # async fn dox() { +/// tokio::task_local! { +/// static NUMBER: u32; +/// } +/// +/// NUMBER.scope(1, async move { +/// println!("task local value: {}", NUMBER.get()); +/// }).await; +/// # } +/// ``` +// Doesn't use pin_project due to custom Drop. +pub struct TaskLocalFuture +where + T: 'static, +{ + local: &'static LocalKey, + slot: Option, + future: Option, + _pinned: PhantomPinned, } -impl TaskLocalFuture { - fn with_task) -> R, R>(self: Pin<&mut Self>, f: F2) -> R { - struct Guard<'a, T: 'static> { - local: &'static LocalKey, - slot: &'a mut Option, - prev: Option, - } - - impl Drop for Guard<'_, T> { - fn drop(&mut self) { - let value = self.local.inner.with(|c| c.replace(self.prev.take())); - *self.slot = value; - } - } +impl Future for TaskLocalFuture { + type Output = F::Output; - let project = self.project(); - let val = project.slot.take(); + #[track_caller] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // safety: The TaskLocalFuture struct is `!Unpin` so there is no way to + // move `self.future` from now on. + let this = unsafe { Pin::into_inner_unchecked(self) }; + let mut future_opt = unsafe { Pin::new_unchecked(&mut this.future) }; - let prev = project.local.inner.with(|c| c.replace(val)); + let res = + this.local + .scope_inner(&mut this.slot, || match future_opt.as_mut().as_pin_mut() { + Some(fut) => { + let res = fut.poll(cx); + if res.is_ready() { + future_opt.set(None); + } + Some(res) + } + None => None, + }); - let _guard = Guard { - prev, - slot: project.slot, - local: *project.local, - }; + match res { + Ok(Some(res)) => res, + Ok(None) => panic!("`TaskLocalFuture` polled after completion"), + Err(err) => err.panic(), + } + } +} - f(project.future) +impl Drop for TaskLocalFuture { + fn drop(&mut self) { + if mem::needs_drop::() && self.future.is_some() { + // Drop the future while the task-local is set, if possible. Otherwise + // the future is dropped normally when the `Option` field drops. + let future = &mut self.future; + let _ = self.local.scope_inner(&mut self.slot, || { + *future = None; + }); + } } } -impl Future for TaskLocalFuture { - type Output = F::Output; +impl fmt::Debug for TaskLocalFuture +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + /// Format the Option without Some. + struct TransparentOption<'a, T> { + value: &'a Option, + } + impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.value.as_ref() { + Some(value) => value.fmt(f), + // Hitting the None branch should not be possible. + None => f.pad(""), + } + } + } - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.with_task(|f| f.poll(cx)) + f.debug_struct("TaskLocalFuture") + .field("value", &TransparentOption { value: &self.slot }) + .finish() } } @@ -316,3 +416,29 @@ impl fmt::Display for AccessError { } impl Error for AccessError {} + +enum ScopeInnerErr { + BorrowError, + AccessError, +} + +impl ScopeInnerErr { + fn panic(&self) -> ! { + match self { + Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed"), + Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local"), + } + } +} + +impl From for ScopeInnerErr { + fn from(_: std::cell::BorrowMutError) -> Self { + Self::BorrowError + } +} + +impl From for ScopeInnerErr { + fn from(_: std::thread::AccessError) -> Self { + Self::AccessError + } +} diff --git a/tokio/tests/task_local.rs b/tokio/tests/task_local.rs index 811d63ea0f8..4e33f29be43 100644 --- a/tokio/tests/task_local.rs +++ b/tokio/tests/task_local.rs @@ -1,12 +1,16 @@ #![cfg(feature = "full")] - -tokio::task_local! { - static REQ_ID: u32; - pub static FOO: bool; -} +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::oneshot; #[tokio::test(flavor = "multi_thread")] async fn local() { + tokio::task_local! { + static REQ_ID: u32; + pub static FOO: bool; + } + let j1 = tokio::spawn(REQ_ID.scope(1, async move { assert_eq!(REQ_ID.get(), 1); assert_eq!(REQ_ID.get(), 1); @@ -31,3 +35,84 @@ async fn local() { j2.await.unwrap(); j3.await.unwrap(); } + +#[tokio::test] +async fn task_local_available_on_abort() { + tokio::task_local! { + static KEY: u32; + } + + struct MyFuture { + tx_poll: Option>, + tx_drop: Option>, + } + impl Future for MyFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + if let Some(tx_poll) = self.tx_poll.take() { + let _ = tx_poll.send(()); + } + Poll::Pending + } + } + impl Drop for MyFuture { + fn drop(&mut self) { + let _ = self.tx_drop.take().unwrap().send(KEY.get()); + } + } + + let (tx_drop, rx_drop) = oneshot::channel(); + let (tx_poll, rx_poll) = oneshot::channel(); + + let h = tokio::spawn(KEY.scope( + 42, + MyFuture { + tx_poll: Some(tx_poll), + tx_drop: Some(tx_drop), + }, + )); + + rx_poll.await.unwrap(); + h.abort(); + assert_eq!(rx_drop.await.unwrap(), 42); + + let err = h.await.unwrap_err(); + if !err.is_cancelled() { + if let Ok(panic) = err.try_into_panic() { + std::panic::resume_unwind(panic); + } else { + panic!(); + } + } +} + +#[tokio::test] +async fn task_local_available_on_completion_drop() { + tokio::task_local! { + static KEY: u32; + } + + struct MyFuture { + tx: Option>, + } + impl Future for MyFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } + } + impl Drop for MyFuture { + fn drop(&mut self) { + let _ = self.tx.take().unwrap().send(KEY.get()); + } + } + + let (tx, rx) = oneshot::channel(); + + let h = tokio::spawn(KEY.scope(42, MyFuture { tx: Some(tx) })); + + assert_eq!(rx.await.unwrap(), 42); + h.await.unwrap(); +}