Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Thread executor for running tasks on specific threads. #7087

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/bevy_tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ mod usages;
pub use usages::tick_global_task_pools_on_main_thread;
pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};

#[cfg(not(target_arch = "wasm32"))]
mod thread_executor;
#[cfg(not(target_arch = "wasm32"))]
pub use thread_executor::{ThreadExecutor, ThreadExecutorTicker};

mod iter;
pub use iter::ParallelIterator;

Expand Down
103 changes: 53 additions & 50 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use async_task::FallibleTask;
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, pin, FutureExt};

use crate::Task;
use crate::{thread_executor::ThreadExecutor, Task};

struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);

Expand Down Expand Up @@ -108,6 +108,7 @@ pub struct TaskPool {
impl TaskPool {
thread_local! {
static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
static THREAD_EXECUTOR: ThreadExecutor<'static> = ThreadExecutor::new();
}

/// Create a `TaskPool` with the default configuration.
Expand Down Expand Up @@ -271,59 +272,61 @@ impl TaskPool {
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
T: Send + 'static,
{
// SAFETY: This safety comment applies to all references transmuted to 'env.
// Any futures spawned with these references need to return before this function completes.
// This is guaranteed because we drive all the futures spawned onto the Scope
// to completion in this function. However, rust has no way of knowing this so we
// transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
let executor: &async_executor::Executor = &self.executor;
let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) };
let task_scope_executor = &async_executor::Executor::default();
let task_scope_executor: &'env async_executor::Executor =
unsafe { mem::transmute(task_scope_executor) };
let spawned: ConcurrentQueue<FallibleTask<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<FallibleTask<T>> =
unsafe { mem::transmute(&spawned) };

let scope = Scope {
executor,
task_scope_executor,
spawned: spawned_ref,
scope: PhantomData,
env: PhantomData,
};

let scope_ref: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };

f(scope_ref);

if spawned.is_empty() {
Vec::new()
} else {
let get_results = async {
let mut results = Vec::with_capacity(spawned_ref.len());
while let Ok(task) = spawned_ref.pop() {
results.push(task.await.unwrap());
}

results
Self::THREAD_EXECUTOR.with(|thread_executor| {
// SAFETY: This safety comment applies to all references transmuted to 'env.
// Any futures spawned with these references need to return before this function completes.
// This is guaranteed because we drive all the futures spawned onto the Scope
// to completion in this function. However, rust has no way of knowing this so we
// transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
let executor: &async_executor::Executor = &self.executor;
let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) };
let thread_executor: &'env ThreadExecutor<'env> =
unsafe { mem::transmute(thread_executor) };
let spawned: ConcurrentQueue<FallibleTask<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<FallibleTask<T>> =
unsafe { mem::transmute(&spawned) };

let scope = Scope {
executor,
thread_executor,
spawned: spawned_ref,
scope: PhantomData,
env: PhantomData,
};

// Pin the futures on the stack.
pin!(get_results);
let scope_ref: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };

f(scope_ref);

if spawned.is_empty() {
Vec::new()
} else {
let get_results = async {
let mut results = Vec::with_capacity(spawned_ref.len());
while let Ok(task) = spawned_ref.pop() {
results.push(task.await.unwrap());
}

loop {
if let Some(result) = future::block_on(future::poll_once(&mut get_results)) {
break result;
results
};

std::panic::catch_unwind(|| {
executor.try_tick();
task_scope_executor.try_tick();
})
.ok();
// Pin the futures on the stack.
pin!(get_results);

let thread_ticker = thread_executor.ticker().unwrap();
loop {
if let Some(result) = future::block_on(future::poll_once(&mut get_results)) {
break result;
};

std::panic::catch_unwind(|| {
executor.try_tick();
thread_ticker.try_tick();
hymm marked this conversation as resolved.
Show resolved Hide resolved
})
.ok();
}
}
}
})
}

/// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
Expand Down Expand Up @@ -395,7 +398,7 @@ impl Drop for TaskPool {
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope, T> {
executor: &'scope async_executor::Executor<'scope>,
task_scope_executor: &'scope async_executor::Executor<'scope>,
thread_executor: &'scope ThreadExecutor<'scope>,
spawned: &'scope ConcurrentQueue<FallibleTask<T>>,
// make `Scope` invariant over 'scope and 'env
scope: PhantomData<&'scope mut &'scope ()>,
Expand Down Expand Up @@ -425,7 +428,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
///
/// For more information, see [`TaskPool::scope`].
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self.task_scope_executor.spawn(f).fallible();
let task = self.thread_executor.spawn(f).fallible();
// ConcurrentQueue only errors when closed or full, but we never
// close and use an unbounded queue, so it is safe to unwrap
self.spawned.push(task).unwrap();
Expand Down
128 changes: 128 additions & 0 deletions crates/bevy_tasks/src/thread_executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::{
marker::PhantomData,
thread::{self, ThreadId},
};

use async_executor::{Executor, Task};
use futures_lite::Future;

/// An executor that can only be ticked on the thread it was instantiated on. But
/// can spawn `Send` tasks from other threads.
///
/// # Example
/// ```rust
/// # use std::sync::{Arc, atomic::{AtomicI32, Ordering}};
/// use bevy_tasks::ThreadExecutor;
///
/// let thread_executor = ThreadExecutor::new();
/// let count = Arc::new(AtomicI32::new(0));
///
/// // create some owned values that can be moved into another thread
/// let count_clone = count.clone();
///
/// std::thread::scope(|scope| {
/// scope.spawn(|| {
/// // we cannot get the ticker from another thread
/// let not_thread_ticker = thread_executor.ticker();
/// assert!(not_thread_ticker.is_none());
///
/// // but we can spawn tasks from another thread
/// thread_executor.spawn(async move {
/// count_clone.fetch_add(1, Ordering::Relaxed);
/// }).detach();
/// });
/// });
///
/// // the tasks do not make progress unless the executor is manually ticked
/// assert_eq!(count.load(Ordering::Relaxed), 0);
///
/// // tick the ticker until task finishes
/// let thread_ticker = thread_executor.ticker().unwrap();
/// thread_ticker.try_tick();
/// assert_eq!(count.load(Ordering::Relaxed), 1);
/// ```
#[derive(Debug)]
pub struct ThreadExecutor<'task> {
executor: Executor<'task>,
thread_id: ThreadId,
}

impl<'task> Default for ThreadExecutor<'task> {
fn default() -> Self {
Self {
executor: Executor::new(),
thread_id: thread::current().id(),
}
}
}

impl<'task> ThreadExecutor<'task> {
/// create a new [`ThreadExecutor`]
pub fn new() -> Self {
Self::default()
}

/// Spawn a task on the thread executor
pub fn spawn<T: Send + 'task>(
&self,
future: impl Future<Output = T> + Send + 'task,
) -> Task<T> {
self.executor.spawn(future)
}

/// Gets the [`ThreadExecutorTicker`] for this executor.
/// Use this to tick the executor.
/// It only returns the ticker if it's on the thread the executor was created on
/// and returns `None` otherwise.
pub fn ticker<'ticker>(&'ticker self) -> Option<ThreadExecutorTicker<'task, 'ticker>> {
if thread::current().id() == self.thread_id {
return Some(ThreadExecutorTicker {
executor: &self.executor,
_marker: PhantomData::default(),
});
}
None
}
}

/// Used to tick the [`ThreadExecutor`]. The executor does not
/// make progress unless it is manually ticked on the thread it was
/// created on.
#[derive(Debug)]
pub struct ThreadExecutorTicker<'task, 'ticker> {
executor: &'ticker Executor<'task>,
// make type not send or sync
_marker: PhantomData<*const ()>,
}
impl<'task, 'ticker> ThreadExecutorTicker<'task, 'ticker> {
/// Tick the thread executor.
pub async fn tick(&self) {
self.executor.tick().await;
}

/// Synchronously try to tick a task on the executor.
/// Returns false if if does not find a task to tick.
pub fn try_tick(&self) -> bool {
self.executor.try_tick()
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;

#[test]
fn test_ticker() {
let executor = Arc::new(ThreadExecutor::new());
let ticker = executor.ticker();
assert!(ticker.is_some());

std::thread::scope(|s| {
s.spawn(|| {
let ticker = executor.ticker();
assert!(ticker.is_none());
});
});
}
}