diff --git a/src/tracing/error.rs b/src/tracing/error.rs index 0ec8c429f..757da29ca 100644 --- a/src/tracing/error.rs +++ b/src/tracing/error.rs @@ -23,4 +23,6 @@ pub enum TracerError { AddressNotAvailable(SocketAddr), #[error("invalid source IP address: {0}")] InvalidSourceAddr(IpAddr), + #[error("error: {0}")] + ErrorString(String), } diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 9d05fc29f..4651e9ebd 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -154,7 +154,8 @@ pub fn dispatch_tcp_probe( PortDirection::FixedDest(dest_port) => (probe.sequence.0, dest_port.0), PortDirection::FixedBoth(_, _) | PortDirection::None => unimplemented!(), }; - let socket = platform::make_stream_socket_ipv4()?; + #[allow(unused_mut)] + let mut socket = platform::make_stream_socket_ipv4()?; let local_addr = SocketAddr::new(IpAddr::V4(src_addr), src_port); socket.bind(local_addr)?; socket.set_ttl(u32::from(probe.ttl.0))?; @@ -180,6 +181,7 @@ pub fn dispatch_tcp_probe( Ok(socket) } +#[cfg(unix)] pub fn recv_icmp_probe( recv_socket: &mut Socket, protocol: TracerProtocol, @@ -204,6 +206,20 @@ pub fn recv_icmp_probe( } } +#[cfg(windows)] +pub fn recv_icmp_probe( + recv_socket: &mut Socket, + protocol: TracerProtocol, + multipath_strategy: MultipathStrategy, + direction: PortDirection, +) -> TraceResult> { + let bytes = &recv_socket.buf_bytes(); + let ipv4 = Ipv4Packet::new_view(bytes).req()?; + // post the WSARecvFrom again, so that the next OVERLAPPED event can get triggered + recv_socket.recv_from()?; + extract_probe_resp(protocol, multipath_strategy, direction, &ipv4) +} + pub fn recv_tcp_socket( tcp_socket: &Socket, sequence: Sequence, diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index 99a238177..af8da164f 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -109,11 +109,12 @@ pub fn dispatch_tcp_probe( PortDirection::FixedDest(dest_port) => (probe.sequence.0, dest_port.0), PortDirection::FixedBoth(_, _) | PortDirection::None => unimplemented!(), }; - let socket = platform::make_stream_socket_ipv6()?; let local_addr = SocketAddr::new(IpAddr::V6(src_addr), src_port); + let remote_addr = SocketAddr::new(IpAddr::V6(dest_addr), dest_port); + #[allow(unused_mut)] + let mut socket = platform::make_stream_socket_ipv6()?; socket.bind(local_addr)?; socket.set_unicast_hops_v6(probe.ttl.0)?; - let remote_addr = SocketAddr::new(IpAddr::V6(dest_addr), dest_port); match socket.connect(remote_addr) { Ok(_) => {} Err(err) => { @@ -134,6 +135,7 @@ pub fn dispatch_tcp_probe( Ok(socket) } +#[cfg(unix)] pub fn recv_icmp_probe( recv_socket: &mut Socket, protocol: TracerProtocol, @@ -160,6 +162,25 @@ pub fn recv_icmp_probe( } } +#[cfg(windows)] +#[allow(unsafe_code)] +pub fn recv_icmp_probe( + recv_socket: &mut Socket, + protocol: TracerProtocol, + direction: PortDirection, +) -> TraceResult> { + let bytes = &recv_socket.buf_bytes(); + let icmp_v6 = IcmpPacket::new_view(bytes).req()?; + let addr = recv_socket.from()?; + // post the WSARecvFrom again, so that the next OVERLAPPED event can get triggered + recv_socket.recv_from()?; + if let IpAddr::V6(src_addr) = addr { + extract_probe_resp(protocol, direction, &icmp_v6, src_addr) + } else { + Err(TracerError::InvalidSourceAddr(addr)) + } +} + pub fn recv_tcp_socket( tcp_socket: &Socket, sequence: Sequence, diff --git a/src/tracing/net/platform/windows.rs b/src/tracing/net/platform/windows.rs index 59fc70597..32ccbfd06 100644 --- a/src/tracing/net/platform/windows.rs +++ b/src/tracing/net/platform/windows.rs @@ -1,216 +1,767 @@ use super::byte_order::PlatformIpv4FieldByteOrder; -use crate::tracing::error::TraceResult; -use std::io; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::net::{Shutdown, SocketAddr}; +use crate::tracing::error::{TraceResult, TracerError}; +use crate::tracing::net::channel::MAX_PACKET_SIZE; +use core::convert; +use std::alloc::{alloc, dealloc, Layout}; +use std::ffi::c_void; +use std::fmt::{self}; +use std::io::{Error, ErrorKind, Result}; +use std::mem::MaybeUninit; +use std::mem::{align_of, size_of}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; - -/// TODO -#[allow(clippy::unnecessary_wraps)] -pub fn for_address(_src_addr: IpAddr) -> TraceResult { - Ok(PlatformIpv4FieldByteOrder::Network) +use windows::core::PSTR; +use windows::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, NO_ERROR, WAIT_FAILED, WAIT_TIMEOUT}; +use windows::Win32::NetworkManagement::IpHelper; +use windows::Win32::Networking::WinSock::{ + bind, closesocket, connect, getpeername, getsockopt, sendto, setsockopt, shutdown, socket, + WSACloseEvent, WSACreateEvent, WSAEventSelect, WSAGetOverlappedResult, WSAIoctl, WSARecvFrom, + WSAResetEvent, WSAStartup, ADDRESS_FAMILY, AF_INET, AF_INET6, FD_CONNECT, FD_WRITE, FIONBIO, + ICMP_ERROR_INFO, INVALID_SOCKET, IPPROTO, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_IP, + IPPROTO_IPV6, IPPROTO_RAW, IPPROTO_TCP, IPPROTO_UDP, IPV6_UNICAST_HOPS, IP_HDRINCL, IP_TOS, + IP_TTL, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_ROUTING_INTERFACE_QUERY, SOCKADDR_IN, SOCKADDR_IN6, + SOCKADDR_STORAGE, SOCKET, SOCKET_ERROR, SOCK_DGRAM, SOCK_RAW, SOCK_STREAM, SOL_SOCKET, + SO_ERROR, SO_PORT_SCALABILITY, TCP_FAIL_CONNECT_ON_ICMP_ERROR, TCP_ICMP_ERROR_INFO, WSABUF, + WSADATA, WSAECONNREFUSED, WSAEHOSTUNREACH, WSAEINPROGRESS, WSAEWOULDBLOCK, WSA_IO_INCOMPLETE, + WSA_IO_PENDING, +}; +use windows::Win32::System::Threading::WaitForSingleObject; +use windows::Win32::System::IO::OVERLAPPED; + +pub struct Socket { + s: SOCKET, + ol: Box, + wbuf: Box, + from: Box, } +impl Socket { + #[allow(unsafe_code)] + /// # Panics + /// + /// Will panic if `Layout` constructor fails to build a layout for `MAX_PACKET_SIZE` aligned on `WSABUF`. + fn create(af: ADDRESS_FAMILY, r#type: u16, protocol: IPPROTO) -> TraceResult { + let s = unsafe { socket(af.0.try_into().unwrap(), i32::from(r#type), protocol.0) }; + if s == INVALID_SOCKET { + return Err(TracerError::IoError(Error::last_os_error())); + } + let from = Box::::default(); + let layout = + Layout::from_size_align(MAX_PACKET_SIZE, std::mem::align_of::()).unwrap(); + let ptr = unsafe { alloc(layout) }; + let wbuf = Box::new(WSABUF { + len: MAX_PACKET_SIZE as u32, + buf: PSTR::from_raw(ptr), + }); + let ol = Box::::default(); + Ok(Self { s, ol, wbuf, from }) + } -#[allow(clippy::unnecessary_wraps)] -pub fn startup() -> TraceResult<()> { - Ok(()) -} + #[allow(unsafe_code)] + #[must_use] + pub fn buf_bytes(&self) -> Vec { + let buf = self.wbuf.buf.as_ptr(); + let slice = unsafe { std::slice::from_raw_parts(buf, self.wbuf.len as usize) }; + slice.to_owned() + } -/// TODO -pub fn lookup_interface_addr_ipv4(_name: &str) -> TraceResult { - unimplemented!() -} + pub fn from(&mut self) -> TraceResult { + sockaddrptr_to_ipaddr(std::ptr::addr_of_mut!(*self.from)) + } -/// TODO -pub fn lookup_interface_addr_ipv6(_name: &str) -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + fn create_event(&mut self) -> TraceResult<()> { + self.ol.hEvent = unsafe { WSACreateEvent() }; + if self.ol.hEvent.is_invalid() { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(()) + } -/// TODO -pub fn make_icmp_send_socket_ipv4() -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + fn wait_for_event(&self, timeout: Duration) -> TraceResult { + let millis = timeout.as_millis() as u32; + let rc = unsafe { WaitForSingleObject(self.ol.hEvent, millis) }; + if rc == WAIT_TIMEOUT { + return Ok(false); + } else if rc == WAIT_FAILED { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(true) + } -/// TODO -pub fn make_udp_send_socket_ipv4() -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + fn reset_event(&self) -> TraceResult<()> { + if !unsafe { WSAResetEvent(self.ol.hEvent) }.as_bool() { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(()) + } -/// TODO -pub fn make_recv_socket_ipv4(_addr: Ipv4Addr) -> TraceResult { - unimplemented!() -} + pub fn udp_from(target: IpAddr) -> TraceResult { + let s = match target { + IpAddr::V4(_) => Self::create(AF_INET, SOCK_DGRAM, IPPROTO_UDP), + IpAddr::V6(_) => Self::create(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), + }?; + Ok(s) + } -/// TODO -pub fn make_stream_socket_ipv4() -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn bind(&mut self, source_socketaddr: SocketAddr) -> TraceResult<&Self> { + let (addr, addrlen) = socketaddr_to_sockaddr(source_socketaddr); + if unsafe { bind(self.s, std::ptr::addr_of!(addr).cast(), addrlen as i32) } == SOCKET_ERROR + { + return Err(TracerError::IoError(Error::last_os_error())); + } + self.create_event()?; + Ok(self) + } -/// TODO -pub fn make_udp_dgram_socket_ipv4() -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn send_to(&self, packet: &[u8], dest_socketaddr: SocketAddr) -> TraceResult<()> { + let (addr, addrlen) = socketaddr_to_sockaddr(dest_socketaddr); + let rc = unsafe { + sendto( + self.s, + packet, + 0, + std::ptr::addr_of!(addr).cast(), + addrlen as i32, + ) + }; + if rc == SOCKET_ERROR { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(()) + } -/// TODO -pub fn make_icmp_send_socket_ipv6() -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + pub fn close(&self) -> TraceResult<()> { + if unsafe { closesocket(self.s) } == SOCKET_ERROR { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(()) + } -/// TODO -pub fn make_udp_send_socket_ipv6() -> TraceResult { - unimplemented!() -} + // NOTE FIONBIO is really unsigned (in WinSock2.h) + #[allow(clippy::cast_sign_loss)] + #[allow(unsafe_code)] + fn set_non_blocking(&self, is_non_blocking: bool) -> TraceResult<()> { + let non_blocking: u32 = u32::from(is_non_blocking); + if unsafe { + WSAIoctl( + self.s, + FIONBIO as u32, + Some(std::ptr::addr_of!(non_blocking).cast()), + size_of::().try_into().unwrap(), + None, + 0, + &mut 0, + None, + None, + ) + } == SOCKET_ERROR + { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(()) + } -/// TODO -pub fn make_recv_socket_ipv6(_addr: Ipv6Addr) -> TraceResult { - unimplemented!() -} + fn setsockopt_bool(&self, level: i32, optname: i32, optval: bool) -> Result<()> { + self.setsockopt_u32(level, optname, u32::from(optval)) + } -/// TODO -pub fn make_udp_dgram_socket_ipv6() -> TraceResult { - unimplemented!() -} + #[allow(unsafe_code)] + fn setsockopt_u32(&self, level: i32, optname: i32, optval: u32) -> Result<()> { + let bytes_array = optval.to_ne_bytes(); + let bytes_slice_ref_option = Some(&bytes_array[..]); + if unsafe { setsockopt(self.s, level, optname, bytes_slice_ref_option) } == SOCKET_ERROR { + Err(Error::last_os_error()) + } else { + Ok(()) + } + } -/// TODO -pub fn make_stream_socket_ipv6() -> TraceResult { - unimplemented!() -} + #[allow(clippy::cast_possible_wrap)] + fn set_header_included(&self, is_header_included: bool) -> TraceResult<()> { + self.setsockopt_bool(IPPROTO_IP as _, IP_HDRINCL as _, is_header_included) + .map_err(TracerError::IoError) + } -/// TODO -pub fn is_readable(_sock: &Socket, _timeout: Duration) -> TraceResult { - unimplemented!() -} + #[allow(clippy::cast_possible_wrap)] + pub fn set_ttl(&self, ttl: u32) -> TraceResult<()> { + self.setsockopt_u32(IPPROTO_IP as _, IP_TTL as _, ttl) + .map_err(TracerError::IoError) + } -/// TODO -pub fn is_writable(_sock: &Socket) -> TraceResult { - unimplemented!() -} + #[allow(clippy::cast_possible_wrap)] + pub fn ttl(&self) -> Result { + self.getsockopt(IPPROTO_IP as _, IP_TTL as _) + } -/// TODO -pub fn is_not_in_progress_error(_code: i32) -> bool { - unimplemented!() -} + #[allow(clippy::cast_possible_wrap)] + pub fn set_tos(&self, tos: u32) -> TraceResult<()> { + self.setsockopt_u32(IPPROTO_IP as _, IP_TOS as _, tos) + .map_err(TracerError::IoError) + } -/// TODO -pub fn is_conn_refused_error(_code: i32) -> bool { - unimplemented!() -} + #[allow(clippy::cast_possible_wrap)] + fn set_reuse_port(&self, is_reuse_port: bool) -> TraceResult<()> { + self.setsockopt_bool(SOL_SOCKET as _, SO_PORT_SCALABILITY as _, is_reuse_port) + .map_err(TracerError::IoError) + } -#[must_use] -pub fn is_host_unreachable_error(_code: i32) -> bool { - unimplemented!() -} + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn set_unicast_hops_v6(&self, max_hops: u8) -> TraceResult<&Self> { + if unsafe { + setsockopt( + self.s, + IPPROTO_IPV6.0, + IPV6_UNICAST_HOPS as i32, + Some(&[max_hops]), + ) + } == SOCKET_ERROR + { + return Err(TracerError::IoError(Error::last_os_error())); + } + Ok(self) + } -/// A network socket. -#[derive(Debug)] -pub struct Socket {} + #[allow(clippy::cast_possible_wrap)] + pub fn unicast_hops_v6(&self) -> Result { + self.getsockopt(IPPROTO_IPV6.0, IPV6_UNICAST_HOPS as _) + } -#[allow(clippy::unused_self)] -impl Socket { - /// TODO - #[allow(dead_code)] - pub fn new(_domain: (), _ty: (), _protocol: Option<()>) -> io::Result { - unimplemented!() + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn recv_from(&mut self) -> TraceResult<()> { + let mut fromlen = std::mem::size_of::() as i32; + let ret = unsafe { + WSARecvFrom( + self.s, + &[*self.wbuf], + Some(&mut 0), + &mut 0, + Some(std::ptr::addr_of_mut!(*self.from).cast()), + Some(&mut fromlen), + Some(&mut *self.ol), + None, + ) + }; + if ret == SOCKET_ERROR { + if Error::last_os_error().raw_os_error() != Some(WSA_IO_PENDING.0) { + return Err(TracerError::IoError(Error::last_os_error())); + } + } else { + // TODO no need to wait for an event, recv succeeded immediately! This should be handled + } + Ok(()) } - /// TODO - pub fn bind(&self, _address: SocketAddr) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + fn get_overlapped_result(&self) -> TraceResult<(u32, u32)> { + let mut bytes = 0; + let mut flags = 0; + let ol = *self.ol; + if unsafe { + WSAGetOverlappedResult(self, std::ptr::addr_of!(ol), &mut bytes, false, &mut flags) + } + .as_bool() + { + return Ok((bytes, flags)); + } + Err(TracerError::IoError(Error::from(ErrorKind::Other))) } - /// TODO - pub fn set_tos(&self, _tos: u32) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn connect(&self, dest_socketaddr: SocketAddr) -> Result<()> { + self.set_fail_connect_on_icmp_error(true)?; + if unsafe { WSAEventSelect(self.s, self.ol.hEvent, (FD_CONNECT | FD_WRITE) as _) } + == SOCKET_ERROR + { + eprintln!("WSAEventSelect failed: {}", Error::last_os_error()); + return Err(Error::last_os_error()); + } + let (addr, addrlen) = socketaddr_to_sockaddr(dest_socketaddr); + let rc = unsafe { connect(self.s, std::ptr::addr_of!(addr).cast(), addrlen as i32) }; + if rc == SOCKET_ERROR { + if Error::last_os_error().raw_os_error() != Some(WSAEWOULDBLOCK.0) { + return Err(Error::last_os_error()); + } + } else { + // TODO + } + Ok(()) } - /// TODO - pub fn set_ttl(&self, _ttl: u32) -> io::Result<()> { - unimplemented!() + #[allow(clippy::cast_possible_wrap)] + pub fn take_error(&self) -> Result> { + match self.getsockopt(SOL_SOCKET as _, SO_ERROR as _) { + Ok(0) => Ok(None), + Ok(errno) => Ok(Some(Error::from_raw_os_error(errno))), + Err(e) => Err(e), + } } - /// TODO - #[allow(dead_code)] - pub fn set_reuse_port(&self, _reuse: bool) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn icmp_error_info(&self) -> Result { + let icmp_error_info = + self.getsockopt::(IPPROTO_TCP.0 as _, TCP_ICMP_ERROR_INFO as _)?; + let src_addr = icmp_error_info.srcaddress; + match ADDRESS_FAMILY(u32::from(unsafe { src_addr.si_family })) { + AF_INET => Ok(IpAddr::V4(Ipv4Addr::from(unsafe { + src_addr.Ipv4.sin_addr.S_un.S_addr.to_ne_bytes() + }))), + AF_INET6 => Ok(IpAddr::V6(Ipv6Addr::from(unsafe { + src_addr.Ipv6.sin6_addr.u.Byte + }))), + _ => Err(Error::from(ErrorKind::AddrNotAvailable)), + } } - /// TODO - #[allow(dead_code)] - pub fn set_header_included(&self, _included: bool) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + fn getsockopt(&self, level: i32, optname: i32) -> Result { + let mut optval: MaybeUninit = MaybeUninit::uninit(); + let mut optlen = size_of::() as i32; + if unsafe { + getsockopt( + self.s, + level, + optname, + PSTR::from_raw(optval.as_mut_ptr().cast()), + &mut optlen, + ) + } == SOCKET_ERROR + { + return Err(Error::last_os_error()); + } + Ok(unsafe { optval.assume_init() }) } - /// TODO - #[allow(dead_code)] - pub fn set_nonblocking(&self, _nonblocking: bool) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn peer_addr(&self) -> Result> { + let mut name: MaybeUninit = MaybeUninit::uninit(); + let mut namelen = size_of::() as i32; + if unsafe { getpeername(self.s, name.as_mut_ptr().cast(), &mut namelen) } == SOCKET_ERROR { + return Err(Error::last_os_error()); + } + Ok(Some(sockaddr_to_socketaddr(unsafe { + &name.assume_init() + })?)) } - /// TODO - pub fn set_unicast_hops_v6(&self, _hops: u8) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + #[allow(clippy::cast_possible_wrap)] + pub fn shutdown(&self, how: Shutdown) -> Result<()> { + let how = match how { + Shutdown::Both => SD_BOTH, + Shutdown::Read => SD_RECEIVE, + Shutdown::Write => SD_SEND, + } as i32; + if unsafe { shutdown(self.s, how) } == SOCKET_ERROR { + return Err(Error::last_os_error()); + } + self.cleanup() } - /// TODO - pub fn connect(&self, _address: SocketAddr) -> io::Result<()> { - unimplemented!() + #[allow(unsafe_code)] + fn _is_writable_select(&self) -> TraceResult { + use windows::Win32::Networking::WinSock::{__WSAFDIsSet, select, FD_SET, TIMEVAL}; + let mut fds = FD_SET::default(); + let timeout = TIMEVAL::default(); + fds.fd_array[0] = self.s; + fds.fd_count = 1; + let rc = unsafe { select(1, None, Some(&mut fds), None, Some(&timeout)) }; + if rc == SOCKET_ERROR { + return Err(TracerError::IoError(Error::last_os_error())); + } + let fdisset = unsafe { __WSAFDIsSet(self.s, &mut fds) }; + Ok(fdisset != 0) } - /// TODO - pub fn send_to(&self, _buf: &[u8], _addr: SocketAddr) -> io::Result { - unimplemented!() + fn is_writable_overlapped(&self) -> TraceResult { + if !self.wait_for_event(Duration::ZERO)? { + return Ok(false); + }; + while self.get_overlapped_result().is_err() { + if Error::last_os_error().raw_os_error() != Some(WSA_IO_INCOMPLETE.0) { + return Err(TracerError::IoError(Error::last_os_error())); + } + } + self.reset_event()?; + Ok(true) } - /// TODO - pub fn recv_from(&self, _buf: &mut [u8]) -> io::Result<(usize, Option)> { - unimplemented!() + #[allow(unsafe_code)] + fn cleanup(&self) -> Result<()> { + let layout = + Layout::from_size_align(MAX_PACKET_SIZE, std::mem::align_of::()).unwrap(); + if unsafe { closesocket(self.s) } == SOCKET_ERROR { + return Err(Error::last_os_error()); + } + if !self.ol.hEvent.is_invalid() && unsafe { WSACloseEvent(self.ol.hEvent) } == false { + return Err(Error::last_os_error()); + } + unsafe { dealloc(self.wbuf.buf.as_ptr(), layout) }; + // TODO should we cleanup sock.from too? + Ok(()) + } + + #[allow(clippy::cast_possible_wrap)] + fn set_fail_connect_on_icmp_error(&self, enabled: bool) -> Result<()> { + self.setsockopt_bool(IPPROTO_TCP.0, TCP_FAIL_CONNECT_ON_ICMP_ERROR as _, enabled) + } +} + +impl fmt::Debug for Socket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Socket").field("s", &self.s).finish() + } +} + +impl convert::From for SOCKET { + fn from(sock: Socket) -> Self { + sock.s } +} +impl convert::From<&Socket> for SOCKET { + fn from(sock: &Socket) -> Self { + sock.s + } +} +impl convert::From<&mut Socket> for SOCKET { + fn from(sock: &mut Socket) -> Self { + sock.s + } +} - pub fn read(&mut self, _buf: &mut [u8]) -> io::Result { - unimplemented!() +#[allow(unsafe_code)] +pub fn startup() -> TraceResult<()> { + const WINSOCK_VERSION: u16 = 0x202; // 2.2 + let mut wsd = MaybeUninit::::zeroed(); + let rc = unsafe { WSAStartup(WINSOCK_VERSION, wsd.as_mut_ptr()) }; + // extracts the WSDATA to ensure it gets dropped (it's not used ATM) + unsafe { wsd.assume_init() }; + if rc == 0 { + Ok(()) + } else { + Err(TracerError::IoError(Error::last_os_error())) } +} + +#[allow(clippy::unnecessary_wraps)] +pub fn for_address(_src_addr: IpAddr) -> TraceResult { + Ok(PlatformIpv4FieldByteOrder::Network) +} - /// TODO - pub fn shutdown(&self, _how: Shutdown) -> io::Result<()> { - unimplemented!() +/// # Panics +/// +/// Will panic if `FriendlyName` or `FistUnicastAddress.Address.lpSockaddr` raw pointer members of the `IP_ADAPTER_ADDRESSES_LH` +/// linked list structure are null or misaligned. +// inspired by +#[allow(unsafe_code)] +fn lookup_interface_addr(family: ADDRESS_FAMILY, name: &str) -> TraceResult { + // Max tries allowed to call `GetAdaptersAddresses` on a loop basis + const MAX_TRIES: usize = 3; + let flags = IpHelper::GAA_FLAG_SKIP_ANYCAST + | IpHelper::GAA_FLAG_SKIP_MULTICAST + | IpHelper::GAA_FLAG_SKIP_DNS_SERVER; + // Initial buffer size is 15k per + let mut buf_len: u32 = 15000; + let mut layout; + let mut list_ptr; + let mut ip_adapter_address; + let mut res; + let mut i = 0; + + loop { + layout = match Layout::from_size_align( + buf_len as usize, + align_of::(), + ) { + Ok(layout) => layout, + Err(e) => { + return Err(TracerError::ErrorString(format!( + "Could not compute layout for {} words: {}", + buf_len, e + ))) + } + }; + list_ptr = unsafe { alloc(layout) }; + if list_ptr.is_null() { + return Err(TracerError::ErrorString(format!( + "Could not allocate {} words for layout {:?}", + buf_len, layout + ))); + } + ip_adapter_address = list_ptr.cast(); + + res = unsafe { + IpHelper::GetAdaptersAddresses( + family, + flags, + None, + Some(ip_adapter_address), + &mut buf_len, + ) + }; + i += 1; + + if res != ERROR_BUFFER_OVERFLOW.0 || i > MAX_TRIES { + break; + } + + unsafe { dealloc(list_ptr, layout) }; } - /// TODO - pub fn local_addr(&self) -> io::Result> { - unimplemented!() + if res != NO_ERROR.0 { + return Err(TracerError::ErrorString(format!( + "GetAdaptersAddresses returned error: {}", + Error::from_raw_os_error(res.try_into().unwrap()) + ))); } - /// TODO - #[allow(dead_code)] - pub fn as_raw_fd(&self) { - unimplemented!() + while !ip_adapter_address.is_null() { + let friendly_name = unsafe { (*ip_adapter_address).FriendlyName.to_string().unwrap() }; + if name == friendly_name { + // NOTE this really should be a while over the linked list of FistUnicastAddress, and current_unicast would then be mutable + // however, this is not supported by our function signature + let current_unicast = unsafe { (*ip_adapter_address).FirstUnicastAddress }; + // while !current_unicast.is_null() { + unsafe { + let socket_address = (*current_unicast).Address; + // let sockaddr = socket_address.lpSockaddr.as_ref().unwrap(); + let sockaddr = socket_address.lpSockaddr; + let ip_addr = sockaddrptr_to_ipaddr(sockaddr.cast()); + dealloc(list_ptr, layout); + return ip_addr; + } + // current_unicast = unsafe { (*current_unicast).Next }; + // } + } + ip_adapter_address = unsafe { (*ip_adapter_address).Next }; } - /// TODO - #[allow(dead_code)] - pub fn unicast_hops_v6(&self) -> io::Result { - unimplemented!() + unsafe { + dealloc(list_ptr, layout); } - /// TODO - pub fn peer_addr(&self) -> io::Result> { - unimplemented!() + Err(TracerError::UnknownInterface(format!( + "could not find address for {}", + name + ))) +} + +#[allow(unsafe_code)] +fn sockaddrptr_to_ipaddr(sockaddr: *mut SOCKADDR_STORAGE) -> TraceResult { + match sockaddr_to_socketaddr(unsafe { sockaddr.as_ref().unwrap() }) { + Err(e) => Err(TracerError::IoError(e)), + Ok(socketaddr) => match socketaddr { + SocketAddr::V4(socketaddrv4) => Ok(IpAddr::V4(*socketaddrv4.ip())), + SocketAddr::V6(socketaddrv6) => Ok(IpAddr::V6(*socketaddrv6.ip())), + }, } +} - /// TODO - pub fn take_error(&self) -> io::Result> { - unimplemented!() +#[allow(unsafe_code)] +pub fn sockaddr_to_socketaddr(sockaddr: &SOCKADDR_STORAGE) -> Result { + let ptr = sockaddr as *const SOCKADDR_STORAGE; + let af = u32::from(sockaddr.ss_family); + if af == AF_INET.0 { + let sockaddr_in_ptr = ptr.cast::(); + let sockaddr_in = unsafe { *sockaddr_in_ptr }; + let ipv4addr = sockaddr_in.sin_addr; + let port = sockaddr_in.sin_port; + Ok(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::from(ipv4addr), + port, + ))) + } else if af == AF_INET6.0 { + #[allow(clippy::cast_ptr_alignment)] + let sockaddr_in6_ptr = ptr.cast::(); + let sockaddr_in6 = unsafe { *sockaddr_in6_ptr }; + let ipv6addr = sockaddr_in6.sin6_addr; + let port = sockaddr_in6.sin6_port; + Ok(SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from(ipv6addr), + port, + sockaddr_in6.sin6_flowinfo, + unsafe { sockaddr_in6.Anonymous.sin6_scope_id }, + ))) + } else { + Err(Error::new( + ErrorKind::Unsupported, + format!("Unsupported address family: {af}"), + )) } +} - /// TODO - #[allow(clippy::unused_self)] - pub fn icmp_error_info(&self) -> io::Result { - unimplemented!() +#[allow(unsafe_code)] +#[must_use] +// TODO this allocate a SOCKADDR_STORAGE, should we drop it manually later? +fn socketaddr_to_sockaddr(socketaddr: SocketAddr) -> (SOCKADDR_STORAGE, u32) { + let (paddr, addrlen): (*const SOCKADDR_STORAGE, u32) = match socketaddr { + SocketAddr::V4(socketaddrv4) => { + let sa: SOCKADDR_IN = socketaddrv4.into(); + ( + std::ptr::addr_of!(sa).cast(), + size_of::() as u32, + ) + } + SocketAddr::V6(socketaddrv6) => { + let sa: SOCKADDR_IN6 = socketaddrv6.into(); + ( + std::ptr::addr_of!(sa).cast(), + size_of::() as u32, + ) + } + }; + (unsafe { *paddr }, addrlen) +} + +pub fn lookup_interface_addr_ipv4(name: &str) -> TraceResult { + lookup_interface_addr(AF_INET, name) +} + +pub fn lookup_interface_addr_ipv6(name: &str) -> TraceResult { + lookup_interface_addr(AF_INET6, name) +} + +#[allow(unsafe_code)] +pub fn routing_interface_query(target: IpAddr) -> TraceResult { + let src: *mut c_void = [0; 1024].as_mut_ptr().cast(); + let bytes = MaybeUninit::::uninit().as_mut_ptr(); + let s = Socket::udp_from(target)?; + let (dest, destlen) = socketaddr_to_sockaddr(SocketAddr::new(target, 0)); + let rc = unsafe { + WSAIoctl( + s, + SIO_ROUTING_INTERFACE_QUERY, + Some(std::ptr::addr_of!(dest).cast()), + destlen, + Some(src), + 1024, + bytes, + None, + None, + ) + }; + if rc == SOCKET_ERROR { + eprintln!( + "routing_interface_query: WSAIoctl failed: {}", + Error::last_os_error() + ); + return Err(TracerError::IoError(Error::last_os_error())); } + + // Note that the WSAIoctl call potentially returns multiple results (see + // ), + // TBD We choose the first one arbitrarily. + let sockaddr = src.cast::(); + sockaddrptr_to_ipaddr(sockaddr) +} + +pub fn make_icmp_send_socket_ipv4() -> TraceResult { + let sock = Socket::create(AF_INET, SOCK_RAW, IPPROTO_RAW)?; + sock.set_non_blocking(true)?; + sock.set_header_included(true)?; + Ok(sock) +} + +pub fn make_udp_send_socket_ipv4() -> TraceResult { + let sock = Socket::create(AF_INET, SOCK_RAW, IPPROTO_RAW)?; + sock.set_non_blocking(true)?; + sock.set_header_included(true)?; + Ok(sock) +} + +pub fn make_recv_socket_ipv4(src_addr: Ipv4Addr) -> TraceResult { + let mut sock = Socket::create(AF_INET, SOCK_RAW, IPPROTO_ICMP)?; + sock.bind(SocketAddr::new(IpAddr::V4(src_addr), 0))?; + sock.recv_from()?; + sock.set_non_blocking(true)?; + sock.set_header_included(true)?; + Ok(sock) } -impl io::Read for Socket { - fn read(&mut self, _buf: &mut [u8]) -> io::Result { - unimplemented!() +pub fn make_icmp_send_socket_ipv6() -> TraceResult { + let sock = Socket::create(AF_INET6, SOCK_RAW, IPPROTO_ICMPV6)?; + sock.set_non_blocking(true)?; + Ok(sock) +} + +pub fn make_udp_send_socket_ipv6() -> TraceResult { + let sock = Socket::create(AF_INET6, SOCK_RAW, IPPROTO_UDP)?; + sock.set_non_blocking(true)?; + Ok(sock) +} + +pub fn make_recv_socket_ipv6(src_addr: Ipv6Addr) -> TraceResult { + let mut sock = Socket::create(AF_INET6, SOCK_RAW, IPPROTO_ICMPV6)?; + sock.bind(SocketAddr::new(IpAddr::V6(src_addr), 0))?; + sock.recv_from()?; + sock.set_non_blocking(true)?; + Ok(sock) +} + +pub fn make_stream_socket_ipv4() -> TraceResult { + let sock = Socket::create(AF_INET, SOCK_STREAM, IPPROTO_TCP)?; + sock.set_non_blocking(true)?; + sock.set_reuse_port(true)?; + Ok(sock) +} + +#[allow(dead_code)] +pub fn make_udp_dgram_socket_ipv4() -> TraceResult { + Socket::create(AF_INET, SOCK_DGRAM, IPPROTO_UDP) +} + +#[allow(dead_code)] +pub fn make_udp_dgram_socket_ipv6() -> TraceResult { + Socket::create(AF_INET6, SOCK_DGRAM, IPPROTO_UDP) +} + +pub fn make_stream_socket_ipv6() -> TraceResult { + let sock = Socket::create(AF_INET6, SOCK_STREAM, IPPROTO_TCP)?; + sock.set_non_blocking(true)?; + sock.set_reuse_port(true)?; + Ok(sock) +} + +pub fn is_readable(sock: &Socket, timeout: Duration) -> TraceResult { + if !sock.wait_for_event(timeout)? { + return Ok(false); + }; + while sock.get_overlapped_result().is_err() { + if Error::last_os_error().raw_os_error() != Some(WSA_IO_INCOMPLETE.0) { + return Err(TracerError::IoError(Error::last_os_error())); + } } + sock.reset_event()?; + Ok(true) +} + +pub fn is_writable(sock: &Socket) -> TraceResult { + sock.is_writable_overlapped() +} + +#[must_use] +pub fn is_not_in_progress_error(code: i32) -> bool { + code != WSAEINPROGRESS.0 +} + +#[must_use] +pub fn is_conn_refused_error(code: i32) -> bool { + code == WSAECONNREFUSED.0 +} + +#[must_use] +pub fn is_host_unreachable_error(code: i32) -> bool { + code == WSAEHOSTUNREACH.0 } diff --git a/src/tracing/net/source.rs b/src/tracing/net/source.rs index 5d0b30cb1..05e2f97d4 100644 --- a/src/tracing/net/source.rs +++ b/src/tracing/net/source.rs @@ -1,10 +1,18 @@ use crate::tracing::error::TraceResult; +#[cfg(windows)] +use crate::tracing::error::TracerError; +#[cfg(unix)] use crate::tracing::error::TracerError::InvalidSourceAddr; use crate::tracing::net::platform; +#[cfg(windows)] +use crate::tracing::net::platform::routing_interface_query; use crate::tracing::net::platform::Socket; use crate::tracing::types::Port; +#[cfg(unix)] use crate::tracing::util::Required; use crate::tracing::PortDirection; +#[cfg(windows)] +use std::io::Error; use std::net::{IpAddr, SocketAddr}; /// The port used for local address discovery if not dest port is available. @@ -13,6 +21,8 @@ const DISCOVERY_PORT: Port = Port(80); /// Discover or validate a source address. pub struct SourceAddr; +// TODO remove platform specific code from here + impl SourceAddr { /// Discover the source `IpAddr`. pub fn discover( @@ -28,25 +38,54 @@ impl SourceAddr { } /// Validate that we can bind to the source address. + #[cfg(unix)] pub fn validate(source_addr: IpAddr) -> TraceResult { + #[cfg(unix)] let socket = udp_socket_for_addr_family(source_addr)?; + #[cfg(windows)] + let mut socket = udp_socket_for_addr_family(source_addr)?; let sock_addr = SocketAddr::new(source_addr, 0); match socket.bind(sock_addr) { Ok(_) => Ok(source_addr), Err(_) => Err(InvalidSourceAddr(sock_addr.ip())), } } + + #[cfg(windows)] + pub fn validate(source_addr: IpAddr) -> TraceResult { + match Socket::udp_from(source_addr)?.bind(SocketAddr::new(source_addr, 0)) { + Err(_) => Err(TracerError::IoError(Error::last_os_error())), + Ok(s) => { + s.close()?; + Ok(source_addr) + } + } + } } /// Discover the local `IpAddr` that will be used to communicate with the given target `IpAddr`. /// /// Note that no packets are transmitted by this method. +#[cfg(unix)] fn discover_local_addr(target_addr: IpAddr, port: u16) -> TraceResult { let socket = udp_socket_for_addr_family(target_addr)?; socket.connect(SocketAddr::new(target_addr, port))?; Ok(socket.local_addr()?.req()?.ip()) } +#[cfg(windows)] +fn discover_local_addr(target: IpAddr, _port: u16) -> TraceResult { + /* + NOTE under Windows, we cannot use a blind connect/getsockname as "If the socket + is using a connectionless protocol, the address may not be available until I/O + occurs on the socket." + We use SIO_ROUTING_INTERFACE_QUERY instead. + */ + + routing_interface_query(target) +} + +#[cfg(unix)] /// Create a socket suitable for a given address. fn udp_socket_for_addr_family(addr: IpAddr) -> TraceResult { Ok(match addr { diff --git a/src/tracing/probe.rs b/src/tracing/probe.rs index 7b761dd2d..bef164313 100644 --- a/src/tracing/probe.rs +++ b/src/tracing/probe.rs @@ -117,6 +117,7 @@ pub enum ProbeResponse { EchoReply(ProbeResponseData), TcpReply(ProbeResponseData), TcpRefused(ProbeResponseData), + TcpTimeExceeded(ProbeResponseData), } /// The data in the probe response. diff --git a/src/tracing/tracer.rs b/src/tracing/tracer.rs index d29dd2323..5f04ec9b5 100644 --- a/src/tracing/tracer.rs +++ b/src/tracing/tracer.rs @@ -152,7 +152,7 @@ impl)> Tracer { fn recv_response(&self, network: &mut N, st: &mut TracerState) -> TraceResult<()> { let next = network.recv_probe()?; match next { - Some(ProbeResponse::TimeExceeded(data)) => { + Some(ProbeResponse::TimeExceeded(data) | ProbeResponse::TcpTimeExceeded(data)) => { let sequence = Sequence(data.sequence); let received = data.recv; let host = data.addr;