diff --git a/Cargo.toml b/Cargo.toml index 245aa277f8..fcde338e71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,9 @@ tokio-util = { version = "0.3", features = ["codec"] } tower-util = "0.3" url = "1.0" +[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] +pnet = "0.25.0" + [features] default = [ "runtime", diff --git a/src/client/connect/dns.rs b/src/client/connect/dns.rs index 613579174a..39ffc71e34 100644 --- a/src/client/connect/dns.rs +++ b/src/client/connect/dns.rs @@ -200,27 +200,33 @@ impl IpAddrs { None } - pub(super) fn split_by_preference(self, local_addr: Option) -> (IpAddrs, IpAddrs) { - if let Some(local_addr) = local_addr { - let preferred = self - .iter - .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6()) - .collect(); - - (IpAddrs::new(preferred), IpAddrs::new(vec![])) - } else { - let preferring_v6 = self - .iter - .as_slice() - .first() - .map(SocketAddr::is_ipv6) - .unwrap_or(false); - - let (preferred, fallback) = self - .iter - .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); - - (IpAddrs::new(preferred), IpAddrs::new(fallback)) + #[inline] + fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> IpAddrs { + IpAddrs::new(self.iter.filter(predicate).collect()) + } + + pub(super) fn split_by_preference( + self, + local_addr_ipv4: Option, + local_addr_ipv6: Option, + ) -> (IpAddrs, IpAddrs) { + match (local_addr_ipv4, local_addr_ipv6) { + (Some(_), None) => (self.filter(SocketAddr::is_ipv4), IpAddrs::new(vec![])), + (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), IpAddrs::new(vec![])), + _ => { + let preferring_v6 = self + .iter + .as_slice() + .first() + .map(SocketAddr::is_ipv6) + .unwrap_or(false); + + let (preferred, fallback) = self + .iter + .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); + + (IpAddrs::new(preferred), IpAddrs::new(fallback)) + } } } @@ -355,34 +361,50 @@ mod tests { #[test] fn test_ip_addrs_split_by_preference() { - let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into(); - let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into(); + let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); + let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); + let v4_addr = (ip_v4, 80).into(); + let v6_addr = (ip_v6, 80).into(); + + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); let (mut preferred, mut fallback) = IpAddrs { iter: vec![v4_addr, v6_addr].into_iter(), } - .split_by_preference(None); + .split_by_preference(Some(ip_v4), Some(ip_v6)); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.next().unwrap().is_ipv6()); let (mut preferred, mut fallback) = IpAddrs { iter: vec![v6_addr, v4_addr].into_iter(), } - .split_by_preference(None); + .split_by_preference(Some(ip_v4), Some(ip_v6)); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.next().unwrap().is_ipv4()); let (mut preferred, fallback) = IpAddrs { iter: vec![v4_addr, v6_addr].into_iter(), } - .split_by_preference(Some(v4_addr.ip())); + .split_by_preference(Some(ip_v4), None); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.is_empty()); let (mut preferred, fallback) = IpAddrs { iter: vec![v4_addr, v6_addr].into_iter(), } - .split_by_preference(Some(v6_addr.ip())); + .split_by_preference(None, Some(ip_v6)); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.is_empty()); } diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index d61dce3a6a..c1cdf4e129 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -3,7 +3,7 @@ use std::fmt; use std::future::Future; use std::io; use std::marker::PhantomData; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; use std::task::{self, Poll}; @@ -72,7 +72,8 @@ struct Config { enforce_http: bool, happy_eyeballs_timeout: Option, keep_alive_timeout: Option, - local_address: Option, + local_address_ipv4: Option, + local_address_ipv6: Option, nodelay: bool, reuse_address: bool, send_buffer_size: Option, @@ -111,7 +112,8 @@ impl HttpConnector { enforce_http: true, happy_eyeballs_timeout: Some(Duration::from_millis(300)), keep_alive_timeout: None, - local_address: None, + local_address_ipv4: None, + local_address_ipv6: None, nodelay: false, reuse_address: false, send_buffer_size: None, @@ -166,7 +168,26 @@ impl HttpConnector { /// Default is `None`. #[inline] pub fn set_local_address(&mut self, addr: Option) { - self.config_mut().local_address = addr; + let (v4, v6) = match addr { + Some(IpAddr::V4(a)) => (Some(a), None), + Some(IpAddr::V6(a)) => (None, Some(a)), + _ => (None, None), + }; + + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = v4; + cfg.local_address_ipv6 = v6; + } + + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + #[inline] + pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) { + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = Some(addr_ipv4); + cfg.local_address_ipv6 = Some(addr_ipv6); } /// Set the connect timeout. @@ -311,7 +332,8 @@ where }; let c = ConnectingTcp::new( - config.local_address, + config.local_address_ipv4, + config.local_address_ipv6, addrs, config.connect_timeout, config.happy_eyeballs_timeout, @@ -454,7 +476,8 @@ impl StdError for ConnectError { } struct ConnectingTcp { - local_addr: Option, + local_addr_ipv4: Option, + local_addr_ipv6: Option, preferred: ConnectingTcpRemote, fallback: Option, reuse_address: bool, @@ -462,17 +485,20 @@ struct ConnectingTcp { impl ConnectingTcp { fn new( - local_addr: Option, + local_addr_ipv4: Option, + local_addr_ipv6: Option, remote_addrs: dns::IpAddrs, connect_timeout: Option, fallback_timeout: Option, reuse_address: bool, ) -> ConnectingTcp { if let Some(fallback_timeout) = fallback_timeout { - let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(local_addr); + let (preferred_addrs, fallback_addrs) = + remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6); if fallback_addrs.is_empty() { return ConnectingTcp { - local_addr, + local_addr_ipv4, + local_addr_ipv6, preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), fallback: None, reuse_address, @@ -480,7 +506,8 @@ impl ConnectingTcp { } ConnectingTcp { - local_addr, + local_addr_ipv4, + local_addr_ipv6, preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), fallback: Some(ConnectingTcpFallback { delay: tokio::time::delay_for(fallback_timeout), @@ -490,7 +517,8 @@ impl ConnectingTcp { } } else { ConnectingTcp { - local_addr, + local_addr_ipv4, + local_addr_ipv6, preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout), fallback: None, reuse_address, @@ -523,13 +551,22 @@ impl ConnectingTcpRemote { impl ConnectingTcpRemote { async fn connect( &mut self, - local_addr: &Option, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, reuse_address: bool, ) -> io::Result { let mut err = None; for addr in &mut self.addrs { debug!("connecting to {}", addr); - match connect(&addr, local_addr, reuse_address, self.connect_timeout)?.await { + match connect( + &addr, + local_addr_ipv4, + local_addr_ipv6, + reuse_address, + self.connect_timeout, + )? + .await + { Ok(tcp) => { debug!("connected to {}", addr); return Ok(tcp); @@ -551,9 +588,38 @@ impl ConnectingTcpRemote { } } +fn bind_local_address( + socket: &socket2::Socket, + dst_addr: &SocketAddr, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, +) -> io::Result<()> { + match (*dst_addr, local_addr_ipv4, local_addr_ipv6) { + (SocketAddr::V4(_), Some(addr), _) => { + socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?; + } + (SocketAddr::V6(_), _, Some(addr)) => { + socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?; + } + _ => { + if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match *dst_addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + socket.bind(&any.into())?; + } + } + } + + Ok(()) +} + fn connect( addr: &SocketAddr, - local_addr: &Option, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, reuse_address: bool, connect_timeout: Option, ) -> io::Result>> { @@ -568,17 +634,7 @@ fn connect( socket.set_reuse_address(true)?; } - if let Some(ref local_addr) = *local_addr { - // Caller has requested this socket be bound before calling connect - socket.bind(&SocketAddr::new(local_addr.clone(), 0).into())?; - } else if cfg!(windows) { - // Windows requires a socket be bound before calling connect - let any: SocketAddr = match *addr { - SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), - SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), - }; - socket.bind(&any.into())?; - } + bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?; let addr = *addr; @@ -600,17 +656,27 @@ fn connect( impl ConnectingTcp { async fn connect(mut self) -> io::Result { let Self { - ref local_addr, + ref local_addr_ipv4, + ref local_addr_ipv6, reuse_address, .. } = self; match self.fallback { - None => self.preferred.connect(local_addr, reuse_address).await, + None => { + self.preferred + .connect(local_addr_ipv4, local_addr_ipv6, reuse_address) + .await + } Some(mut fallback) => { - let preferred_fut = self.preferred.connect(local_addr, reuse_address); + let preferred_fut = + self.preferred + .connect(local_addr_ipv4, local_addr_ipv6, reuse_address); futures_util::pin_mut!(preferred_fut); - let fallback_fut = fallback.remote.connect(local_addr, reuse_address); + let fallback_fut = + fallback + .remote + .connect(local_addr_ipv4, local_addr_ipv6, reuse_address); futures_util::pin_mut!(fallback_fut); let (result, future) = @@ -666,6 +732,32 @@ mod tests { assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); } + #[cfg(any(target_os = "linux", target_os = "macos"))] + fn get_local_ips() -> (Option, Option) { + use std::net::{IpAddr, TcpListener}; + + let mut ip_v4 = None; + let mut ip_v6 = None; + + let ips = pnet::datalink::interfaces() + .into_iter() + .flat_map(|i| i.ips.into_iter().map(|n| n.ip())); + + for ip in ips { + match ip { + IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip), + IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip), + _ => (), + } + + if ip_v4.is_some() && ip_v6.is_some() { + break; + } + } + + (ip_v4, ip_v6) + } + #[tokio::test] async fn test_errors_missing_scheme() { let dst = "example.domain".parse().unwrap(); @@ -676,6 +768,43 @@ mod tests { assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); } + // NOTE: pnet crate that we use in this test doesn't compile on Windows + #[cfg(any(target_os = "linux", target_os = "macos"))] + #[tokio::test] + async fn local_address() { + use std::net::{IpAddr, TcpListener}; + + let (bind_ip_v4, bind_ip_v6) = get_local_ips(); + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = server4.local_addr().unwrap().port(); + let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap(); + + let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move { + let mut connector = HttpConnector::new(); + + match (bind_ip_v4, bind_ip_v6) { + (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6), + (Some(v4), None) => connector.set_local_address(Some(v4.into())), + (None, Some(v6)) => connector.set_local_address(Some(v6.into())), + _ => unreachable!(), + } + + connect(connector, dst.parse().unwrap()).await.unwrap(); + + let (_, client_addr) = server.accept().unwrap(); + + assert_eq!(client_addr.ip(), expected_ip); + }; + + if let Some(ip) = bind_ip_v4 { + assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await; + } + + if let Some(ip) = bind_ip_v6 { + assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await; + } + } + #[test] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)] fn client_happy_eyeballs() { @@ -797,6 +926,7 @@ mod tests { .map(|host| (host.clone(), addr.port()).into()) .collect(); let connecting_tcp = ConnectingTcp::new( + None, None, dns::IpAddrs::new(addrs), None,