From 79071f4d871598fb2390cc9d07fe2b48da85c06b Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 11 Jun 2024 18:37:05 +0200 Subject: [PATCH 01/33] feat: custom protocols --- iroh/src/node.rs | 16 +++++++++-- iroh/src/node/builder.rs | 60 ++++++++++++++++++++++++++++++++++++--- iroh/src/node/protocol.rs | 19 +++++++++++++ 3 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 iroh/src/node/protocol.rs diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 3b9173c706..0b4f3d9a0c 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -5,8 +5,8 @@ //! To shut down the node, call [`Node::shutdown`]. use std::fmt::Debug; use std::net::SocketAddr; -use std::path::Path; use std::sync::Arc; +use std::{any::Any, path::Path}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -23,14 +23,16 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::client::RpcService; +use crate::{client::RpcService, node::builder::ProtocolMap}; mod builder; +mod protocol; mod rpc; mod rpc_status; pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; +pub use protocol::Protocol; /// A server which implements the iroh node. /// @@ -47,6 +49,7 @@ pub struct Node { inner: Arc>, task: Arc>, client: crate::client::MemIroh, + protocols: ProtocolMap, } #[derive(derive_more::Debug)] @@ -150,6 +153,15 @@ impl Node { self.inner.endpoint.my_relay() } + /// Returns the protocol handler for a alpn. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + let protocols = self.protocols.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + let protocol_any: Arc = protocol.as_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + /// Aborts the node. /// /// This does not gracefully terminate currently: all connections are closed and diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index db935479f2..2e1f38ed25 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -1,8 +1,8 @@ use std::{ - collections::BTreeSet, + collections::{BTreeSet, HashMap}, net::{Ipv4Addr, SocketAddrV4}, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, RwLock}, time::Duration, }; @@ -28,11 +28,13 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::sync::oneshot; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, + node::Protocol, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -54,6 +56,9 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5); const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; +pub(super) type ProtocolMap = Arc>>>; +type ProtocolBuilders = Vec<(&'static [u8], Box) -> Arc>)>; + /// Builder for the [`Node`]. /// /// You must supply a blob store and a document store. @@ -84,6 +89,7 @@ where dns_resolver: Option, node_discovery: DiscoveryConfig, docs_store: iroh_docs::store::fs::Store, + protocols: ProtocolBuilders, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, /// Callback to register when a gc loop is done @@ -133,6 +139,7 @@ impl Default for Builder { rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, docs_store: iroh_docs::store::Store::memory(), + protocols: Default::default(), node_discovery: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, @@ -160,6 +167,7 @@ impl Builder { gc_policy: GcPolicy::Disabled, docs_store, node_discovery: Default::default(), + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: None, @@ -223,6 +231,7 @@ where gc_policy: self.gc_policy, docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: self.gc_done_callback, @@ -244,6 +253,7 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -270,6 +280,7 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -343,6 +354,16 @@ where self } + /// Accept a custom protocol. + pub fn accept( + mut self, + alpn: &'static [u8], + protocol: impl FnOnce(Node) -> Arc + 'static, + ) -> Self { + self.protocols.push((alpn, Box::new(protocol))); + self + } + /// Register a callback for when GC is done. #[cfg(any(test, feature = "test-utils"))] pub fn register_gc_done_cb(mut self, cb: Box) -> Self { @@ -481,6 +502,8 @@ where let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); + let protocols = Arc::new(RwLock::new(HashMap::new())); + let inner = Arc::new(NodeInner { db: self.blobs_store, endpoint: endpoint.clone(), @@ -492,7 +515,9 @@ where sync, downloader, }); + let (ready_tx, ready_rx) = oneshot::channel(); let task = { + let protocols = Arc::clone(&protocols); let gossip = gossip.clone(); let handler = rpc::Handler { inner: inner.clone(), @@ -501,8 +526,11 @@ where let ep = endpoint.clone(); tokio::task::spawn( async move { + // Wait until the protocol builders have run. + ready_rx.await.expect("cannot fail"); Self::run( ep, + protocols, handler, self.rpc_endpoint, internal_rpc, @@ -518,8 +546,17 @@ where inner, task: Arc::new(task), client, + protocols, }; + for (alpn, p) in self.protocols { + let protocol = p(node.clone()); + node.protocols.write().unwrap().insert(alpn, protocol); + } + + // Notify the run task that the protocols are now built. + ready_tx.send(()).expect("cannot fail"); + // spawn a task that updates the gossip endpoints. // TODO: track task let mut stream = endpoint.local_endpoints(); @@ -545,6 +582,7 @@ where #[allow(clippy::too_many_arguments)] async fn run( server: Endpoint, + protocols: ProtocolMap, handler: rpc::Handler, rpc: E, internal_rpc: impl ServiceEndpoint, @@ -615,8 +653,9 @@ where let gossip = gossip.clone(); let inner = handler.inner.clone(); let sync = handler.inner.sync.clone(); + let protocols = protocols.clone(); tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync).await { + if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } }); @@ -738,6 +777,7 @@ async fn handle_connection( node: Arc>, gossip: Gossip, sync: DocsEngine, + protocols: ProtocolMap, ) -> Result<()> { match alpn.as_bytes() { GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, @@ -752,7 +792,19 @@ async fn handle_connection( ) .await } - _ => bail!("ignoring connection: unsupported ALPN protocol"), + alpn => { + let protocol = { + let protocols = protocols.read().unwrap(); + protocols.get(alpn).cloned() + }; + if let Some(protocol) = protocol { + drop(protocols); + let connection = connecting.await?; + protocol.accept(connection).await?; + } else { + bail!("ignoring connection: unsupported ALPN protocol"); + } + } } Ok(()) } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs new file mode 100644 index 0000000000..4dc7dbb29d --- /dev/null +++ b/iroh/src/node/protocol.rs @@ -0,0 +1,19 @@ +use std::{any::Any, fmt, future::Future, pin::Pin, sync::Arc}; + +use iroh_net::endpoint::Connection; + +/// Trait for iroh protocol handlers. +pub trait Protocol: Sync + Send + Any + fmt::Debug + 'static { + /// Return `self` as `dyn Any`. + /// + /// Implementations can simply return `self` here. + fn as_arc_any(self: Arc) -> Arc; + + /// Accept an incoming connection. + /// + /// This runs on a freshly spawned tokio task so this can be long-running. + fn accept( + &self, + conn: Connection, + ) -> Pin> + 'static + Send + Sync>>; +} From aa1d78a861cbda50d217fba0e7e57e4ca8c87b57 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 11 Jun 2024 19:02:02 +0200 Subject: [PATCH 02/33] fixes and add example --- iroh/Cargo.toml | 4 + iroh/examples/custom-protocol.rs | 134 +++++++++++++++++++++++++++++++ iroh/src/node/builder.rs | 8 +- iroh/src/node/protocol.rs | 11 ++- 4 files changed, 150 insertions(+), 7 deletions(-) create mode 100644 iroh/examples/custom-protocol.rs diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index aac6f9a645..5130f336c2 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -101,3 +101,7 @@ required-features = ["examples"] [[example]] name = "client" required-features = ["examples"] + +[[example]] +name = "custom-protocol" +required-features = ["examples"] diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs new file mode 100644 index 0000000000..e67896f76b --- /dev/null +++ b/iroh/examples/custom-protocol.rs @@ -0,0 +1,134 @@ +use std::{any::Any, fmt, sync::Arc}; + +use anyhow::Result; +use clap::Parser; +use futures_lite::future::Boxed; +use iroh::{ + blobs::store::Store, + net::{ + endpoint::{get_remote_node_id, Connection}, + NodeId, + }, + node::{Node, Protocol}, +}; +use tracing_subscriber::{prelude::*, EnvFilter}; + +const EXAMPLE_ALPN: &'static [u8] = b"example-proto/0"; + +#[derive(Debug, Parser)] +pub struct Cli { + #[clap(subcommand)] + command: Command, +} + +#[derive(Debug, Parser)] +pub enum Command { + Accept, + Connect { node: NodeId }, +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Cli::parse(); + // create a new node + let node = iroh::node::Node::memory() + .accept(EXAMPLE_ALPN, |node| ExampleProtocol::build(node)) + .spawn() + .await?; + + // print the ticket if this is the accepting side + match args.command { + Command::Accept => { + let node_id = node.node_id(); + println!("node id: {node_id}"); + // wait until ctrl-c + tokio::signal::ctrl_c().await?; + } + Command::Connect { node: node_id } => { + let proto = ExampleProtocol::from_node(&node, EXAMPLE_ALPN).expect("it is registered"); + proto.connect(node_id).await?; + } + } + + node.shutdown().await?; + + Ok(()) +} + +#[derive(Debug)] +struct ExampleProtocol { + node: Node, +} + +impl Protocol for ExampleProtocol { + fn as_arc_any(self: Arc) -> Arc { + self + } + + fn accept(self: Arc, conn: quinn::Connection) -> Boxed> { + Box::pin(async move { self.handle_connection(conn).await }) + } +} + +impl ExampleProtocol { + fn build(node: Node) -> Arc { + Arc::new(Self { node }) + } + + fn from_node(node: &Node, alpn: &'static [u8]) -> Option> { + node.get_protocol::>(alpn) + } + + async fn handle_connection(&self, conn: Connection) -> Result<()> { + let remote_node_id = get_remote_node_id(&conn)?; + println!("accepting new connection from {remote_node_id}"); + let mut send_stream = conn.open_uni().await?; + println!("stream open!"); + // not that this is something that you wanted to do, but let's create a new blob for each + // incoming connection. this could be any mechanism, but we want to demonstrate how to use a + // custom protocol together with built-in iroh functionality + let content = format!("this blob is created for my beloved peer {remote_node_id} ♥"); + let hash = self + .node + .blobs() + .add_bytes(content.as_bytes().to_vec()) + .await?; + // send the hash over our custom proto + send_stream.write_all(hash.hash.as_bytes()).await?; + send_stream.finish().await?; + Ok(()) + } + + pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + println!("connecting to {remote_node_id}"); + let conn = self + .node + .endpoint() + .connect_by_node_id(&remote_node_id, EXAMPLE_ALPN) + .await?; + let mut recv_stream = conn.accept_uni().await?; + let hash_bytes = recv_stream.read_to_end(32).await?; + let hash = iroh::blobs::Hash::from_bytes(*(&hash_bytes.try_into().unwrap())); + println!("received hash: {hash}"); + self.node + .blobs() + .download(hash, remote_node_id.into()) + .await? + .await?; + println!("blob downloaded"); + let content = self.node.blobs().read_to_bytes(hash).await?; + let message = String::from_utf8(content.to_vec())?; + println!("blob content: {message}"); + Ok(()) + } +} + +// set the RUST_LOG env var to one of {debug,info,warn} to see logging info +pub fn setup_logging() { + tracing_subscriber::registry() + .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) + .with(EnvFilter::from_default_env()) + .try_init() + .ok(); +} diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 2e1f38ed25..c0668c7843 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -422,10 +422,16 @@ where } }; + let alpns = PROTOCOLS + .iter() + .chain(self.protocols.iter().map(|(alpn, _)| alpn)) + .map(|p| p.to_vec()) + .collect(); + let endpoint = Endpoint::builder() .secret_key(self.secret_key.clone()) .proxy_from_env() - .alpns(PROTOCOLS.iter().map(|p| p.to_vec()).collect()) + .alpns(alpns) .keylog(self.keylog) .transport_config(transport_config) .concurrent_connections(MAX_CONNECTIONS) diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 4dc7dbb29d..55046952e6 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -1,9 +1,11 @@ -use std::{any::Any, fmt, future::Future, pin::Pin, sync::Arc}; +use std::{any::Any, fmt, sync::Arc}; +use anyhow::Result; +use futures_lite::future::Boxed; use iroh_net::endpoint::Connection; /// Trait for iroh protocol handlers. -pub trait Protocol: Sync + Send + Any + fmt::Debug + 'static { +pub trait Protocol: Send + Sync + Any + fmt::Debug + 'static { /// Return `self` as `dyn Any`. /// /// Implementations can simply return `self` here. @@ -12,8 +14,5 @@ pub trait Protocol: Sync + Send + Any + fmt::Debug + 'static { /// Accept an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. - fn accept( - &self, - conn: Connection, - ) -> Pin> + 'static + Send + Sync>>; + fn accept(self: Arc, conn: Connection) -> Boxed>; } From 62dda4812798ca14b2daf976de7e94c4a4c088a7 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 11 Jun 2024 19:08:35 +0200 Subject: [PATCH 03/33] improve example --- iroh/examples/custom-protocol.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index e67896f76b..ded75add1a 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -82,9 +82,8 @@ impl ExampleProtocol { async fn handle_connection(&self, conn: Connection) -> Result<()> { let remote_node_id = get_remote_node_id(&conn)?; - println!("accepting new connection from {remote_node_id}"); + println!("accepted connection from {remote_node_id}"); let mut send_stream = conn.open_uni().await?; - println!("stream open!"); // not that this is something that you wanted to do, but let's create a new blob for each // incoming connection. this could be any mechanism, but we want to demonstrate how to use a // custom protocol together with built-in iroh functionality @@ -97,10 +96,12 @@ impl ExampleProtocol { // send the hash over our custom proto send_stream.write_all(hash.hash.as_bytes()).await?; send_stream.finish().await?; + println!("closing connection from {remote_node_id}"); Ok(()) } pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + println!("our node id: {}", self.node.node_id()); println!("connecting to {remote_node_id}"); let conn = self .node From 2b149866018b326885edde4d57a0fd6daf5774d8 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 11 Jun 2024 19:12:43 +0200 Subject: [PATCH 04/33] improve example --- iroh/examples/custom-protocol.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index ded75add1a..2516eef8d7 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -13,8 +13,6 @@ use iroh::{ }; use tracing_subscriber::{prelude::*, EnvFilter}; -const EXAMPLE_ALPN: &'static [u8] = b"example-proto/0"; - #[derive(Debug, Parser)] pub struct Cli { #[clap(subcommand)] @@ -33,7 +31,7 @@ async fn main() -> Result<()> { let args = Cli::parse(); // create a new node let node = iroh::node::Node::memory() - .accept(EXAMPLE_ALPN, |node| ExampleProtocol::build(node)) + .accept(ExampleProto::ALPN, |node| ExampleProto::build(node)) .spawn() .await?; @@ -46,7 +44,7 @@ async fn main() -> Result<()> { tokio::signal::ctrl_c().await?; } Command::Connect { node: node_id } => { - let proto = ExampleProtocol::from_node(&node, EXAMPLE_ALPN).expect("it is registered"); + let proto = ExampleProto::get_from_node(&node, EXAMPLE_ALPN).expect("it is registered"); proto.connect(node_id).await?; } } @@ -57,11 +55,11 @@ async fn main() -> Result<()> { } #[derive(Debug)] -struct ExampleProtocol { +struct ExampleProto { node: Node, } -impl Protocol for ExampleProtocol { +impl Protocol for ExampleProto { fn as_arc_any(self: Arc) -> Arc { self } @@ -71,13 +69,15 @@ impl Protocol for ExampleProtocol { } } -impl ExampleProtocol { +impl ExampleProto { + const ALPN: &'static [u8] = b"example-proto/0"; + fn build(node: Node) -> Arc { Arc::new(Self { node }) } - fn from_node(node: &Node, alpn: &'static [u8]) -> Option> { - node.get_protocol::>(alpn) + fn get_from_node(node: &Node, alpn: &'static [u8]) -> Option> { + node.get_protocol::>(alpn) } async fn handle_connection(&self, conn: Connection) -> Result<()> { @@ -100,7 +100,7 @@ impl ExampleProtocol { Ok(()) } - pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + async fn connect(&self, remote_node_id: NodeId) -> Result<()> { println!("our node id: {}", self.node.node_id()); println!("connecting to {remote_node_id}"); let conn = self From 9261259f0bfb50a02187c0c29b553b61c2b04491 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 11 Jun 2024 19:17:12 +0200 Subject: [PATCH 05/33] make builder send again --- iroh/src/node/builder.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index c0668c7843..b7f5a9aded 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -57,7 +57,10 @@ const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; pub(super) type ProtocolMap = Arc>>>; -type ProtocolBuilders = Vec<(&'static [u8], Box) -> Arc>)>; +type ProtocolBuilders = Vec<( + &'static [u8], + Box) -> Arc + Send + 'static>, +)>; /// Builder for the [`Node`]. /// @@ -358,7 +361,7 @@ where pub fn accept( mut self, alpn: &'static [u8], - protocol: impl FnOnce(Node) -> Arc + 'static, + protocol: impl FnOnce(Node) -> Arc + Send + 'static, ) -> Self { self.protocols.push((alpn, Box::new(protocol))); self From 85e40de75b113bad4db4d4780ed1add6e51e735f Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 11 Jun 2024 19:22:17 +0200 Subject: [PATCH 06/33] fix & clippy --- iroh/examples/custom-protocol.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 2516eef8d7..f898b75ed3 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -31,7 +31,7 @@ async fn main() -> Result<()> { let args = Cli::parse(); // create a new node let node = iroh::node::Node::memory() - .accept(ExampleProto::ALPN, |node| ExampleProto::build(node)) + .accept(EXAMPLE_ALPN, |node| ExampleProto::build(node)) .spawn() .await?; @@ -54,6 +54,8 @@ async fn main() -> Result<()> { Ok(()) } +const EXAMPLE_ALPN: &'static [u8] = b"example-proto/0"; + #[derive(Debug)] struct ExampleProto { node: Node, @@ -70,8 +72,6 @@ impl Protocol for ExampleProto { } impl ExampleProto { - const ALPN: &'static [u8] = b"example-proto/0"; - fn build(node: Node) -> Arc { Arc::new(Self { node }) } @@ -110,7 +110,7 @@ impl ExampleProto { .await?; let mut recv_stream = conn.accept_uni().await?; let hash_bytes = recv_stream.read_to_end(32).await?; - let hash = iroh::blobs::Hash::from_bytes(*(&hash_bytes.try_into().unwrap())); + let hash = iroh::blobs::Hash::from_bytes(hash_bytes.try_into().unwrap()); println!("received hash: {hash}"); self.node .blobs() From ee043e5df29f91e767e2ae9c4c8b4b5d80a04583 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 17:49:34 +0200 Subject: [PATCH 07/33] cleanups and PR review --- iroh/examples/custom-protocol.rs | 10 +++++----- iroh/src/node/builder.rs | 4 +--- iroh/src/node/protocol.rs | 8 ++++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index f898b75ed3..5507e3b098 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -2,11 +2,11 @@ use std::{any::Any, fmt, sync::Arc}; use anyhow::Result; use clap::Parser; -use futures_lite::future::Boxed; +use futures_lite::future::Boxed as BoxedFuture; use iroh::{ blobs::store::Store, net::{ - endpoint::{get_remote_node_id, Connection}, + endpoint::{get_remote_node_id, Connecting, Connection}, NodeId, }, node::{Node, Protocol}, @@ -54,7 +54,7 @@ async fn main() -> Result<()> { Ok(()) } -const EXAMPLE_ALPN: &'static [u8] = b"example-proto/0"; +const EXAMPLE_ALPN: &[u8] = b"example-proto/0"; #[derive(Debug)] struct ExampleProto { @@ -66,8 +66,8 @@ impl Protocol for ExampleProto { self } - fn accept(self: Arc, conn: quinn::Connection) -> Boxed> { - Box::pin(async move { self.handle_connection(conn).await }) + fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) } } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index b7f5a9aded..941ff8915f 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -807,9 +807,7 @@ async fn handle_connection( protocols.get(alpn).cloned() }; if let Some(protocol) = protocol { - drop(protocols); - let connection = connecting.await?; - protocol.accept(connection).await?; + protocol.handle_connection(connecting).await?; } else { bail!("ignoring connection: unsupported ALPN protocol"); } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 55046952e6..dd9db9d84c 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -1,8 +1,8 @@ use std::{any::Any, fmt, sync::Arc}; use anyhow::Result; -use futures_lite::future::Boxed; -use iroh_net::endpoint::Connection; +use futures_lite::future::Boxed as BoxedFuture; +use iroh_net::endpoint::Connecting; /// Trait for iroh protocol handlers. pub trait Protocol: Send + Sync + Any + fmt::Debug + 'static { @@ -11,8 +11,8 @@ pub trait Protocol: Send + Sync + Any + fmt::Debug + 'static { /// Implementations can simply return `self` here. fn as_arc_any(self: Arc) -> Arc; - /// Accept an incoming connection. + /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. - fn accept(self: Arc, conn: Connection) -> Boxed>; + fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture>; } From db3513638382c0fab7961406516032ab59138f28 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 18:07:01 +0200 Subject: [PATCH 08/33] improve code structure --- Cargo.lock | 1 + iroh/Cargo.toml | 1 + iroh/examples/custom-protocol.rs | 4 --- iroh/src/node.rs | 16 +++++----- iroh/src/node/builder.rs | 48 ++++++++++++---------------- iroh/src/node/protocol.rs | 54 +++++++++++++++++++++++++++----- 6 files changed, 76 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a63e49d931..ef3fa7ca76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2437,6 +2437,7 @@ dependencies = [ "iroh-quinn", "iroh-test", "num_cpus", + "once_cell", "parking_lot", "portable-atomic", "postcard", diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 5130f336c2..5462a5f2ff 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -35,6 +35,7 @@ num_cpus = { version = "1.15.0" } portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } +once_cell = "1.18.0" parking_lot = "0.12.1" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] } diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 5507e3b098..be68c7e44a 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -62,10 +62,6 @@ struct ExampleProto { } impl Protocol for ExampleProto { - fn as_arc_any(self: Arc) -> Arc { - self - } - fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { self.handle_connection(conn.await?).await }) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 0b4f3d9a0c..36cf4705a9 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -5,8 +5,8 @@ //! To shut down the node, call [`Node::shutdown`]. use std::fmt::Debug; use std::net::SocketAddr; +use std::path::Path; use std::sync::Arc; -use std::{any::Any, path::Path}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -16,6 +16,7 @@ use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; use iroh_net::util::AbortingJoinHandle; use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; +use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; use tokio::task::JoinHandle; @@ -23,7 +24,7 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::{client::RpcService, node::builder::ProtocolMap}; +use crate::{client::RpcService, node::protocol::ProtocolMap}; mod builder; mod protocol; @@ -47,7 +48,7 @@ pub use protocol::Protocol; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>, + task: Arc>>, client: crate::client::MemIroh, protocols: ProtocolMap, } @@ -155,11 +156,7 @@ impl Node { /// Returns the protocol handler for a alpn. pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - let protocols = self.protocols.read().unwrap(); - let protocol: Arc = protocols.get(alpn)?.clone(); - let protocol_any: Arc = protocol.as_arc_any(); - let protocol_ref = Arc::downcast(protocol_any).ok()?; - Some(protocol_ref) + self.protocols.get(alpn) } /// Aborts the node. @@ -173,7 +170,8 @@ impl Node { pub async fn shutdown(self) -> Result<()> { self.inner.cancel_token.cancel(); - if let Ok(task) = Arc::try_unwrap(self.task) { + if let Ok(mut task) = Arc::try_unwrap(self.task) { + let task = task.take().expect("cannot be empty"); task.await?; } Ok(()) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 941ff8915f..34724b64e1 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -1,8 +1,8 @@ use std::{ - collections::{BTreeSet, HashMap}, + collections::BTreeSet, net::{Ipv4Addr, SocketAddrV4}, path::{Path, PathBuf}, - sync::{Arc, RwLock}, + sync::Arc, time::Duration, }; @@ -28,13 +28,12 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; -use tokio::sync::oneshot; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, - node::Protocol, + node::{protocol::ProtocolMap, Protocol}, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -56,7 +55,6 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5); const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; -pub(super) type ProtocolMap = Arc>>>; type ProtocolBuilders = Vec<( &'static [u8], Box) -> Arc + Send + 'static>, @@ -511,7 +509,7 @@ where let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); - let protocols = Arc::new(RwLock::new(HashMap::new())); + let protocols = ProtocolMap::default(); let inner = Arc::new(NodeInner { db: self.blobs_store, @@ -524,9 +522,21 @@ where sync, downloader, }); - let (ready_tx, ready_rx) = oneshot::channel(); + + let node = Node { + inner: inner.clone(), + task: Default::default(), + client, + protocols: protocols.clone(), + }; + + for (alpn, p) in self.protocols { + let protocol = p(node.clone()); + protocols.insert(alpn, protocol); + } + let task = { - let protocols = Arc::clone(&protocols); + let protocols = protocols.clone(); let gossip = gossip.clone(); let handler = rpc::Handler { inner: inner.clone(), @@ -535,8 +545,6 @@ where let ep = endpoint.clone(); tokio::task::spawn( async move { - // Wait until the protocol builders have run. - ready_rx.await.expect("cannot fail"); Self::run( ep, protocols, @@ -551,20 +559,7 @@ where ) }; - let node = Node { - inner, - task: Arc::new(task), - client, - protocols, - }; - - for (alpn, p) in self.protocols { - let protocol = p(node.clone()); - node.protocols.write().unwrap().insert(alpn, protocol); - } - - // Notify the run task that the protocols are now built. - ready_tx.send(()).expect("cannot fail"); + node.task.set(task).expect("was empty"); // spawn a task that updates the gossip endpoints. // TODO: track task @@ -802,10 +797,7 @@ async fn handle_connection( .await } alpn => { - let protocol = { - let protocols = protocols.read().unwrap(); - protocols.get(alpn).cloned() - }; + let protocol = protocols.get_any(alpn); if let Some(protocol) = protocol { protocol.handle_connection(connecting).await?; } else { diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index dd9db9d84c..0bba70cc73 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -1,18 +1,58 @@ -use std::{any::Any, fmt, sync::Arc}; +use std::{ + any::Any, + collections::HashMap, + fmt, + sync::{Arc, RwLock}, +}; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; /// Trait for iroh protocol handlers. -pub trait Protocol: Send + Sync + Any + fmt::Debug + 'static { - /// Return `self` as `dyn Any`. - /// - /// Implementations can simply return `self` here. - fn as_arc_any(self: Arc) -> Arc; - +pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture>; } + +/// Helper trait to facilite casting from `Arc` to `Arc`. +/// +/// This trait has a blanket implementation so there is no need to implement this yourself. +pub trait IntoArcAny { + fn into_arc_any(self: Arc) -> Arc; +} + +impl IntoArcAny for T { + fn into_arc_any(self: Arc) -> Arc { + self + } +} + +/// Map of registered protocol handlers. +#[allow(clippy::type_complexity)] +#[derive(Debug, Clone, Default)] +pub struct ProtocolMap(Arc>>>); + +impl ProtocolMap { + /// Returns the registered protocol handler for an ALPN as a concrete type. + pub fn get(&self, alpn: &[u8]) -> Option> { + let protocols = self.0.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + let protocol_any: Arc = protocol.into_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + + /// Returns the registered protocol handler for an ALPN as a `dyn Protocol`. + pub fn get_any(&self, alpn: &[u8]) -> Option> { + let protocols = self.0.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + Some(protocol) + } + + pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc) { + self.0.write().unwrap().insert(alpn, protocol); + } +} From c7517ce8f99edae25d77b40eab1e6808dd7ec992 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 18:13:29 +0200 Subject: [PATCH 09/33] fixup --- iroh/examples/custom-protocol.rs | 45 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index be68c7e44a..4553718142 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -1,4 +1,4 @@ -use std::{any::Any, fmt, sync::Arc}; +use std::{fmt, sync::Arc}; use anyhow::Result; use clap::Parser; @@ -6,7 +6,7 @@ use futures_lite::future::Boxed as BoxedFuture; use iroh::{ blobs::store::Store, net::{ - endpoint::{get_remote_node_id, Connecting, Connection}, + endpoint::{get_remote_node_id, Connecting}, NodeId, }, node::{Node, Protocol}, @@ -63,7 +63,26 @@ struct ExampleProto { impl Protocol for ExampleProto { fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture> { - Box::pin(async move { self.handle_connection(conn.await?).await }) + Box::pin(async move { + let conn = conn.await?; + let remote_node_id = get_remote_node_id(&conn)?; + println!("accepted connection from {remote_node_id}"); + let mut send_stream = conn.open_uni().await?; + // not that this is something that you wanted to do, but let's create a new blob for each + // incoming connection. this could be any mechanism, but we want to demonstrate how to use a + // custom protocol together with built-in iroh functionality + let content = format!("this blob is created for my beloved peer {remote_node_id} ♥"); + let hash = self + .node + .blobs() + .add_bytes(content.as_bytes().to_vec()) + .await?; + // send the hash over our custom proto + send_stream.write_all(hash.hash.as_bytes()).await?; + send_stream.finish().await?; + println!("closing connection from {remote_node_id}"); + Ok(()) + }) } } @@ -76,26 +95,6 @@ impl ExampleProto { node.get_protocol::>(alpn) } - async fn handle_connection(&self, conn: Connection) -> Result<()> { - let remote_node_id = get_remote_node_id(&conn)?; - println!("accepted connection from {remote_node_id}"); - let mut send_stream = conn.open_uni().await?; - // not that this is something that you wanted to do, but let's create a new blob for each - // incoming connection. this could be any mechanism, but we want to demonstrate how to use a - // custom protocol together with built-in iroh functionality - let content = format!("this blob is created for my beloved peer {remote_node_id} ♥"); - let hash = self - .node - .blobs() - .add_bytes(content.as_bytes().to_vec()) - .await?; - // send the hash over our custom proto - send_stream.write_all(hash.hash.as_bytes()).await?; - send_stream.finish().await?; - println!("closing connection from {remote_node_id}"); - Ok(()) - } - async fn connect(&self, remote_node_id: NodeId) -> Result<()> { println!("our node id: {}", self.node.node_id()); println!("connecting to {remote_node_id}"); From 11a609f51ee2a4115f6f3286ca685c14365924ce Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 18:32:16 +0200 Subject: [PATCH 10/33] rename back to accept --- iroh/examples/custom-protocol.rs | 2 +- iroh/src/node/builder.rs | 2 +- iroh/src/node/protocol.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 4553718142..c973b22063 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -62,7 +62,7 @@ struct ExampleProto { } impl Protocol for ExampleProto { - fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture> { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { let conn = conn.await?; let remote_node_id = get_remote_node_id(&conn)?; diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 34724b64e1..d23732a08c 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -799,7 +799,7 @@ async fn handle_connection( alpn => { let protocol = protocols.get_any(alpn); if let Some(protocol) = protocol { - protocol.handle_connection(connecting).await?; + protocol.accept(connecting).await?; } else { bail!("ignoring connection: unsupported ALPN protocol"); } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 0bba70cc73..139ebbda8a 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -14,7 +14,7 @@ pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. - fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture>; + fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; } /// Helper trait to facilite casting from `Arc` to `Arc`. From a3690480b08e45b119b039118b9c4600aa152e38 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 16:43:01 +0200 Subject: [PATCH 11/33] refactor: use new protocols api and allow to disable docs --- iroh-blobs/src/store/traits.rs | 2 +- iroh-docs/src/engine.rs | 2 +- iroh-docs/src/engine/live.rs | 4 +- iroh-docs/src/net.rs | 3 +- iroh/Cargo.toml | 2 +- iroh/examples/custom-protocol.rs | 4 +- iroh/src/client/authors.rs | 2 +- iroh/src/node.rs | 23 ++- iroh/src/node/builder.rs | 323 +++++++++++++++++-------------- iroh/src/node/protocol.rs | 63 +++++- iroh/src/node/rpc.rs | 221 ++++++++++++++++----- iroh/src/node/rpc/docs.rs | 55 +++--- iroh/src/rpc_protocol.rs | 4 +- iroh/tests/gc.rs | 24 +-- iroh/tests/provide.rs | 6 +- 15 files changed, 478 insertions(+), 260 deletions(-) diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index e0ec3e6b39..2a91d1c0f3 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -295,7 +295,7 @@ pub trait ReadableStore: Map { } /// The mutable part of a Bao store. -pub trait Store: ReadableStore + MapMut { +pub trait Store: ReadableStore + MapMut + std::fmt::Debug { /// This trait method imports a file from a local path. /// /// `data` is the path to the file. diff --git a/iroh-docs/src/engine.rs b/iroh-docs/src/engine.rs index b5345b0bea..c0867b644d 100644 --- a/iroh-docs/src/engine.rs +++ b/iroh-docs/src/engine.rs @@ -197,7 +197,7 @@ impl Engine { /// Handle an incoming iroh-docs connection. pub async fn handle_connection( &self, - conn: iroh_net::endpoint::Connecting, + conn: iroh_net::endpoint::Connection, ) -> anyhow::Result<()> { self.to_live_actor .send(ToLiveActor::HandleConnection { conn }) diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 366379f4a3..86c7cedaba 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -76,7 +76,7 @@ pub enum ToLiveActor { reply: sync::oneshot::Sender>, }, HandleConnection { - conn: iroh_net::endpoint::Connecting, + conn: iroh_net::endpoint::Connection, }, AcceptSyncRequest { namespace: NamespaceId, @@ -749,7 +749,7 @@ impl LiveActor { } #[instrument("accept", skip_all)] - pub async fn handle_connection(&mut self, conn: iroh_net::endpoint::Connecting) { + pub async fn handle_connection(&mut self, conn: iroh_net::endpoint::Connection) { let to_actor_tx = self.sync_actor_tx.clone(); let accept_request_cb = move |namespace, peer| { let to_actor_tx = to_actor_tx.clone(); diff --git a/iroh-docs/src/net.rs b/iroh-docs/src/net.rs index a3f90032e1..cc29d3ec59 100644 --- a/iroh-docs/src/net.rs +++ b/iroh-docs/src/net.rs @@ -106,7 +106,7 @@ pub enum AcceptOutcome { /// Handle an iroh-docs connection and sync all shared documents in the replica store. pub async fn handle_connection( sync: SyncHandle, - connecting: iroh_net::endpoint::Connecting, + connection: iroh_net::endpoint::Connection, accept_cb: F, ) -> Result where @@ -114,7 +114,6 @@ where Fut: Future, { let t_start = Instant::now(); - let connection = connecting.await.map_err(AcceptError::connect)?; let peer = get_remote_node_id(&connection).map_err(AcceptError::connect)?; let (mut send_stream, mut recv_stream) = connection .accept_bi() diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 5462a5f2ff..a8b92488f7 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -32,10 +32,10 @@ iroh-io = { version = "0.6.0", features = ["stats"] } iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } iroh-net = { version = "0.18.0", path = "../iroh-net" } num_cpus = { version = "1.15.0" } +once_cell = "1.17.0" portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } -once_cell = "1.18.0" parking_lot = "0.12.1" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] } diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index c973b22063..d8a5a54ac0 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -31,7 +31,9 @@ async fn main() -> Result<()> { let args = Cli::parse(); // create a new node let node = iroh::node::Node::memory() - .accept(EXAMPLE_ALPN, |node| ExampleProto::build(node)) + .accept(EXAMPLE_ALPN, |node| { + Box::pin(async move { Ok(ExampleProto::build(node)) }) + }) .spawn() .await?; diff --git a/iroh/src/client/authors.rs b/iroh/src/client/authors.rs index e6bddbb494..7cdd44ce72 100644 --- a/iroh/src/client/authors.rs +++ b/iroh/src/client/authors.rs @@ -40,7 +40,7 @@ where /// /// The default author can be set with [`Self::set_default`]. pub async fn default(&self) -> Result { - let res = self.rpc.rpc(AuthorGetDefaultRequest).await?; + let res = self.rpc.rpc(AuthorGetDefaultRequest).await??; Ok(res.author_id) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 36cf4705a9..0915915989 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use std::net::SocketAddr; use std::path::Path; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -14,12 +14,11 @@ use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; -use iroh_net::util::AbortingJoinHandle; use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; @@ -31,7 +30,7 @@ mod protocol; mod rpc; mod rpc_status; -pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; +pub use self::builder::{Builder, DiscoveryConfig, DocsStorage, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; pub use protocol::Protocol; @@ -60,12 +59,10 @@ struct NodeInner { secret_key: SecretKey, cancel_token: CancellationToken, controller: FlumeConnection, - #[allow(dead_code)] - gc_task: Option>, #[debug("rt")] rt: LocalPoolHandle, - pub(crate) sync: DocsEngine, downloader: Downloader, + tasks: Mutex>>, } /// In memory node. @@ -156,7 +153,11 @@ impl Node { /// Returns the protocol handler for a alpn. pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - self.protocols.get(alpn) + self.protocols.get::

