Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate together Bevy's TaskPools #12090

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/bevy_asset/src/processor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
};
use bevy_ecs::prelude::*;
use bevy_log::{debug, error, trace, warn};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
james7132 marked this conversation as resolved.
Show resolved Hide resolved
use bevy_utils::{BoxedFuture, HashMap, HashSet};
use futures_io::ErrorKind;
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
Expand Down Expand Up @@ -165,7 +165,7 @@ impl AssetProcessor {
pub fn process_assets(&self) {
let start_time = std::time::Instant::now();
debug!("Processing Assets");
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
scope.spawn(async move {
self.initialize().await.unwrap();
for source in self.sources().iter_processed() {
Expand Down Expand Up @@ -315,7 +315,7 @@ impl AssetProcessor {
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
error!("AddFolder event cannot be handled in single threaded mode (or WASM) yet.");
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
scope.spawn(async move {
self.process_assets_internal(scope, source, path)
.await
Expand Down Expand Up @@ -457,7 +457,7 @@ impl AssetProcessor {
loop {
let mut check_reprocess_queue =
std::mem::take(&mut self.data.asset_infos.write().await.check_reprocess_queue);
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
for path in check_reprocess_queue.drain(..) {
let processor = self.clone();
let source = self.get_source(path.source()).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_asset/src/server/loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use async_broadcast::RecvError;
use bevy_log::{error, warn};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_utils::{HashMap, TypeIdMap};
use std::{any::TypeId, sync::Arc};
use thiserror::Error;
Expand Down Expand Up @@ -78,7 +78,7 @@ impl AssetLoaders {
match maybe_loader {
MaybeAssetLoader::Ready(_) => unreachable!(),
MaybeAssetLoader::Pending { sender, .. } => {
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let _ = sender.broadcast(loader).await;
})
Expand Down
10 changes: 5 additions & 5 deletions crates/bevy_asset/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
};
use bevy_ecs::prelude::*;
use bevy_log::{error, info};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_utils::{CowArc, HashSet};
use crossbeam_channel::{Receiver, Sender};
use futures_lite::StreamExt;
Expand Down Expand Up @@ -296,7 +296,7 @@ impl AssetServer {
if should_load {
let owned_handle = Some(handle.clone().untyped());
let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
if let Err(err) = server.load_internal(owned_handle, path, false, None).await {
error!("{}", err);
Expand Down Expand Up @@ -366,7 +366,7 @@ impl AssetServer {
let id = handle.id().untyped();

let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let path_clone = path.clone();
match server.load_untyped_async(path).await {
Expand Down Expand Up @@ -551,7 +551,7 @@ impl AssetServer {
pub fn reload<'a>(&self, path: impl Into<AssetPath<'a>>) {
let server = self.clone();
let path = path.into().into_owned();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let mut reloaded = false;

Expand Down Expand Up @@ -690,7 +690,7 @@ impl AssetServer {

let path = path.into_owned();
let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let Ok(source) = server.get_source(path.source()) else {
error!(
Expand Down
21 changes: 2 additions & 19 deletions crates/bevy_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ fn register_math_types(app: &mut App) {
.register_type::<Vec<bevy_math::Vec3>>();
}

/// Setup of default task pools: [`AsyncComputeTaskPool`](bevy_tasks::AsyncComputeTaskPool),
/// [`ComputeTaskPool`](bevy_tasks::ComputeTaskPool), [`IoTaskPool`](bevy_tasks::IoTaskPool).
/// Setup of default task pool: [`ComputeTaskPool`](bevy_tasks::ComputeTaskPool).
#[derive(Default)]
pub struct TaskPoolPlugin {
/// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start.
Expand Down Expand Up @@ -175,39 +174,23 @@ pub fn update_frame_count(mut frame_count: ResMut<FrameCount>) {
#[cfg(test)]
mod tests {
use super::*;
use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
use bevy_tasks::prelude::ComputeTaskPool;

#[test]
fn runs_spawn_local_tasks() {
let mut app = App::new();
app.add_plugins((TaskPoolPlugin::default(), TypeRegistrationPlugin));

let (async_tx, async_rx) = crossbeam_channel::unbounded();
AsyncComputeTaskPool::get()
.spawn_local(async move {
async_tx.send(()).unwrap();
})
.detach();

let (compute_tx, compute_rx) = crossbeam_channel::unbounded();
ComputeTaskPool::get()
.spawn_local(async move {
compute_tx.send(()).unwrap();
})
.detach();

let (io_tx, io_rx) = crossbeam_channel::unbounded();
IoTaskPool::get()
.spawn_local(async move {
io_tx.send(()).unwrap();
})
.detach();

app.run();

async_rx.try_recv().unwrap();
compute_rx.try_recv().unwrap();
io_rx.try_recv().unwrap();
}

#[test]
Expand Down
118 changes: 7 additions & 111 deletions crates/bevy_core/src/task_pool_options.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,6 @@
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_tasks::{ComputeTaskPool, TaskPoolBuilder};
use bevy_utils::tracing::trace;

/// Defines a simple way to determine how many threads to use given the number of remaining cores
/// and number of total cores
#[derive(Clone, Debug)]
pub struct TaskPoolThreadAssignmentPolicy {
/// Force using at least this many threads
pub min_threads: usize,
/// Under no circumstance use more than this many threads for this pool
pub max_threads: usize,
/// Target using this percentage of total cores, clamped by min_threads and max_threads. It is
/// permitted to use 1.0 to try to use all remaining threads
pub percent: f32,
}

impl TaskPoolThreadAssignmentPolicy {
/// Determine the number of threads to use for this task pool
fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
assert!(self.percent >= 0.0);
let mut desired = (total_threads as f32 * self.percent).round() as usize;

// Limit ourselves to the number of cores available
desired = desired.min(remaining_threads);

// Clamp by min_threads, max_threads. (This may result in us using more threads than are
// available, this is intended. An example case where this might happen is a device with
// <= 2 threads.
desired.clamp(self.min_threads, self.max_threads)
}
}

/// Helper for configuring and creating the default task pools. For end-users who want full control,
/// set up [`TaskPoolPlugin`](super::TaskPoolPlugin)
#[derive(Clone, Debug)]
Expand All @@ -40,13 +11,6 @@ pub struct TaskPoolOptions {
/// If the number of physical cores is greater than max_total_threads, force using
/// max_total_threads
pub max_total_threads: usize,

/// Used to determine number of IO threads to allocate
pub io: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of async compute threads to allocate
pub async_compute: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of compute threads to allocate
pub compute: TaskPoolThreadAssignmentPolicy,
}

impl Default for TaskPoolOptions {
Expand All @@ -55,27 +19,6 @@ impl Default for TaskPoolOptions {
// By default, use however many cores are available on the system
min_total_threads: 1,
max_total_threads: usize::MAX,

// Use 25% of cores for IO, at least 1, no more than 4
io: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
},

// Use 25% of cores for async compute, at least 1, no more than 4
async_compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
},

// Use all remaining cores for compute (at least 1)
compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: usize::MAX,
percent: 1.0, // This 1.0 here means "whatever is left over"
},
}
}
}
Expand All @@ -86,7 +29,6 @@ impl TaskPoolOptions {
TaskPoolOptions {
min_total_threads: thread_count,
max_total_threads: thread_count,
..Default::default()
}
}

Expand All @@ -96,57 +38,11 @@ impl TaskPoolOptions {
.clamp(self.min_total_threads, self.max_total_threads);
trace!("Assigning {} cores to default task pools", total_threads);

let mut remaining_threads = total_threads;

{
// Determine the number of IO threads we will use
let io_threads = self
.io
.get_number_of_threads(remaining_threads, total_threads);

trace!("IO Threads: {}", io_threads);
remaining_threads = remaining_threads.saturating_sub(io_threads);

IoTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string())
.build()
});
}

{
// Determine the number of async compute threads we will use
let async_compute_threads = self
.async_compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Async Compute Threads: {}", async_compute_threads);
remaining_threads = remaining_threads.saturating_sub(async_compute_threads);

AsyncComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string())
.build()
});
}

{
// Determine the number of compute threads we will use
// This is intentionally last so that an end user can specify 1.0 as the percent
let compute_threads = self
.compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Compute Threads: {}", compute_threads);

ComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string())
.build()
});
}
ComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(total_threads)
.thread_name("Compute Task Pool".to_string())
.build()
});
}
}
4 changes: 2 additions & 2 deletions crates/bevy_gltf/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use bevy_render::{
};
use bevy_scene::Scene;
#[cfg(not(target_arch = "wasm32"))]
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_transform::components::Transform;
use bevy_utils::{
smallvec::{smallvec, SmallVec},
Expand Down Expand Up @@ -348,7 +348,7 @@ async fn load_gltf<'a, 'b, 'c>(
}
} else {
#[cfg(not(target_arch = "wasm32"))]
IoTaskPool::get()
ComputeTaskPool::get()
.scope(|scope| {
gltf.textures().for_each(|gltf_texture| {
let parent_path = load_context.path().parent().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl Plugin for RenderPlugin {
};
// In wasm, spawn a task and detach it for execution
#[cfg(target_arch = "wasm32")]
bevy_tasks::IoTaskPool::get()
bevy_tasks::ComputeTaskPool::get()
.spawn_local(async_renderer)
.detach();
// Otherwise, just block for it to complete
Expand Down
17 changes: 9 additions & 8 deletions crates/bevy_render/src/render_resource/pipeline_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use bevy_utils::{
use naga::valid::Capabilities;
use std::{
borrow::Cow,
future::Future,
hash::Hash,
mem,
ops::Deref,
Expand Down Expand Up @@ -698,7 +697,7 @@ impl PipelineCache {
let shader_cache = self.shader_cache.clone();
let layout_cache = self.layout_cache.clone();
create_pipeline_task(
async move {
move || {
let mut shader_cache = shader_cache.lock().unwrap();
let mut layout_cache = layout_cache.lock().unwrap();

Expand Down Expand Up @@ -797,7 +796,7 @@ impl PipelineCache {
let shader_cache = self.shader_cache.clone();
let layout_cache = self.layout_cache.clone();
create_pipeline_task(
async move {
move || {
let mut shader_cache = shader_cache.lock().unwrap();
let mut layout_cache = layout_cache.lock().unwrap();

Expand Down Expand Up @@ -953,14 +952,16 @@ impl PipelineCache {
feature = "multi-threaded"
))]
fn create_pipeline_task(
task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
task: impl FnOnce() -> Result<Pipeline, PipelineCacheError> + Send + 'static,
sync: bool,
) -> CachedPipelineState {
if !sync {
return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task));
return CachedPipelineState::Creating(
bevy_tasks::ComputeTaskPool::get().spawn_blocking(task),
);
}

match futures_lite::future::block_on(task) {
match task() {
Ok(pipeline) => CachedPipelineState::Ok(pipeline),
Err(err) => CachedPipelineState::Err(err),
}
Expand All @@ -972,10 +973,10 @@ fn create_pipeline_task(
not(feature = "multi-threaded")
))]
fn create_pipeline_task(
task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
task: impl FnOnce() -> Result<Pipeline, PipelineCacheError> + Send + 'static,
_sync: bool,
) -> CachedPipelineState {
match futures_lite::future::block_on(task) {
match task() {
Ok(pipeline) => CachedPipelineState::Ok(pipeline),
Err(err) => CachedPipelineState::Err(err),
}
Expand Down
Loading
Loading