Skip to content

Commit

Permalink
feat: connect based on PeerIds
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire authored Jan 20, 2023
1 parent bcc7313 commit c57fa4f
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 55 deletions.
14 changes: 8 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,30 @@ use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::debug;

use crate::protocol::{read_lp_data, write_lp, Handshake, Request, Res, Response};
use crate::tls::{self, Keypair};
use crate::tls::{self, Keypair, PeerId};

const MAX_DATA_SIZE: usize = 1024 * 1024 * 1024;

#[derive(Clone, Debug)]
pub struct Options {
pub addr: SocketAddr,
pub peer_id: Option<PeerId>,
}

impl Default for Options {
fn default() -> Self {
Options {
addr: "127.0.0.1:4433".parse().unwrap(),
peer_id: None,
}
}
}

/// Setup a QUIC connection to the provided server address
async fn setup(server_addr: SocketAddr) -> Result<(Client, Connection)> {
async fn setup(opts: Options) -> Result<(Client, Connection)> {
let keypair = Keypair::generate();

let client_config = tls::make_client_config(&keypair, None)?;
let client_config = tls::make_client_config(&keypair, opts.peer_id)?;
let tls = s2n_quic::provider::tls::rustls::Client::from(client_config);

let client = Client::builder()
Expand All @@ -41,8 +43,8 @@ async fn setup(server_addr: SocketAddr) -> Result<(Client, Connection)> {
.start()
.map_err(|e| anyhow!("{:?}", e))?;

debug!("client: connecting to {}", server_addr);
let connect = Connect::new(server_addr).with_server_name("localhost");
debug!("client: connecting to {}", opts.addr);
let connect = Connect::new(opts.addr).with_server_name("localhost");
let mut connection = client.connect(connect).await?;

connection.keep_alive(true)?;
Expand All @@ -69,7 +71,7 @@ pub fn run<D: AsyncWrite + Unpin>(
) -> impl Stream<Item = Result<Event>> {
async_stream::try_stream! {
let now = Instant::now();
let (_client, mut connection) = setup(opts.addr).await?;
let (_client, mut connection) = setup(opts).await?;

let stream = connection.open_bidirectional_stream().await?;
let (mut reader, mut writer) = stream.split();
Expand Down
46 changes: 38 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ pub mod server;

mod tls;

pub use tls::{PeerId, PeerIdError};

#[cfg(test)]
mod tests {
use std::{net::SocketAddr, path::PathBuf};

use crate::tls::PeerId;

use super::*;
use anyhow::Result;
use futures::TryStreamExt;
Expand All @@ -23,12 +27,17 @@ mod tests {
let db = server::create_db(vec![&path]).await?;
let hash = *db.iter().next().unwrap().0;
let addr = "127.0.0.1:4443".parse().unwrap();
let mut server = server::Server::new(db);
let peer_id = server.peer_id();

tokio::task::spawn(async move {
server::run(db, server::Options { addr }).await.unwrap();
server.run(server::Options { addr }).await.unwrap();
});

let opts = client::Options { addr };
let opts = client::Options {
addr,
peer_id: Some(peer_id),
};
let (mut source, sink) = tokio::io::duplex(1024);
let events: Vec<_> = client::run(hash, opts, sink).try_collect().await?;
assert_eq!(events.len(), 3);
Expand Down Expand Up @@ -68,12 +77,17 @@ mod tests {

let db = server::create_db(vec![&path]).await?;
let hash = *db.iter().next().unwrap().0;
let mut server = server::Server::new(db);
let peer_id = server.peer_id();

let server_task = tokio::task::spawn(async move {
server::run(db, server::Options { addr }).await.unwrap();
server.run(server::Options { addr }).await.unwrap();
});

let opts = client::Options { addr };
let opts = client::Options {
addr,
peer_id: Some(peer_id),
};
let (mut source, sink) = tokio::io::duplex(size);
let events: Vec<_> = client::run(hash, opts, sink).try_collect().await?;
assert_eq!(events.len(), 3);
Expand All @@ -98,12 +112,23 @@ mod tests {
tokio::fs::write(&path, content).await?;
let db = server::create_db(vec![&path]).await?;
let hash = *db.iter().next().unwrap().0;
let mut server = server::Server::new(db);
let peer_id = server.peer_id();

tokio::task::spawn(async move {
server::run(db, server::Options { addr }).await.unwrap();
server.run(server::Options { addr }).await.unwrap();
});

async fn run_client(hash: bao::Hash, addr: SocketAddr, content: Vec<u8>) -> Result<()> {
let opts = client::Options { addr };
async fn run_client(
hash: bao::Hash,
addr: SocketAddr,
peer_id: PeerId,
content: Vec<u8>,
) -> Result<()> {
let opts = client::Options {
addr,
peer_id: Some(peer_id),
};
let (mut source, sink) = tokio::io::duplex(1024);
let events: Vec<_> = client::run(hash, opts, sink).try_collect().await?;
assert_eq!(events.len(), 3);
Expand All @@ -115,7 +140,12 @@ mod tests {

let mut tasks = Vec::new();
for _i in 0..3 {
tasks.push(tokio::task::spawn(run_client(hash, addr, content.to_vec())));
tasks.push(tokio::task::spawn(run_client(
hash,
addr,
peer_id,
content.to_vec(),
)));
}

for task in tasks {
Expand Down
23 changes: 19 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use futures::StreamExt;
use indicatif::{HumanDuration, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

use sendme::{client, server};
use sendme::{client, server, PeerId};

#[derive(Parser, Debug, Clone)]
#[clap(version, about, long_about = None)]
Expand All @@ -18,6 +18,7 @@ struct Cli {
}

#[derive(Subcommand, Debug, Clone)]
#[allow(clippy::large_enum_variant)]
enum Commands {
/// Serve the data from the given path
#[clap(about = "Serve the data from the given path")]
Expand All @@ -31,6 +32,9 @@ enum Commands {
#[clap(about = "Fetch the data from the hash")]
Client {
hash: bao::Hash,
#[clap(long)]
/// PeerId of the server.
peer_id: PeerId,
#[clap(long, short)]
/// Option address of the server, defaults to 127.0.0.1:4433.
addr: Option<SocketAddr>,
Expand All @@ -50,9 +54,17 @@ async fn main() -> Result<()> {
let cli = Cli::parse();

match cli.command {
Commands::Client { hash, addr, out } => {
Commands::Client {
hash,
peer_id,
addr,
out,
} => {
println!("Fetching: {}", hash.to_hex());
let mut opts = client::Options::default();
let mut opts = client::Options {
peer_id: Some(peer_id),
..Default::default()
};
if let Some(addr) = addr {
opts.addr = addr;
}
Expand Down Expand Up @@ -98,7 +110,10 @@ async fn main() -> Result<()> {
if let Some(addr) = addr {
opts.addr = addr;
}
server::run(db, opts).await?
let mut server = server::Server::new(db);

println!("Serving from {}", server.peer_id());
server.run(opts).await?
}
}

Expand Down
89 changes: 53 additions & 36 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use std::{collections::HashMap, path::Path, sync::Arc};
use anyhow::{anyhow, bail, ensure, Result};
use bytes::{Bytes, BytesMut};
use s2n_quic::stream::BidirectionalStream;
use s2n_quic::Server;
use s2n_quic::Server as QuicServer;
use tokio::io::AsyncWriteExt;
use tracing::{debug, error};

use crate::protocol::{read_lp, write_lp, Handshake, Request, Res, Response, VERSION};
use crate::tls::{self, Keypair};
use crate::tls::{self, Keypair, PeerId};

#[derive(Clone, Debug)]
pub struct Options {
Expand All @@ -28,42 +28,59 @@ impl Default for Options {
const MAX_CLIENTS: u64 = 1024;
const MAX_STREAMS: u64 = 10;

pub async fn run(db: Arc<HashMap<bao::Hash, Data>>, opts: Options) -> Result<()> {
let keypair = Keypair::generate();
let server_config = tls::make_server_config(&keypair)?;
let tls = s2n_quic::provider::tls::rustls::Server::from(server_config);
let limits = s2n_quic::provider::limits::Limits::default()
.with_max_active_connection_ids(MAX_CLIENTS)?
.with_max_open_local_bidirectional_streams(MAX_STREAMS)?
.with_max_open_remote_bidirectional_streams(MAX_STREAMS)?;

let mut server = Server::builder()
.with_tls(tls)?
.with_io(opts.addr)?
.with_limits(limits)?
.start()
.map_err(|e| anyhow!("{:?}", e))?;

debug!("\nlistening at: {:#?}", server.local_addr().unwrap());

while let Some(mut connection) = server.accept().await {
let db = db.clone();
tokio::spawn(async move {
debug!("connection accepted from {:?}", connection.remote_addr());

while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await {
let db = db.clone();
tokio::spawn(async move {
if let Err(err) = handle_stream(db, stream).await {
error!("error: {:#?}", err);
}
debug!("disconnected");
});
}
});
pub type Database = Arc<HashMap<bao::Hash, Data>>;

pub struct Server {
keypair: Keypair,
db: Database,
}

impl Server {
pub fn new(db: Database) -> Self {
let keypair = Keypair::generate();
Server { keypair, db }
}

Ok(())
pub fn peer_id(&self) -> PeerId {
self.keypair.public().into()
}

pub async fn run(&mut self, opts: Options) -> Result<()> {
let server_config = tls::make_server_config(&self.keypair)?;
let tls = s2n_quic::provider::tls::rustls::Server::from(server_config);
let limits = s2n_quic::provider::limits::Limits::default()
.with_max_active_connection_ids(MAX_CLIENTS)?
.with_max_open_local_bidirectional_streams(MAX_STREAMS)?
.with_max_open_remote_bidirectional_streams(MAX_STREAMS)?;

let mut server = QuicServer::builder()
.with_tls(tls)?
.with_io(opts.addr)?
.with_limits(limits)?
.start()
.map_err(|e| anyhow!("{:?}", e))?;

debug!("\nlistening at: {:#?}", server.local_addr().unwrap());

while let Some(mut connection) = server.accept().await {
let db = self.db.clone();
tokio::spawn(async move {
debug!("connection accepted from {:?}", connection.remote_addr());

while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await {
let db = db.clone();
tokio::spawn(async move {
if let Err(err) = handle_stream(db, stream).await {
error!("error: {:#?}", err);
}
debug!("disconnected");
});
}
});
}

Ok(())
}
}

async fn handle_stream(
Expand Down
2 changes: 1 addition & 1 deletion src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl Keypair {
}

// TODO: probably needs a version field
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Copy)]
pub struct PeerId(PublicKey);

impl From<PublicKey> for PeerId {
Expand Down

0 comments on commit c57fa4f

Please sign in to comment.