Skip to content

Commit

Permalink
fix: fail if an on_blob function does not read all the content (#139)
Browse files Browse the repository at this point in the history
closes #127
  • Loading branch information
ramfox authored Feb 13, 2023
1 parent fc8fbeb commit c266ab5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use anyhow::{anyhow, bail, ensure, Result};
use bytes::BytesMut;
use futures::Future;
use postcard::experimental::max_size::MaxSize;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
use tracing::debug;

use crate::bao_slice_decoder::AsyncSliceDecoder;
Expand Down Expand Up @@ -207,7 +207,12 @@ where
"downloaded more than {total_blobs_size}"
);
remaining_size -= size;
let blob_reader = on_blob(blob.hash, blob_reader, blob.name).await?;
let mut blob_reader =
on_blob(blob.hash, blob_reader, blob.name).await?;

if blob_reader.read_exact(&mut [0u8; 1]).await.is_ok() {
bail!("`on_blob` callback did not fully read the blob content")
}
reader = blob_reader.into_inner();
}
}
Expand Down
51 changes: 50 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod tests {
use rand::RngCore;
use testdir::testdir;
use tokio::fs;
use tokio::io::{self, AsyncReadExt};
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};

use crate::protocol::AuthToken;
use crate::provider::{create_collection, Event, Provider};
Expand Down Expand Up @@ -329,4 +329,53 @@ mod tests {
// Unwrap the JoinHandle, then the result of the Provider
supervisor.await.unwrap().unwrap();
}

#[tokio::test]
async fn test_blob_reader_partial() -> Result<()> {
// Prepare a Provider transferring a file.
let dir = testdir!();
let src0 = dir.join("src0");
let src1 = dir.join("src1");
{
let content = vec![1u8; 1000];
let mut f = tokio::fs::File::create(&src0).await?;
for _ in 0..10 {
f.write_all(&content).await?;
}
}
fs::write(&src1, "hello world").await?;
let (db, hash) = create_collection(vec![src0.into(), src1.into()]).await?;
let provider = Provider::builder(db)
.bind_addr("127.0.0.1:0".parse().unwrap())
.spawn()?;
let auth_token = provider.auth_token();
let provider_addr = provider.listen_addr();

let timeout = tokio::time::timeout(
std::time::Duration::from_secs(10),
get::run(
hash,
auth_token,
get::Options {
addr: provider_addr,
peer_id: None,
},
|| async move { Ok(()) },
|_collection| async move { Ok(()) },
|_hash, stream, _name| async move {
// evil: do nothing with the stream!
Ok(stream)
},
),
)
.await;
provider.shutdown();

let err = timeout.expect(
"`get` function is hanging, make sure we are handling misbehaving `on_blob` functions",
);

err.expect_err("expected an error when passing in a misbehaving `on_blob` function");
Ok(())
}
}

0 comments on commit c266ab5

Please sign in to comment.