diff --git a/src/get.rs b/src/get.rs index 537bd49036..92a2fae95e 100644 --- a/src/get.rs +++ b/src/get.rs @@ -15,21 +15,18 @@ use crate::hp::cfg::DERP_MAGIC_IP; use crate::hp::derp::DerpMap; use crate::hp::hostinfo::Hostinfo; use crate::hp::{cfg, netmap}; -use crate::net::subnet::{same_subnet_v4, same_subnet_v6}; use crate::protocol::{write_lp, AnyGetRequest, Handshake, RangeSpecSeq}; use crate::provider::Ticket; use crate::tls::{self, Keypair, PeerId}; use crate::tokio_util::{TrackingReader, TrackingWriter}; use crate::util::pathbuf_from_name; use crate::IROH_BLOCK_SIZE; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result}; use bao_tree::io::error::DecodeError; use bao_tree::io::DecodeResponseItem; use bao_tree::outboard::PreOrderMemOutboard; use bao_tree::{ByteNum, ChunkNum}; use bytes::BytesMut; -use default_net::Interface; -use futures::StreamExt; use postcard::experimental::max_size::MaxSize; use quinn::RecvStream; use range_collections::RangeSet2; @@ -45,12 +42,12 @@ pub use crate::util::Hash; pub const DEFAULT_PROVIDER_ADDR: (Ipv4Addr, u16) = crate::provider::DEFAULT_BIND_ADDR; /// Options for the client -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub struct Options { - /// The address to connect to - pub addr: Option, + /// The addresses to connect to. + pub addrs: Vec, /// The peer id to expect - pub peer_id: Option, + pub peer_id: PeerId, /// Whether to log the SSL keys when `SSLKEYLOGFILE` environment variable is set. pub keylog: bool, /// The configuration of the derp services. @@ -78,14 +75,15 @@ pub struct Options { /// connection. pub async fn make_client_endpoint( bind_addr: SocketAddr, - peer_id: Option, + peer_id: PeerId, alpn_protocols: Vec>, keylog: bool, derp_map: Option, ) -> Result<(quinn::Endpoint, crate::hp::magicsock::Conn)> { let keypair = Keypair::generate(); - let tls_client_config = tls::make_client_config(&keypair, peer_id, alpn_protocols, keylog)?; + let tls_client_config = + tls::make_client_config(&keypair, Some(peer_id), alpn_protocols, keylog)?; let mut client_config = quinn::ClientConfig::new(Arc::new(tls_client_config)); let conn = crate::hp::magicsock::Conn::new(crate::hp::magicsock::Options { @@ -113,9 +111,10 @@ pub async fn make_client_endpoint( /// Establishes a QUIC connection to the provided peer. pub async fn dial_peer(opts: Options) -> Result { - let bind_addr = match opts.addr.map(|a| a.is_ipv6()) { - Some(true) => SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0).into(), - Some(false) | None => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), + let bind_addr = if opts.addrs.iter().any(|addr| addr.ip().is_ipv6()) { + SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0).into() + } else { + SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into() }; let (endpoint, magicsock) = make_client_endpoint( @@ -128,7 +127,7 @@ pub async fn dial_peer(opts: Options) -> Result { .await?; // Only a single peer in our network currently. - let peer_id = opts.peer_id.expect("need peer"); + let peer_id = opts.peer_id; let node_key: crate::hp::key::node::PublicKey = peer_id.into(); const DEFAULT_DERP_REGION: u16 = 1; @@ -136,9 +135,9 @@ pub async fn dial_peer(opts: Options) -> Result { let mut endpoints = Vec::new(); // Add the provided address as a starting point. - if let Some(addr) = opts.addr { + for addr in &opts.addrs { addresses.push(addr.ip()); - endpoints.push(addr); + endpoints.push(*addr); } magicsock .set_network_map(netmap::NetworkMap { @@ -164,7 +163,7 @@ pub async fn dial_peer(opts: Options) -> Result { .expect("just inserted"); debug!( "connecting to {}: (via {} - {:?})", - peer_id, addr, opts.addr + peer_id, addr, opts.addrs ); let connect = endpoint.connect(addr, "localhost")?; let connection = connect.await.context("failed connecting to provider")?; @@ -196,67 +195,17 @@ pub async fn run_ticket( ticket: &Ticket, request: AnyGetRequest, keylog: bool, - max_concurrent: u8, derp_map: Option, ) -> Result { - let connection = dial_ticket(ticket, keylog, max_concurrent.into(), derp_map).await?; - Ok(run_connection(connection, request)) -} - -async fn dial_ticket( - ticket: &Ticket, - keylog: bool, - max_concurrent: usize, - derp_map: Option, -) -> Result { - // Sort the interfaces to make sure local ones are at the front of the list. - let interfaces = default_net::get_interfaces(); - let (mut addrs, other_addrs) = ticket - .addrs() - .iter() - .partition::, _>(|addr| is_same_subnet(addr, &interfaces)); - addrs.extend(other_addrs); - - let mut conn_stream = futures::stream::iter(addrs) - .map(|addr| { - let opts = Options { - addr: Some(addr), - peer_id: Some(ticket.peer()), - keylog, - derp_map: derp_map.clone(), - }; - dial_peer(opts) - }) - .buffer_unordered(max_concurrent); - while let Some(res) = conn_stream.next().await { - match res { - Ok(conn) => return Ok(conn), - Err(_) => continue, - } - } - Err(anyhow!("Failed to establish connection to peer")) -} + let connection = dial_peer(Options { + addrs: ticket.addrs().to_vec(), + peer_id: ticket.peer(), + keylog, + derp_map, + }) + .await?; -fn is_same_subnet(addr: &SocketAddr, interfaces: &[Interface]) -> bool { - for interface in interfaces { - match addr { - SocketAddr::V4(peer_addr) => { - for net in interface.ipv4.iter() { - if same_subnet_v4(net.addr, *peer_addr.ip(), net.prefix_len) { - return true; - } - } - } - SocketAddr::V6(peer_addr) => { - for net in interface.ipv6.iter() { - if same_subnet_v6(net.addr, *peer_addr.ip(), net.prefix_len) { - return true; - } - } - } - } - } - false + Ok(run_connection(connection, request)) } /// Finite state machine for get responses diff --git a/src/lib.rs b/src/lib.rs index 6689a72323..a8107889b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -174,13 +174,13 @@ mod tests { hash: Hash, file_hash: Hash, name: String, - addr: SocketAddr, + addrs: Vec, peer_id: PeerId, content: Vec, ) -> Result<()> { let opts = get::Options { - addr: Some(addr), - peer_id: Some(peer_id), + addrs, + peer_id, keylog: true, derp_map: None, }; @@ -201,7 +201,7 @@ mod tests { hash, expect_hash.into(), expect_name.clone(), - provider.local_address().unwrap()[0], + provider.local_address().unwrap(), provider.peer_id(), content.to_vec(), ))); @@ -294,10 +294,9 @@ mod tests { }); let addrs = provider.listen_addresses()?; - let addr = *addrs.first().unwrap(); let opts = get::Options { - addr: Some(addr), - peer_id: Some(provider.peer_id()), + addrs, + peer_id: provider.peer_id(), keylog: true, derp_map: None, }; @@ -406,8 +405,8 @@ mod tests { let response = get::run( GetRequest::all(hash).into(), get::Options { - addr: Some(provider_addr[0]), - peer_id: Some(peer_id), + addrs: provider_addr, + peer_id, keylog: true, derp_map: None, }, @@ -450,8 +449,8 @@ mod tests { let request = get::run( GetRequest::all(hash).into(), get::Options { - addr: Some(provider_addr[0]), - peer_id: Some(peer_id), + addrs: provider_addr, + peer_id, keylog: true, derp_map: None, }, @@ -489,13 +488,13 @@ mod tests { return; } }; - let addr = provider.local_address().unwrap(); - let peer_id = Some(provider.peer_id()); + let addrs = provider.local_address().unwrap(); + let peer_id = provider.peer_id(); tokio::time::timeout(Duration::from_secs(10), async move { let request = get::run( GetRequest::all(hash).into(), get::Options { - addr: Some(addr[0]), + addrs, peer_id, keylog: true, derp_map: None, @@ -522,14 +521,8 @@ mod tests { let _drop_guard = provider.cancel_token().drop_guard(); let ticket = provider.ticket(hash).unwrap(); tokio::time::timeout(Duration::from_secs(10), async move { - let response = get::run_ticket( - &ticket, - GetRequest::all(ticket.hash()).into(), - true, - 16, - None, - ) - .await?; + let response = + get::run_ticket(&ticket, GetRequest::all(ticket.hash()).into(), true, None).await?; aggregate_get_response(response).await }) .await @@ -601,11 +594,11 @@ mod tests { return; } }; - let addr = provider.local_address().unwrap(); - let peer_id = Some(provider.peer_id()); + let addrs = provider.local_address().unwrap(); + let peer_id = provider.peer_id(); tokio::time::timeout(Duration::from_secs(10), async move { let connection = dial_peer(get::Options { - addr: Some(addr[0]), + addrs, peer_id, keylog: true, derp_map: None, @@ -682,14 +675,14 @@ mod tests { .spawn() .await .unwrap(); - let addr = provider.local_address().unwrap(); - let peer_id = Some(provider.peer_id()); + let addrs = provider.local_address().unwrap(); + let peer_id = provider.peer_id(); tokio::time::timeout(Duration::from_secs(10), async move { let request: AnyGetRequest = Bytes::from(&b"hello"[..]).into(); let response = get::run( request, get::Options { - addr: Some(addr[0]), + addrs, peer_id, keylog: true, derp_map: None, @@ -718,14 +711,14 @@ mod tests { .spawn() .await .unwrap(); - let addr = provider.local_address().unwrap(); - let peer_id = Some(provider.peer_id()); + let addrs = provider.local_address().unwrap(); + let peer_id = provider.peer_id(); tokio::time::timeout(Duration::from_secs(10), async move { let request: AnyGetRequest = Bytes::from(&b"hello"[..]).into(); let response = get::run( request, get::Options { - addr: Some(addr[0]), + addrs, peer_id, keylog: true, derp_map: None, diff --git a/src/main.rs b/src/main.rs index ead401117e..3357ebc42d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,7 +38,6 @@ const DEFAULT_RPC_PORT: u16 = 0x1337; const RPC_ALPN: [u8; 17] = *b"n0/provider-rpc/1"; const MAX_RPC_CONNECTIONS: u32 = 16; const MAX_RPC_STREAMS: u64 = 1024; -const MAX_CONCURRENT_DIALS: u8 = 16; /// Send data. /// @@ -167,9 +166,9 @@ enum Commands { /// PeerId of the provider #[clap(long, short)] peer: PeerId, - /// Address of the provider + /// Addresses of the provider. #[clap(long, short)] - addr: Option, + addrs: Vec, /// Directory in which to save the file(s), defaults to writing to STDOUT #[clap(long, short)] out: Option, @@ -513,13 +512,13 @@ async fn main_impl() -> Result<()> { Commands::Get { hash, peer, - addr, + addrs, out, single, } => { let opts = get::Options { - addr, - peer_id: Some(peer), + addrs, + peer_id: peer, keylog: cli.keylog, derp_map: config.derp_map(), }; @@ -894,7 +893,7 @@ async fn get_to_dir(get: GetInteractive, out_dir: PathBuf) -> Result<()> { ticket, keylog, derp_map, - } => get::run_ticket(&ticket, request, keylog, MAX_CONCURRENT_DIALS, derp_map).await?, + } => get::run_ticket(&ticket, request, keylog, derp_map).await?, GetInteractive::Hash { opts, .. } => get::run(request, opts).await?, }; let connected = response.next().await?; @@ -1047,7 +1046,7 @@ async fn get_to_stdout(get: GetInteractive) -> Result<()> { ticket, keylog, derp_map, - } => get::run_ticket(&ticket, request, keylog, MAX_CONCURRENT_DIALS, derp_map).await?, + } => get::run_ticket(&ticket, request, keylog, derp_map).await?, GetInteractive::Hash { opts, .. } => get::run(request, opts).await?, }; let connected = response.next().await?; diff --git a/src/net.rs b/src/net.rs index 731e8c664b..d1dba33786 100644 --- a/src/net.rs +++ b/src/net.rs @@ -2,4 +2,3 @@ pub mod interfaces; pub mod ip; -pub mod subnet; diff --git a/src/net/subnet.rs b/src/net/subnet.rs deleted file mode 100644 index e8313bfdea..0000000000 --- a/src/net/subnet.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! Same subnet logic. -//! -//! Tiny module because left/right shifting confuses emacs' rust-mode. So sad. - -use std::net::{Ipv4Addr, Ipv6Addr}; - -/// Checks if both addresses are on the same subnet given the `prefix_len`. -pub(crate) fn same_subnet_v4(addr_a: Ipv4Addr, addr_b: Ipv4Addr, prefix_len: u8) -> bool { - let mask = u32::MAX << (32 - prefix_len); - let a = u32::from(addr_a) & mask; - let b = u32::from(addr_b) & mask; - a == b -} - -pub(crate) fn same_subnet_v6(addr_a: Ipv6Addr, addr_b: Ipv6Addr, prefix_len: u8) -> bool { - let mask = u128::MAX << (128 - prefix_len); - let a = u128::from(addr_a) & mask; - let b = u128::from(addr_b) & mask; - a == b -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_same_subnet_v4() { - let a = Ipv4Addr::new(192, 168, 0, 5); - let b = Ipv4Addr::new(192, 168, 1, 6); - assert!(!same_subnet_v4(a, b, 24)); - assert!(same_subnet_v4(a, b, 16)); - } - - #[test] - fn test_same_subnet_v6() { - let a = Ipv6Addr::new(0xfd56, 0x5799, 0xd8f6, 0x3cc, 0x0, 0x0, 0x0, 0x1); - let b = Ipv6Addr::new(0xfd56, 0x5799, 0xd8f6, 0x3cd, 0x0, 0x0, 0x0, 0x2); - assert!(!same_subnet_v6(a, b, 64)); - assert!(same_subnet_v6(a, b, 48)); - } -} diff --git a/tests/cli.rs b/tests/cli.rs index 3a0157a0b8..a1829e6b9c 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -349,7 +349,10 @@ fn test_provide_get_loop(path: &Path, input: Input, output: Output) -> Result<() let home = testdir!(); let mut provider = make_provider(&path, &input, home, None, None)?; - // std::io::copy(&mut provider.child.stderr.take().unwrap(), &mut std::io::stderr())?; + // std::io::copy( + // &mut provider.child.stderr.take().unwrap(), + // &mut std::io::stderr(), + // )?; let stdout = provider.child.stdout.take().unwrap(); let stdout = BufReader::new(stdout);