Skip to content

Commit

Permalink
finish tcp server layer (first rough state)
Browse files Browse the repository at this point in the history
  • Loading branch information
glendc committed Nov 3, 2023
1 parent c545ebc commit 0ee6912
Show file tree
Hide file tree
Showing 11 changed files with 396 additions and 13 deletions.
7 changes: 7 additions & 0 deletions rama/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
http = "0.2.9"
pin-project-lite = "0.2.13"
tokio = { version = "1.33.0", features = ["net", "io-util"] }
tower-async-service = "0.1.1"

[dev-dependencies]
tokio-test = "0.4.3"
17 changes: 4 additions & 13 deletions rama/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
}
#![feature(async_fn_in_trait)]

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}
pub mod server;
pub mod state;
pub mod stream;
1 change: 1 addition & 0 deletions rama/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod tcp;
117 changes: 117 additions & 0 deletions rama/src/server/tcp/listener.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use std::{
io,
net::{self, SocketAddr},
};

use tokio::net::{TcpListener as TokioTcpListener, ToSocketAddrs};

use super::TcpStream;

pub struct TcpListener {
inner: TokioTcpListener,
}

impl TcpListener {
/// Creates a new TcpListener, which will be bound to the specified address.
///
/// The returned listener is ready for accepting connections.
///
/// Binding with a port number of 0 will request that the OS assigns a port
/// to this listener. The port allocated can be queried via the `local_addr`
/// method.
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> {
let inner = TokioTcpListener::bind(addr).await?;
Ok(TcpListener { inner })
}

/// Creates new `TcpListener` from a `std::net::TcpListener`.
///
/// This function is intended to be used to wrap a TCP listener from the
/// standard library in the Tokio equivalent.
///
/// This API is typically paired with the `socket2` crate and the `Socket`
/// type to build up and customize a listener before it's shipped off to the
/// backing event loop. This allows configuration of options like
/// `SO_REUSEPORT`, binding to multiple addresses, etc.
///
/// # Notes
///
/// The caller is responsible for ensuring that the listener is in
/// non-blocking mode. Otherwise all I/O operations on the listener
/// will block the thread, which will cause unexpected behavior.
/// Non-blocking mode can be set using [`set_nonblocking`].
///
/// [`set_nonblocking`]: std::net::TcpListener::set_nonblocking
pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
let inner = TokioTcpListener::from_std(listener)?;
Ok(TcpListener { inner })
}

/// Returns the local address that this listener is bound to.
///
/// This can be useful, for example, when binding to port 0 to figure out
/// which port was actually bound.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}

/// Gets the value of the `IP_TTL` option for this socket.
///
/// For more information about this option, see [`set_ttl`].
///
/// [`set_ttl`]: method@Self::set_ttl
pub fn ttl(&self) -> io::Result<u32> {
self.inner.ttl()
}

/// Sets the value for the `IP_TTL` option on this socket.
///
/// This value sets the time-to-live field that is used in every packet sent
/// from this socket.
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.set_ttl(ttl)
}

/// Serve connections from this listener with the given service.
///
/// This method will block the current listener for each incoming connection,
/// the underlying service can choose to spawn a task to handle the accepted stream.
pub async fn serve<S, E>(self, mut service: S) -> TcpServeResult<E>
where
S: tower_async_service::Service<TcpStream, Response = (), Error = E>,
{
loop {
let (stream, _) = self.inner.accept().await?;
let stream = TcpStream::new(stream);
service.call(stream).await.map_err(TcpServeError::Service)?;
}
}
}

pub type TcpServeResult<E> = Result<(), TcpServeError<E>>;

#[derive(Debug)]
pub enum TcpServeError<E> {
Io(io::Error),
Service(E),
}

impl<E> From<io::Error> for TcpServeError<E> {
fn from(e: io::Error) -> Self {
TcpServeError::Io(e)
}
}

impl<E> std::fmt::Display for TcpServeError<E>
where
E: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TcpServeError::Io(e) => write!(f, "IO error: {}", e),
TcpServeError::Service(e) => write!(f, "Service error: {}", e),
}
}
}

