Skip to content

Commit

Permalink
ref: Move packet building into conn (#1016)
Browse files Browse the repository at this point in the history
As suggested by @ramfox in #830 (comment)
  • Loading branch information
rklaehn authored May 16, 2023
1 parent ee2bafb commit 3142912
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 265 deletions.
203 changes: 21 additions & 182 deletions src/hp/derp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl<R: AsyncRead + Unpin> Client<R> {
/// Sends a packet to the node identified by `dstkey`
///
/// Errors if the packet is larger than [`super::MAX_PACKET_SIZE`]
pub async fn send(&self, dstkey: PublicKey, packet: Vec<Bytes>) -> Result<()> {
pub async fn send(&self, dstkey: PublicKey, packet: Bytes) -> Result<()> {
debug!("[DERP] -> {:?} ({}b)", dstkey, packet.len());

self.inner
Expand Down Expand Up @@ -304,7 +304,7 @@ impl<R: AsyncRead + Unpin> Client<R> {
#[derive(Debug)]
enum ClientWriterMessage {
/// Send a packet (addressed to the [`PublicKey`]) to the server
Packet((PublicKey, Vec<Bytes>)),
Packet((PublicKey, Bytes)),
/// Forward a packet from the src [`PublicKey`] to the dst [`PublicKey`] to the server
/// Should only be used for mesh clients.
FwdPacket((PublicKey, PublicKey, Bytes)),
Expand Down Expand Up @@ -345,7 +345,7 @@ impl<W: AsyncWrite + Unpin + Send + 'static> ClientWriter<W> {
Some(ClientWriterMessage::Packet((key, bytes))) => {
// TODO: the rate limiter is only used on this method, is it because it's the only method that
// theoretically sends a bunch of data, or is it an oversight? For example, the `forward_packet` method does not have a rate limiter, but _does_ have a timeout.
send_packets(&mut self.writer, &self.rate_limiter, key, &bytes).await?;
send_packet(&mut self.writer, &self.rate_limiter, key, &bytes).await?;
}
Some(ClientWriterMessage::FwdPacket((srckey, dstkey, bytes))) => {
tokio::time::timeout(
Expand Down Expand Up @@ -603,65 +603,30 @@ pub enum ReceivedMessage {
},
}

pub(crate) async fn send_packets<W: AsyncWrite + Unpin>(
pub(crate) async fn send_packet<W: AsyncWrite + Unpin>(
mut writer: W,
rate_limiter: &Option<RateLimiter>,
dstkey: PublicKey,
transmits: &[impl AsRef<[u8]>],
packet: &[u8],
) -> Result<()> {
let mut total_len: usize = 0;
// make sure no transmit is too big for a derp packet.
// also compute the total length of all transmits for logging.
for transmit in transmits {
let len = transmit.as_ref().len();
total_len += len;
ensure!(
len + PUBLIC_KEY_LENGTH <= MAX_PACKET_SIZE,
"packet too big: {}",
len
);
}
// disco packets must be sent as-is
if transmits.len() == 1
&& transmits[0]
.as_ref()
.starts_with(crate::hp::disco::MAGIC.as_bytes())
{
let packet = transmits[0].as_ref();
if let Some(rate_limiter) = rate_limiter {
let frame_len = PUBLIC_KEY_LENGTH + packet.len();
if rate_limiter.check_n(frame_len).is_err() {
tracing::warn!("dropping send: rate limit reached");
return Ok(());
}
}
write_frame(
&mut writer,
FrameType::SendPacket,
&[dstkey.as_bytes(), packet],
)
.await?;
} else {
tracing::trace!("send derp packets {} {}", transmits.len(), total_len);
const PAYLAOD_SIZE: usize = MAX_PACKET_SIZE - PUBLIC_KEY_LENGTH;
for packet in PacketizeIter::<_, PAYLAOD_SIZE>::new(transmits) {
// rate limit for each packet, but exit early if the rate limit is exceeded.
// it is unlikely to recover that quickly.
if let Some(rate_limiter) = rate_limiter {
let frame_len = PUBLIC_KEY_LENGTH + packet.len();
if rate_limiter.check_n(frame_len).is_err() {
tracing::warn!("dropping send: rate limit reached");
return Ok(());
}
}
write_frame(
&mut writer,
FrameType::SendPacket,
&[dstkey.as_bytes(), packet.as_ref()],
)
.await?;
ensure!(
packet.len() <= MAX_PACKET_SIZE,
"packet too big: {}",
packet.len()
);
let frame_len = PUBLIC_KEY_LENGTH + packet.len();
if let Some(rate_limiter) = rate_limiter {
if rate_limiter.check_n(frame_len).is_err() {
tracing::warn!("dropping send: rate limit reached");
return Ok(());
}
}
write_frame(
&mut writer,
FrameType::SendPacket,
&[dstkey.as_bytes(), packet],
)
.await?;
writer.flush().await?;
Ok(())
}
Expand Down Expand Up @@ -747,129 +712,3 @@ pub(crate) fn parse_recv_frame(frame: BytesMut) -> Result<(PublicKey, Bytes)> {
frame.freeze().slice(PUBLIC_KEY_LENGTH..),
))
}

/// Combines blobs into packets of at most MAX_PACKET_SIZE.
///
/// Each item in a packet has a little-endian 2-byte length prefix.
pub struct PacketizeIter<I: Iterator, const N: usize> {
iter: std::iter::Peekable<I>,
buffer: BytesMut,
}

impl<I: Iterator, const N: usize> PacketizeIter<I, N> {
/// Create a new new PacketizeIter from something that can be turned into an
/// iterator of slices, like a Vec<Bytes>.
pub fn new(iter: impl IntoIterator<IntoIter = I>) -> Self {
Self {
iter: iter.into_iter().peekable(),
buffer: BytesMut::with_capacity(N),
}
}
}

impl<I: Iterator, const N: usize> Iterator for PacketizeIter<I, N>
where
I::Item: AsRef<[u8]>,
{
type Item = Bytes;

fn next(&mut self) -> Option<Self::Item> {
use bytes::BufMut;
while let Some(next_bytes) = self.iter.peek() {
let next_bytes = next_bytes.as_ref();
assert!(next_bytes.len() + 2 <= N);
let next_length: u16 = next_bytes.len().try_into().expect("items < 64k size");
if self.buffer.len() + next_bytes.len() + 2 > N {
break;
}
self.buffer.put_u16_le(next_length);
self.buffer.put_slice(next_bytes);
self.iter.next();
}
if !self.buffer.is_empty() {
Some(self.buffer.split().freeze())
} else {
None
}
}
}

/// Splits a packet into its component items.
pub struct PacketSplitIter {
bytes: Bytes,
}

impl PacketSplitIter {
/// Create a new PacketSplitIter from a packet.
///
/// Returns an error if the packet is too big.
pub fn new(bytes: Bytes) -> Self {
Self { bytes }
}

#[cfg(test)]
pub fn split(packet: Bytes) -> std::io::Result<Vec<Bytes>> {
Self::new(packet).collect()
}

fn fail(&mut self) -> Option<std::io::Result<Bytes>> {
self.bytes.clear();
Some(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"",
)))
}
}

impl Iterator for PacketSplitIter {
type Item = std::io::Result<Bytes>;

fn next(&mut self) -> Option<Self::Item> {
use bytes::Buf;
if self.bytes.has_remaining() {
if self.bytes.remaining() < 2 {
return self.fail();
}
let len = self.bytes.get_u16_le() as usize;
if self.bytes.remaining() < len {
return self.fail();
}
let item = self.bytes.split_to(len);
Some(Ok(item))
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_empty() {
let empty_vec: Vec<Bytes> = Vec::new();
let mut iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(empty_vec);
assert_eq!(None, iter.next());
}

#[test]
fn test_single_result() {
let single_vec = vec!["Hello"];
let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(single_vec);
let result = iter.collect::<Vec<_>>();
assert_eq!(1, result.len());
assert_eq!(&[5, 0, b'H', b'e', b'l', b'l', b'o'], &result[0][..]);
}

#[test]
fn test_multiple_results() {
let spacer = vec![0u8; MAX_PACKET_SIZE - 10];
let multiple_vec = vec![&b"Hello"[..], &spacer, &b"World"[..]];
let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(multiple_vec);
let result = iter.collect::<Vec<_>>();
assert_eq!(2, result.len());
assert_eq!(&[5, 0, b'H', b'e', b'l', b'l', b'o'], &result[0][..7]);
assert_eq!(&[5, 0, b'W', b'o', b'r', b'l', b'd'], &result[1][..]);
}
}
9 changes: 3 additions & 6 deletions src/hp/derp/client_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,8 +734,6 @@ fn parse_send_packet(data: &[u8]) -> Result<(PublicKey, &[u8])> {

#[cfg(test)]
mod tests {
use crate::hp::derp::client::PacketSplitIter;

use super::*;
use anyhow::bail;
use std::sync::Arc;
Expand Down Expand Up @@ -808,14 +806,13 @@ mod tests {

// send packet
let data = b"hello world!";
crate::hp::derp::client::send_packets(&mut writer, &None, target.clone(), &[data]).await?;
crate::hp::derp::client::send_packet(&mut writer, &None, target.clone(), &data[..]).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
ServerMessage::SendPacket((got_target, packet)) => {
let payload = PacketSplitIter::split(packet.bytes)?;
assert_eq!(target, got_target);
assert_eq!(key, packet.src);
assert_eq!(&data[..], &payload[0]);
assert_eq!(&data[..], &packet.bytes);
}
m => {
bail!("expected ServerMessage::SendPacket, got {m:?}");
Expand All @@ -827,7 +824,7 @@ mod tests {
let mut disco_data = crate::hp::disco::MAGIC.as_bytes().to_vec();
disco_data.extend_from_slice(target.as_bytes());
disco_data.extend_from_slice(data);
crate::hp::derp::client::send_packets(&mut writer, &None, target.clone(), &[&disco_data])
crate::hp::derp::client::send_packet(&mut writer, &None, target.clone(), &disco_data)
.await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
Expand Down
32 changes: 12 additions & 20 deletions src/hp/derp/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ mod tests {
use tokio_util::sync::CancellationToken;
use tracing_subscriber::{prelude::*, EnvFilter};

use crate::hp::derp::client::PacketSplitIter;
use crate::hp::derp::{DerpNode, DerpRegion, ReceivedMessage, UseIpv4, UseIpv6};
use crate::hp::{
derp::Server as DerpServer,
Expand Down Expand Up @@ -131,15 +130,15 @@ mod tests {

println!("sending message from a to b");
let msg = Bytes::from_static(b"hi there, client b!");
client_a.send(b_key.clone(), vec![msg.clone()]).await?;
client_a.send(b_key.clone(), msg.clone()).await?;
println!("waiting for message from a on b");
let (got_key, got_msg) = b_recv.recv().await.expect("expected message from client_a");
assert_eq!(a_key, got_key);
assert_eq!(msg, got_msg);

println!("sending message from b to a");
let msg = Bytes::from_static(b"right back at ya, client b!");
client_b.send(a_key.clone(), vec![msg.clone()]).await?;
client_b.send(a_key.clone(), msg.clone()).await?;
println!("waiting for message b on a");
let (got_key, got_msg) = a_recv.recv().await.expect("expected message from client_b");
assert_eq!(b_key, got_key);
Expand Down Expand Up @@ -181,23 +180,16 @@ mod tests {
Ok((msg, _)) => {
println!("got message on {:?}: {msg:?}", key.public_key());
if let ReceivedMessage::ReceivedPacket { source, data } = msg {
let iter = PacketSplitIter::new(data);
for packet in iter {
let Ok(data) = packet else {
tracing::warn!("error parsing packet");
return;
};
received_msg_s
.send((source.clone(), data))
.await
.unwrap_or_else(|err| {
panic!(
"client {:?}, error sending message over channel: {:?}",
key.public_key(),
err
)
});
}
received_msg_s
.send((source.clone(), data))
.await
.unwrap_or_else(|err| {
panic!(
"client {:?}, error sending message over channel: {:?}",
key.public_key(),
err
)
});
}
}
}
Expand Down
6 changes: 1 addition & 5 deletions src/hp/derp/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,7 @@ impl Client {
///
/// If there is an error sending the packet, it closes the underlying derp connection before
/// returning.
pub async fn send(
&self,
dst_key: key::node::PublicKey,
b: Vec<Bytes>,
) -> Result<(), ClientError> {
pub async fn send(&self, dst_key: key::node::PublicKey, b: Bytes) -> Result<(), ClientError> {
debug!("send");
let (client, _) = self.connect().await?;
if client.send(dst_key, b).await.is_err() {
Expand Down
Loading

0 comments on commit 3142912

Please sign in to comment.