Skip to content

Commit

Permalink
Use our own local pool with proper drop impl
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn committed Jul 16, 2024
1 parent d17ffa3 commit 8f3469c
Show file tree
Hide file tree
Showing 14 changed files with 415 additions and 58 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion iroh-blobs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ smallvec = { version = "1.10.0", features = ["serde", "const_new"] }
tempfile = { version = "3.10.0", optional = true }
thiserror = "1"
tokio = { version = "1", features = ["fs"] }
tokio-util = { version = "0.7", features = ["io-util", "io", "rt"] }
tokio-util = { version = "0.7", features = ["io-util", "io"] }
tracing = "0.1"
tracing-futures = "0.2.5"

Expand Down
9 changes: 6 additions & 3 deletions iroh-blobs/examples/provide-bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
//! cargo run --example provide-bytes collection
//! To provide a collection (multiple blobs)
use anyhow::Result;
use tokio_util::task::LocalPoolHandle;
use tracing::warn;
use tracing_subscriber::{prelude::*, EnvFilter};

use iroh_blobs::{format::collection::Collection, Hash};
use iroh_blobs::{
format::collection::Collection,
util::local_pool::{self, LocalPool},
Hash,
};

mod connect;
use connect::{make_and_write_certs, make_server_endpoint, CERT_PATH};
Expand Down Expand Up @@ -82,7 +85,7 @@ async fn main() -> Result<()> {
println!("\nfetch the content using a stream by running the following example:\n\ncargo run --example fetch-stream {hash} \"{addr}\" {format}\n");

// create a new local pool handle with 1 worker thread
let lp = LocalPoolHandle::new(1);
let lp = LocalPool::new(local_pool::Config::default());

let accept_task = tokio::spawn(async move {
while let Some(incoming) = endpoint.accept().await {
Expand Down
6 changes: 3 additions & 3 deletions iroh-blobs/src/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ use tokio::{
sync::{mpsc, oneshot},
task::JoinSet,
};
use tokio_util::{sync::CancellationToken, task::LocalPoolHandle, time::delay_queue};
use tokio_util::{sync::CancellationToken, time::delay_queue};
use tracing::{debug, error_span, trace, warn, Instrument};

use crate::{
get::{db::DownloadProgress, Stats},
store::Store,
util::progress::ProgressSender,
util::{local_pool::LocalPoolHandle, progress::ProgressSender},
};

mod get;
Expand Down Expand Up @@ -338,7 +338,7 @@ impl Downloader {

service.run().instrument(error_span!("downloader", %me))
};
rt.spawn_pinned(create_future);
let _ = rt.spawn_pinned(create_future);
Self {
next_id: Arc::new(AtomicU64::new(0)),
msg_tx,
Expand Down
55 changes: 36 additions & 19 deletions iroh-blobs/src/downloader/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use iroh_net::key::SecretKey;

use crate::{
get::{db::BlobId, progress::TransferState},
util::progress::{FlumeProgressSender, IdGenerator},
util::{
local_pool::LocalPool,
progress::{FlumeProgressSender, IdGenerator},
},
};

use super::*;
Expand All @@ -23,7 +26,7 @@ impl Downloader {
dialer: dialer::TestingDialer,
getter: getter::TestingGetter,
concurrency_limits: ConcurrencyLimits,
) -> Self {
) -> (Self, LocalPool) {
Self::spawn_for_test_with_retry_config(
dialer,
getter,
Expand All @@ -37,21 +40,25 @@ impl Downloader {
getter: getter::TestingGetter,
concurrency_limits: ConcurrencyLimits,
retry_config: RetryConfig,
) -> Self {
) -> (Self, LocalPool) {
let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY);

LocalPoolHandle::new(1).spawn_pinned(move || async move {
let lp = LocalPool::new(Default::default());
let _ = lp.spawn_pinned(move || async move {
// we want to see the logs of the service
let _guard = iroh_test::logging::setup();

let service = Service::new(getter, dialer, concurrency_limits, retry_config, msg_rx);
service.run().await
});

Downloader {
next_id: Arc::new(AtomicU64::new(0)),
msg_tx,
}
(
Downloader {
next_id: Arc::new(AtomicU64::new(0)),
msg_tx,
},
lp,
)
}
}

Expand All @@ -63,7 +70,8 @@ async fn smoke_test() {
let getter = getter::TestingGetter::default();
let concurrency_limits = ConcurrencyLimits::default();

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

// send a request and make sure the peer is requested the corresponding download
let peer = SecretKey::generate().public();
Expand All @@ -88,7 +96,8 @@ async fn deduplication() {
getter.set_request_duration(Duration::from_secs(1));
let concurrency_limits = ConcurrencyLimits::default();

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

let peer = SecretKey::generate().public();
let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into();
Expand Down Expand Up @@ -119,7 +128,8 @@ async fn cancellation() {
getter.set_request_duration(Duration::from_millis(500));
let concurrency_limits = ConcurrencyLimits::default();

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

let peer = SecretKey::generate().public();
let kind_1: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into();
Expand Down Expand Up @@ -158,7 +168,8 @@ async fn max_concurrent_requests_total() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

// send the downloads
let peer = SecretKey::generate().public();
Expand Down Expand Up @@ -201,7 +212,8 @@ async fn max_concurrent_requests_per_peer() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

// send the downloads
let peer = SecretKey::generate().public();
Expand Down Expand Up @@ -257,7 +269,8 @@ async fn concurrent_progress() {
}
.boxed()
}));
let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());

let peer = SecretKey::generate().public();
let hash = Hash::new([0u8; 32]);
Expand Down Expand Up @@ -341,7 +354,8 @@ async fn long_queue() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
// send the downloads
let nodes = [
SecretKey::generate().public(),
Expand Down Expand Up @@ -370,7 +384,8 @@ async fn fail_while_running() {
let _guard = iroh_test::logging::setup();
let dialer = dialer::TestingDialer::default();
let getter = getter::TestingGetter::default();
let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let blob_fail = HashAndFormat::raw(Hash::new([1u8; 32]));
let blob_success = HashAndFormat::raw(Hash::new([2u8; 32]));

Expand Down Expand Up @@ -407,7 +422,8 @@ async fn retry_nodes_simple() {
let _guard = iroh_test::logging::setup();
let dialer = dialer::TestingDialer::default();
let getter = getter::TestingGetter::default();
let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default());
let node = SecretKey::generate().public();
let dial_attempts = Arc::new(AtomicUsize::new(0));
let dial_attempts2 = dial_attempts.clone();
Expand All @@ -432,7 +448,7 @@ async fn retry_nodes_fail() {
max_retries_per_node: 3,
};

let downloader = Downloader::spawn_for_test_with_retry_config(
let (downloader, _lp) = Downloader::spawn_for_test_with_retry_config(
dialer.clone(),
getter.clone(),
Default::default(),
Expand Down Expand Up @@ -472,7 +488,8 @@ async fn retry_nodes_jump_queue() {
..Default::default()
};

let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);
let (downloader, _lp) =
Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits);

let good_node = SecretKey::generate().public();
let bad_node = SecretKey::generate().public();
Expand Down
21 changes: 13 additions & 8 deletions iroh-blobs/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use iroh_io::stats::{
use iroh_io::{AsyncSliceReader, AsyncStreamWriter, TokioStreamWriter};
use iroh_net::endpoint::{self, RecvStream, SendStream};
use serde::{Deserialize, Serialize};
use tokio_util::task::LocalPoolHandle;
use tracing::{debug, debug_span, info, trace, warn};
use tracing_futures::Instrument;

use crate::hashseq::parse_hash_seq;
use crate::protocol::{GetRequest, RangeSpec, Request};
use crate::store::*;
use crate::util::local_pool::LocalPoolHandle;
use crate::util::Tag;
use crate::{BlobFormat, Hash};

Expand Down Expand Up @@ -302,14 +302,19 @@ pub async fn handle_connection<D: Map, E: EventSender>(
};
events.send(Event::ClientConnected { connection_id }).await;
let db = db.clone();
rt.spawn_pinned(|| {
async move {
if let Err(err) = handle_stream(db, reader, writer).await {
warn!("error: {err:#?}",);
let res = rt
.spawn_pinned_detached(|| {
async move {
if let Err(err) = handle_stream(db, reader, writer).await {
warn!("error: {err:#?}",);
}
}
}
.instrument(span)
});
.instrument(span)
})
.await;
if res.is_err() {
break;
}
}
}
.instrument(span)
Expand Down
5 changes: 3 additions & 2 deletions iroh-blobs/src/store/bao_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,8 @@ mod tests {
decode_response_into_batch, local, make_wire_data, random_test_data, trickle, validate,
};
use tokio::task::JoinSet;
use tokio_util::task::LocalPoolHandle;

use crate::util::local_pool::LocalPool;

use super::*;

Expand Down Expand Up @@ -957,7 +958,7 @@ mod tests {
)),
hash.into(),
);
let local = LocalPoolHandle::new(4);
let local = LocalPool::new(Default::default());
let mut tasks = Vec::new();
for i in 0..4 {
let file = handle.writer();
Expand Down
7 changes: 5 additions & 2 deletions iroh-blobs/src/store/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use iroh_base::rpc::RpcError;
use iroh_io::AsyncSliceReader;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncRead;
use tokio_util::task::LocalPoolHandle;

use crate::{
hashseq::parse_hash_seq,
protocol::RangeSpec,
util::{
local_pool::{self, LocalPool},
progress::{BoxedProgressSender, IdGenerator, ProgressSender},
Tag,
},
Expand Down Expand Up @@ -423,7 +423,10 @@ async fn validate_impl(
use futures_buffered::BufferedStreamExt;

let validate_parallelism: usize = num_cpus::get();
let lp = LocalPoolHandle::new(validate_parallelism);
let lp = LocalPool::new(local_pool::Config {
threads: validate_parallelism,
..Default::default()
});
let complete = store.blobs().await?.collect::<io::Result<Vec<_>>>()?;
let partial = store
.partial_blobs()
Expand Down
1 change: 1 addition & 0 deletions iroh-blobs/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod progress;
pub use mem_or_file::MemOrFile;
mod sparse_mem_file;
pub use sparse_mem_file::SparseMemFile;
pub mod local_pool;

/// A tag
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, From, Into)]
Expand Down
Loading

0 comments on commit 8f3469c

Please sign in to comment.