(alpn) + } + + fn downloader(&self) -> &Downloader { + &self.inner.downloader } /// Aborts the node. @@ -171,8 +172,10 @@ impl Node { self.inner.cancel_token.cancel(); if let Ok(mut task) = Arc::try_unwrap(self.task) { - let task = task.take().expect("cannot be empty"); - task.await?; + task.take().expect("cannot be empty").await?; + } + if let Some(mut tasks) = self.inner.tasks.lock().unwrap().take() { + tasks.abort_all(); } Ok(()) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index d23732a08c..aab60ad6d2 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -7,7 +7,7 @@ use std::{ }; use anyhow::{bail, Context, Result}; -use futures_lite::StreamExt; +use futures_lite::{future::Boxed, StreamExt}; use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, @@ -28,12 +28,16 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, - node::{protocol::ProtocolMap, Protocol}, + node::{ + protocol::{BlobsProtocol, ProtocolMap}, + Protocol, + }, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -57,9 +61,18 @@ const MAX_STREAMS: u64 = 10; type ProtocolBuilders = Vec<( &'static [u8], - Box) -> Arc + Send + 'static>, + Box) -> Boxed>> + Send + 'static>, )>; +/// Storage backend for documents. +#[derive(Debug, Clone)] +pub enum DocsStorage { + /// In-memory storage. + Memory, + /// File-based persistent storage. + Persistent(PathBuf), +} + /// Builder for the [`Node`]. /// /// You must supply a blob store and a document store. @@ -89,7 +102,7 @@ where gc_policy: GcPolicy, dns_resolver: Option, node_discovery: DiscoveryConfig, - docs_store: iroh_docs::store::fs::Store, + docs_store: Option, protocols: ProtocolBuilders, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, @@ -139,7 +152,7 @@ impl Default for Builder { dns_resolver: None, rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, - docs_store: iroh_docs::store::Store::memory(), + docs_store: Some(DocsStorage::Memory), protocols: Default::default(), node_discovery: Default::default(), #[cfg(any(test, feature = "test-utils"))] @@ -153,7 +166,7 @@ impl Builder { /// Creates a new builder for [`Node`] using the given databases. pub fn with_db_and_store( blobs_store: D, - docs_store: iroh_docs::store::Store, + docs_store: DocsStorage, storage: StorageConfig, ) -> Self { Self { @@ -166,7 +179,7 @@ impl Builder { dns_resolver: None, rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, - docs_store, + docs_store: Some(docs_store), node_discovery: Default::default(), protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] @@ -193,8 +206,7 @@ where let blobs_store = iroh_blobs::store::fs::Store::load(&blob_dir) .await .with_context(|| format!("Failed to load iroh database from {}", blob_dir.display()))?; - let docs_store = - iroh_docs::store::fs::Store::persistent(IrohPaths::DocsDatabase.with_root(root))?; + let docs_store = DocsStorage::Persistent(IrohPaths::DocsDatabase.with_root(root)); let v0 = blobs_store .import_flat_store(iroh_blobs::store::fs::FlatStorePaths { @@ -230,7 +242,7 @@ where relay_mode: self.relay_mode, dns_resolver: self.dns_resolver, gc_policy: self.gc_policy, - docs_store, + docs_store: Some(docs_store), node_discovery: self.node_discovery, protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] @@ -296,6 +308,12 @@ where self } + /// Disables documents support on this node completely. + pub fn disable_docs(mut self) -> Self { + self.docs_store = None; + self + } + /// Sets the relay servers to assist in establishing connectivity. /// /// Relay servers are used to discover other nodes by `PublicKey` and also help @@ -359,7 +377,7 @@ where pub fn accept( mut self, alpn: &'static [u8], - protocol: impl FnOnce(Node) -> Arc + Send + 'static, + protocol: impl FnOnce(Node) -> Boxed>> + Send + 'static, ) -> Self { self.protocols.push((alpn, Box::new(protocol))); self @@ -387,10 +405,68 @@ where /// This will create the underlying network server and spawn a tokio task accepting /// connections. The returned [`Node`] can be used to control the task as well as /// get information about it. - pub async fn spawn(self) -> Result> { + pub async fn spawn(mut self) -> Result> { + // Register the core iroh protocols. + // Register blobs. + let lp = LocalPoolHandle::new(num_cpus::get()); + let blobs_proto = BlobsProtocol::new(self.blobs_store.clone(), lp.clone()); + self = self.accept(iroh_blobs::protocol::ALPN, move |_node| { + Box::pin(async move { + let blobs: Arc = Arc::new(blobs_proto); + Ok(blobs) + }) + }); + + // Register gossip. + self = self.accept(GOSSIP_ALPN, |node| { + Box::pin(async move { + let addr = node.endpoint().my_addr().await?; + let gossip = + Gossip::from_endpoint(node.endpoint().clone(), Default::default(), &addr.info); + let gossip: Arc = Arc::new(gossip); + Ok(gossip) + }) + }); + + if let Some(docs_store) = &self.docs_store { + // register the docs protocol. + let docs_store = match docs_store { + DocsStorage::Memory => iroh_docs::store::fs::Store::memory(), + DocsStorage::Persistent(path) => iroh_docs::store::fs::Store::persistent(path)?, + }; + // load or create the default author for documents + let default_author_storage = match self.storage { + StorageConfig::Persistent(ref root) => { + let path = IrohPaths::DefaultAuthor.with_root(root); + DefaultAuthorStorage::Persistent(path) + } + StorageConfig::Mem => DefaultAuthorStorage::Mem, + }; + let blobs_store = self.blobs_store.clone(); + self = self.accept(DOCS_ALPN, |node| { + Box::pin(async move { + let gossip = node + .get_protocol::(GOSSIP_ALPN) + .context("gossip not found")?; + let sync = Engine::spawn( + node.endpoint().clone(), + (*gossip).clone(), + docs_store, + blobs_store, + node.downloader().clone(), + default_author_storage, + ) + .await?; + let sync = DocsEngine(sync); + let sync: Arc = Arc::new(sync); + Ok(sync) + }) + }); + } + // We clone the blob store to shut it down in case the node fails to spawn. let blobs_store = self.blobs_store.clone(); - match self.spawn_inner().await { + match self.spawn_inner(lp).await { Ok(node) => Ok(node), Err(err) => { debug!("failed to spawn node, shutting down"); @@ -400,9 +476,8 @@ where } } - async fn spawn_inner(mut self) -> Result> { + async fn spawn_inner(mut self, lp: LocalPoolHandle) -> Result> { trace!("spawning node"); - let lp = LocalPoolHandle::new(num_cpus::get()); let mut transport_config = quinn::TransportConfig::default(); transport_config @@ -465,47 +540,12 @@ where debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); - let addr = endpoint.my_addr().await?; - - // initialize the gossip protocol - let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &addr.info); + let blobs_store = self.blobs_store.clone(); + let mut tasks = JoinSet::new(); // initialize the downloader let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); - // load or create the default author for documents - let default_author_storage = match self.storage { - StorageConfig::Persistent(ref root) => { - let path = IrohPaths::DefaultAuthor.with_root(root); - DefaultAuthorStorage::Persistent(path) - } - StorageConfig::Mem => DefaultAuthorStorage::Mem, - }; - - // spawn the docs engine - let sync = Engine::spawn( - endpoint.clone(), - gossip.clone(), - self.docs_store, - self.blobs_store.clone(), - downloader.clone(), - default_author_storage, - ) - .await?; - let sync_db = sync.sync.clone(); - let sync = DocsEngine(sync); - - let gc_task = if let GcPolicy::Interval(gc_period) = self.gc_policy { - tracing::info!("Starting GC task with interval {:?}", gc_period); - let db = self.blobs_store.clone(); - let gc_done_callback = self.gc_done_callback.take(); - - let task = - lp.spawn_pinned(move || Self::gc_loop(db, sync_db, gc_period, gc_done_callback)); - Some(task.into()) - } else { - None - }; let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); @@ -517,9 +557,8 @@ where secret_key: self.secret_key, controller, cancel_token, - gc_task, + tasks: Default::default(), rt: lp.clone(), - sync, downloader, }); @@ -531,47 +570,45 @@ where }; for (alpn, p) in self.protocols { - let protocol = p(node.clone()); + let protocol = p(node.clone()).await?; protocols.insert(alpn, protocol); } let task = { let protocols = protocols.clone(); - let gossip = gossip.clone(); - let handler = rpc::Handler { - inner: inner.clone(), - }; let me = endpoint.node_id().fmt_short(); - let ep = endpoint.clone(); + let inner = inner.clone(); tokio::task::spawn( - async move { - Self::run( - ep, - protocols, - handler, - self.rpc_endpoint, - internal_rpc, - gossip, - ) - .await - } - .instrument(error_span!("node", %me)), + async move { Self::run(inner, protocols, self.rpc_endpoint, internal_rpc).await } + .instrument(error_span!("node", %me)), ) }; - node.task.set(task).expect("was empty"); + if let GcPolicy::Interval(gc_period) = self.gc_policy { + tracing::info!("Starting GC task with interval {:?}", gc_period); + let db = blobs_store.clone(); + let gc_done_callback = self.gc_done_callback.take(); + let sync = protocols.get::(DOCS_ALPN); + + tasks.spawn_local(Self::gc_loop(db, sync, gc_period, gc_done_callback)); + } + // spawn a task that updates the gossip endpoints. - // TODO: track task let mut stream = endpoint.local_endpoints(); - tokio::task::spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_endpoints(&eps) { - warn!("Failed to update gossip endpoints: {err:?}"); + let gossip = protocols.get::(GOSSIP_ALPN); + if let Some(gossip) = gossip { + tasks.spawn(async move { + while let Some(eps) = stream.next().await { + if let Err(err) = gossip.update_endpoints(&eps) { + warn!("Failed to update gossip endpoints: {err:?}"); + } } - } - warn!("failed to retrieve local endpoints"); - }); + warn!("failed to retrieve local endpoints"); + }); + } + + *(node.inner.tasks.lock().unwrap()) = Some(tasks); // Wait for a single endpoint update, to make sure // we found some endpoints @@ -585,13 +622,17 @@ where #[allow(clippy::too_many_arguments)] async fn run( - server: Endpoint, + inner: Arc>, protocols: ProtocolMap, - handler: rpc::Handler, rpc: E, internal_rpc: impl ServiceEndpoint, - gossip: Gossip, ) { + let server = inner.endpoint.clone(); + let docs = protocols.get::(DOCS_ALPN); + let handler = rpc::Handler { + inner: inner.clone(), + docs, + }; let rpc = RpcServer::new(rpc); let internal_rpc = RpcServer::new(internal_rpc); let (ipv4, ipv6) = server.local_addr(); @@ -603,13 +644,16 @@ where let cancel_token = handler.inner.cancel_token.clone(); - // forward our initial endpoints to the gossip protocol - // it may happen the the first endpoint update callback is missed because the gossip cell - // is only initialized once the endpoint is fully bound - if let Some(local_endpoints) = server.local_endpoints().next().await { - debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); - gossip.update_endpoints(&local_endpoints).ok(); + if let Some(gossip) = protocols.get::(GOSSIP_ALPN) { + // forward our initial endpoints to the gossip protocol + // it may happen the the first endpoint update callback is missed because the gossip cell + // is only initialized once the endpoint is fully bound + if let Some(local_endpoints) = server.local_endpoints().next().await { + debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); + gossip.update_endpoints(&local_endpoints).ok(); + } } + loop { tokio::select! { biased; @@ -617,9 +661,23 @@ where // clean shutdown of the blobs db to close the write transaction handler.inner.db.shutdown().await; - if let Err(err) = handler.inner.sync.shutdown().await { - warn!("sync shutdown error: {:?}", err); + // We cannot hold the RwLockReadGuard over an await point, + // so we have to manually loop, clone each protocol, and drop the read guard + // before awaiting shutdown. + let mut i = 0; + loop { + let protocol = { + let protocols = protocols.read(); + if let Some(protocol) = protocols.values().nth(i) { + protocol.clone() + } else { + break; + } + }; + protocol.shutdown().await; + i += 1; } + break }, // handle rpc requests. This will do nothing if rpc is not configured, since @@ -654,12 +712,11 @@ where continue; } }; - let gossip = gossip.clone(); - let inner = handler.inner.clone(); - let sync = handler.inner.sync.clone(); let protocols = protocols.clone(); - tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync, protocols).await { + let mut tasks_guard = inner.tasks.lock().unwrap(); + let tasks = tasks_guard.as_mut().expect("only empty after shutdown"); + tasks.spawn(async move { + if let Err(err) = handle_connection(connecting, alpn, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } }); @@ -681,7 +738,7 @@ where async fn gc_loop( db: D, - ds: iroh_docs::actor::SyncHandle, + ds: Option>, gc_period: Duration, done_cb: Option>, ) { @@ -698,22 +755,24 @@ where tokio::time::sleep(gc_period).await; tracing::debug!("Starting GC"); live.clear(); - let doc_hashes = match ds.content_hashes().await { - Ok(hashes) => hashes, - Err(err) => { - tracing::warn!("Error getting doc hashes: {}", err); - continue 'outer; - } - }; - for hash in doc_hashes { - match hash { - Ok(hash) => { - live.insert(hash); - } + if let Some(ds) = &ds { + let doc_hashes = match ds.sync.content_hashes().await { + Ok(hashes) => hashes, Err(err) => { - tracing::error!("Error getting doc hash: {}", err); + tracing::warn!("Error getting doc hashes: {}", err); continue 'outer; } + }; + for hash in doc_hashes { + match hash { + Ok(hash) => { + live.insert(hash); + } + Err(err) => { + tracing::error!("Error getting doc hash: {}", err); + continue 'outer; + } + } } } @@ -773,37 +832,16 @@ impl Default for GcPolicy { } } -// TODO: Restructure this code to not take all these arguments. -#[allow(clippy::too_many_arguments)] -async fn handle_connection( +async fn handle_connection( connecting: iroh_net::endpoint::Connecting, alpn: String, - node: Arc>, - gossip: Gossip, - sync: DocsEngine, protocols: ProtocolMap, ) -> Result<()> { - match alpn.as_bytes() { - GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, - DOCS_ALPN => sync.handle_connection(connecting).await?, - alpn if alpn == iroh_blobs::protocol::ALPN => { - let connection = connecting.await?; - iroh_blobs::provider::handle_connection( - connection, - node.db.clone(), - MockEventSender, - node.rt.clone(), - ) - .await - } - alpn => { - let protocol = protocols.get_any(alpn); - if let Some(protocol) = protocol { - protocol.accept(connecting).await?; - } else { - bail!("ignoring connection: unsupported ALPN protocol"); - } - } + let protocol = protocols.get_any(alpn.as_bytes()).clone(); + if let Some(protocol) = protocol { + protocol.accept(connecting).await?; + } else { + bail!("ignoring connection: unsupported ALPN protocol"); } Ok(()) } @@ -855,12 +893,3 @@ fn make_rpc_endpoint( Ok((rpc_endpoint, actual_rpc_port)) } - -#[derive(Debug, Clone)] -struct MockEventSender; - -impl iroh_blobs::provider::EventSender for MockEventSender { - fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { - Box::pin(std::future::ready(())) - } -} diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 139ebbda8a..3a099620fe 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -9,12 +9,19 @@ use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; +use crate::node::DocsEngine; + /// Trait for iroh protocol handlers. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; + + /// Called when the node shuts down. + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move {}) + } } /// Helper trait to facilite casting from `Arc` to `Arc`. @@ -30,8 +37,6 @@ impl IntoArcAny for T { } } -/// Map of registered protocol handlers. -#[allow(clippy::type_complexity)] #[derive(Debug, Clone, Default)] pub struct ProtocolMap(Arc>>>); @@ -55,4 +60,58 @@ impl ProtocolMap { pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc) { self.0.write().unwrap().insert(alpn, protocol); } + + pub(super) fn read( + &self, + ) -> std::sync::RwLockReadGuard>> { + self.0.read().unwrap() + } +} + +#[derive(Debug)] +pub(crate) struct BlobsProtocol { + rt: tokio_util::task::LocalPoolHandle, + store: S, +} + +impl BlobsProtocol { + pub fn new(store: S, rt: tokio_util::task::LocalPoolHandle) -> Self { + Self { rt, store } + } +} + +impl Protocol for BlobsProtocol { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { + iroh_blobs::provider::handle_connection( + conn.await?, + self.store.clone(), + MockEventSender, + self.rt.clone(), + ) + .await; + Ok(()) + }) + } +} + +#[derive(Debug, Clone)] +struct MockEventSender; + +impl iroh_blobs::provider::EventSender for MockEventSender { + fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { + Box::pin(std::future::ready(())) + } +} + +impl Protocol for iroh_gossip::net::Gossip { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } +} + +impl Protocol for DocsEngine { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } } diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 6382b50d6a..92dfade8fb 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -7,7 +7,7 @@ use anyhow::{anyhow, ensure, Result}; use futures_buffered::BufferedStreamExt; use futures_lite::{Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; -use iroh_base::rpc::RpcResult; +use iroh_base::rpc::{RpcError, RpcResult}; use iroh_blobs::downloader::{DownloadRequest, Downloader}; use iroh_blobs::export::ExportProgress; use iroh_blobs::format::collection::Collection; @@ -32,21 +32,25 @@ use quic_rpc::{ use tokio_util::task::LocalPoolHandle; use tracing::{debug, info}; -use crate::client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}; -use crate::client::tags::TagInfo; -use crate::client::NodeStatus; use crate::rpc_protocol::{ BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse, BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, - DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocSetHashRequest, - ListTagsRequest, NodeAddrRequest, NodeConnectionInfoRequest, NodeConnectionInfoResponse, - NodeConnectionsRequest, NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, - NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, - NodeWatchResponse, Request, RpcService, SetTagOption, + DocExportFileResponse, DocGetManyResponse, DocImportFileRequest, DocImportFileResponse, + DocListResponse, DocSetHashRequest, ListTagsRequest, NodeAddrRequest, + NodeConnectionInfoRequest, NodeConnectionInfoResponse, NodeConnectionsRequest, + NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, NodeShutdownRequest, + NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, NodeWatchResponse, + Request, RpcService, SetTagOption, }; +use crate::{ + client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}, + node::DocsEngine, +}; +use crate::{client::tags::TagInfo, node::rpc::docs::ITER_CHANNEL_CAP}; +use crate::{client::NodeStatus, rpc_protocol::AuthorListResponse}; use super::NodeInner; @@ -61,6 +65,7 @@ const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; #[derive(Debug, Clone)] pub(crate) struct Handler { pub(crate) inner: Arc>, + pub(crate) docs: Option>, } impl Handler { @@ -126,92 +131,164 @@ impl Handler { BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), AuthorList(msg) => { chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.author_list(req) + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + if let Some(docs) = handler.docs { + docs.author_list(req, tx); + } else { + tx.send(Err(anyhow!("docs are disabled"))) + .expect("has capacity"); + } + rx.into_stream().map(|r| { + r.map(|author_id| AuthorListResponse { author_id }) + .map_err(Into::into) + }) }) .await } AuthorCreate(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_create(req).await + if let Some(docs) = handler.docs { + docs.author_create(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorImport(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_import(req).await + if let Some(docs) = handler.docs { + docs.author_import(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorExport(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_export(req).await + if let Some(docs) = handler.docs { + docs.author_export(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorDelete(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_delete(req).await + if let Some(docs) = handler.docs { + docs.author_delete(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorGetDefault(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_default(req) + if let Some(docs) = handler.docs { + Ok(docs.author_default(req)) + } else { + Err(docs_disabled()) + } }) .await } AuthorSetDefault(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_set_default(req).await + if let Some(docs) = handler.docs { + docs.author_set_default(req).await + } else { + Err(docs_disabled()) + } }) .await } DocOpen(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_open(req).await + if let Some(docs) = handler.docs { + docs.doc_open(req).await + } else { + Err(docs_disabled()) + } }) .await } DocClose(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_close(req).await + if let Some(docs) = handler.docs { + docs.doc_close(req).await + } else { + Err(docs_disabled()) + } }) .await } DocStatus(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_status(req).await + if let Some(docs) = handler.docs { + docs.doc_status(req).await + } else { + Err(docs_disabled()) + } }) .await } DocList(msg) => { chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.doc_list(req) + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + if let Some(docs) = handler.docs { + docs.doc_list(req, tx); + } else { + tx.send(Err(anyhow!("docs are disabled"))) + .expect("has capacity"); + } + rx.into_stream().map(|r| { + r.map(|(id, capability)| DocListResponse { id, capability }) + .map_err(Into::into) + }) }) .await } DocCreate(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_create(req).await + if let Some(docs) = handler.docs { + docs.doc_create(req).await + } else { + Err(docs_disabled()) + } }) .await } DocDrop(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_drop(req).await + if let Some(docs) = handler.docs { + docs.doc_drop(req).await + } else { + Err(docs_disabled()) + } }) .await } DocImport(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_import(req).await + if let Some(docs) = handler.docs { + docs.doc_import(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSet(msg) => { let bao_store = handler.inner.db.clone(); chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set(&bao_store, req).await + if let Some(docs) = handler.docs { + docs.doc_set(&bao_store, req).await + } else { + Err(docs_disabled()) + } }) .await } @@ -225,67 +302,117 @@ impl Handler { } DocDel(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_del(req).await + if let Some(docs) = handler.docs { + docs.doc_del(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSetHash(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set_hash(req).await + if let Some(docs) = handler.docs { + docs.doc_set_hash(req).await + } else { + Err(docs_disabled()) + } }) .await } DocGet(msg) => { chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.doc_get_many(req) + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + if let Some(docs) = handler.docs { + docs.doc_get_many(req, tx); + } else { + tx.send(Err(anyhow!("docs are disabled"))) + .expect("has capacity"); + } + rx.into_stream().map(|r| { + r.map(|entry| DocGetManyResponse { entry }) + .map_err(Into::into) + }) }) .await } DocGetExact(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_exact(req).await + if let Some(docs) = handler.docs { + docs.doc_get_exact(req).await + } else { + Err(docs_disabled()) + } }) .await } DocStartSync(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_start_sync(req).await + if let Some(docs) = handler.docs { + docs.doc_start_sync(req).await + } else { + Err(docs_disabled()) + } }) .await } DocLeave(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_leave(req).await + if let Some(docs) = handler.docs { + docs.doc_leave(req).await + } else { + Err(docs_disabled()) + } }) .await } DocShare(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_share(req).await + if let Some(docs) = handler.docs { + docs.doc_share(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSubscribe(msg) => { chan.try_server_streaming(msg, handler, |handler, req| async move { - handler.inner.sync.doc_subscribe(req).await + if let Some(docs) = handler.docs { + docs.doc_subscribe(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSetDownloadPolicy(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set_download_policy(req).await + if let Some(docs) = handler.docs { + docs.doc_set_download_policy(req).await + } else { + Err(docs_disabled()) + } }) .await } DocGetDownloadPolicy(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_download_policy(req).await + if let Some(docs) = handler.docs { + docs.doc_get_download_policy(req).await + } else { + Err(docs_disabled()) + } }) .await } DocGetSyncPeers(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_sync_peers(req).await + if let Some(docs) = handler.docs { + docs.doc_get_sync_peers(req).await + } else { + Err(docs_disabled()) + } }) .await } @@ -463,6 +590,7 @@ impl Handler { msg: DocImportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { + let docs = self.docs.ok_or_else(|| anyhow!("docs are disabled"))?; use crate::client::docs::ImportProgress as DocImportProgress; use iroh_blobs::store::ImportMode; use std::collections::BTreeMap; @@ -515,16 +643,14 @@ impl Handler { let hash_and_format = temp_tag.inner(); let HashAndFormat { hash, .. } = *hash_and_format; - self.inner - .sync - .doc_set_hash(DocSetHashRequest { - doc_id, - author_id, - key: key.clone(), - hash, - size, - }) - .await?; + docs.doc_set_hash(DocSetHashRequest { + doc_id, + author_id, + key: key.clone(), + hash, + size, + }) + .await?; drop(temp_tag); progress.send(DocImportProgress::AllDone { key }).await?; Ok(()) @@ -549,6 +675,7 @@ impl Handler { msg: DocExportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { + let _docs = self.docs.ok_or_else(|| anyhow!("docs are disabled"))?; let progress = FlumeProgressSender::new(progress); let DocExportFileRequest { entry, path, mode } = msg; let key = bytes::Bytes::from(entry.key().to_vec()); @@ -1118,3 +1245,7 @@ where res.map_err(Into::into) } + +fn docs_disabled() -> RpcError { + anyhow!("docs are disabled").into() +} diff --git a/iroh/src/node/rpc/docs.rs b/iroh/src/node/rpc/docs.rs index a0433a803e..00762945b4 100644 --- a/iroh/src/node/rpc/docs.rs +++ b/iroh/src/node/rpc/docs.rs @@ -3,7 +3,9 @@ use anyhow::anyhow; use futures_lite::Stream; use iroh_blobs::{store::Store as BaoStore, BlobFormat}; -use iroh_docs::{Author, DocTicket, NamespaceSecret}; +use iroh_docs::{ + Author, AuthorId, CapabilityKind, DocTicket, NamespaceId, NamespaceSecret, SignedEntry, +}; use tokio_stream::StreamExt; use crate::client::docs::ShareMode; @@ -11,21 +13,20 @@ use crate::node::DocsEngine; use crate::rpc_protocol::{ AuthorCreateRequest, AuthorCreateResponse, AuthorDeleteRequest, AuthorDeleteResponse, AuthorExportRequest, AuthorExportResponse, AuthorGetDefaultRequest, AuthorGetDefaultResponse, - AuthorImportRequest, AuthorImportResponse, AuthorListRequest, AuthorListResponse, - AuthorSetDefaultRequest, AuthorSetDefaultResponse, DocCloseRequest, DocCloseResponse, - DocCreateRequest, DocCreateResponse, DocDelRequest, DocDelResponse, DocDropRequest, - DocDropResponse, DocGetDownloadPolicyRequest, DocGetDownloadPolicyResponse, DocGetExactRequest, - DocGetExactResponse, DocGetManyRequest, DocGetManyResponse, DocGetSyncPeersRequest, - DocGetSyncPeersResponse, DocImportRequest, DocImportResponse, DocLeaveRequest, - DocLeaveResponse, DocListRequest, DocListResponse, DocOpenRequest, DocOpenResponse, - DocSetDownloadPolicyRequest, DocSetDownloadPolicyResponse, DocSetHashRequest, - DocSetHashResponse, DocSetRequest, DocSetResponse, DocShareRequest, DocShareResponse, - DocStartSyncRequest, DocStartSyncResponse, DocStatusRequest, DocStatusResponse, - DocSubscribeRequest, DocSubscribeResponse, RpcResult, + AuthorImportRequest, AuthorImportResponse, AuthorListRequest, AuthorSetDefaultRequest, + AuthorSetDefaultResponse, DocCloseRequest, DocCloseResponse, DocCreateRequest, + DocCreateResponse, DocDelRequest, DocDelResponse, DocDropRequest, DocDropResponse, + DocGetDownloadPolicyRequest, DocGetDownloadPolicyResponse, DocGetExactRequest, + DocGetExactResponse, DocGetManyRequest, DocGetSyncPeersRequest, DocGetSyncPeersResponse, + DocImportRequest, DocImportResponse, DocLeaveRequest, DocLeaveResponse, DocListRequest, + DocOpenRequest, DocOpenResponse, DocSetDownloadPolicyRequest, DocSetDownloadPolicyResponse, + DocSetHashRequest, DocSetHashResponse, DocSetRequest, DocSetResponse, DocShareRequest, + DocShareResponse, DocStartSyncRequest, DocStartSyncResponse, DocStatusRequest, + DocStatusResponse, DocSubscribeRequest, DocSubscribeResponse, RpcResult, }; /// Capacity for the flume channels to forward sync store iterators to async RPC streams. -const ITER_CHANNEL_CAP: usize = 64; +pub(super) const ITER_CHANNEL_CAP: usize = 64; #[allow(missing_docs)] impl DocsEngine { @@ -57,8 +58,8 @@ impl DocsEngine { pub fn author_list( &self, _req: AuthorListRequest, - ) -> impl Stream> { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + tx: flume::Sender>, + ) { let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -68,10 +69,6 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); - rx.into_stream().map(|r| { - r.map(|author_id| AuthorListResponse { author_id }) - .map_err(Into::into) - }) } pub async fn author_import(&self, req: AuthorImportRequest) -> RpcResult { @@ -108,8 +105,12 @@ impl DocsEngine { Ok(DocDropResponse {}) } - pub fn doc_list(&self, _req: DocListRequest) -> impl Stream> { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + pub fn doc_list( + &self, + _req: DocListRequest, + tx: flume::Sender>, + ) { + // let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -119,10 +120,6 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); - rx.into_stream().map(|r| { - r.map(|(id, capability)| DocListResponse { id, capability }) - .map_err(Into::into) - }) } pub async fn doc_open(&self, req: DocOpenRequest) -> RpcResult { @@ -249,9 +246,9 @@ impl DocsEngine { pub fn doc_get_many( &self, req: DocGetManyRequest, - ) -> impl Stream> { + tx: flume::Sender>, + ) { let DocGetManyRequest { doc_id, query } = req; - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -261,10 +258,6 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); - rx.into_stream().map(|r| { - r.map(|entry| DocGetManyResponse { entry }) - .map_err(Into::into) - }) } pub async fn doc_get_exact(&self, req: DocGetExactRequest) -> RpcResult { diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 8fe71e7d6a..8334590a11 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -439,7 +439,7 @@ pub struct AuthorCreateResponse { pub struct AuthorGetDefaultRequest; impl RpcMsg for AuthorGetDefaultRequest { - type Response = AuthorGetDefaultResponse; + type Response = RpcResult; } /// Response for [`AuthorGetDefaultRequest`] @@ -1153,7 +1153,7 @@ pub enum Response { AuthorList(RpcResult), AuthorCreate(RpcResult), - AuthorGetDefault(AuthorGetDefaultResponse), + AuthorGetDefault(RpcResult), AuthorSetDefault(RpcResult), AuthorImport(RpcResult), AuthorExport(RpcResult), diff --git a/iroh/tests/gc.rs b/iroh/tests/gc.rs index dcca0893b5..e032691df9 100644 --- a/iroh/tests/gc.rs +++ b/iroh/tests/gc.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Result; use bao_tree::{blake3, io::sync::Outboard, ChunkRanges}; use bytes::Bytes; -use iroh::node::{self, Node}; +use iroh::node::{self, DocsStorage, Node}; use rand::RngCore; use iroh_blobs::{ @@ -41,17 +41,19 @@ async fn wrap_in_node(bao_store: S, gc_period: Duration) -> (Node, flume:: where S: iroh_blobs::store::Store, { - let doc_store = iroh_docs::store::Store::memory(); let (gc_send, gc_recv) = flume::unbounded(); - let node = - node::Builder::with_db_and_store(bao_store, doc_store, iroh::node::StorageConfig::Mem) - .gc_policy(iroh::node::GcPolicy::Interval(gc_period)) - .register_gc_done_cb(Box::new(move || { - gc_send.send(()).ok(); - })) - .spawn() - .await - .unwrap(); + let node = node::Builder::with_db_and_store( + bao_store, + DocsStorage::Memory, + iroh::node::StorageConfig::Mem, + ) + .gc_policy(iroh::node::GcPolicy::Interval(gc_period)) + .register_gc_done_cb(Box::new(move || { + gc_send.send(()).ok(); + })) + .spawn() + .await + .unwrap(); (node, gc_recv) } diff --git a/iroh/tests/provide.rs b/iroh/tests/provide.rs index 13376273dd..7b9abf9648 100644 --- a/iroh/tests/provide.rs +++ b/iroh/tests/provide.rs @@ -8,7 +8,7 @@ use std::{ use anyhow::{Context, Result}; use bytes::Bytes; use futures_lite::FutureExt; -use iroh::node::Builder; +use iroh::node::{Builder, DocsStorage}; use iroh_base::node_addr::AddrInfoOptions; use iroh_net::{defaults::default_relay_map, key::SecretKey, NodeAddr, NodeId}; use quic_rpc::transport::misc::DummyServerEndpoint; @@ -40,8 +40,8 @@ async fn dial(secret_key: SecretKey, peer: NodeAddr) -> anyhow::Result(db: D) -> Builder { - let store = iroh_docs::store::Store::memory(); - iroh::node::Builder::with_db_and_store(db, store, iroh::node::StorageConfig::Mem).bind_port(0) + iroh::node::Builder::with_db_and_store(db, DocsStorage::Memory, iroh::node::StorageConfig::Mem) + .bind_port(0) } #[tokio::test] From 63fbcc007aeb68134ed7b83f886d3a6f0127441f Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 18:41:20 +0200 Subject: [PATCH 12/33] improvements --- iroh/src/node.rs | 16 +++++++--------- iroh/src/node/builder.rs | 23 +++++++++++------------ iroh/src/node/protocol.rs | 1 + 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 0915915989..e384594531 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -47,9 +47,7 @@ pub use protocol::Protocol; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>>, client: crate::client::MemIroh, - protocols: ProtocolMap, } #[derive(derive_more::Debug)] @@ -62,7 +60,9 @@ struct NodeInner { #[debug("rt")] rt: LocalPoolHandle, downloader: Downloader, - tasks: Mutex>>, + task: OnceCell>, + protocols: ProtocolMap, + tasks: Mutex>, } /// In memory node. @@ -153,7 +153,7 @@ impl Node { /// Returns the protocol handler for a alpn. pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - self.protocols.get::

