diff --git a/crates/bevy_asset/src/lib.rs b/crates/bevy_asset/src/lib.rs index 62808e38a42c24..5bb5cbc2070a69 100644 --- a/crates/bevy_asset/src/lib.rs +++ b/crates/bevy_asset/src/lib.rs @@ -32,6 +32,7 @@ use bevy_ecs::{ system::IntoSystem, }; use bevy_tasks::IoTaskPool; +use std::ops::Deref; /// The names of asset stages in an App Schedule #[derive(Debug, Hash, PartialEq, Eq, Clone, StageLabel)] @@ -83,7 +84,7 @@ impl Plugin for AssetPlugin { .world() .get_resource::() .expect("`IoTaskPool` resource not found.") - .0 + .deref() .clone(); let source = create_platform_default_asset_io(app); diff --git a/crates/bevy_core/src/task_pool_options.rs b/crates/bevy_core/src/task_pool_options.rs index 19c9dad5bfe2fc..08eb10e609a01f 100644 --- a/crates/bevy_core/src/task_pool_options.rs +++ b/crates/bevy_core/src/task_pool_options.rs @@ -109,12 +109,16 @@ impl DefaultTaskPoolOptions { trace!("IO Threads: {}", io_threads); remaining_threads = remaining_threads.saturating_sub(io_threads); - world.insert_resource(IoTaskPool( - TaskPoolBuilder::default() - .num_threads(io_threads) - .thread_name("IO Task Pool".to_string()) - .build(), - )); + let task_pool = TaskPoolBuilder::default() + .num_threads(io_threads) + .thread_name("IO Task Pool".to_string()) + .build(); + + let io_task_pool = IoTaskPool::init(task_pool) + .map(|pool| pool.clone()) + .unwrap_or_else(|_| IoTaskPool::get().clone()); + + world.insert_resource(io_task_pool); } if !world.contains_resource::() { @@ -126,12 +130,16 @@ impl DefaultTaskPoolOptions { trace!("Async Compute Threads: {}", async_compute_threads); remaining_threads = remaining_threads.saturating_sub(async_compute_threads); - world.insert_resource(AsyncComputeTaskPool( - TaskPoolBuilder::default() - .num_threads(async_compute_threads) - .thread_name("Async Compute Task Pool".to_string()) - .build(), - )); + let task_pool = TaskPoolBuilder::default() + .num_threads(async_compute_threads) + .thread_name("Async Compute Task Pool".to_string()) + .build(); + + let async_task_pool = AsyncComputeTaskPool::init(task_pool) + .map(|pool| pool.clone()) + .unwrap_or_else(|_| AsyncComputeTaskPool::get().clone()); + + world.insert_resource(async_task_pool); } if !world.contains_resource::() { @@ -142,12 +150,17 @@ impl DefaultTaskPoolOptions { .get_number_of_threads(remaining_threads, total_threads); trace!("Compute Threads: {}", compute_threads); - world.insert_resource(ComputeTaskPool( - TaskPoolBuilder::default() - .num_threads(compute_threads) - .thread_name("Compute Task Pool".to_string()) - .build(), - )); + + let task_pool = TaskPoolBuilder::default() + .num_threads(compute_threads) + .thread_name("Compute Task Pool".to_string()) + .build(); + + let compute_task_pool = ComputeTaskPool::init(task_pool) + .map(|pool| pool.clone()) + .unwrap_or_else(|_| ComputeTaskPool::get().clone()); + + world.insert_resource(compute_task_pool); } } } diff --git a/crates/bevy_ecs/src/schedule/executor_parallel.rs b/crates/bevy_ecs/src/schedule/executor_parallel.rs index af989ef9eecc10..2351a3dbe6bb93 100644 --- a/crates/bevy_ecs/src/schedule/executor_parallel.rs +++ b/crates/bevy_ecs/src/schedule/executor_parallel.rs @@ -115,7 +115,11 @@ impl ParallelSystemExecutor for ParallelExecutor { self.update_archetypes(systems, world); let compute_pool = world - .get_resource_or_insert_with(|| ComputeTaskPool(TaskPool::default())) + .get_resource_or_insert_with(|| { + ComputeTaskPool::init(TaskPool::default()) + .map(|pool| pool.clone()) + .unwrap_or_else(|_| ComputeTaskPool::get().clone()) + }) .clone(); compute_pool.scope(|scope| { self.prepare_systems(scope, systems, world); diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index f2c1bf41f396b3..08b0092c00fb4d 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -22,5 +22,7 @@ async-executor = "1.3.0" async-channel = "1.4.2" instant = { version = "0.1", features = ["wasm-bindgen"] } num_cpus = "1" +once_cell = "1.7" + [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" diff --git a/crates/bevy_tasks/src/usages.rs b/crates/bevy_tasks/src/usages.rs index 7adf60639b116c..08ecb05263905c 100644 --- a/crates/bevy_tasks/src/usages.rs +++ b/crates/bevy_tasks/src/usages.rs @@ -11,12 +11,40 @@ //! for consumption. (likely via channels) use super::TaskPool; +use once_cell::sync::OnceCell; use std::ops::Deref; +static COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); +static ASYNC_COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); +static IO_TASK_POOL: OnceCell = OnceCell::new(); + /// A newtype for a task pool for CPU-intensive work that must be completed to deliver the next /// frame #[derive(Clone, Debug)] -pub struct ComputeTaskPool(pub TaskPool); +pub struct ComputeTaskPool(TaskPool); + +impl ComputeTaskPool { + /// Initializes the global ComputeTaskPool instance. + /// + /// Returns the provided `[TaskPool]` if the global instance has already been initialized. + pub fn init(task_pool: TaskPool) -> Result<&'static Self, TaskPool> { + COMPUTE_TASK_POOL + .set(Self(task_pool)) + .map(|_| Self::get()) + .map_err(|pool| pool.0) + } + + /// Gets the global ComputeTaskPool instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + COMPUTE_TASK_POOL.get().expect( + "A ComputeTaskPool has not been initialized yet. Please call \ + ComputeTaskPool::init beforehand.", + ) + } +} impl Deref for ComputeTaskPool { type Target = TaskPool; @@ -28,7 +56,30 @@ impl Deref for ComputeTaskPool { /// A newtype for a task pool for CPU-intensive work that may span across multiple frames #[derive(Clone, Debug)] -pub struct AsyncComputeTaskPool(pub TaskPool); +pub struct AsyncComputeTaskPool(TaskPool); + +impl AsyncComputeTaskPool { + /// Initializes the global AsyncComputeTaskPool instance. + /// + /// Returns the provided `[TaskPool]` if the global instance has already been initialized. + pub fn init(task_pool: TaskPool) -> Result<&'static Self, TaskPool> { + ASYNC_COMPUTE_TASK_POOL + .set(Self(task_pool)) + .map(|_| Self::get()) + .map_err(|pool| pool.0) + } + + /// Gets the global AsyncComputeTaskPool instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + ASYNC_COMPUTE_TASK_POOL.get().expect( + "A AsyncComputeTaskPool has not been initialized yet. Please call \ + AsyncComputeTaskPool::init beforehand.", + ) + } +} impl Deref for AsyncComputeTaskPool { type Target = TaskPool; @@ -41,7 +92,30 @@ impl Deref for AsyncComputeTaskPool { /// A newtype for a task pool for IO-intensive work (i.e. tasks that spend very little time in a /// "woken" state) #[derive(Clone, Debug)] -pub struct IoTaskPool(pub TaskPool); +pub struct IoTaskPool(TaskPool); + +impl IoTaskPool { + /// Initializes the global IoTaskPool instance. + /// + /// Returns the provided `[TaskPool]` if the global instance has already been initialized. + pub fn init(task_pool: TaskPool) -> Result<&'static Self, TaskPool> { + IO_TASK_POOL + .set(Self(task_pool)) + .map(|_| Self::get()) + .map_err(|pool| pool.0) + } + + /// Gets the global IoTaskPool instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + IO_TASK_POOL.get().expect( + "A IoTaskPool has not been initialized yet. Please call \ + IoTaskPool::init beforehand.", + ) + } +} impl Deref for IoTaskPool { type Target = TaskPool;