Skip to content

Commit

Permalink
make a safer abstraction for the main thread executor
Browse files Browse the repository at this point in the history
  • Loading branch information
hymm committed Nov 8, 2022
1 parent ba1d0e1 commit 36ccd85
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 51 deletions.
4 changes: 2 additions & 2 deletions crates/bevy_ecs/src/schedule/executor_parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ impl ParallelExecutor {
if system_data.is_send {
scope.spawn(task);
} else {
scope.spawn_on_scope(task);
scope.spawn_on_main(task);
}

#[cfg(test)]
Expand Down Expand Up @@ -271,7 +271,7 @@ impl ParallelExecutor {
if system_data.is_send {
scope.spawn(task);
} else {
scope.spawn_on_scope(task);
scope.spawn_on_main(task);
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion crates/bevy_tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ pub use single_threaded_task_pool::{Scope, TaskPool, TaskPoolBuilder};
mod usages;
#[cfg(not(target_arch = "wasm32"))]
pub use usages::tick_global_task_pools_on_main_thread;
pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, MainThreadExecutor};
pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};

mod main_thread_executor;
pub use main_thread_executor::MainThreadExecutor;

mod iter;
pub use iter::ParallelIterator;
Expand Down
75 changes: 75 additions & 0 deletions crates/bevy_tasks/src/main_thread_executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::{marker::PhantomData, sync::Arc};

use async_executor::{Executor, Task};
use futures_lite::Future;
use is_main_thread::is_main_thread;
use once_cell::sync::OnceCell;

static MAIN_THREAD_EXECUTOR: OnceCell<MainThreadExecutor> = OnceCell::new();

/// Use to access the global main thread executor. Be aware that the main thread executor
/// only makes progress when it is ticked. This normally happens in `[TaskPool::scope]`.
#[derive(Debug)]
pub struct MainThreadExecutor(Arc<Executor<'static>>);

impl MainThreadExecutor {
/// Initializes the global `[MainThreadExecutor]` instance.
pub fn init() -> &'static Self {
MAIN_THREAD_EXECUTOR.get_or_init(|| Self(Arc::new(Executor::new())))
}

/// Gets the global [`MainThreadExecutor`] instance.
///
/// # Panics
/// Panics if no executor has been initialized yet.
pub fn get() -> &'static Self {
MAIN_THREAD_EXECUTOR.get().expect(
"A MainThreadExecutor has not been initialize yet. Please call \
MainThreadExecutor::init beforehand",
)
}

/// Gets the `[MainThreadSpawner]` for the global main thread executor.
/// Use this to spawn tasks on the main thread.
pub fn spawner(&self) -> MainThreadSpawner<'static> {
MainThreadSpawner(self.0.clone())
}

/// Gets the `[MainThreadTicker]` for the global main thread executor.
/// Use this to tick the main thread executor.
/// Returns None if called on not the main thread.
pub fn ticker(&self) -> Option<MainThreadTicker> {
if let Some(is_main) = is_main_thread() {
if is_main {
return Some(MainThreadTicker {
executor: self.0.clone(),
_marker: PhantomData::default(),
});
}
}
None
}
}

#[derive(Debug)]
pub struct MainThreadSpawner<'a>(Arc<Executor<'a>>);
impl<'a> MainThreadSpawner<'a> {
/// Spawn a task on the main thread
pub fn spawn<T: Send + 'a>(&self, future: impl Future<Output = T> + Send + 'a) -> Task<T> {
self.0.spawn(future)
}
}

#[derive(Debug)]
pub struct MainThreadTicker {
executor: Arc<Executor<'static>>,
// make type not send or sync
_marker: PhantomData<*const ()>,
}
impl MainThreadTicker {
/// Tick the main thread executor.
/// This needs to be called manually on the main thread if a `[TaskPool::scope]` is not active
pub fn tick<'a>(&'a self) -> impl Future<Output = ()> + 'a {
self.executor.tick()
}
}
39 changes: 15 additions & 24 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use std::{

use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, FutureExt};
use is_main_thread::is_main_thread;

use crate::MainThreadExecutor;
use crate::Task;
use crate::{main_thread_executor::MainThreadSpawner, MainThreadExecutor};

