diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 39308c2bfd7bc..20b5c518db58e 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt}; use crate::Task; +struct CallOnDrop(Option>); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + if let Some(call) = self.0.as_ref() { + call(); + } + } +} + /// Used to create a [`TaskPool`] -#[derive(Debug, Default, Clone)] +#[derive(Default)] #[must_use] pub struct TaskPoolBuilder { /// If set, we'll set up the thread pool to use at most `num_threads` threads. @@ -24,6 +34,9 @@ pub struct TaskPoolBuilder { /// Allows customizing the name of the threads - helpful for debugging. If set, threads will /// be named (), i.e. "MyThreadPool (2)" thread_name: Option, + + on_thread_spawn: Option>, + on_thread_destroy: Option>, } impl TaskPoolBuilder { @@ -52,13 +65,27 @@ impl TaskPoolBuilder { self } + /// Sets a callback that is invoked once for every created thread as it starts. + /// + /// This is called on the thread itself and has access to all thread-local storage. + /// This will block running async tasks on the thread until the callback completes. + pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self { + self.on_thread_spawn = Some(Arc::new(f)); + self + } + + /// Sets a callback that is invoked once for every created thread as it terminates. + /// + /// This is called on the thread itself and has access to all thread-local storage. + /// This will block thread termination until the callback completes. + pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self { + self.on_thread_destroy = Some(Arc::new(f)); + self + } + /// Creates a new [`TaskPool`] based on the current options. pub fn build(self) -> TaskPool { - TaskPool::new_internal( - self.num_threads, - self.stack_size, - self.thread_name.as_deref(), - ) + TaskPool::new_internal(self) } } @@ -88,36 +115,42 @@ impl TaskPool { TaskPoolBuilder::new().build() } - fn new_internal( - num_threads: Option, - stack_size: Option, - thread_name: Option<&str>, - ) -> Self { + fn new_internal(builder: TaskPoolBuilder) -> Self { let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>(); let executor = Arc::new(async_executor::Executor::new()); - let num_threads = num_threads.unwrap_or_else(crate::available_parallelism); + let num_threads = builder + .num_threads + .unwrap_or_else(crate::available_parallelism); let threads = (0..num_threads) .map(|i| { let ex = Arc::clone(&executor); let shutdown_rx = shutdown_rx.clone(); - let thread_name = if let Some(thread_name) = thread_name { + let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() { format!("{thread_name} ({i})") } else { format!("TaskPool ({i})") }; let mut thread_builder = thread::Builder::new().name(thread_name); - if let Some(stack_size) = stack_size { + if let Some(stack_size) = builder.stack_size { thread_builder = thread_builder.stack_size(stack_size); } + let on_thread_spawn = builder.on_thread_spawn.clone(); + let on_thread_destroy = builder.on_thread_destroy.clone(); + thread_builder .spawn(move || { TaskPool::LOCAL_EXECUTOR.with(|local_executor| { + if let Some(on_thread_spawn) = on_thread_spawn { + on_thread_spawn(); + drop(on_thread_spawn); + } + let _destructor = CallOnDrop(on_thread_destroy); loop { let res = std::panic::catch_unwind(|| { let tick_forever = async move { @@ -451,6 +484,57 @@ mod tests { assert_eq!(count.load(Ordering::Relaxed), 100); } + #[test] + fn test_thread_callbacks() { + let counter = Arc::new(AtomicI32::new(0)); + let start_counter = counter.clone(); + { + let barrier = Arc::new(Barrier::new(11)); + let last_barrier = barrier.clone(); + // Build and immediately drop to terminate + let _pool = TaskPoolBuilder::new() + .num_threads(10) + .on_thread_spawn(move || { + start_counter.fetch_add(1, Ordering::Relaxed); + barrier.clone().wait(); + }) + .build(); + last_barrier.wait(); + assert_eq!(10, counter.load(Ordering::Relaxed)); + } + assert_eq!(10, counter.load(Ordering::Relaxed)); + let end_counter = counter.clone(); + { + let _pool = TaskPoolBuilder::new() + .num_threads(20) + .on_thread_destroy(move || { + end_counter.fetch_sub(1, Ordering::Relaxed); + }) + .build(); + assert_eq!(10, counter.load(Ordering::Relaxed)); + } + assert_eq!(-10, counter.load(Ordering::Relaxed)); + let start_counter = counter.clone(); + let end_counter = counter.clone(); + { + let barrier = Arc::new(Barrier::new(6)); + let last_barrier = barrier.clone(); + let _pool = TaskPoolBuilder::new() + .num_threads(5) + .on_thread_spawn(move || { + start_counter.fetch_add(1, Ordering::Relaxed); + barrier.wait(); + }) + .on_thread_destroy(move || { + end_counter.fetch_sub(1, Ordering::Relaxed); + }) + .build(); + last_barrier.wait(); + assert_eq!(-5, counter.load(Ordering::Relaxed)); + } + assert_eq!(-10, counter.load(Ordering::Relaxed)); + } + #[test] fn test_mixed_spawn_on_scope_and_spawn() { let pool = TaskPool::new();