diff --git a/iroh/src/node.rs b/iroh/src/node.rs index c55f649010..3209f8b9fa 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -59,16 +59,18 @@ use iroh_net::{ endpoint::{DirectAddrsStream, RemoteInfo}, AddrInfo, Endpoint, NodeAddr, }; -use protocol::BlobsProtocol; +use protocol::blobs::BlobsProtocol; use quic_rpc::{transport::ServerEndpoint as _, RpcServer}; use tokio::task::{JoinError, JoinSet}; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, error, info, info_span, trace, warn, Instrument}; -use crate::node::{docs::DocsEngine, nodes_storage::store_node_addrs, protocol::ProtocolMap}; +use crate::node::{ + nodes_storage::store_node_addrs, + protocol::{docs::DocsProtocol, ProtocolMap}, +}; mod builder; -mod docs; mod nodes_storage; mod protocol; mod rpc; @@ -296,7 +298,7 @@ impl NodeInner { if let GcPolicy::Interval(gc_period) = gc_policy { let protocols = protocols.clone(); let handle = local_pool.spawn(move || async move { - let docs_engine = protocols.get_typed::(DOCS_ALPN); + let docs_engine = protocols.get_typed::(DOCS_ALPN); let blobs = protocols .get_typed::>(iroh_blobs::protocol::ALPN) .expect("missing blobs"); diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index c037f77b1c..9ed124c0ac 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -32,14 +32,12 @@ use tokio::task::JoinError; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, error_span, trace, Instrument}; -use super::{ - docs::DocsEngine, rpc_status::RpcStatus, IrohServerEndpoint, JoinErrToStr, Node, NodeInner, -}; +use super::{rpc_status::RpcStatus, IrohServerEndpoint, JoinErrToStr, Node, NodeInner}; use crate::{ client::RPC_ALPN, node::{ nodes_storage::load_node_addrs, - protocol::{BlobsProtocol, ProtocolMap}, + protocol::{blobs::BlobsProtocol, docs::DocsProtocol, ProtocolMap}, ProtocolHandler, }, rpc_protocol::RpcService, @@ -654,8 +652,8 @@ where let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); // Spawn the docs engine, if enabled. - // This returns None for DocsStorage::Disabled, otherwise Some(DocsEngine). - let docs = DocsEngine::spawn( + // This returns None for DocsStorage::Disabled, otherwise Some(DocsProtocol). + let docs = DocsProtocol::spawn( self.docs_storage, self.blobs_store.clone(), self.storage.default_author_storage(), @@ -813,7 +811,7 @@ impl ProtocolBuilder { store: D, gossip: Gossip, downloader: Downloader, - docs: Option, + docs: Option, ) -> Self { // Register blobs. let blobs_proto = BlobsProtocol::new_with_events( diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index e7b0ddb3e7..2669d9ba1e 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -1,29 +1,13 @@ use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; -use anyhow::{anyhow, Result}; +use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use futures_util::future::join_all; -use iroh_blobs::{ - downloader::{DownloadRequest, Downloader}, - get::{ - db::{DownloadProgress, GetState}, - Stats, - }, - provider::EventSender, - util::{ - local_pool::LocalPoolHandle, - progress::{AsyncChannelProgressSender, ProgressSender}, - SetTagOption, - }, - HashAndFormat, TempTag, -}; -use iroh_net::{endpoint::Connecting, Endpoint, NodeAddr}; -use tracing::{debug, warn}; +use iroh_net::endpoint::Connecting; -use crate::{ - client::blobs::DownloadMode, - rpc_protocol::blobs::{BatchId, DownloadRequest as BlobDownloadRequest}, -}; +pub(crate) mod blobs; +pub(crate) mod docs; +pub(crate) mod gossip; /// Handler for incoming connections. /// @@ -95,251 +79,3 @@ impl ProtocolMap { join_all(handlers).await; } } - -#[derive(Debug)] -pub(crate) struct BlobsProtocol { - rt: LocalPoolHandle, - store: S, - events: EventSender, - downloader: Downloader, - batches: tokio::sync::Mutex, -} - -/// Name used for logging when new node addresses are added from gossip. -const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download"; - -/// Keeps track of all the currently active batch operations of the blobs api. -#[derive(Debug, Default)] -pub(crate) struct BlobBatches { - /// Currently active batches - batches: BTreeMap, - /// Used to generate new batch ids. - max: u64, -} - -/// A single batch of blob operations -#[derive(Debug, Default)] -struct BlobBatch { - /// The tags in this batch. - tags: BTreeMap>, -} - -impl BlobBatches { - /// Create a new unique batch id. - pub(crate) fn create(&mut self) -> BatchId { - let id = self.max; - self.max += 1; - BatchId(id) - } - - /// Store a temp tag in a batch identified by a batch id. - pub(crate) fn store(&mut self, batch: BatchId, tt: TempTag) { - let entry = self.batches.entry(batch).or_default(); - entry.tags.entry(tt.hash_and_format()).or_default().push(tt); - } - - /// Remove a tag from a batch. - pub(crate) fn remove_one(&mut self, batch: BatchId, content: &HashAndFormat) -> Result<()> { - if let Some(batch) = self.batches.get_mut(&batch) { - if let Some(tags) = batch.tags.get_mut(content) { - tags.pop(); - if tags.is_empty() { - batch.tags.remove(content); - } - return Ok(()); - } - } - // this can happen if we try to upgrade a tag from an expired batch - anyhow::bail!("tag not found in batch"); - } - - /// Remove an entire batch. - pub(crate) fn remove(&mut self, batch: BatchId) { - self.batches.remove(&batch); - } -} - -impl BlobsProtocol { - pub(crate) fn new_with_events( - store: S, - rt: LocalPoolHandle, - events: EventSender, - downloader: Downloader, - ) -> Self { - Self { - rt, - store, - events, - downloader, - batches: Default::default(), - } - } - - pub(crate) fn store(&self) -> &S { - &self.store - } - - pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> { - self.batches.lock().await - } - - pub(crate) async fn download( - &self, - endpoint: Endpoint, - req: BlobDownloadRequest, - progress: AsyncChannelProgressSender, - ) -> Result<()> { - let BlobDownloadRequest { - hash, - format, - nodes, - tag, - mode, - } = req; - let hash_and_format = HashAndFormat { hash, format }; - let temp_tag = self.store.temp_tag(hash_and_format); - let stats = match mode { - DownloadMode::Queued => { - self.download_queued(endpoint, hash_and_format, nodes, progress.clone()) - .await? - } - DownloadMode::Direct => { - self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone()) - .await? - } - }; - - progress.send(DownloadProgress::AllDone(stats)).await.ok(); - match tag { - SetTagOption::Named(tag) => { - self.store.set_tag(tag, Some(hash_and_format)).await?; - } - SetTagOption::Auto => { - self.store.create_tag(hash_and_format).await?; - } - } - drop(temp_tag); - - Ok(()) - } - - async fn download_queued( - &self, - endpoint: Endpoint, - hash_and_format: HashAndFormat, - nodes: Vec, - progress: AsyncChannelProgressSender, - ) -> Result { - let mut node_ids = Vec::with_capacity(nodes.len()); - let mut any_added = false; - for node in nodes { - node_ids.push(node.node_id); - if !node.info.is_empty() { - endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?; - any_added = true; - } - } - let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some()); - anyhow::ensure!(can_download, "no way to reach a node for download"); - let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress); - let handle = self.downloader.queue(req).await; - let stats = handle.await?; - Ok(stats) - } - - #[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))] - async fn download_direct_from_nodes( - &self, - endpoint: Endpoint, - hash_and_format: HashAndFormat, - nodes: Vec, - progress: AsyncChannelProgressSender, - ) -> Result { - let mut last_err = None; - let mut remaining_nodes = nodes.len(); - let mut nodes_iter = nodes.into_iter(); - 'outer: loop { - match iroh_blobs::get::db::get_to_db_in_steps( - self.store.clone(), - hash_and_format, - progress.clone(), - ) - .await? - { - GetState::Complete(stats) => return Ok(stats), - GetState::NeedsConn(needs_conn) => { - let (conn, node_id) = 'inner: loop { - match nodes_iter.next() { - None => break 'outer, - Some(node) => { - remaining_nodes -= 1; - let node_id = node.node_id; - if node_id == endpoint.node_id() { - debug!( - ?remaining_nodes, - "skip node {} (it is the node id of ourselves)", - node_id.fmt_short() - ); - continue 'inner; - } - match endpoint.connect(node, iroh_blobs::protocol::ALPN).await { - Ok(conn) => break 'inner (conn, node_id), - Err(err) => { - debug!( - ?remaining_nodes, - "failed to connect to {}: {err}", - node_id.fmt_short() - ); - continue 'inner; - } - } - } - } - }; - match needs_conn.proceed(conn).await { - Ok(stats) => return Ok(stats), - Err(err) => { - warn!( - ?remaining_nodes, - "failed to download from {}: {err}", - node_id.fmt_short() - ); - last_err = Some(err); - } - } - } - } - } - match last_err { - Some(err) => Err(err.into()), - None => Err(anyhow!("No nodes to download from provided")), - } - } -} - -impl ProtocolHandler for BlobsProtocol { - fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - Box::pin(async move { - iroh_blobs::provider::handle_connection( - conn.await?, - self.store.clone(), - self.events.clone(), - self.rt.clone(), - ) - .await; - Ok(()) - }) - } - - fn shutdown(self: Arc) -> BoxedFuture<()> { - Box::pin(async move { - self.store.shutdown().await; - }) - } -} - -impl ProtocolHandler for iroh_gossip::net::Gossip { - fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - Box::pin(async move { self.handle_connection(conn.await?).await }) - } -} diff --git a/iroh/src/node/protocol/blobs.rs b/iroh/src/node/protocol/blobs.rs new file mode 100644 index 0000000000..f385c8361f --- /dev/null +++ b/iroh/src/node/protocol/blobs.rs @@ -0,0 +1,268 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use anyhow::{anyhow, Result}; +use futures_lite::future::Boxed as BoxedFuture; +use iroh_blobs::{ + downloader::{DownloadRequest, Downloader}, + get::{ + db::{DownloadProgress, GetState}, + Stats, + }, + provider::EventSender, + util::{ + local_pool::LocalPoolHandle, + progress::{AsyncChannelProgressSender, ProgressSender}, + SetTagOption, + }, + HashAndFormat, TempTag, +}; +use iroh_net::{endpoint::Connecting, Endpoint, NodeAddr}; +use tracing::{debug, warn}; + +use super::ProtocolHandler; +use crate::{ + client::blobs::DownloadMode, + rpc_protocol::blobs::{BatchId, DownloadRequest as BlobDownloadRequest}, +}; + +#[derive(Debug)] +pub(crate) struct BlobsProtocol { + rt: LocalPoolHandle, + store: S, + events: EventSender, + downloader: Downloader, + batches: tokio::sync::Mutex, +} + +/// Name used for logging when new node addresses are added from gossip. +const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download"; + +/// Keeps track of all the currently active batch operations of the blobs api. +#[derive(Debug, Default)] +pub(crate) struct BlobBatches { + /// Currently active batches + batches: BTreeMap, + /// Used to generate new batch ids. + max: u64, +} + +/// A single batch of blob operations +#[derive(Debug, Default)] +struct BlobBatch { + /// The tags in this batch. + tags: BTreeMap>, +} + +impl BlobBatches { + /// Create a new unique batch id. + pub(crate) fn create(&mut self) -> BatchId { + let id = self.max; + self.max += 1; + BatchId(id) + } + + /// Store a temp tag in a batch identified by a batch id. + pub(crate) fn store(&mut self, batch: BatchId, tt: TempTag) { + let entry = self.batches.entry(batch).or_default(); + entry.tags.entry(tt.hash_and_format()).or_default().push(tt); + } + + /// Remove a tag from a batch. + pub(crate) fn remove_one(&mut self, batch: BatchId, content: &HashAndFormat) -> Result<()> { + if let Some(batch) = self.batches.get_mut(&batch) { + if let Some(tags) = batch.tags.get_mut(content) { + tags.pop(); + if tags.is_empty() { + batch.tags.remove(content); + } + return Ok(()); + } + } + // this can happen if we try to upgrade a tag from an expired batch + anyhow::bail!("tag not found in batch"); + } + + /// Remove an entire batch. + pub(crate) fn remove(&mut self, batch: BatchId) { + self.batches.remove(&batch); + } +} + +impl BlobsProtocol { + pub(crate) fn new_with_events( + store: S, + rt: LocalPoolHandle, + events: EventSender, + downloader: Downloader, + ) -> Self { + Self { + rt, + store, + events, + downloader, + batches: Default::default(), + } + } + + pub(crate) fn store(&self) -> &S { + &self.store + } + + pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> { + self.batches.lock().await + } + + pub(crate) async fn download( + &self, + endpoint: Endpoint, + req: BlobDownloadRequest, + progress: AsyncChannelProgressSender, + ) -> Result<()> { + let BlobDownloadRequest { + hash, + format, + nodes, + tag, + mode, + } = req; + let hash_and_format = HashAndFormat { hash, format }; + let temp_tag = self.store.temp_tag(hash_and_format); + let stats = match mode { + DownloadMode::Queued => { + self.download_queued(endpoint, hash_and_format, nodes, progress.clone()) + .await? + } + DownloadMode::Direct => { + self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone()) + .await? + } + }; + + progress.send(DownloadProgress::AllDone(stats)).await.ok(); + match tag { + SetTagOption::Named(tag) => { + self.store.set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + self.store.create_tag(hash_and_format).await?; + } + } + drop(temp_tag); + + Ok(()) + } + + async fn download_queued( + &self, + endpoint: Endpoint, + hash_and_format: HashAndFormat, + nodes: Vec, + progress: AsyncChannelProgressSender, + ) -> Result { + let mut node_ids = Vec::with_capacity(nodes.len()); + let mut any_added = false; + for node in nodes { + node_ids.push(node.node_id); + if !node.info.is_empty() { + endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?; + any_added = true; + } + } + let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some()); + anyhow::ensure!(can_download, "no way to reach a node for download"); + let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress); + let handle = self.downloader.queue(req).await; + let stats = handle.await?; + Ok(stats) + } + + #[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))] + async fn download_direct_from_nodes( + &self, + endpoint: Endpoint, + hash_and_format: HashAndFormat, + nodes: Vec, + progress: AsyncChannelProgressSender, + ) -> Result { + let mut last_err = None; + let mut remaining_nodes = nodes.len(); + let mut nodes_iter = nodes.into_iter(); + 'outer: loop { + match iroh_blobs::get::db::get_to_db_in_steps( + self.store.clone(), + hash_and_format, + progress.clone(), + ) + .await? + { + GetState::Complete(stats) => return Ok(stats), + GetState::NeedsConn(needs_conn) => { + let (conn, node_id) = 'inner: loop { + match nodes_iter.next() { + None => break 'outer, + Some(node) => { + remaining_nodes -= 1; + let node_id = node.node_id; + if node_id == endpoint.node_id() { + debug!( + ?remaining_nodes, + "skip node {} (it is the node id of ourselves)", + node_id.fmt_short() + ); + continue 'inner; + } + match endpoint.connect(node, iroh_blobs::protocol::ALPN).await { + Ok(conn) => break 'inner (conn, node_id), + Err(err) => { + debug!( + ?remaining_nodes, + "failed to connect to {}: {err}", + node_id.fmt_short() + ); + continue 'inner; + } + } + } + } + }; + match needs_conn.proceed(conn).await { + Ok(stats) => return Ok(stats), + Err(err) => { + warn!( + ?remaining_nodes, + "failed to download from {}: {err}", + node_id.fmt_short() + ); + last_err = Some(err); + } + } + } + } + } + match last_err { + Some(err) => Err(err.into()), + None => Err(anyhow!("No nodes to download from provided")), + } + } +} + +impl ProtocolHandler for BlobsProtocol { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { + iroh_blobs::provider::handle_connection( + conn.await?, + self.store.clone(), + self.events.clone(), + self.rt.clone(), + ) + .await; + Ok(()) + }) + } + + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move { + self.store.shutdown().await; + }) + } +} diff --git a/iroh/src/node/docs.rs b/iroh/src/node/protocol/docs.rs similarity index 91% rename from iroh/src/node/docs.rs rename to iroh/src/node/protocol/docs.rs index be6400c8ca..597e5ee864 100644 --- a/iroh/src/node/docs.rs +++ b/iroh/src/node/protocol/docs.rs @@ -11,9 +11,9 @@ use crate::node::{DocsStorage, ProtocolHandler}; /// Wrapper around [`Engine`] so that we can implement our RPC methods directly. #[derive(Debug, Clone)] -pub(crate) struct DocsEngine(Engine); +pub(crate) struct DocsProtocol(Engine); -impl DocsEngine { +impl DocsProtocol { pub async fn spawn( storage: DocsStorage, blobs_store: S, @@ -36,18 +36,18 @@ impl DocsEngine { default_author_storage, ) .await?; - Ok(Some(DocsEngine(engine))) + Ok(Some(DocsProtocol(engine))) } } -impl Deref for DocsEngine { +impl Deref for DocsProtocol { type Target = Engine; fn deref(&self) -> &Self::Target { &self.0 } } -impl ProtocolHandler for DocsEngine { +impl ProtocolHandler for DocsProtocol { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { self.handle_connection(conn).await }) } diff --git a/iroh/src/node/protocol/gossip.rs b/iroh/src/node/protocol/gossip.rs new file mode 100644 index 0000000000..980a9868bb --- /dev/null +++ b/iroh/src/node/protocol/gossip.rs @@ -0,0 +1,13 @@ +use std::sync::Arc; + +use anyhow::Result; +use futures_lite::future::Boxed as BoxedFuture; +use iroh_net::endpoint::Connecting; + +use super::ProtocolHandler; + +impl ProtocolHandler for iroh_gossip::net::Gossip { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } +} diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index beae5bb0da..cb8b346b0a 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -43,7 +43,10 @@ use crate::{ tags::TagInfo, NodeStatus, }, - node::{docs::DocsEngine, protocol::BlobsProtocol, NodeInner}, + node::{ + protocol::{blobs::BlobsProtocol, docs::DocsProtocol}, + NodeInner, + }, rpc_protocol::{ authors, blobs, blobs::{ @@ -95,8 +98,8 @@ impl Handler { } impl Handler { - fn docs(&self) -> Option> { - self.protocols.get_typed::(DOCS_ALPN) + fn docs(&self) -> Option> { + self.protocols.get_typed::(DOCS_ALPN) } fn blobs(&self) -> Arc> { @@ -112,7 +115,7 @@ impl Handler { async fn with_docs(self, f: F) -> RpcResult where T: Send + 'static, - F: FnOnce(Arc) -> Fut, + F: FnOnce(Arc) -> Fut, Fut: std::future::Future>, { if let Some(docs) = self.docs() { @@ -125,7 +128,7 @@ impl Handler { fn with_docs_stream(self, f: F) -> impl Stream> where T: Send + 'static, - F: FnOnce(Arc) -> S, + F: FnOnce(Arc) -> S, S: Stream>, { if let Some(docs) = self.docs() { diff --git a/iroh/src/node/rpc/docs.rs b/iroh/src/node/rpc/docs.rs index a777aff452..ec2d04f645 100644 --- a/iroh/src/node/rpc/docs.rs +++ b/iroh/src/node/rpc/docs.rs @@ -1,4 +1,4 @@ -//! This module contains an impl block on [`DocsEngine`] with handlers for RPC requests +//! This module contains an impl block on [`DocsProtocol`] with handlers for RPC requests use anyhow::anyhow; use futures_lite::{Stream, StreamExt}; @@ -8,7 +8,7 @@ use iroh_docs::{Author, DocTicket, NamespaceSecret}; use crate::{ client::docs::ShareMode, - node::DocsEngine, + node::protocol::docs::DocsProtocol, rpc_protocol::{ authors::{ CreateRequest, CreateResponse, DeleteRequest, DeleteResponse, ExportRequest, @@ -35,7 +35,7 @@ use crate::{ const ITER_CHANNEL_CAP: usize = 64; #[allow(missing_docs)] -impl DocsEngine { +impl DocsProtocol { pub async fn author_create(&self, _req: CreateRequest) -> RpcResult { // TODO: pass rng let author = Author::new(&mut rand::rngs::OsRng {});