Skip to content

Commit

Permalink
fix: use listen_addresses instead of local_address (#1044)
Browse files Browse the repository at this point in the history
* fix: use listen_addresses instead of local_address

listen_addresses are the actual listen addrs.

* WIP: use addresses of local_endpoints

the thing we painfully wait for...

* fix: increase delays in test_timer_abort_late

* fix: fix regex to allow for extremely slow transfers

* ref: completely remove listen_addresses

* fix: multiply more timer delays by 5

* fix(test): wait for provider to start up in cli_provide_addresses

...instead of a fixed delay. Also don't expect all ports to be 4333, just
at least 1.

* docs: add back doc comment for local_endpoint_addresses
  • Loading branch information
rklaehn authored May 24, 2023
1 parent 55d0211 commit c4a1890
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 112 deletions.
16 changes: 8 additions & 8 deletions src/hp/magicsock/timer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ mod tests {
assert!(!val.load(Ordering::Relaxed));

let moved_val = val.clone();
let timer = Timer::after(Duration::from_millis(10), async move {
let timer = Timer::after(Duration::from_millis(50), async move {
moved_val.store(true, Ordering::Relaxed);
});

assert!(!val.load(Ordering::Relaxed));
time::sleep(Duration::from_millis(15)).await;
time::sleep(Duration::from_millis(75)).await;

assert!(!timer.stop().await);
assert!(val.load(Ordering::Relaxed));
Expand All @@ -140,24 +140,24 @@ mod tests {
assert!(!val.load(Ordering::Relaxed));

let moved_val = val.clone();
let timer = Timer::after(Duration::from_millis(10), async move {
let timer = Timer::after(Duration::from_millis(50), async move {
moved_val.store(true, Ordering::Relaxed);
});

assert!(!val.load(Ordering::Relaxed));
time::sleep(Duration::from_millis(5)).await;
time::sleep(Duration::from_millis(25)).await;

// not yet expired
assert!(!val.load(Ordering::Relaxed));
// reset for another 10ms
timer.reset(Duration::from_millis(20)).await;
// reset for another 100ms
timer.reset(Duration::from_millis(100)).await;

// would have expired if not reset
time::sleep(Duration::from_millis(5)).await;
time::sleep(Duration::from_millis(25)).await;
assert!(!val.load(Ordering::Relaxed));

// definitely expired now
time::sleep(Duration::from_millis(25)).await;
time::sleep(Duration::from_millis(125)).await;
assert!(val.load(Ordering::Relaxed));
}
}
20 changes: 10 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub const IROH_BLOCK_SIZE: BlockSize = match BlockSize::new(4) {
mod tests {
use std::{
collections::BTreeMap,
net::{Ipv4Addr, SocketAddr},
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
path::{Path, PathBuf},
time::{Duration, Instant},
};
Expand Down Expand Up @@ -282,7 +282,7 @@ mod tests {
events
});

let addrs = provider.listen_addresses()?;
let addrs = provider.local_endpoint_addresses().await?;
let opts = get::Options {
addrs,
peer_id: provider.peer_id(),
Expand Down Expand Up @@ -362,7 +362,7 @@ mod tests {
.spawn()
.await
.unwrap();
let provider_addr = provider.local_address().unwrap();
let provider_addr = provider.local_endpoint_addresses().await.unwrap();
let peer_id = provider.peer_id();

// This tasks closes the connection on the provider side as soon as the transfer
Expand Down Expand Up @@ -431,7 +431,7 @@ mod tests {
.bind_addr("127.0.0.1:0".parse().unwrap())
.spawn()
.await?;
let provider_addr = provider.local_address()?;
let provider_addr = provider.local_endpoint_addresses().await?;
let peer_id = provider.peer_id();

let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), async move {
Expand Down Expand Up @@ -467,7 +467,7 @@ mod tests {
let readme = Path::new(env!("CARGO_MANIFEST_DIR")).join("README.md");
let (db, hash) = create_collection(vec![readme.into()]).await.unwrap();
let provider = match Provider::builder(db)
.bind_addr("[::1]:0".parse().unwrap())
.bind_addr((Ipv6Addr::UNSPECIFIED, 0).into())
.spawn()
.await
{
Expand All @@ -478,7 +478,7 @@ mod tests {
return;
}
};
let addrs = provider.local_address().unwrap();
let addrs = provider.local_endpoint_addresses().await.unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let request = get::run(
Expand Down Expand Up @@ -509,7 +509,7 @@ mod tests {
.await
.unwrap();
let _drop_guard = provider.cancel_token().drop_guard();
let ticket = provider.ticket(hash).unwrap();
let ticket = provider.ticket(hash).await.unwrap();
tokio::time::timeout(Duration::from_secs(10), async move {
let response =
get::run_ticket(&ticket, GetRequest::all(ticket.hash()).into(), true, None).await?;
Expand Down Expand Up @@ -587,7 +587,7 @@ mod tests {
return;
}
};
let addrs = provider.local_address().unwrap();
let addrs = provider.local_endpoint_addresses().await.unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let connection = dial_peer(get::Options {
Expand Down Expand Up @@ -668,7 +668,7 @@ mod tests {
.spawn()
.await
.unwrap();
let addrs = provider.local_address().unwrap();
let addrs = provider.local_endpoint_addresses().await.unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let request: AnyGetRequest = Bytes::from(&b"hello"[..]).into();
Expand Down Expand Up @@ -704,7 +704,7 @@ mod tests {
.spawn()
.await
.unwrap();
let addrs = provider.local_address().unwrap();
let addrs = provider.local_endpoint_addresses().await.unwrap();
let peer_id = provider.peer_id();
tokio::time::timeout(Duration::from_secs(10), async move {
let request: AnyGetRequest = Bytes::from(&b"hello"[..]).into();
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async fn main_impl() -> Result<()> {
let stream = controller.server_streaming(ProvideRequest { path }).await?;
let (hash, entries) = aggregate_add_response(stream).await?;
print_add_response(hash, entries);
let ticket = provider.ticket(hash)?;
let ticket = provider.ticket(hash).await?;
println!("All-in-one ticket: {ticket}");
anyhow::Ok(tmp_path)
})
Expand Down
57 changes: 1 addition & 56 deletions src/net/ip.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
//! IP address related utilities.
use std::net::{IpAddr, Ipv6Addr, SocketAddr};

use anyhow::{ensure, Result};
use tracing::debug;
use std::net::{IpAddr, Ipv6Addr};

const IFF_UP: u32 = 0x1;
const IFF_LOOPBACK: u32 = 0x8;
Expand Down Expand Up @@ -155,58 +152,6 @@ pub const fn is_unicast_link_local(addr: Ipv6Addr) -> bool {
(addr.segments()[0] & 0xffc0) == 0xfe80
}

/// Given a listen/bind address, finds all the local addresses for that address family.
pub(crate) fn find_local_addresses(listen_addrs: &[SocketAddr]) -> Result<Vec<SocketAddr>> {
debug!("find_local_address: {:?}", listen_addrs);

let mut addrs = Vec::new();
let mut local_addrs = None;

for addr in listen_addrs {
if addr.ip().is_unspecified() {
// Find all the local addresses for this address family.
if local_addrs.is_none() {
local_addrs = Some(LocalAddresses::new());
debug!("found local addresses: {:?}", local_addrs);
}
let local_addrs = local_addrs.as_ref().unwrap();
let port = addr.port();

match addr.ip() {
IpAddr::V4(_) => {
addrs.extend(
local_addrs
.regular
.iter()
.chain(local_addrs.loopback.iter())
.filter(|a| a.is_ipv4())
.map(|a| SocketAddr::new(*a, port)),
);
}
IpAddr::V6(_) => {
addrs.extend(
local_addrs
.regular
.iter()
.chain(local_addrs.loopback.iter())
.filter(|a| a.is_ipv6())
.map(|a| SocketAddr::new(*a, port)),
);
}
}
} else {
addrs.push(*addr);
}
}
// we might have added duplicates, make sure to remove them
addrs.sort();
addrs.dedup();

ensure!(!addrs.is_empty(), "No local addresses found");

Ok(addrs)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
37 changes: 22 additions & 15 deletions src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ use walkdir::WalkDir;
use crate::blobs::Collection;
use crate::hp::cfg::Endpoint;
use crate::hp::derp::DerpMap;
use crate::net::ip::find_local_addresses;
use crate::protocol::{
read_lp, write_lp, Closed, GetRequest, Handshake, RangeSpec, Request, VERSION,
};
Expand Down Expand Up @@ -524,7 +523,7 @@ impl Provider {
/// The address on which the provider socket is bound.
///
/// Note that this could be an unspecified address, if you need an address on which you
/// can contact the provider consider using [`Provider::listen_addresses`]. However the
/// can contact the provider consider using [`Provider::local_endpoint_addresses`]. However the
/// port will always be the concrete port.
pub fn local_address(&self) -> Result<Vec<SocketAddr>> {
self.inner.local_address()
Expand All @@ -535,11 +534,9 @@ impl Provider {
self.inner.local_endpoints().await
}

/// Returns all addresses on which the provider is reachable.
///
/// This will never be empty.
pub fn listen_addresses(&self) -> Result<Vec<SocketAddr>> {
self.inner.listen_addresses()
/// Convenience method to get just the addr part of [`Provider::local_endpoints`].
pub async fn local_endpoint_addresses(&self) -> Result<Vec<SocketAddr>> {
self.inner.local_endpoint_addresses().await
}

/// Returns the [`PeerId`] of the provider.
Expand All @@ -563,9 +560,9 @@ impl Provider {
/// Return a single token containing everything needed to get a hash.
///
/// See [`Ticket`] for more details of how it can be used.
pub fn ticket(&self, hash: Hash) -> Result<Ticket> {
pub async fn ticket(&self, hash: Hash) -> Result<Ticket> {
// TODO: Verify that the hash exists in the db?
let addrs = self.listen_addresses()?;
let addrs = self.local_endpoint_addresses().await?;
Ticket::new(hash, self.peer_id(), addrs)
}

Expand All @@ -591,6 +588,11 @@ impl ProviderInner {
self.conn.local_endpoints().await
}

async fn local_endpoint_addresses(&self) -> Result<Vec<SocketAddr>> {
let endpoints = self.local_endpoints().await?;
Ok(endpoints.into_iter().map(|x| x.addr).collect())
}

fn local_address(&self) -> Result<Vec<SocketAddr>> {
let (v4, v6) = self.conn.local_addr()?;
let mut addrs = vec![v4];
Expand All @@ -599,9 +601,6 @@ impl ProviderInner {
}
Ok(addrs)
}
fn listen_addresses(&self) -> Result<Vec<SocketAddr>> {
find_local_addresses(&self.local_address()?)
}
}

/// The future completes when the spawned tokio task finishes.
Expand Down Expand Up @@ -680,13 +679,21 @@ impl RpcHandler {
async fn id(self, _: IdRequest) -> IdResponse {
IdResponse {
peer_id: Box::new(self.inner.keypair.public().into()),
listen_addrs: self.inner.listen_addresses().unwrap_or_default(),
listen_addrs: self
.inner
.local_endpoint_addresses()
.await
.unwrap_or_default(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
async fn addrs(self, _: AddrsRequest) -> AddrsResponse {
AddrsResponse {
addrs: self.inner.listen_addresses().unwrap_or_default(),
addrs: self
.inner
.local_endpoint_addresses()
.await
.unwrap_or_default(),
}
}
async fn shutdown(self, request: ShutdownRequest) {
Expand Down Expand Up @@ -1358,7 +1365,7 @@ mod tests {
.await
.unwrap();
let _drop_guard = provider.cancel_token().drop_guard();
let ticket = provider.ticket(hash).unwrap();
let ticket = provider.ticket(hash).await.unwrap();
println!("addrs: {:?}", ticket.addrs());
assert!(!ticket.addrs().is_empty());
}
Expand Down
Loading

0 comments on commit c4a1890

Please sign in to comment.