Skip to content

Commit

Permalink
feat: add handshake and reduce allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Jan 18, 2023
1 parent 3d958e8 commit 6403bd1
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 86 deletions.
33 changes: 33 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ed25519-dalek = "1.0.1"
futures = "0.3.25"
hex = "0.4.3"
indicatif = { version = "0.17.2", features = ["tokio"] }
postcard = { version = "1.0.2", default-features = false, features = ["alloc", "use-std"] }
postcard = { version = "1.0.2", default-features = false, features = ["alloc", "use-std", "experimental-derive"] }
rand = "0.7"
rcgen = "0.10.0"
ring = "0.16.20"
Expand Down
155 changes: 87 additions & 68 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use std::{io::Read, net::SocketAddr, time::Instant};

use anyhow::{anyhow, bail, ensure, Result};
use bytes::BytesMut;
use postcard::experimental::max_size::MaxSize;
use s2n_quic::Connection;
use s2n_quic::{client::Connect, Client};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::debug;

use crate::protocol::{read_lp, write_lp, Request, Res, Response};
use crate::protocol::{read_lp_data, write_lp, Handshake, Request, Res, Response};
use crate::tls::{self, Keypair};

const MAX_DATA_SIZE: usize = 1024 * 1024 * 1024;
Expand Down Expand Up @@ -61,88 +62,106 @@ pub async fn run<D: AsyncWrite + Unpin>(
let stream = connection.open_bidirectional_stream().await?;
let (mut reader, mut writer) = stream.split();

let req = Request {
id: 1,
name: hash.into(),
};
let mut out_buffer = BytesMut::zeroed(std::cmp::max(
Request::POSTCARD_MAX_SIZE,
Handshake::POSTCARD_MAX_SIZE,
));

let mut out_buffer = BytesMut::zeroed(15 + req.name.len());
let used = postcard::to_slice(&req, &mut out_buffer)?;
// 1. Send Handshake
{
debug!("sending handshake");
let handshake = Handshake::default();
let used = postcard::to_slice(&handshake, &mut out_buffer)?;
write_lp(&mut writer, used).await?;
}

write_lp(&mut writer, used).await?;
// 2. Send Request
{
debug!("sending request");
let req = Request {
id: 1,
name: hash.into(),
};

let used = postcard::to_slice(&req, &mut out_buffer)?;
write_lp(&mut writer, used).await?;
}

// read response
// 3. Read response
{
debug!("reading response");
let mut in_buffer = BytesMut::with_capacity(1024);

// read next message
match read_lp::<_, Response>(&mut reader, &mut in_buffer).await? {
Some((response, response_size)) => match response.data {
Res::Found { size, outboard } => {
// Need to read the message now
ensure!(
size <= MAX_DATA_SIZE,
"size too large: {} > {}",
size,
MAX_DATA_SIZE
);

let outboard = outboard.to_vec();
// TODO: avoid buffering

// remove response buffered data
let _ = in_buffer.split_to(response_size);
while in_buffer.len() < size {
reader.read_buf(&mut in_buffer).await?;
}
match read_lp_data(&mut reader, &mut in_buffer).await? {
Some(response_buffer) => {
let response: Response = postcard::from_bytes(&response_buffer)?;
match response.data {
Res::Found { size, outboard } => {
// Need to read the message now
ensure!(
size <= MAX_DATA_SIZE,
"size too large: {} > {}",
size,
MAX_DATA_SIZE
);

// TODO: avoid buffering

// remove response buffered data
while in_buffer.len() < size {
reader.read_buf(&mut in_buffer).await?;
}

debug!("client: received data: {}bytes", in_buffer.len());
ensure!(
size == in_buffer.len(),
"expected {} bytes, got {} bytes",
size,
in_buffer.len()
);

let mut decoder = bao::decode::Decoder::new_outboard(
std::io::Cursor::new(&in_buffer[..]),
&*outboard,
&hash,
);

{
let mut buf = [0u8; 1024];
loop {
// TODO: avoid blocking
let read = decoder.read(&mut buf)?;
if read == 0 {
break;
debug!("client: received data: {}bytes", in_buffer.len());
ensure!(
size == in_buffer.len(),
"expected {} bytes, got {} bytes",
size,
in_buffer.len()
);

let mut decoder = bao::decode::Decoder::new_outboard(
std::io::Cursor::new(&in_buffer[..]),
outboard,
&hash,
);

{
let mut buf = [0u8; 1024];
loop {
// TODO: avoid blocking
let read = decoder.read(&mut buf)?;
if read == 0 {
break;
}
dest.write_all(&buf[..read]).await?;
}
dest.write_all(&buf[..read]).await?;
}
}

// Shut down the stream
writer.close().await?;
// Shut down the stream
debug!("shutting down stream");
writer.close().await?;

let data_len = size;
let elapsed = now.elapsed();
let elapsed_s = elapsed.as_secs_f64();
let data_len_bit = data_len * 8;
let mbits = data_len_bit as f64 / (1000. * 1000.) / elapsed_s;
let data_len = size;
let elapsed = now.elapsed();
let elapsed_s = elapsed.as_secs_f64();
let data_len_bit = data_len * 8;
let mbits = data_len_bit as f64 / (1000. * 1000.) / elapsed_s;

let stats = Stats {
data_len,
elapsed,
mbits,
};
let stats = Stats {
data_len,
elapsed,
mbits,
};

Ok(stats)
}
Res::NotFound => {
bail!("data not found");
Ok(stats)
}
Res::NotFound => {
bail!("data not found");
}
}
},
}
None => {
bail!("server disconnected");
}
Expand Down
74 changes: 61 additions & 13 deletions src/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
use anyhow::{ensure, Result};
use anyhow::{bail, ensure, Result};
use bytes::BytesMut;
use postcard::experimental::max_size::MaxSize;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::debug;

/// Maximum message size is limited to 100MiB for now.
const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 100;

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub const VERSION: u64 = 1;

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, MaxSize)]
pub struct Handshake {
pub version: u64,
}

impl Default for Handshake {
fn default() -> Self {
Handshake { version: VERSION }
}
}

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, MaxSize)]
pub struct Request {
pub id: u64,
/// blake3 hash
Expand Down Expand Up @@ -65,18 +79,52 @@ pub async fn read_lp<'a, R: AsyncRead + futures::io::AsyncRead + Unpin, T: Deser
buffer: &'a mut BytesMut,
) -> Result<Option<(T, usize)>> {
// read length prefix
if let Ok(size) = unsigned_varint::aio::read_u64(&mut reader).await {
let size = usize::try_from(size)?;
ensure!(size < MAX_MESSAGE_SIZE, "received message is too large");
let size = read_prefix(&mut reader, buffer).await?;

while buffer.len() < size {
reader.read_buf(buffer).await?;
}
let response: T = postcard::from_bytes(&buffer[..size])?;
debug!("read message of size {}", size);
while buffer.len() < size {
debug!("reading message {} {}", buffer.len(), size);
reader.read_buf(buffer).await?;
}

let response: T = postcard::from_bytes(&buffer[..size])?;
debug!("read message of size {}", size);

Ok(Some((response, size)))
} else {
Ok(None)
Ok(Some((response, size)))
}

/// Read and deserialize into the given type from the provided source, based on the length prefix.
pub async fn read_lp_data<R: AsyncRead + futures::io::AsyncRead + Unpin>(
mut reader: R,
buffer: &mut BytesMut,
) -> Result<Option<BytesMut>> {
// read length prefix
let size = read_prefix(&mut reader, buffer).await?;

while buffer.len() < size {
reader.read_buf(buffer).await?;
}
let response = buffer.split_to(size);
Ok(Some(response))
}

async fn read_prefix<R: AsyncRead + futures::io::AsyncRead + Unpin>(
mut reader: R,
buffer: &mut BytesMut,
) -> Result<usize> {
// read length prefix
let size = loop {
if let Ok((size, rest)) = unsigned_varint::decode::u64(&buffer[..]) {
let size = usize::try_from(size)?;
ensure!(size < MAX_MESSAGE_SIZE, "received message is too large");

let _ = buffer.split_to(buffer.len() - rest.len());
break size;
}

if reader.read_buf(buffer).await? == 0 {
bail!("no more data available");
}
};

Ok(size)
}
25 changes: 21 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{collections::HashMap, path::Path, sync::Arc};

use anyhow::{anyhow, ensure, Result};
use anyhow::{anyhow, bail, ensure, Result};
use bytes::{Bytes, BytesMut};
use s2n_quic::stream::BidirectionalStream;
use s2n_quic::Server;
use tokio::io::AsyncWriteExt;
use tracing::{debug, error};

use crate::protocol::{read_lp, write_lp, Request, Res, Response};
use crate::protocol::{read_lp, write_lp, Handshake, Request, Res, Response, VERSION};
use crate::tls::{self, Keypair};

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -64,12 +64,28 @@ async fn handle_stream(
let mut out_buffer = BytesMut::with_capacity(1024);
let mut in_buffer = BytesMut::with_capacity(1024);

// decode next message
// 1. Read Handshake
debug!("reading handshake");
if let Some((handshake, size)) = read_lp::<_, Handshake>(&mut reader, &mut in_buffer).await? {
ensure!(
handshake.version == VERSION,
"expected version {} but got {}",
VERSION,
handshake.version
);
let _ = in_buffer.split_to(size);
} else {
bail!("no valid handshake received");
}

// 2. Decode protocol messages.
loop {
in_buffer.clear();
debug!("reading request");
match read_lp::<_, Request>(&mut reader, &mut in_buffer).await? {
Some((request, _size)) => {
let name = bao::Hash::from(request.name);
debug!("got request({}): {}", request.id, name.to_hex());

let (data, piece) = if let Some(data) = db.get(&name) {
debug!("found {}", name.to_hex());
(
Expand Down Expand Up @@ -109,6 +125,7 @@ async fn handle_stream(
break;
}
}
in_buffer.clear();
}

Ok(())
Expand Down

0 comments on commit 6403bd1

Please sign in to comment.