(alpn) + self.inner.protocols.get::

(alpn) } fn downloader(&self) -> &Downloader { @@ -171,11 +171,9 @@ impl Node { pub async fn shutdown(self) -> Result<()> { self.inner.cancel_token.cancel(); - if let Ok(mut task) = Arc::try_unwrap(self.task) { - task.take().expect("cannot be empty").await?; - } - if let Some(mut tasks) = self.inner.tasks.lock().unwrap().take() { - tasks.abort_all(); + if let Ok(mut inner) = Arc::try_unwrap(self.inner) { + inner.task.take().expect("cannot be empty").await?; + inner.tasks.lock().unwrap().abort_all(); } Ok(()) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index aab60ad6d2..d238c52f78 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -28,7 +28,6 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; -use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; @@ -541,7 +540,6 @@ where debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); let blobs_store = self.blobs_store.clone(); - let mut tasks = JoinSet::new(); // initialize the downloader let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); @@ -560,13 +558,13 @@ where tasks: Default::default(), rt: lp.clone(), downloader, + task: Default::default(), + protocols: protocols.clone(), }); let node = Node { inner: inner.clone(), - task: Default::default(), client, - protocols: protocols.clone(), }; for (alpn, p) in self.protocols { @@ -583,7 +581,7 @@ where .instrument(error_span!("node", %me)), ) }; - node.task.set(task).expect("was empty"); + node.inner.task.set(task).expect("was empty"); if let GcPolicy::Interval(gc_period) = self.gc_policy { tracing::info!("Starting GC task with interval {:?}", gc_period); @@ -591,14 +589,19 @@ where let gc_done_callback = self.gc_done_callback.take(); let sync = protocols.get::(DOCS_ALPN); - tasks.spawn_local(Self::gc_loop(db, sync, gc_period, gc_done_callback)); + node.inner.tasks.lock().unwrap().spawn_local(Self::gc_loop( + db, + sync, + gc_period, + gc_done_callback, + )); } // spawn a task that updates the gossip endpoints. let mut stream = endpoint.local_endpoints(); let gossip = protocols.get::(GOSSIP_ALPN); if let Some(gossip) = gossip { - tasks.spawn(async move { + node.inner.tasks.lock().unwrap().spawn(async move { while let Some(eps) = stream.next().await { if let Err(err) = gossip.update_endpoints(&eps) { warn!("Failed to update gossip endpoints: {err:?}"); @@ -608,8 +611,6 @@ where }); } - *(node.inner.tasks.lock().unwrap()) = Some(tasks); - // Wait for a single endpoint update, to make sure // we found some endpoints tokio::time::timeout(ENDPOINT_WAIT, endpoint.local_endpoints().next()) @@ -713,9 +714,7 @@ where } }; let protocols = protocols.clone(); - let mut tasks_guard = inner.tasks.lock().unwrap(); - let tasks = tasks_guard.as_mut().expect("only empty after shutdown"); - tasks.spawn(async move { + inner.tasks.lock().unwrap().spawn(async move { if let Err(err) = handle_connection(connecting, alpn, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 3a099620fe..c6e5ce2a8b 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -38,6 +38,7 @@ impl IntoArcAny for T { } #[derive(Debug, Clone, Default)] +#[allow(clippy::type_complexity)] pub struct ProtocolMap(Arc>>>); impl ProtocolMap { From 13d3bf61aaaf2c1515e615018395e7671c6fa01d Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 18:43:53 +0200 Subject: [PATCH 13/33] simplify --- iroh-docs/src/engine.rs | 2 +- iroh-docs/src/engine/live.rs | 4 ++-- iroh-docs/src/net.rs | 3 ++- iroh/examples/custom-protocol.rs | 6 +++--- iroh/src/node/protocol.rs | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/iroh-docs/src/engine.rs b/iroh-docs/src/engine.rs index c0867b644d..b5345b0bea 100644 --- a/iroh-docs/src/engine.rs +++ b/iroh-docs/src/engine.rs @@ -197,7 +197,7 @@ impl Engine { /// Handle an incoming iroh-docs connection. pub async fn handle_connection( &self, - conn: iroh_net::endpoint::Connection, + conn: iroh_net::endpoint::Connecting, ) -> anyhow::Result<()> { self.to_live_actor .send(ToLiveActor::HandleConnection { conn }) diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 86c7cedaba..366379f4a3 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -76,7 +76,7 @@ pub enum ToLiveActor { reply: sync::oneshot::Sender>, }, HandleConnection { - conn: iroh_net::endpoint::Connection, + conn: iroh_net::endpoint::Connecting, }, AcceptSyncRequest { namespace: NamespaceId, @@ -749,7 +749,7 @@ impl LiveActor { } #[instrument("accept", skip_all)] - pub async fn handle_connection(&mut self, conn: iroh_net::endpoint::Connection) { + pub async fn handle_connection(&mut self, conn: iroh_net::endpoint::Connecting) { let to_actor_tx = self.sync_actor_tx.clone(); let accept_request_cb = move |namespace, peer| { let to_actor_tx = to_actor_tx.clone(); diff --git a/iroh-docs/src/net.rs b/iroh-docs/src/net.rs index cc29d3ec59..a3f90032e1 100644 --- a/iroh-docs/src/net.rs +++ b/iroh-docs/src/net.rs @@ -106,7 +106,7 @@ pub enum AcceptOutcome { /// Handle an iroh-docs connection and sync all shared documents in the replica store. pub async fn handle_connection( sync: SyncHandle, - connection: iroh_net::endpoint::Connection, + connecting: iroh_net::endpoint::Connecting, accept_cb: F, ) -> Result where @@ -114,6 +114,7 @@ where Fut: Future, { let t_start = Instant::now(); + let connection = connecting.await.map_err(AcceptError::connect)?; let peer = get_remote_node_id(&connection).map_err(AcceptError::connect)?; let (mut send_stream, mut recv_stream) = connection .accept_bi() diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index d8a5a54ac0..75a9650889 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -1,4 +1,4 @@ -use std::{fmt, sync::Arc}; +use std::sync::Arc; use anyhow::Result; use clap::Parser; @@ -63,7 +63,7 @@ struct ExampleProto { node: Node, } -impl Protocol for ExampleProto { +impl Protocol for ExampleProto { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { let conn = conn.await?; @@ -88,7 +88,7 @@ impl Protocol for ExampleProto { } } -impl ExampleProto { +impl ExampleProto { fn build(node: Node) -> Arc { Arc::new(Self { node }) } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index c6e5ce2a8b..3d46bb7584 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -113,6 +113,6 @@ impl Protocol for iroh_gossip::net::Gossip { impl Protocol for DocsEngine { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - Box::pin(async move { self.handle_connection(conn.await?).await }) + Box::pin(async move { self.handle_connection(conn).await }) } } From f56bb8ca86c7ba969ec1e5ac6e333950867fa77b Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 20:19:32 +0200 Subject: [PATCH 14/33] fixup --- iroh/examples/custom-protocol.rs | 2 +- iroh/src/node/builder.rs | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 75a9650889..6d9be32bae 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -89,7 +89,7 @@ impl Protocol for ExampleProto { } impl ExampleProto { - fn build(node: Node) -> Arc { + fn build(node: Node) -> Arc { Arc::new(Self { node }) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index d238c52f78..7b18a2f678 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -373,12 +373,18 @@ where } /// Accept a custom protocol. + /// + /// Use this to register custom protocols onto the iroh node. Whenever a new connection for + /// `alpn` comes in, it is passed to this protocol handler. + /// + /// `protocol_builder` is a closure that returns a future which must resolve to a + /// `Arc`. pub fn accept( mut self, alpn: &'static [u8], - protocol: impl FnOnce(Node) -> Boxed>> + Send + 'static, + protocol_builder: impl FnOnce(Node) -> Boxed>> + Send + 'static, ) -> Self { - self.protocols.push((alpn, Box::new(protocol))); + self.protocols.push((alpn, Box::new(protocol_builder))); self } From 97ddd642b16cd457653c2e23aba0941f87714bbb Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 12 Jun 2024 20:32:59 +0200 Subject: [PATCH 15/33] fixup --- iroh/src/node.rs | 19 +++++++++++-------- iroh/src/node/builder.rs | 34 +++++++++++++++++++++++++++++++--- iroh/src/node/protocol.rs | 2 +- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/iroh/src/node.rs b/iroh/src/node.rs index e384594531..1fa2aa1b7c 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -14,11 +14,12 @@ use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; -use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; -use once_cell::sync::OnceCell; +use iroh_net::{ + endpoint::LocalEndpointsStream, key::SecretKey, util::SharedAbortingJoinHandle, Endpoint, +}; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; -use tokio::task::{JoinHandle, JoinSet}; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; @@ -60,7 +61,7 @@ struct NodeInner { #[debug("rt")] rt: LocalPoolHandle, downloader: Downloader, - task: OnceCell>, + task: Mutex>>, protocols: ProtocolMap, tasks: Mutex>, } @@ -164,17 +165,19 @@ impl Node { /// /// This does not gracefully terminate currently: all connections are closed and /// anything in-transit is lost. The task will stop running. - /// If this is the last copy of the `Node`, this will finish once the task is + /// If this is the first call to this method, this will finish once the task is /// fully shutdown. /// /// The shutdown behaviour will become more graceful in the future. pub async fn shutdown(self) -> Result<()> { self.inner.cancel_token.cancel(); - if let Ok(mut inner) = Arc::try_unwrap(self.inner) { - inner.task.take().expect("cannot be empty").await?; - inner.tasks.lock().unwrap().abort_all(); + let task = self.inner.task.lock().unwrap().take(); + if let Some(task) = task { + task.await.map_err(|err| anyhow!(err))?; + self.inner.tasks.lock().unwrap().abort_all(); } + Ok(()) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 7b18a2f678..ad611e57f8 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -377,8 +377,35 @@ where /// Use this to register custom protocols onto the iroh node. Whenever a new connection for /// `alpn` comes in, it is passed to this protocol handler. /// - /// `protocol_builder` is a closure that returns a future which must resolve to a - /// `Arc`. + /// The `protocol_builder` argument is a closure that returns a future which must resolve + /// to a protocol handler. The latter is a struct that implements [`Protocol`]. Note that the + /// closure must return `Arc`. Sometimes the Rust compiler will not be able to do + /// the cast automatically, so usually you will have to cast manually: + /// + /// ```rust + /// # use anyhow::Result; + /// # use futures_lite::future::Boxed as BoxedFuture; + /// + /// const MY_ALPN: &[u8] = "my-protocol/1"; + /// + /// #[derive(Debug)] + /// struct MyProtocol; + /// + /// impl Protocol for MyProtocol { + /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + /// todo!() + /// } + /// } + /// + /// let node = Node::memory().accept(MY_ALPN |_node| Box::pin(async move { + /// let protocol = MyProtocol; + /// let protocol: Arc = Arc::new(protocol); + /// Ok(protocol) + /// })) + /// + /// ``` + /// + /// pub fn accept( mut self, alpn: &'static [u8], @@ -587,7 +614,8 @@ where .instrument(error_span!("node", %me)), ) }; - node.inner.task.set(task).expect("was empty"); + + *(node.inner.task.lock().unwrap()) = Some(task.into()); if let GcPolicy::Interval(gc_period) = self.gc_policy { tracing::info!("Starting GC task with interval {:?}", gc_period); diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 3d46bb7584..2ec044d309 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -11,7 +11,7 @@ use iroh_net::endpoint::Connecting; use crate::node::DocsEngine; -/// Trait for iroh protocol handlers. +/// Handler for incoming connections. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// From f97826e80f16babc694bb595e6be1d8fff1bfb98 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Thu, 13 Jun 2024 11:53:41 +0200 Subject: [PATCH 16/33] use JoinSet --- iroh/src/node.rs | 3 -- iroh/src/node/builder.rs | 72 ++++++++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 1fa2aa1b7c..c561464c82 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -19,7 +19,6 @@ use iroh_net::{ }; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; -use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; @@ -63,7 +62,6 @@ struct NodeInner { downloader: Downloader, task: Mutex>>, protocols: ProtocolMap, - tasks: Mutex>, } /// In memory node. @@ -175,7 +173,6 @@ impl Node { let task = self.inner.task.lock().unwrap().take(); if let Some(task) = task { task.await.map_err(|err| anyhow!(err))?; - self.inner.tasks.lock().unwrap().abort_all(); } Ok(()) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index ad611e57f8..4182e78fae 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -28,6 +28,7 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; @@ -588,7 +589,6 @@ where secret_key: self.secret_key, controller, cancel_token, - tasks: Default::default(), rt: lp.clone(), downloader, task: Default::default(), @@ -605,46 +605,52 @@ where protocols.insert(alpn, protocol); } - let task = { - let protocols = protocols.clone(); - let me = endpoint.node_id().fmt_short(); - let inner = inner.clone(); - tokio::task::spawn( - async move { Self::run(inner, protocols, self.rpc_endpoint, internal_rpc).await } - .instrument(error_span!("node", %me)), - ) - }; - - *(node.inner.task.lock().unwrap()) = Some(task.into()); + let mut join_set = JoinSet::new(); + // spawn a task that for the garbage collection. if let GcPolicy::Interval(gc_period) = self.gc_policy { tracing::info!("Starting GC task with interval {:?}", gc_period); let db = blobs_store.clone(); let gc_done_callback = self.gc_done_callback.take(); let sync = protocols.get::(DOCS_ALPN); - - node.inner.tasks.lock().unwrap().spawn_local(Self::gc_loop( - db, - sync, - gc_period, - gc_done_callback, - )); + let handle = + lp.spawn_pinned(move || Self::gc_loop(db, sync, gc_period, gc_done_callback)); + // we cannot spawn tasks that run on the local pool directly into the join set, + // so instead we create a new task that supervises the local task. + join_set.spawn(async move { + if let Err(err) = handle.await { + return Err(anyhow::Error::from(err)); + } + Ok(()) + }); } // spawn a task that updates the gossip endpoints. - let mut stream = endpoint.local_endpoints(); let gossip = protocols.get::(GOSSIP_ALPN); if let Some(gossip) = gossip { - node.inner.tasks.lock().unwrap().spawn(async move { + let mut stream = endpoint.local_endpoints(); + join_set.spawn(async move { while let Some(eps) = stream.next().await { if let Err(err) = gossip.update_endpoints(&eps) { warn!("Failed to update gossip endpoints: {err:?}"); } } warn!("failed to retrieve local endpoints"); + Ok(()) }); } + let task = { + let me = endpoint.node_id().fmt_short(); + let inner = inner.clone(); + tokio::task::spawn( + async move { Self::run(inner, self.rpc_endpoint, internal_rpc, join_set).await } + .instrument(error_span!("node", %me)), + ) + }; + + *(node.inner.task.lock().unwrap()) = Some(task.into()); + // Wait for a single endpoint update, to make sure // we found some endpoints tokio::time::timeout(ENDPOINT_WAIT, endpoint.local_endpoints().next()) @@ -655,15 +661,14 @@ where Ok(node) } - #[allow(clippy::too_many_arguments)] async fn run( inner: Arc>, - protocols: ProtocolMap, rpc: E, internal_rpc: impl ServiceEndpoint, + mut join_set: JoinSet>, ) { let server = inner.endpoint.clone(); - let docs = protocols.get::(DOCS_ALPN); + let docs = inner.protocols.get::(DOCS_ALPN); let handler = rpc::Handler { inner: inner.clone(), docs, @@ -679,7 +684,7 @@ where let cancel_token = handler.inner.cancel_token.clone(); - if let Some(gossip) = protocols.get::(GOSSIP_ALPN) { + if let Some(gossip) = inner.protocols.get::(GOSSIP_ALPN) { // forward our initial endpoints to the gossip protocol // it may happen the the first endpoint update callback is missed because the gossip cell // is only initialized once the endpoint is fully bound @@ -694,7 +699,7 @@ where biased; _ = cancel_token.cancelled() => { // clean shutdown of the blobs db to close the write transaction - handler.inner.db.shutdown().await; + inner.db.shutdown().await; // We cannot hold the RwLockReadGuard over an await point, // so we have to manually loop, clone each protocol, and drop the read guard @@ -702,7 +707,7 @@ where let mut i = 0; loop { let protocol = { - let protocols = protocols.read(); + let protocols = inner.protocols.read(); if let Some(protocol) = protocols.values().nth(i) { protocol.clone() } else { @@ -747,17 +752,26 @@ where continue; } }; - let protocols = protocols.clone(); - inner.tasks.lock().unwrap().spawn(async move { + let protocols = inner.protocols.clone(); + join_set.spawn(async move { if let Err(err) = handle_connection(connecting, alpn, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } + Ok(()) }); }, + res = join_set.join_next(), if !join_set.is_empty() => { + if let Some(Err(err)) = res { + error!("Task failed: {err:?}"); + break; + } + }, else => break, } } + join_set.shutdown().await; + // Closing the Endpoint is the equivalent of calling Connection::close on all // connections: Operations will immediately fail with // ConnectionError::LocallyClosed. All streams are interrupted, this is not From 1f108065cbfb7b08e02b5b6c0aa9aca5b68067ea Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Thu, 13 Jun 2024 11:56:14 +0200 Subject: [PATCH 17/33] improve shutdown --- iroh/src/node/builder.rs | 41 ++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 4182e78fae..af1044583d 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -698,26 +698,6 @@ where tokio::select! { biased; _ = cancel_token.cancelled() => { - // clean shutdown of the blobs db to close the write transaction - inner.db.shutdown().await; - - // We cannot hold the RwLockReadGuard over an await point, - // so we have to manually loop, clone each protocol, and drop the read guard - // before awaiting shutdown. - let mut i = 0; - loop { - let protocol = { - let protocols = inner.protocols.read(); - if let Some(protocol) = protocols.values().nth(i) { - protocol.clone() - } else { - break; - } - }; - protocol.shutdown().await; - i += 1; - } - break }, // handle rpc requests. This will do nothing if rpc is not configured, since @@ -770,6 +750,27 @@ where } } + // clean shutdown of the blobs db to close the write transaction + inner.db.shutdown().await; + + // We cannot hold the RwLockReadGuard over an await point, + // so we have to manually loop, clone each protocol, and drop the read guard + // before awaiting shutdown. + let mut i = 0; + loop { + let protocol = { + let protocols = inner.protocols.read(); + if let Some(protocol) = protocols.values().nth(i) { + protocol.clone() + } else { + break; + } + }; + protocol.shutdown().await; + i += 1; + } + + // force shutdown remaining tasks. join_set.shutdown().await; // Closing the Endpoint is the equivalent of calling Connection::close on all From 41bb9621c602b32ef9ed64543540eb057781aec8 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Thu, 13 Jun 2024 16:41:42 +0200 Subject: [PATCH 18/33] fix shutdown --- iroh-blobs/src/store/fs.rs | 2 + iroh-docs/src/engine.rs | 6 ++- iroh-docs/src/engine/live.rs | 34 ++++++++++++---- iroh/src/node.rs | 5 +++ iroh/src/node/builder.rs | 79 +++++++++++++----------------------- iroh/src/node/protocol.rs | 35 ++++++++++++++++ 6 files changed, 101 insertions(+), 60 deletions(-) diff --git a/iroh-blobs/src/store/fs.rs b/iroh-blobs/src/store/fs.rs index 5febe54457..e9e113a603 100644 --- a/iroh-blobs/src/store/fs.rs +++ b/iroh-blobs/src/store/fs.rs @@ -1486,6 +1486,8 @@ impl Actor { let mut msgs = PeekableFlumeReceiver::new(self.state.msgs.clone()); while let Some(msg) = msgs.recv() { if let ActorMessage::Shutdown { tx } = msg { + // Make sure the database is dropped before we send the reply. + drop(self); if let Some(tx) = tx { tx.send(()).ok(); } diff --git a/iroh-docs/src/engine.rs b/iroh-docs/src/engine.rs index b5345b0bea..73bb215595 100644 --- a/iroh-docs/src/engine.rs +++ b/iroh-docs/src/engine.rs @@ -207,7 +207,11 @@ impl Engine { /// Shutdown the engine. pub async fn shutdown(&self) -> Result<()> { - self.to_live_actor.send(ToLiveActor::Shutdown).await?; + let (reply, reply_rx) = oneshot::channel(); + self.to_live_actor + .send(ToLiveActor::Shutdown { reply }) + .await?; + reply_rx.await?; Ok(()) } } diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 366379f4a3..136b59fdac 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -67,7 +67,9 @@ pub enum ToLiveActor { #[debug("onsehot::Sender")] reply: sync::oneshot::Sender>, }, - Shutdown, + Shutdown { + reply: sync::oneshot::Sender<()>, + }, Subscribe { namespace: NamespaceId, #[debug("sender")] @@ -224,10 +226,18 @@ impl LiveActor { error!(?err, "Error during shutdown"); } gossip_handle.await?; - res + match res { + Ok(reply) => { + // If the shutdown is triggered from call to the shutdown method, + // trigger the reply to signal completion of the shutdown. + reply.send(()).ok(); + Ok(()) + } + Err(err) => Err(err), + } } - async fn run_inner(&mut self) -> Result<()> { + async fn run_inner(&mut self) -> Result> { let mut i = 0; loop { i += 1; @@ -237,8 +247,15 @@ impl LiveActor { msg = self.inbox.recv() => { let msg = msg.context("to_actor closed")?; trace!(?i, %msg, "tick: to_actor"); - if !self.on_actor_message(msg).await.context("on_actor_message")? { - break; + match msg { + ToLiveActor::Shutdown { reply } => { + // Return the oneshot reply to the upper-level run to send after + // shutdown is complete. + break Ok(reply); + } + msg => { + self.on_actor_message(msg).await.context("on_actor_message")?; + } } } event = self.replica_events_rx.recv_async() => { @@ -267,14 +284,13 @@ impl LiveActor { } } } - debug!("close (shutdown)"); - Ok(()) } async fn on_actor_message(&mut self, msg: ToLiveActor) -> anyhow::Result { match msg { - ToLiveActor::Shutdown => { - return Ok(false); + ToLiveActor::Shutdown { .. } => { + unreachable!("handled in run"); + // return Ok(false); } ToLiveActor::IncomingSyncReport { from, report } => { self.on_sync_report(from, report).await diff --git a/iroh/src/node.rs b/iroh/src/node.rs index c561464c82..00688fdf87 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -168,13 +168,18 @@ impl Node { /// /// The shutdown behaviour will become more graceful in the future. pub async fn shutdown(self) -> Result<()> { + // Trigger shutdown of the main run task by activating the cancel token. self.inner.cancel_token.cancel(); + // Wait for the main run task to terminate. let task = self.inner.task.lock().unwrap().take(); if let Some(task) = task { task.await.map_err(|err| anyhow!(err))?; } + // Give protocol handlers a chance to shutdown. + self.inner.protocols.shutdown().await; + Ok(()) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index af1044583d..310241d63e 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -205,7 +205,9 @@ where tokio::fs::create_dir_all(&blob_dir).await?; let blobs_store = iroh_blobs::store::fs::Store::load(&blob_dir) .await - .with_context(|| format!("Failed to load iroh database from {}", blob_dir.display()))?; + .with_context(|| { + format!("Failed to load blobs database from {}", blob_dir.display()) + })?; let docs_store = DocsStorage::Persistent(IrohPaths::DocsDatabase.with_root(root)); let v0 = blobs_store @@ -497,21 +499,6 @@ where }); } - // We clone the blob store to shut it down in case the node fails to spawn. - let blobs_store = self.blobs_store.clone(); - match self.spawn_inner(lp).await { - Ok(node) => Ok(node), - Err(err) => { - debug!("failed to spawn node, shutting down"); - blobs_store.shutdown().await; - Err(err) - } - } - } - - async fn spawn_inner(mut self, lp: LocalPoolHandle) -> Result> { - trace!("spawning node"); - let mut transport_config = quinn::TransportConfig::default(); transport_config .max_concurrent_bidi_streams(MAX_STREAMS.try_into()?) @@ -600,14 +587,22 @@ where client, }; + // Build the protocol handlers for the registered protocols. for (alpn, p) in self.protocols { - let protocol = p(node.clone()).await?; - protocols.insert(alpn, protocol); + let protocol = p(node.clone()).await; + match protocol { + Ok(protocol) => protocols.insert(alpn, protocol), + Err(err) => { + // Shutdown the protocols that were already built before returning the error. + protocols.shutdown().await; + return Err(err); + } + } } let mut join_set = JoinSet::new(); - // spawn a task that for the garbage collection. + // Spawn a task that for the garbage collection. if let GcPolicy::Interval(gc_period) = self.gc_policy { tracing::info!("Starting GC task with interval {:?}", gc_period); let db = blobs_store.clone(); @@ -615,7 +610,7 @@ where let sync = protocols.get::(DOCS_ALPN); let handle = lp.spawn_pinned(move || Self::gc_loop(db, sync, gc_period, gc_done_callback)); - // we cannot spawn tasks that run on the local pool directly into the join set, + // We cannot spawn tasks that run on the local pool directly into the join set, // so instead we create a new task that supervises the local task. join_set.spawn(async move { if let Err(err) = handle.await { @@ -625,7 +620,7 @@ where }); } - // spawn a task that updates the gossip endpoints. + // Spawn a task that updates the gossip endpoints. let gossip = protocols.get::(GOSSIP_ALPN); if let Some(gossip) = gossip { let mut stream = endpoint.local_endpoints(); @@ -667,7 +662,8 @@ where internal_rpc: impl ServiceEndpoint, mut join_set: JoinSet>, ) { - let server = inner.endpoint.clone(); + let endpoint = inner.endpoint.clone(); + let docs = inner.protocols.get::(DOCS_ALPN); let handler = rpc::Handler { inner: inner.clone(), @@ -675,7 +671,7 @@ where }; let rpc = RpcServer::new(rpc); let internal_rpc = RpcServer::new(internal_rpc); - let (ipv4, ipv6) = server.local_addr(); + let (ipv4, ipv6) = endpoint.local_addr(); debug!( "listening at: {}{}", ipv4, @@ -688,8 +684,8 @@ where // forward our initial endpoints to the gossip protocol // it may happen the the first endpoint update callback is missed because the gossip cell // is only initialized once the endpoint is fully bound - if let Some(local_endpoints) = server.local_endpoints().next().await { - debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); + if let Some(local_endpoints) = endpoint.local_endpoints().next().await { + debug!(me = ?endpoint.node_id(), "gossip initial update: {local_endpoints:?}"); gossip.update_endpoints(&local_endpoints).ok(); } } @@ -724,7 +720,7 @@ where } }, // handle incoming p2p connections - Some(mut connecting) = server.accept() => { + Some(mut connecting) = endpoint.accept() => { let alpn = match connecting.alpn().await { Ok(alpn) => alpn, Err(err) => { @@ -750,38 +746,21 @@ where } } - // clean shutdown of the blobs db to close the write transaction - inner.db.shutdown().await; - - // We cannot hold the RwLockReadGuard over an await point, - // so we have to manually loop, clone each protocol, and drop the read guard - // before awaiting shutdown. - let mut i = 0; - loop { - let protocol = { - let protocols = inner.protocols.read(); - if let Some(protocol) = protocols.values().nth(i) { - protocol.clone() - } else { - break; - } - }; - protocol.shutdown().await; - i += 1; - } - - // force shutdown remaining tasks. - join_set.shutdown().await; - // Closing the Endpoint is the equivalent of calling Connection::close on all // connections: Operations will immediately fail with // ConnectionError::LocallyClosed. All streams are interrupted, this is not // graceful. let error_code = Closed::ProviderTerminating; - server + endpoint .close(error_code.into(), error_code.reason()) .await .ok(); + + // Abort remaining tasks. + join_set.shutdown().await; + + // Shutdown of the DocsEngine and blobs store is handled in Node::shutdown through + // ProtocolMap::shutdown. } async fn gc_loop( diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 2ec044d309..1206b0c6e3 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -2,12 +2,14 @@ use std::{ any::Any, collections::HashMap, fmt, + ops::Deref, sync::{Arc, RwLock}, }; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; +use tracing::warn; use crate::node::DocsEngine; @@ -67,6 +69,26 @@ impl ProtocolMap { ) -> std::sync::RwLockReadGuard>> { self.0.read().unwrap() } + + /// Shutdown the protocol handlers. + pub(super) async fn shutdown(&self) { + // We cannot hold the RwLockReadGuard over an await point, + // so we have to manually loop, clone each protocol, and drop the read guard + // before awaiting shutdown. + let mut i = 0; + loop { + let protocol = { + let protocols = self.read(); + if let Some(protocol) = protocols.values().nth(i) { + protocol.clone() + } else { + break; + } + }; + protocol.shutdown().await; + i += 1; + } + } } #[derive(Debug)] @@ -94,6 +116,12 @@ impl Protocol for BlobsProtocol { Ok(()) }) } + + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move { + self.store.shutdown().await; + }) + } } #[derive(Debug, Clone)] @@ -115,4 +143,11 @@ impl Protocol for DocsEngine { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { self.handle_connection(conn).await }) } + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move { + if let Err(err) = self.deref().shutdown().await { + warn!("Error while shutting down docs engine: {err:?}"); + } + }) + } } From 5703d6eea840b5437ea4b0b8badf2e24509071ae Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Fri, 14 Jun 2024 00:07:07 +0200 Subject: [PATCH 19/33] fix doctest --- iroh/examples/custom-protocol.rs | 32 ++++++++++++++-------------- iroh/src/node/builder.rs | 36 ++++++++++++++++++++++---------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 6d9be32bae..4ed0de72ac 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -64,40 +64,40 @@ struct ExampleProto { } impl Protocol for ExampleProto { - fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + fn accept(self: Arc, connecting: Connecting) -> BoxedFuture> { Box::pin(async move { - let conn = conn.await?; - let remote_node_id = get_remote_node_id(&conn)?; - println!("accepted connection from {remote_node_id}"); - let mut send_stream = conn.open_uni().await?; - // not that this is something that you wanted to do, but let's create a new blob for each - // incoming connection. this could be any mechanism, but we want to demonstrate how to use a - // custom protocol together with built-in iroh functionality - let content = format!("this blob is created for my beloved peer {remote_node_id} ♥"); + let connection = connecting.await?; + let peer = get_remote_node_id(&connection)?; + println!("accepted connection from {peer}"); + let mut send_stream = connection.open_uni().await?; + // Let's create a new blob for each incoming connection. + // This functions as an example of using existing iroh functionality within a protocol + // (you likely don't want to create a new blob for each connection for real) + let content = format!("this blob is created for my beloved peer {peer} ♥"); let hash = self .node .blobs() .add_bytes(content.as_bytes().to_vec()) .await?; - // send the hash over our custom proto + // Send the hash over our custom protocol. send_stream.write_all(hash.hash.as_bytes()).await?; send_stream.finish().await?; - println!("closing connection from {remote_node_id}"); + println!("closing connection from {peer}"); Ok(()) }) } } impl ExampleProto { - fn build(node: Node) -> Arc { + pub fn build(node: Node) -> Arc { Arc::new(Self { node }) } - fn get_from_node(node: &Node, alpn: &'static [u8]) -> Option> { + pub fn get_from_node(node: &Node, alpn: &'static [u8]) -> Option> { node.get_protocol::>(alpn) } - async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { println!("our node id: {}", self.node.node_id()); println!("connecting to {remote_node_id}"); let conn = self @@ -122,8 +122,8 @@ impl ExampleProto { } } -// set the RUST_LOG env var to one of {debug,info,warn} to see logging info -pub fn setup_logging() { +/// Set the RUST_LOG env var to one of {debug,info,warn} to see logging. +fn setup_logging() { tracing_subscriber::registry() .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) .with(EnvFilter::from_default_env()) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 310241d63e..a273d60cbc 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -7,7 +7,7 @@ use std::{ }; use anyhow::{bail, Context, Result}; -use futures_lite::{future::Boxed, StreamExt}; +use futures_lite::{future::Boxed as BoxedFuture, StreamExt}; use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, @@ -61,7 +61,7 @@ const MAX_STREAMS: u64 = 10; type ProtocolBuilders = Vec<( &'static [u8], - Box) -> Boxed>> + Send + 'static>, + Box) -> BoxedFuture>> + Send + 'static>, )>; /// Storage backend for documents. @@ -386,33 +386,47 @@ where /// the cast automatically, so usually you will have to cast manually: /// /// ```rust + /// # use std::sync::Arc; /// # use anyhow::Result; /// # use futures_lite::future::Boxed as BoxedFuture; + /// # use iroh::{node::{Node, Protocol}, net::endpoint::Connecting}; + /// # + /// # #[tokio::main] + /// # async fn main() -> Result<()> { /// - /// const MY_ALPN: &[u8] = "my-protocol/1"; + /// const MY_ALPN: &[u8] = b"my-protocol/1"; /// /// #[derive(Debug)] /// struct MyProtocol; /// /// impl Protocol for MyProtocol { /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - /// todo!() + /// todo!(); /// } /// } /// - /// let node = Node::memory().accept(MY_ALPN |_node| Box::pin(async move { - /// let protocol = MyProtocol; - /// let protocol: Arc = Arc::new(protocol); - /// Ok(protocol) - /// })) - /// + /// let node = Node::memory() + /// .accept(MY_ALPN, |_node| { + /// Box::pin(async move { + /// let protocol = MyProtocol; + /// let protocol: Arc = Arc::new(protocol); + /// Ok(protocol) + /// }) + /// }) + /// .spawn() + /// .await?; + /// # node.shutdown().await?; + /// # Ok(()) + /// # } /// ``` /// /// pub fn accept( mut self, alpn: &'static [u8], - protocol_builder: impl FnOnce(Node) -> Boxed>> + Send + 'static, + protocol_builder: impl FnOnce(Node) -> BoxedFuture>> + + Send + + 'static, ) -> Self { self.protocols.push((alpn, Box::new(protocol_builder))); self From f4abba5c1438e3f2b5a48dcb5a320bc8d5063002 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Fri, 14 Jun 2024 00:40:03 +0200 Subject: [PATCH 20/33] cleanup --- iroh-docs/src/engine/live.rs | 1 - iroh/src/node/builder.rs | 79 +++++++++++++++--------------------- iroh/src/node/protocol.rs | 27 +++++++++++- 3 files changed, 59 insertions(+), 48 deletions(-) diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 136b59fdac..67bd42edbd 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -290,7 +290,6 @@ impl LiveActor { match msg { ToLiveActor::Shutdown { .. } => { unreachable!("handled in run"); - // return Ok(false); } ToLiveActor::IncomingSyncReport { from, report } => { self.on_sync_report(from, report).await diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index a273d60cbc..2360799cad 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -574,18 +574,15 @@ where debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); - let blobs_store = self.blobs_store.clone(); - - // initialize the downloader + // Initialize the downloader let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); + // Initialize the internal RPC connection. let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); - let protocols = ProtocolMap::default(); - let inner = Arc::new(NodeInner { - db: self.blobs_store, + db: self.blobs_store.clone(), endpoint: endpoint.clone(), secret_key: self.secret_key, controller, @@ -593,37 +590,26 @@ where rt: lp.clone(), downloader, task: Default::default(), - protocols: protocols.clone(), + protocols: ProtocolMap::default(), }); - let node = Node { - inner: inner.clone(), - client, - }; + let node = Node { inner, client }; // Build the protocol handlers for the registered protocols. - for (alpn, p) in self.protocols { - let protocol = p(node.clone()).await; - match protocol { - Ok(protocol) => protocols.insert(alpn, protocol), - Err(err) => { - // Shutdown the protocols that were already built before returning the error. - protocols.shutdown().await; - return Err(err); - } - } - } + node.inner + .protocols + .build(node.clone(), self.protocols) + .await?; let mut join_set = JoinSet::new(); // Spawn a task that for the garbage collection. if let GcPolicy::Interval(gc_period) = self.gc_policy { tracing::info!("Starting GC task with interval {:?}", gc_period); - let db = blobs_store.clone(); - let gc_done_callback = self.gc_done_callback.take(); - let sync = protocols.get::(DOCS_ALPN); - let handle = - lp.spawn_pinned(move || Self::gc_loop(db, sync, gc_period, gc_done_callback)); + let docs = node.get_protocol::(DOCS_ALPN); + let handle = lp.spawn_pinned(move || { + Self::gc_loop(self.blobs_store, docs, gc_period, self.gc_done_callback) + }); // We cannot spawn tasks that run on the local pool directly into the join set, // so instead we create a new task that supervises the local task. join_set.spawn(async move { @@ -635,8 +621,7 @@ where } // Spawn a task that updates the gossip endpoints. - let gossip = protocols.get::(GOSSIP_ALPN); - if let Some(gossip) = gossip { + if let Some(gossip) = node.get_protocol::(GOSSIP_ALPN) { let mut stream = endpoint.local_endpoints(); join_set.spawn(async move { while let Some(eps) = stream.next().await { @@ -649,16 +634,17 @@ where }); } - let task = { - let me = endpoint.node_id().fmt_short(); - let inner = inner.clone(); - tokio::task::spawn( - async move { Self::run(inner, self.rpc_endpoint, internal_rpc, join_set).await } - .instrument(error_span!("node", %me)), + // Spawn the main task and store it in the node for structured termination in shutdown. + let task = tokio::task::spawn( + Self::run( + node.inner.clone(), + self.rpc_endpoint, + internal_rpc, + join_set, ) - }; - - *(node.inner.task.lock().unwrap()) = Some(task.into()); + .instrument(error_span!("node", me=%endpoint.node_id().fmt_short())), + ); + *node.inner.task.lock().unwrap() = Some(task.into()); // Wait for a single endpoint update, to make sure // we found some endpoints @@ -708,7 +694,7 @@ where tokio::select! { biased; _ = cancel_token.cancelled() => { - break + break; }, // handle rpc requests. This will do nothing if rpc is not configured, since // accept is just a pending future. @@ -733,7 +719,7 @@ where } } }, - // handle incoming p2p connections + // handle incoming p2p connections. Some(mut connecting) = endpoint.accept() => { let alpn = match connecting.alpn().await { Ok(alpn) => alpn, @@ -750,6 +736,7 @@ where Ok(()) }); }, + // handle task terminations and quit on panics. res = join_set.join_next(), if !join_set.is_empty() => { if let Some(Err(err)) = res { error!("Task failed: {err:?}"); @@ -770,16 +757,16 @@ where .await .ok(); + // Shutdown protocol handlers.. + inner.protocols.shutdown().await; + // Abort remaining tasks. join_set.shutdown().await; - - // Shutdown of the DocsEngine and blobs store is handled in Node::shutdown through - // ProtocolMap::shutdown. } async fn gc_loop( db: D, - ds: Option>, + docs: Option>, gc_period: Duration, done_cb: Option>, ) { @@ -796,8 +783,8 @@ where tokio::time::sleep(gc_period).await; tracing::debug!("Starting GC"); live.clear(); - if let Some(ds) = &ds { - let doc_hashes = match ds.sync.content_hashes().await { + if let Some(docs) = &docs { + let doc_hashes = match docs.sync.content_hashes().await { Ok(hashes) => hashes, Err(err) => { tracing::warn!("Error getting doc hashes: {}", err); diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 1206b0c6e3..87a0ea7386 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -11,7 +11,7 @@ use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; use tracing::warn; -use crate::node::DocsEngine; +use crate::node::{DocsEngine, Node}; /// Handler for incoming connections. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { @@ -43,6 +43,11 @@ impl IntoArcAny for T { #[allow(clippy::type_complexity)] pub struct ProtocolMap(Arc>>>); +pub type ProtocolBuilders = Vec<( + &'static [u8], + Box) -> BoxedFuture>> + Send + 'static>, +)>; + impl ProtocolMap { /// Returns the registered protocol handler for an ALPN as a concrete type. pub fn get(&self, alpn: &[u8]) -> Option> { @@ -70,6 +75,26 @@ impl ProtocolMap { self.0.read().unwrap() } + /// Build the protocols from a list of builders. + pub(super) async fn build( + &self, + node: Node, + builders: ProtocolBuilders, + ) -> Result<()> { + for (alpn, builder) in builders { + let protocol = builder(node.clone()).await; + match protocol { + Ok(protocol) => self.insert(alpn, protocol), + Err(err) => { + // Shutdown the protocols that were already built before returning the error. + self.shutdown().await; + return Err(err); + } + } + } + Ok(()) + } + /// Shutdown the protocol handlers. pub(super) async fn shutdown(&self) { // We cannot hold the RwLockReadGuard over an await point, From 1d54b14ff1e8ff082d724950b2ec8ba02a10578f Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Fri, 14 Jun 2024 00:49:55 +0200 Subject: [PATCH 21/33] use OnceCell not Mutex --- iroh/src/node.rs | 21 +++++++++++---------- iroh/src/node/builder.rs | 2 +- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 00688fdf87..562bd87db7 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use std::net::SocketAddr; use std::path::Path; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -17,6 +17,7 @@ use iroh_docs::engine::Engine; use iroh_net::{ endpoint::LocalEndpointsStream, key::SecretKey, util::SharedAbortingJoinHandle, Endpoint, }; +use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; use tokio_util::sync::CancellationToken; @@ -60,7 +61,7 @@ struct NodeInner { #[debug("rt")] rt: LocalPoolHandle, downloader: Downloader, - task: Mutex>>, + task: OnceCell>, protocols: ProtocolMap, } @@ -171,14 +172,14 @@ impl Node { // Trigger shutdown of the main run task by activating the cancel token. self.inner.cancel_token.cancel(); - // Wait for the main run task to terminate. - let task = self.inner.task.lock().unwrap().take(); - if let Some(task) = task { - task.await.map_err(|err| anyhow!(err))?; - } - - // Give protocol handlers a chance to shutdown. - self.inner.protocols.shutdown().await; + // Wait for the main task to terminate. + self.inner + .task + .get() + .expect("is always set") + .clone() + .await + .map_err(|err| anyhow!(err))?; Ok(()) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 2360799cad..c4ddfc7191 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -644,7 +644,7 @@ where ) .instrument(error_span!("node", me=%endpoint.node_id().fmt_short())), ); - *node.inner.task.lock().unwrap() = Some(task.into()); + node.inner.task.set(task.into()).expect("was empty"); // Wait for a single endpoint update, to make sure // we found some endpoints From 1c20c7b8b0c3d36f5de7395e689d36bb3995b057 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Fri, 14 Jun 2024 00:53:53 +0200 Subject: [PATCH 22/33] docs --- iroh/src/node.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 562bd87db7..f8502fb9f8 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -160,14 +160,14 @@ impl Node { &self.inner.downloader } - /// Aborts the node. + /// Shutdown the node. /// /// This does not gracefully terminate currently: all connections are closed and - /// anything in-transit is lost. The task will stop running. - /// If this is the first call to this method, this will finish once the task is - /// fully shutdown. + /// anything in-transit is lost. The shutdown behaviour will become more graceful + /// in the future. /// - /// The shutdown behaviour will become more graceful in the future. + /// Returns a future that completes once all tasks terminated and all resources are closed. + /// The future resolves to an error if the main task panicked. pub async fn shutdown(self) -> Result<()> { // Trigger shutdown of the main run task by activating the cancel token. self.inner.cancel_token.cancel(); From 0da43759dcca5a7f673d67e223d5f8388c39bd55 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Fri, 14 Jun 2024 01:07:12 +0200 Subject: [PATCH 23/33] further cleanup --- iroh/src/node/builder.rs | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index c4ddfc7191..cc16124b01 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -6,7 +6,7 @@ use std::{ time::Duration, }; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result}; use futures_lite::{future::Boxed as BoxedFuture, StreamExt}; use iroh_base::key::SecretKey; use iroh_blobs::{ @@ -720,19 +720,10 @@ where } }, // handle incoming p2p connections. - Some(mut connecting) = endpoint.accept() => { - let alpn = match connecting.alpn().await { - Ok(alpn) => alpn, - Err(err) => { - error!("invalid handshake: {:?}", err); - continue; - } - }; + Some(connecting) = endpoint.accept() => { let protocols = inner.protocols.clone(); join_set.spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, protocols).await { - warn!("Handling incoming connection ended with error: {err}"); - } + handle_connection(connecting, protocols).await; Ok(()) }); }, @@ -860,18 +851,21 @@ impl Default for GcPolicy { } } -async fn handle_connection( - connecting: iroh_net::endpoint::Connecting, - alpn: String, - protocols: ProtocolMap, -) -> Result<()> { - let protocol = protocols.get_any(alpn.as_bytes()).clone(); - if let Some(protocol) = protocol { - protocol.accept(connecting).await?; - } else { - bail!("ignoring connection: unsupported ALPN protocol"); +async fn handle_connection(mut connecting: iroh_net::endpoint::Connecting, protocols: ProtocolMap) { + let alpn = match connecting.alpn().await { + Ok(alpn) => alpn, + Err(err) => { + warn!("Ignoring connection: invalid handshake: {:?}", err); + return; + } + }; + let Some(handler) = protocols.get_any(alpn.as_bytes()) else { + warn!("Ignoring connection: unsupported ALPN protocol"); + return; + }; + if let Err(err) = handler.accept(connecting).await { + warn!("Handling incoming connection ended with error: {err}"); } - Ok(()) } const DEFAULT_RPC_PORT: u16 = 0x1337; From f2f43e0e4791b9eea708fcbbff8c71f24c0b2460 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 18 Jun 2024 16:26:16 +0200 Subject: [PATCH 24/33] feat(iroh-net): allow to change the accepted ALPNs --- iroh-net/src/endpoint.rs | 82 +++++++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/iroh-net/src/endpoint.rs b/iroh-net/src/endpoint.rs index dd1d219569..f34116f8e6 100644 --- a/iroh-net/src/endpoint.rs +++ b/iroh-net/src/endpoint.rs @@ -126,15 +126,12 @@ impl Builder { } }; let secret_key = self.secret_key.unwrap_or_else(SecretKey::generate); - let mut server_config = make_server_config( - &secret_key, - self.alpn_protocols, - self.transport_config, - self.keylog, - )?; - if let Some(c) = self.concurrent_connections { - server_config.concurrent_connections(c); - } + let static_config = StaticConfig { + transport_config: Arc::new(self.transport_config.unwrap_or_default()), + keylog: self.keylog, + concurrent_connections: self.concurrent_connections, + secret_key: secret_key.clone(), + }; let dns_resolver = self .dns_resolver .unwrap_or_else(|| default_resolver().clone()); @@ -150,7 +147,7 @@ impl Builder { #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, }; - Endpoint::bind(Some(server_config), msock_opts, self.keylog).await + Endpoint::bind(static_config, msock_opts, self.alpn_protocols).await } // # The very common methods everyone basically needs. @@ -297,17 +294,41 @@ impl Builder { } } +/// Configuration for a [`quinn::Endpoint`] that cannot be changed at runtime. +#[derive(Debug)] +struct StaticConfig { + secret_key: SecretKey, + transport_config: Arc, + keylog: bool, + concurrent_connections: Option, +} + +impl StaticConfig { + /// Build a [`quinn::ServerConfig`] with the specified ALPN protocols. + fn build(&self, alpn_protocols: Vec>) -> Result { + let mut server_config = make_server_config( + &self.secret_key, + alpn_protocols, + self.transport_config.clone(), + self.keylog, + )?; + if let Some(c) = self.concurrent_connections { + server_config.concurrent_connections(c); + } + Ok(server_config) + } +} + /// Creates a [`quinn::ServerConfig`] with the given secret key and limits. pub fn make_server_config( secret_key: &SecretKey, alpn_protocols: Vec>, - transport_config: Option, + transport_config: Arc, keylog: bool, ) -> Result { let tls_server_config = tls::make_server_config(secret_key, alpn_protocols, keylog)?; let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_server_config)); - server_config.transport_config(Arc::new(transport_config.unwrap_or_default())); - + server_config.transport_config(transport_config); Ok(server_config) } @@ -335,12 +356,11 @@ pub fn make_server_config( /// [QUIC]: https://quicwg.org #[derive(Clone, Debug)] pub struct Endpoint { - secret_key: Arc, msock: Handle, endpoint: quinn::Endpoint, rtt_actor: Arc, - keylog: bool, cancel_token: CancellationToken, + static_config: Arc, } impl Endpoint { @@ -360,16 +380,17 @@ impl Endpoint { /// This is for internal use, the public interface is the [`Builder`] obtained from /// [Self::builder]. See the methods on the builder for documentation of the parameters. async fn bind( - server_config: Option, + static_config: StaticConfig, msock_opts: magicsock::Options, - keylog: bool, + initial_alpns: Vec>, ) -> Result { - let secret_key = msock_opts.secret_key.clone(); - let span = info_span!("magic_ep", me = %secret_key.public().fmt_short()); + let span = info_span!("magic_ep", me = %static_config.secret_key.public().fmt_short()); let _guard = span.enter(); let msock = magicsock::MagicSock::spawn(msock_opts).await?; trace!("created magicsock"); + let server_config = static_config.build(initial_alpns)?; + let mut endpoint_config = quinn::EndpointConfig::default(); // Setting this to false means that quinn will ignore packets that have the QUIC fixed bit // set to 0. The fixed bit is the 3rd bit of the first byte of a packet. @@ -380,22 +401,31 @@ impl Endpoint { let endpoint = quinn::Endpoint::new_with_abstract_socket( endpoint_config, - server_config, + Some(server_config), msock.clone(), Arc::new(quinn::TokioRuntime), )?; trace!("created quinn endpoint"); Ok(Self { - secret_key: Arc::new(secret_key), msock, endpoint, rtt_actor: Arc::new(rtt_actor::RttHandle::new()), - keylog, cancel_token: CancellationToken::new(), + static_config: Arc::new(static_config), }) } + /// Set the list of accepted ALPN protocols. + /// + /// This will only affect new incoming connections. + /// Note that this *overrides* the current list of ALPNs. + pub fn set_alpns(&self, alpns: Vec>) -> Result<()> { + let server_config = self.static_config.build(alpns)?; + self.endpoint.set_server_config(Some(server_config)); + Ok(()) + } + // # Methods for establishing connectivity. /// Connects to a remote [`Endpoint`]. @@ -481,10 +511,10 @@ impl Endpoint { let client_config = { let alpn_protocols = vec![alpn.to_vec()]; let tls_client_config = tls::make_client_config( - &self.secret_key, + &self.static_config.secret_key, Some(*node_id), alpn_protocols, - self.keylog, + self.static_config.keylog, )?; let mut client_config = quinn::ClientConfig::new(Arc::new(tls_client_config)); let mut transport_config = quinn::TransportConfig::default(); @@ -553,7 +583,7 @@ impl Endpoint { /// Returns the secret_key of this endpoint. pub fn secret_key(&self) -> &SecretKey { - &self.secret_key + &self.static_config.secret_key } /// Returns the node id of this endpoint. @@ -561,7 +591,7 @@ impl Endpoint { /// This ID is the unique addressing information of this node and other peers must know /// it to be able to connect to this node. pub fn node_id(&self) -> NodeId { - self.secret_key.public() + self.static_config.secret_key.public() } /// Returns the current [`NodeAddr`] for this endpoint. From b8165dca395622140ac86148d5501a373dc19f9f Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 18 Jun 2024 16:57:43 +0200 Subject: [PATCH 25/33] refactor: use two-stage build for node spawning and optional protocols --- iroh-docs/src/engine/live.rs | 2 - iroh/examples/custom-protocol.rs | 45 ++-- iroh/src/node.rs | 48 ++-- iroh/src/node/builder.rs | 444 +++++++++++++++++++++---------- iroh/src/node/docs.rs | 62 +++++ iroh/src/node/protocol.rs | 93 +------ iroh/src/node/rpc.rs | 213 +++++---------- iroh/src/node/rpc/docs.rs | 2 +- 8 files changed, 486 insertions(+), 423 deletions(-) create mode 100644 iroh/src/node/docs.rs diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 67bd42edbd..0ca2194958 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -228,8 +228,6 @@ impl LiveActor { gossip_handle.await?; match res { Ok(reply) => { - // If the shutdown is triggered from call to the shutdown method, - // trigger the reply to signal completion of the shutdown. reply.send(()).ok(); Ok(()) } diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 4ed0de72ac..78ae274838 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -4,12 +4,12 @@ use anyhow::Result; use clap::Parser; use futures_lite::future::Boxed as BoxedFuture; use iroh::{ - blobs::store::Store, + client::MemIroh, net::{ endpoint::{get_remote_node_id, Connecting}, - NodeId, + Endpoint, NodeId, }, - node::{Node, Protocol}, + node::Protocol, }; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -30,10 +30,10 @@ async fn main() -> Result<()> { setup_logging(); let args = Cli::parse(); // create a new node - let node = iroh::node::Node::memory() - .accept(EXAMPLE_ALPN, |node| { - Box::pin(async move { Ok(ExampleProto::build(node)) }) - }) + let builder = iroh::node::Node::memory().build().await?; + let proto = ExampleProto::new(builder.client().clone(), builder.endpoint().clone()); + let node = builder + .accept(EXAMPLE_ALPN, Arc::new(proto.clone())) .spawn() .await?; @@ -46,7 +46,6 @@ async fn main() -> Result<()> { tokio::signal::ctrl_c().await?; } Command::Connect { node: node_id } => { - let proto = ExampleProto::get_from_node(&node, EXAMPLE_ALPN).expect("it is registered"); proto.connect(node_id).await?; } } @@ -58,12 +57,13 @@ async fn main() -> Result<()> { const EXAMPLE_ALPN: &[u8] = b"example-proto/0"; -#[derive(Debug)] -struct ExampleProto { - node: Node, +#[derive(Debug, Clone)] +struct ExampleProto { + client: MemIroh, + endpoint: Endpoint, } -impl Protocol for ExampleProto { +impl Protocol for ExampleProto { fn accept(self: Arc, connecting: Connecting) -> BoxedFuture> { Box::pin(async move { let connection = connecting.await?; @@ -75,7 +75,7 @@ impl Protocol for ExampleProto { // (you likely don't want to create a new blob for each connection for real) let content = format!("this blob is created for my beloved peer {peer} ♥"); let hash = self - .node + .client .blobs() .add_bytes(content.as_bytes().to_vec()) .await?; @@ -88,34 +88,29 @@ impl Protocol for ExampleProto { } } -impl ExampleProto { - pub fn build(node: Node) -> Arc { - Arc::new(Self { node }) - } - - pub fn get_from_node(node: &Node, alpn: &'static [u8]) -> Option> { - node.get_protocol::>(alpn) +impl ExampleProto { + pub fn new(client: MemIroh, endpoint: Endpoint) -> Self { + Self { client, endpoint } } pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { - println!("our node id: {}", self.node.node_id()); + println!("our node id: {}", self.endpoint.node_id()); println!("connecting to {remote_node_id}"); let conn = self - .node - .endpoint() + .endpoint .connect_by_node_id(&remote_node_id, EXAMPLE_ALPN) .await?; let mut recv_stream = conn.accept_uni().await?; let hash_bytes = recv_stream.read_to_end(32).await?; let hash = iroh::blobs::Hash::from_bytes(hash_bytes.try_into().unwrap()); println!("received hash: {hash}"); - self.node + self.client .blobs() .download(hash, remote_node_id.into()) .await? .await?; println!("blob downloaded"); - let content = self.node.blobs().read_to_bytes(hash).await?; + let content = self.client.blobs().read_to_bytes(hash).await?; let message = String::from_utf8(content.to_vec())?; println!("blob content: {message}"); Ok(()) diff --git a/iroh/src/node.rs b/iroh/src/node.rs index f8502fb9f8..84660f30c9 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -13,20 +13,23 @@ use futures_lite::StreamExt; use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; -use iroh_docs::engine::Engine; +use iroh_gossip::net::Gossip; use iroh_net::{ endpoint::LocalEndpointsStream, key::SecretKey, util::SharedAbortingJoinHandle, Endpoint, }; -use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::{client::RpcService, node::protocol::ProtocolMap}; +use crate::{ + client::RpcService, + node::{docs::DocsEngine, protocol::ProtocolMap}, +}; mod builder; +mod docs; mod protocol; mod rpc; mod rpc_status; @@ -49,20 +52,22 @@ pub use protocol::Protocol; pub struct Node { inner: Arc>, client: crate::client::MemIroh, + task: SharedAbortingJoinHandle<()>, + protocols: Arc, } #[derive(derive_more::Debug)] struct NodeInner { db: D, + docs: Option, endpoint: Endpoint, + gossip: Gossip, secret_key: SecretKey, cancel_token: CancellationToken, controller: FlumeConnection, #[debug("rt")] rt: LocalPoolHandle, downloader: Downloader, - task: OnceCell>, - protocols: ProtocolMap, } /// In memory node. @@ -151,15 +156,6 @@ impl Node { self.inner.endpoint.my_relay() } - /// Returns the protocol handler for a alpn. - pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - self.inner.protocols.get::

(alpn) - } - - fn downloader(&self) -> &Downloader { - &self.inner.downloader - } - /// Shutdown the node. /// /// This does not gracefully terminate currently: all connections are closed and @@ -173,13 +169,7 @@ impl Node { self.inner.cancel_token.cancel(); // Wait for the main task to terminate. - self.inner - .task - .get() - .expect("is always set") - .clone() - .await - .map_err(|err| anyhow!(err))?; + self.task.await.map_err(|err| anyhow!(err))?; Ok(()) } @@ -188,6 +178,11 @@ impl Node { pub fn cancel_token(&self) -> CancellationToken { self.inner.cancel_token.clone() } + + /// Get a protocol handler. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } } impl std::ops::Deref for Node { @@ -210,17 +205,6 @@ impl NodeInner { } } -/// Wrapper around [`Engine`] so that we can implement our RPC methods directly. -#[derive(Debug, Clone)] -pub(crate) struct DocsEngine(Engine); - -impl std::ops::Deref for DocsEngine { - type Target = Engine; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - #[cfg(test)] mod tests { use std::time::Duration; diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index cc16124b01..39aa1e1668 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -14,7 +14,7 @@ use iroh_blobs::{ protocol::Closed, store::{GcMarkEvent, GcSweepEvent, Map, Store as BaoStore}, }; -use iroh_docs::engine::{DefaultAuthorStorage, Engine}; +use iroh_docs::engine::DefaultAuthorStorage; use iroh_docs::net::DOCS_ALPN; use iroh_gossip::net::{Gossip, GOSSIP_ALPN}; use iroh_net::{ @@ -24,7 +24,9 @@ use iroh_net::{ Endpoint, }; use quic_rpc::{ - transport::{misc::DummyServerEndpoint, quinn::QuinnServerEndpoint}, + transport::{ + flume::FlumeServerEndpoint, misc::DummyServerEndpoint, quinn::QuinnServerEndpoint, + }, RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; @@ -42,9 +44,7 @@ use crate::{ util::{fs::load_secret_key, path::IrohPaths}, }; -use super::{rpc, rpc_status::RpcStatus, DocsEngine, Node, NodeInner}; - -pub const PROTOCOLS: [&[u8]; 3] = [iroh_blobs::protocol::ALPN, GOSSIP_ALPN, DOCS_ALPN]; +use super::{docs::DocsEngine, rpc, rpc_status::RpcStatus, Node, NodeInner}; /// Default bind address for the node. /// 11204 is "iroh" in leetspeak @@ -120,6 +120,18 @@ pub enum StorageConfig { Persistent(PathBuf), } +impl StorageConfig { + fn default_author(&self) -> DefaultAuthorStorage { + match self { + StorageConfig::Persistent(ref root) => { + let path = IrohPaths::DefaultAuthor.with_root(root); + DefaultAuthorStorage::Persistent(path) + } + StorageConfig::Mem => DefaultAuthorStorage::Mem, + } + } +} + /// Configuration for node discovery. #[derive(Debug, Default)] pub enum DiscoveryConfig { @@ -454,65 +466,28 @@ where /// This will create the underlying network server and spawn a tokio task accepting /// connections. The returned [`Node`] can be used to control the task as well as /// get information about it. - pub async fn spawn(mut self) -> Result> { - // Register the core iroh protocols. - // Register blobs. - let lp = LocalPoolHandle::new(num_cpus::get()); - let blobs_proto = BlobsProtocol::new(self.blobs_store.clone(), lp.clone()); - self = self.accept(iroh_blobs::protocol::ALPN, move |_node| { - Box::pin(async move { - let blobs: Arc = Arc::new(blobs_proto); - Ok(blobs) - }) - }); - - // Register gossip. - self = self.accept(GOSSIP_ALPN, |node| { - Box::pin(async move { - let addr = node.endpoint().my_addr().await?; - let gossip = - Gossip::from_endpoint(node.endpoint().clone(), Default::default(), &addr.info); - let gossip: Arc = Arc::new(gossip); - Ok(gossip) - }) - }); + pub async fn spawn(self) -> Result> { + let unspawned_node = self.build().await?; + unspawned_node.spawn().await + } - if let Some(docs_store) = &self.docs_store { - // register the docs protocol. - let docs_store = match docs_store { - DocsStorage::Memory => iroh_docs::store::fs::Store::memory(), - DocsStorage::Persistent(path) => iroh_docs::store::fs::Store::persistent(path)?, - }; - // load or create the default author for documents - let default_author_storage = match self.storage { - StorageConfig::Persistent(ref root) => { - let path = IrohPaths::DefaultAuthor.with_root(root); - DefaultAuthorStorage::Persistent(path) - } - StorageConfig::Mem => DefaultAuthorStorage::Mem, - }; - let blobs_store = self.blobs_store.clone(); - self = self.accept(DOCS_ALPN, |node| { - Box::pin(async move { - let gossip = node - .get_protocol::(GOSSIP_ALPN) - .context("gossip not found")?; - let sync = Engine::spawn( - node.endpoint().clone(), - (*gossip).clone(), - docs_store, - blobs_store, - node.downloader().clone(), - default_author_storage, - ) - .await?; - let sync = DocsEngine(sync); - let sync: Arc = Arc::new(sync); - Ok(sync) - }) - }); + /// Build a node without spawning it. + /// + /// Returns an `UnspawnedNode`, on which custom protocols can be registered with + /// [`UnspawnedNode::accept`]. To spawn the node, call [`UnspawnedNode::spawn`]. + pub async fn build(self) -> Result> { + // Clone the blob store to shutdown in case of error. + let blobs_store = self.blobs_store.clone(); + match self.build_inner().await { + Ok(node) => Ok(node), + Err(err) => { + blobs_store.shutdown().await; + Err(err) + } } + } + async fn build_inner(self) -> Result> { let mut transport_config = quinn::TransportConfig::default(); transport_config .max_concurrent_bidi_streams(MAX_STREAMS.try_into()?) @@ -532,16 +507,10 @@ where } }; - let alpns = PROTOCOLS - .iter() - .chain(self.protocols.iter().map(|(alpn, _)| alpn)) - .map(|p| p.to_vec()) - .collect(); - let endpoint = Endpoint::builder() .secret_key(self.secret_key.clone()) .proxy_from_env() - .alpns(alpns) + .alpns(vec![]) .keylog(self.keylog) .transport_config(transport_config) .concurrent_connections(MAX_CONNECTIONS) @@ -570,88 +539,55 @@ where let endpoint = endpoint.bind(bind_port).await?; trace!("created quinn endpoint"); - let cancel_token = CancellationToken::new(); - - debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); + let addr = endpoint.my_addr().await?; - // Initialize the downloader + let lp = LocalPoolHandle::new(num_cpus::get()); + let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &addr.info); let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); + let docs = if let Some(docs_store) = &self.docs_store { + let docs_engine = DocsEngine::spawn( + docs_store, + self.blobs_store.clone(), + self.storage.default_author(), + endpoint.clone(), + gossip.clone(), + downloader.clone(), + ) + .await?; + Some(docs_engine) + } else { + None + }; + // Initialize the internal RPC connection. let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); + debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); let inner = Arc::new(NodeInner { db: self.blobs_store.clone(), + docs, endpoint: endpoint.clone(), secret_key: self.secret_key, controller, - cancel_token, - rt: lp.clone(), + cancel_token: CancellationToken::new(), + rt: lp, downloader, - task: Default::default(), - protocols: ProtocolMap::default(), + gossip, }); - let node = Node { inner, client }; - - // Build the protocol handlers for the registered protocols. - node.inner - .protocols - .build(node.clone(), self.protocols) - .await?; - - let mut join_set = JoinSet::new(); - - // Spawn a task that for the garbage collection. - if let GcPolicy::Interval(gc_period) = self.gc_policy { - tracing::info!("Starting GC task with interval {:?}", gc_period); - let docs = node.get_protocol::(DOCS_ALPN); - let handle = lp.spawn_pinned(move || { - Self::gc_loop(self.blobs_store, docs, gc_period, self.gc_done_callback) - }); - // We cannot spawn tasks that run on the local pool directly into the join set, - // so instead we create a new task that supervises the local task. - join_set.spawn(async move { - if let Err(err) = handle.await { - return Err(anyhow::Error::from(err)); - } - Ok(()) - }); - } - - // Spawn a task that updates the gossip endpoints. - if let Some(gossip) = node.get_protocol::(GOSSIP_ALPN) { - let mut stream = endpoint.local_endpoints(); - join_set.spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_endpoints(&eps) { - warn!("Failed to update gossip endpoints: {err:?}"); - } - } - warn!("failed to retrieve local endpoints"); - Ok(()) - }); - } - - // Spawn the main task and store it in the node for structured termination in shutdown. - let task = tokio::task::spawn( - Self::run( - node.inner.clone(), - self.rpc_endpoint, - internal_rpc, - join_set, - ) - .instrument(error_span!("node", me=%endpoint.node_id().fmt_short())), - ); - node.inner.task.set(task.into()).expect("was empty"); + let node = UnspawnedNode { + inner, + client, + protocols: Default::default(), + internal_rpc, + gc_policy: self.gc_policy, + gc_done_callback: self.gc_done_callback, + rpc_endpoint: self.rpc_endpoint, + }; - // Wait for a single endpoint update, to make sure - // we found some endpoints - tokio::time::timeout(ENDPOINT_WAIT, endpoint.local_endpoints().next()) - .await - .context("waiting for endpoint")? - .context("no endpoints")?; + let node = node.register_iroh_protocols(); Ok(node) } @@ -660,14 +596,13 @@ where inner: Arc>, rpc: E, internal_rpc: impl ServiceEndpoint, + protocols: Arc, mut join_set: JoinSet>, ) { let endpoint = inner.endpoint.clone(); - let docs = inner.protocols.get::(DOCS_ALPN); let handler = rpc::Handler { inner: inner.clone(), - docs, }; let rpc = RpcServer::new(rpc); let internal_rpc = RpcServer::new(internal_rpc); @@ -680,7 +615,7 @@ where let cancel_token = handler.inner.cancel_token.clone(); - if let Some(gossip) = inner.protocols.get::(GOSSIP_ALPN) { + if let Some(gossip) = protocols.get_typed::(GOSSIP_ALPN) { // forward our initial endpoints to the gossip protocol // it may happen the the first endpoint update callback is missed because the gossip cell // is only initialized once the endpoint is fully bound @@ -721,7 +656,7 @@ where }, // handle incoming p2p connections. Some(connecting) = endpoint.accept() => { - let protocols = inner.protocols.clone(); + let protocols = protocols.clone(); join_set.spawn(async move { handle_connection(connecting, protocols).await; Ok(()) @@ -749,7 +684,7 @@ where .ok(); // Shutdown protocol handlers.. - inner.protocols.shutdown().await; + protocols.shutdown().await; // Abort remaining tasks. join_set.shutdown().await; @@ -757,7 +692,7 @@ where async fn gc_loop( db: D, - docs: Option>, + docs: Option, gc_period: Duration, done_cb: Option>, ) { @@ -836,6 +771,226 @@ where } } +/// A node that is initialized but not yet spawned. +/// +/// This is returned from [`Builder::build`] and may be used to register custom protocols with +/// [`Self::accept`]. It provides access to the services which are already started, the node's +/// endpoint and a client to the node. +/// +/// Note that the client returned from [`Self::client`] can only be used after spawning the node, +/// until then all RPC calls will time out. +#[derive(derive_more::Debug)] +pub struct UnspawnedNode { + inner: Arc>, + client: crate::client::MemIroh, + internal_rpc: FlumeServerEndpoint, + rpc_endpoint: E, + protocols: ProtocolMap, + #[debug("callback")] + gc_done_callback: Option>, + gc_policy: GcPolicy, +} + +impl> UnspawnedNode { + /// Register a protocol handler for incoming connections. + /// + /// Use this to register custom protocols onto the iroh node. Whenever a new connection for + /// `alpn` comes in, it is passed to this protocol handler. + /// + /// See the [`Protocol`] trait for details. + /// + /// Example usage: + /// + /// ```rust + /// # use std::sync::Arc; + /// # use anyhow::Result; + /// # use futures_lite::future::Boxed as BoxedFuture; + /// # use iroh::{node::{Node, Protocol}, net::endpoint::Connecting, client::MemIroh}; + /// # + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// + /// const MY_ALPN: &[u8] = b"my-protocol/1"; + /// + /// #[derive(Debug)] + /// struct MyProtocol { + /// iroh: MemIroh + /// } + /// + /// impl Protocol for MyProtocol { + /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + /// todo!(); + /// } + /// } + /// + /// let node = Node::memory() + /// .build() + /// .await? + /// .accept(MY_ALPN, |_node| Arc::new(MyProtocol::build(node.client()))) + /// .spawn() + /// .await?; + /// # node.shutdown().await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// + pub fn accept(mut self, alpn: &'static [u8], handler: Arc) -> Self { + self.protocols.insert(alpn, handler); + self + } + + /// Return a client to control this node over an in-memory channel. + /// + /// Note that the client can only be used after spawning the node, + /// until then all RPC calls will time out. + pub fn client(&self) -> &crate::client::MemIroh { + &self.client + } + + /// Returns the [`Endpoint`] of the node. + pub fn endpoint(&self) -> &Endpoint { + &self.inner.endpoint + } + + /// Returns the [`crate::blobs::store::Store`] used by the node. + pub fn blobs_db(&self) -> &D { + &self.inner.db + } + + /// Returns a reference to the used [`LocalPoolHandle`]. + pub fn local_pool_handle(&self) -> &LocalPoolHandle { + &self.inner.rt + } + + /// Returns a reference to the [`Downloader`] used by the node. + pub fn downloader(&self) -> &Downloader { + &self.inner.downloader + } + + /// Returns a reference to the [`Gossip`] handle used by the node. + pub fn gossip(&self) -> &Gossip { + &self.inner.gossip + } + + /// Register the core iroh protocols (blobs, gossip, optionally docs). + fn register_iroh_protocols(mut self) -> Self { + // Register blobs. + let blobs_proto = + BlobsProtocol::new(self.blobs_db().clone(), self.local_pool_handle().clone()); + self = self.accept(iroh_blobs::protocol::ALPN, Arc::new(blobs_proto)); + + // Register gossip. + let gossip = self.gossip().clone(); + self = self.accept(GOSSIP_ALPN, Arc::new(gossip)); + + // Register docs, if enabled. + if let Some(docs_engine) = self.inner.docs.clone() { + self = self.accept(DOCS_ALPN, Arc::new(docs_engine)); + } + + self + } + + /// Spawn the node and start accepting connections. + pub async fn spawn(self) -> Result> { + let Self { + inner, + client, + internal_rpc, + rpc_endpoint, + protocols, + gc_done_callback, + gc_policy, + } = self; + let protocols = Arc::new(protocols); + let protocols_clone = protocols.clone(); + + // Create the actual spawn future in an async block so that we can shutdown the protocols in case of + // error. + let node_fut = async move { + let mut join_set = JoinSet::new(); + + // Spawn a task for the garbage collection. + if let GcPolicy::Interval(gc_period) = gc_policy { + tracing::info!("Starting GC task with interval {:?}", gc_period); + let lp = inner.rt.clone(); + let docs = inner.docs.clone(); + let blobs_store = inner.db.clone(); + let handle = lp.spawn_pinned(move || { + Builder::::gc_loop(blobs_store, docs, gc_period, gc_done_callback) + }); + // We cannot spawn tasks that run on the local pool directly into the join set, + // so instead we create a new task that supervises the local task. + join_set.spawn(async move { + if let Err(err) = handle.await { + return Err(anyhow::Error::from(err)); + } + Ok(()) + }); + } + + // Spawn a task that updates the gossip endpoints. + if let Some(gossip) = protocols.get_typed::(GOSSIP_ALPN) { + let mut stream = inner.endpoint.local_endpoints(); + join_set.spawn(async move { + while let Some(eps) = stream.next().await { + if let Err(err) = gossip.update_endpoints(&eps) { + warn!("Failed to update gossip endpoints: {err:?}"); + } + } + warn!("failed to retrieve local endpoints"); + Ok(()) + }); + } + + // Update the endpoint with our alpns. + let alpns = protocols + .alpns() + .map(|alpn| alpn.to_vec()) + .collect::>(); + inner.endpoint.set_alpns(alpns)?; + + // Spawn the main task and store it in the node for structured termination in shutdown. + let task = tokio::task::spawn( + Builder::run( + inner.clone(), + rpc_endpoint, + internal_rpc, + protocols.clone(), + join_set, + ) + .instrument(error_span!("node", me=%inner.endpoint.node_id().fmt_short())), + ); + + let node = Node { + inner, + client, + protocols, + task: task.into(), + }; + + // Wait for a single endpoint update, to make sure + // we found some endpoints + tokio::time::timeout(ENDPOINT_WAIT, node.endpoint().local_endpoints().next()) + .await + .context("waiting for endpoint")? + .context("no endpoints")?; + + Ok(node) + }; + + match node_fut.await { + Ok(node) => Ok(node), + Err(err) => { + // Shutdown the protocols in case of error. + protocols_clone.shutdown().await; + Err(err) + } + } + } +} + /// Policy for garbage collection. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum GcPolicy { @@ -851,7 +1006,10 @@ impl Default for GcPolicy { } } -async fn handle_connection(mut connecting: iroh_net::endpoint::Connecting, protocols: ProtocolMap) { +async fn handle_connection( + mut connecting: iroh_net::endpoint::Connecting, + protocols: Arc, +) { let alpn = match connecting.alpn().await { Ok(alpn) => alpn, Err(err) => { @@ -859,7 +1017,7 @@ async fn handle_connection(mut connecting: iroh_net::endpoint::Connecting, proto return; } }; - let Some(handler) = protocols.get_any(alpn.as_bytes()) else { + let Some(handler) = protocols.get(alpn.as_bytes()) else { warn!("Ignoring connection: unsupported ALPN protocol"); return; }; @@ -885,7 +1043,7 @@ fn make_rpc_endpoint( let mut server_config = iroh_net::endpoint::make_server_config( secret_key, vec![RPC_ALPN.to_vec()], - Some(transport_config), + Arc::new(transport_config), false, )?; server_config.concurrent_connections(MAX_RPC_CONNECTIONS); diff --git a/iroh/src/node/docs.rs b/iroh/src/node/docs.rs new file mode 100644 index 0000000000..08bd66658e --- /dev/null +++ b/iroh/src/node/docs.rs @@ -0,0 +1,62 @@ +use std::{ops::Deref, sync::Arc}; + +use anyhow::Result; +use futures_lite::future::Boxed as BoxedFuture; +use iroh_blobs::downloader::Downloader; +use iroh_gossip::net::Gossip; +use tracing::warn; + +use iroh_docs::engine::{DefaultAuthorStorage, Engine}; +use iroh_net::{endpoint::Connecting, Endpoint}; + +use crate::node::{DocsStorage, Protocol}; + +/// Wrapper around [`Engine`] so that we can implement our RPC methods directly. +#[derive(Debug, Clone)] +pub(crate) struct DocsEngine(Engine); + +impl DocsEngine { + pub async fn spawn( + storage: &DocsStorage, + blobs_store: S, + default_author_storage: DefaultAuthorStorage, + endpoint: Endpoint, + gossip: Gossip, + downloader: Downloader, + ) -> anyhow::Result { + let docs_store = match storage { + DocsStorage::Memory => iroh_docs::store::fs::Store::memory(), + DocsStorage::Persistent(path) => iroh_docs::store::fs::Store::persistent(path)?, + }; + let engine = Engine::spawn( + endpoint, + gossip, + docs_store, + blobs_store, + downloader, + default_author_storage, + ) + .await?; + Ok(DocsEngine(engine)) + } +} + +impl Deref for DocsEngine { + type Target = Engine; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Protocol for DocsEngine { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn).await }) + } + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move { + if let Err(err) = self.deref().shutdown().await { + warn!("Error while shutting down docs engine: {err:?}"); + } + }) + } +} diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 87a0ea7386..49a71b69a7 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -1,17 +1,8 @@ -use std::{ - any::Any, - collections::HashMap, - fmt, - ops::Deref, - sync::{Arc, RwLock}, -}; +use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; -use tracing::warn; - -use crate::node::{DocsEngine, Node}; /// Handler for incoming connections. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { @@ -40,78 +31,33 @@ impl IntoArcAny for T { } #[derive(Debug, Clone, Default)] -#[allow(clippy::type_complexity)] -pub struct ProtocolMap(Arc>>>); - -pub type ProtocolBuilders = Vec<( - &'static [u8], - Box) -> BoxedFuture>> + Send + 'static>, -)>; +pub(super) struct ProtocolMap(BTreeMap<&'static [u8], Arc>); impl ProtocolMap { /// Returns the registered protocol handler for an ALPN as a concrete type. - pub fn get(&self, alpn: &[u8]) -> Option> { - let protocols = self.0.read().unwrap(); - let protocol: Arc = protocols.get(alpn)?.clone(); + pub fn get_typed(&self, alpn: &[u8]) -> Option> { + let protocol: Arc = self.0.get(alpn)?.clone(); let protocol_any: Arc = protocol.into_arc_any(); let protocol_ref = Arc::downcast(protocol_any).ok()?; Some(protocol_ref) } - /// Returns the registered protocol handler for an ALPN as a `dyn Protocol`. - pub fn get_any(&self, alpn: &[u8]) -> Option> { - let protocols = self.0.read().unwrap(); - let protocol: Arc = protocols.get(alpn)?.clone(); - Some(protocol) + pub fn get(&self, alpn: &[u8]) -> Option> { + self.0.get(alpn).cloned() } - pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc) { - self.0.write().unwrap().insert(alpn, protocol); + pub fn insert(&mut self, alpn: &'static [u8], handler: Arc) { + self.0.insert(alpn, handler); } - pub(super) fn read( - &self, - ) -> std::sync::RwLockReadGuard>> { - self.0.read().unwrap() - } - - /// Build the protocols from a list of builders. - pub(super) async fn build( - &self, - node: Node, - builders: ProtocolBuilders, - ) -> Result<()> { - for (alpn, builder) in builders { - let protocol = builder(node.clone()).await; - match protocol { - Ok(protocol) => self.insert(alpn, protocol), - Err(err) => { - // Shutdown the protocols that were already built before returning the error. - self.shutdown().await; - return Err(err); - } - } - } - Ok(()) + pub fn alpns(&self) -> impl Iterator { + self.0.keys() } /// Shutdown the protocol handlers. - pub(super) async fn shutdown(&self) { - // We cannot hold the RwLockReadGuard over an await point, - // so we have to manually loop, clone each protocol, and drop the read guard - // before awaiting shutdown. - let mut i = 0; - loop { - let protocol = { - let protocols = self.read(); - if let Some(protocol) = protocols.values().nth(i) { - protocol.clone() - } else { - break; - } - }; - protocol.shutdown().await; - i += 1; + pub async fn shutdown(&self) { + for handler in self.0.values() { + handler.clone().shutdown().await; } } } @@ -163,16 +109,3 @@ impl Protocol for iroh_gossip::net::Gossip { Box::pin(async move { self.handle_connection(conn.await?).await }) } } - -impl Protocol for DocsEngine { - fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - Box::pin(async move { self.handle_connection(conn).await }) - } - fn shutdown(self: Arc) -> BoxedFuture<()> { - Box::pin(async move { - if let Err(err) = self.deref().shutdown().await { - warn!("Error while shutting down docs engine: {err:?}"); - } - }) - } -} diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 92dfade8fb..c9ab5630d6 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -47,7 +47,7 @@ use crate::rpc_protocol::{ }; use crate::{ client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}, - node::DocsEngine, + node::docs::DocsEngine, }; use crate::{client::tags::TagInfo, node::rpc::docs::ITER_CHANNEL_CAP}; use crate::{client::NodeStatus, rpc_protocol::AuthorListResponse}; @@ -65,10 +65,27 @@ const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; #[derive(Debug, Clone)] pub(crate) struct Handler { pub(crate) inner: Arc>, - pub(crate) docs: Option>, } impl Handler { + fn docs(&self) -> Option<&DocsEngine> { + self.inner.docs.as_ref() + } + + async fn with_docs(self, f: F) -> RpcResult + where + T: Send + 'static, + F: FnOnce(DocsEngine) -> Fut, + Fut: std::future::Future>, + { + if let Some(docs) = self.docs() { + let docs = docs.clone(); + f(docs).await + } else { + Err(docs_disabled()) + } + } + pub(crate) fn handle_rpc_request>( &self, msg: Request, @@ -132,7 +149,7 @@ impl Handler { AuthorList(msg) => { chan.server_streaming(msg, handler, |handler, req| { let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); - if let Some(docs) = handler.docs { + if let Some(docs) = handler.docs() { docs.author_list(req, tx); } else { tx.send(Err(anyhow!("docs are disabled"))) @@ -146,99 +163,63 @@ impl Handler { .await } AuthorCreate(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.author_create(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.author_create(req).await }) }) .await } AuthorImport(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.author_import(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.author_import(req).await }) }) .await } AuthorExport(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.author_export(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.author_export(req).await }) }) .await } AuthorDelete(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.author_delete(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.author_delete(req).await }) }) .await } AuthorGetDefault(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - Ok(docs.author_default(req)) - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { Ok(docs.author_default(req)) }) }) .await } AuthorSetDefault(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.author_set_default(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.author_set_default(req).await }) }) .await } DocOpen(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_open(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_open(req).await }) }) .await } DocClose(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_close(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_close(req).await }) }) .await } DocStatus(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_status(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_status(req).await }) }) .await } DocList(msg) => { chan.server_streaming(msg, handler, |handler, req| { let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); - if let Some(docs) = handler.docs { + if let Some(docs) = handler.docs() { docs.doc_list(req, tx); } else { tx.send(Err(anyhow!("docs are disabled"))) @@ -252,43 +233,27 @@ impl Handler { .await } DocCreate(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_create(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_create(req).await }) }) .await } DocDrop(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_drop(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_drop(req).await }) }) .await } DocImport(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_import(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_import(req).await }) }) .await } DocSet(msg) => { let bao_store = handler.inner.db.clone(); - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_set(&bao_store, req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_set(&bao_store, req).await }) }) .await } @@ -301,29 +266,21 @@ impl Handler { .await } DocDel(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_del(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_del(req).await }) }) .await } DocSetHash(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_set_hash(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_set_hash(req).await }) }) .await } DocGet(msg) => { chan.server_streaming(msg, handler, |handler, req| { let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); - if let Some(docs) = handler.docs { + if let Some(docs) = handler.docs() { docs.doc_get_many(req, tx); } else { tx.send(Err(anyhow!("docs are disabled"))) @@ -337,48 +294,32 @@ impl Handler { .await } DocGetExact(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_get_exact(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_get_exact(req).await }) }) .await } DocStartSync(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_start_sync(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_start_sync(req).await }) }) .await } DocLeave(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_leave(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_leave(req).await }) }) .await } DocShare(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_share(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_share(req).await }) }) .await } DocSubscribe(msg) => { chan.try_server_streaming(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { + if let Some(docs) = handler.docs() { docs.doc_subscribe(req).await } else { Err(docs_disabled()) @@ -387,32 +328,24 @@ impl Handler { .await } DocSetDownloadPolicy(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_set_download_policy(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs( + |docs| async move { docs.doc_set_download_policy(req).await }, + ) }) .await } DocGetDownloadPolicy(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_get_download_policy(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs( + |docs| async move { docs.doc_get_download_policy(req).await }, + ) }) .await } DocGetSyncPeers(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs { - docs.doc_get_sync_peers(req).await - } else { - Err(docs_disabled()) - } + chan.rpc(msg, handler, |handler, req| { + handler.with_docs(|docs| async move { docs.doc_get_sync_peers(req).await }) }) .await } @@ -590,7 +523,7 @@ impl Handler { msg: DocImportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { - let docs = self.docs.ok_or_else(|| anyhow!("docs are disabled"))?; + let docs = self.docs().ok_or_else(|| anyhow!("docs are disabled"))?; use crate::client::docs::ImportProgress as DocImportProgress; use iroh_blobs::store::ImportMode; use std::collections::BTreeMap; @@ -675,7 +608,7 @@ impl Handler { msg: DocExportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { - let _docs = self.docs.ok_or_else(|| anyhow!("docs are disabled"))?; + let _docs = self.docs().ok_or_else(|| anyhow!("docs are disabled"))?; let progress = FlumeProgressSender::new(progress); let DocExportFileRequest { entry, path, mode } = msg; let key = bytes::Bytes::from(entry.key().to_vec()); diff --git a/iroh/src/node/rpc/docs.rs b/iroh/src/node/rpc/docs.rs index 00762945b4..3a405b585d 100644 --- a/iroh/src/node/rpc/docs.rs +++ b/iroh/src/node/rpc/docs.rs @@ -9,7 +9,7 @@ use iroh_docs::{ use tokio_stream::StreamExt; use crate::client::docs::ShareMode; -use crate::node::DocsEngine; +use crate::node::docs::DocsEngine; use crate::rpc_protocol::{ AuthorCreateRequest, AuthorCreateResponse, AuthorDeleteRequest, AuthorDeleteResponse, AuthorExportRequest, AuthorExportResponse, AuthorGetDefaultRequest, AuthorGetDefaultResponse, From 4799ea188d693840903d85a4d1d3477cec42da8e Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 18 Jun 2024 17:09:44 +0200 Subject: [PATCH 26/33] revert making docs optional, move to separate PR --- Cargo.lock | 1 - iroh-docs/src/engine.rs | 6 +- iroh-docs/src/engine/live.rs | 31 ++--- iroh/Cargo.toml | 1 - iroh/src/client/authors.rs | 2 +- iroh/src/node.rs | 22 +++- iroh/src/node/builder.rs | 242 ++++++++++++----------------------- iroh/src/node/docs.rs | 62 --------- iroh/src/node/protocol.rs | 14 +- iroh/src/node/rpc.rs | 198 ++++++++++------------------ iroh/src/node/rpc/docs.rs | 57 +++++---- iroh/src/rpc_protocol.rs | 4 +- iroh/tests/gc.rs | 24 ++-- iroh/tests/provide.rs | 6 +- 14 files changed, 229 insertions(+), 441 deletions(-) delete mode 100644 iroh/src/node/docs.rs diff --git a/Cargo.lock b/Cargo.lock index ef3fa7ca76..a63e49d931 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2437,7 +2437,6 @@ dependencies = [ "iroh-quinn", "iroh-test", "num_cpus", - "once_cell", "parking_lot", "portable-atomic", "postcard", diff --git a/iroh-docs/src/engine.rs b/iroh-docs/src/engine.rs index 73bb215595..b5345b0bea 100644 --- a/iroh-docs/src/engine.rs +++ b/iroh-docs/src/engine.rs @@ -207,11 +207,7 @@ impl Engine { /// Shutdown the engine. pub async fn shutdown(&self) -> Result<()> { - let (reply, reply_rx) = oneshot::channel(); - self.to_live_actor - .send(ToLiveActor::Shutdown { reply }) - .await?; - reply_rx.await?; + self.to_live_actor.send(ToLiveActor::Shutdown).await?; Ok(()) } } diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 0ca2194958..366379f4a3 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -67,9 +67,7 @@ pub enum ToLiveActor { #[debug("onsehot::Sender")] reply: sync::oneshot::Sender>, }, - Shutdown { - reply: sync::oneshot::Sender<()>, - }, + Shutdown, Subscribe { namespace: NamespaceId, #[debug("sender")] @@ -226,16 +224,10 @@ impl LiveActor { error!(?err, "Error during shutdown"); } gossip_handle.await?; - match res { - Ok(reply) => { - reply.send(()).ok(); - Ok(()) - } - Err(err) => Err(err), - } + res } - async fn run_inner(&mut self) -> Result> { + async fn run_inner(&mut self) -> Result<()> { let mut i = 0; loop { i += 1; @@ -245,15 +237,8 @@ impl LiveActor { msg = self.inbox.recv() => { let msg = msg.context("to_actor closed")?; trace!(?i, %msg, "tick: to_actor"); - match msg { - ToLiveActor::Shutdown { reply } => { - // Return the oneshot reply to the upper-level run to send after - // shutdown is complete. - break Ok(reply); - } - msg => { - self.on_actor_message(msg).await.context("on_actor_message")?; - } + if !self.on_actor_message(msg).await.context("on_actor_message")? { + break; } } event = self.replica_events_rx.recv_async() => { @@ -282,12 +267,14 @@ impl LiveActor { } } } + debug!("close (shutdown)"); + Ok(()) } async fn on_actor_message(&mut self, msg: ToLiveActor) -> anyhow::Result { match msg { - ToLiveActor::Shutdown { .. } => { - unreachable!("handled in run"); + ToLiveActor::Shutdown => { + return Ok(false); } ToLiveActor::IncomingSyncReport { from, report } => { self.on_sync_report(from, report).await diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index a8b92488f7..5130f336c2 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -32,7 +32,6 @@ iroh-io = { version = "0.6.0", features = ["stats"] } iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } iroh-net = { version = "0.18.0", path = "../iroh-net" } num_cpus = { version = "1.15.0" } -once_cell = "1.17.0" portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } diff --git a/iroh/src/client/authors.rs b/iroh/src/client/authors.rs index 7cdd44ce72..e6bddbb494 100644 --- a/iroh/src/client/authors.rs +++ b/iroh/src/client/authors.rs @@ -40,7 +40,7 @@ where /// /// The default author can be set with [`Self::set_default`]. pub async fn default(&self) -> Result { - let res = self.rpc.rpc(AuthorGetDefaultRequest).await??; + let res = self.rpc.rpc(AuthorGetDefaultRequest).await?; Ok(res.author_id) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 84660f30c9..4c3b0a60dc 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -13,6 +13,7 @@ use futures_lite::StreamExt; use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; +use iroh_docs::engine::Engine; use iroh_gossip::net::Gossip; use iroh_net::{ endpoint::LocalEndpointsStream, key::SecretKey, util::SharedAbortingJoinHandle, Endpoint, @@ -23,18 +24,14 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::{ - client::RpcService, - node::{docs::DocsEngine, protocol::ProtocolMap}, -}; +use crate::{client::RpcService, node::protocol::ProtocolMap}; mod builder; -mod docs; mod protocol; mod rpc; mod rpc_status; -pub use self::builder::{Builder, DiscoveryConfig, DocsStorage, GcPolicy, StorageConfig}; +pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; pub use protocol::Protocol; @@ -59,7 +56,7 @@ pub struct Node { #[derive(derive_more::Debug)] struct NodeInner { db: D, - docs: Option, + sync: DocsEngine, endpoint: Endpoint, gossip: Gossip, secret_key: SecretKey, @@ -205,6 +202,17 @@ impl NodeInner { } } +/// Wrapper around [`Engine`] so that we can implement our RPC methods directly. +#[derive(Debug, Clone)] +pub(crate) struct DocsEngine(Engine); + +impl std::ops::Deref for DocsEngine { + type Target = Engine; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + #[cfg(test)] mod tests { use std::time::Duration; diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 39aa1e1668..0ac4185a89 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -7,14 +7,14 @@ use std::{ }; use anyhow::{Context, Result}; -use futures_lite::{future::Boxed as BoxedFuture, StreamExt}; +use futures_lite::StreamExt; use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, protocol::Closed, store::{GcMarkEvent, GcSweepEvent, Map, Store as BaoStore}, }; -use iroh_docs::engine::DefaultAuthorStorage; +use iroh_docs::engine::{DefaultAuthorStorage, Engine}; use iroh_docs::net::DOCS_ALPN; use iroh_gossip::net::{Gossip, GOSSIP_ALPN}; use iroh_net::{ @@ -44,7 +44,7 @@ use crate::{ util::{fs::load_secret_key, path::IrohPaths}, }; -use super::{docs::DocsEngine, rpc, rpc_status::RpcStatus, Node, NodeInner}; +use super::{rpc, rpc_status::RpcStatus, DocsEngine, Node, NodeInner}; /// Default bind address for the node. /// 11204 is "iroh" in leetspeak @@ -59,20 +59,6 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5); const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; -type ProtocolBuilders = Vec<( - &'static [u8], - Box) -> BoxedFuture>> + Send + 'static>, -)>; - -/// Storage backend for documents. -#[derive(Debug, Clone)] -pub enum DocsStorage { - /// In-memory storage. - Memory, - /// File-based persistent storage. - Persistent(PathBuf), -} - /// Builder for the [`Node`]. /// /// You must supply a blob store and a document store. @@ -102,8 +88,7 @@ where gc_policy: GcPolicy, dns_resolver: Option, node_discovery: DiscoveryConfig, - docs_store: Option, - protocols: ProtocolBuilders, + docs_store: iroh_docs::store::Store, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, /// Callback to register when a gc loop is done @@ -120,18 +105,6 @@ pub enum StorageConfig { Persistent(PathBuf), } -impl StorageConfig { - fn default_author(&self) -> DefaultAuthorStorage { - match self { - StorageConfig::Persistent(ref root) => { - let path = IrohPaths::DefaultAuthor.with_root(root); - DefaultAuthorStorage::Persistent(path) - } - StorageConfig::Mem => DefaultAuthorStorage::Mem, - } - } -} - /// Configuration for node discovery. #[derive(Debug, Default)] pub enum DiscoveryConfig { @@ -164,8 +137,7 @@ impl Default for Builder { dns_resolver: None, rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, - docs_store: Some(DocsStorage::Memory), - protocols: Default::default(), + docs_store: iroh_docs::store::Store::memory(), node_discovery: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, @@ -178,7 +150,7 @@ impl Builder { /// Creates a new builder for [`Node`] using the given databases. pub fn with_db_and_store( blobs_store: D, - docs_store: DocsStorage, + docs_store: iroh_docs::store::Store, storage: StorageConfig, ) -> Self { Self { @@ -191,9 +163,8 @@ impl Builder { dns_resolver: None, rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, - docs_store: Some(docs_store), + docs_store, node_discovery: Default::default(), - protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: None, @@ -220,7 +191,8 @@ where .with_context(|| { format!("Failed to load blobs database from {}", blob_dir.display()) })?; - let docs_store = DocsStorage::Persistent(IrohPaths::DocsDatabase.with_root(root)); + let docs_store = + iroh_docs::store::fs::Store::persistent(IrohPaths::DocsDatabase.with_root(root))?; let v0 = blobs_store .import_flat_store(iroh_blobs::store::fs::FlatStorePaths { @@ -256,9 +228,8 @@ where relay_mode: self.relay_mode, dns_resolver: self.dns_resolver, gc_policy: self.gc_policy, - docs_store: Some(docs_store), + docs_store, node_discovery: self.node_discovery, - protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: self.gc_done_callback, @@ -280,7 +251,6 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, - protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -307,7 +277,6 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, - protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -322,12 +291,6 @@ where self } - /// Disables documents support on this node completely. - pub fn disable_docs(mut self) -> Self { - self.docs_store = None; - self - } - /// Sets the relay servers to assist in establishing connectivity. /// /// Relay servers are used to discover other nodes by `PublicKey` and also help @@ -387,63 +350,6 @@ where self } - /// Accept a custom protocol. - /// - /// Use this to register custom protocols onto the iroh node. Whenever a new connection for - /// `alpn` comes in, it is passed to this protocol handler. - /// - /// The `protocol_builder` argument is a closure that returns a future which must resolve - /// to a protocol handler. The latter is a struct that implements [`Protocol`]. Note that the - /// closure must return `Arc`. Sometimes the Rust compiler will not be able to do - /// the cast automatically, so usually you will have to cast manually: - /// - /// ```rust - /// # use std::sync::Arc; - /// # use anyhow::Result; - /// # use futures_lite::future::Boxed as BoxedFuture; - /// # use iroh::{node::{Node, Protocol}, net::endpoint::Connecting}; - /// # - /// # #[tokio::main] - /// # async fn main() -> Result<()> { - /// - /// const MY_ALPN: &[u8] = b"my-protocol/1"; - /// - /// #[derive(Debug)] - /// struct MyProtocol; - /// - /// impl Protocol for MyProtocol { - /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - /// todo!(); - /// } - /// } - /// - /// let node = Node::memory() - /// .accept(MY_ALPN, |_node| { - /// Box::pin(async move { - /// let protocol = MyProtocol; - /// let protocol: Arc = Arc::new(protocol); - /// Ok(protocol) - /// }) - /// }) - /// .spawn() - /// .await?; - /// # node.shutdown().await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// - pub fn accept( - mut self, - alpn: &'static [u8], - protocol_builder: impl FnOnce(Node) -> BoxedFuture>> - + Send - + 'static, - ) -> Self { - self.protocols.push((alpn, Box::new(protocol_builder))); - self - } - /// Register a callback for when GC is done. #[cfg(any(test, feature = "test-utils"))] pub fn register_gc_done_cb(mut self, cb: Box) -> Self { @@ -488,6 +394,9 @@ where } async fn build_inner(self) -> Result> { + trace!("building node"); + let lp = LocalPoolHandle::new(num_cpus::get()); + let mut transport_config = quinn::TransportConfig::default(); transport_config .max_concurrent_bidi_streams(MAX_STREAMS.try_into()?) @@ -539,27 +448,37 @@ where let endpoint = endpoint.bind(bind_port).await?; trace!("created quinn endpoint"); + let cancel_token = CancellationToken::new(); + let addr = endpoint.my_addr().await?; - let lp = LocalPoolHandle::new(num_cpus::get()); + // initialize the gossip protocol let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &addr.info); + + // initialize the downloader let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); - let docs = if let Some(docs_store) = &self.docs_store { - let docs_engine = DocsEngine::spawn( - docs_store, - self.blobs_store.clone(), - self.storage.default_author(), - endpoint.clone(), - gossip.clone(), - downloader.clone(), - ) - .await?; - Some(docs_engine) - } else { - None + // load or create the default author for documents + let default_author_storage = match self.storage { + StorageConfig::Persistent(ref root) => { + let path = IrohPaths::DefaultAuthor.with_root(root); + DefaultAuthorStorage::Persistent(path) + } + StorageConfig::Mem => DefaultAuthorStorage::Mem, }; + // spawn the docs engine + let sync = Engine::spawn( + endpoint.clone(), + gossip.clone(), + self.docs_store, + self.blobs_store.clone(), + downloader.clone(), + default_author_storage, + ) + .await?; + let sync = DocsEngine(sync); + // Initialize the internal RPC connection. let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); @@ -567,11 +486,11 @@ where let inner = Arc::new(NodeInner { db: self.blobs_store.clone(), - docs, + sync, endpoint: endpoint.clone(), secret_key: self.secret_key, controller, - cancel_token: CancellationToken::new(), + cancel_token, rt: lp, downloader, gossip, @@ -615,14 +534,12 @@ where let cancel_token = handler.inner.cancel_token.clone(); - if let Some(gossip) = protocols.get_typed::(GOSSIP_ALPN) { - // forward our initial endpoints to the gossip protocol - // it may happen the the first endpoint update callback is missed because the gossip cell - // is only initialized once the endpoint is fully bound - if let Some(local_endpoints) = endpoint.local_endpoints().next().await { - debug!(me = ?endpoint.node_id(), "gossip initial update: {local_endpoints:?}"); - gossip.update_endpoints(&local_endpoints).ok(); - } + // forward our initial endpoints to the gossip protocol + // it may happen the the first endpoint update callback is missed because the gossip cell + // is only initialized once the endpoint is fully bound + if let Some(local_endpoints) = endpoint.local_endpoints().next().await { + debug!(me = ?endpoint.node_id(), "gossip initial update: {local_endpoints:?}"); + inner.gossip.update_endpoints(&local_endpoints).ok(); } loop { @@ -683,7 +600,11 @@ where .await .ok(); - // Shutdown protocol handlers.. + // Shutdown sync and blobs. + inner.sync.shutdown().await.ok(); + inner.db.shutdown().await; + + // Shutdown protocol handlers. protocols.shutdown().await; // Abort remaining tasks. @@ -692,7 +613,7 @@ where async fn gc_loop( db: D, - docs: Option, + ds: DocsEngine, gc_period: Duration, done_cb: Option>, ) { @@ -709,24 +630,23 @@ where tokio::time::sleep(gc_period).await; tracing::debug!("Starting GC"); live.clear(); - if let Some(docs) = &docs { - let doc_hashes = match docs.sync.content_hashes().await { - Ok(hashes) => hashes, + + let doc_hashes = match ds.sync.content_hashes().await { + Ok(hashes) => hashes, + Err(err) => { + tracing::warn!("Error getting doc hashes: {}", err); + continue 'outer; + } + }; + for hash in doc_hashes { + match hash { + Ok(hash) => { + live.insert(hash); + } Err(err) => { - tracing::warn!("Error getting doc hashes: {}", err); + tracing::error!("Error getting doc hash: {}", err); continue 'outer; } - }; - for hash in doc_hashes { - match hash { - Ok(hash) => { - live.insert(hash); - } - Err(err) => { - tracing::error!("Error getting doc hash: {}", err); - continue 'outer; - } - } } } @@ -873,7 +793,7 @@ impl> UnspawnedNode< &self.inner.gossip } - /// Register the core iroh protocols (blobs, gossip, optionally docs). + /// Register the core iroh protocols (blobs, gossip, docs). fn register_iroh_protocols(mut self) -> Self { // Register blobs. let blobs_proto = @@ -884,10 +804,9 @@ impl> UnspawnedNode< let gossip = self.gossip().clone(); self = self.accept(GOSSIP_ALPN, Arc::new(gossip)); - // Register docs, if enabled. - if let Some(docs_engine) = self.inner.docs.clone() { - self = self.accept(DOCS_ALPN, Arc::new(docs_engine)); - } + // Register docs. + let docs = self.inner.sync.clone(); + self = self.accept(DOCS_ALPN, Arc::new(docs)); self } @@ -915,7 +834,7 @@ impl> UnspawnedNode< if let GcPolicy::Interval(gc_period) = gc_policy { tracing::info!("Starting GC task with interval {:?}", gc_period); let lp = inner.rt.clone(); - let docs = inner.docs.clone(); + let docs = inner.sync.clone(); let blobs_store = inner.db.clone(); let handle = lp.spawn_pinned(move || { Builder::::gc_loop(blobs_store, docs, gc_period, gc_done_callback) @@ -931,18 +850,17 @@ impl> UnspawnedNode< } // Spawn a task that updates the gossip endpoints. - if let Some(gossip) = protocols.get_typed::(GOSSIP_ALPN) { - let mut stream = inner.endpoint.local_endpoints(); - join_set.spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_endpoints(&eps) { - warn!("Failed to update gossip endpoints: {err:?}"); - } + let mut stream = inner.endpoint.local_endpoints(); + let gossip = inner.gossip.clone(); + join_set.spawn(async move { + while let Some(eps) = stream.next().await { + if let Err(err) = gossip.update_endpoints(&eps) { + warn!("Failed to update gossip endpoints: {err:?}"); } - warn!("failed to retrieve local endpoints"); - Ok(()) - }); - } + } + warn!("failed to retrieve local endpoints"); + Ok(()) + }); // Update the endpoint with our alpns. let alpns = protocols diff --git a/iroh/src/node/docs.rs b/iroh/src/node/docs.rs deleted file mode 100644 index 08bd66658e..0000000000 --- a/iroh/src/node/docs.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::{ops::Deref, sync::Arc}; - -use anyhow::Result; -use futures_lite::future::Boxed as BoxedFuture; -use iroh_blobs::downloader::Downloader; -use iroh_gossip::net::Gossip; -use tracing::warn; - -use iroh_docs::engine::{DefaultAuthorStorage, Engine}; -use iroh_net::{endpoint::Connecting, Endpoint}; - -use crate::node::{DocsStorage, Protocol}; - -/// Wrapper around [`Engine`] so that we can implement our RPC methods directly. -#[derive(Debug, Clone)] -pub(crate) struct DocsEngine(Engine); - -impl DocsEngine { - pub async fn spawn( - storage: &DocsStorage, - blobs_store: S, - default_author_storage: DefaultAuthorStorage, - endpoint: Endpoint, - gossip: Gossip, - downloader: Downloader, - ) -> anyhow::Result { - let docs_store = match storage { - DocsStorage::Memory => iroh_docs::store::fs::Store::memory(), - DocsStorage::Persistent(path) => iroh_docs::store::fs::Store::persistent(path)?, - }; - let engine = Engine::spawn( - endpoint, - gossip, - docs_store, - blobs_store, - downloader, - default_author_storage, - ) - .await?; - Ok(DocsEngine(engine)) - } -} - -impl Deref for DocsEngine { - type Target = Engine; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Protocol for DocsEngine { - fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { - Box::pin(async move { self.handle_connection(conn).await }) - } - fn shutdown(self: Arc) -> BoxedFuture<()> { - Box::pin(async move { - if let Err(err) = self.deref().shutdown().await { - warn!("Error while shutting down docs engine: {err:?}"); - } - }) - } -} diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 49a71b69a7..90df8de350 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -4,6 +4,8 @@ use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; +use crate::node::DocsEngine; + /// Handler for incoming connections. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. @@ -87,12 +89,6 @@ impl Protocol for BlobsProtocol { Ok(()) }) } - - fn shutdown(self: Arc) -> BoxedFuture<()> { - Box::pin(async move { - self.store.shutdown().await; - }) - } } #[derive(Debug, Clone)] @@ -109,3 +105,9 @@ impl Protocol for iroh_gossip::net::Gossip { Box::pin(async move { self.handle_connection(conn.await?).await }) } } + +impl Protocol for DocsEngine { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn).await }) + } +} diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index c9ab5630d6..6382b50d6a 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -7,7 +7,7 @@ use anyhow::{anyhow, ensure, Result}; use futures_buffered::BufferedStreamExt; use futures_lite::{Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; -use iroh_base::rpc::{RpcError, RpcResult}; +use iroh_base::rpc::RpcResult; use iroh_blobs::downloader::{DownloadRequest, Downloader}; use iroh_blobs::export::ExportProgress; use iroh_blobs::format::collection::Collection; @@ -32,25 +32,21 @@ use quic_rpc::{ use tokio_util::task::LocalPoolHandle; use tracing::{debug, info}; +use crate::client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}; +use crate::client::tags::TagInfo; +use crate::client::NodeStatus; use crate::rpc_protocol::{ BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse, BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, - DocExportFileResponse, DocGetManyResponse, DocImportFileRequest, DocImportFileResponse, - DocListResponse, DocSetHashRequest, ListTagsRequest, NodeAddrRequest, - NodeConnectionInfoRequest, NodeConnectionInfoResponse, NodeConnectionsRequest, - NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, NodeShutdownRequest, - NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, NodeWatchResponse, - Request, RpcService, SetTagOption, + DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocSetHashRequest, + ListTagsRequest, NodeAddrRequest, NodeConnectionInfoRequest, NodeConnectionInfoResponse, + NodeConnectionsRequest, NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, + NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, + NodeWatchResponse, Request, RpcService, SetTagOption, }; -use crate::{ - client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}, - node::docs::DocsEngine, -}; -use crate::{client::tags::TagInfo, node::rpc::docs::ITER_CHANNEL_CAP}; -use crate::{client::NodeStatus, rpc_protocol::AuthorListResponse}; use super::NodeInner; @@ -68,24 +64,6 @@ pub(crate) struct Handler { } impl Handler { - fn docs(&self) -> Option<&DocsEngine> { - self.inner.docs.as_ref() - } - - async fn with_docs(self, f: F) -> RpcResult - where - T: Send + 'static, - F: FnOnce(DocsEngine) -> Fut, - Fut: std::future::Future>, - { - if let Some(docs) = self.docs() { - let docs = docs.clone(); - f(docs).await - } else { - Err(docs_disabled()) - } - } - pub(crate) fn handle_rpc_request>( &self, msg: Request, @@ -148,112 +126,92 @@ impl Handler { BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), AuthorList(msg) => { chan.server_streaming(msg, handler, |handler, req| { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); - if let Some(docs) = handler.docs() { - docs.author_list(req, tx); - } else { - tx.send(Err(anyhow!("docs are disabled"))) - .expect("has capacity"); - } - rx.into_stream().map(|r| { - r.map(|author_id| AuthorListResponse { author_id }) - .map_err(Into::into) - }) + handler.inner.sync.author_list(req) }) .await } AuthorCreate(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.author_create(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.author_create(req).await }) .await } AuthorImport(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.author_import(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.author_import(req).await }) .await } AuthorExport(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.author_export(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.author_export(req).await }) .await } AuthorDelete(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.author_delete(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.author_delete(req).await }) .await } AuthorGetDefault(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { Ok(docs.author_default(req)) }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.author_default(req) }) .await } AuthorSetDefault(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.author_set_default(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.author_set_default(req).await }) .await } DocOpen(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_open(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_open(req).await }) .await } DocClose(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_close(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_close(req).await }) .await } DocStatus(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_status(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_status(req).await }) .await } DocList(msg) => { chan.server_streaming(msg, handler, |handler, req| { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); - if let Some(docs) = handler.docs() { - docs.doc_list(req, tx); - } else { - tx.send(Err(anyhow!("docs are disabled"))) - .expect("has capacity"); - } - rx.into_stream().map(|r| { - r.map(|(id, capability)| DocListResponse { id, capability }) - .map_err(Into::into) - }) + handler.inner.sync.doc_list(req) }) .await } DocCreate(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_create(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_create(req).await }) .await } DocDrop(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_drop(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_drop(req).await }) .await } DocImport(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_import(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_import(req).await }) .await } DocSet(msg) => { let bao_store = handler.inner.db.clone(); - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_set(&bao_store, req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_set(&bao_store, req).await }) .await } @@ -266,86 +224,68 @@ impl Handler { .await } DocDel(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_del(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_del(req).await }) .await } DocSetHash(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_set_hash(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_set_hash(req).await }) .await } DocGet(msg) => { chan.server_streaming(msg, handler, |handler, req| { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); - if let Some(docs) = handler.docs() { - docs.doc_get_many(req, tx); - } else { - tx.send(Err(anyhow!("docs are disabled"))) - .expect("has capacity"); - } - rx.into_stream().map(|r| { - r.map(|entry| DocGetManyResponse { entry }) - .map_err(Into::into) - }) + handler.inner.sync.doc_get_many(req) }) .await } DocGetExact(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_get_exact(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_get_exact(req).await }) .await } DocStartSync(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_start_sync(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_start_sync(req).await }) .await } DocLeave(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_leave(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_leave(req).await }) .await } DocShare(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_share(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_share(req).await }) .await } DocSubscribe(msg) => { chan.try_server_streaming(msg, handler, |handler, req| async move { - if let Some(docs) = handler.docs() { - docs.doc_subscribe(req).await - } else { - Err(docs_disabled()) - } + handler.inner.sync.doc_subscribe(req).await }) .await } DocSetDownloadPolicy(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs( - |docs| async move { docs.doc_set_download_policy(req).await }, - ) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_set_download_policy(req).await }) .await } DocGetDownloadPolicy(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs( - |docs| async move { docs.doc_get_download_policy(req).await }, - ) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_get_download_policy(req).await }) .await } DocGetSyncPeers(msg) => { - chan.rpc(msg, handler, |handler, req| { - handler.with_docs(|docs| async move { docs.doc_get_sync_peers(req).await }) + chan.rpc(msg, handler, |handler, req| async move { + handler.inner.sync.doc_get_sync_peers(req).await }) .await } @@ -523,7 +463,6 @@ impl Handler { msg: DocImportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { - let docs = self.docs().ok_or_else(|| anyhow!("docs are disabled"))?; use crate::client::docs::ImportProgress as DocImportProgress; use iroh_blobs::store::ImportMode; use std::collections::BTreeMap; @@ -576,14 +515,16 @@ impl Handler { let hash_and_format = temp_tag.inner(); let HashAndFormat { hash, .. } = *hash_and_format; - docs.doc_set_hash(DocSetHashRequest { - doc_id, - author_id, - key: key.clone(), - hash, - size, - }) - .await?; + self.inner + .sync + .doc_set_hash(DocSetHashRequest { + doc_id, + author_id, + key: key.clone(), + hash, + size, + }) + .await?; drop(temp_tag); progress.send(DocImportProgress::AllDone { key }).await?; Ok(()) @@ -608,7 +549,6 @@ impl Handler { msg: DocExportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { - let _docs = self.docs().ok_or_else(|| anyhow!("docs are disabled"))?; let progress = FlumeProgressSender::new(progress); let DocExportFileRequest { entry, path, mode } = msg; let key = bytes::Bytes::from(entry.key().to_vec()); @@ -1178,7 +1118,3 @@ where res.map_err(Into::into) } - -fn docs_disabled() -> RpcError { - anyhow!("docs are disabled").into() -} diff --git a/iroh/src/node/rpc/docs.rs b/iroh/src/node/rpc/docs.rs index 3a405b585d..a0433a803e 100644 --- a/iroh/src/node/rpc/docs.rs +++ b/iroh/src/node/rpc/docs.rs @@ -3,30 +3,29 @@ use anyhow::anyhow; use futures_lite::Stream; use iroh_blobs::{store::Store as BaoStore, BlobFormat}; -use iroh_docs::{ - Author, AuthorId, CapabilityKind, DocTicket, NamespaceId, NamespaceSecret, SignedEntry, -}; +use iroh_docs::{Author, DocTicket, NamespaceSecret}; use tokio_stream::StreamExt; use crate::client::docs::ShareMode; -use crate::node::docs::DocsEngine; +use crate::node::DocsEngine; use crate::rpc_protocol::{ AuthorCreateRequest, AuthorCreateResponse, AuthorDeleteRequest, AuthorDeleteResponse, AuthorExportRequest, AuthorExportResponse, AuthorGetDefaultRequest, AuthorGetDefaultResponse, - AuthorImportRequest, AuthorImportResponse, AuthorListRequest, AuthorSetDefaultRequest, - AuthorSetDefaultResponse, DocCloseRequest, DocCloseResponse, DocCreateRequest, - DocCreateResponse, DocDelRequest, DocDelResponse, DocDropRequest, DocDropResponse, - DocGetDownloadPolicyRequest, DocGetDownloadPolicyResponse, DocGetExactRequest, - DocGetExactResponse, DocGetManyRequest, DocGetSyncPeersRequest, DocGetSyncPeersResponse, - DocImportRequest, DocImportResponse, DocLeaveRequest, DocLeaveResponse, DocListRequest, - DocOpenRequest, DocOpenResponse, DocSetDownloadPolicyRequest, DocSetDownloadPolicyResponse, - DocSetHashRequest, DocSetHashResponse, DocSetRequest, DocSetResponse, DocShareRequest, - DocShareResponse, DocStartSyncRequest, DocStartSyncResponse, DocStatusRequest, - DocStatusResponse, DocSubscribeRequest, DocSubscribeResponse, RpcResult, + AuthorImportRequest, AuthorImportResponse, AuthorListRequest, AuthorListResponse, + AuthorSetDefaultRequest, AuthorSetDefaultResponse, DocCloseRequest, DocCloseResponse, + DocCreateRequest, DocCreateResponse, DocDelRequest, DocDelResponse, DocDropRequest, + DocDropResponse, DocGetDownloadPolicyRequest, DocGetDownloadPolicyResponse, DocGetExactRequest, + DocGetExactResponse, DocGetManyRequest, DocGetManyResponse, DocGetSyncPeersRequest, + DocGetSyncPeersResponse, DocImportRequest, DocImportResponse, DocLeaveRequest, + DocLeaveResponse, DocListRequest, DocListResponse, DocOpenRequest, DocOpenResponse, + DocSetDownloadPolicyRequest, DocSetDownloadPolicyResponse, DocSetHashRequest, + DocSetHashResponse, DocSetRequest, DocSetResponse, DocShareRequest, DocShareResponse, + DocStartSyncRequest, DocStartSyncResponse, DocStatusRequest, DocStatusResponse, + DocSubscribeRequest, DocSubscribeResponse, RpcResult, }; /// Capacity for the flume channels to forward sync store iterators to async RPC streams. -pub(super) const ITER_CHANNEL_CAP: usize = 64; +const ITER_CHANNEL_CAP: usize = 64; #[allow(missing_docs)] impl DocsEngine { @@ -58,8 +57,8 @@ impl DocsEngine { pub fn author_list( &self, _req: AuthorListRequest, - tx: flume::Sender>, - ) { + ) -> impl Stream> { + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -69,6 +68,10 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); + rx.into_stream().map(|r| { + r.map(|author_id| AuthorListResponse { author_id }) + .map_err(Into::into) + }) } pub async fn author_import(&self, req: AuthorImportRequest) -> RpcResult { @@ -105,12 +108,8 @@ impl DocsEngine { Ok(DocDropResponse {}) } - pub fn doc_list( - &self, - _req: DocListRequest, - tx: flume::Sender>, - ) { - // let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + pub fn doc_list(&self, _req: DocListRequest) -> impl Stream> { + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -120,6 +119,10 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); + rx.into_stream().map(|r| { + r.map(|(id, capability)| DocListResponse { id, capability }) + .map_err(Into::into) + }) } pub async fn doc_open(&self, req: DocOpenRequest) -> RpcResult { @@ -246,9 +249,9 @@ impl DocsEngine { pub fn doc_get_many( &self, req: DocGetManyRequest, - tx: flume::Sender>, - ) { + ) -> impl Stream> { let DocGetManyRequest { doc_id, query } = req; + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -258,6 +261,10 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); + rx.into_stream().map(|r| { + r.map(|entry| DocGetManyResponse { entry }) + .map_err(Into::into) + }) } pub async fn doc_get_exact(&self, req: DocGetExactRequest) -> RpcResult { diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 8334590a11..8fe71e7d6a 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -439,7 +439,7 @@ pub struct AuthorCreateResponse { pub struct AuthorGetDefaultRequest; impl RpcMsg for AuthorGetDefaultRequest { - type Response = RpcResult; + type Response = AuthorGetDefaultResponse; } /// Response for [`AuthorGetDefaultRequest`] @@ -1153,7 +1153,7 @@ pub enum Response { AuthorList(RpcResult), AuthorCreate(RpcResult), - AuthorGetDefault(RpcResult), + AuthorGetDefault(AuthorGetDefaultResponse), AuthorSetDefault(RpcResult), AuthorImport(RpcResult), AuthorExport(RpcResult), diff --git a/iroh/tests/gc.rs b/iroh/tests/gc.rs index e032691df9..dcca0893b5 100644 --- a/iroh/tests/gc.rs +++ b/iroh/tests/gc.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Result; use bao_tree::{blake3, io::sync::Outboard, ChunkRanges}; use bytes::Bytes; -use iroh::node::{self, DocsStorage, Node}; +use iroh::node::{self, Node}; use rand::RngCore; use iroh_blobs::{ @@ -41,19 +41,17 @@ async fn wrap_in_node(bao_store: S, gc_period: Duration) -> (Node, flume:: where S: iroh_blobs::store::Store, { + let doc_store = iroh_docs::store::Store::memory(); let (gc_send, gc_recv) = flume::unbounded(); - let node = node::Builder::with_db_and_store( - bao_store, - DocsStorage::Memory, - iroh::node::StorageConfig::Mem, - ) - .gc_policy(iroh::node::GcPolicy::Interval(gc_period)) - .register_gc_done_cb(Box::new(move || { - gc_send.send(()).ok(); - })) - .spawn() - .await - .unwrap(); + let node = + node::Builder::with_db_and_store(bao_store, doc_store, iroh::node::StorageConfig::Mem) + .gc_policy(iroh::node::GcPolicy::Interval(gc_period)) + .register_gc_done_cb(Box::new(move || { + gc_send.send(()).ok(); + })) + .spawn() + .await + .unwrap(); (node, gc_recv) } diff --git a/iroh/tests/provide.rs b/iroh/tests/provide.rs index 7b9abf9648..13376273dd 100644 --- a/iroh/tests/provide.rs +++ b/iroh/tests/provide.rs @@ -8,7 +8,7 @@ use std::{ use anyhow::{Context, Result}; use bytes::Bytes; use futures_lite::FutureExt; -use iroh::node::{Builder, DocsStorage}; +use iroh::node::Builder; use iroh_base::node_addr::AddrInfoOptions; use iroh_net::{defaults::default_relay_map, key::SecretKey, NodeAddr, NodeId}; use quic_rpc::transport::misc::DummyServerEndpoint; @@ -40,8 +40,8 @@ async fn dial(secret_key: SecretKey, peer: NodeAddr) -> anyhow::Result(db: D) -> Builder { - iroh::node::Builder::with_db_and_store(db, DocsStorage::Memory, iroh::node::StorageConfig::Mem) - .bind_port(0) + let store = iroh_docs::store::Store::memory(); + iroh::node::Builder::with_db_and_store(db, store, iroh::node::StorageConfig::Mem).bind_port(0) } #[tokio::test] From 766c8e5a4efce9922633f339ee79763efdd954ee Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 18 Jun 2024 23:23:34 +0200 Subject: [PATCH 27/33] refactor: concurrent shutdown --- iroh/src/node/builder.rs | 32 ++++++++++++++++---------------- iroh/src/node/protocol.rs | 13 +++++++++---- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 0ac4185a89..45d01ec850 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -534,7 +534,7 @@ where let cancel_token = handler.inner.cancel_token.clone(); - // forward our initial endpoints to the gossip protocol + // forward the initial endpoints to the gossip protocol. // it may happen the the first endpoint update callback is missed because the gossip cell // is only initialized once the endpoint is fully bound if let Some(local_endpoints) = endpoint.local_endpoints().next().await { @@ -590,22 +590,22 @@ where } } - // Closing the Endpoint is the equivalent of calling Connection::close on all - // connections: Operations will immediately fail with - // ConnectionError::LocallyClosed. All streams are interrupted, this is not - // graceful. + // Shutdown the different parts of the node concurrently. let error_code = Closed::ProviderTerminating; - endpoint - .close(error_code.into(), error_code.reason()) - .await - .ok(); - - // Shutdown sync and blobs. - inner.sync.shutdown().await.ok(); - inner.db.shutdown().await; - - // Shutdown protocol handlers. - protocols.shutdown().await; + // We ignore all errors during shutdown. + let _ = tokio::join!( + // Close the endpoint. + // Closing the Endpoint is the equivalent of calling Connection::close on all + // connections: Operations will immediately fail with ConnectionError::LocallyClosed. + // All streams are interrupted, this is not graceful. + endpoint.close(error_code.into(), error_code.reason()), + // Shutdown sync engine. + inner.sync.shutdown(), + // Shutdown blobs store engine. + inner.db.shutdown(), + // Shutdown protocol handlers. + protocols.shutdown(), + ); // Abort remaining tasks. join_set.shutdown().await; diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 90df8de350..528014437d 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -2,6 +2,7 @@ use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; +use futures_util::future::join_all; use iroh_net::endpoint::Connecting; use crate::node::DocsEngine; @@ -44,23 +45,27 @@ impl ProtocolMap { Some(protocol_ref) } + /// Returns the registered protocol handler for an ALPN as a [`Arc`]. pub fn get(&self, alpn: &[u8]) -> Option> { self.0.get(alpn).cloned() } + /// Insert a protocol handler. pub fn insert(&mut self, alpn: &'static [u8], handler: Arc) { self.0.insert(alpn, handler); } + /// Returns an iterator of all registered ALPN protocol identifiers. pub fn alpns(&self) -> impl Iterator { self.0.keys() } - /// Shutdown the protocol handlers. + /// Shutdown all protocol handlers. + /// + /// Calls and awaits [`Protocol::shutdown`] for all registered handlers concurrently. pub async fn shutdown(&self) { - for handler in self.0.values() { - handler.clone().shutdown().await; - } + let handlers = self.0.values().cloned().map(Protocol::shutdown); + join_all(handlers).await; } } From d1c75e331e459b3848beec77d1a3783478787efc Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 18 Jun 2024 23:31:20 +0200 Subject: [PATCH 28/33] fix: no need for empty alpn call in endpoint builder --- iroh/src/node/builder.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 45d01ec850..aa5d21142d 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -419,7 +419,6 @@ where let endpoint = Endpoint::builder() .secret_key(self.secret_key.clone()) .proxy_from_env() - .alpns(vec![]) .keylog(self.keylog) .transport_config(transport_config) .concurrent_connections(MAX_CONNECTIONS) From c1d0e5d918b488ad6f4573fa3d619c863d7b2488 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Tue, 18 Jun 2024 23:54:46 +0200 Subject: [PATCH 29/33] fix: doctest --- iroh/src/node/builder.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index c8dbc5ce85..420d693bc0 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -733,7 +733,7 @@ impl> UnspawnedNode< /// /// #[derive(Debug)] /// struct MyProtocol { - /// iroh: MemIroh + /// client: MemIroh /// } /// /// impl Protocol for MyProtocol { @@ -742,10 +742,15 @@ impl> UnspawnedNode< /// } /// } /// - /// let node = Node::memory() + /// let unspawned_node = Node::memory() /// .build() - /// .await? - /// .accept(MY_ALPN, |_node| Arc::new(MyProtocol::build(node.client()))) + /// .await?; + /// + /// let client = unspawned_node.client().clone(); + /// let handler = MyProtocol { client }; + /// + /// let node = unspawned_node + /// .accept(MY_ALPN, Arc::new(handler)) /// .spawn() /// .await?; /// # node.shutdown().await?; From 34ea612e732dfa02ddb68841f2d7c42cff11f0c9 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 19 Jun 2024 12:13:31 +0200 Subject: [PATCH 30/33] address PR review --- iroh-net/src/endpoint.rs | 8 ++++---- iroh/src/node.rs | 5 ++++- iroh/src/node/builder.rs | 16 ++++++++++++---- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/iroh-net/src/endpoint.rs b/iroh-net/src/endpoint.rs index b8a81342de..b741f47178 100644 --- a/iroh-net/src/endpoint.rs +++ b/iroh-net/src/endpoint.rs @@ -303,8 +303,8 @@ struct StaticConfig { } impl StaticConfig { - /// Build a [`quinn::ServerConfig`] with the specified ALPN protocols. - fn build(&self, alpn_protocols: Vec>) -> Result { + /// Create a [`quinn::ServerConfig`] with the specified ALPN protocols. + fn create_server_config(&self, alpn_protocols: Vec>) -> Result { let mut server_config = make_server_config( &self.secret_key, alpn_protocols, @@ -388,7 +388,7 @@ impl Endpoint { let msock = magicsock::MagicSock::spawn(msock_opts).await?; trace!("created magicsock"); - let server_config = static_config.build(initial_alpns)?; + let server_config = static_config.create_server_config(initial_alpns)?; let mut endpoint_config = quinn::EndpointConfig::default(); // Setting this to false means that quinn will ignore packets that have the QUIC fixed bit @@ -420,7 +420,7 @@ impl Endpoint { /// This will only affect new incoming connections. /// Note that this *overrides* the current list of ALPNs. pub fn set_alpns(&self, alpns: Vec>) -> Result<()> { - let server_config = self.static_config.build(alpns)?; + let server_config = self.static_config.create_server_config(alpns)?; self.endpoint.set_server_config(Some(server_config)); Ok(()) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index c88de2aa9e..93f45001c3 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -176,7 +176,10 @@ impl Node { self.inner.cancel_token.clone() } - /// Get a protocol handler. + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. pub fn get_protocol(&self, alpn: &[u8]) -> Option> { self.protocols.get_typed(alpn) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 420d693bc0..eb0ef4f9a2 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -696,8 +696,8 @@ where /// [`Self::accept`]. It provides access to the services which are already started, the node's /// endpoint and a client to the node. /// -/// Note that the client returned from [`Self::client`] can only be used after spawning the node, -/// until then all RPC calls will time out. +/// Note that RPC calls performed with client returned from [`Self::client`] will not complete +/// until the node is spawned. #[derive(derive_more::Debug)] pub struct UnspawnedNode { inner: Arc>, @@ -766,8 +766,8 @@ impl> UnspawnedNode< /// Return a client to control this node over an in-memory channel. /// - /// Note that the client can only be used after spawning the node, - /// until then all RPC calls will time out. + /// Note that RPC calls performed with the client will not complete until the node is + /// spawned. pub fn client(&self) -> &crate::client::MemIroh { &self.client } @@ -797,6 +797,14 @@ impl> UnspawnedNode< &self.inner.gossip } + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } + /// Register the core iroh protocols (blobs, gossip, docs). fn register_iroh_protocols(mut self) -> Self { // Register blobs. From 823005f10c8087b8ecbec1c68524c3915b6ea9d7 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 19 Jun 2024 12:19:35 +0200 Subject: [PATCH 31/33] rename UnspawnedNode to ProtocolBuilder --- iroh/src/node/builder.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index eb0ef4f9a2..9f6c595cb3 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -379,9 +379,9 @@ where /// Build a node without spawning it. /// - /// Returns an `UnspawnedNode`, on which custom protocols can be registered with - /// [`UnspawnedNode::accept`]. To spawn the node, call [`UnspawnedNode::spawn`]. - pub async fn build(self) -> Result> { + /// Returns an `ProtocolBuilder`, on which custom protocols can be registered with + /// [`ProtocolBuilder::accept`]. To spawn the node, call [`ProtocolBuilder::spawn`]. + pub async fn build(self) -> Result> { // Clone the blob store to shutdown in case of error. let blobs_store = self.blobs_store.clone(); match self.build_inner().await { @@ -393,7 +393,7 @@ where } } - async fn build_inner(self) -> Result> { + async fn build_inner(self) -> Result> { trace!("building node"); let lp = LocalPoolHandle::new(num_cpus::get()); @@ -495,7 +495,7 @@ where gossip, }); - let node = UnspawnedNode { + let node = ProtocolBuilder { inner, client, protocols: Default::default(), @@ -699,7 +699,7 @@ where /// Note that RPC calls performed with client returned from [`Self::client`] will not complete /// until the node is spawned. #[derive(derive_more::Debug)] -pub struct UnspawnedNode { +pub struct ProtocolBuilder { inner: Arc>, client: crate::client::MemIroh, internal_rpc: FlumeServerEndpoint, @@ -710,7 +710,7 @@ pub struct UnspawnedNode { gc_policy: GcPolicy, } -impl> UnspawnedNode { +impl> ProtocolBuilder { /// Register a protocol handler for incoming connections. /// /// Use this to register custom protocols onto the iroh node. Whenever a new connection for From b58d79e67a384f51814ae431a1478856c8996254 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 19 Jun 2024 12:34:45 +0200 Subject: [PATCH 32/33] expand docs --- iroh/src/node/protocol.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 528014437d..ac66b3b06e 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -8,6 +8,15 @@ use iroh_net::endpoint::Connecting; use crate::node::DocsEngine; /// Handler for incoming connections. +/// +/// An iroh node can accept connections for arbitrary ALPN protocols. By default, the iroh node +/// only accepts connections for the ALPNs of the core iroh protocols (blobs, gossip, docs). +/// +/// With this trait, you can handle incoming connections for custom protocols. +/// +/// Implement this trait on a struct that should handle incoming connections. +/// The protocol handler must then be registered on the node for an ALPN protocol with +/// [`crate::node::builder::ProtocolBuilder::accept`]. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// From e3b042d646989eecec1b394174d239398ed1b2b0 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Wed, 19 Jun 2024 12:52:59 +0200 Subject: [PATCH 33/33] rename trait Protocol to ProtocolHandler --- iroh/examples/custom-protocol.rs | 4 ++-- iroh/src/node.rs | 4 ++-- iroh/src/node/builder.rs | 12 ++++++------ iroh/src/node/protocol.rs | 24 ++++++++++++------------ 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 78ae274838..4a12687725 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -9,7 +9,7 @@ use iroh::{ endpoint::{get_remote_node_id, Connecting}, Endpoint, NodeId, }, - node::Protocol, + node::ProtocolHandler, }; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -63,7 +63,7 @@ struct ExampleProto { endpoint: Endpoint, } -impl Protocol for ExampleProto { +impl ProtocolHandler for ExampleProto { fn accept(self: Arc, connecting: Connecting) -> BoxedFuture> { Box::pin(async move { let connection = connecting.await?; diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 93f45001c3..ae9a5ddb69 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -33,7 +33,7 @@ mod rpc_status; pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; -pub use protocol::Protocol; +pub use protocol::ProtocolHandler; /// A server which implements the iroh node. /// @@ -180,7 +180,7 @@ impl Node { /// /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` /// does not match the passed type. - pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { self.protocols.get_typed(alpn) } } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 9f6c595cb3..5a266127cf 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -38,7 +38,7 @@ use crate::{ client::RPC_ALPN, node::{ protocol::{BlobsProtocol, ProtocolMap}, - Protocol, + ProtocolHandler, }, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, @@ -716,7 +716,7 @@ impl> ProtocolBuilde /// Use this to register custom protocols onto the iroh node. Whenever a new connection for /// `alpn` comes in, it is passed to this protocol handler. /// - /// See the [`Protocol`] trait for details. + /// See the [`ProtocolHandler`] trait for details. /// /// Example usage: /// @@ -724,7 +724,7 @@ impl> ProtocolBuilde /// # use std::sync::Arc; /// # use anyhow::Result; /// # use futures_lite::future::Boxed as BoxedFuture; - /// # use iroh::{node::{Node, Protocol}, net::endpoint::Connecting, client::MemIroh}; + /// # use iroh::{node::{Node, ProtocolHandler}, net::endpoint::Connecting, client::MemIroh}; /// # /// # #[tokio::main] /// # async fn main() -> Result<()> { @@ -736,7 +736,7 @@ impl> ProtocolBuilde /// client: MemIroh /// } /// - /// impl Protocol for MyProtocol { + /// impl ProtocolHandler for MyProtocol { /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { /// todo!(); /// } @@ -759,7 +759,7 @@ impl> ProtocolBuilde /// ``` /// /// - pub fn accept(mut self, alpn: &'static [u8], handler: Arc) -> Self { + pub fn accept(mut self, alpn: &'static [u8], handler: Arc) -> Self { self.protocols.insert(alpn, handler); self } @@ -801,7 +801,7 @@ impl> ProtocolBuilde /// /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` /// does not match the passed type. - pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { self.protocols.get_typed(alpn) } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index ac66b3b06e..25106e7c38 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -17,7 +17,7 @@ use crate::node::DocsEngine; /// Implement this trait on a struct that should handle incoming connections. /// The protocol handler must then be registered on the node for an ALPN protocol with /// [`crate::node::builder::ProtocolBuilder::accept`]. -pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { +pub trait ProtocolHandler: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. @@ -43,24 +43,24 @@ impl IntoArcAny for T { } #[derive(Debug, Clone, Default)] -pub(super) struct ProtocolMap(BTreeMap<&'static [u8], Arc>); +pub(super) struct ProtocolMap(BTreeMap<&'static [u8], Arc>); impl ProtocolMap { /// Returns the registered protocol handler for an ALPN as a concrete type. - pub fn get_typed(&self, alpn: &[u8]) -> Option> { - let protocol: Arc = self.0.get(alpn)?.clone(); + pub fn get_typed(&self, alpn: &[u8]) -> Option> { + let protocol: Arc = self.0.get(alpn)?.clone(); let protocol_any: Arc = protocol.into_arc_any(); let protocol_ref = Arc::downcast(protocol_any).ok()?; Some(protocol_ref) } - /// Returns the registered protocol handler for an ALPN as a [`Arc`]. - pub fn get(&self, alpn: &[u8]) -> Option> { + /// Returns the registered protocol handler for an ALPN as a [`Arc`]. + pub fn get(&self, alpn: &[u8]) -> Option> { self.0.get(alpn).cloned() } /// Insert a protocol handler. - pub fn insert(&mut self, alpn: &'static [u8], handler: Arc) { + pub fn insert(&mut self, alpn: &'static [u8], handler: Arc) { self.0.insert(alpn, handler); } @@ -71,9 +71,9 @@ impl ProtocolMap { /// Shutdown all protocol handlers. /// - /// Calls and awaits [`Protocol::shutdown`] for all registered handlers concurrently. + /// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently. pub async fn shutdown(&self) { - let handlers = self.0.values().cloned().map(Protocol::shutdown); + let handlers = self.0.values().cloned().map(ProtocolHandler::shutdown); join_all(handlers).await; } } @@ -90,7 +90,7 @@ impl BlobsProtocol { } } -impl Protocol for BlobsProtocol { +impl ProtocolHandler for BlobsProtocol { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { iroh_blobs::provider::handle_connection( @@ -114,13 +114,13 @@ impl iroh_blobs::provider::EventSender for MockEventSender { } } -impl Protocol for iroh_gossip::net::Gossip { +impl ProtocolHandler for iroh_gossip::net::Gossip { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { self.handle_connection(conn.await?).await }) } } -impl Protocol for DocsEngine { +impl ProtocolHandler for DocsEngine { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { self.handle_connection(conn).await }) }