/// Used to create a [`TaskPool`]
#[derive(Debug, Default, Clone)]
Expand Down Expand Up @@ -246,16 +245,16 @@ impl TaskPool {
// transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
let executor: &async_executor::Executor = &self.executor;
let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) };
let task_scope_executor = MainThreadExecutor::init();
let task_scope_executor: &'env async_executor::Executor =
unsafe { mem::transmute(task_scope_executor) };
let main_thread_spawner = MainThreadExecutor::init().spawner();
let main_thread_spawner: MainThreadSpawner<'env> =
unsafe { mem::transmute(main_thread_spawner) };
let spawned: ConcurrentQueue<async_executor::Task<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<async_executor::Task<T>> =
unsafe { mem::transmute(&spawned) };

let scope = Scope {
executor,
task_scope_executor,
main_thread_spawner,
spawned: spawned_ref,
scope: PhantomData,
env: PhantomData,
Expand All @@ -278,20 +277,10 @@ impl TaskPool {
results
};

let is_main = if let Some(is_main) = is_main_thread() {
is_main
} else {
false
};

if is_main {
if let Some(main_thread_ticker) = MainThreadExecutor::get().ticker() {
let tick_forever = async move {
loop {
if let Some(is_main) = is_main_thread() {
if is_main {
task_scope_executor.tick().await;
}
}
main_thread_ticker.tick().await;
}
};

Expand Down Expand Up @@ -372,7 +361,7 @@ impl Drop for TaskPool {
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope, T> {
executor: &'scope async_executor::Executor<'scope>,
task_scope_executor: &'scope async_executor::Executor<'scope>,
main_thread_spawner: MainThreadSpawner<'scope>,
spawned: &'scope ConcurrentQueue<async_executor::Task<T>>,
// make `Scope` invariant over 'scope and 'env
scope: PhantomData<&'scope mut &'scope ()>,
Expand Down Expand Up @@ -401,8 +390,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
/// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
///
/// For more information, see [`TaskPool::scope`].
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self.task_scope_executor.spawn(f);
pub fn spawn_on_main<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let main_thread_spawner: &MainThreadSpawner<'scope> =
unsafe { mem::transmute(&self.main_thread_spawner) };
let task = main_thread_spawner.spawn(f);
// ConcurrentQueue only errors when closed or full, but we never
// close and use an unbouded queue, so it is safe to unwrap
self.spawned.push(task).unwrap();
Expand Down Expand Up @@ -473,7 +464,7 @@ mod tests {
});
} else {
let count_clone = local_count.clone();
scope.spawn_on_scope(async move {
scope.spawn_on_main(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
Expand Down Expand Up @@ -514,7 +505,7 @@ mod tests {
});
let spawner = std::thread::current().id();
let inner_count_clone = count_clone.clone();
scope.spawn_on_scope(async move {
scope.spawn_on_main(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
if std::thread::current().id() != spawner {
// NOTE: This check is using an atomic rather than simply panicing the
Expand Down Expand Up @@ -589,7 +580,7 @@ mod tests {
inner_count_clone.fetch_add(1, Ordering::Release);

// spawning on the scope from another thread runs the futures on the scope's thread
scope.spawn_on_scope(async move {
scope.spawn_on_main(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
if std::thread::current().id() != spawner {
// NOTE: This check is using an atomic rather than simply panicing the
Expand Down
24 changes: 0 additions & 24 deletions crates/bevy_tasks/src/usages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::ops::Deref;
static COMPUTE_TASK_POOL: OnceCell<ComputeTaskPool> = OnceCell::new();
static ASYNC_COMPUTE_TASK_POOL: OnceCell<AsyncComputeTaskPool> = OnceCell::new();
static IO_TASK_POOL: OnceCell<IoTaskPool> = OnceCell::new();
static MAIN_THREAD_EXECUTOR: OnceCell<MainThreadExecutor> = OnceCell::new();

/// A newtype for a task pool for CPU-intensive work that must be completed to deliver the next
/// frame
Expand Down Expand Up @@ -111,29 +110,6 @@ impl Deref for IoTaskPool {
}
}

pub struct MainThreadExecutor(async_executor::Executor<'static>);

impl MainThreadExecutor {
pub fn init() -> &'static Self {
MAIN_THREAD_EXECUTOR.get_or_init(|| Self(async_executor::Executor::new()))
}

pub fn get() -> &'static Self {
MAIN_THREAD_EXECUTOR.get().expect(
"A MainThreadExecutor has not been initialize yet. Please call \
MainThreadExecutor::init beforehand",
)
}
}

impl Deref for MainThreadExecutor {
type Target = async_executor::Executor<'static>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

/// Used by `bevy_core` to tick the global tasks pools on the main thread.
/// This will run a maximum of 100 local tasks per executor per call to this function.
#[cfg(not(target_arch = "wasm32"))]
Expand Down

0 comments on commit 36ccd85

Please sign in to comment.