From 36ccd85480c65362b6c1ccb4acfa3aecd136028f Mon Sep 17 00:00:00 2001 From: Michael Hsu Date: Sun, 6 Nov 2022 10:59:36 -0800 Subject: [PATCH] make a safer abstraction for the main thread executor --- .../src/schedule/executor_parallel.rs | 4 +- crates/bevy_tasks/src/lib.rs | 5 +- crates/bevy_tasks/src/main_thread_executor.rs | 75 +++++++++++++++++++ crates/bevy_tasks/src/task_pool.rs | 39 ++++------ crates/bevy_tasks/src/usages.rs | 24 ------ 5 files changed, 96 insertions(+), 51 deletions(-) create mode 100644 crates/bevy_tasks/src/main_thread_executor.rs diff --git a/crates/bevy_ecs/src/schedule/executor_parallel.rs b/crates/bevy_ecs/src/schedule/executor_parallel.rs index 68dd1f1ea798dc..8883b2922a386b 100644 --- a/crates/bevy_ecs/src/schedule/executor_parallel.rs +++ b/crates/bevy_ecs/src/schedule/executor_parallel.rs @@ -236,7 +236,7 @@ impl ParallelExecutor { if system_data.is_send { scope.spawn(task); } else { - scope.spawn_on_scope(task); + scope.spawn_on_main(task); } #[cfg(test)] @@ -271,7 +271,7 @@ impl ParallelExecutor { if system_data.is_send { scope.spawn(task); } else { - scope.spawn_on_scope(task); + scope.spawn_on_main(task); } } } diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 3382a8b252996e..3be4008fb86e91 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -20,7 +20,10 @@ pub use single_threaded_task_pool::{Scope, TaskPool, TaskPoolBuilder}; mod usages; #[cfg(not(target_arch = "wasm32"))] pub use usages::tick_global_task_pools_on_main_thread; -pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, MainThreadExecutor}; +pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; + +mod main_thread_executor; +pub use main_thread_executor::MainThreadExecutor; mod iter; pub use iter::ParallelIterator; diff --git a/crates/bevy_tasks/src/main_thread_executor.rs b/crates/bevy_tasks/src/main_thread_executor.rs new file mode 100644 index 00000000000000..b90a8abe03ea99 --- /dev/null +++ b/crates/bevy_tasks/src/main_thread_executor.rs @@ -0,0 +1,75 @@ +use std::{marker::PhantomData, sync::Arc}; + +use async_executor::{Executor, Task}; +use futures_lite::Future; +use is_main_thread::is_main_thread; +use once_cell::sync::OnceCell; + +static MAIN_THREAD_EXECUTOR: OnceCell = OnceCell::new(); + +/// Use to access the global main thread executor. Be aware that the main thread executor +/// only makes progress when it is ticked. This normally happens in `[TaskPool::scope]`. +#[derive(Debug)] +pub struct MainThreadExecutor(Arc>); + +impl MainThreadExecutor { + /// Initializes the global `[MainThreadExecutor]` instance. + pub fn init() -> &'static Self { + MAIN_THREAD_EXECUTOR.get_or_init(|| Self(Arc::new(Executor::new()))) + } + + /// Gets the global [`MainThreadExecutor`] instance. + /// + /// # Panics + /// Panics if no executor has been initialized yet. + pub fn get() -> &'static Self { + MAIN_THREAD_EXECUTOR.get().expect( + "A MainThreadExecutor has not been initialize yet. Please call \ + MainThreadExecutor::init beforehand", + ) + } + + /// Gets the `[MainThreadSpawner]` for the global main thread executor. + /// Use this to spawn tasks on the main thread. + pub fn spawner(&self) -> MainThreadSpawner<'static> { + MainThreadSpawner(self.0.clone()) + } + + /// Gets the `[MainThreadTicker]` for the global main thread executor. + /// Use this to tick the main thread executor. + /// Returns None if called on not the main thread. + pub fn ticker(&self) -> Option { + if let Some(is_main) = is_main_thread() { + if is_main { + return Some(MainThreadTicker { + executor: self.0.clone(), + _marker: PhantomData::default(), + }); + } + } + None + } +} + +#[derive(Debug)] +pub struct MainThreadSpawner<'a>(Arc>); +impl<'a> MainThreadSpawner<'a> { + /// Spawn a task on the main thread + pub fn spawn(&self, future: impl Future + Send + 'a) -> Task { + self.0.spawn(future) + } +} + +#[derive(Debug)] +pub struct MainThreadTicker { + executor: Arc>, + // make type not send or sync + _marker: PhantomData<*const ()>, +} +impl MainThreadTicker { + /// Tick the main thread executor. + /// This needs to be called manually on the main thread if a `[TaskPool::scope]` is not active + pub fn tick<'a>(&'a self) -> impl Future + 'a { + self.executor.tick() + } +} diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index ceea96cb4defab..dde21334ff3a36 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -8,10 +8,9 @@ use std::{ use concurrent_queue::ConcurrentQueue; use futures_lite::{future, FutureExt}; -use is_main_thread::is_main_thread; -use crate::MainThreadExecutor; use crate::Task; +use crate::{main_thread_executor::MainThreadSpawner, MainThreadExecutor}; /// Used to create a [`TaskPool`] #[derive(Debug, Default, Clone)] @@ -246,16 +245,16 @@ impl TaskPool { // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. let executor: &async_executor::Executor = &self.executor; let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) }; - let task_scope_executor = MainThreadExecutor::init(); - let task_scope_executor: &'env async_executor::Executor = - unsafe { mem::transmute(task_scope_executor) }; + let main_thread_spawner = MainThreadExecutor::init().spawner(); + let main_thread_spawner: MainThreadSpawner<'env> = + unsafe { mem::transmute(main_thread_spawner) }; let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); let spawned_ref: &'env ConcurrentQueue> = unsafe { mem::transmute(&spawned) }; let scope = Scope { executor, - task_scope_executor, + main_thread_spawner, spawned: spawned_ref, scope: PhantomData, env: PhantomData, @@ -278,20 +277,10 @@ impl TaskPool { results }; - let is_main = if let Some(is_main) = is_main_thread() { - is_main - } else { - false - }; - - if is_main { + if let Some(main_thread_ticker) = MainThreadExecutor::get().ticker() { let tick_forever = async move { loop { - if let Some(is_main) = is_main_thread() { - if is_main { - task_scope_executor.tick().await; - } - } + main_thread_ticker.tick().await; } }; @@ -372,7 +361,7 @@ impl Drop for TaskPool { #[derive(Debug)] pub struct Scope<'scope, 'env: 'scope, T> { executor: &'scope async_executor::Executor<'scope>, - task_scope_executor: &'scope async_executor::Executor<'scope>, + main_thread_spawner: MainThreadSpawner<'scope>, spawned: &'scope ConcurrentQueue>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, @@ -401,8 +390,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { - let task = self.task_scope_executor.spawn(f); + pub fn spawn_on_main + 'scope + Send>(&self, f: Fut) { + let main_thread_spawner: &MainThreadSpawner<'scope> = + unsafe { mem::transmute(&self.main_thread_spawner) }; + let task = main_thread_spawner.spawn(f); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbouded queue, so it is safe to unwrap self.spawned.push(task).unwrap(); @@ -473,7 +464,7 @@ mod tests { }); } else { let count_clone = local_count.clone(); - scope.spawn_on_scope(async move { + scope.spawn_on_main(async move { if *foo != 42 { panic!("not 42!?!?") } else { @@ -514,7 +505,7 @@ mod tests { }); let spawner = std::thread::current().id(); let inner_count_clone = count_clone.clone(); - scope.spawn_on_scope(async move { + scope.spawn_on_main(async move { inner_count_clone.fetch_add(1, Ordering::Release); if std::thread::current().id() != spawner { // NOTE: This check is using an atomic rather than simply panicing the @@ -589,7 +580,7 @@ mod tests { inner_count_clone.fetch_add(1, Ordering::Release); // spawning on the scope from another thread runs the futures on the scope's thread - scope.spawn_on_scope(async move { + scope.spawn_on_main(async move { inner_count_clone.fetch_add(1, Ordering::Release); if std::thread::current().id() != spawner { // NOTE: This check is using an atomic rather than simply panicing the diff --git a/crates/bevy_tasks/src/usages.rs b/crates/bevy_tasks/src/usages.rs index e9f22af3fbb7bc..1d0c83b271c2f1 100644 --- a/crates/bevy_tasks/src/usages.rs +++ b/crates/bevy_tasks/src/usages.rs @@ -17,7 +17,6 @@ use std::ops::Deref; static COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); static ASYNC_COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); static IO_TASK_POOL: OnceCell = OnceCell::new(); -static MAIN_THREAD_EXECUTOR: OnceCell = OnceCell::new(); /// A newtype for a task pool for CPU-intensive work that must be completed to deliver the next /// frame @@ -111,29 +110,6 @@ impl Deref for IoTaskPool { } } -pub struct MainThreadExecutor(async_executor::Executor<'static>); - -impl MainThreadExecutor { - pub fn init() -> &'static Self { - MAIN_THREAD_EXECUTOR.get_or_init(|| Self(async_executor::Executor::new())) - } - - pub fn get() -> &'static Self { - MAIN_THREAD_EXECUTOR.get().expect( - "A MainThreadExecutor has not been initialize yet. Please call \ - MainThreadExecutor::init beforehand", - ) - } -} - -impl Deref for MainThreadExecutor { - type Target = async_executor::Executor<'static>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - /// Used by `bevy_core` to tick the global tasks pools on the main thread. /// This will run a maximum of 100 local tasks per executor per call to this function. #[cfg(not(target_arch = "wasm32"))]