Skip to content

Commit

Permalink
refactor: update connection logic to for magicsock
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed May 15, 2023
1 parent f76d650 commit e13f663
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 157 deletions.
99 changes: 24 additions & 75 deletions src/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@ use crate::hp::cfg::DERP_MAGIC_IP;
use crate::hp::derp::DerpMap;
use crate::hp::hostinfo::Hostinfo;
use crate::hp::{cfg, netmap};
use crate::net::subnet::{same_subnet_v4, same_subnet_v6};
use crate::protocol::{write_lp, AnyGetRequest, Handshake, RangeSpecSeq};
use crate::provider::Ticket;
use crate::tls::{self, Keypair, PeerId};
use crate::tokio_util::{TrackingReader, TrackingWriter};
use crate::util::pathbuf_from_name;
use crate::IROH_BLOCK_SIZE;
use anyhow::{anyhow, Context, Result};
use anyhow::{Context, Result};
use bao_tree::io::error::DecodeError;
use bao_tree::io::DecodeResponseItem;
use bao_tree::outboard::PreOrderMemOutboard;
use bao_tree::{ByteNum, ChunkNum};
use bytes::BytesMut;
use default_net::Interface;
use futures::StreamExt;
use postcard::experimental::max_size::MaxSize;
use quinn::RecvStream;
use range_collections::RangeSet2;
Expand All @@ -45,12 +42,12 @@ pub use crate::util::Hash;
pub const DEFAULT_PROVIDER_ADDR: (Ipv4Addr, u16) = crate::provider::DEFAULT_BIND_ADDR;