impl<E> std::error::Error for TcpServeError<E> where E: std::error::Error {}
7 changes: 7 additions & 0 deletions rama/src/server/tcp/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod listener;
pub use listener::{TcpListener, TcpServeError, TcpServeResult};

mod stream;
pub use stream::TcpStream;

pub mod service;
73 changes: 73 additions & 0 deletions rama/src/server/tcp/service/echo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::io::Error;

use tower_async_service::Service;

use crate::stream::Stream;

/// An async service which echoes the incoming bytes back on the same stream.
///
/// # Example
///
/// ```rust
/// use tower_async_service::Service;
/// use rama::server::tcp::service::EchoService;
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let stream = tokio_test::io::Builder::new().read(b"hello world").write(b"hello world").build();
/// let mut service = EchoService::new();
///
/// let bytes_copied = service.call(stream).await?;
/// # assert_eq!(bytes_copied, 11);
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct EchoService {
_phantom: (),
}

impl EchoService {
/// Creates a new [`EchoService`],
pub fn new() -> Self {
Self { _phantom: () }
}
}

impl Default for EchoService {
fn default() -> Self {
Self::new()
}
}

impl<S> Service<S> for EchoService
where
S: Stream,
{
type Response = u64;
type Error = Error;

async fn call(&mut self, stream: S) -> Result<Self::Response, Self::Error> {
let (mut reader, mut writer) = tokio::io::split(stream);
tokio::io::copy(&mut reader, &mut writer).await
}
}

#[cfg(test)]
mod tests {
use super::*;

use tokio_test::io::Builder;

#[tokio::test]
async fn test_echo() {
let stream = Builder::new()
.read(b"one")
.write(b"one")
.read(b"two")
.write(b"two")
.build();

EchoService::new().call(stream).await.unwrap();
}
}
80 changes: 80 additions & 0 deletions rama/src/server/tcp/service/forward.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::{io::Error, pin::Pin};

use tower_async_service::Service;

use crate::stream::Stream;

/// Async service which forwards the incoming connection bytes to the given destination,
/// and forwards the response back from the destination to the incoming connection.
///
/// # Example
///
/// ```rust
/// use tower_async_service::Service;
/// use rama::server::tcp::service::ForwardService;
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let destination = tokio_test::io::Builder::new().write(b"hello world").read(b"hello world").build();
/// # let stream = tokio_test::io::Builder::new().read(b"hello world").write(b"hello world").build();
/// let mut service = ForwardService::new(destination);
///
/// let (bytes_copied_to, bytes_copied_from) = service.call(stream).await?;
/// # assert_eq!(bytes_copied_to, 11);
/// # assert_eq!(bytes_copied_from, 11);
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
pub struct ForwardService<D> {
destination: Pin<Box<D>>,
}

impl<D> ForwardService<D> {
/// Creates a new [`ForwardService`],
pub fn new(destination: D) -> Self {
Self {
destination: Box::pin(destination),
}
}
}

impl<S, D> Service<S> for ForwardService<D>
where
S: Stream,
D: Stream,
{
type Response = (u64, u64);
type Error = Error;

async fn call(&mut self, source: S) -> Result<Self::Response, Self::Error> {
tokio::pin!(source);
tokio::io::copy_bidirectional(&mut source, &mut self.destination).await
}
}

#[cfg(test)]
mod tests {
use super::*;

use tokio_test::io::Builder;

#[tokio::test]
async fn test_forwarder() {
let destination = Builder::new()
.write(b"to(1)")
.read(b"from(1)")
.write(b"to(2)")
.wait(std::time::Duration::from_secs(1))
.read(b"from(2)")
.build();
let stream = Builder::new()
.read(b"to(1)")
.write(b"from(1)")
.read(b"to(2)")
.write(b"from(2)")
.build();

ForwardService::new(destination).call(stream).await.unwrap();
}
}
5 changes: 5 additions & 0 deletions rama/src/server/tcp/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod echo;
pub use echo::EchoService;

mod forward;
pub use forward::ForwardService;
Loading

0 comments on commit 0ee6912

Please sign in to comment.