diff --git a/iroh/src/client.rs b/iroh/src/client.rs index 4148af526d..a5ee335909 100644 --- a/iroh/src/client.rs +++ b/iroh/src/client.rs @@ -3,22 +3,28 @@ //! TODO: Contains only iroh sync related methods. Add other methods. use std::collections::HashMap; +use std::io; +use std::pin::Pin; use std::result::Result as StdResult; +use std::task::{Context, Poll}; use anyhow::{anyhow, Result}; use bytes::Bytes; +use futures::stream::BoxStream; use futures::{Stream, StreamExt, TryStreamExt}; use iroh_bytes::Hash; use iroh_net::{key::PublicKey, magic_endpoint::ConnectionInfo}; use iroh_sync::{store::GetFilter, AuthorId, Entry, NamespaceId}; use quic_rpc::{RpcClient, ServiceConnection}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio_util::io::StreamReader; use crate::rpc_protocol::{ - AuthorCreateRequest, AuthorListRequest, BytesGetRequest, ConnectionInfoRequest, - ConnectionInfoResponse, ConnectionsRequest, CounterStats, DocCreateRequest, DocGetManyRequest, - DocGetOneRequest, DocImportRequest, DocInfoRequest, DocListRequest, DocSetRequest, - DocShareRequest, DocStartSyncRequest, DocStopSyncRequest, DocSubscribeRequest, DocTicket, - ProviderService, ShareMode, StatsGetRequest, + AuthorCreateRequest, AuthorListRequest, BytesGetRequest, BytesGetResponse, + ConnectionInfoRequest, ConnectionInfoResponse, ConnectionsRequest, CounterStats, + DocCreateRequest, DocGetManyRequest, DocGetOneRequest, DocImportRequest, DocInfoRequest, + DocListRequest, DocSetRequest, DocShareRequest, DocStartSyncRequest, DocStopSyncRequest, + DocSubscribeRequest, DocTicket, ProviderService, ShareMode, StatsGetRequest, }; use crate::sync_engine::{LiveEvent, LiveStatus, PeerSource}; @@ -102,10 +108,14 @@ where /// Get the bytes for a hash. /// /// Note: This reads the full blob into memory. - // TODO: add get_reader for streaming gets pub async fn get_bytes(&self, hash: Hash) -> Result { - let res = self.rpc.rpc(BytesGetRequest { hash }).await??; - Ok(res.data) + let mut stream = self.get_bytes_stream(hash).await?; + stream.read_to_end().await + } + + /// Get the bytes for a hash. + pub async fn get_bytes_stream(&self, hash: Hash) -> Result { + BlobReader::from_rpc(&self.rpc, hash).await } /// Get statistics of the running node. @@ -128,6 +138,74 @@ where } } +/// Data reader for a single blob. +/// +/// Implements [`AsyncRead`]. +pub struct BlobReader { + size: u64, + is_complete: bool, + stream: tokio_util::io::StreamReader>, Bytes>, +} +impl BlobReader { + fn new(size: u64, is_complete: bool, stream: BoxStream<'static, io::Result>) -> Self { + Self { + size, + is_complete, + stream: StreamReader::new(stream), + } + } + + async fn from_rpc>( + rpc: &RpcClient, + hash: Hash, + ) -> anyhow::Result { + let stream = rpc.server_streaming(BytesGetRequest { hash }).await?; + let mut stream = flatten(stream); + + let (size, is_complete) = match stream.next().await { + Some(Ok(BytesGetResponse::Entry { size, is_complete })) => (size, is_complete), + Some(Err(err)) => return Err(err.into()), + None | Some(Ok(_)) => return Err(anyhow!("Expected header frame")), + }; + + let stream = stream.map(|item| match item { + Ok(BytesGetResponse::Data { chunk }) => Ok(chunk), + Ok(_) => Err(io::Error::new(io::ErrorKind::Other, "Expected data frame")), + Err(err) => Err(io::Error::new(io::ErrorKind::Other, format!("{err}"))), + }); + Ok(Self::new(size, is_complete, stream.boxed())) + } + + /// Total size of this blob. + pub fn size(&self) -> u64 { + self.size + } + + /// Whether this blob has been downloaded completely. + /// + /// Returns false for partial blobs for which some chunks are missing. + pub fn is_complete(&self) -> bool { + self.is_complete + } + + /// Read all bytes of the blob. + pub async fn read_to_end(&mut self) -> anyhow::Result { + let mut buf = Vec::with_capacity(self.size() as usize); + AsyncReadExt::read_to_end(self, &mut buf).await?; + Ok(buf.into()) + } +} + +impl AsyncRead for BlobReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + /// Document handle #[derive(Debug, Clone)] pub struct Doc { @@ -164,10 +242,14 @@ where } /// Get the contents of an entry as a byte array. - // TODO: add get_content_reader pub async fn get_content_bytes(&self, hash: Hash) -> Result { - let bytes = self.rpc.rpc(BytesGetRequest { hash }).await??; - Ok(bytes.data) + let mut stream = BlobReader::from_rpc(&self.rpc, hash).await?; + stream.read_to_end().await + } + + /// Get the contents of an entry as a [`BlobReader`]. + pub async fn get_content_reader(&self, hash: Hash) -> Result { + BlobReader::from_rpc(&self.rpc, hash).await } /// Get the latest entry for a key and author. diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 928c051d2c..130c50471a 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -27,7 +27,7 @@ use iroh_bytes::get::Stats; use iroh_bytes::protocol::GetRequest; use iroh_bytes::provider::ShareProgress; use iroh_bytes::util::progress::{FlumeProgressSender, IdGenerator, ProgressSender}; -use iroh_bytes::util::{RpcError, RpcResult}; +use iroh_bytes::util::RpcResult; use iroh_bytes::{ protocol::{Closed, Request, RequestToken}, provider::{CustomGetHandler, ProvideProgress, RequestAuthorizationHandler}, @@ -35,7 +35,7 @@ use iroh_bytes::{ util::Hash, }; use iroh_gossip::net::{Gossip, GOSSIP_ALPN}; -use iroh_io::AsyncSliceReaderExt; +use iroh_io::{AsyncSliceReader, AsyncSliceReaderExt}; use iroh_net::defaults::default_derp_map; use iroh_net::magic_endpoint::get_alpn; use iroh_net::{ @@ -79,6 +79,11 @@ pub const DEFAULT_BIND_ADDR: (Ipv4Addr, u16) = (Ipv4Addr::LOCALHOST, 11204); /// How long we wait at most for some endpoints to be discovered. const ENDPOINT_WAIT: Duration = Duration::from_secs(5); +/// Chunk size for getting blobs over RPC +const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64; +/// Channel cap for getting blobs over RPC +const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; + /// Builder for the [`Node`]. /// /// You must supply a blob store. Various store implementations are available @@ -1141,29 +1146,48 @@ impl RpcHandler { }) } - // TODO: streaming - async fn bytes_get(self, req: BytesGetRequest) -> RpcResult { - let entry = self - .inner - .db - .get(&req.hash) - .ok_or_else(|| RpcError::from(anyhow!("not found")))?; - // TODO: size limit - // TODO: streaming - let data = self.inner.rt.local_pool().spawn_pinned(|| async move { - let data = entry - .data_reader() - .await - .map_err(anyhow::Error::from)? - .read_to_end() - .await - .map_err(anyhow::Error::from)?; - Result::<_, anyhow::Error>::Ok(data) + fn bytes_get( + self, + req: BytesGetRequest, + ) -> impl Stream> + Send + 'static { + let (tx, rx) = flume::bounded(RPC_BLOB_GET_CHANNEL_CAP); + let entry = self.inner.db.get(&req.hash); + self.inner.rt.local_pool().spawn_pinned(move || async move { + if let Err(err) = read_loop(entry, tx.clone(), RPC_BLOB_GET_CHUNK_SIZE).await { + tx.send_async(RpcResult::Err(err.into())).await.ok(); + } }); - let data = data - .await - .map_err(|_err| anyhow::anyhow!("task failed to complete"))??; - Ok(BytesGetResponse { data }) + + async fn read_loop( + entry: Option>, + tx: flume::Sender>, + chunk_size: usize, + ) -> anyhow::Result<()> { + let entry = entry.ok_or_else(|| anyhow!("Blob not found"))?; + let size = entry.size(); + tx.send_async(Ok(BytesGetResponse::Entry { + size, + is_complete: entry.is_complete(), + })) + .await?; + let mut reader = entry.data_reader().await?; + let mut offset = 0u64; + loop { + let chunk = reader.read_at(offset, chunk_size).await?; + let len = chunk.len(); + if !chunk.is_empty() { + tx.send_async(Ok(BytesGetResponse::Data { chunk })).await?; + } + if len < chunk_size { + break; + } else { + offset += len as u64; + } + } + Ok(()) + } + + rx.into_stream() } fn connections( @@ -1329,8 +1353,10 @@ fn handle_rpc_request< .await } ConnectionInfo(msg) => chan.rpc(msg, handler, RpcHandler::connection_info).await, - // TODO: make streaming - BytesGet(msg) => chan.rpc(msg, handler, RpcHandler::bytes_get).await, + BytesGet(msg) => { + chan.server_streaming(msg, handler, RpcHandler::bytes_get) + .await + } } }); } diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 7015f572f4..d16b684393 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -615,15 +615,29 @@ pub struct BytesGetRequest { pub hash: Hash, } -impl RpcMsg for BytesGetRequest { +impl Msg for BytesGetRequest { + type Pattern = ServerStreaming; +} + +impl ServerStreamingMsg for BytesGetRequest { type Response = RpcResult; } /// Response to [`BytesGetRequest`] #[derive(Serialize, Deserialize, Debug)] -pub struct BytesGetResponse { - /// The blob data - pub data: Bytes, +pub enum BytesGetResponse { + /// The entry header. + Entry { + /// The size of the blob + size: u64, + /// Wether the blob is complete + is_complete: bool, + }, + /// Chunks of entry data. + Data { + /// The data chunk + chunk: Bytes, + }, } /// Get stats for the running Iroh node