/// Options for the client
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug)]
pub struct Options {
/// The address to connect to
pub addr: Option<SocketAddr>,
/// The addresses to connect to.
pub addrs: Vec<SocketAddr>,
/// The peer id to expect
pub peer_id: Option<PeerId>,
pub peer_id: PeerId,
/// Whether to log the SSL keys when `SSLKEYLOGFILE` environment variable is set.
pub keylog: bool,
/// The configuration of the derp services.
Expand Down Expand Up @@ -78,14 +75,15 @@ pub struct Options {
/// connection.
pub async fn make_client_endpoint(
bind_addr: SocketAddr,
peer_id: Option<PeerId>,
peer_id: PeerId,
alpn_protocols: Vec<Vec<u8>>,
keylog: bool,
derp_map: Option<DerpMap>,
) -> Result<(quinn::Endpoint, crate::hp::magicsock::Conn)> {
let keypair = Keypair::generate();

let tls_client_config = tls::make_client_config(&keypair, peer_id, alpn_protocols, keylog)?;
let tls_client_config =
tls::make_client_config(&keypair, Some(peer_id), alpn_protocols, keylog)?;
let mut client_config = quinn::ClientConfig::new(Arc::new(tls_client_config));

let conn = crate::hp::magicsock::Conn::new(crate::hp::magicsock::Options {
Expand Down Expand Up @@ -113,9 +111,10 @@ pub async fn make_client_endpoint(

/// Establishes a QUIC connection to the provided peer.
pub async fn dial_peer(opts: Options) -> Result<quinn::Connection> {
let bind_addr = match opts.addr.map(|a| a.is_ipv6()) {
Some(true) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0).into(),
Some(false) | None => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(),
let bind_addr = if opts.addrs.iter().any(|addr| addr.ip().is_ipv6()) {
SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0).into()
} else {
SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into()
};

let (endpoint, magicsock) = make_client_endpoint(
Expand All @@ -128,17 +127,17 @@ pub async fn dial_peer(opts: Options) -> Result<quinn::Connection> {
.await?;

// Only a single peer in our network currently.
let peer_id = opts.peer_id.expect("need peer");
let peer_id = opts.peer_id;
let node_key: crate::hp::key::node::PublicKey = peer_id.into();
const DEFAULT_DERP_REGION: u16 = 1;

let mut addresses = Vec::new();
let mut endpoints = Vec::new();

// Add the provided address as a starting point.
if let Some(addr) = opts.addr {
for addr in &opts.addrs {
addresses.push(addr.ip());
endpoints.push(addr);
endpoints.push(*addr);
}
magicsock
.set_network_map(netmap::NetworkMap {
Expand All @@ -164,7 +163,7 @@ pub async fn dial_peer(opts: Options) -> Result<quinn::Connection> {
.expect("just inserted");
debug!(
"connecting to {}: (via {} - {:?})",
peer_id, addr, opts.addr
peer_id, addr, opts.addrs
);
let connect = endpoint.connect(addr, "localhost")?;
let connection = connect.await.context("failed connecting to provider")?;
Expand Down Expand Up @@ -196,67 +195,17 @@ pub async fn run_ticket(
ticket: &Ticket,
request: AnyGetRequest,
keylog: bool,
max_concurrent: u8,
derp_map: Option<DerpMap>,
) -> Result<get_response_machine::AtInitial> {
let connection = dial_ticket(ticket, keylog, max_concurrent.into(), derp_map).await?;
Ok(run_connection(connection, request))
}

async fn dial_ticket(
ticket: &Ticket,
keylog: bool,
max_concurrent: usize,
derp_map: Option<DerpMap>,
) -> Result<quinn::Connection> {
// Sort the interfaces to make sure local ones are at the front of the list.
let interfaces = default_net::get_interfaces();
let (mut addrs, other_addrs) = ticket
.addrs()
.iter()
.partition::<Vec<_>, _>(|addr| is_same_subnet(addr, &interfaces));
addrs.extend(other_addrs);

let mut conn_stream = futures::stream::iter(addrs)
.map(|addr| {
let opts = Options {
addr: Some(addr),
peer_id: Some(ticket.peer()),
keylog,
derp_map: derp_map.clone(),
};
dial_peer(opts)
})
.buffer_unordered(max_concurrent);
while let Some(res) = conn_stream.next().await {
match res {
Ok(conn) => return Ok(conn),
Err(_) => continue,
}
}
Err(anyhow!("Failed to establish connection to peer"))
}
let connection = dial_peer(Options {
addrs: ticket.addrs().to_vec(),
peer_id: ticket.peer(),
keylog,
derp_map,
})
.await?;

fn is_same_subnet(addr: &SocketAddr, interfaces: &[Interface]) -> bool {
for interface in interfaces {
match addr {
SocketAddr::V4(peer_addr) => {
for net in interface.ipv4.iter() {
if same_subnet_v4(net.addr, *peer_addr.ip(), net.prefix_len) {
return true;
}
}
}
SocketAddr::V6(peer_addr) => {
for net in interface.ipv6.iter() {
if same_subnet_v6(net.addr, *peer_addr.ip(), net.prefix_len) {
return true;
}
}
}
}
}
false
Ok(run_connection(connection, request))
}

/// Finite state machine for get responses
Expand Down
55 changes: 24 additions & 31 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ mod tests {
hash: Hash,
file_hash: Hash,
name: String,
addr: SocketAddr,
addrs: Vec<SocketAddr>,
peer_id: PeerId,
content: Vec<u8>,
) -> Result<()> {
let opts = get::Options {
addr: Some(addr),
peer_id: Some(peer_id),
addrs,
peer_id,
keylog: true,
derp_map: None,
};
Expand All @@ -201,7 +201,7 @@ mod tests {
hash,
expect_hash.into(),
expect_name.clone(),
provider.local_address().unwrap()[0],
provider.local_address().unwrap(),
provider.peer_id(),
content.to_vec(),
)));
Expand Down Expand Up @@ -294,10 +294,9 @@ mod tests {
});

let addrs = provider.listen_addresses()?;
let addr = *addrs.first().unwrap();
let opts = get::Options {
addr: Some(addr),
peer_id: Some(provider.peer_id()),
addrs,
peer_id: provider.peer_id(),
keylog: true,
derp_map: None,
};
Expand Down Expand Up @@ -406,8 +405,8 @@ mod tests {
let response = get::run(
GetRequest::all(hash).into(),
get::Options {
addr: Some(provider_addr[0]),
peer_id: Some(peer_id),
addrs: provider_addr,
peer_id,
keylog: true,
derp_map: None,
},
Expand Down Expand Up @@ -450,8 +449,8 @@ mod tests {
let request = get::run(
GetRequest::all(hash).into(),
get::Options {
addr: Some(provider_addr[0]),
peer_id: Some(peer_id),
addrs: provider_addr,
peer_id,
keylog: true,
derp_map: None,
},
Expand Down Expand Up @@ -489,13 +488,13 @@ mod tests {
return;
}
};
let addr = provider.local_address().unwrap();
let peer_id = Some(provider.peer_id());
let addrs = provider.local_address().unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let request = get::run(
GetRequest::all(hash).into(),
get::Options {
addr: Some(addr[0]),
addrs,
peer_id,
keylog: true,
derp_map: None,
Expand All @@ -522,14 +521,8 @@ mod tests {
let _drop_guard = provider.cancel_token().drop_guard();
let ticket = provider.ticket(hash).unwrap();
tokio::time::timeout(Duration::from_secs(10), async move {
let response = get::run_ticket(
&ticket,
GetRequest::all(ticket.hash()).into(),
true,
16,
None,
)
.await?;
let response =
get::run_ticket(&ticket, GetRequest::all(ticket.hash()).into(), true, None).await?;
aggregate_get_response(response).await
})
.await
Expand Down Expand Up @@ -601,11 +594,11 @@ mod tests {
return;
}
};
let addr = provider.local_address().unwrap();
let peer_id = Some(provider.peer_id());
let addrs = provider.local_address().unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let connection = dial_peer(get::Options {
addr: Some(addr[0]),
addrs,
peer_id,
keylog: true,
derp_map: None,
Expand Down Expand Up @@ -682,14 +675,14 @@ mod tests {
.spawn()
.await
.unwrap();
let addr = provider.local_address().unwrap();
let peer_id = Some(provider.peer_id());
let addrs = provider.local_address().unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let request: AnyGetRequest = Bytes::from(&b"hello"[..]).into();
let response = get::run(
request,
get::Options {
addr: Some(addr[0]),
addrs,
peer_id,
keylog: true,
derp_map: None,
Expand Down Expand Up @@ -718,14 +711,14 @@ mod tests {
.spawn()
.await
.unwrap();
let addr = provider.local_address().unwrap();
let peer_id = Some(provider.peer_id());
let addrs = provider.local_address().unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let request: AnyGetRequest = Bytes::from(&b"hello"[..]).into();
let response = get::run(
request,
get::Options {
addr: Some(addr[0]),
addrs,
peer_id,
keylog: true,
derp_map: None,
Expand Down
15 changes: 7 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ const DEFAULT_RPC_PORT: u16 = 0x1337;
const RPC_ALPN: [u8; 17] = *b"n0/provider-rpc/1";
const MAX_RPC_CONNECTIONS: u32 = 16;
const MAX_RPC_STREAMS: u64 = 1024;
const MAX_CONCURRENT_DIALS: u8 = 16;

/// Send data.
///
Expand Down Expand Up @@ -167,9 +166,9 @@ enum Commands {
/// PeerId of the provider
#[clap(long, short)]
peer: PeerId,
/// Address of the provider
/// Addresses of the provider.
#[clap(long, short)]
addr: Option<SocketAddr>,
addrs: Vec<SocketAddr>,
/// Directory in which to save the file(s), defaults to writing to STDOUT
#[clap(long, short)]
out: Option<PathBuf>,
Expand Down Expand Up @@ -513,13 +512,13 @@ async fn main_impl() -> Result<()> {
Commands::Get {
hash,
peer,
addr,
addrs,
out,
single,
} => {
let opts = get::Options {
addr,
peer_id: Some(peer),
addrs,
peer_id: peer,
keylog: cli.keylog,
derp_map: config.derp_map(),
};
Expand Down Expand Up @@ -894,7 +893,7 @@ async fn get_to_dir(get: GetInteractive, out_dir: PathBuf) -> Result<()> {
ticket,
keylog,
derp_map,
} => get::run_ticket(&ticket, request, keylog, MAX_CONCURRENT_DIALS, derp_map).await?,
} => get::run_ticket(&ticket, request, keylog, derp_map).await?,
GetInteractive::Hash { opts, .. } => get::run(request, opts).await?,
};
let connected = response.next().await?;
Expand Down Expand Up @@ -1047,7 +1046,7 @@ async fn get_to_stdout(get: GetInteractive) -> Result<()> {
ticket,
keylog,
derp_map,
} => get::run_ticket(&ticket, request, keylog, MAX_CONCURRENT_DIALS, derp_map).await?,
} => get::run_ticket(&ticket, request, keylog, derp_map).await?,
GetInteractive::Hash { opts, .. } => get::run(request, opts).await?,
};
let connected = response.next().await?;
Expand Down
1 change: 0 additions & 1 deletion src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
pub mod interfaces;
pub mod ip;
pub mod subnet;
Loading

0 comments on commit e13f663

Please sign in to comment.