diff --git a/src/transport/util.rs b/src/transport/util.rs index a279e5d..c070444 100644 --- a/src/transport/util.rs +++ b/src/transport/util.rs @@ -1,189 +1,191 @@ -use std::{ - pin::Pin, - task::{self, Poll}, -}; - -use futures_lite::Stream; -use futures_sink::Sink; -use pin_project::pin_project; -use serde::{de::DeserializeOwned, Serialize}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::LengthDelimitedCodec; - -#[pin_project] -pub struct FramedPostcardRead( - #[pin] - tokio_serde::SymmetricallyFramed< - tokio_util::codec::FramedRead, - In, - tokio_serde_postcard::SymmetricalPostcard, - >, -); - -impl FramedPostcardRead { - /// Wrap a socket in a length delimited codec and postcard encoding - pub fn new(inner: T, max_frame_length: usize) -> Self { - // configure length delimited codec with max frame length - let framing = LengthDelimitedCodec::builder() - .max_frame_length(max_frame_length) - .new_codec(); - // create the actual framing. This turns the AsyncRead/AsyncWrite into a Stream/Sink of Bytes/BytesMut - let framed = tokio_util::codec::FramedRead::new(inner, framing); - let postcard = tokio_serde_postcard::Postcard::new(); - // create the actual framing. This turns the Stream/Sink of Bytes/BytesMut into a Stream/Sink of In/Out - let framed = tokio_serde::Framed::new(framed, postcard); - Self(framed) +mod deps { + use std::{ + pin::Pin, + task::{self, Poll}, + }; + + use futures_lite::Stream; + use futures_sink::Sink; + use pin_project::pin_project; + use serde::{de::DeserializeOwned, Serialize}; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio_util::codec::LengthDelimitedCodec; + + #[pin_project] + pub struct FramedPostcardRead( + #[pin] + tokio_serde::SymmetricallyFramed< + tokio_util::codec::FramedRead, + In, + tokio_serde_postcard::SymmetricalPostcard, + >, + ); + + impl FramedPostcardRead { + /// Wrap a socket in a length delimited codec and postcard encoding + pub fn new(inner: T, max_frame_length: usize) -> Self { + // configure length delimited codec with max frame length + let framing = LengthDelimitedCodec::builder() + .max_frame_length(max_frame_length) + .new_codec(); + // create the actual framing. This turns the AsyncRead/AsyncWrite into a Stream/Sink of Bytes/BytesMut + let framed = tokio_util::codec::FramedRead::new(inner, framing); + let postcard = tokio_serde_postcard::Postcard::new(); + // create the actual framing. This turns the Stream/Sink of Bytes/BytesMut into a Stream/Sink of In/Out + let framed = tokio_serde::Framed::new(framed, postcard); + Self(framed) + } } -} -impl FramedPostcardRead { - /// Get the underlying binary stream - /// - /// This can be useful if you want to drop the framing and use the underlying stream directly - /// after exchanging some messages. - pub fn into_inner(self) -> T { - self.0.into_inner().into_inner() + impl FramedPostcardRead { + /// Get the underlying binary stream + /// + /// This can be useful if you want to drop the framing and use the underlying stream directly + /// after exchanging some messages. + pub fn into_inner(self) -> T { + self.0.into_inner().into_inner() + } } -} -impl Stream for FramedPostcardRead { - type Item = Result; + impl Stream for FramedPostcardRead { + type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - Pin::new(&mut self.project().0).poll_next(cx) + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.project().0).poll_next(cx) + } } -} -/// Wrapper that wraps a bidirectional binary stream in a length delimited codec and postcard encoding -/// to get a bidirectional stream of rpc Messages -#[pin_project] -pub struct FramedPostcardWrite( - #[pin] - tokio_serde::SymmetricallyFramed< - tokio_util::codec::FramedWrite, - Out, - tokio_serde_postcard::SymmetricalPostcard, - >, -); - -impl FramedPostcardWrite { - /// Wrap a socket in a length delimited codec and postcard encoding - pub fn new(inner: T, max_frame_length: usize) -> Self { - // configure length delimited codec with max frame length - let framing = LengthDelimitedCodec::builder() - .max_frame_length(max_frame_length) - .new_codec(); - // create the actual framing. This turns the AsyncRead/AsyncWrite into a Stream/Sink of Bytes/BytesMut - let framed = tokio_util::codec::FramedWrite::new(inner, framing); - let postcard = tokio_serde_postcard::SymmetricalPostcard::new(); - // create the actual framing. This turns the Stream/Sink of Bytes/BytesMut into a Stream/Sink of In/Out - let framed = tokio_serde::SymmetricallyFramed::new(framed, postcard); - Self(framed) + /// Wrapper that wraps a bidirectional binary stream in a length delimited codec and postcard encoding + /// to get a bidirectional stream of rpc Messages + #[pin_project] + pub struct FramedPostcardWrite( + #[pin] + tokio_serde::SymmetricallyFramed< + tokio_util::codec::FramedWrite, + Out, + tokio_serde_postcard::SymmetricalPostcard, + >, + ); + + impl FramedPostcardWrite { + /// Wrap a socket in a length delimited codec and postcard encoding + pub fn new(inner: T, max_frame_length: usize) -> Self { + // configure length delimited codec with max frame length + let framing = LengthDelimitedCodec::builder() + .max_frame_length(max_frame_length) + .new_codec(); + // create the actual framing. This turns the AsyncRead/AsyncWrite into a Stream/Sink of Bytes/BytesMut + let framed = tokio_util::codec::FramedWrite::new(inner, framing); + let postcard = tokio_serde_postcard::SymmetricalPostcard::new(); + // create the actual framing. This turns the Stream/Sink of Bytes/BytesMut into a Stream/Sink of In/Out + let framed = tokio_serde::SymmetricallyFramed::new(framed, postcard); + Self(framed) + } } -} -impl FramedPostcardWrite { - /// Get the underlying binary stream - /// - /// This can be useful if you want to drop the framing and use the underlying stream directly - /// after exchanging some messages. - pub fn into_inner(self) -> T { - self.0.into_inner().into_inner() + impl FramedPostcardWrite { + /// Get the underlying binary stream + /// + /// This can be useful if you want to drop the framing and use the underlying stream directly + /// after exchanging some messages. + pub fn into_inner(self) -> T { + self.0.into_inner().into_inner() + } } -} -impl Sink for FramedPostcardWrite { - type Error = std::io::Error; + impl Sink for FramedPostcardWrite { + type Error = std::io::Error; - fn poll_ready( - self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.project().0).poll_ready(cx) - } + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.project().0).poll_ready(cx) + } - fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { - Pin::new(&mut self.project().0).start_send(item) - } + fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + Pin::new(&mut self.project().0).start_send(item) + } - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.project().0).poll_flush(cx) - } + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.project().0).poll_flush(cx) + } - fn poll_close( - self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.project().0).poll_close(cx) + fn poll_close( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.project().0).poll_close(cx) + } } -} -mod tokio_serde_postcard { - use { - bytes::{BufMut as _, Bytes, BytesMut}, - pin_project::pin_project, - serde::{Deserialize, Serialize}, - std::{io, marker::PhantomData, pin::Pin}, - tokio_serde::{Deserializer, Serializer}, - }; - - #[pin_project] - pub struct Postcard { - #[pin] - buffer: Box>, - _marker: PhantomData<(Item, SinkItem)>, - } + mod tokio_serde_postcard { + use { + bytes::{BufMut as _, Bytes, BytesMut}, + pin_project::pin_project, + serde::{Deserialize, Serialize}, + std::{io, marker::PhantomData, pin::Pin}, + tokio_serde::{Deserializer, Serializer}, + }; + + #[pin_project] + pub struct Postcard { + #[pin] + buffer: Box>, + _marker: PhantomData<(Item, SinkItem)>, + } - impl Default for Postcard { - fn default() -> Self { - Self::new() + impl Default for Postcard { + fn default() -> Self { + Self::new() + } } - } - impl Postcard { - pub fn new() -> Self { - Self { - buffer: Box::new(None), - _marker: PhantomData, + impl Postcard { + pub fn new() -> Self { + Self { + buffer: Box::new(None), + _marker: PhantomData, + } } } - } - pub type SymmetricalPostcard = Postcard; + pub type SymmetricalPostcard = Postcard; - impl Deserializer for Postcard - where - for<'a> Item: Deserialize<'a>, - { - type Error = io::Error; + impl Deserializer for Postcard + where + for<'a> Item: Deserialize<'a>, + { + type Error = io::Error; - fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result { - postcard::from_bytes(&src) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result { + postcard::from_bytes(&src) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } } - } - impl Serializer for Postcard - where - SinkItem: Serialize, - { - type Error = io::Error; - - fn serialize(self: Pin<&mut Self>, data: &SinkItem) -> Result { - let mut this = self.project(); - let buffer = this.buffer.take().unwrap_or_default(); - let mut buffer = postcard::to_io(data, buffer.writer()) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? - .into_inner(); - if buffer.len() <= 1024 { - let res = buffer.split().freeze(); - this.buffer.replace(buffer); - Ok(res) - } else { - Ok(buffer.freeze()) + impl Serializer for Postcard + where + SinkItem: Serialize, + { + type Error = io::Error; + + fn serialize(self: Pin<&mut Self>, data: &SinkItem) -> Result { + let mut this = self.project(); + let buffer = this.buffer.take().unwrap_or_default(); + let mut buffer = postcard::to_io(data, buffer.writer()) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .into_inner(); + if buffer.len() <= 1024 { + let res = buffer.split().freeze(); + this.buffer.replace(buffer); + Ok(res) + } else { + Ok(buffer.freeze()) + } } } } @@ -213,13 +215,13 @@ mod direct { _phantom: PhantomData, } - impl FramedPostcardWrite - where - W: AsyncWrite, - Out: Serialize, - { + impl FramedPostcardWrite { /// Creates a new `FramedPostcardWrite` with the provided `AsyncWrite`. - pub fn new(inner: W) -> Self { + pub fn new(inner: W, max_frame_size: usize) -> Self + where + W: AsyncWrite, + Out: Serialize, + { Self { inner, buffer: SmallVec::new(), @@ -227,6 +229,10 @@ mod direct { _phantom: PhantomData, } } + + pub fn into_inner(self) -> W { + self.inner + } } impl Sink for FramedPostcardWrite @@ -327,12 +333,9 @@ mod direct { ReadingData { len: usize, read: usize }, } - impl FramedPostcardRead - where - R: AsyncRead, - { + impl FramedPostcardRead { /// Creates a new `FramedPostcardRead` with the provided `AsyncRead`. - pub fn new(inner: R) -> Self { + pub fn new(inner: R, max_frame_size: usize) -> Self { Self { inner, buffer: SmallVec::new(), @@ -343,6 +346,10 @@ mod direct { _phantom: PhantomData, } } + + pub fn into_inner(self) -> R { + self.inner + } } impl Stream for FramedPostcardRead @@ -425,3 +432,5 @@ mod direct { } } } + +pub use direct::*;